aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw0217@gmail.com>2018-03-29 11:12:41 -0700
committerGravatar Jianwei Xie <xiejw0217@gmail.com>2018-03-29 11:12:41 -0700
commit4336e0c05b68c5e978a06cf4c4ae3459a64ba94d (patch)
tree4fc3e23a853d80aeb22777165a9fc4b9c0b06a76
parent86868a156860877fc6e8c3393baf4942b6b7dbd4 (diff)
parent63dffd5a3bc4e94e74cb140cbf7a68e0e5644ad6 (diff)
Resolve conflicts.
-rw-r--r--tensorflow/BUILD298
-rw-r--r--tensorflow/c/BUILD19
-rw-r--r--tensorflow/c/eager/BUILD2
-rw-r--r--tensorflow/c/eager/c_api.cc223
-rw-r--r--tensorflow/c/eager/c_api_test.cc4
-rw-r--r--tensorflow/c/eager/tape.h21
-rw-r--r--tensorflow/c/python_api.cc26
-rw-r--r--tensorflow/c/python_api.h7
-rw-r--r--tensorflow/cc/BUILD12
-rw-r--r--tensorflow/cc/framework/cc_op_gen_test.cc5
-rw-r--r--tensorflow/cc/framework/scope.cc3
-rw-r--r--tensorflow/cc/saved_model/BUILD15
-rw-r--r--tensorflow/cc/saved_model/python/BUILD12
-rw-r--r--tensorflow/cc/tools/BUILD15
-rw-r--r--tensorflow/compiler/aot/BUILD14
-rw-r--r--tensorflow/compiler/aot/codegen_test.cc3
-rw-r--r--tensorflow/compiler/aot/tests/BUILD14
-rw-r--r--tensorflow/compiler/aot/tfcompile_main.cc2
-rw-r--r--tensorflow/compiler/jit/BUILD14
-rw-r--r--tensorflow/compiler/jit/graphcycles/BUILD14
-rw-r--r--tensorflow/compiler/jit/kernels/BUILD14
-rw-r--r--tensorflow/compiler/jit/legacy_flags/BUILD14
-rw-r--r--tensorflow/compiler/jit/ops/BUILD14
-rw-r--r--tensorflow/compiler/jit/xla_tensor_info.h6
-rw-r--r--tensorflow/compiler/plugin/BUILD14
-rw-r--r--tensorflow/compiler/tests/BUILD14
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py14
-rw-r--r--tensorflow/compiler/tf2xla/BUILD14
-rw-r--r--tensorflow/compiler/tf2xla/cc/BUILD14
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD14
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cast_op.cc45
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD14
-rw-r--r--tensorflow/compiler/tf2xla/ops/BUILD14
-rw-r--r--tensorflow/compiler/xla/BUILD12
-rw-r--r--tensorflow/compiler/xla/client/BUILD14
-rw-r--r--tensorflow/compiler/xla/client/client.cc62
-rw-r--r--tensorflow/compiler/xla/client/client.h46
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc20
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h9
-rw-r--r--tensorflow/compiler/xla/client/lib/BUILD14
-rw-r--r--tensorflow/compiler/xla/client/xla_client/BUILD15
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc150
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h71
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc4
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_computation.cc5
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_computation.h5
-rw-r--r--tensorflow/compiler/xla/legacy_flags/BUILD14
-rw-r--r--tensorflow/compiler/xla/python/BUILD12
-rw-r--r--tensorflow/compiler/xla/service/BUILD32
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc57
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc85
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD14
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h15
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc90
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD14
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc24
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc108
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD14
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc108
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc37
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc7
-rw-r--r--tensorflow/compiler/xla/service/interpreter/BUILD11
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD14
-rw-r--r--tensorflow/compiler/xla/service/service.cc28
-rw-r--r--tensorflow/compiler/xla/service/service.h16
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc6
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h3
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc4
-rw-r--r--tensorflow/compiler/xla/service_interface.h8
-rw-r--r--tensorflow/compiler/xla/tests/BUILD16
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc26
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc22
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h6
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc9
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc19
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc99
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc4
-rw-r--r--tensorflow/compiler/xla/tools/BUILD14
-rw-r--r--tensorflow/compiler/xla/tools/parser/BUILD14
-rw-r--r--tensorflow/compiler/xla/xla.proto9
-rw-r--r--tensorflow/contrib/BUILD12
-rw-r--r--tensorflow/contrib/all_reduce/BUILD13
-rw-r--r--tensorflow/contrib/all_reduce/python/all_reduce.py7
-rw-r--r--tensorflow/contrib/all_reduce/python/all_reduce_test.py6
-rw-r--r--tensorflow/contrib/android/BUILD14
-rw-r--r--tensorflow/contrib/autograph/utils/builtins.py10
-rw-r--r--tensorflow/contrib/batching/BUILD11
-rw-r--r--tensorflow/contrib/batching/test_util/BUILD11
-rw-r--r--tensorflow/contrib/batching/util/BUILD12
-rw-r--r--tensorflow/contrib/bayesflow/BUILD12
-rw-r--r--tensorflow/contrib/boosted_trees/BUILD9
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/BUILD9
-rw-r--r--tensorflow/contrib/boosted_trees/lib/BUILD11
-rw-r--r--tensorflow/contrib/boosted_trees/proto/BUILD11
-rw-r--r--tensorflow/contrib/boosted_trees/resources/BUILD11
-rw-r--r--tensorflow/contrib/cloud/BUILD12
-rw-r--r--tensorflow/contrib/cloud/kernels/BUILD15
-rw-r--r--tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc5
-rw-r--r--tensorflow/contrib/cluster_resolver/BUILD13
-rw-r--r--tensorflow/contrib/cmake/external/grpc.cmake2
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt1
-rw-r--r--tensorflow/contrib/cmake/python_protos.txt1
-rw-r--r--tensorflow/contrib/cmake/tf_core_ops.cmake7
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake1
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake1
-rw-r--r--tensorflow/contrib/coder/BUILD11
-rw-r--r--tensorflow/contrib/compiler/BUILD12
-rw-r--r--tensorflow/contrib/copy_graph/BUILD12
-rw-r--r--tensorflow/contrib/crf/BUILD12
-rw-r--r--tensorflow/contrib/cudnn_rnn/BUILD12
-rw-r--r--tensorflow/contrib/data/BUILD14
-rw-r--r--tensorflow/contrib/data/__init__.py5
-rw-r--r--tensorflow/contrib/data/kernels/BUILD11
-rw-r--r--tensorflow/contrib/data/kernels/prefetching_kernels.cc7
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD14
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py15
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD12
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py61
-rw-r--r--tensorflow/contrib/decision_trees/proto/BUILD8
-rw-r--r--tensorflow/contrib/deprecated/BUILD12
-rw-r--r--tensorflow/contrib/distributions/BUILD27
-rw-r--r--tensorflow/contrib/distributions/__init__.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py531
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py108
-rw-r--r--tensorflow/contrib/distributions/python/ops/batch_reshape.py333
-rw-r--r--tensorflow/contrib/eager/proto/BUILD11
-rw-r--r--tensorflow/contrib/eager/python/BUILD14
-rw-r--r--tensorflow/contrib/eager/python/checkpointable_utils_test.py33
-rw-r--r--tensorflow/contrib/eager/python/examples/spinn/spinn_test.py13
-rw-r--r--tensorflow/contrib/estimator/BUILD43
-rw-r--r--tensorflow/contrib/estimator/__init__.py3
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees.py323
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py207
-rw-r--r--tensorflow/contrib/factorization/BUILD13
-rw-r--r--tensorflow/contrib/factorization/examples/BUILD11
-rw-r--r--tensorflow/contrib/factorization/kernels/BUILD11
-rw-r--r--tensorflow/contrib/factorization/python/ops/factorization_ops.py10
-rw-r--r--tensorflow/contrib/factorization/python/ops/factorization_ops_test.py8
-rw-r--r--tensorflow/contrib/feature_column/BUILD12
-rw-r--r--tensorflow/contrib/ffmpeg/BUILD12
-rw-r--r--tensorflow/contrib/ffmpeg/default/BUILD12
-rw-r--r--tensorflow/contrib/framework/BUILD12
-rw-r--r--tensorflow/contrib/fused_conv/BUILD12
-rw-r--r--tensorflow/contrib/gan/BUILD12
-rw-r--r--tensorflow/contrib/gdr/BUILD12
-rw-r--r--tensorflow/contrib/graph_editor/BUILD12
-rw-r--r--tensorflow/contrib/grid_rnn/BUILD12
-rw-r--r--tensorflow/contrib/hooks/BUILD11
-rw-r--r--tensorflow/contrib/hvx/clock_cycle_profiling/BUILD12
-rw-r--r--tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD11
-rwxr-xr-xtensorflow/contrib/image/BUILD12
-rw-r--r--tensorflow/contrib/input_pipeline/BUILD11
-rw-r--r--tensorflow/contrib/input_pipeline/kernels/BUILD11
-rw-r--r--tensorflow/contrib/integrate/BUILD11
-rw-r--r--tensorflow/contrib/kafka/BUILD14
-rw-r--r--tensorflow/contrib/keras/BUILD12
-rw-r--r--tensorflow/contrib/kernel_methods/BUILD12
-rw-r--r--tensorflow/contrib/kfac/BUILD12
-rw-r--r--tensorflow/contrib/kfac/examples/BUILD12
-rw-r--r--tensorflow/contrib/kfac/examples/tests/BUILD12
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/BUILD12
-rw-r--r--tensorflow/contrib/kfac/python/ops/BUILD12
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py8
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py16
-rw-r--r--tensorflow/contrib/labeled_tensor/BUILD11
-rw-r--r--tensorflow/contrib/layers/BUILD12
-rw-r--r--tensorflow/contrib/layers/kernels/BUILD11
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py5
-rw-r--r--tensorflow/contrib/learn/BUILD12
-rw-r--r--tensorflow/contrib/learn/python/learn/datasets/BUILD12
-rw-r--r--tensorflow/contrib/learn/python/learn/datasets/base.py35
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/run_config.py11
-rw-r--r--tensorflow/contrib/legacy_seq2seq/BUILD12
-rw-r--r--tensorflow/contrib/libsvm/BUILD12
-rw-r--r--tensorflow/contrib/linalg/BUILD12
-rw-r--r--tensorflow/contrib/linear_optimizer/BUILD11
-rw-r--r--tensorflow/contrib/lite/BUILD15
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h1
-rw-r--r--tensorflow/contrib/lite/examples/label_image/BUILD12
-rw-r--r--tensorflow/contrib/lite/java/BUILD12
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/BUILD12
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD12
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/BUILD12
-rw-r--r--tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD12
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD13
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected_test.cc57
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD12
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h11
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h43
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h169
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h22
-rw-r--r--tensorflow/contrib/lite/kernels/mean.cc62
-rw-r--r--tensorflow/contrib/lite/kernels/mean_test.cc81
-rw-r--r--tensorflow/contrib/lite/models/BUILD12
-rw-r--r--tensorflow/contrib/lite/models/smartreply/BUILD12
-rw-r--r--tensorflow/contrib/lite/nnapi/BUILD12
-rw-r--r--tensorflow/contrib/lite/python/BUILD12
-rw-r--r--tensorflow/contrib/lite/schema/BUILD12
-rw-r--r--tensorflow/contrib/lite/testing/BUILD12
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py3
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc3
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc17
-rw-r--r--tensorflow/contrib/lite/toco/BUILD14
-rw-r--r--tensorflow/contrib/lite/toco/args.h3
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/swap_elementwise_binary.cc175
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD23
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/swap_elementwise_binary_test.cc89
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc9
-rw-r--r--tensorflow/contrib/lite/toco/python/BUILD12
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD12
-rw-r--r--tensorflow/contrib/lite/toco/tflite/BUILD12
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc11
-rw-r--r--tensorflow/contrib/lite/toco/toco_saved_model.cc9
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc1
-rw-r--r--tensorflow/contrib/lite/tools/BUILD12
-rw-r--r--tensorflow/contrib/lookup/BUILD12
-rw-r--r--tensorflow/contrib/losses/BUILD12
-rw-r--r--tensorflow/contrib/makefile/BUILD9
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt6
-rw-r--r--tensorflow/contrib/makefile/tf_proto_files.txt1
-rw-r--r--tensorflow/contrib/memory_stats/BUILD12
-rw-r--r--tensorflow/contrib/meta_graph_transform/BUILD12
-rw-r--r--tensorflow/contrib/metrics/BUILD11
-rw-r--r--tensorflow/contrib/model_pruning/BUILD12
-rw-r--r--tensorflow/contrib/model_pruning/examples/cifar10/BUILD12
-rw-r--r--tensorflow/contrib/mpi_collectives/BUILD12
-rw-r--r--tensorflow/contrib/nccl/BUILD12
-rw-r--r--tensorflow/contrib/nearest_neighbor/BUILD12
-rw-r--r--tensorflow/contrib/nn/BUILD11
-rw-r--r--tensorflow/contrib/opt/BUILD11
-rw-r--r--tensorflow/contrib/periodic_resample/BUILD12
-rw-r--r--tensorflow/contrib/predictor/BUILD12
-rw-r--r--tensorflow/contrib/quantization/BUILD12
-rw-r--r--tensorflow/contrib/quantize/BUILD12
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py6
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py18
-rw-r--r--tensorflow/contrib/receptive_field/BUILD12
-rw-r--r--tensorflow/contrib/reduce_slice_ops/BUILD12
-rw-r--r--tensorflow/contrib/remote_fused_graph/pylib/BUILD12
-rw-r--r--tensorflow/contrib/resampler/BUILD11
-rw-r--r--tensorflow/contrib/rnn/BUILD13
-rw-r--r--tensorflow/contrib/saved_model/BUILD12
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/BUILD6
-rw-r--r--tensorflow/contrib/seq2seq/BUILD12
-rw-r--r--tensorflow/contrib/session_bundle/BUILD12
-rw-r--r--tensorflow/contrib/session_bundle/example/BUILD13
-rw-r--r--tensorflow/contrib/session_bundle/session_bundle_test.cc30
-rw-r--r--tensorflow/contrib/session_bundle/signature_test.cc68
-rw-r--r--tensorflow/contrib/signal/BUILD12
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py2
-rw-r--r--tensorflow/contrib/slim/BUILD12
-rw-r--r--tensorflow/contrib/slim/python/slim/data/BUILD12
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/BUILD12
-rw-r--r--tensorflow/contrib/solvers/BUILD13
-rw-r--r--tensorflow/contrib/sparsemax/BUILD12
-rw-r--r--tensorflow/contrib/specs/BUILD12
-rw-r--r--tensorflow/contrib/staging/BUILD12
-rw-r--r--tensorflow/contrib/stat_summarizer/BUILD12
-rw-r--r--tensorflow/contrib/stateless/BUILD12
-rw-r--r--tensorflow/contrib/summary/BUILD12
-rw-r--r--tensorflow/contrib/tensor_forest/BUILD14
-rw-r--r--tensorflow/contrib/tensor_forest/hybrid/BUILD12
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/BUILD5
-rw-r--r--tensorflow/contrib/tensor_forest/proto/BUILD8
-rw-r--r--tensorflow/contrib/tensorboard/BUILD12
-rw-r--r--tensorflow/contrib/tensorboard/db/BUILD6
-rw-r--r--tensorflow/contrib/tensorrt/BUILD12
-rw-r--r--tensorflow/contrib/testing/BUILD12
-rw-r--r--tensorflow/contrib/text/BUILD11
-rw-r--r--tensorflow/contrib/tfprof/BUILD12
-rw-r--r--tensorflow/contrib/timeseries/BUILD12
-rw-r--r--tensorflow/contrib/timeseries/examples/BUILD12
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/BUILD12
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators_test.py21
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head.py6
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD12
-rw-r--r--tensorflow/contrib/tpu/BUILD15
-rw-r--r--tensorflow/contrib/tpu/profiler/BUILD16
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py2
-rw-r--r--tensorflow/contrib/tpu/profiler/tf_op_stats.proto16
-rw-r--r--tensorflow/contrib/tpu/profiler/tpu_profiler_analysis_pb2_grpc.py2
-rw-r--r--tensorflow/contrib/tpu/proto/BUILD11
-rw-r--r--tensorflow/contrib/tpu/python/profiler/__init__.py1
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config.py32
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py12
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py79
-rw-r--r--tensorflow/contrib/training/BUILD12
-rw-r--r--tensorflow/contrib/util/BUILD12
-rw-r--r--tensorflow/contrib/verbs/BUILD12
-rw-r--r--tensorflow/core/BUILD66
-rw-r--r--tensorflow/core/api_def/BUILD12
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt87
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateEnsemble.pbtxt23
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesDeserializeEnsemble.pbtxt26
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesEnsembleResourceHandleOp.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesGetEnsembleStates.pbtxt35
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeStatsSummary.pbtxt56
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesPredict.pbtxt41
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesSerializeEnsemble.pbtxt23
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesTrainingPredict.pbtxt69
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesUpdateEnsemble.pbtxt82
-rw-r--r--tensorflow/core/api_def/base_api/api_def_IsBoostedTreesEnsembleInitialized.pbtxt17
-rw-r--r--tensorflow/core/common_runtime/buf_rendezvous.cc166
-rw-r--r--tensorflow/core/common_runtime/buf_rendezvous.h103
-rw-r--r--tensorflow/core/common_runtime/buf_rendezvous_test.cc197
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr.cc114
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr.h70
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr_test.cc98
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.cc666
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.h209
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local_test.cc151
-rw-r--r--tensorflow/core/common_runtime/collective_rma_local.cc108
-rw-r--r--tensorflow/core/common_runtime/collective_rma_local.h88
-rw-r--r--tensorflow/core/common_runtime/collective_rma_local_test.cc148
-rw-r--r--tensorflow/core/common_runtime/device_resolver_local.cc49
-rw-r--r--tensorflow/core/common_runtime/device_resolver_local.h48
-rw-r--r--tensorflow/core/common_runtime/device_resolver_local_test.cc87
-rw-r--r--tensorflow/core/common_runtime/eager/BUILD41
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc134
-rw-r--r--tensorflow/core/common_runtime/eager/execute.h41
-rw-r--r--tensorflow/core/common_runtime/eager/execute_node.h88
-rw-r--r--tensorflow/core/debug/BUILD15
-rw-r--r--tensorflow/core/distributed_runtime/BUILD12
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD12
-rw-r--r--tensorflow/core/framework/collective.cc120
-rw-r--r--tensorflow/core/framework/collective.h308
-rw-r--r--tensorflow/core/framework/common_shape_fns.h2
-rw-r--r--tensorflow/core/framework/op_kernel.h1
-rw-r--r--tensorflow/core/framework/resource_mgr.h15
-rw-r--r--tensorflow/core/grappler/BUILD12
-rw-r--r--tensorflow/core/grappler/clusters/BUILD12
-rw-r--r--tensorflow/core/grappler/costs/BUILD13
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc4
-rw-r--r--tensorflow/core/grappler/costs/measuring_cost_estimator.cc23
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc309
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h14
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc452
-rw-r--r--tensorflow/core/grappler/costs/utils.cc4
-rw-r--r--tensorflow/core/grappler/inputs/BUILD12
-rw-r--r--tensorflow/core/grappler/op_types.cc2
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD26
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc85
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h2
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc77
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper.cc19
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper_test.cc71
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer_test.cc9
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer_test.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc8
-rw-r--r--tensorflow/core/grappler/utils.cc8
-rw-r--r--tensorflow/core/grappler/utils/BUILD12
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.cc8
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.h4
-rw-r--r--tensorflow/core/kernels/BUILD19
-rw-r--r--tensorflow/core/kernels/batching_util/BUILD12
-rw-r--r--tensorflow/core/kernels/boosted_trees/BUILD89
-rw-r--r--tensorflow/core/kernels/boosted_trees/boosted_trees.proto113
-rw-r--r--tensorflow/core/kernels/boosted_trees/prediction_ops.cc263
-rw-r--r--tensorflow/core/kernels/boosted_trees/resource_ops.cc189
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.cc301
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.h221
-rw-r--r--tensorflow/core/kernels/boosted_trees/stats_ops.cc296
-rw-r--r--tensorflow/core/kernels/boosted_trees/training_ops.cc219
-rw-r--r--tensorflow/core/kernels/cwise_op_log.cc4
-rw-r--r--tensorflow/core/kernels/cwise_ops.h21
-rw-r--r--tensorflow/core/kernels/data/BUILD33
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc5
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner.cc46
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner.h71
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner_test.cc82
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc13
-rw-r--r--tensorflow/core/kernels/data/sql/BUILD12
-rw-r--r--tensorflow/core/kernels/fuzzing/BUILD12
-rw-r--r--tensorflow/core/kernels/hexagon/BUILD12
-rw-r--r--tensorflow/core/kernels/initializable_lookup_table.h2
-rw-r--r--tensorflow/core/kernels/list_kernels.cc1
-rw-r--r--tensorflow/core/kernels/list_kernels.h3
-rw-r--r--tensorflow/core/kernels/neon/BUILD12
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc1
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.h7
-rw-r--r--tensorflow/core/kernels/snapshot_op_gpu.cu.cc3
-rw-r--r--tensorflow/core/kernels/xent_op.cc4
-rw-r--r--tensorflow/core/lib/db/BUILD6
-rw-r--r--tensorflow/core/lib/io/inputbuffer_test.cc3
-rw-r--r--tensorflow/core/lib/io/recordio_test.cc3
-rw-r--r--tensorflow/core/ops/array_ops.cc123
-rw-r--r--tensorflow/core/ops/array_ops_test.cc6
-rw-r--r--tensorflow/core/ops/boosted_trees_ops.cc319
-rw-r--r--tensorflow/core/ops/compat/BUILD15
-rw-r--r--tensorflow/core/ops/nn_ops.cc7
-rw-r--r--tensorflow/core/platform/cloud/BUILD14
-rw-r--r--tensorflow/core/platform/default/build_config/BUILD17
-rw-r--r--tensorflow/core/platform/hadoop/BUILD12
-rw-r--r--tensorflow/core/platform/s3/BUILD12
-rw-r--r--tensorflow/core/profiler/BUILD15
-rw-r--r--tensorflow/core/profiler/internal/BUILD14
-rw-r--r--tensorflow/core/profiler/internal/advisor/BUILD15
-rw-r--r--tensorflow/core/util/ctc/BUILD12
-rw-r--r--tensorflow/core/util/tensor_bundle/BUILD15
-rw-r--r--tensorflow/docs_src/community/contributing.md64
-rw-r--r--tensorflow/docs_src/community/groups.md17
-rw-r--r--tensorflow/docs_src/community/index.md95
-rw-r--r--tensorflow/docs_src/community/leftnav_files5
-rw-r--r--tensorflow/docs_src/community/lists.md35
-rw-r--r--tensorflow/docs_src/community/welcome.md71
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md5
-rw-r--r--tensorflow/examples/adding_an_op/BUILD12
-rw-r--r--tensorflow/examples/android/BUILD16
-rw-r--r--tensorflow/examples/benchmark/BUILD6
-rw-r--r--tensorflow/examples/get_started/regression/BUILD12
-rw-r--r--tensorflow/examples/how_tos/reading_data/BUILD12
-rw-r--r--tensorflow/examples/image_retraining/BUILD12
-rw-r--r--tensorflow/examples/label_image/BUILD16
-rw-r--r--tensorflow/examples/learn/BUILD12
-rw-r--r--tensorflow/examples/multibox_detector/BUILD14
-rw-r--r--tensorflow/examples/saved_model/BUILD13
-rw-r--r--tensorflow/examples/speech_commands/BUILD12
-rw-r--r--tensorflow/examples/tutorials/estimators/BUILD12
-rw-r--r--tensorflow/examples/tutorials/layers/BUILD12
-rw-r--r--tensorflow/examples/tutorials/mnist/BUILD12
-rw-r--r--tensorflow/examples/tutorials/monitors/BUILD12
-rw-r--r--tensorflow/examples/tutorials/word2vec/BUILD11
-rw-r--r--tensorflow/examples/wav_to_spectrogram/BUILD14
-rw-r--r--tensorflow/go/op/wrappers.go1788
-rw-r--r--tensorflow/java/BUILD12
-rw-r--r--tensorflow/python/BUILD37
-rw-r--r--tensorflow/python/__init__.py2
-rw-r--r--tensorflow/python/client/tf_session.i1
-rw-r--r--tensorflow/python/data/BUILD12
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD12
-rw-r--r--tensorflow/python/data/ops/BUILD12
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py2
-rw-r--r--tensorflow/python/data/util/BUILD12
-rw-r--r--tensorflow/python/debug/BUILD12
-rw-r--r--tensorflow/python/eager/BUILD15
-rw-r--r--tensorflow/python/eager/backprop.py33
-rw-r--r--tensorflow/python/eager/backprop_test.py47
-rw-r--r--tensorflow/python/eager/context.py2
-rw-r--r--tensorflow/python/eager/core_test.py8
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc60
-rw-r--r--tensorflow/python/estimator/BUILD60
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py736
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py799
-rw-r--r--tensorflow/python/estimator/estimator.py341
-rw-r--r--tensorflow/python/estimator/estimator_lib.py4
-rw-r--r--tensorflow/python/estimator/run_config.py3
-rw-r--r--tensorflow/python/feature_column/BUILD12
-rw-r--r--tensorflow/python/framework/function.py6
-rw-r--r--tensorflow/python/framework/function_test.py9
-rw-r--r--tensorflow/python/framework/importer_test.py34
-rw-r--r--tensorflow/python/framework/ops.py36
-rw-r--r--tensorflow/python/framework/python_op_gen.cc2
-rw-r--r--tensorflow/python/framework/python_op_gen_main.cc4
-rw-r--r--tensorflow/python/framework/test_util.py10
-rw-r--r--tensorflow/python/grappler/cluster_test.py16
-rwxr-xr-xtensorflow/python/keras/BUILD12
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/base_layer.py90
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/network.py9
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/sequential.py31
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training.py5
-rw-r--r--tensorflow/python/keras/_impl/keras/model_subclassing_test.py130
-rw-r--r--tensorflow/python/kernel_tests/BUILD14
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/BUILD76
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/__init__.py0
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py926
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py228
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py289
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py1465
-rw-r--r--tensorflow/python/kernel_tests/distributions/BUILD12
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/init_ops_test.py24
-rw-r--r--tensorflow/python/kernel_tests/linalg/BUILD12
-rw-r--r--tensorflow/python/kernel_tests/random/BUILD12
-rw-r--r--tensorflow/python/layers/base.py2
-rw-r--r--tensorflow/python/ops/array_ops.py5
-rw-r--r--tensorflow/python/ops/boosted_trees_ops.py160
-rw-r--r--tensorflow/python/ops/distributions/BUILD12
-rw-r--r--tensorflow/python/ops/init_ops.py9
-rw-r--r--tensorflow/python/ops/linalg/BUILD12
-rw-r--r--tensorflow/python/ops/losses/BUILD12
-rw-r--r--tensorflow/python/ops/nn_ops.py8
-rw-r--r--tensorflow/python/ops/nn_test.py51
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py5
-rw-r--r--tensorflow/python/ops/script_ops.py6
-rw-r--r--tensorflow/python/profiler/BUILD15
-rw-r--r--tensorflow/python/profiler/internal/BUILD15
-rw-r--r--tensorflow/python/saved_model/BUILD12
-rw-r--r--tensorflow/python/tools/BUILD14
-rw-r--r--tensorflow/python/training/checkpointable.py1
-rw-r--r--tensorflow/python/training/device_setter.py13
-rw-r--r--tensorflow/python/training/distribute.py75
-rw-r--r--tensorflow/python/training/monitored_session.py4
-rw-r--r--tensorflow/python/training/optimizer.py174
-rw-r--r--tensorflow/python/training/saver.py20
-rw-r--r--tensorflow/python/training/saver_test.py32
-rw-r--r--tensorflow/python/training/slot_creator.py8
-rw-r--r--tensorflow/python/util/nest.py90
-rw-r--r--tensorflow/python/util/nest_test.py156
-rw-r--r--tensorflow/python/util/util.cc374
-rw-r--r--tensorflow/python/util/util.h51
-rw-r--r--tensorflow/python/util/util.i9
-rw-r--r--tensorflow/stream_executor/BUILD6
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc87
-rw-r--r--tensorflow/stream_executor/cuda/cudnn_version.cc42
-rw-r--r--tensorflow/stream_executor/cuda/cudnn_version.h51
-rw-r--r--tensorflow/stream_executor/cuda/cudnn_version_test.cc75
-rw-r--r--tensorflow/stream_executor/kernel.cc3
-rw-r--r--tensorflow/stream_executor/lib/str_util.h2
-rw-r--r--tensorflow/tools/api/generator/BUILD12
-rw-r--r--tensorflow/tools/api/golden/BUILD12
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt54
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt54
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.pbtxt8
-rw-r--r--tensorflow/tools/api/lib/BUILD12
-rw-r--r--tensorflow/tools/api/tests/BUILD12
-rw-r--r--tensorflow/tools/benchmark/BUILD9
-rw-r--r--tensorflow/tools/build_info/BUILD15
-rwxr-xr-xtensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh1
-rwxr-xr-xtensorflow/tools/ci_build/osx/cpu/run_py3_cc_core.sh1
-rw-r--r--tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh2
-rw-r--r--tensorflow/tools/common/BUILD11
-rw-r--r--tensorflow/tools/compatibility/BUILD15
-rw-r--r--tensorflow/tools/dist_test/server/BUILD12
-rw-r--r--tensorflow/tools/docker/BUILD12
-rw-r--r--tensorflow/tools/docker/notebooks/BUILD12
-rw-r--r--tensorflow/tools/docs/BUILD11
-rw-r--r--tensorflow/tools/git/BUILD15
-rw-r--r--tensorflow/tools/graph_transforms/BUILD11
-rw-r--r--tensorflow/tools/mlpbtxt/BUILD12
-rw-r--r--tensorflow/tools/pip_package/BUILD1
-rw-r--r--tensorflow/tools/proto_text/BUILD15
-rw-r--r--tensorflow/tools/quantization/BUILD12
-rw-r--r--tensorflow/tools/test/BUILD12
-rw-r--r--tensorflow/user_ops/BUILD12
-rw-r--r--tensorflow/workspace.bzl17
-rw-r--r--third_party/examples/eager/spinn/spinn.py29
-rw-r--r--third_party/hadoop/BUILD12
-rw-r--r--third_party/mkl/BUILD11
-rw-r--r--third_party/mkl/mkl.BUILD6
-rw-r--r--third_party/mpi/BUILD12
-rw-r--r--third_party/sycl/BUILD12
-rw-r--r--third_party/sycl/sycl/BUILD12
551 files changed, 18615 insertions, 5890 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 31e64793de..823393ebdf 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -401,19 +401,6 @@ package_group(
],
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "g3doc/sitemap.md",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
py_library(
name = "tensorflow_py",
srcs = ["__init__.py"],
@@ -433,291 +420,6 @@ py_library(
],
)
-filegroup(
- name = "all_opensource_files",
- data = [
- ":all_files",
- "//tensorflow/c:all_files",
- "//tensorflow/cc:all_files",
- "//tensorflow/cc/saved_model:all_files",
- "//tensorflow/cc/saved_model/python:all_files",
- "//tensorflow/cc/tools:all_files",
- "//tensorflow/compiler/aot:all_files",
- "//tensorflow/compiler/aot/tests:all_files",
- "//tensorflow/compiler/jit:all_files",
- "//tensorflow/compiler/jit/graphcycles:all_files",
- "//tensorflow/compiler/jit/kernels:all_files",
- "//tensorflow/compiler/jit/legacy_flags:all_files",
- "//tensorflow/compiler/jit/ops:all_files",
- "//tensorflow/compiler/plugin:all_files",
- "//tensorflow/compiler/tests:all_files",
- "//tensorflow/compiler/tf2xla:all_files",
- "//tensorflow/compiler/tf2xla/cc:all_files",
- "//tensorflow/compiler/tf2xla/kernels:all_files",
- "//tensorflow/compiler/tf2xla/lib:all_files",
- "//tensorflow/compiler/tf2xla/ops:all_files",
- "//tensorflow/compiler/xla:all_files",
- "//tensorflow/compiler/xla/client:all_files",
- "//tensorflow/compiler/xla/client/lib:all_files",
- "//tensorflow/compiler/xla/client/xla_client:all_files",
- "//tensorflow/compiler/xla/legacy_flags:all_files",
- "//tensorflow/compiler/xla/python:all_files",
- "//tensorflow/compiler/xla/service:all_files",
- "//tensorflow/compiler/xla/service/cpu:all_files",
- "//tensorflow/compiler/xla/service/gpu:all_files",
- "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend:all_files",
- "//tensorflow/compiler/xla/service/interpreter:all_files",
- "//tensorflow/compiler/xla/service/llvm_ir:all_files",
- "//tensorflow/compiler/xla/tests:all_files",
- "//tensorflow/compiler/xla/tools:all_files",
- "//tensorflow/compiler/xla/tools/parser:all_files",
- "//tensorflow/contrib:all_files",
- "//tensorflow/contrib/all_reduce:all_files",
- "//tensorflow/contrib/android:all_files",
- "//tensorflow/contrib/autograph:all_files",
- "//tensorflow/contrib/autograph/converters:all_files",
- "//tensorflow/contrib/autograph/impl:all_files",
- "//tensorflow/contrib/autograph/pyct:all_files",
- "//tensorflow/contrib/autograph/pyct/static_analysis:all_files",
- "//tensorflow/contrib/autograph/utils:all_files",
- "//tensorflow/contrib/batching:all_files",
- "//tensorflow/contrib/bayesflow:all_files",
- "//tensorflow/contrib/boosted_trees:all_files",
- "//tensorflow/contrib/boosted_trees/estimator_batch:all_files",
- "//tensorflow/contrib/boosted_trees/lib:all_files",
- "//tensorflow/contrib/boosted_trees/proto:all_files",
- "//tensorflow/contrib/boosted_trees/resources:all_files",
- "//tensorflow/contrib/cloud:all_files",
- "//tensorflow/contrib/cloud/kernels:all_files",
- "//tensorflow/contrib/cluster_resolver:all_files",
- "//tensorflow/contrib/coder:all_files",
- "//tensorflow/contrib/compiler:all_files",
- "//tensorflow/contrib/copy_graph:all_files",
- "//tensorflow/contrib/crf:all_files",
- "//tensorflow/contrib/cudnn_rnn:all_files",
- "//tensorflow/contrib/data:all_files",
- "//tensorflow/contrib/data/kernels:all_files",
- "//tensorflow/contrib/data/python/kernel_tests:all_files",
- "//tensorflow/contrib/data/python/ops:all_files",
- "//tensorflow/contrib/decision_trees/proto:all_files",
- "//tensorflow/contrib/deprecated:all_files",
- "//tensorflow/contrib/distributions:all_files",
- "//tensorflow/contrib/eager/proto:all_files",
- "//tensorflow/contrib/eager/python:all_files",
- "//tensorflow/contrib/estimator:all_files",
- "//tensorflow/contrib/factorization:all_files",
- "//tensorflow/contrib/factorization/examples:all_files",
- "//tensorflow/contrib/factorization/kernels:all_files",
- "//tensorflow/contrib/feature_column:all_files",
- "//tensorflow/contrib/ffmpeg:all_files",
- "//tensorflow/contrib/ffmpeg/default:all_files",
- "//tensorflow/contrib/framework:all_files",
- "//tensorflow/contrib/fused_conv:all_files",
- "//tensorflow/contrib/gan:all_files",
- "//tensorflow/contrib/gdr:all_files",
- "//tensorflow/contrib/graph_editor:all_files",
- "//tensorflow/contrib/grid_rnn:all_files",
- "//tensorflow/contrib/hooks:all_files",
- "//tensorflow/contrib/hvx/clock_cycle_profiling:all_files",
- "//tensorflow/contrib/hvx/hvx_ops_support_checker:all_files",
- "//tensorflow/contrib/image:all_files",
- "//tensorflow/contrib/input_pipeline:all_files",
- "//tensorflow/contrib/input_pipeline/kernels:all_files",
- "//tensorflow/contrib/integrate:all_files",
- "//tensorflow/contrib/keras:all_files",
- "//tensorflow/contrib/kernel_methods:all_files",
- "//tensorflow/contrib/kfac:all_files",
- "//tensorflow/contrib/kfac/examples:all_files",
- "//tensorflow/contrib/kfac/examples/tests:all_files",
- "//tensorflow/contrib/kfac/python/kernel_tests:all_files",
- "//tensorflow/contrib/kfac/python/ops:all_files",
- "//tensorflow/contrib/labeled_tensor:all_files",
- "//tensorflow/contrib/layers:all_files",
- "//tensorflow/contrib/layers/kernels:all_files",
- "//tensorflow/contrib/learn:all_files",
- "//tensorflow/contrib/learn/python/learn/datasets:all_files",
- "//tensorflow/contrib/legacy_seq2seq:all_files",
- "//tensorflow/contrib/libsvm:all_files",
- "//tensorflow/contrib/linalg:all_files",
- "//tensorflow/contrib/linear_optimizer:all_files",
- "//tensorflow/contrib/lite:all_files",
- "//tensorflow/contrib/lite/java:all_files",
- "//tensorflow/contrib/lite/java/demo/app/src/main:all_files",
- "//tensorflow/contrib/lite/java/demo/app/src/main/assets:all_files",
- "//tensorflow/contrib/lite/java/src/main/native:all_files",
- "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:all_files",
- "//tensorflow/contrib/lite/kernels:all_files",
- "//tensorflow/contrib/lite/kernels/internal:all_files",
- "//tensorflow/contrib/lite/models/smartreply:all_files",
- "//tensorflow/contrib/lite/nnapi:all_files",
- "//tensorflow/contrib/lite/python:all_files",
- "//tensorflow/contrib/lite/schema:all_files",
- "//tensorflow/contrib/lite/testing:all_files",
- "//tensorflow/contrib/lite/toco:all_files",
- "//tensorflow/contrib/lite/toco/graph_transformations/tests:all_files",
- "//tensorflow/contrib/lite/toco/python:all_files",
- "//tensorflow/contrib/lite/toco/tensorflow_graph_matching:all_files",
- "//tensorflow/contrib/lite/toco/tflite:all_files",
- "//tensorflow/contrib/lite/tools:all_files",
- "//tensorflow/contrib/lookup:all_files",
- "//tensorflow/contrib/losses:all_files",
- "//tensorflow/contrib/makefile:all_files",
- "//tensorflow/contrib/memory_stats:all_files",
- "//tensorflow/contrib/meta_graph_transform:all_files",
- "//tensorflow/contrib/metrics:all_files",
- "//tensorflow/contrib/model_pruning:all_files",
- "//tensorflow/contrib/model_pruning/examples/cifar10:all_files",
- "//tensorflow/contrib/nccl:all_files",
- "//tensorflow/contrib/nearest_neighbor:all_files",
- "//tensorflow/contrib/nn:all_files",
- "//tensorflow/contrib/opt:all_files",
- "//tensorflow/contrib/periodic_resample:all_files",
- "//tensorflow/contrib/predictor:all_files",
- "//tensorflow/contrib/quantize:all_files",
- "//tensorflow/contrib/receptive_field:all_files",
- "//tensorflow/contrib/reduce_slice_ops:all_files",
- "//tensorflow/contrib/remote_fused_graph/pylib:all_files",
- "//tensorflow/contrib/resampler:all_files",
- "//tensorflow/contrib/rnn:all_files",
- "//tensorflow/contrib/saved_model:all_files",
- "//tensorflow/contrib/saved_model/cc/saved_model:all_files",
- "//tensorflow/contrib/seq2seq:all_files",
- "//tensorflow/contrib/session_bundle:all_files",
- "//tensorflow/contrib/session_bundle/example:all_files",
- "//tensorflow/contrib/signal:all_files",
- "//tensorflow/contrib/slim:all_files",
- "//tensorflow/contrib/slim/python/slim/data:all_files",
- "//tensorflow/contrib/slim/python/slim/nets:all_files",
- "//tensorflow/contrib/solvers:all_files",
- "//tensorflow/contrib/sparsemax:all_files",
- "//tensorflow/contrib/specs:all_files",
- "//tensorflow/contrib/staging:all_files",
- "//tensorflow/contrib/stat_summarizer:all_files",
- "//tensorflow/contrib/stateless:all_files",
- "//tensorflow/contrib/summary:all_files",
- "//tensorflow/contrib/tensor_forest:all_files",
- "//tensorflow/contrib/tensor_forest/hybrid:all_files",
- "//tensorflow/contrib/tensor_forest/kernels/v4:all_files",
- "//tensorflow/contrib/tensor_forest/proto:all_files",
- "//tensorflow/contrib/tensorboard:all_files",
- "//tensorflow/contrib/tensorboard/db:all_files",
- "//tensorflow/contrib/tensorrt:all_files",
- "//tensorflow/contrib/testing:all_files",
- "//tensorflow/contrib/text:all_files",
- "//tensorflow/contrib/tfprof:all_files",
- "//tensorflow/contrib/timeseries:all_files",
- "//tensorflow/contrib/timeseries/examples:all_files",
- "//tensorflow/contrib/timeseries/python/timeseries:all_files",
- "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:all_files",
- "//tensorflow/contrib/tpu:all_files",
- "//tensorflow/contrib/tpu/profiler:all_files",
- "//tensorflow/contrib/tpu/proto:all_files",
- "//tensorflow/contrib/training:all_files",
- "//tensorflow/contrib/util:all_files",
- "//tensorflow/contrib/verbs:all_files",
- "//tensorflow/core:all_files",
- "//tensorflow/core/api_def:all_files",
- "//tensorflow/core/common_runtime/eager:all_files",
- "//tensorflow/core/debug:all_files",
- "//tensorflow/core/distributed_runtime:all_files",
- "//tensorflow/core/distributed_runtime/rpc:all_files",
- "//tensorflow/core/grappler:all_files",
- "//tensorflow/core/grappler/clusters:all_files",
- "//tensorflow/core/grappler/costs:all_files",
- "//tensorflow/core/grappler/inputs:all_files",
- "//tensorflow/core/grappler/optimizers:all_files",
- "//tensorflow/core/grappler/utils:all_files",
- "//tensorflow/core/kernels:all_files",
- "//tensorflow/core/kernels/batching_util:all_files",
- "//tensorflow/core/kernels/data:all_files",
- "//tensorflow/core/kernels/data/sql:all_files",
- "//tensorflow/core/kernels/fuzzing:all_files",
- "//tensorflow/core/kernels/hexagon:all_files",
- "//tensorflow/core/kernels/neon:all_files",
- "//tensorflow/core/lib/db:all_files",
- "//tensorflow/core/ops/compat:all_files",
- "//tensorflow/core/platform/cloud:all_files",
- "//tensorflow/core/platform/default/build_config:all_files",
- "//tensorflow/core/platform/hadoop:all_files",
- "//tensorflow/core/platform/s3:all_files",
- "//tensorflow/core/profiler:all_files",
- "//tensorflow/core/profiler/internal:all_files",
- "//tensorflow/core/profiler/internal/advisor:all_files",
- "//tensorflow/core/util/ctc:all_files",
- "//tensorflow/core/util/tensor_bundle:all_files",
- "//tensorflow/examples/adding_an_op:all_files",
- "//tensorflow/examples/android:all_files",
- "//tensorflow/examples/benchmark:all_files",
- "//tensorflow/examples/get_started/regression:all_files",
- "//tensorflow/examples/how_tos/reading_data:all_files",
- "//tensorflow/examples/image_retraining:all_files",
- "//tensorflow/examples/label_image:all_files",
- "//tensorflow/examples/learn:all_files",
- "//tensorflow/examples/multibox_detector:all_files",
- "//tensorflow/examples/saved_model:all_files",
- "//tensorflow/examples/speech_commands:all_files",
- "//tensorflow/examples/tutorials/estimators:all_files",
- "//tensorflow/examples/tutorials/layers:all_files",
- "//tensorflow/examples/tutorials/mnist:all_files",
- "//tensorflow/examples/tutorials/monitors:all_files",
- "//tensorflow/examples/tutorials/word2vec:all_files",
- "//tensorflow/examples/wav_to_spectrogram:all_files",
- "//tensorflow/go:all_files",
- "//tensorflow/java:all_files",
- "//tensorflow/java/src/main/java/org/tensorflow/examples:all_files",
- "//tensorflow/java/src/main/native:all_files",
- "//tensorflow/python:all_files",
- "//tensorflow/python/data:all_files",
- "//tensorflow/python/data/kernel_tests:all_files",
- "//tensorflow/python/data/ops:all_files",
- "//tensorflow/python/data/util:all_files",
- "//tensorflow/python/debug:all_files",
- "//tensorflow/python/eager:all_files",
- "//tensorflow/python/estimator:all_files",
- "//tensorflow/python/feature_column:all_files",
- "//tensorflow/python/keras:all_files",
- "//tensorflow/python/kernel_tests:all_files",
- "//tensorflow/python/kernel_tests/distributions:all_files",
- "//tensorflow/python/kernel_tests/linalg:all_files",
- "//tensorflow/python/kernel_tests/random:all_files",
- "//tensorflow/python/kernel_tests/testdata:all_files",
- "//tensorflow/python/ops/distributions:all_files",
- "//tensorflow/python/ops/linalg:all_files",
- "//tensorflow/python/ops/losses:all_files",
- "//tensorflow/python/profiler:all_files",
- "//tensorflow/python/profiler/internal:all_files",
- "//tensorflow/python/saved_model:all_files",
- "//tensorflow/python/tools:all_files",
- "//tensorflow/tools/api/generator:all_files",
- "//tensorflow/tools/api/golden:all_files",
- "//tensorflow/tools/api/lib:all_files",
- "//tensorflow/tools/api/tests:all_files",
- "//tensorflow/tools/benchmark:all_files",
- "//tensorflow/tools/build_info:all_files",
- "//tensorflow/tools/ci_build/gpu_build:all_files",
- "//tensorflow/tools/common:all_files",
- "//tensorflow/tools/compatibility:all_files",
- "//tensorflow/tools/dist_test/server:all_files",
- "//tensorflow/tools/docker:all_files",
- "//tensorflow/tools/docker/notebooks:all_files",
- "//tensorflow/tools/docs:all_files",
- "//tensorflow/tools/git:all_files",
- "//tensorflow/tools/graph_transforms:all_files",
- "//tensorflow/tools/mlpbtxt:all_files",
- "//tensorflow/tools/proto_text:all_files",
- "//tensorflow/tools/quantization:all_files",
- "//tensorflow/tools/test:all_files",
- "//tensorflow/user_ops:all_files",
- "//third_party/eigen3:all_files",
- "//third_party/fft2d:all_files",
- "//third_party/flatbuffers:all_files",
- "//third_party/hadoop:all_files",
- "//third_party/sycl:all_files",
- "//third_party/sycl/sycl:all_files",
- ],
- visibility = ["//visibility:public"],
-)
-
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 426f97b844..2367014cd0 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -34,6 +34,8 @@ filegroup(
exclude = [
"c_api_experimental.cc",
"c_api_experimental.h",
+ "python_api.cc",
+ "python_api.h",
"*test*",
],
),
@@ -281,20 +283,7 @@ tf_cuda_library(
deps = [
":c_api",
":c_api_internal",
+ # TODO(b/74620627): remove when _USE_C_SHAPES is removed
+ "//tensorflow/python:cpp_shape_inference_proto_cc",
],
)
-
-# -----------------------------------------------------------------------------
-# Google-internal targets.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 8df7b56623..a2d96357ac 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -30,6 +30,8 @@ tf_cuda_library(
"//tensorflow/core:core_cpu",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
+ "//tensorflow/core/common_runtime/eager:execute",
+ "//tensorflow/core/common_runtime/eager:execute_node",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/common_runtime/eager:copy_to_device_node",
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index eaeb2fd07a..028865d360 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -33,6 +33,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
+#include "tensorflow/core/common_runtime/eager/execute.h"
+#include "tensorflow/core/common_runtime/eager/execute_node.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -434,39 +436,8 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
namespace {
-// TODO(apassos) move to TensorHandle
-tensorflow::TensorHandle* TFE_TensorHandleCopyToDevice_Internal(
- tensorflow::TensorHandle* h, TFE_Context* ctx, const char* device_name,
- TF_Status* status) {
- status->status = ctx->context.GetStatus();
- if (!status->status.ok()) {
- return nullptr;
- }
- tensorflow::Device* dstd = ctx->context.HostCPU();
- if (device_name != nullptr && strlen(device_name) > 0) {
- status->status =
- ctx->context.device_mgr()->LookupDevice(device_name, &dstd);
- if (!status->status.ok()) return nullptr;
- }
- if (ctx->context.Async()) {
- // Note that `h` may not be currently ready. However execution order will
- // make sure that `h` is ready before the copy is actually done.
- tensorflow::CopyToDeviceNode* node =
- new tensorflow::CopyToDeviceNode(h, dstd, &ctx->context);
- tensorflow::TensorHandle* output = node->dst();
- // Note that calling Add makes `node` accessible by the EagerExecutor
- // thread. So further accesses need to be thread-safe.
- ctx->context.ExecutorAdd(node);
- return output;
- } else {
- tensorflow::TensorHandle* output = nullptr;
- status->status = h->CopyToDevice(&ctx->context, dstd, &output);
- return output;
- }
-}
-
tensorflow::Status ValidateInputTypeAndPlacement(
- TFE_Context* ctx, tensorflow::Device* host_device,
+ tensorflow::EagerContext* ctx, tensorflow::Device* host_device,
tensorflow::Device* op_device, TFE_Op* op,
const tensorflow::OpKernel* kernel) {
const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
@@ -483,8 +454,8 @@ tensorflow::Status ValidateInputTypeAndPlacement(
const tensorflow::Device* actual_device =
handle_device == nullptr ? host_device : handle_device;
if (expected_device != actual_device) {
- switch (TFE_ContextGetDevicePlacementPolicy(ctx)) {
- case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32:
+ switch (ctx->GetDevicePlacementPolicy()) {
+ case tensorflow::DEVICE_PLACEMENT_SILENT_FOR_INT32:
// TODO(xpan): See if we could bubble python related error up
// to python level.
if (handle->dtype == tensorflow::DT_INT32) {
@@ -493,7 +464,7 @@ tensorflow::Status ValidateInputTypeAndPlacement(
break;
}
TF_FALLTHROUGH_INTENDED;
- case TFE_DEVICE_PLACEMENT_EXPLICIT:
+ case tensorflow::DEVICE_PLACEMENT_EXPLICIT:
return tensorflow::errors::InvalidArgument(
"Tensors on conflicting devices:"
" cannot compute ",
@@ -505,7 +476,7 @@ tensorflow::Status ValidateInputTypeAndPlacement(
" or transparently copied by using tfe.enable_eager_execution("
"tfe.DEVICE_PLACEMENT_SILENT). Copying tensors between devices"
" may slow down your model");
- case TFE_DEVICE_PLACEMENT_WARN:
+ case tensorflow::DEVICE_PLACEMENT_WARN:
LOG(WARNING) << "before computing " << op->name << " input #" << i
<< " was expected to be on " << expected_device->name()
<< " but is actually on " << actual_device->name()
@@ -513,17 +484,14 @@ tensorflow::Status ValidateInputTypeAndPlacement(
<< "). This triggers a copy which can be a performance "
"bottleneck.";
break;
- case TFE_DEVICE_PLACEMENT_SILENT: // Do nothing.
+ case tensorflow::DEVICE_PLACEMENT_SILENT: // Do nothing.
break;
}
// We are only here if the policy is warn or silent copies, so we should
// trigger a copy.
- TF_Status* s = TF_NewStatus();
- tensorflow::TensorHandle* copied_tensor =
- TFE_TensorHandleCopyToDevice_Internal(
- handle, ctx, expected_device->name().c_str(), s);
- tensorflow::Status status = s->status;
- TF_DeleteStatus(s);
+ tensorflow::TensorHandle* copied_tensor = nullptr;
+ tensorflow::Status status = tensorflow::EagerCopyToDevice(
+ handle, ctx, expected_device->name().c_str(), &copied_tensor);
if (!status.ok()) {
if (copied_tensor != nullptr) copied_tensor->Unref();
return tensorflow::errors::Internal(
@@ -574,145 +542,6 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
return nullptr;
}
-tensorflow::Status Execute(
- TFE_Context* ctx, tensorflow::Device* device,
- const tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>&
- op_inputs,
- tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats,
- tensorflow::TensorHandle** retvals, int num_retvals) {
- if (!ctx->context.SoftPlacement() && device == nullptr) {
- device = ctx->context.HostCPU();
- }
-
- if (device == nullptr) {
- // TODO(apassos) debug how the assignment below might return a different
- // device from the one requested above.
- device = kernel->device();
- }
-
- std::vector<tensorflow::Tensor> outputs(1);
- const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
- output_memory_types = &kernel->kernel()->output_memory_types();
- std::vector<tensorflow::Tensor> inputs(op_inputs.size());
- for (int i = 0; i < op_inputs.size(); ++i) {
- const tensorflow::Tensor* input_tensor = nullptr;
- TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
- inputs[i] = *input_tensor;
- }
- // WARNING: kernel->Run utilizes the FunctionLibraryRuntime
- // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def.
- // But knowledge of the implementation
- // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by
- // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
- // This is quite subtle. Re-work things to make this better? (Would it make
- // sense for FunctionLibraryRuntime to ensure thread-safe access to
- // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats
- // for ops which are a part of functions.
- // TODO(agarwal): change Run to take vector of handles ?
- TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats));
- if (maybe_stats != nullptr) {
- maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() -
- maybe_stats->all_start_micros());
- tensorflow::mutex_lock ml(*ctx->context.MetadataMu());
- if (ctx->context.ShouldStoreMetadata()) {
- auto* step_stats = ctx->context.RunMetadataProto()->mutable_step_stats();
- // Lazily initialize the RunMetadata with information about all devices if
- // this is the first call.
- while (step_stats->dev_stats_size() < ctx->context.devices()->size()) {
- step_stats->add_dev_stats();
- }
- // Find the current device's index.
- int device_idx = 0;
- for (int i = 0; i < ctx->context.devices()->size(); ++i) {
- if (ctx->context.devices()->at(i) == device) {
- device_idx = i;
- break;
- }
- }
- // Populate the device stats for this device.
- auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
- dev_stats->set_device(device->name());
- *dev_stats->add_node_stats() = *maybe_stats;
- }
- }
- DCHECK_EQ(num_retvals, outputs.size());
- tensorflow::Device* op_device = IsCPU(device) ? nullptr : device;
- for (int i = 0; i < num_retvals; ++i) {
- tensorflow::Device* d = op_device;
- if (d != nullptr && output_memory_types != nullptr &&
- (*output_memory_types)[i] == tensorflow::HOST_MEMORY) {
- d = nullptr;
- }
- if (retvals[i] == nullptr) {
- retvals[i] = new tensorflow::TensorHandle(outputs[i], d, op_device);
- } else {
- retvals[i]->SetTensorAndDevice(outputs[i], d, op_device);
- }
- }
- return tensorflow::Status::OK();
-}
-
-// TODO(agarwal): move EagerExecutor and EagerNode related code to a separate
-// file.
-class ExecuteNode : public tensorflow::EagerNode {
- public:
- ExecuteNode(TFE_Op* op, tensorflow::KernelAndDevice* kernel,
- tensorflow::NodeExecStats* maybe_stats,
- const tensorflow::DataTypeVector& output_dtypes,
- TFE_TensorHandle** retvals, int num_retvals)
- : tensorflow::EagerNode(op->ctx->context.NextId()),
- ctx_(op->ctx),
- op_device_(op->device),
- inputs_(op->inputs),
- kernel_(kernel),
- maybe_stats_(maybe_stats),
- retvals_(num_retvals) {
- for (auto handle : inputs_) {
- handle->Ref();
- }
- TFE_Context* ctx = op->ctx;
- for (int i = 0; i < num_retvals; ++i) {
- tensorflow::TensorHandle* h =
- new tensorflow::TensorHandle(id, output_dtypes[i], &ctx->context);
- h->Ref();
- retvals[i] = new TFE_TensorHandle(h);
- retvals_[i] = h;
- }
- }
-
- ~ExecuteNode() override {
- for (auto handle : inputs_) {
- handle->Unref();
- }
- for (auto handle : retvals_) {
- handle->Unref();
- }
- }
-
- tensorflow::Status Run() override {
- const tensorflow::Status status =
- Execute(ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(),
- retvals_.begin(), retvals_.size());
- if (status.ok()) {
- return status;
- } else {
- return tensorflow::Status(
- status.code(),
- tensorflow::strings::StrCat("Got error, \"", status.error_message(),
- "\" while executing kernel ",
- kernel_->kernel()->def().DebugString()));
- }
- }
-
- private:
- TFE_Context* ctx_;
- tensorflow::Device* op_device_;
- tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs_;
- tensorflow::KernelAndDevice* kernel_;
- std::unique_ptr<tensorflow::NodeExecStats> maybe_stats_;
- tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> retvals_;
-};
-
#ifdef TENSORFLOW_EAGER_USE_XLA
// Synthesizes and returns a wrapper function over `op`, which must be a
@@ -1037,8 +866,8 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// device from the one requested above.
device = kernel->device();
}
- status->status = ValidateInputTypeAndPlacement(ctx, ctx->context.HostCPU(),
- device, op, kernel->kernel());
+ status->status = ValidateInputTypeAndPlacement(
+ &ctx->context, ctx->context.HostCPU(), device, op, kernel->kernel());
if (!status->status.ok()) return;
std::unique_ptr<tensorflow::NodeExecStats> maybe_stats;
if (ctx->context.ShouldStoreMetadata()) {
@@ -1053,18 +882,27 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// Note that for async mode, execution order will make sure that all
// input handles are ready before executing them.
// TODO(agarwal): Consider executing "cheap" kernels inline for performance.
- tensorflow::EagerNode* node =
- new ExecuteNode(op, kernel, maybe_stats.release(), output_dtypes,
- retvals, *num_retvals);
+ tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> handle_retvals(
+ *num_retvals);
+ tensorflow::uint64 id = op->ctx->context.NextId();
+ for (int i = 0; i < *num_retvals; ++i) {
+ tensorflow::TensorHandle* h =
+ new tensorflow::TensorHandle(id, output_dtypes[i], &op->ctx->context);
+ retvals[i] = new TFE_TensorHandle(h);
+ handle_retvals[i] = h;
+ }
+ tensorflow::EagerNode* node = new tensorflow::ExecuteNode(
+ id, &op->ctx->context, op->device, op->inputs, kernel,
+ maybe_stats.release(), output_dtypes, handle_retvals);
ctx->context.ExecutorAdd(node);
} else {
// Execute checks if retvals[i] is nullptr or not to figure if it needs to
// allocate it.
std::vector<tensorflow::TensorHandle*> handle_retvals(*num_retvals,
nullptr);
- status->status =
- Execute(op->ctx, op->device, op->inputs, kernel, maybe_stats.get(),
- handle_retvals.data(), *num_retvals);
+ status->status = tensorflow::EagerExecute(
+ &op->ctx->context, op->device, op->inputs, kernel, maybe_stats.get(),
+ handle_retvals.data(), *num_retvals);
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
}
@@ -1075,8 +913,9 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx,
const char* device_name,
TF_Status* status) {
- tensorflow::TensorHandle* handle = TFE_TensorHandleCopyToDevice_Internal(
- h->handle, ctx, device_name, status);
+ tensorflow::TensorHandle* handle;
+ status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context,
+ device_name, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle(handle);
}
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 2268aba90d..d88a6c1dda 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -688,12 +688,12 @@ TEST(CAPI, Execute_Min_CPU) {
TFE_DeleteOp(minOp);
TFE_DeleteTensorHandle(input);
TFE_DeleteTensorHandle(axis);
- TFE_DeleteContext(ctx, status);
- ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
+ TFE_DeleteContext(ctx, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float output[2] = {0};
EXPECT_EQ(sizeof(output), TF_TensorByteSize(t));
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index c7bd3bdafd..97c323b872 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -601,23 +601,28 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
}
CHECK(state.op_tape.empty());
result->reserve(source_tensor_ids.size());
+ gtl::FlatSet<int64> used_gradient_ids(source_tensor_ids.size());
for (auto is : source_tensor_ids) {
auto grad_it = gradients.find(is);
if (grad_it == gradients.end()) {
result->push_back(nullptr);
} else {
- if (grad_it->second.size() == 1) {
- result->push_back(grad_it->second[0]);
- } else {
- result->push_back(vspace.AggregateGradients(grad_it->second));
+ if (grad_it->second.size() > 1) {
+ Gradient* grad = vspace.AggregateGradients(grad_it->second);
+ grad_it->second.clear();
+ grad_it->second.push_back(grad);
}
- gradients.erase(grad_it);
+ result->push_back(grad_it->second[0]);
+ used_gradient_ids.insert(is);
}
}
- VLOG(1) << "Final gradients size: " << gradients.size();
+ VLOG(1) << "Final gradients size: "
+ << gradients.size() - used_gradient_ids.size();
for (auto grad_pair : gradients) {
- for (const auto& g : grad_pair.second) {
- vspace.DeleteGradient(g);
+ if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) {
+ for (const auto& g : grad_pair.second) {
+ vspace.DeleteGradient(g);
+ }
}
}
return Status::OK();
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index cd604538f1..93155998b8 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/python_api.h"
#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/python/framework/cpp_shape_inference.pb.h"
namespace tensorflow {
@@ -109,4 +110,29 @@ void ExtendSession(TF_Session* session, TF_Status* status) {
session->extend_before_run = false;
}
+std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
+ Node* node = &output.oper->node;
+ CppShapeInferenceResult::HandleData handle_data;
+ handle_data.set_is_set(true);
+ {
+ mutex_lock l(graph->mu);
+ tensorflow::shape_inference::InferenceContext* ic =
+ graph->refiner.GetContext(node);
+ CHECK(ic != nullptr);
+ CHECK_LT(output.index, ic->num_outputs());
+ const auto* shapes_and_types =
+ ic->output_handle_shapes_and_types(output.index);
+ if (shapes_and_types == nullptr) return "";
+
+ for (const auto& p : *shapes_and_types) {
+ auto* out_shape_and_type = handle_data.add_shape_and_type();
+ ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
+ out_shape_and_type->set_dtype(p.dtype);
+ }
+ }
+ string result;
+ handle_data.SerializeToString(&result);
+ return result;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h
index 13b680b3a2..2d4c8cd9ed 100644
--- a/tensorflow/c/python_api.h
+++ b/tensorflow/c/python_api.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_PYTHON_API_H_
#define TENSORFLOW_C_PYTHON_API_H_
+#include <string>
+
#include "tensorflow/c/c_api.h"
// These functions can be removed without notice. They exist to facilitate some
@@ -51,6 +53,11 @@ void SetRequireShapeInferenceFns(TF_Graph* graph, bool require);
// the graph after the session has been made aware of them.
void ExtendSession(TF_Session* session, TF_Status* status);
+// Returns the serialized CppShapeInferenceResult::HandleData proto for
+// `output` if its a resource tensor, or otherwise returns the empty string.
+// TODO(b/74620627): remove when _USE_C_SHAPES is removed
+std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output);
+
} // namespace tensorflow
#endif // TENSORFLOW_C_PYTHON_API_H_
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index 9060c19e9d..079e063d3e 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -620,18 +620,6 @@ tf_cc_binary(
],
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
cc_library(
name = "queue_runner",
srcs = ["training/queue_runner.cc"],
diff --git a/tensorflow/cc/framework/cc_op_gen_test.cc b/tensorflow/cc/framework/cc_op_gen_test.cc
index 1e0f2d241b..5d9dfd95a5 100644
--- a/tensorflow/cc/framework/cc_op_gen_test.cc
+++ b/tensorflow/cc/framework/cc_op_gen_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -61,12 +62,12 @@ op {
)";
void ExpectHasSubstr(StringPiece s, StringPiece expected) {
- EXPECT_TRUE(s.contains(expected))
+ EXPECT_TRUE(str_util::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'";
}
void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) {
- EXPECT_FALSE(s.contains(expected))
+ EXPECT_FALSE(str_util::StrContains(s, expected))
<< "'" << s << "' contains '" << expected << "'";
}
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index 7164249262..c143b97833 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
@@ -218,7 +219,7 @@ std::unordered_set<string> Scope::Impl::GetColocationConstraints(
if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
for (const string& entry : node_constraints) {
StringPiece s(entry);
- if (s.Consume(kColocationGroupPrefix)) {
+ if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) {
current_constraints.insert(s.ToString());
}
}
diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD
index d29ad3ebcb..06a3be18e0 100644
--- a/tensorflow/cc/saved_model/BUILD
+++ b/tensorflow/cc/saved_model/BUILD
@@ -94,18 +94,3 @@ filegroup(
"testdata/half_plus_two/**",
]),
)
-
-# -----------------------------------------------------------------------------
-# Google-internal targets.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/cc/saved_model/python/BUILD b/tensorflow/cc/saved_model/python/BUILD
index f5fbc75edc..6f04ebdc55 100644
--- a/tensorflow/cc/saved_model/python/BUILD
+++ b/tensorflow/cc/saved_model/python/BUILD
@@ -7,18 +7,6 @@ package(
default_visibility = ["//visibility:public"],
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
load("//tensorflow/core:platform/default/build_config.bzl", "tf_py_clif_cc")
tf_py_clif_cc(
diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD
index f413a5cc52..6f1c873540 100644
--- a/tensorflow/cc/tools/BUILD
+++ b/tensorflow/cc/tools/BUILD
@@ -41,18 +41,3 @@ tf_cc_test(
"//tensorflow/core:testlib",
],
)
-
-# -----------------------------------------------------------------------------
-# Google-internal targets.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index ffa2d08829..fa03b1f3c2 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -250,17 +250,3 @@ exports_files([
"benchmark_main.template", # used by tf_library(...,gen_benchmark=True)
"test.cc", # used by tf_library(...,gen_test=True)
])
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
index 972b7d51ec..2642536c4f 100644
--- a/tensorflow/compiler/aot/codegen_test.cc
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
@@ -33,7 +34,7 @@ namespace {
void ExpectErrorContains(const Status& status, StringPiece str) {
EXPECT_NE(Status::OK(), status);
- EXPECT_TRUE(StringPiece(status.error_message()).contains(str))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), str))
<< "expected error: " << status.error_message() << " to contain: " << str;
}
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 28aab6eb61..b053dad1b5 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -182,17 +182,3 @@ tf_cc_test(
"//third_party/eigen3",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index e2f01179d4..8ea014c2ee 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -55,7 +55,7 @@ const char kUsageHeader[] =
"\n";
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
- if (StringPiece(fname).ends_with(".pbtxt")) {
+ if (str_util::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 8e505da622..9ea246ffdc 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -365,20 +365,6 @@ tf_cc_test(
],
)
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD
index 15507b3851..676f71a75a 100644
--- a/tensorflow/compiler/jit/graphcycles/BUILD
+++ b/tensorflow/compiler/jit/graphcycles/BUILD
@@ -27,17 +27,3 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 616a7f8f15..00a6f4075f 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -41,17 +41,3 @@ cc_library(
],
alwayslink = 1,
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD
index 9cd66fc13c..5d211f4d73 100644
--- a/tensorflow/compiler/jit/legacy_flags/BUILD
+++ b/tensorflow/compiler/jit/legacy_flags/BUILD
@@ -63,17 +63,3 @@ cc_library(
"//tensorflow/core:lib",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD
index e5787ca4c8..c9e46bc147 100644
--- a/tensorflow/compiler/jit/ops/BUILD
+++ b/tensorflow/compiler/jit/ops/BUILD
@@ -17,17 +17,3 @@ cc_library(
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/jit/xla_tensor_info.h b/tensorflow/compiler/jit/xla_tensor_info.h
index fbd6ad770f..a02f642c37 100644
--- a/tensorflow/compiler/jit/xla_tensor_info.h
+++ b/tensorflow/compiler/jit/xla_tensor_info.h
@@ -71,6 +71,12 @@ class XlaTensorInfoManager : public AllocatorWrapper {
// Creates a new XlaTensorInfoManager, delegating all DeallocateRaw calls to
// allocator.
XlaTensorInfoManager(Allocator* allocator) : AllocatorWrapper(allocator) {}
+ ~XlaTensorInfoManager() {
+ // Destroy the tensor info hashtable under the lock, to ensure all accesses
+ // to the hashtable are properly sequenced.
+ mutex_lock lock(lock_);
+ tensor_infos_.clear();
+ }
// Returns the XlaTensorInfo for the given device memory pointer or nullptr if
// none exists.
diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD
index da4bc44c7a..238fd15166 100644
--- a/tensorflow/compiler/plugin/BUILD
+++ b/tensorflow/compiler/plugin/BUILD
@@ -49,17 +49,3 @@ cc_library(
"//tensorflow/compiler/jit:xla_device",
],
)
-
-#-----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 1c5a8f8e69..edabdc218a 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -835,17 +835,3 @@ tf_xla_py_test(
"//tensorflow/python:platform_test",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 3d3e112f48..a8ab235378 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -600,6 +600,20 @@ class UnaryOpsTest(XLATestCase):
src,
expected=dst)
+ def testBitcast(self):
+ self._assertOpOutputMatchesExpected(
+ lambda x: array_ops.bitcast(x, dtypes.int32),
+ np.array([1, 0x3f800000], np.int32),
+ expected=np.array([1, 0x3f800000], np.int32))
+ self._assertOpOutputMatchesExpected(
+ lambda x: array_ops.bitcast(x, dtypes.float32),
+ np.array([1, 0x3f800000], np.int32),
+ expected=np.array([1e-45, 1.0], np.float32))
+ self._assertOpOutputMatchesExpected(
+ lambda x: array_ops.bitcast(x, dtypes.int32),
+ np.array([1e-45, 1.0], np.float32),
+ expected=np.array([1, 0x3f800000], np.int32))
+
def testInvertPermutation(self):
self._assertOpOutputMatchesExpected(
array_ops.invert_permutation,
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index eb20ca501c..8c33bf179c 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -462,17 +462,3 @@ cc_library(
"//tensorflow/core:protos_all_cc",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD
index 311dddca94..c30bb9cacd 100644
--- a/tensorflow/compiler/tf2xla/cc/BUILD
+++ b/tensorflow/compiler/tf2xla/cc/BUILD
@@ -51,17 +51,3 @@ cc_library(
"//tensorflow/core:protos_all_cc",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 0bbfe86de3..f1bc7d6af4 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -217,17 +217,3 @@ cc_library(
],
alwayslink = 1,
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc
index 43a6a747c6..c52b2dcb7e 100644
--- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc
@@ -62,5 +62,50 @@ class CastOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("Cast"), CastOp);
+class BitcastOp : public XlaOpKernel {
+ public:
+ explicit BitcastOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &src_dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("type", &dst_dtype_));
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(src_dtype_, &src_type_));
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dst_dtype_, &dst_type_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* builder = ctx->builder();
+ xla::ComputationDataHandle input = ctx->Input(0);
+ xla::ComputationDataHandle output;
+
+ if (src_dtype_ == dst_dtype_) {
+ output = input;
+ } else {
+ // The only complex type in XLA is C64, so error out if the bitcast has a
+ // complex source or destination type and the bitcast is not trivial.
+ OP_REQUIRES(ctx,
+ !xla::primitive_util::IsComplexType(src_type_) &&
+ !xla::primitive_util::IsComplexType(dst_type_),
+ errors::Unimplemented("Complex types not supported."));
+ // XLA bitcast requires that the bit-width of the source and destination
+ // matches, and currently only the simple lowering is performed.
+ OP_REQUIRES(ctx,
+ xla::primitive_util::BitWidth(src_type_) ==
+ xla::primitive_util::BitWidth(dst_type_),
+ errors::Unimplemented(
+ "Only bitcasts between equally sized types supported."));
+ output = builder->BitcastConvertType(input, dst_type_);
+ }
+
+ ctx->SetOutput(0, output);
+ }
+
+ protected:
+ DataType src_dtype_, dst_dtype_;
+ xla::PrimitiveType src_type_, dst_type_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BitcastOp);
+};
+
+REGISTER_XLA_OP(Name("Bitcast"), BitcastOp);
+
} // anonymous namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index 488fda74bf..344773c8c5 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -140,17 +140,3 @@ cc_library(
"//tensorflow/core:lib",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD
index 98f72b3792..aeb743a663 100644
--- a/tensorflow/compiler/tf2xla/ops/BUILD
+++ b/tensorflow/compiler/tf2xla/ops/BUILD
@@ -39,17 +39,3 @@ tf_gen_op_wrapper_py(
":sendrecv_ops",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index cd13db4d30..751777222f 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -654,18 +654,6 @@ tf_cc_test(
# -----------------------------------------------------------------------------
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.
cc_header_only_library(
name = "xla_headers_lib",
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index 5094e5ce67..a299c2afd4 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -214,17 +214,3 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 5ce3c45528..c4c8894374 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -276,7 +276,12 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
if (execution_profile != nullptr) {
*execution_profile = response.profile();
- // TODO(b/74197823): Get execution stats for the graph and VLOG(1) them.
+ if (VLOG_IS_ON(1)) {
+ TF_ASSIGN_OR_RETURN(
+ auto execution_stats,
+ ExecutionStatsAsString(computation, response.profile()));
+ VLOG(1) << execution_stats;
+ }
}
return MakeUnique<GlobalData>(stub_, response.output());
@@ -317,6 +322,12 @@ StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
return std::move(outputs);
}
+StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
+ tensorflow::gtl::ArraySlice<XlaComputationInstance> computations) {
+ return Unimplemented(
+ "ExecuteParallel is not yet implemented for XlaComputation.");
+}
+
StatusOr<std::vector<DeviceHandle>> Client::GetDeviceHandles(
int64 device_count) {
if (device_count < 1) {
@@ -393,6 +404,27 @@ StatusOr<ComputationStats> Client::GetComputationStats(
return response.stats();
}
+StatusOr<ComputationStats> Client::GetComputationStats(
+ const XlaComputation& computation,
+ const DebugOptions& debug_options) const {
+ ComputationGraphStatsRequest request;
+
+ // TODO(b/74197823): Find a way to avoid the copy of the hlo proto.
+ *request.mutable_computation() = computation.proto();
+ *request.mutable_debug_options() = debug_options;
+ ComputationStatsResponse response;
+
+ VLOG(1) << "making computation graph stats request";
+ Status s = stub_->GetComputationGraphStats(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+ CHECK(response.has_stats());
+ return response.stats();
+}
+
StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
const Computation& computation) {
GetComputationShapeRequest request;
@@ -410,6 +442,12 @@ StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
return WrapUnique(response.release_program_shape());
}
+StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
+ const XlaComputation& computation) {
+ TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape());
+ return MakeUnique<ProgramShape>(result);
+}
+
StatusOr<Shape> Client::GetShape(const GlobalData& data) {
GetShapeRequest request;
*request.mutable_data() = data.handle();
@@ -448,6 +486,28 @@ StatusOr<string> Client::ExecutionStatsAsString(
return string("[Execution Statistics] not available.");
}
+StatusOr<string> Client::ExecutionStatsAsString(
+ const XlaComputation& computation, const ExecutionProfile& profile) {
+ TF_ASSIGN_OR_RETURN(
+ auto computation_stats,
+ GetComputationStats(computation,
+ legacy_flags::GetDebugOptionsFromFlags()));
+ int64 total_flops =
+ computation_stats.flop_count() + computation_stats.transcendental_count();
+ if (profile.compute_time_ns() > 0) {
+ int64 nanoseconds = profile.compute_time_ns();
+ int64 cycle_count = profile.compute_cycle_count();
+ double gflops = total_flops / nanoseconds;
+ return tensorflow::strings::StrCat(
+ "[Execution Statistics] flop count: ", computation_stats.flop_count(),
+ ", transcendental count: ", computation_stats.transcendental_count(),
+ ", compute execution time: ", nanoseconds, " nsec",
+ ", compute cycles: ", cycle_count, ", performance: ", gflops,
+ "gflop/s");
+ }
+ return string("[Execution Statistics] not available.");
+}
+
StatusOr<ChannelHandle> Client::CreateChannelHandle() {
CreateChannelHandleRequest request;
CreateChannelHandleResponse response;
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h
index ec87646ebf..05d707dab1 100644
--- a/tensorflow/compiler/xla/client/client.h
+++ b/tensorflow/compiler/xla/client/client.h
@@ -99,6 +99,36 @@ class Client {
StatusOr<std::vector<std::unique_ptr<GlobalData>>> ExecuteParallel(
tensorflow::gtl::ArraySlice<ComputationInstance> computations);
+ // A struct to represent a computation instance to be executed.
+ // * If execution_options.device_handles is not empty, the computation is
+ // executed on the devices associated with the handles by partitioning the
+ // computation based on the attached sharding attributes. Otherwise, a
+ // device is chosen by the service.
+ //
+ // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+ struct XlaComputationInstance {
+ const XlaComputation& computation;
+ std::vector<GlobalData*> arguments;
+ ExecutionOptions execution_options;
+ ExecutionProfile* execution_profile;
+
+ XlaComputationInstance(const XlaComputation& computation,
+ std::vector<GlobalData*> arguments,
+ ExecutionOptions execution_options,
+ ExecutionProfile* execution_profile)
+ : computation(computation),
+ arguments(std::move(arguments)),
+ execution_options(execution_options),
+ execution_profile(execution_profile) {}
+ };
+
+ // Executes a list XlaComputationInstances and returns global data produced
+ // from each computation.
+ //
+ // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+ StatusOr<std::vector<std::unique_ptr<GlobalData>>> ExecuteParallel(
+ tensorflow::gtl::ArraySlice<XlaComputationInstance> computations);
+
// Requests device_count device handles available on the target. The returned
// device handles are used to specify the devices to execute the computations
// (see ExecuteParallel) or to transfer data (see TransferToServer or
@@ -175,6 +205,13 @@ class Client {
StatusOr<ComputationStats> GetComputationStats(
const Computation& computation, const DebugOptions& debug_options) const;
+ // Retrieves the statistics of the given computation.
+ //
+ // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+ StatusOr<ComputationStats> GetComputationStats(
+ const XlaComputation& computation,
+ const DebugOptions& debug_options) const;
+
// Returns the Shape of the given array specified by 'data'. The shape
// includes the Layout of the array as it is stored on the service.
StatusOr<Shape> GetShape(const GlobalData& data);
@@ -184,6 +221,13 @@ class Client {
StatusOr<std::unique_ptr<ProgramShape>> GetComputationShape(
const Computation& computation);
+ // As above, but returns the shape of the provided computation (parameter
+ // types/names and return type).
+ //
+ // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+ StatusOr<std::unique_ptr<ProgramShape>> GetComputationShape(
+ const XlaComputation& computation);
+
// Creates a channel handle that can be used to transfer data between
// two computations via a pair of Send and Recv instructions.
StatusOr<ChannelHandle> CreateChannelHandle();
@@ -197,6 +241,8 @@ class Client {
// ExecutionProfile returned from an execution of the computation.
StatusOr<string> ExecutionStatsAsString(const Computation& computation,
const ExecutionProfile& profile);
+ StatusOr<string> ExecutionStatsAsString(const XlaComputation& computation,
+ const ExecutionProfile& profile);
ServiceInterface* stub_; // Stub that this client is connected on.
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index 39d02f0863..4d3b0ee0d6 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -253,26 +253,6 @@ StatusOr<ProgramShape> ComputationBuilder::GetProgramShape() {
return std::move(*response.mutable_program_shape());
}
-ComputationDataHandle ComputationBuilder::CheckShape(
- const ComputationDataHandle& operand, const Shape& expected_shape) {
- std::unique_ptr<Shape> actual_shape = GetShape(operand).ConsumeValueOrDie();
- CHECK(ShapeUtil::Equal(expected_shape, *actual_shape))
- << "want " << ShapeUtil::HumanString(expected_shape) << " got "
- << ShapeUtil::HumanString(*actual_shape);
- return operand;
-}
-
-void ComputationBuilder::CheckSameShape(const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs) {
- std::unique_ptr<Shape> lhs_shape = GetShape(lhs).ConsumeValueOrDie();
- std::unique_ptr<Shape> rhs_shape = GetShape(rhs).ConsumeValueOrDie();
- VLOG(2) << "checking " << ShapeUtil::HumanString(*lhs_shape) << " equals "
- << ShapeUtil::HumanString(*rhs_shape);
- CHECK(ShapeUtil::Equal(*lhs_shape, *rhs_shape))
- << "lhs " << ShapeUtil::HumanString(*lhs_shape) << " rhs "
- << ShapeUtil::HumanString(*rhs_shape);
-}
-
ComputationDataHandle ComputationBuilder::Slice(
const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index 2141ebc206..019c6f3afb 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -104,15 +104,6 @@ class ComputationBuilder {
// Retrieves the (inferred) result for the current computation's shape.
StatusOr<ProgramShape> GetProgramShape();
- // Checks that the operand has the given expected shape. Returns the operand
- // if yes, fails with a CHECK error if no.
- ComputationDataHandle CheckShape(const ComputationDataHandle& operand,
- const Shape& expected_shape);
-
- // Checks that the lhs and rhs results have the same shape.
- void CheckSameShape(const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs);
-
// Enqueues a constant with the value of the given literal onto the
// computation.
ComputationDataHandle ConstantLiteral(const Literal& literal);
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index fca2bf2688..d02972f2c0 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -48,17 +48,3 @@ cc_library(
"//tensorflow/core:lib",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD
index cc5f551c9c..b1dba16856 100644
--- a/tensorflow/compiler/xla/client/xla_client/BUILD
+++ b/tensorflow/compiler/xla/client/xla_client/BUILD
@@ -70,22 +70,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/core:test",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index fcaf393b6b..e51a8b14c0 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -128,6 +128,18 @@ StatusOr<ProgramShape> XlaBuilder::GetProgramShape() {
return GetProgramShape(&root_id);
}
+XlaComputation XlaBuilder::BuildAndNoteError() {
+ DCHECK(parent_builder_ != nullptr);
+ auto build_status = Build();
+ if (!build_status.ok()) {
+ parent_builder_->NoteError(
+ AddStatus(build_status.status(),
+ tensorflow::strings::StrCat("error from: ", name_)));
+ return {};
+ }
+ return build_status.ConsumeValueOrDie();
+}
+
StatusOr<XlaComputation> XlaBuilder::Build() {
if (!first_error_.ok()) {
string backtrace;
@@ -357,10 +369,12 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation,
}
c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
- TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
- ShapeInference::InferCallShape(
- operand_shape_ptrs,
- /*to_apply=*/computation.GetProgramShape()));
+ TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
+ computation.GetProgramShape());
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferCallShape(operand_shape_ptrs,
+ /*to_apply=*/called_program_shape));
// Add called computation.
instr.add_called_computation_ids(
@@ -491,11 +505,40 @@ XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
}
XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ std::vector<const Shape*> operand_shape_ptrs;
+ std::vector<Shape> operand_shapes;
+ for (const XlaOp& e : elements) {
+ TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(e));
+ operand_shapes.push_back(shape);
+ }
+ c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferVariadicOpShape(
+ HloOpcode::kTuple, operand_shape_ptrs));
+ return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
+ }());
}
XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data));
+ if (!ShapeUtil::IsTuple(tuple_shape)) {
+ return InvalidArgument(
+ "Operand to GetTupleElement() is not a tuple; got %s",
+ ShapeUtil::HumanString(tuple_shape).c_str());
+ }
+ *instr.mutable_shape() =
+ ShapeUtil::GetTupleElementShape(tuple_shape, index);
+
+ instr.set_tuple_index(index);
+
+ return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
+ {tuple_data});
+ }());
}
XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs,
@@ -916,6 +959,99 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
return UnimplementedOp();
}
+StatusOr<bool> XlaBuilder::IsConstant(const XlaOp& operand,
+ int64 num_parameters) {
+ return Unimplemented("IsConstant is not implemented.");
+}
+
+StatusOr<std::unique_ptr<Literal>> XlaBuilder::ComputeConstant(
+ const XlaOp& operand, const Layout* output_layout,
+ tensorflow::gtl::ArraySlice<Literal> parameters) {
+ return Unimplemented("ComputeConstant is not implemented");
+}
+
+std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
+ const string& computation_name) {
+ auto sub_builder = MakeUnique<XlaBuilder>(computation_name);
+ sub_builder->parent_builder_ = this;
+ sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_;
+ return sub_builder;
+}
+
+Status XlaBuilder::SetReturnValue(const XlaOp& operand) {
+ return Unimplemented("SetReturnValue is not implemented.");
+}
+
+/* static */ ConvolutionDimensionNumbers
+XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
+ ConvolutionDimensionNumbers dimension_numbers;
+ dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
+ dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
+ dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
+ dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
+ dimension_numbers.set_kernel_output_feature_dimension(
+ kConvKernelOutputDimension);
+ dimension_numbers.set_kernel_input_feature_dimension(
+ kConvKernelInputDimension);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ dimension_numbers.add_input_spatial_dimensions(i + 2);
+ dimension_numbers.add_kernel_spatial_dimensions(i + 2);
+ dimension_numbers.add_output_spatial_dimensions(i + 2);
+ }
+ return dimension_numbers;
+}
+
+/* static */ Status XlaBuilder::Validate(
+ const ConvolutionDimensionNumbers& dnum) {
+ if (dnum.input_spatial_dimensions_size() < 2) {
+ return FailedPrecondition("input spacial dimension < 2: %d",
+ dnum.input_spatial_dimensions_size());
+ }
+ if (dnum.kernel_spatial_dimensions_size() < 2) {
+ return FailedPrecondition("kernel spacial dimension < 2: %d",
+ dnum.kernel_spatial_dimensions_size());
+ }
+ if (dnum.output_spatial_dimensions_size() < 2) {
+ return FailedPrecondition("output spacial dimension < 2: %d",
+ dnum.output_spatial_dimensions_size());
+ }
+
+ if (std::set<int64>(
+ {dnum.input_batch_dimension(), dnum.input_feature_dimension(),
+ dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)})
+ .size() != 4) {
+ return FailedPrecondition(
+ "dimension numbers for the input are not unique: (%lld, %lld, %lld, "
+ "%lld)",
+ dnum.input_batch_dimension(), dnum.input_feature_dimension(),
+ dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1));
+ }
+ if (std::set<int64>({dnum.kernel_output_feature_dimension(),
+ dnum.kernel_input_feature_dimension(),
+ dnum.kernel_spatial_dimensions(0),
+ dnum.kernel_spatial_dimensions(1)})
+ .size() != 4) {
+ return FailedPrecondition(
+ "dimension numbers for the weight are not unique: (%lld, %lld, %lld, "
+ "%lld)",
+ dnum.kernel_output_feature_dimension(),
+ dnum.kernel_input_feature_dimension(),
+ dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1));
+ }
+ if (std::set<int64>({dnum.output_batch_dimension(),
+ dnum.output_feature_dimension(),
+ dnum.output_spatial_dimensions(0),
+ dnum.output_spatial_dimensions(1)})
+ .size() != 4) {
+ return FailedPrecondition(
+ "dimension numbers for the output are not unique: (%lld, %lld, %lld, "
+ "%lld)",
+ dnum.output_batch_dimension(), dnum.output_feature_dimension(),
+ dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1));
+ }
+ return Status::OK();
+}
+
StatusOr<XlaOp> XlaBuilder::AddInstruction(
HloInstructionProto&& instr, HloOpcode opcode,
tensorflow::gtl::ArraySlice<XlaOp> operands) {
@@ -957,7 +1093,7 @@ StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
}
XlaOp XlaBuilder::UnimplementedOp() {
- NoteError(Unimplemented("Op not yet implemented"));
+ NoteError(Unimplemented("Op not implemented"));
return {};
}
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index c5c35159e0..f66feb93ce 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -335,6 +335,26 @@ class XlaBuilder {
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_numbers);
+ // Default dimension numbers used for a 2D convolution.
+ static constexpr int64 kConvBatchDimension = 0;
+ static constexpr int64 kConvFeatureDimension = 1;
+ static constexpr int64 kConvFirstSpatialDimension = 2;
+ static constexpr int64 kConvSecondSpatialDimension = 3;
+ static constexpr int64 kConvKernelOutputDimension = 0;
+ static constexpr int64 kConvKernelInputDimension = 1;
+ static constexpr int64 kConvKernelFirstSpatialDimension = 2;
+ static constexpr int64 kConvKernelSecondSpatialDimension = 3;
+
+ // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
+ // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
+ // the kernel operand
+ // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
+ static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
+ int num_spatial_dims = 2);
+
+ // Returns an error if the convolution dimension numbers have conflicts.
+ static Status Validate(const ConvolutionDimensionNumbers& dnum);
+
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
@@ -711,10 +731,59 @@ class XlaBuilder {
const XlaOp& grad_output, float epsilon,
int64 feature_index);
+ // Computes the value of a constant indicated by a XlaOp using a non-optimized
+ // interpreter on the host.
+ //
+ // The operand must represent a constant value, which in this case
+ // means that it must not statically depend on any parameter of the
+ // computation that is being built other then the ones specified on the
+ // parameter list. The parameters in the list will be indexed by their
+ // parameter id property so the number of parameters specified should be at
+ // least as many as the largest used parameter index.
+ //
+ // `IsConstant` can be used to test whether a computation is a compile-time
+ // constant without evaluation it. `ComputeConstant` only succeeds for
+ // computations where `IsConstant` returns true.
+ //
+ // This functionality can be useful when translating a computation
+ // into XLA where something that looked dynamic is required by
+ // XLA to be specified as a constant. E.g. the source
+ // computation (outside of XLA) may include a dynamic
+ // computation of the shape of something and ComputeConstant lets
+ // you determine what the value of that computation is in the case
+ // where the value can be determined at compile time.
+ //
+ // If output_layout is non-null, then the output of the computation
+ // will be stored using that layout.
+ StatusOr<std::unique_ptr<Literal>> ComputeConstant(
+ const XlaOp& operand, const Layout* output_layout = nullptr,
+ tensorflow::gtl::ArraySlice<Literal> parameters = {});
+
+ // Returns a new XlaBuilder whose resultant Computation is used only by this
+ // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
+ // behavior as the parent.
+ std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
+
+ // Modifies the computation being built so that executions of it will return
+ // the value associated with operand, rather than the last expression enqueued
+ // on the XlaBuilder. Any subsequent operations added to the XlaBuilder will
+ // not have any effect unless SetReturnValue is called again.
+ Status SetReturnValue(const XlaOp& operand);
+
// Builds the computation with the requested operations, or returns a non-ok
// status.
StatusOr<XlaComputation> Build();
+ // Builds the computation with the requested operations, or notes an error in
+ // the parent XlaBuilder and returns an empty computation if building failed.
+ // This function is intended to be used where the returned XlaComputation is
+ // only used by the parent XlaBuilder and hence further operation on the
+ // returned XlaComputation will simply be error'ed out if an error occurred
+ // while building this computation. If the built computation is to be used by
+ // a XlaBuilder other than the parent XlaBuilder then Build() should be used
+ // instead.
+ XlaComputation BuildAndNoteError();
+
// Returns the first error that was encountered while building the
// computation. When an error is encountered, by default we return a vacuous
// XlaOp and inform the user of the error that occurred while
@@ -814,6 +883,8 @@ class XlaBuilder {
// Mode bit that indicates whether to die when a first error is encountered.
bool die_immediately_on_error_ = false;
+
+ XlaBuilder* parent_builder_{nullptr};
};
template <typename NativeT>
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc
index 85d4227ba4..ce984564d0 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <string>
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -39,7 +40,8 @@ class XlaBuilderTest : public ::testing::Test {
TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build());
const HloModuleProto& proto = computation.proto();
TF_ASSIGN_OR_RETURN(const auto& config,
- HloModule::CreateModuleConfigFromProto(proto));
+ HloModule::CreateModuleConfigFromProto(
+ proto, legacy_flags::GetDebugOptionsFromFlags()));
return HloModule::CreateFromProto(proto, config);
}
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc
index 3681792eee..a6752c6010 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc
@@ -17,9 +17,12 @@ limitations under the License.
#include <utility>
+#include "tensorflow/compiler/xla/status_macros.h"
+
namespace xla {
-const ProgramShape& XlaComputation::GetProgramShape() const {
+StatusOr<ProgramShape> XlaComputation::GetProgramShape() const {
+ TF_RET_CHECK(proto_.has_program_shape());
return proto_.program_shape();
}
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_client/xla_computation.h
index 5b89747fdd..2a3c695266 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_computation.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.h
@@ -29,6 +29,8 @@ namespace xla {
// TODO(b/74197823): Replace xla::Computation with this one.
class XlaComputation {
public:
+ XlaComputation() : unique_id_(-1) {}
+
XlaComputation(const XlaComputation&) = delete;
XlaComputation& operator=(const XlaComputation&) = delete;
@@ -38,7 +40,8 @@ class XlaComputation {
// Returns the "program shape" (parameter and return shapes) for this
// computation.
- const ProgramShape& GetProgramShape() const;
+ StatusOr<ProgramShape> GetProgramShape() const;
+
const HloModuleProto& proto() const { return proto_; }
private:
diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD
index 0a9725db0a..89353448e2 100644
--- a/tensorflow/compiler/xla/legacy_flags/BUILD
+++ b/tensorflow/compiler/xla/legacy_flags/BUILD
@@ -75,17 +75,3 @@ tf_cc_test(
"//tensorflow/core:test",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index e2972f0601..0517a5502e 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -72,15 +72,3 @@ tf_py_wrap_cc(
"//tensorflow/compiler/xla/service:cpu_plugin",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index da16976d06..b7d1bf64d0 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -285,6 +285,23 @@ cc_library(
],
)
+tf_cc_test(
+ name = "dfs_hlo_visitor_with_default_test",
+ srcs = ["dfs_hlo_visitor_with_default_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_runner",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
+ ],
+)
+
cc_library(
name = "hlo_reachability",
srcs = ["hlo_reachability.cc"],
@@ -1580,6 +1597,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -2633,17 +2651,3 @@ cc_library(
"//tensorflow/core:lib",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index f9fabd8a35..0e4624fd69 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1731,18 +1731,29 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
function));
}
- VLOG(10) << "Considering folding Pad: " << operand->ToString()
- << "\ninto reduce-window: " << reduce_window->ToString();
-
// This optimization folds a pad op into reduce_window.
- if (operand->opcode() != HloOpcode::kPad) {
+ HloInstruction* pad;
+ const HloInstruction* convert = nullptr;
+ if (operand->opcode() == HloOpcode::kPad) {
+ pad = operand;
+ } else if (operand->opcode() == HloOpcode::kConvert &&
+ operand->operand(0)->opcode() == HloOpcode::kPad) {
+ convert = operand;
+ pad = operand->mutable_operand(0);
+ } else {
VLOG(10) << "Not folding pad into reduce-window as there is no pad.";
return Status::OK();
}
+ VLOG(10) << "Considering folding Pad: " << pad->ToString()
+ << "\ninto reduce-window: " << reduce_window->ToString()
+ << (convert != nullptr ? tensorflow::strings::StrCat(
+ "\nvia convert: ", convert->ToString())
+ : "");
+
// Do not fold interior padding into ReduceWindow since the backends do not
// support it.
- const PaddingConfig& pad_config = operand->padding_config();
+ const PaddingConfig& pad_config = pad->padding_config();
if (HasInteriorPadding(pad_config)) {
VLOG(10) << "Not folding pad into reduce-window due to interior padding.";
return Status::OK();
@@ -1750,14 +1761,27 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
// If reduce_window already has padding, the pad value of the pad op and the
// init value of reduce_window must match to allow folding the pad.
- const HloInstruction* pad_value = operand->operand(1);
+ const HloInstruction* pad_value = pad->operand(1);
const HloInstruction* reduce_init_value = reduce_window->operand(1);
if (pad_value != reduce_init_value) {
+ auto literals_are_equivalent = [&] {
+ auto& pad_literal = pad_value->literal();
+ auto& reduce_init_literal = reduce_init_value->literal();
+ if (pad_literal == reduce_init_literal) {
+ return true;
+ }
+ auto converted_pad_literal = pad_literal.ConvertToShape(
+ reduce_init_value->shape(), /*round_f32_to_bf16=*/true);
+ if (!converted_pad_literal.ok()) {
+ return false;
+ }
+ return *converted_pad_literal.ValueOrDie() == reduce_init_literal;
+ };
// The pad value is usually a constant, so we handle that case and do not
// try to get more fancy about proving equivalence in cases beyond that.
if (pad_value->opcode() != HloOpcode::kConstant ||
reduce_init_value->opcode() != HloOpcode::kConstant ||
- pad_value->literal() != reduce_init_value->literal()) {
+ !literals_are_equivalent()) {
VLOG(10) << "Not folding pad into reduce-window due to different pad "
"values.";
return Status::OK();
@@ -1766,7 +1790,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
// If the pad puts a single non-identity value in each window that we're
// reducing, then this is a broadcast.
- HloInstruction* pad_operand = operand->mutable_operand(0);
+ HloInstruction* pad_operand = pad->mutable_operand(0);
auto is_effective_broadcast = [&] {
if (window_util::HasStride(window)) {
VLOG(10) << "Window has stride.";
@@ -1810,6 +1834,18 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
VLOG(10) << "Found window covers a single unpadded element.";
return true;
};
+
+ HloInstruction* new_reduce_window_operand;
+ if (convert != nullptr) {
+ new_reduce_window_operand =
+ computation_->AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::ChangeElementType(pad_operand->shape(),
+ convert->shape().element_type()),
+ pad_operand));
+ } else {
+ new_reduce_window_operand = pad_operand;
+ }
+
if (is_effective_broadcast()) {
VLOG(10) << "Replacing pad/reduce-window with (implicit) broadcast.";
auto fadd = [this](std::unique_ptr<HloInstruction> x) {
@@ -1818,7 +1854,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
return ReplaceWithNewInstruction(
reduce_window, HloInstruction::CreateBroadcastSequence(
/*output_shape=*/reduce_window->shape(),
- /*operand=*/pad_operand, fadd));
+ /*operand=*/new_reduce_window_operand, fadd));
}
// Carry out the folding of the pad into reduce_window.
@@ -1835,10 +1871,11 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
window_dim.set_padding_high(window_dim.padding_high() +
pad_dim.edge_padding_high());
}
+
return ReplaceWithNewInstruction(
reduce_window, HloInstruction::CreateReduceWindow(
/*shape=*/reduce_window->shape(),
- /*operand=*/pad_operand,
+ /*operand=*/new_reduce_window_operand,
/*init_value=*/reduce_window->mutable_operand(1),
/*window=*/new_window,
/*reduce_computation=*/function));
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 3b80a827bf..20c549562d 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2338,6 +2338,91 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
}
+// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
+// ReduceWindow(Convert(op), x).
+TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
+ HloModule module(TestName());
+ HloComputation::Builder builder(TestName());
+
+ // Create operand to the pad.
+ HloInstruction* parameter =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(BF16, {1, 2, 3, 4}), "p0"));
+
+ // Create the pad.
+ PaddingConfig padding = MakeNoPaddingConfig(4);
+ padding.mutable_dimensions(1)->set_edge_padding_low(1);
+ padding.mutable_dimensions(3)->set_edge_padding_high(2);
+
+ HloInstruction* pad_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding));
+
+ HloInstruction* convert =
+ builder.AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::ChangeElementType(pad->shape(), F32), pad));
+
+ // Create add computation.
+ HloComputation* add_computation = nullptr;
+ {
+ HloComputation::Builder builder(TestName() + ".add");
+ const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ HloInstruction* p0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "p0"));
+ HloInstruction* p1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "p1"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
+ add_computation = module.AddEmbeddedComputation(builder.Build());
+ }
+
+ // Create the reduce-window.
+ Window window;
+ for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) {
+ auto* dim = window.add_dimensions();
+ dim->set_size(1);
+ dim->set_padding_low(10);
+ dim->set_padding_high(100);
+ dim->set_window_dilation(1);
+ dim->set_base_dilation(1);
+ }
+ const Shape reduce_window_shape =
+ ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
+ HloInstruction* reduce_init_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction* reduce_window =
+ builder.AddInstruction(HloInstruction::CreateReduceWindow(
+ reduce_window_shape, convert, reduce_init_value, window,
+ add_computation));
+
+ // Build the computation and run the simplifier.
+ auto computation = module.AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, reduce_window);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+
+ // Running simplification again should not result in any further changes.
+ ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
+
+ // Verify the result
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::ReduceWindow(op::Convert(parameter), op::Constant()));
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape))
+ << ShapeUtil::HumanString(root->shape()) << " vs "
+ << ShapeUtil::HumanString(reduce_window_shape);
+ EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
+ EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
+ EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
+ EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
+ EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
+ EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
+ EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
+ EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
+}
+
TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
HloComputation::Builder builder(TestName());
const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1});
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 0faa9e9c41..966e2d0fc5 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -916,17 +916,3 @@ tf_cc_test(
"//tensorflow/core:test",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index ecda5288ee..240faebe62 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -35,6 +35,12 @@ class HloInstruction;
// DfsHloVisitor with default action based on the HloInstruction being visited.
// Users should not use this class directly, but use the type aliases
// DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead.
+//
+// Do *not* add an override to this class if the opcode is covered by
+// HandleElementwiseUnary/Binary. These opcode handlers dispatch to
+// HandleElementwiseUnary/Binary in DfsHloVisitorBase. Adding such a handler
+// here will break passes which rely on the HandleElementwiseUnary/Binary
+// handling these opcodes.
template <typename HloInstructionPtr>
class DfsHloVisitorWithDefaultBase
: public DfsHloVisitorBase<HloInstructionPtr> {
@@ -70,12 +76,6 @@ class DfsHloVisitorWithDefaultBase
Status HandleConcatenate(HloInstructionPtr concatenate) override {
return DefaultAction(concatenate);
}
- Status HandleConvert(HloInstructionPtr convert) override {
- return DefaultAction(convert);
- }
- Status HandleCopy(HloInstructionPtr copy) override {
- return DefaultAction(copy);
- }
Status HandleSelect(HloInstructionPtr select) override {
return DefaultAction(select);
}
@@ -91,9 +91,6 @@ class DfsHloVisitorWithDefaultBase
Status HandleCrossReplicaSum(HloInstructionPtr crs) override {
return DefaultAction(crs);
}
- Status HandleCompare(HloInstructionPtr compare) override {
- return DefaultAction(compare);
- }
Status HandleRng(HloInstructionPtr random) override {
return DefaultAction(random);
}
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc
new file mode 100644
index 0000000000..825e1436f0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_runner.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+class DfsHloVisitorWithDefaultTest : public HloTestBase {};
+
+TEST_F(DfsHloVisitorWithDefaultTest, DefaultElementwiseTest) {
+ // Verify that HandleElementwiseBinary and HandleElementwiseUnary are called
+ // on the appropriate HLO ops (elementwise binary/unary ops).
+
+ class ElementwiseTestVisitor : public DfsHloVisitorWithDefault {
+ public:
+ Status DefaultAction(HloInstruction* hlo) override {
+ // The HLO should be neither an elementwise unary nor binary op. These
+ // cases are handled in HandleElementwiseBinary/Unary.
+ TF_RET_CHECK(!(hlo->IsElementwise() && hlo->operand_count() == 2))
+ << hlo->ToString();
+ TF_RET_CHECK(!(hlo->IsElementwise() && hlo->operand_count() == 1))
+ << hlo->ToString();
+ return Status::OK();
+ }
+
+ Status HandleElementwiseBinary(HloInstruction* hlo) override {
+ // HLO should be elementwise binary.
+ TF_RET_CHECK(hlo->IsElementwise() && hlo->operand_count() == 2)
+ << hlo->ToString();
+ return Status::OK();
+ }
+ Status HandleElementwiseUnary(HloInstruction* hlo) override {
+ // HLO should be elementwise unary.
+ TF_RET_CHECK(hlo->IsElementwise() && hlo->operand_count() == 1)
+ << hlo->ToString();
+ return Status::OK();
+ }
+ };
+
+ // HLO module contains are arbitrary mix of elementwise and non-elementwise
+ // operations.
+ const string& hlo_string = R"(
+HloModule TestModule
+
+ENTRY TestComputation {
+ arg = f32[] parameter(0)
+ tuple = (f32[]) tuple(arg)
+ gte = f32[] get-tuple-element(tuple), index=0
+ abs = f32[] abs(arg)
+ add = f32[] add(arg, gte)
+ broadcast = f32[42] broadcast(add), dimensions={}
+ slice = f32[0] slice(broadcast), slice={[1:2]}
+ copy = f32[] copy(arg)
+ eq = pred[] equal-to(arg, gte)
+ neg = f32[] negate(arg)
+ ROOT convert = f64[] convert(f32[] arg)
+})";
+ std::unique_ptr<HloModule> module =
+ HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())
+ .ConsumeValueOrDie();
+ ElementwiseTestVisitor visitor;
+ TF_EXPECT_OK(module->entry_computation()->Accept(&visitor));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 93b2f2a474..f1707442fe 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -700,17 +700,3 @@ tf_cc_test(
"//tensorflow/core:test",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 04b37d913e..28f9344795 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -267,16 +267,22 @@ StatusOr<std::unique_ptr<ShapedBuffer>> GpuExecutable::ExecuteOnStream(
++i) {
const BufferAllocation& allocation = assignment_->GetAllocation(i);
if (allocation.is_entry_computation_parameter()) {
- // The caller must give us a buffer for ShapeIndex {} of every parameter.
- // It can optionally give us a buffer for other ShapeIndices, but we
- // ignore them: Because we can't rely on these sub-buffers' addresses
- // being available, our generated code can't use them. Instead, it must
- // chase pointers starting at the tuple root.
- if (allocation.param_shape_index().empty()) {
- auto param_no = allocation.parameter_number();
- buffer_allocations_builder.RegisterBuffer(
- i, arguments[param_no]->root_buffer());
+ auto param_no = allocation.parameter_number();
+ se::DeviceMemoryBase buffer =
+ arguments[param_no]->buffer(allocation.param_shape_index());
+
+ // All top-level buffers and sub-buffers must have an explicit, non-null
+ // pointer, except for zero-sized buffers, which may be null.
+ if (buffer.is_null() && buffer.size() > 0) {
+ return FailedPrecondition(
+ "Cannot run XLA computation because pointer to (sub-)buffer at "
+ "index %s of parameter %lld was null. All pointers to "
+ "(sub-)buffers must not be null, unless the (sub-)buffer has zero "
+ "elements.",
+ allocation.param_shape_index().ToString().c_str(), param_no);
}
+
+ buffer_allocations_builder.RegisterBuffer(i, buffer);
}
}
se::StreamExecutor* executor = run_options->stream()->parent();
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 199e6b7874..d29cc21ab1 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -145,37 +145,6 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk,
llvm::ConstantAsMetadata::get(threads_per_block_ir_value)}));
}
-// Tries to get a Slice for the given instruction at the given index, but
-// returns nullopt if we might not know the slice's address at runtime without
-// dereferencing a containing tuple.
-//
-// In particular, when XLA accepts a parameter of tuple type, the caller has the
-// option of telling XLA what are the values inside of the tuple, or just giving
-// XLA a pointer to the top-level tuple and letting us chase the pointers on the
-// GPU. We therefore cannot rely having these pointers to parameter sub-buffers
-// being present when we run the program.
-optional<BufferAllocation::Slice> GetKnownAtRuntimeSlice(
- const HloInstruction* instr, const ShapeIndex& index,
- const BufferAssignment& buffer_assn) {
- auto maybe_slice = buffer_assn.GetUniqueSlice(instr, index);
- if (!maybe_slice.ok()) {
- return nullopt;
- }
- // BufferAllocation gives a slice and alloc to every buffer accessed by XLA,
- // but we don't necessarily know the runtime address of sub-buffers of input
- // parameters.
- const BufferAllocation::Slice& slice = maybe_slice.ValueOrDie();
- const BufferAllocation* alloc = slice.allocation();
- if (alloc->IsInputOrOutput() && !alloc->maybe_live_out() &&
- !alloc->param_shape_index().empty()) {
- return nullopt;
- }
-
- // Otherwise, we will know the address of this slice at runtime without having
- // to dereference a tuple.
- return slice;
-}
-
} // namespace
IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
@@ -206,7 +175,7 @@ bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment,
return hlo.opcode() == HloOpcode::kCopy &&
hlo.operand(0)->opcode() == HloOpcode::kConstant &&
ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) &&
- GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value();
+ buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok();
}
bool ImplementedAsDeviceToDeviceMemcpy(
@@ -216,13 +185,13 @@ bool ImplementedAsDeviceToDeviceMemcpy(
//
// 1. `hlo` is a kCopy instruction.
// 2. `hlo` and its operand have the same shape (thus the same layout too).
- // 3. The operand to `hlo` has a buffer assignment (constants do not, for
- // instance) which means the source buffer also resides on the device.
+ // 3. `hlo` and its operand have a statically-known buffer assignment
+ // (constants do not, for instance), which means the source buffer also
+ // resides on the device.
return hlo.opcode() == HloOpcode::kCopy &&
ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) &&
- GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value() &&
- GetKnownAtRuntimeSlice(hlo.operand(0), {}, buffer_assignment)
- .has_value();
+ buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok() &&
+ buffer_assignment.GetUniqueTopLevelSlice(hlo.operand(0)).ok();
}
} // namespace
@@ -1959,49 +1928,54 @@ GetHloBufferSlices(const HloInstruction* hlo,
-> optional<std::pair<BufferAllocation::Slice, ShapeIndex>> {
// Simple, common case: Is the buffer for instr known at runtime? If so,
// we're done.
- auto slice = GetKnownAtRuntimeSlice(instr, index, buffer_assn);
- if (slice.has_value()) {
- return {{*slice, ShapeIndex()}};
+ auto slice = buffer_assn.GetUniqueSlice(instr, index);
+ if (slice.ok()) {
+ return {{slice.ValueOrDie(), ShapeIndex()}};
}
- // If we don't know the buffer for instr at index, see if we know the buffer
- // for instr at index without its last element. If so, we can dynamically
- // find the buffer for instr by dereferencing a pointer in that buffer.
- // Continue looking this way until we run out of elements in 'index'.
- ShapeIndex new_index = index;
- ShapeIndex gte_indices;
- while (!new_index.empty()) {
- gte_indices.push_front(new_index.back());
- new_index.pop_back();
- auto slice = GetKnownAtRuntimeSlice(instr, new_index, buffer_assn);
- if (slice.has_value()) {
- return {{*slice, gte_indices}};
- }
- }
-
- // If *that* didn't work, walk up any bitcasts that we might see. These
- // must appear before any GTE instructions, because it's illegal to bitcast
- // to a tuple type.
+ // If that didn't work, walk up any bitcasts that we might see. These must
+ // appear before any GTE instructions, because it's illegal to bitcast to a
+ // tuple type.
const HloInstruction* parent = instr;
while (parent->opcode() == HloOpcode::kBitcast) {
parent = parent->operand(0);
- auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn);
- if (slice.has_value()) {
- return {{*slice, gte_indices}};
+ auto slice = buffer_assn.GetUniqueSlice(parent, {});
+ if (slice.ok()) {
+ return {{slice.ValueOrDie(), ShapeIndex()}};
}
}
- // Finally, check whether instr is a GTE instruction. If it is, see if we
- // can get a buffer for its parent, and continue walking up parents until we
- // find a defined buffer or we hit something that's not a GTE.
+ // Check whether instr is a GTE instruction. If it is, see if we can get a
+ // buffer for its parent, and continue walking up parents until we find a
+ // defined buffer or we hit something that's not a GTE.
+ ShapeIndex gte_indices;
while (parent->opcode() == HloOpcode::kGetTupleElement) {
gte_indices.push_front(parent->tuple_index());
parent = parent->operand(0);
- auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn);
- if (slice.has_value()) {
- return {{*slice, gte_indices}};
+ auto slice = buffer_assn.GetUniqueSlice(parent, {});
+ if (slice.ok()) {
+ return {{slice.ValueOrDie(), gte_indices}};
+ }
+ }
+
+ // Finally, if we don't know the buffer for instr at index, see if we know
+ // the buffer for instr at index without its last element. If so, we can
+ // dynamically find the buffer for instr by dereferencing a pointer in that
+ // buffer. Continue looking this way until we run out of elements in
+ // 'index'.
+ //
+ // We can almost always get a buffer without resorting to this. The only
+ // exception is for cases where the relevant sub-buffer is truly unknowable,
+ // for example the sub-buffer of a tuple-shaped select.
+ ShapeIndex new_index = index;
+ while (!new_index.empty()) {
+ gte_indices.push_front(new_index.back());
+ new_index.pop_back();
+ auto slice = buffer_assn.GetUniqueSlice(instr, new_index);
+ if (slice.ok()) {
+ return {{slice.ValueOrDie(), gte_indices}};
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
index f4c4dcdafd..86c4ac18b0 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
@@ -68,17 +68,3 @@ tf_cc_test(
"@llvm//:support",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 4ec2ef27bf..44e4f75f75 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -379,20 +380,101 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) {
}
Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
- auto rhs_instruction = convolution->operand(1);
+ auto lhs = convolution->operand(0);
+ auto rhs = convolution->operand(1);
+ Window window = convolution->window();
+ const auto& result_shape = convolution->shape();
+ const Shape& lhs_shape = lhs->shape();
+ const Shape& rhs_shape = rhs->shape();
+
const auto& dnums = convolution->convolution_dimension_numbers();
- const int64 output_features =
- convolution->shape().dimensions(dnums.output_feature_dimension());
-
- // For each output element, we do one fma per element in the kernel at some
- // given output feature index.
- const int64 fmas_per_output_element =
- output_features > 0
- ? ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features
- : 0;
- const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape());
- current_properties_[kFlopsKey] =
- output_elements * fmas_per_output_element * kFmaFlops;
+
+ const int64 input_batch_dim = dnums.input_batch_dimension();
+ const int64 input_feature_dim = dnums.input_feature_dimension();
+ const int64 output_feature_dim = dnums.output_feature_dimension();
+ const int64 input_feature =
+ ShapeUtil::GetDimension(lhs_shape, input_feature_dim);
+ const int64 output_feature =
+ ShapeUtil::GetDimension(result_shape, output_feature_dim);
+ const int64 batch = ShapeUtil::GetDimension(lhs_shape, input_batch_dim);
+
+ DimensionVector kernel_limits;
+ DimensionVector output_limits;
+ DimensionVector input_limits;
+ if (window.dimensions().empty()) {
+ window = window_util::MakeWindow({1});
+ kernel_limits.push_back(1);
+ output_limits.push_back(1);
+ input_limits.push_back(1);
+ } else {
+ for (int64 spatial_dimension = 0;
+ spatial_dimension < window.dimensions_size(); ++spatial_dimension) {
+ // Spatial dimension number for kernel (rhs).
+ const int64 kernel_spatial_dim =
+ dnums.kernel_spatial_dimensions(spatial_dimension);
+ const int64 kernel_limit = rhs_shape.dimensions(kernel_spatial_dim);
+ kernel_limits.push_back(kernel_limit);
+
+ // Spatial dimension number for output.
+ const int64 output_spatial_dim =
+ dnums.output_spatial_dimensions(spatial_dimension);
+ const int64 output_limit = result_shape.dimensions(output_spatial_dim);
+ output_limits.push_back(output_limit);
+
+ // Spatial dimension number for input (lhs).
+ const int64 input_spatial_dim =
+ dnums.input_spatial_dimensions(spatial_dimension);
+ const int64 input_limit = lhs_shape.dimensions(input_spatial_dim);
+ input_limits.push_back(input_limit);
+ }
+ }
+
+ DimensionVector valid_position_counts;
+
+ // Loop over each spatial dimension.
+ for (int64 spatial_dimension = 0;
+ spatial_dimension < window.dimensions_size(); ++spatial_dimension) {
+ int64 valid_position_count = 0;
+ // Loop over each point in the kernel.
+ for (int64 kernel_idx = 0; kernel_idx < kernel_limits[spatial_dimension];
+ ++kernel_idx) {
+ // Loop over each point in the output.
+ for (int64 output_idx = 0; output_idx < output_limits[spatial_dimension];
+ ++output_idx) {
+ // Calculate lhs (input) index without taking base dilation into
+ // account.
+ const auto& window_dim = window.dimensions(spatial_dimension);
+ const int64 undilated_index = output_idx * window_dim.stride() -
+ window_dim.padding_low() +
+ kernel_idx * window_dim.window_dilation();
+
+ // Calculate the actual lhs (input) index after dilation. Avoid the
+ // division as an optimization.
+ const int64 lhs_spatial_index =
+ window_dim.base_dilation() > 1
+ ? undilated_index / window_dim.base_dilation()
+ : undilated_index;
+
+ // Skip if the lhs (input) index is to be dilated.
+ if (undilated_index != lhs_spatial_index * window_dim.base_dilation()) {
+ continue;
+ }
+
+ // Skip if input index is not in bound.
+ if (lhs_spatial_index < 0 ||
+ lhs_spatial_index >= input_limits[spatial_dimension]) {
+ continue;
+ }
+
+ valid_position_count += 1;
+ }
+ }
+ valid_position_counts.push_back(valid_position_count);
+ }
+
+ const int64 fma_count =
+ input_feature * output_feature * batch * Product(valid_position_counts);
+ current_properties_[kFlopsKey] = fma_count * kFmaFlops;
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 3b289c240a..3d055b327e 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -186,12 +186,14 @@ TEST_F(HloCostAnalysisTest, Map) {
TEST_F(HloCostAnalysisTest, Convolution) {
ComputationBuilder builder(client_, "convolution");
auto input = builder.Parameter(
- 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
- /*x_dim=*/20}),
+ 0,
+ ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
+ /*x_dim=*/20}),
"input");
auto kernel = builder.Parameter(
- 1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
- /*x_dim=*/3}),
+ 1,
+ ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
+ /*x_dim=*/3}),
"kernel");
auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid);
@@ -440,5 +442,32 @@ TEST_F(HloCostAnalysisTest, TupleCost) {
EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2);
}
+TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
+ ComputationBuilder builder(client_, "BaseDilatedConvolution");
+ auto input = builder.Parameter(
+ 0,
+ ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
+ /*x_dim=*/20}),
+ "input");
+ auto kernel = builder.Parameter(
+ 1,
+ ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
+ /*x_dim=*/3}),
+ "kernel");
+
+ auto result = builder.ConvGeneralDilated(
+ input, kernel, /*window_strides=*/{1, 1}, /*padding=*/{{1, 1}, {1, 1}},
+ /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers(2));
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis(ShapeSize);
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ EXPECT_EQ(analysis.flop_count(), 1472);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 1dc72355cf..25702dc65e 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -823,7 +823,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
// Otherwise, print e.g. "%constant.42 (s32[100])".
string constant_name;
- if (tensorflow::StringPiece(constant->name()).starts_with("constant")) {
+ if (tensorflow::str_util::StartsWith(constant->name(), "constant")) {
constant_name = constant->name();
} else {
constant_name = StrCat("constant ", constant->name());
@@ -1041,8 +1041,8 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
// The HLO instruction name contains usually the opcode, e.g. "%add.42" is
// an add instruction. In this case we render just the name.
- if (tensorflow::StringPiece(instr->name())
- .starts_with(HloOpcodeString(instr->opcode()))) {
+ if (tensorflow::str_util::StartsWith(instr->name(),
+ HloOpcodeString(instr->opcode()))) {
return Printf("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
}
string extended_opcode =
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 595c531ccf..08b9a29aed 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -295,12 +295,13 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
/* static */
StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
- const HloModuleProto& module) {
+ const HloModuleProto& module, const DebugOptions& debug_options) {
TF_RET_CHECK(module.has_program_shape())
<< "No program shape found in the proto";
const auto& program_shape = module.program_shape();
HloModuleConfig module_config(program_shape);
+ module_config.set_debug_options(debug_options);
// The module config is constructed with default layouts regardless of what is
// passed in via the ProgramShape. Set the layouts to the appropriate values.
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 755bbd359f..9f7f25202b 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -172,7 +172,7 @@ class HloModule {
// Creates and returns an HloModuleConfig with an appropriate program shape
// for the HLO module in the given proto.
static StatusOr<HloModuleConfig> CreateModuleConfigFromProto(
- const HloModuleProto& module);
+ const HloModuleProto& module, const DebugOptions& debug_options);
// Outlines the given expression from the given computation.
// instructions_to_outline contains the instructions that form the expression.
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index e5b1c2efa3..ec7d8210a7 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -52,10 +52,9 @@ namespace {
// Creates an HloModule from the given proto.
StatusOr<std::unique_ptr<HloModule>> HloProtoToModule(
const HloProto& proto, const DebugOptions& debug_options) {
- TF_ASSIGN_OR_RETURN(
- HloModuleConfig config,
- HloModule::CreateModuleConfigFromProto(proto.hlo_module()));
- config.set_debug_options(debug_options);
+ TF_ASSIGN_OR_RETURN(HloModuleConfig config,
+ HloModule::CreateModuleConfigFromProto(proto.hlo_module(),
+ debug_options));
TF_ASSIGN_OR_RETURN(auto module,
HloModule::CreateFromProto(proto.hlo_module(), config));
return std::move(module);
diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD
index 0db3863f24..4550548495 100644
--- a/tensorflow/compiler/xla/service/interpreter/BUILD
+++ b/tensorflow/compiler/xla/service/interpreter/BUILD
@@ -120,14 +120,3 @@ cc_library(
"//tensorflow/core:stream_executor_no_cuda",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index 37261ed1e6..f1e7fc2953 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -169,17 +169,3 @@ cc_library(
"@llvm//:core",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 1d379f0d03..ca8071b7bb 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -837,6 +837,11 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
return tensorflow::Status::OK();
}
+tensorflow::Status Service::ExecuteGraphParallel(
+ const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) {
+ return Unimplemented("execute-graph-parallel is not yet implemented");
+}
+
tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
GetDeviceHandlesResponse* result) {
const int64 available_device_count = execute_backend_->device_count();
@@ -1445,6 +1450,29 @@ tensorflow::Status Service::GetComputationStats(
return tensorflow::Status::OK();
}
+tensorflow::Status Service::GetComputationGraphStats(
+ const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) {
+ HloModuleConfig config;
+ config.set_debug_options(arg->debug_options());
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
+ HloModule::CreateFromProto(arg->computation(), config));
+
+ hlo_graph_dumper::MaybeDumpHloModule(*module,
+ "computation statistics subject");
+
+ // Run HLO analysis to get the computation statistics.
+ HloCostAnalysis analysis(
+ execute_backend_->compiler()->ShapeSizeBytesFunction());
+
+ TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis));
+
+ ComputationStats stats;
+ stats.set_flop_count(analysis.flop_count());
+ stats.set_transcendental_count(analysis.transcendental_count());
+ *result->mutable_stats() = stats;
+ return tensorflow::Status::OK();
+}
+
template <typename RequestT, typename ResponseT>
tensorflow::Status Service::AddInstruction(
const RequestT* arg, ResponseT* result,
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 773f0a642d..ebe4a2e043 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -126,6 +126,15 @@ class Service : public ServiceInterface {
tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg,
ExecuteParallelResponse* result) override;
+ // Executes one or more computations in parallel with the provided global data
+ // passed as immutable arguments. Returns global data output for each
+ // computation.
+ //
+ // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+ tensorflow::Status ExecuteGraphParallel(
+ const ExecuteGraphParallelRequest* arg,
+ ExecuteParallelResponse* result) override;
+
// Requests one or more device handles from the target.
//
// When N device handles are requested and the number of replicas is R, at
@@ -224,6 +233,13 @@ class Service : public ServiceInterface {
const ComputationStatsRequest* arg,
ComputationStatsResponse* result) override;
+ // Retrieves the statistics of a computation.
+ //
+ // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+ tensorflow::Status GetComputationGraphStats(
+ const ComputationGraphStatsRequest* arg,
+ ComputationStatsResponse* result) override;
+
// Snapshots the current state of a computation handle into a serializable
// protocol buffer form, so it can be loaded via
// LoadComputationSnapshot.
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 36456d552d..77e12d3602 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1070,6 +1070,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
for (const HloInstruction* operand : operands) {
operand_shapes.push_back(&operand->shape());
}
+ return InferVariadicOpShape(opcode, operand_shapes);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
+ HloOpcode opcode,
+ tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
return InferVariadicOpShape(OpcodeToVariadicOperation(opcode),
operand_shapes);
}
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 88830e6d25..9da2c99b41 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -85,6 +85,9 @@ class ShapeInference {
tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
static StatusOr<Shape> InferVariadicOpShape(
HloOpcode opcode,
+ tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ static StatusOr<Shape> InferVariadicOpShape(
+ HloOpcode opcode,
tensorflow::gtl::ArraySlice<const HloInstruction*> operands);
// Infers the shape produced by applying the given mapping computation shape
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 0dca30a804..fcdb2e01fb 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -1284,8 +1284,8 @@ StatusOr<ComputationDataHandle> UserComputation::AddCustomCallInstruction(
TF_RETURN_IF_ERROR(LookUpRequest(handle).status());
}
- if (tensorflow::StringPiece(custom_call_request.call_target_name())
- .starts_with("$")) {
+ if (tensorflow::str_util::StartsWith(custom_call_request.call_target_name(),
+ "$")) {
return InvalidArgument(
"Invalid custom_call_target \"%s\": Call targets that start with '$' "
"are reserved for internal use.",
diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h
index d8235113dd..32aae64973 100644
--- a/tensorflow/compiler/xla/service_interface.h
+++ b/tensorflow/compiler/xla/service_interface.h
@@ -60,6 +60,10 @@ class ServiceInterface {
virtual tensorflow::Status ExecuteParallel(
const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) = 0;
+ virtual tensorflow::Status ExecuteGraphParallel(
+ const ExecuteGraphParallelRequest* arg,
+ ExecuteParallelResponse* result) = 0;
+
virtual tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg,
ExecuteAsyncResponse* result) = 0;
@@ -72,6 +76,10 @@ class ServiceInterface {
virtual tensorflow::Status GetComputationStats(
const ComputationStatsRequest* arg, ComputationStatsResponse* result) = 0;
+ virtual tensorflow::Status GetComputationGraphStats(
+ const ComputationGraphStatsRequest* arg,
+ ComputationStatsResponse* result) = 0;
+
virtual tensorflow::Status GetComputationShape(
const GetComputationShapeRequest* arg,
GetComputationShapeResponse* result) = 0;
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 5ab25f2264..e337669aeb 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1011,6 +1011,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -1958,17 +1960,3 @@ tf_cc_test(
"//tensorflow/core:test",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index 28ab965499..af8af99c79 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -69,6 +69,17 @@ class BatchNormalizationTest
CHECK_EQ(kY, input_array_.width());
}
+ ComputationDataHandle CheckShape(ComputationBuilder* b,
+ const ComputationDataHandle& operand,
+ const Shape& expected_shape) const {
+ std::unique_ptr<Shape> actual_shape =
+ b->GetShape(operand).ConsumeValueOrDie();
+ CHECK(ShapeUtil::Equal(expected_shape, *actual_shape))
+ << "want " << ShapeUtil::HumanString(expected_shape) << " got "
+ << ShapeUtil::HumanString(*actual_shape);
+ return operand;
+ }
+
static constexpr int64 kSamples = 3;
static constexpr int64 kX = 1;
static constexpr int64 kY = 1;
@@ -164,14 +175,15 @@ XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) {
XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) {
ComputationBuilder builder(client_, "batch_normalize_per_spec");
auto input_activations =
- builder.CheckShape(builder.ConstantLiteral(input_literal_),
- ShapeUtil::MakeShape(F32, {3, 2, 1, 1}));
+ CheckShape(&builder, builder.ConstantLiteral(input_literal_),
+ ShapeUtil::MakeShape(F32, {3, 2, 1, 1}));
auto gamma = builder.ConstantR1<float>({1.0, 1.0});
auto beta = builder.ConstantR1<float>({0.0, 0.0});
Computation add = CreateScalarAddComputation(F32, &builder);
// Reduce all dimensions except dimension 1.
Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2});
- auto sum = builder.CheckShape(
+ auto sum = CheckShape(
+ &builder,
builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
/*dimensions_to_reduce=*/{0, 2, 3}),
TwoElementVectorF32);
@@ -187,14 +199,16 @@ XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) {
auto activation_deviations = builder.Sub(input_activations, set_means,
/*broadcast_dimensions=*/{1});
auto dev_squares = builder.SquareF32(activation_deviations);
- auto sum_of_squares = builder.CheckShape(
+ auto sum_of_squares = CheckShape(
+ &builder,
builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add,
/*dimensions_to_reduce=*/{0, 2, 3}),
TwoElementVectorF32);
auto variance = builder.Div(sum_of_squares, count);
auto standard_deviation = builder.SqrtF32(variance);
- auto standard_deviation_above_epsilon = builder.CheckShape(
- builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2}));
+ auto standard_deviation_above_epsilon =
+ CheckShape(&builder, builder.Gt(standard_deviation, epsilon),
+ ShapeUtil::MakeShape(PRED, {2}));
auto gt_eps = builder.Select(standard_deviation_above_epsilon,
standard_deviation, epsilon2);
auto normalization_factors = builder.ReciprocalF32(gt_eps);
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index ec95a68ead..4a9faef1dc 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -441,8 +441,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
EXPECT_EQ(expected, actual->GetR1U8AsString());
}
+template <typename BuilderT>
void ClientLibraryTestBase::ComputeAndCompareTuple(
- ComputationBuilder* builder, const Literal& expected,
+ BuilderT* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
auto actual_status = ExecuteAndTransfer(builder, arguments);
EXPECT_IS_OK(actual_status.status());
@@ -453,8 +454,9 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
LiteralTestUtil::ExpectEqual(expected, *actual);
}
+template <typename BuilderT>
void ClientLibraryTestBase::ComputeAndCompareTuple(
- ComputationBuilder* builder, const Literal& expected,
+ BuilderT* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
auto actual_status = ExecuteAndTransfer(builder, arguments);
EXPECT_IS_OK(actual_status.status());
@@ -619,4 +621,20 @@ template void ClientLibraryTestBase::ComputeAndCompareLiteral(
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
const Shape* shape_with_layout);
+template void ClientLibraryTestBase::ComputeAndCompareTuple(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+template void ClientLibraryTestBase::ComputeAndCompareTuple(
+ XlaBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+template void ClientLibraryTestBase::ComputeAndCompareTuple(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
+
+template void ClientLibraryTestBase::ComputeAndCompareTuple(
+ XlaBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 5ff200be03..be90f14c8e 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -217,11 +217,13 @@ class ClientLibraryTestBase : public ::testing::Test {
// Convenience method for running a built computation, transferring the
// result, and comparing it to the expected tuple literal.
+ template <typename BuilderT>
void ComputeAndCompareTuple(
- ComputationBuilder* builder, const Literal& expected,
+ BuilderT* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ template <typename BuilderT>
void ComputeAndCompareTuple(
- ComputationBuilder* builder, const Literal& expected,
+ BuilderT* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
// Convenience method for running a built computation and comparing the result
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index ec2c580670..e5a03b49ad 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -167,8 +168,8 @@ TEST_F(ComputeConstantTest, DirectParamMissing) {
EXPECT_FALSE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
- EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString())
- .contains("depends on a parameter"))
+ EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(),
+ "depends on a parameter"))
<< value.status();
}
}
@@ -183,8 +184,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) {
EXPECT_FALSE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
- EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString())
- .contains("depends on a parameter"))
+ EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(),
+ "depends on a parameter"))
<< value.status();
}
}
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 09b1dd283e..7b994a4c17 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -54,6 +54,25 @@ using TypesF16F32F64CF64 =
#error "Situation not handled yet"
#endif
+// Check that we can safely pass an input tuple's elements to a dot operation.
+TEST_F(DotOperationTest, DotOfInputTupleElem) {
+ ComputationBuilder builder(client_, TestName());
+
+ ComputationDataHandle param;
+ auto param_data = CreateParameterAndTransferLiteral(
+ 0,
+ *Literal::MakeTuple({Literal::CreateR2<float>({{1, 2}, {3, 4}}).get(),
+ Literal::CreateR2<float>({{5, 6}, {7, 8}}).get()}),
+ "arg0", &builder, &param);
+ auto lhs = builder.GetTupleElement(param, 0);
+ auto rhs = builder.GetTupleElement(param, 1);
+ builder.Dot(lhs, rhs);
+
+ ComputeAndCompareLiteral(&builder,
+ *Literal::CreateR2<float>({{19, 22}, {43, 50}}),
+ {param_data.get()});
+}
+
template <typename T>
class DotOperationTest_F16F32F64CF64 : public DotOperationTest {};
TYPED_TEST_CASE(DotOperationTest_F16F32F64CF64, TypesF16F32F64CF64);
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index fa60af4b6a..098be6d7aa 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -41,7 +43,7 @@ class TupleTest : public ClientLibraryTestBase {
// Tests a tuple-shaped constant.
XLA_TEST_F(TupleTest, TupleConstant) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
@@ -54,13 +56,13 @@ XLA_TEST_F(TupleTest, TupleConstant) {
Literal::CreateR1<float>(constant_vector).get(),
Literal::CreateR2<float>(constant_matrix).get()});
- auto result = builder.ConstantLiteral(*value);
+ builder.ConstantLiteral(*value);
ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
}
// Tests a tuple made of scalar constants.
XLA_TEST_F(TupleTest, TupleScalarConstant) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
const float constant_scalar1 = 7.3f;
const float constant_scalar2 = 1.2f;
@@ -68,13 +70,13 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) {
Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar1).get(),
Literal::CreateR0<float>(constant_scalar2).get()});
- auto result = builder.ConstantLiteral(*value);
+ builder.ConstantLiteral(*value);
ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
}
// Tests the creation of tuple data.
XLA_TEST_F(TupleTest, TupleCreate) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
@@ -82,9 +84,9 @@ XLA_TEST_F(TupleTest, TupleCreate) {
{1.1f, 2.2f, 3.5f}, // row 0
{4.8f, 5.0f, 6.7f}, // row 1
};
- auto result = builder.Tuple({builder.ConstantR0<float>(constant_scalar),
- builder.ConstantR1<float>(constant_vector),
- builder.ConstantR2<float>(constant_matrix)});
+ builder.Tuple({builder.ConstantR0<float>(constant_scalar),
+ builder.ConstantR1<float>(constant_vector),
+ builder.ConstantR2<float>(constant_matrix)});
auto expected =
Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(),
@@ -95,9 +97,9 @@ XLA_TEST_F(TupleTest, TupleCreate) {
// Tests the creation of tuple data.
XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.ConstantR0<float>(7.0), builder.ConstantR1<float>({})});
auto expected = Literal::MakeTuple({Literal::CreateR0<float>(7.0).get(),
@@ -107,15 +109,15 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
// Tests the creation of an empty tuple.
XLA_TEST_F(TupleTest, EmptyTupleCreate) {
- ComputationBuilder builder(client_, TestName());
- auto result = builder.Tuple({});
+ XlaBuilder builder(TestName());
+ builder.Tuple({});
auto expected = Literal::MakeTuple({});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
// Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest, GetTupleElement) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.f, 2.f, 3.f}, // row 0
@@ -123,23 +125,23 @@ XLA_TEST_F(TupleTest, GetTupleElement) {
};
auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
builder.ConstantR2<float>(constant_matrix)});
- auto matrix_element = builder.GetTupleElement(tuple_data, 1);
+ builder.GetTupleElement(tuple_data, 1);
ComputeAndCompareR2<float>(&builder, Array2D<float>(constant_matrix), {},
error_spec_);
}
// Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto tuple_data = builder.Tuple(
{builder.ConstantR1<float>({}),
builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 101))});
- auto matrix_element = builder.GetTupleElement(tuple_data, 1);
+ builder.GetTupleElement(tuple_data, 1);
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_);
}
XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto value = builder.ConstantR1<float>({4.5f});
builder.GetTupleElement(value, 1);
auto result_status = builder.Build();
@@ -152,7 +154,7 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) {
// Extracts both elements from a tuple with GetTupleElement and then adds them
// together.
XLA_TEST_F(TupleTest, AddTupleElements) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.f, 2.f, 3.f}, // row 0
@@ -164,22 +166,22 @@ XLA_TEST_F(TupleTest, AddTupleElements) {
auto matrix_element = builder.GetTupleElement(tuple_data, 1);
auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie();
auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie();
- auto result = builder.Add(matrix_element, vector_element,
- /*broadcast_dimensions=*/{1});
+ builder.Add(matrix_element, vector_element,
+ /*broadcast_dimensions=*/{1});
Array2D<float> expected({
{2.f, 4.f, 6.f}, // row 0
{5.f, 7.f, 9.f}, // row 1
});
- ASSERT_TRUE(ShapeUtil::ShapeIs(*vector_shape, F32, {3}));
- ASSERT_TRUE(ShapeUtil::ShapeIs(*matrix_shape, F32, {/*y=*/2, /*x=*/3}));
+ ASSERT_TRUE(ShapeUtil::ShapeIs(vector_shape, F32, {3}));
+ ASSERT_TRUE(ShapeUtil::ShapeIs(matrix_shape, F32, {/*y=*/2, /*x=*/3}));
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
// Extracts both elements from a tuple and then puts them into a new tuple in
// the opposite order.
XLA_TEST_F(TupleTest, TupleGTEToTuple) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.f, 2.f, 3.f}, // row 0
@@ -187,8 +189,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
};
auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
builder.ConstantR2<float>(constant_matrix)});
- auto new_tuple = builder.Tuple({builder.GetTupleElement(tuple_data, 1),
- builder.GetTupleElement(tuple_data, 0)});
+ builder.Tuple({builder.GetTupleElement(tuple_data, 1),
+ builder.GetTupleElement(tuple_data, 0)});
auto expected =
Literal::MakeTuple({Literal::CreateR2<float>(constant_matrix).get(),
Literal::CreateR1<float>(constant_vector).get()});
@@ -196,8 +198,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
}
XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
- ComputationBuilder b(client_, TestName());
- ComputationDataHandle v1, v2;
+ XlaBuilder b(TestName());
+ XlaOp v1, v2;
for (bool direction : {false, true}) {
std::unique_ptr<GlobalData> v1_data =
@@ -210,7 +212,7 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
auto v2_gt = b.Gt(v2, v1); // true
auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true}
auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false}
- auto select = b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
+ b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
auto expected =
Literal::MakeTuple({Literal::CreateR0<bool>(direction).get(),
Literal::CreateR0<bool>(!direction).get()});
@@ -237,7 +239,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
// \ (tuple10)-- /
// \ / \ /
// -----(GTE 0)-- --(GTE 1)----------
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.f, 2.f, 3.f}, // row 0
@@ -257,8 +259,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
auto addvectors = builder.Add(vector_from_01, vector_from_10);
auto addmatrices = builder.Add(matrix_from_01, matrix_from_10);
- auto result = builder.Add(addmatrices, addvectors,
- /*broadcast_dimensions=*/{1});
+ builder.Add(addmatrices, addvectors,
+ /*broadcast_dimensions=*/{1});
Array2D<float> expected({
{4.f, 8.f, 12.f}, // row 0
@@ -269,7 +271,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) {
// Tests a selection between tuples with "false" path taken.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
@@ -278,8 +280,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) {
auto tuple21 = builder.Tuple(
{builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
- auto select =
- builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+ builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec2).get(),
Literal::CreateR1<float>(vec1).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
@@ -314,7 +315,7 @@ XLA_TEST_F(TupleTest, TuplesInAMap) {
XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) {
// Tests a selection between tuples with "true" path taken.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
@@ -323,8 +324,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) {
auto tuple21 = builder.Tuple(
{builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
- auto select =
- builder.Select(builder.ConstantR0<bool>(true), tuple12, tuple21);
+ builder.Select(builder.ConstantR0<bool>(true), tuple12, tuple21);
auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec1).get(),
Literal::CreateR1<float>(vec2).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
@@ -333,7 +333,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) {
XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
// Tests a selection between tuples but the final result is an element of the
// tuple, not the whole tuple.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
@@ -344,7 +344,7 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
auto select =
builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
- auto element = builder.GetTupleElement(select, 0);
+ builder.GetTupleElement(select, 0);
ComputeAndCompareR1<float>(&builder, vec2, {}, error_spec_);
}
@@ -368,7 +368,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) {
// / --(GTE 1)--
// /
// (tuple 21)
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
@@ -384,8 +384,8 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) {
builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21);
auto select2 =
builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1);
- auto result = builder.Add(builder.GetTupleElement(select2, 0),
- builder.GetTupleElement(select2, 1));
+ builder.Add(builder.GetTupleElement(select2, 0),
+ builder.GetTupleElement(select2, 1));
ComputeAndCompareR1<float>(&builder, {3.f, 6.f, 9.f}, {}, error_spec_);
}
@@ -394,7 +394,7 @@ XLA_TEST_F(TupleTest,
DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesReuseConstants)) {
// Similar to SelectBetweenTuples, but the constants are shared between the
// input tuples.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
@@ -403,19 +403,18 @@ XLA_TEST_F(TupleTest,
auto tuple12 = builder.Tuple({c1, c2});
auto tuple21 = builder.Tuple({c2, c1});
- auto select =
- builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+ builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+
auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec2).get(),
Literal::CreateR1<float>(vec1).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, NestedTuples) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto inner_tuple = builder.Tuple(
{builder.ConstantR1<float>({1.0, 2.0}), builder.ConstantR0<float>(42.0)});
- auto outer_tuple =
- builder.Tuple({inner_tuple, builder.ConstantR1<float>({22.0, 44.0})});
+ builder.Tuple({inner_tuple, builder.ConstantR1<float>({22.0, 44.0})});
auto expected_v1 = Literal::CreateR1<float>({1.0, 2.0});
auto expected_s = Literal::CreateR0<float>(42.0);
@@ -429,7 +428,7 @@ XLA_TEST_F(TupleTest, NestedTuples) {
}
XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Shape data_shape = ShapeUtil::MakeShape(F32, {3});
Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape});
@@ -460,7 +459,7 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
}
XLA_TEST_F(TupleTest, ComplexTuples) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
{
Shape c64r0 = ShapeUtil::MakeShape(C64, {});
Shape c64r1 = ShapeUtil::MakeShape(C64, {2});
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 24b9f37a80..ff3418a128 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -294,7 +295,8 @@ XLA_TEST_F(HloProfileTest,
auto while_body_profile_start =
std::find_if(profile_output_lines.begin(), profile_output_lines.end(),
[](tensorflow::StringPiece s) {
- return s.starts_with("Execution profile for body");
+ return tensorflow::str_util::StartsWith(
+ s, "Execution profile for body");
});
ASSERT_NE(while_body_profile_start, profile_output_lines.end());
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index 2e55f609d1..0bc4045a54 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -223,17 +223,3 @@ tf_cc_binary(
"//tensorflow/core:lib",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD
index 97aacf6b39..0fa4b98d0a 100644
--- a/tensorflow/compiler/xla/tools/parser/BUILD
+++ b/tensorflow/compiler/xla/tools/parser/BUILD
@@ -70,17 +70,3 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
-
-# -----------------------------------------------------------------------------
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index edf1b07af8..5cb18113e5 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -299,6 +299,11 @@ message ComputationStatsRequest {
DebugOptions debug_options = 2;
}
+message ComputationGraphStatsRequest {
+ HloModuleProto computation = 1;
+ DebugOptions debug_options = 2;
+}
+
message ComputationStatsResponse {
ComputationStats stats = 1;
}
@@ -355,6 +360,10 @@ message ExecuteParallelRequest {
repeated ExecuteRequest requests = 1;
}
+message ExecuteGraphParallelRequest {
+ repeated ExecuteGraphRequest requests = 1;
+}
+
message ExecuteResponse {
GlobalDataHandle output = 1;
ExecutionProfile profile = 2;
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index fb81b50fe8..c211ad8b9b 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -174,15 +174,3 @@ cc_library(
"//conditions:default": [],
}),
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD
index 8dff93b4f8..62d1b1cf07 100644
--- a/tensorflow/contrib/all_reduce/BUILD
+++ b/tensorflow/contrib/all_reduce/BUILD
@@ -45,16 +45,3 @@ tf_py_test(
"//tensorflow/python:state_ops",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "g3doc/sitemap.md",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py
index 6658f0d9c1..8add2aacff 100644
--- a/tensorflow/contrib/all_reduce/python/all_reduce.py
+++ b/tensorflow/contrib/all_reduce/python/all_reduce.py
@@ -38,16 +38,15 @@ def _flatten_tensors(tensors):
shape: the original shape of each element of input tensors
Raises:
- ValueError: tensors are empty or non-isomorphic.
+ ValueError: tensors are empty or non-isomorphic or have unknown shape.
"""
if not tensors:
raise ValueError("tensors cannot be empty")
shape = tensors[0].shape
for tensor in tensors:
shape = shape.merge_with(tensor.shape)
- if shape.ndims is None:
- raise ValueError("At least one of the tensors in 'tensors' must have "
- "statically known rank.")
+ if not shape.is_fully_defined():
+ raise ValueError("Tensors must have statically known shape.")
if len(shape) != 1:
reshaped = []
for t in tensors:
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce_test.py b/tensorflow/contrib/all_reduce/python/all_reduce_test.py
index 47bab0a367..b3f5d92259 100644
--- a/tensorflow/contrib/all_reduce/python/all_reduce_test.py
+++ b/tensorflow/contrib/all_reduce/python/all_reduce_test.py
@@ -36,6 +36,12 @@ from tensorflow.python.platform import tf_logging
class AllReduceTest(test_util.TensorFlowTestCase):
+ def testFlattenTensorsShapesDefined(self):
+ x = array_ops.placeholder(types_pb2.DT_FLOAT, [None])
+ with self.assertRaisesRegexp(ValueError,
+ "must have statically known shape"):
+ ar._flatten_tensors([x, x])
+
def testRingPermutations(self):
# 0 devices
pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 0, [])
diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD
index 4bff3c27d2..60306ebdc6 100644
--- a/tensorflow/contrib/android/BUILD
+++ b/tensorflow/contrib/android/BUILD
@@ -38,20 +38,6 @@ cc_library(
alwayslink = 1,
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "bin/**",
- "gen/**",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
# JAR with Java bindings to TF.
android_library(
name = "android_tensorflow_inference_java",
diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py
index 4ab32ee47d..c6af0e4d13 100644
--- a/tensorflow/contrib/autograph/utils/builtins.py
+++ b/tensorflow/contrib/autograph/utils/builtins.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import sys
+
import six
from tensorflow.contrib.autograph.utils import py_func
@@ -97,7 +99,13 @@ def dynamic_print(*values):
if all(map(is_tf_print_compatible, values)):
return logging_ops.Print(1, values)
- return py_func.wrap_py_func(print, None, values, use_dummy_return=True)
+
+ def flushed_print(*vals):
+ print(*vals)
+ sys.stdout.flush()
+
+ return py_func.wrap_py_func(
+ flushed_print, None, values, use_dummy_return=True)
def dynamic_dataset(iterated):
diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD
index ee67909133..d65c990c87 100644
--- a/tensorflow/contrib/batching/BUILD
+++ b/tensorflow/contrib/batching/BUILD
@@ -112,14 +112,3 @@ py_test(
"//tensorflow/python:script_ops",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/batching/test_util/BUILD b/tensorflow/contrib/batching/test_util/BUILD
index 6db627faad..7cb2d8079b 100644
--- a/tensorflow/contrib/batching/test_util/BUILD
+++ b/tensorflow/contrib/batching/test_util/BUILD
@@ -8,17 +8,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
-
cc_library(
name = "fake_clock_env",
testonly = 1,
diff --git a/tensorflow/contrib/batching/util/BUILD b/tensorflow/contrib/batching/util/BUILD
index 2a84a7712a..8f81b6702f 100644
--- a/tensorflow/contrib/batching/util/BUILD
+++ b/tensorflow/contrib/batching/util/BUILD
@@ -8,18 +8,6 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "**/google_*",
- ],
- ),
-)
-
cc_library(
name = "periodic_function_dynamic",
hdrs = ["periodic_function.h"],
diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD
index a55029b314..5a2d7f6a3c 100644
--- a/tensorflow/contrib/bayesflow/BUILD
+++ b/tensorflow/contrib/bayesflow/BUILD
@@ -57,15 +57,3 @@ cuda_py_test(
"//tensorflow/python:random_seed",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD
index 6fdcd0f996..ddeda0079c 100644
--- a/tensorflow/contrib/boosted_trees/BUILD
+++ b/tensorflow/contrib/boosted_trees/BUILD
@@ -14,15 +14,6 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = ["**/OWNERS"],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
package_group(name = "friends")
cc_library(
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
index dcd235f876..17e20c4b31 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
@@ -10,15 +10,6 @@ package(
load("//tensorflow:tensorflow.bzl", "py_test")
-filegroup(
- name = "all_files",
- srcs = glob(
- include = ["**/*"],
- exclude = ["**/OWNERS"],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
py_library(
name = "init_py",
srcs = ["__init__.py"],
diff --git a/tensorflow/contrib/boosted_trees/lib/BUILD b/tensorflow/contrib/boosted_trees/lib/BUILD
index 131bd48562..3028c22817 100644
--- a/tensorflow/contrib/boosted_trees/lib/BUILD
+++ b/tensorflow/contrib/boosted_trees/lib/BUILD
@@ -15,17 +15,6 @@ load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
# Utils
cc_library(
diff --git a/tensorflow/contrib/boosted_trees/proto/BUILD b/tensorflow/contrib/boosted_trees/proto/BUILD
index 9a61e163eb..b07f0a4314 100644
--- a/tensorflow/contrib/boosted_trees/proto/BUILD
+++ b/tensorflow/contrib/boosted_trees/proto/BUILD
@@ -4,17 +4,6 @@ exports_files(["LICENSE"])
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tf_proto_library(
name = "learner_proto",
srcs = [
diff --git a/tensorflow/contrib/boosted_trees/resources/BUILD b/tensorflow/contrib/boosted_trees/resources/BUILD
index 9fc101612f..c065186845 100644
--- a/tensorflow/contrib/boosted_trees/resources/BUILD
+++ b/tensorflow/contrib/boosted_trees/resources/BUILD
@@ -9,17 +9,6 @@ package(
],
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
cc_library(
name = "stamped_resource",
hdrs = ["stamped_resource.h"],
diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD
index fe8bd072af..f3a75e8688 100644
--- a/tensorflow/contrib/cloud/BUILD
+++ b/tensorflow/contrib/cloud/BUILD
@@ -14,18 +14,6 @@ load(
"tf_py_test",
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tf_gen_op_libs(
op_lib_names = ["bigquery_reader_ops"],
deps = [
diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD
index 56f930a9a8..ff46f0daa8 100644
--- a/tensorflow/contrib/cloud/kernels/BUILD
+++ b/tensorflow/contrib/cloud/kernels/BUILD
@@ -20,20 +20,6 @@ load(
"tf_proto_library",
)
-filegroup(
- name = "all_files",
- srcs = glob(
- include = [
- "**/*",
- ],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tf_kernel_library(
name = "bigquery_reader_ops",
srcs = ["bigquery_reader_ops.cc"],
@@ -73,6 +59,7 @@ tf_cc_test(
],
deps = [
":bigquery_table_accessor",
+ "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc
index e9b79a066d..7416eb19d3 100644
--- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc
+++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/example/feature.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/cloud/http_request_fake.h"
#include "tensorflow/core/platform/test.h"
@@ -28,8 +29,8 @@ constexpr char kTestProject[] = "test-project";
constexpr char kTestDataset[] = "test-dataset";
constexpr char kTestTable[] = "test-table";
-bool HasSubstr(const string& base, const string& substr) {
- bool ok = StringPiece(base).contains(substr);
+bool HasSubstr(StringPiece base, StringPiece substr) {
+ bool ok = str_util::StrContains(base, substr);
EXPECT_TRUE(ok) << base << ", expected substring " << substr;
return ok;
}
diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD
index 1a124eca36..c239e6f8f9 100644
--- a/tensorflow/contrib/cluster_resolver/BUILD
+++ b/tensorflow/contrib/cluster_resolver/BUILD
@@ -10,19 +10,6 @@ package(
licenses(["notice"]) # Apache 2.0
-filegroup(
- name = "all_files",
- srcs = glob(
- include = [
- "**/*",
- ],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
-
py_library(
name = "cluster_resolver_pip",
srcs = [
diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake
index cc218e8ab8..abfc69243e 100644
--- a/tensorflow/contrib/cmake/external/grpc.cmake
+++ b/tensorflow/contrib/cmake/external/grpc.cmake
@@ -17,7 +17,7 @@ include (ExternalProject)
set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include)
set(GRPC_URL https://github.com/grpc/grpc.git)
set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc)
-set(GRPC_TAG 575bda39755b98d1f7099406bb57a6e3b2074874)
+set(GRPC_TAG bd6bdf93279a39a8cd92978fd7c9d14eccd98fc2)
if(WIN32)
if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index 112b690511..cc7d791042 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -79,6 +79,7 @@ tensorflow/python/keras/_impl/keras/preprocessing
tensorflow/python/keras/_impl/keras/utils
tensorflow/python/keras/_impl/keras/wrappers
tensorflow/python/kernel_tests
+tensorflow/python/kernel_tests/boosted_trees
tensorflow/python/kernel_tests/distributions
tensorflow/python/kernel_tests/linalg
tensorflow/python/kernel_tests/random
diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt
index c03c0c80fe..0c80d529af 100644
--- a/tensorflow/contrib/cmake/python_protos.txt
+++ b/tensorflow/contrib/cmake/python_protos.txt
@@ -1,4 +1,5 @@
tensorflow/core
+tensorflow/core/kernels/boosted_trees
tensorflow/core/profiler
tensorflow/python
tensorflow/contrib/boosted_trees/proto
diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake
index d6712aa2b4..092a48bc6b 100644
--- a/tensorflow/contrib/cmake/tf_core_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_core_ops.cmake
@@ -15,8 +15,9 @@
set(tf_op_lib_names
"audio_ops"
"array_ops"
- "batch_ops"
+ "batch_ops"
"bitwise_ops"
+ "boosted_trees_ops"
"candidate_sampling_ops"
"checkpoint_ops"
"control_flow_ops"
@@ -28,7 +29,7 @@ set(tf_op_lib_names
"image_ops"
"io_ops"
"linalg_ops"
- "list_ops"
+ "list_ops"
"lookup_ops"
"logging_ops"
"manip_ops"
@@ -48,7 +49,7 @@ set(tf_op_lib_names
"state_ops"
"stateless_random_ops"
"string_ops"
- "summary_ops"
+ "summary_ops"
"training_ops"
)
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 31e715b654..b776307924 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -319,6 +319,7 @@ GENERATE_PYTHON_OP_LIB("audio_ops")
GENERATE_PYTHON_OP_LIB("array_ops")
GENERATE_PYTHON_OP_LIB("batch_ops")
GENERATE_PYTHON_OP_LIB("bitwise_ops")
+GENERATE_PYTHON_OP_LIB("boosted_trees_ops")
GENERATE_PYTHON_OP_LIB("math_ops")
GENERATE_PYTHON_OP_LIB("functional_ops")
GENERATE_PYTHON_OP_LIB("candidate_sampling_ops")
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index b86a8f1ec2..92f2ab6dea 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -284,6 +284,7 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py" # Deadlocks
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561
+ "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py" # Segfaults on Windows.
# tensor_forest tests (also note that we exclude the hybrid tests for now)
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order.
diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD
index ec3d550b70..ce12e38248 100644
--- a/tensorflow/contrib/coder/BUILD
+++ b/tensorflow/contrib/coder/BUILD
@@ -154,14 +154,3 @@ tf_py_test(
],
main = "python/ops/coder_ops_test.py",
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index 388d8e6ed6..bcee0b04c8 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -46,15 +46,3 @@ cuda_py_test(
],
xla_enabled = True,
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/copy_graph/BUILD b/tensorflow/contrib/copy_graph/BUILD
index 8ec706df74..fa44c4d54e 100644
--- a/tensorflow/contrib/copy_graph/BUILD
+++ b/tensorflow/contrib/copy_graph/BUILD
@@ -41,15 +41,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/crf/BUILD b/tensorflow/contrib/crf/BUILD
index 7aad4abdb9..5c1a17df4f 100644
--- a/tensorflow/contrib/crf/BUILD
+++ b/tensorflow/contrib/crf/BUILD
@@ -40,15 +40,3 @@ cuda_py_tests(
"//tensorflow/python:platform_test",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD
index fa86ad38c9..8b5d13f725 100644
--- a/tensorflow/contrib/cudnn_rnn/BUILD
+++ b/tensorflow/contrib/cudnn_rnn/BUILD
@@ -123,15 +123,3 @@ cuda_py_test(
"requires_cudnn5",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD
index 9e25a77d9f..35312f06b3 100644
--- a/tensorflow/contrib/data/BUILD
+++ b/tensorflow/contrib/data/BUILD
@@ -44,17 +44,3 @@ tf_custom_op_library(
tf_gen_op_libs(
op_lib_names = ["dataset_ops"],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- include = [
- "**/*",
- ],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 766721d8d2..17048314a4 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -32,6 +32,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
@@group_by_window
@@ignore_errors
@@make_batched_features_dataset
+@@make_csv_dataset
@@make_saveable_from_iterator
@@map_and_batch
@@padded_batch_and_drop_remainder
@@ -70,6 +71,7 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset
+from tensorflow.contrib.data.python.ops.readers import make_csv_dataset
from tensorflow.contrib.data.python.ops.readers import read_batch_features
from tensorflow.contrib.data.python.ops.readers import SqlDataset
from tensorflow.contrib.data.python.ops.resampling import rejection_resample
@@ -82,3 +84,6 @@ from tensorflow.python.ops.parsing_ops import parse_single_example_v2 as parse_s
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)
+
+# A constant that can be used to enable auto-tuning.
+AUTOTUNE = -1
diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD
index c87da7dfaa..83ada6fb67 100644
--- a/tensorflow/contrib/data/kernels/BUILD
+++ b/tensorflow/contrib/data/kernels/BUILD
@@ -61,14 +61,3 @@ cc_library(
"@protobuf_archive//:protobuf_headers",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
index 79d1fc3494..2afb8dbbf4 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
@@ -314,6 +314,7 @@ class FunctionBufferResourceHandleOp : public OpKernel {
source_device, target_device, func_args, thread_pool_size_);
return Status::OK();
}));
+ core::ScopedUnref s(buffer);
OP_REQUIRES_OK(ctx, buffer->Instantiate());
initialized_ = true;
}
@@ -373,25 +374,27 @@ class FunctionBufferingResourceGetNextOp : public AsyncOpKernel {
OP_REQUIRES_OK_ASYNC(
ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer),
done);
- core::ScopedUnref s(buffer);
if (buffer->Finished()) {
+ buffer->Unref();
ctx->SetStatus(errors::OutOfRange("end_of_sequence"));
done();
return;
}
FunctionBufferCallback callback =
- [ctx, done](const BufferElement& buffer_element) {
+ [ctx, buffer, done](const BufferElement& buffer_element) {
Status s = buffer_element.status;
if (!s.ok()) {
ctx->SetStatus(s);
+ buffer->Unref();
done();
return;
}
for (size_t i = 0; i < buffer_element.value.size(); ++i) {
ctx->set_output(i, buffer_element.value[i]);
}
+ buffer->Unref();
done();
};
buffer->MaybeGet(std::move(callback));
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 0b3bf63f79..0f4c9e48cf 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -513,17 +513,3 @@ tf_py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- include = [
- "**/*",
- ],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index d0131896a1..6002cc73c8 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -104,6 +104,21 @@ class GroupByWindowTest(test.TestCase):
self.assertAllEqual([0, 0, 0], sess.run(get_next))
self.assertAllEqual([1], sess.run(get_next))
+ def testEmpty(self):
+ iterator = (
+ dataset_ops.Dataset.range(4).apply(
+ grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Window size must be greater than zero, but got 0."):
+ print(sess.run(get_next))
+
def testReduceFuncError(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 647620eb84..236792bb98 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -183,15 +183,3 @@ py_library(
"//tensorflow/python/data/util:sparse",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 95edca6cdd..9a48aa02fb 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -18,9 +18,11 @@ from __future__ import division
from __future__ import print_function
import csv
+from math import ceil
import numpy as np
+from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.data.python.ops import shuffle_ops
from tensorflow.python.data.ops import dataset_ops
@@ -176,6 +178,9 @@ def make_csv_dataset(
shuffle_buffer_size=10000,
shuffle_seed=None,
prefetch_buffer_size=1,
+ num_parallel_reads=1,
+ num_parallel_parser_calls=2,
+ sloppy=False,
default_float_type=dtypes.float32,
num_rows_for_inference=100,
):
@@ -231,6 +236,15 @@ def make_csv_dataset(
prefetch_buffer_size: An int specifying the number of feature batches to
prefetch for performance improvement. Recommended value is the number of
batches consumed per training step.
+ num_parallel_reads: Number of threads used to read CSV records from files.
+ If >1, the results will be interleaved.
+ num_parallel_parser_calls: Number of parallel invocations of the CSV parsing
+ function on CSV records.
+ sloppy: If `True`, reading performance will be improved at
+ the cost of non-deterministic ordering. If `False`, the order of elements
+ produced is deterministic prior to shuffling (elements are still
+ randomized if `shuffle=True`. Note that if the seed is set, then order
+ of elements after shuffling is deterministic). Defaults to `False`.
default_float_type: Either `tf.float32` or `tf.float64`. If defaults are
not provided, float-like strings are interpreted to be this type.
num_rows_for_inference: Number of rows of a file to use for type inference
@@ -247,11 +261,16 @@ def make_csv_dataset(
Raises:
ValueError: If any of the arguments is malformed.
"""
- filenames = _get_file_names(file_pattern, shuffle)
+ # Create dataset of all matching filenames
+ filenames = _get_file_names(file_pattern, False)
+ dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
+ if shuffle:
+ dataset = dataset.shuffle(len(filenames), shuffle_seed)
+
+ # Clean arguments; figure out column names and defaults
if comment is not None and len(comment) != 1:
raise ValueError("`comment` arg must be a single-character string or None")
- # Clean arguments; figure out column names and defaults
if column_names is None:
if not header:
raise ValueError("Cannot infer column names without a header line.")
@@ -272,7 +291,6 @@ def make_csv_dataset(
filenames, len(column_names), field_delim, use_quote_delim, na_value,
header, comment, default_float_type, num_rows_for_inference)
- dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
if label_name is not None and label_name not in column_names:
raise ValueError("`label_name` provided must be one of the columns.")
@@ -311,16 +329,31 @@ def make_csv_dataset(
return features, label
return features
- # TODO(rachelim): interleave records from files for better shuffling
- dataset = dataset.flat_map(filename_to_dataset)
- # TODO(rachelim): use fused shuffle_and_repeat for perf
- if shuffle:
+ # Read files sequentially or in parallel
+ dataset = dataset.apply(
+ interleave_ops.parallel_interleave(
+ filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
+
+ if num_epochs != 1 and shuffle:
+ # Use shuffle_and_repeat for perf
+ dataset = dataset.apply(
+ shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs,
+ shuffle_seed))
+ elif shuffle:
dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed)
- if num_epochs != 1:
+ elif num_epochs != 1:
dataset = dataset.repeat(num_epochs)
- dataset = dataset.batch(batch_size)
- dataset = dataset.map(decode_csv)
+ # Use map_and_batch for perf
+ # TODO(b/76425672): use num_parallel_calls for better performance tuning when
+ # that is added
+ dataset = dataset.apply(
+ batching.map_and_batch(
+ map_func=decode_csv,
+ batch_size=batch_size,
+ num_parallel_batches=int(
+ ceil(num_parallel_parser_calls / batch_size))))
+
dataset = dataset.prefetch(prefetch_buffer_size)
return dataset
@@ -416,12 +449,10 @@ def make_batched_features_dataset(file_pattern,
`Tensor` or `SparseTensor` objects.
"""
# Create dataset of all matching filenames
+ filenames = _get_file_names(file_pattern, False)
+ dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
if shuffle:
- dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=True)
- else:
- # TODO(b/73959787): Use Dataset.list_files() once ordering is deterministic.
- filenames = _get_file_names(file_pattern, shuffle)
- dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
+ dataset = dataset.shuffle(len(filenames), shuffle_seed)
# Read `Example` records from files as tensor objects.
if reader_args is None:
diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD
index ae3847b8b6..3b50a48336 100644
--- a/tensorflow/contrib/decision_trees/proto/BUILD
+++ b/tensorflow/contrib/decision_trees/proto/BUILD
@@ -13,14 +13,6 @@ load(
"tf_pyclif_proto_library",
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tf_proto_library(
name = "generic_tree_model",
srcs = ["generic_tree_model.proto"],
diff --git a/tensorflow/contrib/deprecated/BUILD b/tensorflow/contrib/deprecated/BUILD
index 3dfbbf5527..401527f1e7 100644
--- a/tensorflow/contrib/deprecated/BUILD
+++ b/tensorflow/contrib/deprecated/BUILD
@@ -30,15 +30,3 @@ py_test(
"//tensorflow/python:logging_ops",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 1c381cc354..de08eb491b 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -457,6 +457,20 @@ cuda_py_test(
)
cuda_py_test(
+ name = "batch_reshape_test",
+ size = "small",
+ srcs = ["python/kernel_tests/batch_reshape_test.py"],
+ additional_deps = [
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "sample_stats_test",
size = "medium",
srcs = ["python/kernel_tests/sample_stats_test.py"],
@@ -486,6 +500,7 @@ cuda_py_test(
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
],
+ shard_count = 4,
tags = [
"manual",
"noasan",
@@ -745,18 +760,6 @@ cuda_py_test(
],
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
# === Bijector Tests ==========================================================
cuda_py_test(
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index 61c411271d..4d4489468d 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -24,6 +24,7 @@ from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops.autoregressive import *
+from tensorflow.contrib.distributions.python.ops.batch_reshape import *
from tensorflow.contrib.distributions.python.ops.binomial import *
from tensorflow.contrib.distributions.python.ops.cauchy import *
from tensorflow.contrib.distributions.python.ops.chi2 import *
@@ -96,9 +97,10 @@ _allowed_symbols = [
'ReparameterizationType',
'Distribution',
'Autoregressive',
- 'Binomial',
+ 'BatchReshape',
'Bernoulli',
'Beta',
+ 'Binomial',
'BetaWithSoftplusConcentration',
'Categorical',
'Chi2',
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
new file mode 100644
index 0000000000..4d2f40e27f
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
@@ -0,0 +1,531 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for BatchReshape."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import batch_reshape as batch_reshape_lib
+from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_lib
+from tensorflow.contrib.distributions.python.ops import wishart as wishart_lib
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops.distributions import normal as normal_lib
+from tensorflow.python.platform import test
+
+
+class _BatchReshapeTest(object):
+
+ def make_wishart(self, dims, new_batch_shape, old_batch_shape):
+ new_batch_shape_ph = (
+ constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape
+ else array_ops.placeholder_with_default(
+ np.int32(new_batch_shape), shape=None))
+
+ scale = self.dtype([
+ [[1., 0.5],
+ [0.5, 1.]],
+ [[0.5, 0.25],
+ [0.25, 0.75]],
+ ])
+ scale = np.reshape(np.concatenate([scale, scale], axis=0),
+ old_batch_shape + [dims, dims])
+ scale_ph = array_ops.placeholder_with_default(
+ scale, shape=scale.shape if self.is_static_shape else None)
+ wishart = wishart_lib.WishartFull(df=5, scale=scale_ph)
+ reshape_wishart = batch_reshape_lib.BatchReshape(
+ distribution=wishart,
+ batch_shape=new_batch_shape_ph,
+ validate_args=True)
+
+ return wishart, reshape_wishart
+
+ def test_matrix_variate_sample_and_log_prob(self):
+ dims = 2
+ new_batch_shape = [4]
+ old_batch_shape = [2, 2]
+ wishart, reshape_wishart = self.make_wishart(
+ dims, new_batch_shape, old_batch_shape)
+
+ batch_shape = reshape_wishart.batch_shape_tensor()
+ event_shape = reshape_wishart.event_shape_tensor()
+
+ expected_sample_shape = [3, 1] + new_batch_shape + [dims, dims]
+ x = wishart.sample([3, 1], seed=42)
+ expected_sample = array_ops.reshape(x, expected_sample_shape)
+ actual_sample = reshape_wishart.sample([3, 1], seed=42)
+
+ expected_log_prob_shape = [3, 1] + new_batch_shape
+ expected_log_prob = array_ops.reshape(
+ wishart.log_prob(x), expected_log_prob_shape)
+ actual_log_prob = reshape_wishart.log_prob(expected_sample)
+
+ with self.test_session() as sess:
+ [
+ batch_shape_,
+ event_shape_,
+ expected_sample_, actual_sample_,
+ expected_log_prob_, actual_log_prob_,
+ ] = sess.run([
+ batch_shape,
+ event_shape,
+ expected_sample, actual_sample,
+ expected_log_prob, actual_log_prob,
+ ])
+
+ self.assertAllEqual(new_batch_shape, batch_shape_)
+ self.assertAllEqual([dims, dims], event_shape_)
+ self.assertAllClose(expected_sample_, actual_sample_,
+ atol=0., rtol=1e-6)
+ self.assertAllClose(expected_log_prob_, actual_log_prob_,
+ atol=0., rtol=1e-6)
+ if not self.is_static_shape:
+ return
+ self.assertAllEqual(new_batch_shape, reshape_wishart.batch_shape)
+ self.assertAllEqual([dims, dims], reshape_wishart.event_shape)
+ self.assertAllEqual(expected_sample_shape, actual_sample.shape)
+ self.assertAllEqual(expected_log_prob_shape, actual_log_prob.shape)
+
+ def test_matrix_variate_stats(self):
+ dims = 2
+ new_batch_shape = [4]
+ old_batch_shape = [2, 2]
+ wishart, reshape_wishart = self.make_wishart(
+ dims, new_batch_shape, old_batch_shape)
+
+ expected_scalar_stat_shape = new_batch_shape
+ expected_matrix_stat_shape = new_batch_shape + [dims, dims]
+
+ expected_entropy = array_ops.reshape(
+ wishart.entropy(), expected_scalar_stat_shape)
+ actual_entropy = reshape_wishart.entropy()
+
+ expected_mean = array_ops.reshape(
+ wishart.mean(), expected_matrix_stat_shape)
+ actual_mean = reshape_wishart.mean()
+
+ expected_mode = array_ops.reshape(
+ wishart.mode(), expected_matrix_stat_shape)
+ actual_mode = reshape_wishart.mode()
+
+ expected_stddev = array_ops.reshape(
+ wishart.stddev(), expected_matrix_stat_shape)
+ actual_stddev = reshape_wishart.stddev()
+
+ expected_variance = array_ops.reshape(
+ wishart.variance(), expected_matrix_stat_shape)
+ actual_variance = reshape_wishart.variance()
+
+ with self.test_session() as sess:
+ [
+ expected_entropy_, actual_entropy_,
+ expected_mean_, actual_mean_,
+ expected_mode_, actual_mode_,
+ expected_stddev_, actual_stddev_,
+ expected_variance_, actual_variance_,
+ ] = sess.run([
+ expected_entropy, actual_entropy,
+ expected_mean, actual_mean,
+ expected_mode, actual_mode,
+ expected_stddev, actual_stddev,
+ expected_variance, actual_variance,
+ ])
+
+ self.assertAllClose(expected_entropy_, actual_entropy_,
+ atol=0., rtol=1e-6)
+ self.assertAllClose(expected_mean_, actual_mean_,
+ atol=0., rtol=1e-6)
+ self.assertAllClose(expected_mode_, actual_mode_,
+ atol=0., rtol=1e-6)
+ self.assertAllClose(expected_stddev_, actual_stddev_,
+ atol=0., rtol=1e-6)
+ self.assertAllClose(expected_variance_, actual_variance_,
+ atol=0., rtol=1e-6)
+ if not self.is_static_shape:
+ return
+ self.assertAllEqual(expected_scalar_stat_shape, actual_entropy.shape)
+ self.assertAllEqual(expected_matrix_stat_shape, actual_mean.shape)
+ self.assertAllEqual(expected_matrix_stat_shape, actual_mode.shape)
+ self.assertAllEqual(expected_matrix_stat_shape, actual_stddev.shape)
+ self.assertAllEqual(expected_matrix_stat_shape, actual_variance.shape)
+
+ def make_normal(self, new_batch_shape, old_batch_shape):
+ new_batch_shape_ph = (
+ constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape
+ else array_ops.placeholder_with_default(
+ np.int32(new_batch_shape), shape=None))
+
+ scale = self.dtype(0.5 + np.arange(
+ np.prod(old_batch_shape)).reshape(old_batch_shape))
+ scale_ph = array_ops.placeholder_with_default(
+ scale, shape=scale.shape if self.is_static_shape else None)
+ normal = normal_lib.Normal(loc=self.dtype(0), scale=scale_ph)
+ reshape_normal = batch_reshape_lib.BatchReshape(
+ distribution=normal,
+ batch_shape=new_batch_shape_ph,
+ validate_args=True)
+ return normal, reshape_normal
+
+ def test_scalar_variate_sample_and_log_prob(self):
+ new_batch_shape = [2, 2]
+ old_batch_shape = [4]
+
+ normal, reshape_normal = self.make_normal(
+ new_batch_shape, old_batch_shape)
+
+ batch_shape = reshape_normal.batch_shape_tensor()
+ event_shape = reshape_normal.event_shape_tensor()
+
+ expected_sample_shape = new_batch_shape
+ x = normal.sample(seed=52)
+ expected_sample = array_ops.reshape(x, expected_sample_shape)
+ actual_sample = reshape_normal.sample(seed=52)
+
+ expected_log_prob_shape = new_batch_shape
+ expected_log_prob = array_ops.reshape(
+ normal.log_prob(x), expected_log_prob_shape)
+ actual_log_prob = reshape_normal.log_prob(expected_sample)
+
+ with self.test_session() as sess:
+ [
+ batch_shape_,
+ event_shape_,
+ expected_sample_, actual_sample_,
+ expected_log_prob_, actual_log_prob_,
+ ] = sess.run([
+ batch_shape,
+ event_shape,
+ expected_sample, actual_sample,
+ expected_log_prob, actual_log_prob,
+ ])
+ self.assertAllEqual(new_batch_shape, batch_shape_)
+ self.assertAllEqual([], event_shape_)
+ self.assertAllClose(expected_sample_, actual_sample_,
+ atol=0., rtol=1e-6)
+ self.assertAllClose(expected_log_prob_, actual_log_prob_,
+ atol=0., rtol=1e-6)
+ if not self.is_static_shape:
+ return
+ self.assertAllEqual(new_batch_shape, reshape_normal.batch_shape)
+ self.assertAllEqual([], reshape_normal.event_shape)
+ self.assertAllEqual(expected_sample_shape, actual_sample.shape)
+ self.assertAllEqual(expected_log_prob_shape, actual_log_prob.shape)
+
+ def test_scalar_variate_stats(self):
+ new_batch_shape = [2, 2]
+ old_batch_shape = [4]
+
+ normal, reshape_normal = self.make_normal(new_batch_shape, old_batch_shape)
+
+ expected_scalar_stat_shape = new_batch_shape
+
+ expected_entropy = array_ops.reshape(
+ normal.entropy(), expected_scalar_stat_shape)
+ actual_entropy = reshape_normal.entropy()
+
+ expected_mean = array_ops.reshape(
+ normal.mean(), expected_scalar_stat_shape)
+ actual_mean = reshape_normal.mean()
+
+ expected_mode = array_ops.reshape(
+ normal.mode(), expected_scalar_stat_shape)
+ actual_mode = reshape_normal.mode()
+
+ expected_stddev = array_ops.reshape(
+ normal.stddev(), expected_scalar_stat_shape)
+ actual_stddev = reshape_normal.stddev()
+
+ expected_variance = array_ops.reshape(
+ normal.variance(), expected_scalar_stat_shape)
+ actual_variance = reshape_normal.variance()
+
+ with self.test_session() as sess:
+ [
+ expected_entropy_, actual_entropy_,
+ expected_mean_, actual_mean_,
+ expected_mode_, actual_mode_,
+ expected_stddev_, actual_stddev_,
+ expected_variance_, actual_variance_,
+ ] = sess.run([
+ expected_entropy, actual_entropy,
+ expected_mean, actual_mean,
+ expected_mode, actual_mode,
+ expected_stddev, actual_stddev,
+ expected_variance, actual_variance,
+ ])
+ self.assertAllClose(expected_entropy_, actual_entropy_,
+ atol=0., rtol=1e-6)
+ self.assertAllClose(expected_mean_, actual_mean_,
+ atol=0., rtol=1e-6)
+ self.assertAllClose(expected_mode_, actual_mode_,
+ atol=0., rtol=1e-6)
+ self.assertAllClose(expected_stddev_, actual_stddev_,
+ atol=0., rtol=1e-6)
+ self.assertAllClose(expected_variance_, actual_variance_,
+ atol=0., rtol=1e-6)
+ if not self.is_static_shape:
+ return
+ self.assertAllEqual(expected_scalar_stat_shape, actual_entropy.shape)
+ self.assertAllEqual(expected_scalar_stat_shape, actual_mean.shape)
+ self.assertAllEqual(expected_scalar_stat_shape, actual_mode.shape)
+ self.assertAllEqual(expected_scalar_stat_shape, actual_stddev.shape)
+ self.assertAllEqual(expected_scalar_stat_shape, actual_variance.shape)
+
+ def make_mvn(self, dims, new_batch_shape, old_batch_shape):
+ new_batch_shape_ph = (
+ constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape
+ else array_ops.placeholder_with_default(
+ np.int32(new_batch_shape), shape=None))
+
+ scale = np.ones(old_batch_shape + [dims], self.dtype)
+ scale_ph = array_ops.placeholder_with_default(
+ scale, shape=scale.shape if self.is_static_shape else None)
+ mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)
+ reshape_mvn = batch_reshape_lib.BatchReshape(
+ distribution=mvn,
+ batch_shape=new_batch_shape_ph,
+ validate_args=True)
+ return mvn, reshape_mvn
+
+ def test_vector_variate_sample_and_log_prob(self):
+ dims = 3
+ new_batch_shape = [2, 1]
+ old_batch_shape = [2]
+ mvn, reshape_mvn = self.make_mvn(
+ dims, new_batch_shape, old_batch_shape)
+
+ batch_shape = reshape_mvn.batch_shape_tensor()
+ event_shape = reshape_mvn.event_shape_tensor()
+
+ expected_sample_shape = [3] + new_batch_shape + [dims]
+ x = mvn.sample(3, seed=62)
+ expected_sample = array_ops.reshape(x, expected_sample_shape)
+ actual_sample = reshape_mvn.sample(3, seed=62)
+
+ expected_log_prob_shape = [3] + new_batch_shape
+ expected_log_prob = array_ops.reshape(
+ mvn.log_prob(x), expected_log_prob_shape)
+ actual_log_prob = reshape_mvn.log_prob(expected_sample)
+
+ with self.test_session() as sess:
+ [
+ batch_shape_,
+ event_shape_,
+ expected_sample_, actual_sample_,
+ expected_log_prob_, actual_log_prob_,
+ ] = sess.run([
+ batch_shape,
+ event_shape,
+ expected_sample, actual_sample,
+ expected_log_prob, actual_log_prob,
+ ])
+ self.assertAllEqual(new_batch_shape, batch_shape_)
+ self.assertAllEqual([dims], event_shape_)
+ self.assertAllClose(expected_sample_, actual_sample_,
+ atol=0., rtol=1e-6)
+ self.assertAllClose(expected_log_prob_, actual_log_prob_,
+ atol=0., rtol=1e-6)
+ if not self.is_static_shape:
+ return
+ self.assertAllEqual(new_batch_shape, reshape_mvn.batch_shape)
+ self.assertAllEqual([dims], reshape_mvn.event_shape)
+ self.assertAllEqual(expected_sample_shape, actual_sample.shape)
+ self.assertAllEqual(expected_log_prob_shape, actual_log_prob.shape)
+
+ def test_vector_variate_stats(self):
+ dims = 3
+ new_batch_shape = [2, 1]
+ old_batch_shape = [2]
+ mvn, reshape_mvn = self.make_mvn(
+ dims, new_batch_shape, old_batch_shape)
+
+ expected_scalar_stat_shape = new_batch_shape
+
+ expected_entropy = array_ops.reshape(
+ mvn.entropy(), expected_scalar_stat_shape)
+ actual_entropy = reshape_mvn.entropy()
+
+ expected_vector_stat_shape = new_batch_shape + [dims]
+
+ expected_mean = array_ops.reshape(
+ mvn.mean(), expected_vector_stat_shape)
+ actual_mean = reshape_mvn.mean()
+
+ expected_mode = array_ops.reshape(
+ mvn.mode(), expected_vector_stat_shape)
+ actual_mode = reshape_mvn.mode()
+
+ expected_stddev = array_ops.reshape(
+ mvn.stddev(), expected_vector_stat_shape)
+ actual_stddev = reshape_mvn.stddev()
+
+ expected_variance = array_ops.reshape(
+ mvn.variance(), expected_vector_stat_shape)
+ actual_variance = reshape_mvn.variance()
+
+ expected_matrix_stat_shape = new_batch_shape + [dims, dims]
+
+ expected_covariance = array_ops.reshape(
+ mvn.covariance(), expected_matrix_stat_shape)
+ actual_covariance = reshape_mvn.covariance()
+
+ with self.test_session() as sess:
+ [
+ expected_entropy_, actual_entropy_,
+ expected_mean_, actual_mean_,
+ expected_mode_, actual_mode_,
+ expected_stddev_, actual_stddev_,
+ expected_variance_, actual_variance_,
+ expected_covariance_, actual_covariance_,
+ ] = sess.run([
+ expected_entropy, actual_entropy,
+ expected_mean, actual_mean,
+ expected_mode, actual_mode,
+ expected_stddev, actual_stddev,
+ expected_variance, actual_variance,
+ expected_covariance, actual_covariance,
+ ])
+ self.assertAllClose(expected_entropy_, actual_entropy_,
+ atol=0., rtol=1e-6)
+ self.assertAllClose(expected_mean_, actual_mean_,
+ atol=0., rtol=1e-6)
+ self.assertAllClose(expected_mode_, actual_mode_,
+ atol=0., rtol=1e-6)
+ self.assertAllClose(expected_stddev_, actual_stddev_,
+ atol=0., rtol=1e-6)
+ self.assertAllClose(expected_variance_, actual_variance_,
+ atol=0., rtol=1e-6)
+ self.assertAllClose(expected_covariance_, actual_covariance_,
+ atol=0., rtol=1e-6)
+ if not self.is_static_shape:
+ return
+ self.assertAllEqual(expected_scalar_stat_shape, actual_entropy.shape)
+ self.assertAllEqual(expected_vector_stat_shape, actual_mean.shape)
+ self.assertAllEqual(expected_vector_stat_shape, actual_mode.shape)
+ self.assertAllEqual(expected_vector_stat_shape, actual_stddev.shape)
+ self.assertAllEqual(expected_vector_stat_shape, actual_variance.shape)
+ self.assertAllEqual(expected_matrix_stat_shape, actual_covariance.shape)
+
+ def test_bad_reshape_size(self):
+ dims = 2
+ new_batch_shape = [2, 3]
+ old_batch_shape = [2] # 2 != 2*3
+
+ new_batch_shape_ph = (
+ constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape
+ else array_ops.placeholder_with_default(
+ np.int32(new_batch_shape), shape=None))
+
+ scale = np.ones(old_batch_shape + [dims], self.dtype)
+ scale_ph = array_ops.placeholder_with_default(
+ scale, shape=scale.shape if self.is_static_shape else None)
+ mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)
+
+ if self.is_static_shape:
+ with self.assertRaisesRegexp(
+ ValueError, (r"`batch_shape` size \(6\) must match "
+ r"`distribution\.batch_shape` size \(2\)")):
+ batch_reshape_lib.BatchReshape(
+ distribution=mvn,
+ batch_shape=new_batch_shape_ph,
+ validate_args=True)
+
+ else:
+ with self.test_session():
+ with self.assertRaisesOpError(r"`batch_shape` size must match "
+ r"`distributions.batch_shape` size"):
+ batch_reshape_lib.BatchReshape(
+ distribution=mvn,
+ batch_shape=new_batch_shape_ph,
+ validate_args=True).sample().eval()
+
+ def test_non_positive_shape(self):
+ dims = 2
+ new_batch_shape = [-1, -2] # -1*-2=2 so will pass size check.
+ old_batch_shape = [2]
+
+ new_batch_shape_ph = (
+ constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape
+ else array_ops.placeholder_with_default(
+ np.int32(new_batch_shape), shape=None))
+
+ scale = np.ones(old_batch_shape + [dims], self.dtype)
+ scale_ph = array_ops.placeholder_with_default(
+ scale, shape=scale.shape if self.is_static_shape else None)
+ mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)
+
+ if self.is_static_shape:
+ with self.assertRaisesRegexp(ValueError, r".*must be positive.*"):
+ batch_reshape_lib.BatchReshape(
+ distribution=mvn,
+ batch_shape=new_batch_shape_ph,
+ validate_args=True)
+
+ else:
+ with self.test_session():
+ with self.assertRaisesOpError(r".*must be positive.*"):
+ batch_reshape_lib.BatchReshape(
+ distribution=mvn,
+ batch_shape=new_batch_shape_ph,
+ validate_args=True).sample().eval()
+
+ def test_non_vector_shape(self):
+ dims = 2
+ new_batch_shape = 2
+ old_batch_shape = [2]
+
+ new_batch_shape_ph = (
+ constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape
+ else array_ops.placeholder_with_default(
+ np.int32(new_batch_shape), shape=None))
+
+ scale = np.ones(old_batch_shape + [dims], self.dtype)
+ scale_ph = array_ops.placeholder_with_default(
+ scale, shape=scale.shape if self.is_static_shape else None)
+ mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)
+
+ if self.is_static_shape:
+ with self.assertRaisesRegexp(ValueError, r".*must be a vector.*"):
+ batch_reshape_lib.BatchReshape(
+ distribution=mvn,
+ batch_shape=new_batch_shape_ph,
+ validate_args=True)
+
+ else:
+ with self.test_session():
+ with self.assertRaisesOpError(r".*must be a vector.*"):
+ batch_reshape_lib.BatchReshape(
+ distribution=mvn,
+ batch_shape=new_batch_shape_ph,
+ validate_args=True).sample().eval()
+
+
+class BatchReshapeStaticTest(_BatchReshapeTest, test.TestCase):
+
+ dtype = np.float32
+ is_static_shape = True
+
+
+class BatchReshapeDynamicTest(_BatchReshapeTest, test.TestCase):
+
+ dtype = np.float64
+ is_static_shape = False
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py
index 3548ac1807..0400c80c29 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py
@@ -22,39 +22,75 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops import statistical_testing as st
from tensorflow.python.framework import errors
-from tensorflow.python.ops import check_ops
from tensorflow.python.platform import test
class StatisticalTestingTest(test.TestCase):
def test_dkwm_design_mean_one_sample_soundness(self):
- numbers = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10]
+ thresholds = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10]
rates = [1e-6, 1e-3, 1e-2, 1.1e-1, 0.2, 0.5, 0.7, 1.]
- with self.test_session() as sess:
- for ff in rates:
- for fp in rates:
- sufficient_n = st.min_num_samples_for_dkwm_mean_test(
- numbers, 0., 1., false_fail_rate=ff, false_pass_rate=fp)
- detectable_d = st.min_discrepancy_of_true_means_detectable_by_dkwm(
- sufficient_n, 0., 1., false_fail_rate=ff, false_pass_rate=fp)
- sess.run(check_ops.assert_less_equal(detectable_d, numbers))
+ false_fail_rates, false_pass_rates = np.meshgrid(rates, rates)
+ false_fail_rates = false_fail_rates.flatten().astype(np.float32)
+ false_pass_rates = false_pass_rates.flatten().astype(np.float32)
+
+ detectable_discrepancies = []
+ for false_pass_rate, false_fail_rate in zip(
+ false_pass_rates, false_fail_rates):
+ sufficient_n = st.min_num_samples_for_dkwm_mean_test(
+ thresholds, low=0., high=1., false_fail_rate=false_fail_rate,
+ false_pass_rate=false_pass_rate)
+ detectable_discrepancies.append(
+ st.min_discrepancy_of_true_means_detectable_by_dkwm(
+ sufficient_n, low=0., high=1., false_fail_rate=false_fail_rate,
+ false_pass_rate=false_pass_rate))
+
+ detectable_discrepancies_ = self.evaluate(detectable_discrepancies)
+ for discrepancies, false_pass_rate, false_fail_rate in zip(
+ detectable_discrepancies_, false_pass_rates, false_fail_rates):
+ below_threshold = discrepancies <= thresholds
+ self.assertAllEqual(
+ np.ones_like(below_threshold, np.bool), below_threshold,
+ msg='false_pass_rate({}), false_fail_rate({})'.format(
+ false_pass_rate, false_fail_rate))
def test_dkwm_design_mean_two_sample_soundness(self):
- numbers = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10]
+ thresholds = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10]
rates = [1e-6, 1e-3, 1e-2, 1.1e-1, 0.2, 0.5, 0.7, 1.]
- with self.test_session() as sess:
- for ff in rates:
- for fp in rates:
- (sufficient_n1,
- sufficient_n2) = st.min_num_samples_for_dkwm_mean_two_sample_test(
- numbers, 0., 1., 0., 1.,
- false_fail_rate=ff, false_pass_rate=fp)
- d_fn = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample
- detectable_d = d_fn(
- sufficient_n1, 0., 1., sufficient_n2, 0., 1.,
- false_fail_rate=ff, false_pass_rate=fp)
- sess.run(check_ops.assert_less_equal(detectable_d, numbers))
+ false_fail_rates, false_pass_rates = np.meshgrid(rates, rates)
+ false_fail_rates = false_fail_rates.flatten().astype(np.float32)
+ false_pass_rates = false_pass_rates.flatten().astype(np.float32)
+
+ detectable_discrepancies = []
+ for false_pass_rate, false_fail_rate in zip(
+ false_pass_rates, false_fail_rates):
+ [
+ sufficient_n1,
+ sufficient_n2
+ ] = st.min_num_samples_for_dkwm_mean_two_sample_test(
+ thresholds, low1=0., high1=1., low2=0., high2=1.,
+ false_fail_rate=false_fail_rate,
+ false_pass_rate=false_pass_rate)
+
+ detectable_discrepancies.append(
+ st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample(
+ n1=sufficient_n1,
+ low1=0.,
+ high1=1.,
+ n2=sufficient_n2,
+ low2=0.,
+ high2=1.,
+ false_fail_rate=false_fail_rate,
+ false_pass_rate=false_pass_rate))
+
+ detectable_discrepancies_ = self.evaluate(detectable_discrepancies)
+ for discrepancies, false_pass_rate, false_fail_rate in zip(
+ detectable_discrepancies_, false_pass_rates, false_fail_rates):
+ below_threshold = discrepancies <= thresholds
+ self.assertAllEqual(
+ np.ones_like(below_threshold, np.bool), below_threshold,
+ msg='false_pass_rate({}), false_fail_rate({})'.format(
+ false_pass_rate, false_fail_rate))
def test_true_mean_confidence_interval_by_dkwm_one_sample(self):
rng = np.random.RandomState(seed=0)
@@ -105,16 +141,16 @@ class StatisticalTestingTest(test.TestCase):
def test_dkwm_mean_two_sample_assertion(self):
rng = np.random.RandomState(seed=0)
- num_samples = 15000
+ num_samples = 4000
- # 15000 samples is chosen to be enough to find discrepancies of
- # size 0.1 or more with assurance 1e-6, as confirmed here:
+ # 4000 samples is chosen to be enough to find discrepancies of
+ # size 0.2 or more with assurance 1e-6, as confirmed here:
with self.test_session() as sess:
d = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample(
num_samples, 0., 1., num_samples, 0., 1.,
false_fail_rate=1e-6, false_pass_rate=1e-6)
d = sess.run(d)
- self.assertLess(d, 0.1)
+ self.assertLess(d, 0.2)
# Test that the test assertion agrees that the standard
# uniform distribution has the same mean as itself.
@@ -124,6 +160,15 @@ class StatisticalTestingTest(test.TestCase):
sess.run(st.assert_true_mean_equal_by_dkwm_two_sample(
samples1, 0., 1., samples2, 0., 1., false_fail_rate=1e-6))
+ def test_dkwm_mean_two_sample_assertion_beta_2_1_false(self):
+ rng = np.random.RandomState(seed=0)
+ num_samples = 4000
+ samples1 = rng.uniform(size=num_samples).astype(np.float32)
+
+ # As established above, 4000 samples is enough to find discrepancies
+ # of size 0.2 or more with assurance 1e-6.
+
+ with self.test_session() as sess:
# Test that the test assertion confirms that the mean of the
# standard uniform distribution is different from the mean of beta(2, 1).
beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32)
@@ -133,6 +178,15 @@ class StatisticalTestingTest(test.TestCase):
beta_high_samples, 0., 1.,
false_fail_rate=1e-6))
+ def test_dkwm_mean_two_sample_assertion_beta_1_2_false(self):
+ rng = np.random.RandomState(seed=0)
+ num_samples = 4000
+ samples1 = rng.uniform(size=num_samples).astype(np.float32)
+
+ # As established above, 4000 samples is enough to find discrepancies
+ # of size 0.2 or more with assurance 1e-6.
+
+ with self.test_session() as sess:
# Test that the test assertion confirms that the mean of the
# standard uniform distribution is different from the mean of beta(1, 2).
beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
new file mode 100644
index 0000000000..c7ee9b2117
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
@@ -0,0 +1,333 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The BatchReshape distribution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import distribution as distribution_lib
+
+
+__all__ = [
+ "BatchReshape",
+]
+
+
+class BatchReshape(distribution_lib.Distribution):
+ """The Batch-Reshaping distribution.
+
+ This "meta-distribution" reshapes the batch dimensions of another
+ distribution.
+
+ Note: Unlike `tf.reshape`, the `BatchReshape` distribution does not support
+ `-1` for flattening.
+
+ #### Examples
+
+ ```python
+ tfd = tf.contrib.distributions
+
+ dtype = np.float32
+ dims = 2
+ new_batch_shape = [1, 2, 3]
+ old_batch_shape = [6]
+
+ scale = np.ones(old_batch_shape + [dims], dtype)
+ mvn = tfd.MultivariateNormalDiag(scale_diag=scale)
+ reshape_mvn = tfd.BatchReshape(
+ distribution=mvn,
+ batch_shape=new_batch_shape,
+ validate_args=True)
+
+ reshape_mvn.batch_shape
+ # ==> [1, 2, 3]
+
+ x = reshape_mvn.sample(sample_shape=[4, 5])
+ x.shape
+ # ==> [4, 5, 1, 2, 3, 2] == sample_shape + new_batch_shape + [dims]
+
+ reshape_mvn.log_prob(x).shape
+ # ==> [4, 5, 1, 2, 3] == sample_shape + new_batch_shape
+ ```
+
+ """
+
+ def __init__(self,
+ distribution,
+ batch_shape,
+ validate_args=False,
+ allow_nan_stats=True,
+ name=None):
+ """Construct BatchReshape distribution.
+
+ Args:
+ distribution: The base distribution instance to reshape. Typically an
+ instance of `Distribution`.
+ batch_shape: Positive `int`-like vector-shaped `Tensor` representing the
+ new shape of the batch dimensions.
+ validate_args: Python `bool`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: The name to give Ops created by the initializer.
+ Default value: `"BatchReshape" + distribution.name`.
+
+ Raises:
+ ValueError: if `batch_shape` is not a vector.
+ ValueError: if `batch_shape` has non-positive elements.
+ ValueError: if `batch_shape` size is not the same as a
+ `distribution.batch_shape` size.
+ """
+ parameters = locals()
+ name = name or "BatchReshape" + distribution.name
+ self._distribution = distribution
+ with ops.name_scope(name, values=[batch_shape]) as name:
+ self._batch_shape_ = ops.convert_to_tensor(
+ batch_shape,
+ dtype=dtypes.int32,
+ name="batch_shape")
+ self._batch_shape_static = tensor_util.constant_value(self._batch_shape_)
+ if self._batch_shape_static is not None:
+ self._batch_shape_static = np.int32(self._batch_shape_static)
+ self._runtime_assertions = make_runtime_assertions(
+ self._distribution,
+ self._batch_shape_,
+ validate_args,
+ self._batch_shape_static)
+ super(BatchReshape, self).__init__(
+ dtype=self._distribution.dtype,
+ reparameterization_type=self._distribution.reparameterization_type,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ parameters=parameters,
+ graph_parents=(
+ [self._batch_shape_] +
+ self._distribution._graph_parents), # pylint: disable=protected-access
+ name=name)
+
+ @property
+ def distribution(self):
+ return self._distribution
+
+ def _batch_shape_tensor(self):
+ with ops.control_dependencies(self._runtime_assertions):
+ return array_ops.identity(self._batch_shape_)
+
+ def _batch_shape(self):
+ return tensor_shape.TensorShape(self._batch_shape_static)
+
+ def _event_shape_tensor(self):
+ with ops.control_dependencies(self._runtime_assertions):
+ return array_ops.identity(self.distribution.event_shape_tensor())
+
+ def _event_shape(self):
+ return self.distribution.event_shape
+
+ def _sample_n(self, n, seed=None):
+ with ops.control_dependencies(self._runtime_assertions):
+ x = self.distribution.sample(sample_shape=n, seed=seed)
+ new_shape = array_ops.concat([
+ [n],
+ self.batch_shape_tensor(),
+ self.event_shape_tensor(),
+ ], axis=0)
+ return array_ops.reshape(x, new_shape)
+
+ def _log_prob(self, x):
+ return self._call_reshape_input_output(
+ self.distribution.log_prob, x)
+
+ def _prob(self, x):
+ return self._call_reshape_input_output(
+ self.distribution.prob, x)
+
+ def _log_cdf(self, x):
+ return self._call_reshape_input_output(
+ self.distribution.log_cdf, x)
+
+ def _cdf(self, x):
+ return self._call_reshape_input_output(
+ self.distribution.cdf, x)
+
+ def _log_survival_function(self, x):
+ return self._call_reshape_input_output(
+ self.distribution.log_survival_function, x)
+
+ def _survival_function(self, x):
+ return self._call_reshape_input_output(
+ self.distribution.survival_function, x)
+
+ def _entropy(self):
+ return self._call_and_reshape_output(
+ self.distribution.entropy,
+ [],
+ [tensor_shape.scalar()])
+
+ def _mean(self):
+ return self._call_and_reshape_output(self.distribution.mean)
+
+ def _mode(self):
+ return self._call_and_reshape_output(self.distribution.mode)
+
+ def _stddev(self):
+ return self._call_and_reshape_output(self.distribution.stddev)
+
+ def _variance(self):
+ return self._call_and_reshape_output(self.distribution.variance)
+
+ def _covariance(self):
+ return self._call_and_reshape_output(
+ self.distribution.covariance,
+ [self.event_shape_tensor()]*2,
+ [self.event_shape]*2)
+
+ def _sample_shape(self, x):
+ """Computes graph and static `sample_shape`."""
+ x_ndims = (array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims)
+ event_ndims = (array_ops.size(self.event_shape_tensor())
+ if self.event_shape.ndims is None
+ else self.event_shape.ndims)
+ batch_ndims = (array_ops.size(self.batch_shape_tensor())
+ if self.batch_shape.ndims is None
+ else self.batch_shape.ndims)
+ sample_ndims = x_ndims - batch_ndims - event_ndims
+ if isinstance(sample_ndims, int):
+ static_sample_shape = x.shape[:sample_ndims]
+ else:
+ static_sample_shape = tensor_shape.TensorShape(None)
+ if static_sample_shape.is_fully_defined():
+ sample_shape = np.int32(static_sample_shape.as_list())
+ else:
+ sample_shape = array_ops.shape(x)[:sample_ndims]
+ return sample_shape, static_sample_shape
+
+ def _call_reshape_input_output(self, fn, x):
+ """Calls `fn`, appropriately reshaping its input `x` and output."""
+ with ops.control_dependencies(self._runtime_assertions):
+ sample_shape, static_sample_shape = self._sample_shape(x)
+ old_shape = array_ops.concat([
+ sample_shape,
+ self.distribution.batch_shape_tensor(),
+ self.event_shape_tensor(),
+ ], axis=0)
+ result = fn(array_ops.reshape(x, old_shape))
+ new_shape = array_ops.concat([
+ sample_shape,
+ self.batch_shape_tensor(),
+ ], axis=0)
+ result = array_ops.reshape(result, new_shape)
+ if (static_sample_shape.ndims is not None and
+ self.batch_shape.ndims is not None):
+ new_shape = static_sample_shape.concatenate(self.batch_shape)
+ result.set_shape(result.shape.merge_with(new_shape))
+ return result
+
+ def _call_and_reshape_output(
+ self,
+ fn,
+ event_shape_list=None,
+ static_event_shape_list=None):
+ """Calls `fn` and appropriately reshapes its output."""
+ with ops.control_dependencies(self._runtime_assertions):
+ if event_shape_list is None:
+ event_shape_list = [self._event_shape_tensor()]
+ if static_event_shape_list is None:
+ static_event_shape_list = [self.event_shape]
+ new_shape = array_ops.concat(
+ [self.batch_shape_tensor()] + event_shape_list,
+ axis=0)
+ result = array_ops.reshape(fn(), new_shape)
+ if (self.batch_shape.ndims is not None and
+ self.event_shape.ndims is not None):
+ event_shape = tensor_shape.TensorShape([])
+ for rss in static_event_shape_list:
+ event_shape = event_shape.concatenate(rss)
+ static_shape = result.shape.merge_with(
+ self.batch_shape.concatenate(event_shape))
+ result.set_shape(static_shape)
+ return result
+
+
+def make_runtime_assertions(
+ distribution,
+ batch_shape,
+ validate_args,
+ batch_shape_static):
+ """Helper to __init__ which makes or raises assertions."""
+ runtime_assertions = []
+
+ if batch_shape.shape.ndims is not None:
+ if batch_shape.shape.ndims != 1:
+ raise ValueError("`batch_shape` must be a vector "
+ "(saw rank: {}).".format(
+ batch_shape.shape.ndims))
+ elif validate_args:
+ runtime_assertions += [
+ check_ops.assert_rank(
+ batch_shape,
+ 1,
+ message="`batch_shape` must be a vector.",
+ name="assert_batch_shape_is_vector"),
+ ]
+
+ batch_size_static = np.prod(batch_shape_static)
+ dist_batch_size_static = (
+ None if not distribution.batch_shape.is_fully_defined()
+ else np.prod(distribution.batch_shape).value)
+
+ if batch_size_static is not None and dist_batch_size_static is not None:
+ if batch_size_static != dist_batch_size_static:
+ raise ValueError("`batch_shape` size ({}) must match "
+ "`distribution.batch_shape` size ({}).".format(
+ batch_size_static,
+ dist_batch_size_static))
+ elif validate_args:
+ runtime_assertions += [
+ check_ops.assert_equal(
+ math_ops.reduce_prod(batch_shape),
+ math_ops.reduce_prod(distribution.batch_shape_tensor()),
+ message=("`batch_shape` size must match "
+ "`distributions.batch_shape` size."),
+ name="assert_batch_size"),
+ ]
+
+ if batch_shape_static is not None:
+ if np.any(batch_shape_static < 1):
+ raise ValueError("`batch_shape` elements must be positive "
+ "(i.e., larger than zero).")
+ elif validate_args:
+ runtime_assertions += [
+ check_ops.assert_positive(
+ batch_shape,
+ message=("`batch_shape` elements must be positive "
+ "(i.e., larger than zero)."),
+ name="assert_batch_shape_positive")
+ ]
+
+ return runtime_assertions
diff --git a/tensorflow/contrib/eager/proto/BUILD b/tensorflow/contrib/eager/proto/BUILD
index aedfec8924..b016d2dcb5 100644
--- a/tensorflow/contrib/eager/proto/BUILD
+++ b/tensorflow/contrib/eager/proto/BUILD
@@ -4,17 +4,6 @@ exports_files(["LICENSE"])
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tf_proto_library(
name = "checkpointable_object_graph_proto",
srcs = [
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index 80176397c0..edb9130266 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -80,6 +80,7 @@ cuda_py_test(
"//tensorflow/python/data",
"//tensorflow/python/eager:test",
],
+ tags = ["noguitar"],
)
py_library(
@@ -276,16 +277,3 @@ cuda_py_test(
"notsan",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "g3doc/sitemap.md",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py
index a8c47d76d1..5e1b64728a 100644
--- a/tensorflow/contrib/eager/python/checkpointable_utils_test.py
+++ b/tensorflow/contrib/eager/python/checkpointable_utils_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.keras._impl.keras.engine import sequential
from tensorflow.python.keras._impl.keras.engine import training
from tensorflow.python.layers import core
from tensorflow.python.ops import control_flow_ops
@@ -1036,6 +1037,38 @@ class CheckpointingTests(test.TestCase):
beta1_power, _ = optimizer._get_beta_accumulators()
self.assertAllEqual(3., self.evaluate(beta1_power))
+ @test_util.run_in_graph_and_eager_modes()
+ def test_sequential(self):
+ model = sequential.Sequential()
+ checkpoint = checkpointable_utils.Checkpoint(model=model)
+ model.add(core.Dense(4))
+ second_dense = core.Dense(5)
+ model.add(second_dense)
+ model(constant_op.constant([[1.]]))
+ checkpoint.restore(None).initialize_or_restore()
+ self.evaluate(second_dense.bias.assign(
+ constant_op.constant([1., 2., 3., 4., 5.])))
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.evaluate(second_dense.bias.assign(
+ constant_op.constant([5., 6., 7., 8., 9.])))
+ checkpoint.restore(save_path).assert_consumed().run_restore_ops()
+ self.assertAllEqual([1., 2., 3., 4., 5.], self.evaluate(second_dense.bias))
+
+ deferred_sequential = sequential.Sequential()
+ deferred_sequential_checkpoint = checkpointable_utils.Checkpoint(
+ model=deferred_sequential)
+ status = deferred_sequential_checkpoint.restore(save_path)
+ deferred_sequential.add(core.Dense(4))
+ deferred_sequential(constant_op.constant([[1.]]))
+ deferred_second_dense = core.Dense(5)
+ deferred_sequential.add(deferred_second_dense)
+ deferred_sequential(constant_op.constant([[1.]]))
+ status.run_restore_ops()
+ self.assertAllEqual([1., 2., 3., 4., 5.],
+ self.evaluate(deferred_second_dense.bias))
+
class TemplateTests(test.TestCase):
diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
index 6673653418..9adf47d505 100644
--- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
+++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
@@ -173,7 +173,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
right_in.append(tf.random_normal((1, size * 2)))
tracking.append(tf.random_normal((1, tracker_size * 2)))
- out = reducer(left_in, right_in=right_in, tracking=tracking)
+ out = reducer(left_in, right_in, tracking=tracking)
self.assertEqual(batch_size, len(out))
self.assertEqual(tf.float32, out[0].dtype)
self.assertEqual((1, size * 2), out[0].shape)
@@ -227,7 +227,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
self.assertEqual((batch_size, size * 2), stacks[0][0].shape)
for _ in range(2):
- out1, out2 = tracker(bufs, stacks=stacks)
+ out1, out2 = tracker(bufs, stacks)
self.assertIsNone(out2)
self.assertEqual(batch_size, len(out1))
self.assertEqual(tf.float32, out1[0].dtype)
@@ -260,7 +260,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
self.assertEqual(tf.int64, transitions.dtype)
self.assertEqual((num_transitions, 1), transitions.shape)
- out = s(buffers, transitions=transitions, training=True)
+ out = s(buffers, transitions, training=True)
self.assertEqual(tf.float32, out.dtype)
self.assertEqual((1, embedding_dims), out.shape)
@@ -286,15 +286,12 @@ class SpinnTest(test_util.TensorFlowTestCase):
vocab_size)
# Invoke model under non-training mode.
- logits = model(
- prem, premise_transition=prem_trans, hypothesis=hypo,
- hypothesis_transition=hypo_trans, training=False)
+ logits = model(prem, prem_trans, hypo, hypo_trans, training=False)
self.assertEqual(tf.float32, logits.dtype)
self.assertEqual((batch_size, d_out), logits.shape)
# Invoke model under training model.
- logits = model(prem, premise_transition=prem_trans, hypothesis=hypo,
- hypothesis_transition=hypo_trans, training=True)
+ logits = model(prem, prem_trans, hypo, hypo_trans, training=True)
self.assertEqual(tf.float32, logits.dtype)
self.assertEqual((batch_size, d_out), logits.shape)
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index c846343d6d..2be62c9438 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -9,23 +9,12 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
py_library(
name = "estimator_py",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
+ ":boosted_trees",
":dnn",
":dnn_linear_combined",
":extenders",
@@ -39,6 +28,36 @@ py_library(
)
py_library(
+ name = "boosted_trees",
+ srcs = ["python/estimator/boosted_trees.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:boosted_trees",
+ ],
+)
+
+py_test(
+ name = "boosted_trees_test",
+ size = "medium",
+ srcs = ["python/estimator/boosted_trees_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "notsan",
+ ],
+ deps = [
+ ":boosted_trees",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:training",
+ "//tensorflow/python/estimator:numpy_io",
+ "//tensorflow/python/feature_column",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
name = "dnn",
srcs = ["python/estimator/dnn.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index 6b9f9575b6..d2fc2c4bfa 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
+from tensorflow.contrib.estimator.python.estimator.boosted_trees import *
from tensorflow.contrib.estimator.python.estimator.dnn import *
from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import *
from tensorflow.contrib.estimator.python.estimator.extenders import *
@@ -44,6 +45,8 @@ _allowed_symbols = [
'DNNEstimator',
'DNNLinearCombinedEstimator',
'LinearEstimator',
+ 'boosted_trees_classifier_train_in_memory',
+ 'boosted_trees_regressor_train_in_memory',
'call_logit_fn',
'dnn_logit_fn_builder',
'linear_logit_fn_builder',
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
new file mode 100644
index 0000000000..5880164519
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
@@ -0,0 +1,323 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Boosted Trees estimators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees
+
+
+class _BoostedTreesEstimator(estimator.Estimator):
+ """An Estimator for Tensorflow Boosted Trees models."""
+
+ def __init__(self,
+ feature_columns,
+ n_batches_per_layer,
+ head,
+ model_dir=None,
+ weight_column=None,
+ n_trees=100,
+ max_depth=6,
+ learning_rate=0.1,
+ l1_regularization=0.,
+ l2_regularization=0.,
+ tree_complexity=0.,
+ config=None):
+ """Initializes a `BoostedTreesEstimator` instance.
+
+ Args:
+ feature_columns: An iterable containing all the feature columns used by
+ the model. All items in the set should be instances of classes derived
+ from `FeatureColumn`.
+ n_batches_per_layer: the number of batches to collect statistics per
+ layer.
+ head: the `Head` instance defined for Estimator.
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator
+ to continue training a previously saved model.
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It is used to downweight or boost examples during training. It
+ will be multiplied by the loss of the example. If it is a string, it is
+ used as a key to fetch weight tensor from the `features`. If it is a
+ `_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
+ then weight_column.normalizer_fn is applied on it to get weight tensor.
+ n_trees: number trees to be created.
+ max_depth: maximum depth of the tree to grow.
+ learning_rate: shrinkage parameter to be used when a tree added to the
+ model.
+ l1_regularization: regularization multiplier applied to the absolute
+ weights of the tree leafs.
+ l2_regularization: regularization multiplier applied to the square weights
+ of the tree leafs.
+ tree_complexity: regularization factor to penalize trees with more leaves.
+ config: `RunConfig` object to configure the runtime settings.
+ """
+ # TODO(youngheek): param validations.
+
+ # HParams for the model.
+ tree_hparams = canned_boosted_trees.TreeHParams(
+ n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
+ tree_complexity)
+
+ def _model_fn(features, labels, mode, config):
+ return canned_boosted_trees._bt_model_fn( # pylint: disable=protected-access
+ features, labels, mode, head, feature_columns, tree_hparams,
+ n_batches_per_layer, config)
+
+ super(_BoostedTreesEstimator, self).__init__(
+ model_fn=_model_fn, model_dir=model_dir, config=config)
+
+
+def boosted_trees_classifier_train_in_memory(
+ train_input_fn,
+ feature_columns,
+ model_dir=None,
+ n_classes=canned_boosted_trees._HOLD_FOR_MULTI_CLASS_SUPPORT,
+ weight_column=None,
+ label_vocabulary=None,
+ n_trees=100,
+ max_depth=6,
+ learning_rate=0.1,
+ l1_regularization=0.,
+ l2_regularization=0.,
+ tree_complexity=0.,
+ config=None,
+ train_hooks=None):
+ """Trains a boosted tree classifier with in memory dataset.
+
+ Example:
+
+ ```python
+ bucketized_feature_1 = bucketized_column(
+ numeric_column('feature_1'), BUCKET_BOUNDARIES_1)
+ bucketized_feature_2 = bucketized_column(
+ numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
+
+ def input_fn_train():
+ dataset = create-dataset-from-training-data
+ # Don't use repeat or cache, since it is assumed to be one epoch
+ # This is either tf.data.Dataset, or a tuple of feature dict and label.
+ return dataset
+
+ classifier = boosted_trees_classifier_train_in_memory(
+ train_input_fn,
+ feature_columns=[bucketized_feature_1, bucketized_feature_2],
+ n_trees=100,
+ ... <some other params>
+ )
+
+ def input_fn_eval():
+ ...
+ return dataset
+
+ metrics = classifier.evaluate(input_fn=input_fn_eval, steps=10)
+ ```
+
+ Args:
+ train_input_fn: the input function returns a dataset containing a single
+ epoch of *unbatched* features and labels.
+ feature_columns: An iterable containing all the feature columns used by
+ the model. All items in the set should be instances of classes derived
+ from `FeatureColumn`.
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator
+ to continue training a previously saved model.
+ n_classes: number of label classes. Default is binary classification.
+ Multiclass support is not yet implemented.
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It is used to downweight or boost examples during training. It
+ will be multiplied by the loss of the example. If it is a string, it is
+ used as a key to fetch weight tensor from the `features`. If it is a
+ `_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
+ then weight_column.normalizer_fn is applied on it to get weight tensor.
+ label_vocabulary: A list of strings represents possible label values. If
+ given, labels must be string type and have any value in
+ `label_vocabulary`. If it is not given, that means labels are
+ already encoded as integer or float within [0, 1] for `n_classes=2` and
+ encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
+ Also there will be errors if vocabulary is not provided and labels are
+ string.
+ n_trees: number trees to be created.
+ max_depth: maximum depth of the tree to grow.
+ learning_rate: shrinkage parameter to be used when a tree added to the
+ model.
+ l1_regularization: regularization multiplier applied to the absolute
+ weights of the tree leafs.
+ l2_regularization: regularization multiplier applied to the square weights
+ of the tree leafs.
+ tree_complexity: regularization factor to penalize trees with more leaves.
+ config: `RunConfig` object to configure the runtime settings.
+ train_hooks: a list of Hook instances to be passed to estimator.train().
+
+ Returns:
+ a `BoostedTreesClassifier` instance created with the given arguments and
+ trained with the data loaded up on memory from the input_fn.
+
+ Raises:
+ ValueError: when wrong arguments are given or unsupported functionalities
+ are requested.
+ """
+ # pylint: disable=protected-access
+ # TODO(nponomareva): Support multi-class cases.
+ if n_classes == canned_boosted_trees._HOLD_FOR_MULTI_CLASS_SUPPORT:
+ n_classes = 2
+ head, closed_form = (
+ canned_boosted_trees._create_classification_head_and_closed_form(
+ n_classes, weight_column, label_vocabulary=label_vocabulary))
+
+ # HParams for the model.
+ tree_hparams = canned_boosted_trees.TreeHParams(
+ n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
+ tree_complexity)
+
+ def _model_fn(features, labels, mode, config):
+ return canned_boosted_trees._bt_model_fn(
+ features,
+ labels,
+ mode,
+ head,
+ feature_columns,
+ tree_hparams,
+ n_batches_per_layer=1,
+ config=config,
+ closed_form_grad_and_hess_fn=closed_form,
+ train_in_memory=True)
+
+ in_memory_classifier = estimator.Estimator(
+ model_fn=_model_fn, model_dir=model_dir, config=config)
+
+ in_memory_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
+
+ return in_memory_classifier
+ # pylint: enable=protected-access
+
+
+def boosted_trees_regressor_train_in_memory(
+ train_input_fn,
+ feature_columns,
+ model_dir=None,
+ label_dimension=canned_boosted_trees._HOLD_FOR_MULTI_DIM_SUPPORT,
+ weight_column=None,
+ n_trees=100,
+ max_depth=6,
+ learning_rate=0.1,
+ l1_regularization=0.,
+ l2_regularization=0.,
+ tree_complexity=0.,
+ config=None,
+ train_hooks=None):
+ """Trains a boosted tree regressor with in memory dataset.
+
+ Example:
+
+ ```python
+ bucketized_feature_1 = bucketized_column(
+ numeric_column('feature_1'), BUCKET_BOUNDARIES_1)
+ bucketized_feature_2 = bucketized_column(
+ numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
+
+ def input_fn_train():
+ dataset = create-dataset-from-training-data
+ # Don't use repeat or cache, since it is assumed to be one epoch
+ # This is either tf.data.Dataset, or a tuple of feature dict and label.
+ return dataset
+
+ regressor = boosted_trees_regressor_train_in_memory(
+ train_input_fn,
+ feature_columns=[bucketized_feature_1, bucketized_feature_2],
+ n_trees=100,
+ ... <some other params>
+ )
+
+ def input_fn_eval():
+ ...
+ return dataset
+
+ metrics = regressor.evaluate(input_fn=input_fn_eval, steps=10)
+ ```
+
+ Args:
+ train_input_fn: the input function returns a dataset containing a single
+ epoch of *unbatched* features and labels.
+ feature_columns: An iterable containing all the feature columns used by
+ the model. All items in the set should be instances of classes derived
+ from `FeatureColumn`.
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator
+ to continue training a previously saved model.
+ label_dimension: Number of regression targets per example.
+ Multi-dimensional support is not yet implemented.
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It is used to downweight or boost examples during training. It
+ will be multiplied by the loss of the example. If it is a string, it is
+ used as a key to fetch weight tensor from the `features`. If it is a
+ `_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
+ then weight_column.normalizer_fn is applied on it to get weight tensor.
+ n_trees: number trees to be created.
+ max_depth: maximum depth of the tree to grow.
+ learning_rate: shrinkage parameter to be used when a tree added to the
+ model.
+ l1_regularization: regularization multiplier applied to the absolute
+ weights of the tree leafs.
+ l2_regularization: regularization multiplier applied to the square weights
+ of the tree leafs.
+ tree_complexity: regularization factor to penalize trees with more leaves.
+ config: `RunConfig` object to configure the runtime settings.
+ train_hooks: a list of Hook instances to be passed to estimator.train().
+
+ Returns:
+ a `BoostedTreesClassifier` instance created with the given arguments and
+ trained with the data loaded up on memory from the input_fn.
+
+ Raises:
+ ValueError: when wrong arguments are given or unsupported functionalities
+ are requested.
+ """
+ # pylint: disable=protected-access
+ # TODO(nponomareva): Extend it to multi-dimension cases.
+ if label_dimension == canned_boosted_trees._HOLD_FOR_MULTI_DIM_SUPPORT:
+ label_dimension = 1
+ head = canned_boosted_trees._create_regression_head(label_dimension,
+ weight_column)
+
+ # HParams for the model.
+ tree_hparams = canned_boosted_trees.TreeHParams(
+ n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
+ tree_complexity)
+
+ def _model_fn(features, labels, mode, config):
+ return canned_boosted_trees._bt_model_fn(
+ features,
+ labels,
+ mode,
+ head,
+ feature_columns,
+ tree_hparams,
+ n_batches_per_layer=1,
+ config=config,
+ train_in_memory=True)
+
+ in_memory_regressor = estimator.Estimator(
+ model_fn=_model_fn, model_dir=model_dir, config=config)
+
+ in_memory_regressor.train(input_fn=train_input_fn, hooks=train_hooks)
+
+ return in_memory_regressor
+ # pylint: enable=protected-access
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
new file mode 100644
index 0000000000..e99a87f3b3
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
@@ -0,0 +1,207 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests boosted_trees estimators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.estimator.python.estimator import boosted_trees
+from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import checkpoint_utils
+
+NUM_FEATURES = 3
+
+BUCKET_BOUNDARIES = [-2., .5, 12.] # Boundaries for all the features.
+INPUT_FEATURES = np.array(
+ [
+ [12.5, 1.0, -2.001, -2.0001, -1.999], # feature_0 quantized:[3,2,0,0,1]
+ [2.0, -3.0, 0.5, 0.0, 0.4995], # feature_1 quantized:[2,0,2,1,1]
+ [3.0, 20.0, 50.0, -100.0, 102.75], # feature_2 quantized:[2,3,3,0,3]
+ ],
+ dtype=np.float32)
+CLASSIFICATION_LABELS = [[0.], [1.], [1.], [0.], [0.]]
+REGRESSION_LABELS = [[1.5], [0.3], [0.2], [2.], [5.]]
+FEATURES_DICT = {'f_%d' % i: INPUT_FEATURES[i] for i in range(NUM_FEATURES)}
+
+
+def _make_train_input_fn(is_classification):
+ """Makes train input_fn for classification/regression."""
+
+ def _input_fn():
+ features = dict(FEATURES_DICT)
+ if is_classification:
+ labels = CLASSIFICATION_LABELS
+ else:
+ labels = REGRESSION_LABELS
+ return features, labels
+
+ return _input_fn
+
+
+class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._head = canned_boosted_trees._create_regression_head(label_dimension=1)
+ self._feature_columns = {
+ feature_column.bucketized_column(
+ feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+ BUCKET_BOUNDARIES)
+ for i in range(NUM_FEATURES)
+ }
+
+ def _assert_checkpoint(self, model_dir, expected_global_step):
+ self.assertEqual(expected_global_step,
+ checkpoint_utils.load_variable(model_dir,
+ ops.GraphKeys.GLOBAL_STEP))
+
+ def testTrainAndEvaluateEstimator(self):
+ input_fn = _make_train_input_fn(is_classification=False)
+
+ est = boosted_trees._BoostedTreesEstimator(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ head=self._head,
+ max_depth=5)
+
+ # It will stop after 10 steps because of the max depth and num trees.
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+ self._assert_checkpoint(est.model_dir, 11)
+ eval_res = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertAllClose(eval_res['average_loss'], 0.913176)
+
+ def testInferEstimator(self):
+ train_input_fn = _make_train_input_fn(is_classification=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees._BoostedTreesEstimator(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ head=self._head)
+
+ # It will stop after 5 steps because of the max depth and num trees.
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(train_input_fn, steps=num_steps)
+ self._assert_checkpoint(est.model_dir, 6)
+
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertEquals(5, len(predictions))
+ self.assertAllClose([0.703549], predictions[0]['predictions'])
+ self.assertAllClose([0.266539], predictions[1]['predictions'])
+ self.assertAllClose([0.256479], predictions[2]['predictions'])
+ self.assertAllClose([1.088732], predictions[3]['predictions'])
+ self.assertAllClose([1.901732], predictions[4]['predictions'])
+
+
+class BoostedTreesClassifierTrainInMemoryTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._feature_columns = {
+ feature_column.bucketized_column(
+ feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+ BUCKET_BOUNDARIES)
+ for i in range(NUM_FEATURES)
+ }
+
+ def _assert_checkpoint(self, model_dir, expected_global_step):
+ self.assertEqual(expected_global_step,
+ checkpoint_utils.load_variable(model_dir,
+ ops.GraphKeys.GLOBAL_STEP))
+
+ def testBinaryClassifierTrainInMemoryAndEvalAndInfer(self):
+ train_input_fn = _make_train_input_fn(is_classification=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.boosted_trees_classifier_train_in_memory(
+ train_input_fn=train_input_fn,
+ feature_columns=self._feature_columns,
+ n_trees=1,
+ max_depth=5)
+ # It will stop after 5 steps because of the max depth and num trees.
+ self._assert_checkpoint(est.model_dir, 6)
+
+ # Check eval.
+ eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+ self.assertAllClose(eval_res['accuracy'], 1.0)
+
+ # Check predict that all labels are correct.
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertEquals(5, len(predictions))
+ self.assertAllClose([0], predictions[0]['class_ids'])
+ self.assertAllClose([1], predictions[1]['class_ids'])
+ self.assertAllClose([1], predictions[2]['class_ids'])
+ self.assertAllClose([0], predictions[3]['class_ids'])
+ self.assertAllClose([0], predictions[4]['class_ids'])
+
+
+class BoostedTreesRegressorTrainInMemoryTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._feature_columns = {
+ feature_column.bucketized_column(
+ feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+ BUCKET_BOUNDARIES)
+ for i in range(NUM_FEATURES)
+ }
+
+ def _assert_checkpoint(self, model_dir, expected_global_step):
+ self.assertEqual(expected_global_step,
+ checkpoint_utils.load_variable(model_dir,
+ ops.GraphKeys.GLOBAL_STEP))
+
+ def testRegressorTrainInMemoryAndEvalAndInfer(self):
+ train_input_fn = _make_train_input_fn(is_classification=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.boosted_trees_regressor_train_in_memory(
+ train_input_fn=train_input_fn,
+ feature_columns=self._feature_columns,
+ n_trees=1,
+ max_depth=5)
+ # It will stop after 5 steps because of the max depth and num trees.
+ self._assert_checkpoint(est.model_dir, 6)
+
+ # Check eval.
+ eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+ self.assertAllClose(eval_res['average_loss'], 2.2136638)
+
+ # Validate predictions.
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertEquals(5, len(predictions))
+ self.assertAllClose([0.703549], predictions[0]['predictions'])
+ self.assertAllClose([0.266539], predictions[1]['predictions'])
+ self.assertAllClose([0.256479], predictions[2]['predictions'])
+ self.assertAllClose([1.088732], predictions[3]['predictions'])
+ self.assertAllClose([1.901732], predictions[4]['predictions'])
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD
index ad8568ad44..0a648d5d40 100644
--- a/tensorflow/contrib/factorization/BUILD
+++ b/tensorflow/contrib/factorization/BUILD
@@ -347,16 +347,3 @@ cuda_py_test(
],
main = "python/kernel_tests/masked_matmul_benchmark.py",
)
-
-# All files
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/factorization/examples/BUILD b/tensorflow/contrib/factorization/examples/BUILD
index bbe842bd5c..363baa121a 100644
--- a/tensorflow/contrib/factorization/examples/BUILD
+++ b/tensorflow/contrib/factorization/examples/BUILD
@@ -21,14 +21,3 @@ tf_py_test(
],
tags = ["notsan"],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/factorization/kernels/BUILD b/tensorflow/contrib/factorization/kernels/BUILD
index 44eab56011..ea8b9a17a2 100644
--- a/tensorflow/contrib/factorization/kernels/BUILD
+++ b/tensorflow/contrib/factorization/kernels/BUILD
@@ -67,14 +67,3 @@ tf_cc_test(
"//tensorflow/core:testlib",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
index 3f3e3e0f25..811fa89bc3 100644
--- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
@@ -218,11 +218,11 @@ class WALSModel(object):
- When set to None, w_ij = unobserved_weight, which simplifies to ALS.
Note that col_weights must also be set to "None" in this case.
- If it is a list of lists of non-negative real numbers, it needs to be
- in the form of \\([[w_0, w_1, ...], [w_k, ... ], [...]]\\), with the
- number of inner lists matching the number of row factor shards and the
- elements in each inner list are the weights for the rows of the
- corresponding row factor shard. In this case, \\(w_ij\\) =
- unobserved_weight + row_weights[i] * col_weights[j].
+ in the form of [[w_0, w_1, ...], [w_k, ... ], [...]], with the number of
+ inner lists matching the number of row factor shards and the elements in
+ each inner list are the weights for the rows of the corresponding row
+ factor shard. In this case, w_ij = unobserved_weight +
+ row_weights[i] * col_weights[j].
- If this is a single non-negative real number, this value is used for
all row weights and \\(w_ij\\) = unobserved_weight + row_weights *
col_weights[j].
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
index 002f9cfbdd..bb5140aeb3 100644
--- a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
@@ -283,8 +283,8 @@ class WalsModelTest(test.TestCase):
# Test column projection.
# Using the specified projection weights for the 3 column feature vectors.
- # This is expected to reproduce the same column factors in the model as the
- # weights and feature vectors are identical to that used in model
+ # This is expected to reproduce the same column factors in the model as
+ # the weights and feature vectors are identical to that used in model
# training.
projected_cols = wals_model.project_col_factors(
sp_input=sp_feeder,
@@ -462,8 +462,8 @@ class WalsModelTest(test.TestCase):
# Test column projection.
# Using the specified projection weights for the 2 column feature vectors.
- # This is expected to reproduce the same column factors in the model as the
- # weights and feature vectors are identical to that used in model
+ # This is expected to reproduce the same column factors in the model as
+ # the weights and feature vectors are identical to that used in model
# training.
projected_cols = wals_model.project_col_factors(
sp_input=sp_feeder,
diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD
index 3614b2b15a..aab7d0c9e8 100644
--- a/tensorflow/contrib/feature_column/BUILD
+++ b/tensorflow/contrib/feature_column/BUILD
@@ -8,18 +8,6 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "py_test")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
py_library(
name = "feature_column_py",
srcs = ["__init__.py"],
diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD
index eccce99071..f7b3273a4d 100644
--- a/tensorflow/contrib/ffmpeg/BUILD
+++ b/tensorflow/contrib/ffmpeg/BUILD
@@ -180,15 +180,3 @@ py_library(
"//tensorflow/python:util",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/ffmpeg/default/BUILD b/tensorflow/contrib/ffmpeg/default/BUILD
index 6b455567d7..59bad8982d 100644
--- a/tensorflow/contrib/ffmpeg/default/BUILD
+++ b/tensorflow/contrib/ffmpeg/default/BUILD
@@ -74,15 +74,3 @@ tf_cc_test(
"//tensorflow/core:test",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD
index ac043fda06..b1c8ad49ea 100644
--- a/tensorflow/contrib/framework/BUILD
+++ b/tensorflow/contrib/framework/BUILD
@@ -321,15 +321,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD
index ce37672895..0eb6889db1 100644
--- a/tensorflow/contrib/fused_conv/BUILD
+++ b/tensorflow/contrib/fused_conv/BUILD
@@ -157,15 +157,3 @@ cuda_py_test(
"requires_cudnn6",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index 0eb0e3cbe2..9e56d3c039 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -544,15 +544,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD
index 707ae25d48..e534fdc177 100644
--- a/tensorflow/contrib/gdr/BUILD
+++ b/tensorflow/contrib/gdr/BUILD
@@ -10,18 +10,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
-filegroup(
name = "c_srcs",
data = glob([
"**/*.cc",
diff --git a/tensorflow/contrib/graph_editor/BUILD b/tensorflow/contrib/graph_editor/BUILD
index 967ad2fc09..1711100e3a 100644
--- a/tensorflow/contrib/graph_editor/BUILD
+++ b/tensorflow/contrib/graph_editor/BUILD
@@ -39,18 +39,6 @@ py_library(
],
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
py_library(
name = "match",
srcs = ["tests/match.py"],
diff --git a/tensorflow/contrib/grid_rnn/BUILD b/tensorflow/contrib/grid_rnn/BUILD
index d601a1ec6f..d0b4464066 100644
--- a/tensorflow/contrib/grid_rnn/BUILD
+++ b/tensorflow/contrib/grid_rnn/BUILD
@@ -41,15 +41,3 @@ cuda_py_tests(
"//tensorflow/python:variables",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/hooks/BUILD b/tensorflow/contrib/hooks/BUILD
index 1b528d7afc..d65b2d6026 100644
--- a/tensorflow/contrib/hooks/BUILD
+++ b/tensorflow/contrib/hooks/BUILD
@@ -23,14 +23,3 @@ py_library(
"//tensorflow/python:util",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD b/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD
index 324035100d..e39c60b252 100644
--- a/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD
+++ b/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD
@@ -13,18 +13,6 @@ exports_files(["LICENSE"])
package(default_visibility = ["//visibility:public"])
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tf_cc_binary(
name = "clock_cycle_profiling",
testonly = 1,
diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD
index 909dc396a3..0081fb6177 100644
--- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD
+++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD
@@ -10,17 +10,6 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
-
tf_cc_binary(
name = "hvx_ops_support_checker",
testonly = 1,
diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD
index 79eb3762ed..da450480b3 100755
--- a/tensorflow/contrib/image/BUILD
+++ b/tensorflow/contrib/image/BUILD
@@ -384,15 +384,3 @@ cuda_py_test(
"//tensorflow/python:platform_test",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/input_pipeline/BUILD b/tensorflow/contrib/input_pipeline/BUILD
index 9d6b4d5d87..0e34315db4 100644
--- a/tensorflow/contrib/input_pipeline/BUILD
+++ b/tensorflow/contrib/input_pipeline/BUILD
@@ -114,14 +114,3 @@ tf_cc_tests(
"//tensorflow/core:testlib",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/input_pipeline/kernels/BUILD b/tensorflow/contrib/input_pipeline/kernels/BUILD
index f20a6e38d4..797605b8fe 100644
--- a/tensorflow/contrib/input_pipeline/kernels/BUILD
+++ b/tensorflow/contrib/input_pipeline/kernels/BUILD
@@ -17,14 +17,3 @@ cc_library(
],
alwayslink = 1,
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/integrate/BUILD b/tensorflow/contrib/integrate/BUILD
index 66948c1ea1..0b7d64f4ed 100644
--- a/tensorflow/contrib/integrate/BUILD
+++ b/tensorflow/contrib/integrate/BUILD
@@ -42,14 +42,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/kafka/BUILD b/tensorflow/contrib/kafka/BUILD
index 1c3974871c..3913c9dc7a 100644
--- a/tensorflow/contrib/kafka/BUILD
+++ b/tensorflow/contrib/kafka/BUILD
@@ -119,17 +119,3 @@ tf_py_test(
"notap",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- include = [
- "**/*",
- ],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD
index 7e0019ce4a..7a4cab20d1 100644
--- a/tensorflow/contrib/keras/BUILD
+++ b/tensorflow/contrib/keras/BUILD
@@ -52,15 +52,3 @@ py_library(
"//tensorflow/python/keras",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD
index eff7dfeb4c..87c2dcd89b 100644
--- a/tensorflow/contrib/kernel_methods/BUILD
+++ b/tensorflow/contrib/kernel_methods/BUILD
@@ -90,15 +90,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/kfac/BUILD b/tensorflow/contrib/kfac/BUILD
index 9a5759bf14..b719046b37 100644
--- a/tensorflow/contrib/kfac/BUILD
+++ b/tensorflow/contrib/kfac/BUILD
@@ -24,15 +24,3 @@ py_library(
"//tensorflow/python:util",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/kfac/examples/BUILD b/tensorflow/contrib/kfac/examples/BUILD
index 89965eda37..7dd40c19c5 100644
--- a/tensorflow/contrib/kfac/examples/BUILD
+++ b/tensorflow/contrib/kfac/examples/BUILD
@@ -58,15 +58,3 @@ py_library(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/kfac/examples/tests/BUILD b/tensorflow/contrib/kfac/examples/tests/BUILD
index ce7da95c12..ede7f183fe 100644
--- a/tensorflow/contrib/kfac/examples/tests/BUILD
+++ b/tensorflow/contrib/kfac/examples/tests/BUILD
@@ -50,15 +50,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
index 146ae8b7e2..f73c24f8fb 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
@@ -155,15 +155,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD
index d721ad08af..b897fd68a0 100644
--- a/tensorflow/contrib/kfac/python/ops/BUILD
+++ b/tensorflow/contrib/kfac/python/ops/BUILD
@@ -244,15 +244,3 @@ py_library(
"//tensorflow/python:util",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index b04bf76a88..e0d9cb5ea9 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -861,12 +861,12 @@ class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
super(ConvKFCBasicFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
+ inputs, grads_list = self._process_data(grads_list)
+
# Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(),
+ self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
self._strides)
- inputs, grads_list = self._process_data(grads_list)
-
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvInputKroneckerFactor,
(inputs, self._filter_shape, self._padding, self._strides,
@@ -1391,7 +1391,7 @@ class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse,
inputs, grads_list = self._process_data(grads_list)
# Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(inputs.shape.as_list(),
+ self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
self._strides)
self._input_factor = self._layer_collection.make_or_get_factor(
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index 353e1c6abb..0d40d265a1 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -336,12 +336,16 @@ class FisherFactor(object):
new_cov = math_ops.add_n(new_cov_contribs) / float(self._num_towers)
- # I have no idea if the TPU code below is still correct since I don't know
- # what it actually does. Also, this code is not present in some of the
- # other versions of make_covariance_update_op. Does it matter?
- # Synchronize value across all TPU cores.
+ # Compute average of 'new_cov' across all TPU cores. On a TPU, each
+ # instance of 'new_cov' will be based on a different minibatch. This ensures
+ # that by the end of assign_moving_average(), all TPU cores see the same
+ # value for self._cov.
+ #
+ # Other implementations of make_covariance_update_op() that accumulate
+ # statistics in other variables should mimic this behavior.
if utils.on_tpu():
new_cov = utils.cross_replica_mean(new_cov)
+
return moving_averages.assign_moving_average(
self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
@@ -1398,6 +1402,10 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
new_cov_dt1 = (math_ops.add_n(new_cov_dt1_contribs)
/ float(self._num_towers))
+ # See comments in FisherFactor.make_covariance_update_op() for details.
+ if utils.on_tpu():
+ new_cov_dt1 = utils.cross_replica_mean(new_cov_dt1)
+
op2 = moving_averages.assign_moving_average(
self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS)
diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD
index 894e6f6946..18b265ae80 100644
--- a/tensorflow/contrib/labeled_tensor/BUILD
+++ b/tensorflow/contrib/labeled_tensor/BUILD
@@ -213,14 +213,3 @@ py_test(
"//tensorflow/python:math_ops",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD
index 852d06e1e3..4be55468db 100644
--- a/tensorflow/contrib/layers/BUILD
+++ b/tensorflow/contrib/layers/BUILD
@@ -390,15 +390,3 @@ py_test(
"//tensorflow/python:variables",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/layers/kernels/BUILD b/tensorflow/contrib/layers/kernels/BUILD
index e407a9ce01..7aae09ff3e 100644
--- a/tensorflow/contrib/layers/kernels/BUILD
+++ b/tensorflow/contrib/layers/kernels/BUILD
@@ -18,14 +18,3 @@ cc_library(
],
alwayslink = 1,
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index 0b38c0c3fd..e49589ddf6 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -33,6 +33,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.framework.python import ops as contrib_framework_ops
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops as framework_ops
from tensorflow.python.layers import base
@@ -660,7 +661,9 @@ def _force_data_dependency(first_compute, then_compute):
if x.get_shape().ndims is None:
raise ValueError("Rank of Tensor %s must be known" % x)
ndims = x.get_shape().ndims
- return array_ops.reshape(array_ops.slice(x, [0] * ndims, [1] * ndims), [])
+ begin = framework_ops.convert_to_tensor([0] * ndims, dtype=dtypes.int32)
+ size = framework_ops.convert_to_tensor([1] * ndims, dtype=dtypes.int32)
+ return array_ops.reshape(array_ops.slice(x, begin, size), [])
first_compute_sum = math_ops.add_n(
[_first_element(x) for x in first_compute if x is not None])
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 16f80a876f..ba55365c14 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -874,15 +874,3 @@ py_binary(
"//tensorflow/python:platform",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/learn/python/learn/datasets/BUILD b/tensorflow/contrib/learn/python/learn/datasets/BUILD
index 8bf372841d..2c7215bba3 100644
--- a/tensorflow/contrib/learn/python/learn/datasets/BUILD
+++ b/tensorflow/contrib/learn/python/learn/datasets/BUILD
@@ -44,18 +44,6 @@ py_binary(
],
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
py_test(
name = "base_test",
size = "small",
diff --git a/tensorflow/contrib/learn/python/learn/datasets/base.py b/tensorflow/contrib/learn/python/learn/datasets/base.py
index 3b5c9b97c0..4676eedb20 100644
--- a/tensorflow/contrib/learn/python/learn/datasets/base.py
+++ b/tensorflow/contrib/learn/python/learn/datasets/base.py
@@ -139,15 +139,48 @@ def retry(initial_delay,
Args:
initial_delay: the initial delay.
+ max_delay: the maximum delay allowed (actual max is
+ max_delay * (1 + jitter).
factor: each subsequent retry, the delay is multiplied by this value.
(must be >= 1).
jitter: to avoid lockstep, the returned delay is multiplied by a random
number between (1-jitter) and (1+jitter). To add a 20% jitter, set
jitter = 0.2. Must be < 1.
+ is_retriable: (optional) a function that takes an Exception as an argument
+ and returns true if retry should be applied.
+
+ Returns:
+ A function that wraps another function to automatically retry it.
+ """
+ return _internal_retry(
+ initial_delay=initial_delay,
+ max_delay=max_delay,
+ factor=factor,
+ jitter=jitter,
+ is_retriable=is_retriable)
+
+
+def _internal_retry(initial_delay,
+ max_delay,
+ factor=2.0,
+ jitter=0.25,
+ is_retriable=None):
+ """Simple decorator for wrapping retriable functions, for internal use only.
+
+ Args:
+ initial_delay: the initial delay.
max_delay: the maximum delay allowed (actual max is
max_delay * (1 + jitter).
+ factor: each subsequent retry, the delay is multiplied by this value.
+ (must be >= 1).
+ jitter: to avoid lockstep, the returned delay is multiplied by a random
+ number between (1-jitter) and (1+jitter). To add a 20% jitter, set
+ jitter = 0.2. Must be < 1.
is_retriable: (optional) a function that takes an Exception as an argument
and returns true if retry should be applied.
+
+ Returns:
+ A function that wraps another function to automatically retry it.
"""
if factor < 1:
raise ValueError('factor must be >= 1; was %f' % (factor,))
@@ -195,7 +228,7 @@ def _is_retriable(e):
@deprecated(None, 'Please use urllib or similar directly.')
-@retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable)
+@_internal_retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable)
def urlretrieve_with_retry(url, filename=None):
return urllib.request.urlretrieve(url, filename)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py
index 1d161093de..f3500bf56f 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py
@@ -290,8 +290,15 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
Note - using this argument, it is easy to provide settings which break
otherwise perfectly good models. Use with care.
"""
- super(RunConfig, self).__init__(
- master=master, evaluation_master=evaluation_master)
+ # Neither parent class calls super().__init__(), so here we have to
+ # manually call their __init__() methods.
+ ClusterConfig.__init__(
+ self, master=master, evaluation_master=evaluation_master)
+ # For too long this code didn't call:
+ # core_run_config.RunConfig.__init__(self)
+ # so instead of breaking compatibility with that assumption, we
+ # just manually initialize this field:
+ self._distribute = None
gpu_options = config_pb2.GPUOptions(
per_process_gpu_memory_fraction=gpu_memory_fraction)
diff --git a/tensorflow/contrib/legacy_seq2seq/BUILD b/tensorflow/contrib/legacy_seq2seq/BUILD
index 1fa55132b1..8c2c4fd29c 100644
--- a/tensorflow/contrib/legacy_seq2seq/BUILD
+++ b/tensorflow/contrib/legacy_seq2seq/BUILD
@@ -60,15 +60,3 @@ cuda_py_tests(
],
tags = ["noasan"], # times out b/63678675
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/libsvm/BUILD b/tensorflow/contrib/libsvm/BUILD
index df96402a4f..4dccb9be7c 100644
--- a/tensorflow/contrib/libsvm/BUILD
+++ b/tensorflow/contrib/libsvm/BUILD
@@ -88,15 +88,3 @@ tf_py_test(
"//tensorflow/python:platform_test",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD
index 359255374d..a7812f74d1 100644
--- a/tensorflow/contrib/linalg/BUILD
+++ b/tensorflow/contrib/linalg/BUILD
@@ -61,15 +61,3 @@ cuda_py_test(
shard_count = 4,
tags = ["noasan"],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD
index cea3627ed5..5b89c6cef9 100644
--- a/tensorflow/contrib/linear_optimizer/BUILD
+++ b/tensorflow/contrib/linear_optimizer/BUILD
@@ -138,14 +138,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index 18efa64507..ac269d540a 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -271,18 +271,3 @@ cc_test(
# ],
# }),
#)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "downloads",
- "examples",
- "gen",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index d7993e60cc..17b791e4e2 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -79,6 +79,7 @@ typedef enum {
kTfLiteBuiltinBidirectionalSequenceLstm = 52,
kTfLiteBuiltinCast = 53,
kTfLiteBuiltinPrelu = 54,
+ kTfLiteBuiltinMaximum = 55,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/contrib/lite/examples/label_image/BUILD
index 959347b549..9322e186a2 100644
--- a/tensorflow/contrib/lite/examples/label_image/BUILD
+++ b/tensorflow/contrib/lite/examples/label_image/BUILD
@@ -69,15 +69,3 @@ cc_library(
# "//testing/base/public:gunit",
# ],
# )
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD
index f52d6ba6c5..7f7a2632dd 100644
--- a/tensorflow/contrib/lite/java/BUILD
+++ b/tensorflow/contrib/lite/java/BUILD
@@ -167,15 +167,3 @@ tflite_jni_binary(
"//tensorflow/contrib/lite/java/src/main/native",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
index 5eb749aae6..d6fbef9cc9 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
@@ -27,15 +27,3 @@ android_binary(
"@androidsdk//com.android.support:support-v4-25.2.0",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD
index dd0cd6c98f..ce68160b68 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD
@@ -10,15 +10,3 @@ exports_files(
],
),
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/lite/java/src/main/native/BUILD b/tensorflow/contrib/lite/java/src/main/native/BUILD
index 3571182ca9..4399ed2025 100644
--- a/tensorflow/contrib/lite/java/src/main/native/BUILD
+++ b/tensorflow/contrib/lite/java/src/main/native/BUILD
@@ -95,15 +95,3 @@ exports_files(
"version_script.lds",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD
index 2b4f37bc6c..b524246d43 100644
--- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD
+++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD
@@ -16,15 +16,3 @@ android_library(
"//tensorflow/contrib/lite/java:tensorflowlite_java",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 48021aea47..df0f3cbeb0 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -723,6 +723,7 @@ tf_cc_test(
":builtin_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_absl//absl/memory",
"@com_google_googletest//:gtest",
],
)
@@ -923,16 +924,4 @@ tf_cc_test(
],
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tflite_portable_test_suite()
diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc
index a0f766c4f4..87413000a9 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected_test.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc
@@ -19,12 +19,25 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
#include "tensorflow/contrib/lite/model.h"
namespace tflite {
+
+namespace ops {
+namespace builtin {
+
+TfLiteRegistration* Register_FULLY_CONNECTED_REF();
+TfLiteRegistration* Register_FULLY_CONNECTED_NEON_OPT();
+TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT();
+TfLiteRegistration* Register_FULLY_CONNECTED_PIE();
+
+} // namespace builtin
+} // namespace ops
+
namespace {
using ::testing::ElementsAre;
@@ -119,7 +132,8 @@ static float fully_connected_golden_output[] = {
class BaseFullyConnectedOpModel : public SingleOpModel {
public:
// TODO(ahentz): test different activation types too.
- BaseFullyConnectedOpModel(int units, int batches, const TensorData& input,
+ BaseFullyConnectedOpModel(TfLiteRegistration* registration, int units,
+ int batches, const TensorData& input,
const TensorData& output = {TensorType_FLOAT32})
: batches_(batches), units_(units) {
int total_input_size = 1;
@@ -149,6 +163,8 @@ class BaseFullyConnectedOpModel : public SingleOpModel {
BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
.Union());
+ resolver_ = absl::make_unique<SingleOpResolver>(
+ BuiltinOperator_FULLY_CONNECTED, registration);
BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
}
@@ -208,10 +224,25 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
}
};
+const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
+ {"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()},
+ {"NeonOptimized", ops::builtin::Register_FULLY_CONNECTED_NEON_OPT()},
+ {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
+ {"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()},
+});
+
+class FullyConnectedOpTest : public SingleOpTest {
+ protected:
+ const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+ return *kKernelMap;
+ }
+};
+
// TODO(ahentz): add more small tests like this one, focused on making sure the
// calculations are correct.
-TEST(FullyConnectedOpTest, SimpleTest) {
- FloatFullyConnectedOpModel m(3, 2, {TensorType_FLOAT32, {2, 10}});
+TEST_P(FullyConnectedOpTest, SimpleTest) {
+ FloatFullyConnectedOpModel m(GetRegistration(), 3, 2,
+ {TensorType_FLOAT32, {2, 10}});
m.SetWeights({
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
@@ -229,9 +260,9 @@ TEST(FullyConnectedOpTest, SimpleTest) {
EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
}
-TEST(FullyConnectedOpTest, SimpleTestQuantized) {
+TEST_P(FullyConnectedOpTest, SimpleTestQuantized) {
QuantizedFullyConnectedOpModel m(
- 3, 2,
+ GetRegistration(), 3, 2,
/*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
/*output=*/{TensorType_UINT8, {}, -127, 128});
@@ -261,7 +292,8 @@ TEST(FullyConnectedOpTest, SimpleTest4DInput) {
// Note that it is not required that the first dimension be the number of
// batches. All we care is that the input can be evenly distributed in
// batches. In this case, we need the input to have multiples of '2'.
- FloatFullyConnectedOpModel m(/*units=*/3,
+ FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(),
+ /*units=*/3,
/*batches=*/2,
/*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}});
m.SetWeights({
@@ -284,9 +316,9 @@ TEST(FullyConnectedOpTest, SimpleTest4DInput) {
}));
}
-TEST(FullyConnectedOpTest, SimpleTest4dInputQuantized) {
+TEST_P(FullyConnectedOpTest, SimpleTest4dInputQuantized) {
QuantizedFullyConnectedOpModel m(
- 3, 2,
+ GetRegistration(), 3, 2,
/*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64},
/*output=*/{TensorType_UINT8, {}, -127, 128});
@@ -312,10 +344,15 @@ TEST(FullyConnectedOpTest, SimpleTest4dInputQuantized) {
EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187));
}
+INSTANTIATE_TEST_CASE_P(
+ FullyConnectedOpTest, FullyConnectedOpTest,
+ ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+
// TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard
// to debug errors and doesn't necessarily test all the important details.
-TEST(FullyConnectedOpTest, BlackBoxTest) {
- FloatFullyConnectedOpModel m(16, 2, {TensorType_FLOAT32, {2, 8}});
+TEST_P(FullyConnectedOpTest, BlackBoxTest) {
+ FloatFullyConnectedOpModel m(GetRegistration(), 16, 2,
+ {TensorType_FLOAT32, {2, 8}});
m.SetWeights(
{0.091327, 0.103366, -0.316505, -0.083120, 0.149366, -0.196636,
-0.123672, 0.062800, 0.063031, 0.191670, -0.062001, -0.061504,
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index aa3957bee1..167c0f1fde 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -431,15 +431,3 @@ cc_library(
)
exports_files(["optimized/eigen_tensor_reduced_instantiations_oss.h"])
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index c71b070680..0f78e0f728 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -1694,12 +1694,11 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
#ifdef __aarch64__
- // Call kernel optimized for depthwise convolutions using 3x3 filters,
- // stride = 1, no padding, depth_multiplier = 1 and depth a multiple of 16.
- if (filter_width == 3 && filter_height == 3 && depth_multiplier == 1 &&
- (stride_width == 1 || stride_width == 2) &&
- (stride_height == 1 || stride_height == 2) && pad_width == 0 &&
- pad_height == 0 && (input_depth % 16) == 0) {
+ // Call kernel optimized for depthwise convolutions using 3x3 filters if
+ // parameters are supported.
+ if (Fast3by3FilterKernelSupported(input_dims, filter_dims, stride_width,
+ stride_height, pad_width, pad_height,
+ depth_multiplier, output_dims)) {
DepthwiseConv3by3FilterDepth16(
input_data, input_dims, input_offset, filter_data, filter_dims,
filter_offset, bias_data, bias_dims, stride_width, stride_height,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
index 9dc76e7608..a349892076 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
@@ -440,6 +440,47 @@ struct ConvKernel3x3FilterDepth16<1, 1> {
}
};
+inline bool Fast3by3FilterKernelSupported(const Dims<4>& input_dims,
+ const Dims<4>& filter_dims,
+ int stride_width, int stride_height,
+ int pad_width, int pad_height,
+ int depth_multiplier,
+ const Dims<4>& output_dims) {
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_depth = ArraySize(input_dims, 0);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+
+ bool supported = filter_width == 3 && filter_height == 3 &&
+ depth_multiplier == 1 &&
+ (stride_width == 1 || stride_width == 2) &&
+ (stride_height == 1 || stride_height == 2) &&
+ pad_width == 0 && pad_height == 0 && (input_depth % 16) == 0;
+
+ if (!supported) {
+ return false;
+ }
+
+ // Handle case where padding is zero but type is not kValid. This would
+ // require special boundary case handling that is not supported yet.
+
+ const int out_x = output_width - 1;
+ const int out_y = output_height - 1;
+
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+
+ const int in_x_end = in_x_origin + filter_width;
+ const int in_y_end = in_y_origin + filter_height;
+
+ // Supported only if filter on the right and bottom boundary lies completely
+ // within the input.
+ return in_x_end <= input_width && in_y_end <= input_height;
+}
+
inline void DepthwiseConv3by3FilterDepth16(
const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
@@ -634,7 +675,7 @@ inline void DepthwiseConv3by3FilterDepth16(
// Handle the rest of the right side.
for (; out_x < output_width; out_x++) {
// This code path can only be reached if we're handling >1 x outputs
- // at a time or support padding.
+ // at a time or support kSame padding.
}
}
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index e079ff3f4c..4661004d09 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -324,6 +324,22 @@ void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
}
}
+inline void optimized_ops_preload_l1_stream(const uint8* ptr) {
+#ifdef GEMMLOWP_ARM_64
+ asm volatile("prfm pldl1strm, [%[ptr]]\n" ::[ptr] "r"(ptr) :);
+#else
+ gemmlowp::Prefetch(ptr);
+#endif
+}
+
+inline void optimized_ops_preload_l1_keep(const uint8* ptr) {
+#ifdef GEMMLOWP_ARM_64
+ asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :);
+#else
+ gemmlowp::Prefetch(ptr);
+#endif
+}
+
#ifdef GEMMLOWP_NEON
// In the common case of batch size 1, a fully-connected node degenerates
// to a matrix*vector product. LSTM cells contain a fully-connected node;
@@ -516,6 +532,124 @@ inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims,
}
#endif
+#ifdef GEMMLOWP_NEON
+inline void GEMVForLstmCellWithSymmetricRange(
+ const uint8* input_data, const Dims<4>& input_dims,
+ const uint8* weights_data, const Dims<4>& weights_dims,
+ const int32* bias_data, const Dims<4>& bias_dims, int32 accum_multiplier,
+ int accum_shift, int16* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("GEMVForLstmCellWithSymmetricRange");
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
+ ArraySize(output_dims, 3),
+ 1);
+ const int input_size = input_dims.strides[3];
+ const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0);
+ // This special fast path for quantized LSTM cells does not try to support
+ // odd sizes that we haven't encountered in any LSTM cell, that would
+ // require special code (that would go untested until any LSTM cell
+ // exercises it). We just guard our assumptions about size evenness with
+ // the following assertions.
+ TFLITE_DCHECK(!(output_size % 4));
+ TFLITE_DCHECK(!(input_size % 8));
+ const int32* bias_ptr = bias_data;
+ int16* output_ptr = output_data;
+ const uint8x16_t signbit = vdupq_n_u8(0x80);
+ for (int in = 0; in < input_size; in += 32) {
+ optimized_ops_preload_l1_keep(input_data + in);
+ }
+ for (int out = 0; out < output_size; out += 4) {
+ const uint8* weights_ptr_0 = weights_data + out * input_size;
+ const uint8* weights_ptr_1 = weights_ptr_0 + 1 * input_size;
+ const uint8* weights_ptr_2 = weights_ptr_0 + 2 * input_size;
+ const uint8* weights_ptr_3 = weights_ptr_0 + 3 * input_size;
+
+ int32x4_t acc_0 = vdupq_n_s32(0);
+ int32x4_t acc_1 = vdupq_n_s32(0);
+ int32x4_t acc_2 = vdupq_n_s32(0);
+ int32x4_t acc_3 = vdupq_n_s32(0);
+ int in = 0;
+ const int kReadAhead = 256;
+ // Handle 16 levels of depth at a time.
+ for (; in < input_size; in += 16) {
+ int8x16_t weights_val_0 =
+ vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_0)));
+ int8x16_t weights_val_1 =
+ vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_1)));
+ int8x16_t weights_val_2 =
+ vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_2)));
+ int8x16_t weights_val_3 =
+ vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_3)));
+ int8x16_t input_val =
+ vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(input_data + in)));
+ int16x8_t acc16_0 =
+ vmull_s8(vget_low_s8(weights_val_0), vget_low_s8(input_val));
+ int16x8_t acc16_1 =
+ vmull_s8(vget_low_s8(weights_val_1), vget_low_s8(input_val));
+ int16x8_t acc16_2 =
+ vmull_s8(vget_low_s8(weights_val_2), vget_low_s8(input_val));
+ int16x8_t acc16_3 =
+ vmull_s8(vget_low_s8(weights_val_3), vget_low_s8(input_val));
+ acc16_0 = vmlal_s8(acc16_0, vget_high_s8(weights_val_0),
+ vget_high_s8(input_val));
+ acc16_1 = vmlal_s8(acc16_1, vget_high_s8(weights_val_1),
+ vget_high_s8(input_val));
+ acc16_2 = vmlal_s8(acc16_2, vget_high_s8(weights_val_2),
+ vget_high_s8(input_val));
+ acc16_3 = vmlal_s8(acc16_3, vget_high_s8(weights_val_3),
+ vget_high_s8(input_val));
+ acc_0 = vpadalq_s16(acc_0, acc16_0);
+ acc_1 = vpadalq_s16(acc_1, acc16_1);
+ acc_2 = vpadalq_s16(acc_2, acc16_2);
+ acc_3 = vpadalq_s16(acc_3, acc16_3);
+ weights_ptr_0 += 16;
+ weights_ptr_1 += 16;
+ weights_ptr_2 += 16;
+ weights_ptr_3 += 16;
+ optimized_ops_preload_l1_stream(weights_ptr_0 + kReadAhead);
+ optimized_ops_preload_l1_stream(weights_ptr_1 + kReadAhead);
+ optimized_ops_preload_l1_stream(weights_ptr_2 + kReadAhead);
+ optimized_ops_preload_l1_stream(weights_ptr_3 + kReadAhead);
+ }
+ // Horizontally reduce accumulators
+ int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
+ pairwise_reduced_acc_2, pairwise_reduced_acc_3;
+ pairwise_reduced_acc_0 =
+ vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0));
+ pairwise_reduced_acc_1 =
+ vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1));
+ pairwise_reduced_acc_2 =
+ vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2));
+ pairwise_reduced_acc_3 =
+ vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3));
+ const int32x2_t reduced_lo =
+ vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
+ const int32x2_t reduced_hi =
+ vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
+ int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
+ // Add bias values.
+ int32x4_t bias_vec = vld1q_s32(bias_ptr);
+ bias_ptr += 4;
+ reduced = vaddq_s32(reduced, bias_vec);
+ int left_shift = accum_shift > 0 ? accum_shift : 0;
+ int right_shift = accum_shift > 0 ? 0 : -accum_shift;
+ reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
+ // Multiply by the fixed-point multiplier.
+ reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
+ // Rounding-shift-right.
+ using gemmlowp::RoundingDivideByPOT;
+ reduced = RoundingDivideByPOT(reduced, right_shift);
+ // Narrow values down to 16 bit signed.
+ const int16x4_t res16 = vqmovn_s32(reduced);
+ vst1_s16(output_ptr, res16);
+ output_ptr += 4;
+ }
+}
+#endif
+
inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
const float* weights_data,
const Dims<4>& weights_dims, const float* bias_data,
@@ -559,14 +693,6 @@ void FullyConnected(const float* input_data, const Dims<4>& input_dims,
output_data, output_dims);
}
-inline void preload_l1_stream(const uint8* ptr) {
-#ifdef GEMMLOWP_ARM_64
- asm volatile("prfm pldl1strm, [%[ptr]]\n" ::[ptr] "r"(ptr) :);
-#else
- gemmlowp::Prefetch(ptr);
-#endif
-}
-
#ifdef USE_NEON
inline void FullyConnectedAsGEMV(
const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
@@ -587,10 +713,10 @@ inline void FullyConnectedAsGEMV(
const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
static constexpr int kPeel = 4;
for (int k = 0; k < input_size; k += 64) {
- preload_l1_stream(input_data + k);
+ optimized_ops_preload_l1_stream(input_data + k);
}
for (int k = 0; k < kPeel * input_size; k += 64) {
- preload_l1_stream(filter_data + k);
+ optimized_ops_preload_l1_stream(filter_data + k);
}
TFLITE_DCHECK(!(output_size % kPeel));
const int32* bias_ptr = bias_data;
@@ -609,7 +735,7 @@ inline void FullyConnectedAsGEMV(
for (int k = 0; k < kPeel; k++) {
const uint8* filter_ptr = filter_data + in + (out + k) * input_size;
filter_val_u8[k] = vld1q_u8(filter_ptr);
- preload_l1_stream(filter_ptr + 64);
+ optimized_ops_preload_l1_stream(filter_ptr + 64);
}
int16x8_t input_val[2];
const uint8x8_t low = vget_low_u8(input_val_u8);
@@ -834,13 +960,22 @@ inline void FullyConnected(
// the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
// is explained in the function comment above.
#ifdef GEMMLOWP_NEON
- if (batches == 1 && !(output_depth % 4) && !(accum_depth % 8) &&
- input_offset == -128 && output_activation_min == -32768 &&
+ if (batches == 1 && input_offset == -128 && output_activation_min == -32768 &&
output_activation_max == 32767) {
- GEMVForLstmCell(input_data, input_dims, filter_data, filter_dims,
- filter_offset, bias_data_int32, bias_dims,
- output_multiplier, -output_shift, output_data, output_dims);
- return;
+ if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 16)) {
+ GEMVForLstmCellWithSymmetricRange(input_data, input_dims, filter_data,
+ filter_dims, bias_data_int32, bias_dims,
+ output_multiplier, -output_shift,
+ output_data, output_dims);
+ return;
+ }
+ if (!(output_depth % 4) && !(accum_depth % 8)) {
+ GEMVForLstmCell(input_data, input_dims, filter_data, filter_dims,
+ filter_offset, bias_data_int32, bias_dims,
+ output_multiplier, -output_shift, output_data,
+ output_dims);
+ return;
+ }
}
#endif
gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> weights_matrix(
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index c053ff17ec..3575974ae9 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -3184,19 +3184,20 @@ inline void Exp(const T* input_data, const size_t num_elements,
}
}
-template <typename T>
-inline void Mean(T* input_data, const int* input_dims, const int input_num_dims,
+template <typename T, typename U>
+inline bool Mean(T* input_data, const int* input_dims, const int input_num_dims,
T* output_data, const int* output_dims,
const int output_num_dims, const int* axis,
const int num_axis_dimensions, bool keep_dims, int* temp_index,
- int* resolved_axis) {
+ int* resolved_axis, U* temp_sum) {
// resets output data.
size_t num_outputs = 1;
for (int idx = 0; idx < output_num_dims; ++idx) {
num_outputs *= static_cast<size_t>(output_dims[idx]);
}
for (size_t idx = 0; idx < num_outputs; ++idx) {
- output_data[idx] = 0;
+ output_data[idx] = T();
+ temp_sum[idx] = U();
}
// resets temp index.
for (int idx = 0; idx < input_num_dims; ++idx) {
@@ -3229,19 +3230,24 @@ inline void Mean(T* input_data, const int* input_dims, const int input_num_dims,
size_t output_offset =
ReducedOutputOffset(input_num_dims, input_dims, temp_index,
num_resolved_axis, resolved_axis);
- output_data[output_offset] += input_data[input_offset];
+ temp_sum[output_offset] += static_cast<U>(input_data[input_offset]);
}
// takes average by num of elements added to get mean.
size_t num_elements_in_axis = 1;
for (int idx = 0; idx < num_resolved_axis; ++idx) {
- num_elements_in_axis *= static_cast<size_t>(input_dims[resolved_axis[idx]]);
+ size_t current = static_cast<size_t>(input_dims[resolved_axis[idx]]);
+ if (current > (std::numeric_limits<U>::max() / num_elements_in_axis)) {
+ return false;
+ }
+ num_elements_in_axis *= current;
}
if (num_elements_in_axis > 0) {
for (size_t idx = 0; idx < num_outputs; ++idx) {
- output_data[idx] = static_cast<T>(static_cast<float>(output_data[idx]) /
- num_elements_in_axis);
+ output_data[idx] =
+ static_cast<T>(temp_sum[idx] / static_cast<U>(num_elements_in_axis));
}
}
+ return true;
}
template <typename T>
diff --git a/tensorflow/contrib/lite/kernels/mean.cc b/tensorflow/contrib/lite/kernels/mean.cc
index aff19581ea..047bdd1039 100644
--- a/tensorflow/contrib/lite/kernels/mean.cc
+++ b/tensorflow/contrib/lite/kernels/mean.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -48,7 +49,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// Creates two temp tensors to store index and axis for internal
// implementation only.
auto* scratch_tensor_index = new int;
- context->AddTensors(context, 2, scratch_tensor_index);
+ context->AddTensors(context, 3, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -64,6 +65,14 @@ TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context,
return context->ResizeTensor(context, resolved_axis, axis_size);
}
+// Resizes the temp tensor that stores temp sum of reduced elements.
+TfLiteStatus ResizeTempSum(TfLiteContext* context, MeanContext* op_context,
+ TfLiteTensor* temp_sum) {
+ TfLiteIntArray* size = TfLiteIntArrayCreate(1);
+ size->data[0] = static_cast<int>(NumElements(op_context->output));
+ return context->ResizeTensor(context, temp_sum, size);
+}
+
// Resizes output array based on the input size and resolved axis.
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
MeanContext* op_context) {
@@ -135,7 +144,7 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
// Creates a temp index to iterate through input data.
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(2);
+ node->temporaries = TfLiteIntArrayCreate(3);
node->temporaries->data[0] = *scratch_tensor_index;
TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]];
scratch_tensor->type = kTfLiteInt32;
@@ -149,6 +158,25 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
node->temporaries->data[1] = *scratch_tensor_index + 1;
TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]];
resolved_axis->type = kTfLiteInt32;
+ // Creates a temp tensor to store temp sums when calculating mean.
+ node->temporaries->data[2] = *scratch_tensor_index + 2;
+ TfLiteTensor* temp_sum = &context->tensors[node->temporaries->data[2]];
+ switch (op_context->input->type) {
+ case kTfLiteFloat32:
+ temp_sum->type = kTfLiteFloat32;
+ break;
+ case kTfLiteInt32:
+ temp_sum->type = kTfLiteInt64;
+ break;
+ case kTfLiteInt64:
+ temp_sum->type = kTfLiteInt64;
+ break;
+ case kTfLiteUInt8:
+ temp_sum->type = kTfLiteInt32;
+ break;
+ default:
+ return kTfLiteError;
+ }
return kTfLiteOk;
}
@@ -160,16 +188,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context));
TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]];
+ TfLiteTensor* temp_sum = &context->tensors[node->temporaries->data[2]];
// Leaves work to Eval if axis is not constant; else resizes output.
if (!IsConstantTensor(op_context.axis)) {
SetTensorToDynamic(op_context.output);
SetTensorToDynamic(resolved_axis);
+ SetTensorToDynamic(temp_sum);
return kTfLiteOk;
}
resolved_axis->allocation_type = kTfLiteArenaRw;
TF_LITE_ENSURE_OK(context,
ResizeTempAxis(context, &op_context, resolved_axis));
- return ResizeOutputTensor(context, &op_context);
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ temp_sum->allocation_type = kTfLiteArenaRw;
+ return ResizeTempSum(context, &op_context, temp_sum);
}
template <KernelType kernel_type>
@@ -178,14 +210,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
int num_axis = static_cast<int>(NumElements(op_context.axis));
TfLiteTensor* temp_index = &context->tensors[node->temporaries->data[0]];
TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]];
+ TfLiteTensor* temp_sum = &context->tensors[node->temporaries->data[2]];
// Resize the output tensor if the output tensor is dynamic.
if (IsDynamicTensor(op_context.output)) {
TF_LITE_ENSURE_OK(context,
ResizeTempAxis(context, &op_context, resolved_axis));
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
}
-#define TF_LITE_MEAN(kernel_type, data_type) \
+#define TF_LITE_MEAN(kernel_type, data_type, temp_data_type) \
kernel_type::Mean<>( \
GetTensorData<data_type>(op_context.input), \
op_context.input->dims->data, op_context.input->dims->size, \
@@ -193,21 +227,26 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
op_context.output->dims->data, op_context.output->dims->size, \
GetTensorData<int>(op_context.axis), num_axis, \
op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis))
+ GetTensorData<int>(resolved_axis), \
+ GetTensorData<temp_data_type>(temp_sum))
if (kernel_type == kReference) {
switch (op_context.input->type) {
case kTfLiteFloat32:
- TF_LITE_MEAN(reference_ops, float);
+ TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, float, float));
break;
case kTfLiteInt32:
- TF_LITE_MEAN(reference_ops, int);
- break;
- case kTfLiteUInt8:
- TF_LITE_MEAN(reference_ops, uint8_t);
+ TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, int, int64_t));
break;
case kTfLiteInt64:
- TF_LITE_MEAN(reference_ops, int64_t);
+ TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, int64_t, int64_t));
+ break;
+ case kTfLiteUInt8:
+ TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
+ op_context.output->params.scale);
+ TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
+ op_context.output->params.zero_point);
+ TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, uint8_t, int));
break;
default:
return kTfLiteError;
@@ -216,7 +255,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
#undef TF_LITE_MEAN
return kTfLiteOk;
}
-
} // namespace mean
TfLiteRegistration* Register_MEAN_REF() {
diff --git a/tensorflow/contrib/lite/kernels/mean_test.cc b/tensorflow/contrib/lite/kernels/mean_test.cc
index 2d6d4bc2da..79c9957f76 100644
--- a/tensorflow/contrib/lite/kernels/mean_test.cc
+++ b/tensorflow/contrib/lite/kernels/mean_test.cc
@@ -37,8 +37,15 @@ class BaseMeanOpModel : public SingleOpModel {
return ExtractVector<T>(output_);
}
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+ int Input() { return input_; }
+
protected:
int input_;
int axis_;
@@ -142,56 +149,64 @@ TEST(DynamicFloatMeanOpTest, Scale) {
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({9.527})));
}
+// for quantized Add, the error shouldn't exceed step
+float GetTolerance(int min, int max) { return (max - min) / 255.0; }
+
TEST(ConstUint8MeanOpTest, NotKeepDims) {
- std::initializer_list<uint8_t> data = {1, 2, 3, 4, 5, 6, 7, 8,
- 9, 10, 11, 12, 13, 14, 15, 16,
- 17, 18, 19, 20, 21, 22, 23, 24};
- MeanOpConstModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {2}},
- {4}, {1, 0, -3, -3}, false);
- m.SetInput(data);
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::initializer_list<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ MeanOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0},
+ {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
m.Invoke();
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
- EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({12, 13}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ {0.4, 0.4}, kQuantizedTolerance)));
}
TEST(ConstUint8MeanOpTest, KeepDims) {
- std::initializer_list<uint8_t> data = {1, 2, 3, 4, 5, 6, 7, 8,
- 9, 10, 11, 12, 13, 14, 15, 16,
- 17, 18, 19, 20, 21, 22, 23, 24};
- MeanOpConstModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {3}},
- {2}, {0, 2}, true);
- m.SetInput(data);
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::initializer_list<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ MeanOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0},
+ {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
m.Invoke();
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
- EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({10, 12, 14}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({0.3, 0.35, 0.55}, kQuantizedTolerance)));
}
TEST(DynamicUint8MeanOpTest, NotKeepDims) {
- std::initializer_list<uint8_t> data = {1, 2, 3, 4, 5, 6, 7, 8,
- 9, 10, 11, 12, 13, 14, 15, 16,
- 17, 18, 19, 20, 21, 22, 23, 24};
- MeanOpDynamicModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {2}},
- {TensorType_INT32, {4}}, false);
- std::initializer_list<int> axis = {1, 0, -3, -3};
+ float kQuantizedTolerance = GetTolerance(-5.0, 2.0);
+ std::initializer_list<float> data = {1.3, -4.8, -3.6, 0.24};
+ MeanOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0},
+ {TensorType_UINT8, {2}, -5.0, 2.0},
+ {TensorType_INT32, {1}}, false);
+ std::initializer_list<int> axis = {1};
m.SetAxis(axis);
- m.SetInput(data);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
- EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({12, 13}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({-1.75, -1.68}, kQuantizedTolerance)));
}
TEST(DynamicUint8MeanOpTest, KeepDims) {
- std::initializer_list<uint8_t> data = {1, 2, 3, 4, 5, 6, 7, 8,
- 9, 10, 11, 12, 13, 14, 15, 16,
- 17, 18, 19, 20, 21, 22, 23, 24};
- MeanOpDynamicModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {3}},
- {TensorType_INT32, {2}}, true);
- std::initializer_list<int> axis = {0, 2};
+ float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
+ std::initializer_list<float> data = {11.14, -0.14, 7.423, 0.879};
+ MeanOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0},
+ {TensorType_UINT8, {2}, -10.0, 12.0},
+ {TensorType_INT32, {1}}, true);
+ std::initializer_list<int> axis = {0};
m.SetAxis(axis);
- m.SetInput(data);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
m.Invoke();
- EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
- EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({10, 12, 14}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({9.2815, 0.3695}, kQuantizedTolerance)));
}
} // namespace
diff --git a/tensorflow/contrib/lite/models/BUILD b/tensorflow/contrib/lite/models/BUILD
index 6a1255b586..efa47b06fa 100644
--- a/tensorflow/contrib/lite/models/BUILD
+++ b/tensorflow/contrib/lite/models/BUILD
@@ -12,15 +12,3 @@ load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
exports_files(glob([
"testdata/*",
]))
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD
index 733c3f4c7f..a82d1f2eb6 100644
--- a/tensorflow/contrib/lite/models/smartreply/BUILD
+++ b/tensorflow/contrib/lite/models/smartreply/BUILD
@@ -86,15 +86,3 @@ cc_test(
"@com_google_googletest//:gtest",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/lite/nnapi/BUILD b/tensorflow/contrib/lite/nnapi/BUILD
index 402f1e949b..467a2b7a7b 100644
--- a/tensorflow/contrib/lite/nnapi/BUILD
+++ b/tensorflow/contrib/lite/nnapi/BUILD
@@ -11,15 +11,3 @@ cc_library(
],
linkopts = ["-ldl"],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 411d5c0d27..e70aa51298 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -118,15 +118,3 @@ py_library(
":convert_saved_model",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD
index da65ec659c..246ec85fe4 100644
--- a/tensorflow/contrib/lite/schema/BUILD
+++ b/tensorflow/contrib/lite/schema/BUILD
@@ -70,16 +70,4 @@ cc_test(
],
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tflite_portable_test_suite()
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 12b7b3c350..62f20638ba 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -374,16 +374,4 @@ tf_cc_test(
}),
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tflite_portable_test_suite()
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 68bce19aa3..8045052452 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -909,12 +909,11 @@ def make_gather_tests(zip_path):
# TODO(mgubin): add string tests when they are supported by Toco.
# TODO(mgubin): add tests for Nd indices when they are supported by
# TfLite.
- # TODO(mgubin): add tests for axis != 0 when it is supported by TfLite.
"params_dtype": [tf.float32, tf.int32],
"params_shape": [[10], [1, 2, 20]],
"indices_dtype": [tf.int32],
"indices_shape": [[3], [5]],
- "axis": [0], # axis!=0 is GatherV2
+ "axis": [0, 1],
}]
def build_graph(parameters):
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index e9d505a76d..6697b86e79 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -91,6 +91,9 @@ std::map<string, string> kBrokenTests = {
// PRelu only supports 4D input with (1, 1, channels) 3D alpha now.
{R"(^\/prelu.*shared_axes=\[1\])", "75975192"},
+
+ // No support for axis!=0 in GatherV2.
+ {R"(^\/gather.*axis=1)", "76910444"},
};
// Allows test data to be unzipped into a temporary directory and makes
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 613223f3d4..c399f4f2b7 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -56,12 +56,16 @@ void SetTensorData(const std::vector<T>& values, TfLitePtrUnion* data) {
class TfLiteDriver::Expectation {
public:
- Expectation() { data_.raw = nullptr; }
+ Expectation() {
+ data_.raw = nullptr;
+ num_elements_ = 0;
+ }
~Expectation() { delete[] data_.raw; }
template <typename T>
void SetData(const string& csv_values) {
const auto& values = testing::Split<T>(csv_values, ",");
- data_.raw = new char[values.size() * sizeof(T)];
+ num_elements_ = values.size();
+ data_.raw = new char[num_elements_ * sizeof(T)];
SetTensorData(values, &data_);
}
@@ -88,7 +92,13 @@ class TfLiteDriver::Expectation {
constexpr double kRelativeThreshold = 1e-2f;
constexpr double kAbsoluteThreshold = 1e-4f;
- int tensor_size = tensor.bytes / sizeof(T);
+ size_t tensor_size = tensor.bytes / sizeof(T);
+
+ if (tensor_size != num_elements_) {
+ std::cerr << "Expected a tensor with " << num_elements_
+ << " elements, got " << tensor_size << std::endl;
+ return false;
+ }
bool good_output = true;
for (int i = 0; i < tensor_size; ++i) {
@@ -115,6 +125,7 @@ class TfLiteDriver::Expectation {
}
TfLitePtrUnion data_;
+ size_t num_elements_;
};
TfLiteDriver::TfLiteDriver(bool use_nnapi) : use_nnapi_(use_nnapi) {}
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 102740ee47..bba61627f9 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -143,6 +143,7 @@ cc_library(
":toco_graphviz_dump_options",
":toco_port",
":types_proto_cc",
+ "//tensorflow/cc/saved_model:tag_constants",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
@@ -279,6 +280,7 @@ cc_library(
"graph_transformations/resolve_tensorflow_switch.cc",
"graph_transformations/resolve_tensorflow_tile.cc",
"graph_transformations/resolve_transpose_attributes.cc",
+ "graph_transformations/swap_elementwise_binary.cc",
"graph_transformations/unfuse_activation_functions.cc",
"graph_transformations/unpartition_embedding_lookup.cc",
"graph_transformations/unroll_batch_matmul.cc",
@@ -418,15 +420,3 @@ tf_cc_test(
"@com_google_googletest//:gtest_main",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index 7b71792ff7..52c789293c 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -26,6 +26,7 @@ limitations under the License.
#endif
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
+#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/contrib/lite/toco/toco_port.h"
#include "tensorflow/contrib/lite/toco/toco_types.h"
@@ -220,7 +221,7 @@ struct ParsedTocoFlags {
Arg<string> output_file;
Arg<string> input_format = Arg<string>("TENSORFLOW_GRAPHDEF");
Arg<string> output_format = Arg<string>("TFLITE");
- Arg<string> savedmodel_tagset;
+ Arg<string> savedmodel_tagset = Arg<string>(tensorflow::kSavedModelTagServe);
// TODO(aselle): command_line_flags doesn't support doubles
Arg<float> default_ranges_min = Arg<float>(0.);
Arg<float> default_ranges_max = Arg<float>(0.);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 640afc7c74..1291825c8e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -180,6 +180,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantGather)
DECLARE_GRAPH_TRANSFORMATION(ResolveMultiplyByZero)
+DECLARE_GRAPH_TRANSFORMATION(SwapElementwiseBinary)
DECLARE_GRAPH_TRANSFORMATION(Dequantize)
DECLARE_GRAPH_TRANSFORMATION(UnpartitionEmbeddingLookup)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
index 5cc82da5d5..7c97ef0d31 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -332,6 +332,7 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
case OperatorType::kPad:
case OperatorType::kGather:
case OperatorType::kTranspose:
+ case OperatorType::kMean:
changed = HardcodeMinMaxFromFirstInput(model, op);
break;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 9679ea0a77..9fcc95e1fe 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -52,7 +52,7 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kStridedSlice ||
type == OperatorType::kDepthToSpace ||
type == OperatorType::kLstmCell || type == OperatorType::kGather ||
- type == OperatorType::kTranspose;
+ type == OperatorType::kTranspose || type == OperatorType::kMean;
}
template <ArrayDataType A>
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/swap_elementwise_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/swap_elementwise_binary.cc
new file mode 100644
index 0000000000..ecbce58d16
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/swap_elementwise_binary.cc
@@ -0,0 +1,175 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool ShapesAllowSwapping(const string& input_array_name,
+ const string& const_array_name, Model* model) {
+ const Array& input_array = model->GetOrCreateArray(input_array_name);
+ const Array& const_array = model->GetOrCreateArray(const_array_name);
+ // Wait until these shapes have been resolved.
+ if (!input_array.has_shape() || !const_array.has_shape()) {
+ return false;
+ }
+
+ // Currently swapping is not handled for scalar const_array, though that could
+ // be done once there is a test model.
+ if (RequiredBufferSizeForShape(input_array.shape()) !=
+ RequiredBufferSizeForShape(const_array.shape())) {
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace
+
+// Swaps:
+// Input
+// \
+// (Reshape Op) Const
+// \ /
+// (Add/Sub/Mul/Div op)
+// |
+// Output
+//
+// To:
+//
+// Input Const
+// \ /
+// (Add/Sub/Mul/Div op)
+// |
+// (Reshape Op)
+// |
+// Output
+//
+// This can allow Add/Mul ops from batch normalization to be folded into an
+// Input op from a FullyConnected layer.
+bool SwapElementwiseBinary::Run(Model* model, std::size_t op_index) {
+ const auto element_wise_op_it = model->operators.begin() + op_index;
+ std::unique_ptr<Operator>& element_wise_op = *element_wise_op_it;
+ DCHECK(element_wise_op);
+
+ switch (element_wise_op->type) {
+ case OperatorType::kAdd:
+ case OperatorType::kSub:
+ case OperatorType::kMul:
+ case OperatorType::kDiv:
+ break;
+ default:
+ return false;
+ }
+
+ int reshape_input = -1;
+ Operator* op = GetOpWithOutput(*model, element_wise_op->inputs[0]);
+ if (!op) {
+ return false;
+ }
+
+ if (op->type == OperatorType::kTensorFlowReshape) {
+ reshape_input = 0;
+ } else {
+ op = GetOpWithOutput(*model, element_wise_op->inputs[1]);
+ if (!op || op->type != OperatorType::kTensorFlowReshape) {
+ return false;
+ }
+ reshape_input = 1;
+ }
+
+ int const_input = (reshape_input == 0) ? 1 : 0;
+ const string& const_input_array = element_wise_op->inputs[const_input];
+ if (!IsConstantParameterArray(*model, const_input_array)) {
+ return false;
+ }
+
+ // Do not fold division if denominator is not constant.
+ if (element_wise_op->type != OperatorType::kDiv && const_input != 1) {
+ return false;
+ }
+
+ const auto reshape_it =
+ FindOpWithOutput(*model, element_wise_op->inputs[reshape_input]);
+ // Note: we take copies of the tensor names here, instead of const-refs as we
+ // may overwrite the original names.
+ const string reshape_input_name = (*reshape_it)->inputs[0];
+ const string intermediate_name = (*reshape_it)->outputs[0];
+ const string element_wise_output_name = element_wise_op->outputs[0];
+
+ // Check the reshape op input and const op have their shapes resolved.
+ if (!ShapesAllowSwapping(reshape_input_name, const_input_array, model)) {
+ return false;
+ }
+
+ int count_ops_consuming_output = CountOpsWithInput(*model, intermediate_name);
+ DCHECK_GE(count_ops_consuming_output, 1);
+ if (count_ops_consuming_output > 1) {
+ AddMessageF(
+ "Not exchanging element-wise function with %s because it is "
+ "consumed by more than 1 other operator",
+ LogName(**reshape_it));
+ return false;
+ }
+
+ // If the element_wise_op was originally producing an output_array we can't
+ // swap as otherwise the output array would change. It'd be nice to still be
+ // able to swap but if code is relying on the fetch names instead of array
+ // indices this won't work.
+ for (int i = 0; i < model->flags.output_arrays_size(); ++i) {
+ if (model->flags.output_arrays(i) == element_wise_op->outputs[0]) {
+ AddMessageF(
+ "Not exchanging activation function with %s to preserve output array "
+ "name %s",
+ LogName(**reshape_it), element_wise_op->outputs[0]);
+ return false;
+ }
+ }
+
+ // Rewire by changing inputs, including all consumers.
+ // TODO(b/76086261): Replace with new utility function.
+ Operator* consumer = GetFirstOpWithInput(*model, element_wise_output_name);
+ while (consumer) {
+ for (int i = 0; i < consumer->inputs.size(); ++i) {
+ if (consumer->inputs[i] == element_wise_output_name) {
+ consumer->inputs[i] = intermediate_name;
+ }
+ }
+ consumer = GetFirstOpWithInput(*model, element_wise_output_name);
+ }
+ element_wise_op->inputs[reshape_input] = reshape_input_name;
+ (*reshape_it)->inputs[0] = element_wise_output_name;
+
+ // Clear shapes; this will allow shape propagation to fix the sizes for us.
+ model->GetOrCreateArray(element_wise_output_name).clear_shape();
+
+ // Finally, swap operators. Note that this only works when there are no other
+ // direct descendents of the reshape operator.
+ element_wise_op.swap(*reshape_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
index 2f94f9cd8a..a2008ddbdb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
@@ -19,8 +19,8 @@ tf_cc_test(
)
tf_cc_test(
- name = "lstm_utils_test",
- srcs = ["lstm_utils_test.cc"],
+ name = "swap_elementwise_binary_test",
+ srcs = ["swap_elementwise_binary_test.cc"],
deps = [
"//tensorflow/contrib/lite/toco:graph_transformations",
"//tensorflow/contrib/lite/toco:model",
@@ -29,14 +29,13 @@ tf_cc_test(
],
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
+tf_cc_test(
+ name = "lstm_utils_test",
+ srcs = ["lstm_utils_test.cc"],
+ deps = [
+ "//tensorflow/contrib/lite/toco:graph_transformations",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "@com_google_googletest//:gtest_main",
+ ],
)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/swap_elementwise_binary_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/swap_elementwise_binary_test.cc
new file mode 100644
index 0000000000..c3778017f3
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/swap_elementwise_binary_test.cc
@@ -0,0 +1,89 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+namespace {
+
+int ShapeCount(const std::vector<int>& size) {
+ CHECK(size.size());
+ int count = 1;
+ for (int dim : size) {
+ count *= dim;
+ }
+ return count;
+}
+
+// Adds a new parameter array to the model.
+void AddConstArray(const string& name, const float* data,
+ const std::vector<int>& size, Model* model) {
+ Array& array = model->GetOrCreateArray(name);
+ array.data_type = ArrayDataType::kFloat;
+ Shape* shape = array.mutable_shape();
+ *(shape->mutable_dims()) = size;
+
+ auto& buffer = array.GetMutableBuffer<toco::ArrayDataType::kFloat>();
+ buffer.data.resize(ShapeCount(size));
+ std::copy(data, data + ShapeCount(size), buffer.data.data());
+}
+
+} // namespace
+
+TEST(SwapElementwiseBinaryTest, SwapsReshape) {
+ Model model;
+ const float parameters[2][4] = {{0., 1., 2., 3.}, {10., 11., 12., 13.}};
+
+ AddConstArray("before_reshape", parameters[0], {2, 2}, &model);
+ AddConstArray("add_vector", parameters[1], {1, 4}, &model);
+
+ auto reshape_op = absl::make_unique<TensorFlowReshapeOperator>();
+ reshape_op->shape = {1, 4};
+ reshape_op->inputs = {"before_reshape"};
+ reshape_op->outputs = {"after_reshape"};
+ Array& reshape_array = model.GetOrCreateArray("after_reshape");
+ *(reshape_array.mutable_shape()) = {1, 4};
+
+ auto add_op = absl::make_unique<AddOperator>();
+ add_op->inputs = {"after_reshape", "add_vector"};
+ add_op->outputs = {"add"};
+ Array& add_array = model.GetOrCreateArray("add");
+ *(add_array.mutable_shape()) = {1, 4};
+
+ model.operators.push_back(std::move(reshape_op));
+ model.operators.push_back(std::move(add_op));
+
+ auto transformation = absl::make_unique<toco::SwapElementwiseBinary>();
+ ASSERT_TRUE(transformation->Run(&model, 1));
+
+ Operator* op = GetOpWithOutput(model, "add");
+ ASSERT_NE(nullptr, op);
+ ASSERT_EQ(OperatorType::kAdd, op->type);
+ ASSERT_EQ(2, op->inputs.size());
+ for (const string& input : op->inputs) {
+ EXPECT_TRUE(IsConstantParameterArray(model, input))
+ << input << " is not const input";
+ }
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index b844e0b948..c26e4bddff 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1343,13 +1343,16 @@ void ConvertFloorOperator(const NodeDef& node,
void ConvertGatherOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
- CHECK_EQ(node.op(), "Gather");
- CheckInputsCount(node, tf_import_flags, 2);
+ CHECK(node.op() == "Gather" || node.op() == "GatherV2");
+ if (node.op() == "Gather") CheckInputsCount(node, tf_import_flags, 2);
+ if (node.op() == "GatherV2") CheckInputsCount(node, tf_import_flags, 3);
const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
auto* op = new GatherOperator;
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
+ // TODO(ahentz): we currently ignore the third tensor in GatherV2 but we
+ // should read it an pass it on to the TF Lite Interpreter.
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
}
@@ -2119,7 +2122,7 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
ConvertCastOperator(node, tf_import_flags, model);
} else if (node.op() == "Floor") {
ConvertFloorOperator(node, tf_import_flags, model);
- } else if (node.op() == "Gather") {
+ } else if (node.op() == "Gather" || node.op() == "GatherV2") {
ConvertGatherOperator(node, tf_import_flags, model);
} else if (node.op() == "ResizeBilinear") {
ConvertResizeBilinearOperator(node, tf_import_flags, model);
diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD
index 86d91bd3be..6c4f8e12cd 100644
--- a/tensorflow/contrib/lite/toco/python/BUILD
+++ b/tensorflow/contrib/lite/toco/python/BUILD
@@ -60,15 +60,3 @@ tf_py_test(
],
tags = ["no_pip"],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD
index 0c1a1141fc..336e94de1e 100644
--- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD
@@ -88,15 +88,3 @@ cc_library(
"//tensorflow/core:protos_all_cc",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD
index 9d3e1daf12..e0191801a0 100644
--- a/tensorflow/contrib/lite/toco/tflite/BUILD
+++ b/tensorflow/contrib/lite/toco/tflite/BUILD
@@ -137,15 +137,3 @@ tf_cc_test(
"@flatbuffers",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index 2771959970..335b496dcc 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -300,6 +300,17 @@ void Export(const Model& model, bool allow_custom_ops,
std::set<string> error_summary;
auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
&builder, &error_summary);
+ const string fake_quant_operation_name = "FAKE_QUANT";
+ if (error_summary.count(fake_quant_operation_name) != 0) {
+ LOG(ERROR)
+ << fake_quant_operation_name
+ << " operation was not converted. If running quantized make sure you "
+ "are passing --inference_type=QUANTIZED_UINT8 and values for "
+ "--std_values and --mean_values.";
+ // Remove the fake quant operation from the errors, since it shouldn't
+ // be provided a custom implementation.
+ error_summary.erase(fake_quant_operation_name);
+ }
if (!allow_custom_ops && !error_summary.empty()) {
LOG(QFATAL) << "Some of the operators in the model are not supported by "
"the standard TensorFlow Lite runtime. If you have a custom "
diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.cc b/tensorflow/contrib/lite/toco/toco_saved_model.cc
index 91a742b9e0..26f55a66c7 100644
--- a/tensorflow/contrib/lite/toco/toco_saved_model.cc
+++ b/tensorflow/contrib/lite/toco/toco_saved_model.cc
@@ -35,9 +35,12 @@ const tensorflow::SavedModelBundle* LoadSavedModel(
<< "Model is not saved in the supported SavedModel format.\n";
// Gets the tags identifying the MetaGraphDef from the command line arguments.
- QCHECK(parsed_toco_flags.savedmodel_tagset.specified())
- << "Missing required flag --savedmodel_tagset.\n";
- const string tags_str = parsed_toco_flags.savedmodel_tagset.value();
+ string tags_str;
+ if (parsed_toco_flags.savedmodel_tagset.specified()) {
+ tags_str = parsed_toco_flags.savedmodel_tagset.value();
+ } else {
+ tags_str = parsed_toco_flags.savedmodel_tagset.default_value();
+ }
auto tags = absl::StrSplit(tags_str, ',');
// Loads MetaGraphDef.
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 30dd6fab9e..41ea1481bc 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -90,6 +90,7 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveTensorFlowTile);
transformations->Add(new ResolveTensorFlowConcat);
transformations->Add(new ResolveMultiplyByZero);
+ transformations->Add(new SwapElementwiseBinary);
transformations->Add(new IdentifyDilatedConv);
transformations->Add(new IdentifyL2Normalization);
transformations->Add(new IdentifyL2Pool);
diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD
index b5abbc0712..44fde69a1e 100644
--- a/tensorflow/contrib/lite/tools/BUILD
+++ b/tensorflow/contrib/lite/tools/BUILD
@@ -91,18 +91,6 @@ cc_library(
deps = ["//tensorflow/contrib/lite:framework"],
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
cc_library(
name = "verifier",
srcs = ["verifier.cc"],
diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD
index 8ca03f4193..02b4f80252 100644
--- a/tensorflow/contrib/lookup/BUILD
+++ b/tensorflow/contrib/lookup/BUILD
@@ -47,15 +47,3 @@ tf_py_test(
],
grpc_enabled = True,
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/losses/BUILD b/tensorflow/contrib/losses/BUILD
index 5694211521..728f75f8ef 100644
--- a/tensorflow/contrib/losses/BUILD
+++ b/tensorflow/contrib/losses/BUILD
@@ -97,15 +97,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/makefile/BUILD b/tensorflow/contrib/makefile/BUILD
index 701eeb44fe..1abb46f4d4 100644
--- a/tensorflow/contrib/makefile/BUILD
+++ b/tensorflow/contrib/makefile/BUILD
@@ -3,12 +3,3 @@
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:private"])
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = ["**/OWNERS"],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 7a7683c953..b6acf71b9d 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -228,6 +228,11 @@ tensorflow/core/kernels/cast_op_impl_int64.cc
tensorflow/core/kernels/cast_op_impl_int8.cc
tensorflow/core/kernels/cast_op_impl_uint16.cc
tensorflow/core/kernels/cast_op_impl_uint8.cc
+tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+tensorflow/core/kernels/boosted_trees/resource_ops.cc
+tensorflow/core/kernels/boosted_trees/resources.cc
+tensorflow/core/kernels/boosted_trees/stats_ops.cc
+tensorflow/core/kernels/boosted_trees/training_ops.cc
tensorflow/core/kernels/bias_op.cc
tensorflow/core/kernels/bcast_ops.cc
tensorflow/core/kernels/batch_norm_op.cc
@@ -286,6 +291,7 @@ tensorflow/core/ops/data_flow_ops.cc
tensorflow/core/ops/ctc_ops.cc
tensorflow/core/ops/control_flow_ops.cc
tensorflow/core/ops/candidate_sampling_ops.cc
+tensorflow/core/ops/boosted_trees_ops.cc
tensorflow/core/ops/array_ops.cc
tensorflow/core/ops/array_grad.cc
tensorflow/core/kernels/spacetobatch_functor.cc
diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt
index d569bde637..1f254692d7 100644
--- a/tensorflow/contrib/makefile/tf_proto_files.txt
+++ b/tensorflow/contrib/makefile/tf_proto_files.txt
@@ -18,6 +18,7 @@ tensorflow/core/protobuf/device_properties.proto
tensorflow/core/protobuf/rewriter_config.proto
tensorflow/core/protobuf/tensor_bundle.proto
tensorflow/core/lib/core/error_codes.proto
+tensorflow/core/kernels/boosted_trees/boosted_trees.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/variable.proto
tensorflow/core/framework/types.proto
diff --git a/tensorflow/contrib/memory_stats/BUILD b/tensorflow/contrib/memory_stats/BUILD
index 72424c32e7..63843b993c 100644
--- a/tensorflow/contrib/memory_stats/BUILD
+++ b/tensorflow/contrib/memory_stats/BUILD
@@ -79,15 +79,3 @@ cuda_py_test(
"//tensorflow/python:random_ops",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/meta_graph_transform/BUILD b/tensorflow/contrib/meta_graph_transform/BUILD
index 4b5b1c3e15..24400789f8 100644
--- a/tensorflow/contrib/meta_graph_transform/BUILD
+++ b/tensorflow/contrib/meta_graph_transform/BUILD
@@ -59,15 +59,3 @@ filegroup(
"**/*.py",
]),
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD
index e90c525113..5ca42f41c1 100644
--- a/tensorflow/contrib/metrics/BUILD
+++ b/tensorflow/contrib/metrics/BUILD
@@ -97,14 +97,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD
index ca3f13479e..f50575b2cf 100644
--- a/tensorflow/contrib/model_pruning/BUILD
+++ b/tensorflow/contrib/model_pruning/BUILD
@@ -125,15 +125,3 @@ py_library(
":rnn_cells",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/BUILD b/tensorflow/contrib/model_pruning/examples/cifar10/BUILD
index e7848adcc5..30ea912222 100644
--- a/tensorflow/contrib/model_pruning/examples/cifar10/BUILD
+++ b/tensorflow/contrib/model_pruning/examples/cifar10/BUILD
@@ -68,15 +68,3 @@ py_binary(
"//tensorflow/contrib/model_pruning:pruning",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD
index 9f9802b8fe..a7be92a35e 100644
--- a/tensorflow/contrib/mpi_collectives/BUILD
+++ b/tensorflow/contrib/mpi_collectives/BUILD
@@ -126,15 +126,3 @@ tf_py_test(
],
tags = ["manual"],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD
index 94d01efee1..6cbfd03881 100644
--- a/tensorflow/contrib/nccl/BUILD
+++ b/tensorflow/contrib/nccl/BUILD
@@ -141,15 +141,3 @@ cuda_py_test(
"notap",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/nearest_neighbor/BUILD b/tensorflow/contrib/nearest_neighbor/BUILD
index 9500c18b1d..6fa7624467 100644
--- a/tensorflow/contrib/nearest_neighbor/BUILD
+++ b/tensorflow/contrib/nearest_neighbor/BUILD
@@ -111,15 +111,3 @@ tf_py_test(
"//tensorflow/python:client_testlib",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/nn/BUILD b/tensorflow/contrib/nn/BUILD
index 5543eb6c6e..ef7ab22646 100644
--- a/tensorflow/contrib/nn/BUILD
+++ b/tensorflow/contrib/nn/BUILD
@@ -98,14 +98,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index bacf15bbd6..c57c5e3f29 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -265,14 +265,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD
index bd9078ae76..6ca7fe8b6e 100644
--- a/tensorflow/contrib/periodic_resample/BUILD
+++ b/tensorflow/contrib/periodic_resample/BUILD
@@ -95,18 +95,6 @@ py_test(
# )
filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
-filegroup(
name = "custom_op_sources",
srcs = glob(
[
diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD
index a80f060b91..36e21af618 100644
--- a/tensorflow/contrib/predictor/BUILD
+++ b/tensorflow/contrib/predictor/BUILD
@@ -8,18 +8,6 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
py_library(
name = "predictor",
srcs = ["__init__.py"],
diff --git a/tensorflow/contrib/quantization/BUILD b/tensorflow/contrib/quantization/BUILD
index c19a31afb2..2de10e8fae 100644
--- a/tensorflow/contrib/quantization/BUILD
+++ b/tensorflow/contrib/quantization/BUILD
@@ -49,15 +49,3 @@ filegroup(
"**/*.py",
]),
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index 0b76296204..b9918fdee1 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -246,15 +246,3 @@ py_test(
"//tensorflow/python:platform_test",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index 5750be6f4c..4a8f8a04cc 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -134,9 +134,9 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
match.output_tensor)
- if nodes_modified_count != 1:
- raise ValueError(
- 'Unexpected inputs to op: %s' % match.output_tensor.name)
+ if nodes_modified_count == 0:
+ raise ValueError('Folding batch norms failed, %s had no outputs.' %
+ match.output_tensor.name)
def _FindFusedBatchNorms(graph):
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 019d123a68..2889016a84 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -305,7 +305,8 @@ def _FindLayersToQuantize(graph):
# the output of the final BiasAdd must be quantized. So we treat the BiasAdd
# as the 'activation_op' in the _LayerMatch, to ensure that it's output is
# quantized.
- final_layer_matcher = graph_matcher.GraphMatcher(bias_add_pattern)
+ final_layer_matcher = graph_matcher.GraphMatcher(
+ graph_matcher.OneofPattern([bias_add_pattern, folded_bias_add_pattern]))
for match_result in final_layer_matcher.match_graph(graph):
layer_op = match_result.get_op(layer_pattern)
weight_tensor = match_result.get_tensor(weight_identity_pattern)
@@ -463,11 +464,16 @@ def _InsertQuantOp(context,
lambda: inputs,
name=name_prefix + '/delayed_quant')
- nodes_modified_count = graph_editor.reroute_ts(
- [quant], [inputs], can_modify=consumers)
- if nodes_modified_count != len(consumers):
- raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join(
- [consumer.name for consumer in consumers]))
+ if consumers:
+ tensors_modified_count = graph_editor.reroute_ts(
+ [quant], [inputs], can_modify=consumers)
+ # Some operations can have multiple output tensors going to the same
+ # consumer. Since consumers is a set, we need to ensure that
+ # tensors_modified_count is greater than or equal to the length of the set
+ # of consumers.
+ if tensors_modified_count < len(consumers):
+ raise ValueError('No inputs quantized for ops: [%s]' % ', '.join(
+ [consumer.name for consumer in consumers]))
def _GetContextFromOp(op):
diff --git a/tensorflow/contrib/receptive_field/BUILD b/tensorflow/contrib/receptive_field/BUILD
index e975aeaea7..9325a14745 100644
--- a/tensorflow/contrib/receptive_field/BUILD
+++ b/tensorflow/contrib/receptive_field/BUILD
@@ -106,15 +106,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/reduce_slice_ops/BUILD b/tensorflow/contrib/reduce_slice_ops/BUILD
index b31f4488f5..02b3d66e46 100644
--- a/tensorflow/contrib/reduce_slice_ops/BUILD
+++ b/tensorflow/contrib/reduce_slice_ops/BUILD
@@ -101,15 +101,3 @@ tf_cc_test(
"//tensorflow/core:testlib",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/remote_fused_graph/pylib/BUILD b/tensorflow/contrib/remote_fused_graph/pylib/BUILD
index 27f0a7f58f..996b55f9b8 100644
--- a/tensorflow/contrib/remote_fused_graph/pylib/BUILD
+++ b/tensorflow/contrib/remote_fused_graph/pylib/BUILD
@@ -48,15 +48,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/resampler/BUILD b/tensorflow/contrib/resampler/BUILD
index f0ecc8b85a..48345d7030 100644
--- a/tensorflow/contrib/resampler/BUILD
+++ b/tensorflow/contrib/resampler/BUILD
@@ -85,14 +85,3 @@ cuda_py_test(
"//tensorflow/python:array_ops",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD
index 7e5e35d0b5..43c0f75955 100644
--- a/tensorflow/contrib/rnn/BUILD
+++ b/tensorflow/contrib/rnn/BUILD
@@ -321,19 +321,6 @@ tf_cc_test(
],
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "tools/**",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tf_gen_op_libs(
op_lib_names = [
"lstm_ops",
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD
index 245fe07f2b..faad40d335 100644
--- a/tensorflow/contrib/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/BUILD
@@ -81,15 +81,3 @@ py_test(
"//tensorflow/python/saved_model:utils",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/BUILD b/tensorflow/contrib/saved_model/cc/saved_model/BUILD
index ea4da80ba3..3c616c555b 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/cc/saved_model/BUILD
@@ -49,9 +49,3 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(["*"]),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD
index ab80c68b1a..a62069a252 100644
--- a/tensorflow/contrib/seq2seq/BUILD
+++ b/tensorflow/contrib/seq2seq/BUILD
@@ -211,15 +211,3 @@ cuda_py_test(
"//tensorflow/python:variables",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD
index 75a753ed89..31717305e7 100644
--- a/tensorflow/contrib/session_bundle/BUILD
+++ b/tensorflow/contrib/session_bundle/BUILD
@@ -17,18 +17,6 @@ load(
"tf_cc_test",
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "g3doc/sitemap.md",
- ],
- ),
-)
-
# TODO(b/32673259): add a test to continuously validate these files.
filegroup(
name = "session_bundle_half_plus_two",
diff --git a/tensorflow/contrib/session_bundle/example/BUILD b/tensorflow/contrib/session_bundle/example/BUILD
index dbbae01f36..9a56eab431 100644
--- a/tensorflow/contrib/session_bundle/example/BUILD
+++ b/tensorflow/contrib/session_bundle/example/BUILD
@@ -10,19 +10,6 @@ exports_files(["LICENSE"])
# vardef("PYTHON_BIN_PATH", "/usr/bin/python")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "g3doc/sitemap.md",
- ],
- ),
- visibility = ["//visibility:public"],
-)
-
py_binary(
name = "export_half_plus_two",
srcs = [
diff --git a/tensorflow/contrib/session_bundle/session_bundle_test.cc b/tensorflow/contrib/session_bundle/session_bundle_test.cc
index 6d997bac9e..612623ae30 100644
--- a/tensorflow/contrib/session_bundle/session_bundle_test.cc
+++ b/tensorflow/contrib/session_bundle/session_bundle_test.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
@@ -239,8 +240,8 @@ TEST(LoadSessionBundleFromPath, BasicTestRunOptionsThreadPoolInvalid) {
// Expect failed session run calls with invalid run-options.
EXPECT_FALSE(status.ok());
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("Invalid inter_op_thread_pool: 2"))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "Invalid inter_op_thread_pool: 2"))
<< status.error_message();
}
@@ -314,8 +315,8 @@ TEST_F(SessionBundleTest, ServingGraphEmpty) {
});
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
EXPECT_FALSE(status_.ok());
- EXPECT_TRUE(StringPiece(status_.error_message())
- .contains("Expected exactly one serving GraphDef"))
+ EXPECT_TRUE(str_util::StrContains(status_.error_message(),
+ "Expected exactly one serving GraphDef"))
<< status_.error_message();
}
@@ -330,8 +331,9 @@ TEST_F(SessionBundleTest, ServingGraphAnyIncorrectType) {
});
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
EXPECT_FALSE(status_.ok());
- EXPECT_TRUE(StringPiece(status_.error_message())
- .contains("Expected Any type_url for: tensorflow.GraphDef"))
+ EXPECT_TRUE(
+ str_util::StrContains(status_.error_message(),
+ "Expected Any type_url for: tensorflow.GraphDef"))
<< status_.error_message();
}
@@ -347,7 +349,8 @@ TEST_F(SessionBundleTest, ServingGraphAnyValueCorrupted) {
});
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
EXPECT_FALSE(status_.ok());
- EXPECT_TRUE(StringPiece(status_.error_message()).contains("Failed to unpack"))
+ EXPECT_TRUE(
+ str_util::StrContains(status_.error_message(), "Failed to unpack"))
<< status_.error_message();
}
@@ -362,9 +365,9 @@ TEST_F(SessionBundleTest, AssetFileAnyIncorrectType) {
});
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
EXPECT_FALSE(status_.ok());
- EXPECT_TRUE(
- StringPiece(status_.error_message())
- .contains("Expected Any type_url for: tensorflow.serving.AssetFile"))
+ EXPECT_TRUE(str_util::StrContains(
+ status_.error_message(),
+ "Expected Any type_url for: tensorflow.serving.AssetFile"))
<< status_.error_message();
}
@@ -380,7 +383,8 @@ TEST_F(SessionBundleTest, AssetFileAnyValueCorrupted) {
});
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
EXPECT_FALSE(status_.ok());
- EXPECT_TRUE(StringPiece(status_.error_message()).contains("Failed to unpack"))
+ EXPECT_TRUE(
+ str_util::StrContains(status_.error_message(), "Failed to unpack"))
<< status_.error_message();
}
@@ -395,8 +399,8 @@ TEST_F(SessionBundleTest, InitOpTooManyValues) {
});
status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
EXPECT_FALSE(status_.ok());
- EXPECT_TRUE(StringPiece(status_.error_message())
- .contains("Expected exactly one serving init op"))
+ EXPECT_TRUE(str_util::StrContains(status_.error_message(),
+ "Expected exactly one serving init op"))
<< status_.error_message();
}
diff --git a/tensorflow/contrib/session_bundle/signature_test.cc b/tensorflow/contrib/session_bundle/signature_test.cc
index 741b7fde9b..b1ff55552e 100644
--- a/tensorflow/contrib/session_bundle/signature_test.cc
+++ b/tensorflow/contrib/session_bundle/signature_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
@@ -33,8 +34,8 @@ namespace tensorflow {
namespace serving {
namespace {
-static bool HasSubstr(const string& base, const string& substr) {
- bool ok = StringPiece(base).contains(substr);
+static bool HasSubstr(StringPiece base, StringPiece substr) {
+ bool ok = str_util::StrContains(base, substr);
EXPECT_TRUE(ok) << base << ", expected substring " << substr;
return ok;
}
@@ -69,8 +70,8 @@ TEST(GetClassificationSignature, MissingSignature) {
ClassificationSignature signature;
const Status status = GetClassificationSignature(meta_graph_def, &signature);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("Expected a classification signature"))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "Expected a classification signature"))
<< status.error_message();
}
@@ -86,8 +87,8 @@ TEST(GetClassificationSignature, WrongSignatureType) {
ClassificationSignature signature;
const Status status = GetClassificationSignature(meta_graph_def, &signature);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("Expected a classification signature"))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "Expected a classification signature"))
<< status.error_message();
}
@@ -122,8 +123,8 @@ TEST(GetNamedClassificationSignature, MissingSignature) {
const Status status =
GetNamedClassificationSignature("foo", meta_graph_def, &signature);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("Missing signature named \"foo\""))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "Missing signature named \"foo\""))
<< status.error_message();
}
@@ -141,9 +142,9 @@ TEST(GetNamedClassificationSignature, WrongSignatureType) {
const Status status =
GetNamedClassificationSignature("foo", meta_graph_def, &signature);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(
- StringPiece(status.error_message())
- .contains("Expected a classification signature for name \"foo\""))
+ EXPECT_TRUE(str_util::StrContains(
+ status.error_message(),
+ "Expected a classification signature for name \"foo\""))
<< status.error_message();
}
@@ -176,8 +177,8 @@ TEST(GetRegressionSignature, MissingSignature) {
RegressionSignature signature;
const Status status = GetRegressionSignature(meta_graph_def, &signature);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("Expected a regression signature"))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "Expected a regression signature"))
<< status.error_message();
}
@@ -193,8 +194,8 @@ TEST(GetRegressionSignature, WrongSignatureType) {
RegressionSignature signature;
const Status status = GetRegressionSignature(meta_graph_def, &signature);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("Expected a regression signature"))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "Expected a regression signature"))
<< status.error_message();
}
@@ -227,8 +228,8 @@ TEST(GetNamedSignature, MissingSignature) {
Signature signature;
const Status status = GetNamedSignature("foo", meta_graph_def, &signature);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("Missing signature named \"foo\""))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "Missing signature named \"foo\""))
<< status.error_message();
}
@@ -370,7 +371,7 @@ TEST(RunClassification, RunNotOk) {
const Status status = RunClassification(signature, input_tensor, &session,
&classes_tensor, nullptr);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(StringPiece(status.error_message()).contains("Data is gone"))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), "Data is gone"))
<< status.error_message();
}
@@ -386,7 +387,8 @@ TEST(RunClassification, TooManyOutputs) {
const Status status = RunClassification(signature, input_tensor, &session,
&classes_tensor, nullptr);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(StringPiece(status.error_message()).contains("Expected 1 output"))
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(), "Expected 1 output"))
<< status.error_message();
}
@@ -402,8 +404,9 @@ TEST(RunClassification, WrongBatchOutputs) {
const Status status = RunClassification(signature, input_tensor, &session,
&classes_tensor, nullptr);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("Input batch size did not match output batch size"))
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(),
+ "Input batch size did not match output batch size"))
<< status.error_message();
}
@@ -449,7 +452,7 @@ TEST_F(RunRegressionTest, RunNotOk) {
const Status status =
RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(StringPiece(status.error_message()).contains("Data is gone"))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), "Data is gone"))
<< status.error_message();
}
@@ -460,8 +463,9 @@ TEST_F(RunRegressionTest, MismatchedSizeForBatchInputAndOutput) {
const Status status =
RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("Input batch size did not match output batch size"))
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(),
+ "Input batch size did not match output batch size"))
<< status.error_message();
}
@@ -488,7 +492,7 @@ TEST(GetSignatures, MissingSignature) {
const auto status = GetSignatures(meta_graph_def, &read_signatures);
EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
EXPECT_TRUE(
- StringPiece(status.error_message()).contains("Expected exactly one"))
+ str_util::StrContains(status.error_message(), "Expected exactly one"))
<< status.error_message();
}
@@ -502,9 +506,9 @@ TEST(GetSignatures, WrongProtoInAny) {
Signatures read_signatures;
const auto status = GetSignatures(meta_graph_def, &read_signatures);
EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("Expected Any type_url for: "
- "tensorflow.serving.Signatures"))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "Expected Any type_url for: "
+ "tensorflow.serving.Signatures"))
<< status.error_message();
}
@@ -519,7 +523,7 @@ TEST(GetSignatures, JunkInAny) {
Signatures read_signatures;
const auto status = GetSignatures(meta_graph_def, &read_signatures);
EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
- EXPECT_TRUE(StringPiece(status.error_message()).contains("Failed to unpack"))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), "Failed to unpack"))
<< status.error_message();
}
@@ -567,7 +571,7 @@ TEST(GetSignatures, MultipleSignaturesNotOK) {
const auto status = GetSignatures(meta_graph_def, &read_signatures);
EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
EXPECT_TRUE(
- StringPiece(status.error_message()).contains("Expected exactly one"))
+ str_util::StrContains(status.error_message(), "Expected exactly one"))
<< status.error_message();
}
@@ -641,8 +645,8 @@ TEST(GetGenericSignature, WrongSignatureType) {
const Status status =
GetGenericSignature("generic_bindings", meta_graph_def, &signature);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(StringPiece(status.error_message())
- .contains("Expected a generic signature:"))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "Expected a generic signature:"))
<< status.error_message();
}
diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD
index a83fc20596..fdecceff52 100644
--- a/tensorflow/contrib/signal/BUILD
+++ b/tensorflow/contrib/signal/BUILD
@@ -130,15 +130,3 @@ cuda_py_tests(
"//tensorflow/python:platform_test",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
index 1c052354b8..64cc8c7ea5 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
@@ -338,7 +338,7 @@ class FrameTest(test.TestCase):
def test_constant_folding(self):
"""frame should be constant foldable for constant inputs."""
- for pad_end in [False, True]:
+ for pad_end in [True, False]:
g = ops.Graph()
with g.as_default():
frame_length, frame_step = 32, 16
diff --git a/tensorflow/contrib/slim/BUILD b/tensorflow/contrib/slim/BUILD
index c2f106c2b2..516e3ea073 100644
--- a/tensorflow/contrib/slim/BUILD
+++ b/tensorflow/contrib/slim/BUILD
@@ -178,15 +178,3 @@ py_test(
"//tensorflow/python:summary",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/slim/python/slim/data/BUILD b/tensorflow/contrib/slim/python/slim/data/BUILD
index 5daabbd62e..dc12e67fc6 100644
--- a/tensorflow/contrib/slim/python/slim/data/BUILD
+++ b/tensorflow/contrib/slim/python/slim/data/BUILD
@@ -193,15 +193,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/slim/python/slim/nets/BUILD b/tensorflow/contrib/slim/python/slim/nets/BUILD
index 7f03aaf085..8bbdf96384 100644
--- a/tensorflow/contrib/slim/python/slim/nets/BUILD
+++ b/tensorflow/contrib/slim/python/slim/nets/BUILD
@@ -317,15 +317,3 @@ py_test(
"//tensorflow/python:variables",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/solvers/BUILD b/tensorflow/contrib/solvers/BUILD
index 87b67486ad..5247288d54 100644
--- a/tensorflow/contrib/solvers/BUILD
+++ b/tensorflow/contrib/solvers/BUILD
@@ -93,16 +93,3 @@ cuda_py_test(
"//tensorflow/python:platform_test",
],
)
-
-# All files
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/sparsemax/BUILD b/tensorflow/contrib/sparsemax/BUILD
index fcfaa2aba4..b729fff261 100644
--- a/tensorflow/contrib/sparsemax/BUILD
+++ b/tensorflow/contrib/sparsemax/BUILD
@@ -65,15 +65,3 @@ cuda_py_tests(
"//tensorflow/python:platform_test",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/specs/BUILD b/tensorflow/contrib/specs/BUILD
index 084953a0a2..055b04db8a 100644
--- a/tensorflow/contrib/specs/BUILD
+++ b/tensorflow/contrib/specs/BUILD
@@ -60,15 +60,3 @@ tf_py_test(
"//tensorflow/python:variables",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/staging/BUILD b/tensorflow/contrib/staging/BUILD
index bc4a289468..0c86f3db1d 100644
--- a/tensorflow/contrib/staging/BUILD
+++ b/tensorflow/contrib/staging/BUILD
@@ -6,18 +6,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
py_library(
name = "staging",
srcs = ["__init__.py"],
diff --git a/tensorflow/contrib/stat_summarizer/BUILD b/tensorflow/contrib/stat_summarizer/BUILD
index 5fd02efbf6..d4096751c4 100644
--- a/tensorflow/contrib/stat_summarizer/BUILD
+++ b/tensorflow/contrib/stat_summarizer/BUILD
@@ -32,15 +32,3 @@ tf_py_test(
"//tensorflow/python:variables",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/stateless/BUILD b/tensorflow/contrib/stateless/BUILD
index 6e259e1d32..dcbef2881d 100644
--- a/tensorflow/contrib/stateless/BUILD
+++ b/tensorflow/contrib/stateless/BUILD
@@ -38,15 +38,3 @@ cuda_py_test(
"//tensorflow/python:random_ops",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD
index 80563c5e15..fda1367b15 100644
--- a/tensorflow/contrib/summary/BUILD
+++ b/tensorflow/contrib/summary/BUILD
@@ -83,18 +83,6 @@ py_library(
],
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
# NOTE: target cannot be testonly because it needs to be in the pip
# package. Sigh.
py_library(
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 1e4cc3f095..11a59ec22b 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -16,20 +16,6 @@ package(default_visibility = ["//visibility:public"])
exports_files(["LICENSE"])
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "kernels/v4/*",
- "proto/*",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
# ---------------------------------- V2 ops ------------------------------------------#
filegroup(
name = "v2_op_sources",
diff --git a/tensorflow/contrib/tensor_forest/hybrid/BUILD b/tensorflow/contrib/tensor_forest/hybrid/BUILD
index a2a3b485f6..b7185e09c7 100644
--- a/tensorflow/contrib/tensor_forest/hybrid/BUILD
+++ b/tensorflow/contrib/tensor_forest/hybrid/BUILD
@@ -12,18 +12,6 @@ package(default_visibility = ["//visibility:public"])
exports_files(["LICENSE"])
filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
-filegroup(
name = "custom_op_sources",
srcs = glob(
["core/ops/*.cc"],
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/BUILD b/tensorflow/contrib/tensor_forest/kernels/v4/BUILD
index 794b76d858..b1b1559383 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/BUILD
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/BUILD
@@ -11,11 +11,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-filegroup(
- name = "all_files",
- srcs = glob(["**/*"]),
-)
-
DECISION_TREE_RESOURCE_DEPS = [
":decision_node_evaluator",
":input_data",
diff --git a/tensorflow/contrib/tensor_forest/proto/BUILD b/tensorflow/contrib/tensor_forest/proto/BUILD
index 1cfef44af1..04fd6a9839 100644
--- a/tensorflow/contrib/tensor_forest/proto/BUILD
+++ b/tensorflow/contrib/tensor_forest/proto/BUILD
@@ -6,14 +6,6 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
package(default_visibility = ["//visibility:public"])
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tf_proto_library(
name = "fertile_stats_proto",
srcs = ["fertile_stats.proto"],
diff --git a/tensorflow/contrib/tensorboard/BUILD b/tensorflow/contrib/tensorboard/BUILD
index d833744d0c..f4efd9717d 100644
--- a/tensorflow/contrib/tensorboard/BUILD
+++ b/tensorflow/contrib/tensorboard/BUILD
@@ -88,15 +88,3 @@ py_test(
"//tensorflow/python:platform",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/tensorboard/db/BUILD b/tensorflow/contrib/tensorboard/db/BUILD
index 4175d8adb5..3f6b4cdc9a 100644
--- a/tensorflow/contrib/tensorboard/db/BUILD
+++ b/tensorflow/contrib/tensorboard/db/BUILD
@@ -135,9 +135,3 @@ tf_cc_binary(
"//tensorflow/core/lib/db:sqlite",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(["*"]),
- visibility = ["//tensorflow:__pkg__"],
-)
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 906cc3f034..2f316767b3 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -272,15 +272,3 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/testing/BUILD b/tensorflow/contrib/testing/BUILD
index 0be6aa755b..8a40e111d7 100644
--- a/tensorflow/contrib/testing/BUILD
+++ b/tensorflow/contrib/testing/BUILD
@@ -22,15 +22,3 @@ py_library(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/text/BUILD b/tensorflow/contrib/text/BUILD
index 698fdd830f..38d91f7e49 100644
--- a/tensorflow/contrib/text/BUILD
+++ b/tensorflow/contrib/text/BUILD
@@ -111,14 +111,3 @@ py_test(
"//tensorflow/python:training",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/tfprof/BUILD b/tensorflow/contrib/tfprof/BUILD
index 28adce71d4..e7f4ebdd36 100644
--- a/tensorflow/contrib/tfprof/BUILD
+++ b/tensorflow/contrib/tfprof/BUILD
@@ -20,15 +20,3 @@ py_library(
"//tensorflow/python/profiler:tfprof_logger",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/timeseries/BUILD b/tensorflow/contrib/timeseries/BUILD
index 6ba069778c..f2b8786a52 100644
--- a/tensorflow/contrib/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/BUILD
@@ -31,15 +31,3 @@ py_library(
"//tensorflow/contrib/timeseries/python/timeseries/state_space_models:test_utils",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD
index bb86ecb220..40cf9147b3 100644
--- a/tensorflow/contrib/timeseries/examples/BUILD
+++ b/tensorflow/contrib/timeseries/examples/BUILD
@@ -106,15 +106,3 @@ py_test(
"//tensorflow/python/estimator:estimator_py",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index ed3ed4c0e1..55a25e39fe 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -442,15 +442,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index f4304f2560..51d0c0ca3f 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -126,6 +126,27 @@ class TimeSeriesRegressorTest(test.TestCase):
signatures=signatures,
session=sess)
+ # Test cold starting
+ batch_numpy_times = numpy.tile(
+ numpy.arange(30, dtype=numpy.int64)[None, :], (10, 1))
+ batch_numpy_values = numpy.ones([10, 30, 1])
+ state = saved_model_utils.cold_start_filter(
+ signatures=signatures,
+ session=sess,
+ features={
+ feature_keys.FilteringFeatures.TIMES: batch_numpy_times,
+ feature_keys.FilteringFeatures.VALUES: batch_numpy_values
+ }
+ )
+ predict_times = numpy.tile(
+ numpy.arange(30, 45, dtype=numpy.int64)[None, :], (10, 1))
+ predictions = saved_model_utils.predict_continuation(
+ continue_from=state,
+ times=predict_times,
+ signatures=signatures,
+ session=sess)
+ self.assertAllEqual([10, 15, 1], predictions["mean"].shape)
+
def test_fit_restore_fit_ar_regressor(self):
def _estimator_fn(model_dir):
return estimators.ARRegressor(
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py
index 3d7e615290..4cf6bbcfd4 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head.py
@@ -154,8 +154,10 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc
no_state_features = {
k: v for k, v in features.items()
if not k.startswith(feature_keys.State.STATE_PREFIX)}
- cold_filtering_outputs = self.create_loss(
- no_state_features, estimator_lib.ModeKeys.EVAL)
+ # Ignore any state management when cold-starting. The model's default
+ # start state is replicated across the batch.
+ cold_filtering_outputs = self.model.define_loss(
+ features=no_state_features, mode=estimator_lib.ModeKeys.EVAL)
return estimator_lib.EstimatorSpec(
mode=estimator_lib.ModeKeys.PREDICT,
export_outputs={
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD
index c86d06e923..ca25ccd2b8 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD
@@ -268,15 +268,3 @@ py_library(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index eea19e9465..3e32a7a85c 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -119,6 +119,8 @@ py_library(
srcs = ["python/profiler/__init__.py"],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/contrib/tpu/profiler:tpu_profiler_analysis_pb2_grpc",
+ "//tensorflow/contrib/tpu/profiler:tpu_profiler_analysis_proto_py",
"//tensorflow/contrib/tpu/profiler:trace_events_proto_py",
"//tensorflow/python:util",
],
@@ -281,16 +283,3 @@ tf_py_test(
"//tensorflow/python:framework_test_lib",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- include = [
- "**/*",
- ],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD
index 0a52d0b13b..1c32993e8e 100644
--- a/tensorflow/contrib/tpu/profiler/BUILD
+++ b/tensorflow/contrib/tpu/profiler/BUILD
@@ -6,18 +6,6 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_cc")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tf_proto_library(
name = "tpu_profiler_proto",
srcs = ["tpu_profiler.proto"],
@@ -127,7 +115,5 @@ py_library(
srcs = ["tpu_profiler_analysis_pb2_grpc.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
- deps = [
- ":tpu_profiler_analysis_proto_py",
- ],
+ deps = [":tpu_profiler_analysis_proto_py"],
)
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
index a730d6142d..0b78cf8695 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
@@ -76,7 +76,7 @@ def main(unused_argv=None):
else:
tpu_cluster_resolver = (
tf.contrib.cluster_resolver.TPUClusterResolver(
- tpu_names=[FLAGS.tpu_name],
+ [FLAGS.tpu_name],
zone=FLAGS.tpu_zone,
project=FLAGS.gcp_project))
service_addr = tpu_cluster_resolver.get_master()
diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
index 20ed7419fd..590db2c376 100644
--- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
+++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
@@ -199,10 +199,22 @@ message HostOpsPerTpuStep {
map<int32, int32> step_diffs = 5;
}
+message HostOpsDetailsPerCore {
+ // Map from core id to HostOpsPerTpuStep.
+ map<int32, HostOpsPerTpuStep> core_map = 1;
+}
+
+message HostOpsDetailsPerHost {
+ // Map from hostname to a map from core id to HostOpsPerTpuStep.
+ map<string, HostOpsDetailsPerCore> host_map = 1;
+}
+
// Result proto for the host ops for all TPU steps.
message HostOpsResult {
- // A sequence of HostOpsPerTpuStep (one for each TPU step)
- repeated HostOpsPerTpuStep host_op_sequence = 1;
+ reserved 1; // (was repeated HostOpsPerTpuStep host_op_sequence)
+ // A sequence of records with one for each TPU step. Each record
+ // is a map from hostname to a map from core id to HostOpsPerTpuStep.
+ repeated HostOpsDetailsPerHost hostops_details = 2;
}
// Result proto for TfStatsHelper.
diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis_pb2_grpc.py b/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis_pb2_grpc.py
index c28fef22a9..8f51488288 100644
--- a/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis_pb2_grpc.py
+++ b/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis_pb2_grpc.py
@@ -22,7 +22,7 @@ from __future__ import print_function
import grpc
-from third_party.tensorflow.contrib.tpu.profiler import tpu_profiler_analysis_pb2 as third__party_dot_tensorflow_dot_contrib_dot_tpu_dot_profiler_dot_tpu__profiler__analysis__pb2
+from tensorflow.contrib.tpu.profiler import tpu_profiler_analysis_pb2 as third__party_dot_tensorflow_dot_contrib_dot_tpu_dot_profiler_dot_tpu__profiler__analysis__pb2
class TPUProfileAnalysisStub(object):
diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD
index e166098567..fcfbbe1a21 100644
--- a/tensorflow/contrib/tpu/proto/BUILD
+++ b/tensorflow/contrib/tpu/proto/BUILD
@@ -4,17 +4,6 @@ exports_files(["LICENSE"])
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tf_proto_library(
name = "tpu_embedding_config_proto",
srcs = [
diff --git a/tensorflow/contrib/tpu/python/profiler/__init__.py b/tensorflow/contrib/tpu/python/profiler/__init__.py
index bde13f0527..15ce6aceec 100644
--- a/tensorflow/contrib/tpu/python/profiler/__init__.py
+++ b/tensorflow/contrib/tpu/python/profiler/__init__.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
+from tensorflow.contrib.tpu.profiler.tpu_profiler_analysis_pb2 import *
from tensorflow.contrib.tpu.profiler.trace_events_pb2 import *
# pylint: enable=wildcard-import,unused-import
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 38b5ea2310..cc1a7fd801 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -35,10 +35,16 @@ _TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
_SERVICE_KEY = run_config_lib._SERVICE_KEY
_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
_NUM_CORES_PER_HOST = 8
-
# pylint: enable=protected-access
+class InputPipelineConfig(object):
+ r"""Please see the definition of these values in TPUConfig."""
+ PER_SHARD_V1 = 1
+ PER_HOST_V1 = 2
+ PER_HOST_V2 = 3
+
+
# TODO(b/72511246) Provide a simplified api to configure model parallelism.
class TPUConfig(
collections.namedtuple('TPUConfig', [
@@ -68,13 +74,16 @@ class TPUConfig(
partitioned across 4 cores which span two cores in both x and y
coordinates. Please refer to @{tf.contrib.tpu.Topology} for the
geometry of a TPU mesh.
- per_host_input_for_training: If `True`, `input_fn` is invoked Per-Host
- rather than Per-Core. With Per-Host input pipeline deployment, `input_fn`
- is invoked once on each host. With Per-Core input pipeline deployment, it
- is invoked once for each core. To be precise, with a global batch size
- `train_batch_size` in `TPUEstimator` constructor, the batch size for each
- shard is `train_batch_size` // #hosts. With Per-Core input pipeline
- deployment, the shard batch size is `train_batch_size` // #cores.
+ per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`,
+ `input_fn` is invoked per-host rather than per-core. With per-host input
+ pipeline configuration, `input_fn` is invoked once on each host. With the
+ per-core input pipeline configuration, it is invoked once for each core.
+ With a global batch size `train_batch_size` in `TPUEstimator` constructor,
+ the batch size for each shard is `train_batch_size` // #hosts in the
+ `True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is
+ `train_batch_size` // #cores. With the per-core input pipeline
+ configuration, the shard batch size is also `train_batch_size` // #cores.
+ Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN.
tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
within TPUEstimator, however when using ClusterSpec propagation in more
esoteric cluster configurations, you may need to specify the job name as a
@@ -117,6 +126,13 @@ class TPUConfig(
raise ValueError('computation_shape elements can only be 1 or 2; got '
'computation_shape={}'.format(computation_shape))
+ # per_host_input_for_training may be True, False, or integer in [1..3].
+ # Map legacy values (True, False) to numeric values.
+ if per_host_input_for_training is False:
+ per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1
+ elif per_host_input_for_training is True:
+ per_host_input_for_training = InputPipelineConfig.PER_HOST_V1
+
# Check initial_infeed_sleep_secs.
if initial_infeed_sleep_secs:
util_lib.check_positive_integer(initial_infeed_sleep_secs,
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index 3bac2db77e..fbc1173e49 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -24,6 +24,7 @@ import copy
import numpy as np
from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment
+from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.platform import tf_logging as logging
@@ -205,7 +206,13 @@ class _TPUContext(object):
"""Return true if input_fn is invoked per-core (other than per-host)."""
mode = self._assert_mode()
return (mode == model_fn_lib.ModeKeys.TRAIN and
- not self._config.tpu_config.per_host_input_for_training)
+ (self._config.tpu_config.per_host_input_for_training is
+ tpu_config.InputPipelineConfig.PER_SHARD_V1))
+
+ def is_input_per_host_with_iterators(self):
+ """Return true if input_fn should be run in the per-host v2 config."""
+ return (self._config.tpu_config.per_host_input_for_training is
+ tpu_config.InputPipelineConfig.PER_HOST_V2)
def is_running_on_cpu(self, is_export_mode=False):
"""Determines whether the input_fn and model_fn should be invoked on CPU.
@@ -271,7 +278,8 @@ class _TPUContext(object):
return global_batch_size
# On TPU
- if self.is_input_sharded_per_core():
+ if self.is_input_sharded_per_core() or (
+ self.is_input_per_host_with_iterators()):
# We prohibit per core input sharding for the model parallelism case,
# therefore it is safe to use num_cores here.
return global_batch_size // self.num_cores
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 152f8c8c69..fa56708f44 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -740,6 +740,61 @@ def generate_per_host_enqueue_ops_fn_for_host(
return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
+def generate_per_host_v2_enqueue_ops_fn_for_host(
+ ctx, input_fn, inputs_structure_recorder, device, host_id):
+ """Generates infeed enqueue ops for per-host input_fn on a single host."""
+ del host_id # unused
+ captured_infeed_queue = _CapturedObject()
+ hooks = []
+
+ with ops.device(device):
+ inputs = _Inputs.from_input_fn(input_fn())
+
+ is_dataset = inputs.is_dataset
+ if not is_dataset:
+ raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 '
+ 'input pipeline configuration.')
+ if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
+ # TODO(b/XXX): Add predict support for PER_HOST_V2
+ raise TypeError('Most PREDICT not yet supported in PER_HOST_V2 mode.')
+
+ hooks.append(inputs.dataset_initializer_hook())
+
+ def enqueue_ops_fn():
+ """Generates the per_host enqueue ops."""
+ control_deps = []
+ per_host_sharded_inputs = []
+ num_replicas_per_host = ctx.num_of_replicas_per_host
+ with ops.device(device):
+ if not inputs.is_dataset:
+ raise TypeError('`input_fn` must return a `Dataset` for this mode.')
+ for _ in range(num_replicas_per_host):
+ # Use control dependencies to ensure a deterministic ordering.
+ with ops.control_dependencies(control_deps):
+ features, labels = inputs.features_and_labels() # Calls get_next()
+
+ inputs_structure_recorder.validate_and_record_structure(
+ features, labels)
+ flattened_inputs = (
+ inputs_structure_recorder.flatten_features_and_labels(
+ features, labels))
+
+ control_deps.extend(flattened_inputs)
+ per_host_sharded_inputs.append(flattened_inputs)
+
+ infeed_queue = tpu_feed.InfeedQueue(
+ number_of_tuple_elements=len(per_host_sharded_inputs[0]))
+ captured_infeed_queue.capture(infeed_queue)
+ infeed_queue.set_configuration_from_sharded_input_tensors(
+ per_host_sharded_inputs)
+
+ per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
+ per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function)
+ return per_host_enqueue_ops
+
+ return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
+
+
class _InputPipeline(object):
"""`_InputPipeline` handles invoking `input_fn` and piping to infeed queue.
@@ -975,10 +1030,17 @@ class _InputPipeline(object):
host_device = tpu_host_placement_fn(host_id=host_id)
with ops.device(host_device):
with ops.name_scope('input_pipeline_task%d' % (host_id)):
- enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
- generate_per_host_enqueue_ops_fn_for_host(
- self._ctx, self._input_fn, self._inputs_structure_recorder,
- self._batch_axis, host_device, host_id))
+ if self._ctx.is_input_per_host_with_iterators():
+ enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
+ generate_per_host_v2_enqueue_ops_fn_for_host(
+ self._ctx, self._input_fn,
+ self._inputs_structure_recorder, host_device, host_id))
+ else:
+ enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
+ generate_per_host_enqueue_ops_fn_for_host(
+ self._ctx, self._input_fn,
+ self._inputs_structure_recorder, self._batch_axis,
+ host_device, host_id))
all_hooks.extend(hooks)
# NOTE(xiejw): We dispatch here based on the return type of the
@@ -1724,7 +1786,7 @@ class TPUEstimator(estimator_lib.Estimator):
labels to match up with the corresponding images. If None is supplied,
and per_host_input_for_training is True, batches will be sharded based
on the major dimension. If tpu_config.per_host_input_for_training is
- False, batch_axis is ignored.
+ False or `PER_HOST_V2`, batch_axis is ignored.
Raises:
ValueError: `params` has reserved keys already.
@@ -1744,7 +1806,8 @@ class TPUEstimator(estimator_lib.Estimator):
raise ValueError('`train_batch_size` cannot be `None`')
util_lib.check_positive_integer(train_batch_size, 'train_batch_size')
- if (not config.tpu_config.per_host_input_for_training and
+ if (config.tpu_config.per_host_input_for_training is
+ tpu_config.InputPipelineConfig.PER_SHARD_V1 and
config.tpu_config.computation_shape):
raise ValueError(
'Model parallelism only supports per host input for training. '
@@ -2362,6 +2425,10 @@ class _Inputs(object):
def features_and_labels(self):
"""Gets `features` and `labels`."""
if self.is_dataset:
+ if self._iterator is None:
+ raise RuntimeError('Internal error: Must call dataset_initializer_hook '
+ 'before calling features_and_labels(). Please file '
+ 'a bug!')
return _Inputs._parse_inputs(self._iterator.get_next())
return (self._features, self._labels)
diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD
index 6ae2f38252..4d2bfd3e43 100644
--- a/tensorflow/contrib/training/BUILD
+++ b/tensorflow/contrib/training/BUILD
@@ -308,18 +308,6 @@ py_test(
],
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tf_proto_library(
name = "protos_all",
srcs = glob(["**/*.proto"]),
diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD
index 6c766e4f1c..d9ccda8e89 100644
--- a/tensorflow/contrib/util/BUILD
+++ b/tensorflow/contrib/util/BUILD
@@ -75,15 +75,3 @@ py_library(
"//tensorflow/python:util",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD
index 80a5d07ea4..9720fd6e86 100644
--- a/tensorflow/contrib/verbs/BUILD
+++ b/tensorflow/contrib/verbs/BUILD
@@ -12,18 +12,6 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
exports_files(["LICENSE"])
filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
-filegroup(
name = "c_srcs",
data = glob([
"**/*.cc",
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 1d11410332..614e06cf83 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -144,11 +144,14 @@ load(
"tf_cuda_tests_tags",
"if_static",
)
+load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
)
+exports_files(["ops/ops.pbtxt"])
+
# -----------------------------------------------------------------------------
# Public targets
@@ -379,13 +382,13 @@ cc_library(
)
cc_library(
- name = "session_message",
- srcs = ["util/session_message.cc"],
- hdrs = ["util/session_message.h"],
+ name = "stacktrace",
+ srcs = glob(["platform/*/stacktrace.h"]),
+ hdrs = ["platform/stacktrace.h"],
deps = [
- ":framework",
- ":lib",
- ":protos_all_cc",
+ ":abi",
+ ":lib_platform",
+ "//tensorflow/core/platform/default/build_config:stacktrace",
],
)
@@ -394,8 +397,20 @@ cc_library(
srcs = ["platform/stacktrace_handler.cc"],
hdrs = ["platform/stacktrace_handler.h"],
deps = [
- ":lib",
+ ":abi",
":lib_platform",
+ ":stacktrace",
+ ],
+)
+
+cc_library(
+ name = "session_message",
+ srcs = ["util/session_message.cc"],
+ hdrs = ["util/session_message.h"],
+ deps = [
+ ":framework",
+ ":lib",
+ ":protos_all_cc",
],
)
@@ -443,6 +458,7 @@ tf_cuda_library(
"framework/attr_value_util.h",
"framework/bfloat16.h",
"framework/cancellation.h",
+ "framework/collective.h",
"framework/common_shape_fns.h",
"framework/control_flow.h", # TODO(josh11b): Make internal?
"framework/dataset.h",
@@ -613,6 +629,7 @@ tf_gen_op_libs(
op_lib_names = [
"batch_ops",
"bitwise_ops",
+ "boosted_trees_ops",
"candidate_sampling_ops",
"checkpoint_ops",
"control_flow_ops",
@@ -725,6 +742,7 @@ cc_library(
":audio_ops_op_lib",
":batch_ops_op_lib",
":bitwise_ops_op_lib",
+ ":boosted_trees_ops_op_lib",
":candidate_sampling_ops_op_lib",
":checkpoint_ops_op_lib",
":control_flow_ops_op_lib",
@@ -866,6 +884,7 @@ cc_library(
"//tensorflow/core/kernels:audio",
"//tensorflow/core/kernels:batch_kernels",
"//tensorflow/core/kernels:bincount_op",
+ "//tensorflow/core/kernels:boosted_trees_ops",
"//tensorflow/core/kernels:candidate_sampler_ops",
"//tensorflow/core/kernels:checkpoint_ops",
"//tensorflow/core/kernels:control_flow_ops",
@@ -924,6 +943,9 @@ cc_library(
"//tensorflow/core/kernels:mkl_softmax_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
"//tensorflow/core/kernels:mkl_aggregate_ops",
+ ]) + if_cuda([
+ "//tensorflow/core/grappler/optimizers:gpu_swapping_kernels",
+ "//tensorflow/core/grappler/optimizers:gpu_swapping_ops",
]),
)
@@ -1624,6 +1646,7 @@ cc_library(
"platform/**/env_time.cc",
"platform/**/cuda_libdevice_path.cc",
"platform/**/device_tracer.cc",
+ "platform/abi.cc",
"platform/variant_coding.cc",
"platform/**/variant_cord_coding.cc",
],
@@ -1635,6 +1658,7 @@ cc_library(
"platform/**/stream_executor.h",
"platform/**/env_time.cc",
"platform/**/device_tracer.cc",
+ "platform/abi.cc",
"platform/variant_coding.cc",
"platform/**/variant_cord_coding.cc",
] +
@@ -1648,6 +1672,7 @@ cc_library(
deps = tf_additional_lib_deps() + [
":lib_hash_crc32c_accelerate_internal",
":lib_proto_parsing",
+ ":abi",
"//third_party/eigen3",
"//tensorflow/core/platform/default/build_config:platformlib",
"@snappy",
@@ -2157,6 +2182,11 @@ tf_cuda_library(
CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/allocator_retry.h",
"common_runtime/bfc_allocator.h",
+ "common_runtime/collective_executor_mgr.h",
+ "common_runtime/collective_param_resolver_local.h",
+ "common_runtime/collective_rma_local.h",
+ "common_runtime/device_resolver_local.h",
+ "common_runtime/buf_rendezvous.h",
"common_runtime/build_graph_options.h",
"common_runtime/constant_folding.h",
"common_runtime/copy_tensor.h",
@@ -2195,7 +2225,11 @@ tf_cuda_library(
"common_runtime/accumulate_n_optimizer.cc",
"common_runtime/allocator_retry.cc",
"common_runtime/bfc_allocator.cc",
+ "common_runtime/buf_rendezvous.cc",
"common_runtime/build_graph_options.cc",
+ "common_runtime/collective_executor_mgr.cc",
+ "common_runtime/collective_param_resolver_local.cc",
+ "common_runtime/collective_rma_local.cc",
"common_runtime/constant_folding.cc",
"common_runtime/copy_tensor.cc",
"common_runtime/costmodel_manager.cc",
@@ -2203,6 +2237,7 @@ tf_cuda_library(
"common_runtime/device.cc",
"common_runtime/device_factory.cc",
"common_runtime/device_mgr.cc",
+ "common_runtime/device_resolver_local.cc",
"common_runtime/device_set.cc",
"common_runtime/executor.cc",
"common_runtime/function.cc",
@@ -2810,6 +2845,11 @@ tf_cc_tests(
name = "higher_level_tests",
size = "small",
srcs = [
+ "common_runtime/buf_rendezvous_test.cc",
+ "common_runtime/collective_executor_mgr_test.cc",
+ "common_runtime/collective_param_resolver_local_test.cc",
+ "common_runtime/collective_rma_local_test.cc",
+ "common_runtime/device_resolver_local_test.cc",
"common_runtime/device_set_test.cc",
"common_runtime/optimization_registry_test.cc",
"common_runtime/pending_counts_test.cc",
@@ -3820,18 +3860,6 @@ cc_library(
# -----------------------------------------------------------------------------
# Google-internal targets go here (must be at the end).
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
alias(
name = "android_srcs_no_runtime",
actual = ":mobile_srcs_no_runtime",
diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD
index 58dbac4e8e..19d6438809 100644
--- a/tensorflow/core/api_def/BUILD
+++ b/tensorflow/core/api_def/BUILD
@@ -18,18 +18,6 @@ load(
)
filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
-filegroup(
name = "base_api_def",
srcs = glob(["base_api/*"]),
visibility = ["//tensorflow:internal"],
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt
new file mode 100644
index 0000000000..b1921e3507
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt
@@ -0,0 +1,87 @@
+op {
+ graph_op_name: "BoostedTreesCalculateBestGainsPerFeature"
+ visibility: HIDDEN
+ in_arg {
+ name: "node_id_range"
+ description: <<END
+A Rank 1 tensor (shape=[2]) to specify the range [first, last] of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1]+1)` (Note that the last index node_id_range[1] is inclusive).
+END
+ }
+ in_arg {
+ name: "stats_summary_list"
+ description: <<END
+A list of Rank 3 tensor (#shape=[max_splits, bucket, 2]) for accumulated stats summary (gradient/hessian) per node per buckets for each feature. The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used.
+END
+ }
+ out_arg {
+ name: "node_ids_list"
+ description: <<END
+An output list of Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes.
+END
+ }
+ out_arg {
+ name: "gains_list"
+ description: <<END
+An output list of Rank 1 tensors indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes.
+END
+ }
+ out_arg {
+ name: "thresholds_list"
+ description: <<END
+An output list of Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes.
+END
+ }
+ out_arg {
+ name: "left_node_contribs_list"
+ description: <<END
+A list of Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes.
+END
+ }
+ out_arg {
+ name: "right_node_contribs_list"
+ description: <<END
+A list of Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node.
+END
+ }
+ attr {
+ name: "l1"
+ description: <<END
+l1 regularization factor on leaf weights, per instance based.
+END
+ }
+ attr {
+ name: "l2"
+ description: <<END
+l2 regularization factor on leaf weights, per instance based.
+END
+ }
+ attr {
+ name: "tree_complexity"
+ description: <<END
+adjustment to the gain, per leaf based.
+END
+ }
+ attr {
+ name: "max_splits"
+ description: <<END
+the number of nodes that can be split in the whole tree. Used as a dimension of output tensors.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+inferred from the size of `stats_summary_list`; the number of total features.
+END
+ }
+ summary: "Calculates gains for each feature and returns the best possible split information for the feature."
+ description: <<END
+The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature.
+
+It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split.
+
+In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features).
+
+The length of output lists are all of the same length, `num_features`.
+The output shapes are compatible in a way that the first dimension of all tensors of all lists are the same and equal to the number of possible split nodes for each feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateEnsemble.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateEnsemble.pbtxt
new file mode 100644
index 0000000000..aee73b910f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateEnsemble.pbtxt
@@ -0,0 +1,23 @@
+op {
+ graph_op_name: "BoostedTreesCreateEnsemble"
+ visibility: HIDDEN
+ in_arg {
+ name: "tree_ensemble_handle"
+ description: <<END
+Handle to the tree ensemble resource to be created.
+END
+ }
+ in_arg {
+ name: "stamp_token"
+ description: <<END
+Token to use as the initial value of the resource stamp.
+END
+ }
+ in_arg {
+ name: "tree_ensemble_serialized"
+ description: <<END
+Serialized proto of the tree ensemble.
+END
+ }
+ summary: "Creates a tree ensemble model and returns a handle to it."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesDeserializeEnsemble.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesDeserializeEnsemble.pbtxt
new file mode 100644
index 0000000000..b1602ba045
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesDeserializeEnsemble.pbtxt
@@ -0,0 +1,26 @@
+op {
+ graph_op_name: "BoostedTreesDeserializeEnsemble"
+ visibility: HIDDEN
+ in_arg {
+ name: "tree_ensemble_handle"
+ description: <<END
+Handle to the tree ensemble.
+END
+ }
+ in_arg {
+ name: "stamp_token"
+ description: <<END
+Token to use as the new value of the resource stamp.
+END
+ }
+ in_arg {
+ name: "tree_ensemble_serialized"
+ description: <<END
+Serialized proto of the ensemble.
+END
+ }
+ summary: "Deserializes a serialized tree ensemble config and replaces current tree"
+ description: <<END
+ensemble.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesEnsembleResourceHandleOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesEnsembleResourceHandleOp.pbtxt
new file mode 100644
index 0000000000..1bce5639a2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesEnsembleResourceHandleOp.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "BoostedTreesEnsembleResourceHandleOp"
+ visibility: HIDDEN
+ summary: "Creates a handle to a BoostedTreesEnsembleResource"
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesGetEnsembleStates.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesGetEnsembleStates.pbtxt
new file mode 100644
index 0000000000..ef45a92498
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesGetEnsembleStates.pbtxt
@@ -0,0 +1,35 @@
+op {
+ graph_op_name: "BoostedTreesGetEnsembleStates"
+ visibility: HIDDEN
+ in_arg {
+ name: "tree_ensemble_handle"
+ description: <<END
+Handle to the tree ensemble.
+END
+ }
+ out_arg {
+ name: "stamp_token"
+ description: <<END
+Stamp token of the tree ensemble resource.
+END
+ }
+ out_arg {
+ name: "num_trees"
+ description: <<END
+The number of trees in the tree ensemble resource.
+END
+ }
+ out_arg {
+ name: "num_finalized_trees"
+ description: <<END
+The number of trees that were finished successfully.
+END
+ }
+ out_arg {
+ name: "num_attempted_layers"
+ description: <<END
+The number of layers we attempted to build (but not necessarily succeeded).
+END
+ }
+ summary: "Retrieves the tree ensemble resource stamp token."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeStatsSummary.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeStatsSummary.pbtxt
new file mode 100644
index 0000000000..dc0856c900
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeStatsSummary.pbtxt
@@ -0,0 +1,56 @@
+op {
+ graph_op_name: "BoostedTreesMakeStatsSummary"
+ visibility: HIDDEN
+ in_arg {
+ name: "node_ids"
+ description: <<END
+int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer.
+END
+ }
+ in_arg {
+ name: "gradients"
+ description: <<END
+float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients.
+END
+ }
+ in_arg {
+ name: "hessians"
+ description: <<END
+float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians.
+END
+ }
+ in_arg {
+ name: "bucketized_features_list"
+ description: <<END
+int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column).
+END
+ }
+ out_arg {
+ name: "stats_summary"
+ description: <<END
+output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians.
+END
+ }
+ attr {
+ name: "max_splits"
+ description: <<END
+int; the maximum number of splits possible in the whole tree.
+END
+ }
+ attr {
+ name: "num_buckets"
+ description: <<END
+int; equals to the maximum possible value of bucketized feature.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+int; inferred from the size of bucketized_features_list; the number of features.
+END
+ }
+ summary: "Makes the summary of accumulated stats for the batch."
+ description: <<END
+The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesPredict.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesPredict.pbtxt
new file mode 100644
index 0000000000..b23e77a1fa
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesPredict.pbtxt
@@ -0,0 +1,41 @@
+op {
+ graph_op_name: "BoostedTreesPredict"
+ visibility: HIDDEN
+ in_arg {
+ name: "bucketized_features"
+ description: <<END
+A list of rank 1 Tensors containing bucket id for each
+feature.
+END
+ }
+ out_arg {
+ name: "logits"
+ description: <<END
+Output rank 2 Tensor containing logits for each example.
+END
+ }
+ attr {
+ name: "num_bucketized_features"
+ description: <<END
+Inferred.
+END
+ }
+ attr {
+ name: "logits_dimension"
+ description: <<END
+scalar, dimension of the logits, to be used for partial logits
+shape.
+END
+ }
+ attr {
+ name: "max_depth"
+ description: <<END
+scalar, max depth of trees. To be used for parallelization costs.
+END
+ }
+ summary: "Runs multiple additive regression ensemble predictors on input instances and"
+ description: <<END
+computes the logits. It is designed to be used during prediction.
+It traverses all the trees and calculates the final score for each instance.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesSerializeEnsemble.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesSerializeEnsemble.pbtxt
new file mode 100644
index 0000000000..c0b3688d8a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesSerializeEnsemble.pbtxt
@@ -0,0 +1,23 @@
+op {
+ graph_op_name: "BoostedTreesSerializeEnsemble"
+ visibility: HIDDEN
+ in_arg {
+ name: "tree_ensemble_handle"
+ description: <<END
+Handle to the tree ensemble.
+END
+ }
+ out_arg {
+ name: "stamp_token"
+ description: <<END
+Stamp token of the tree ensemble resource.
+END
+ }
+ out_arg {
+ name: "tree_ensemble_serialized"
+ description: <<END
+Serialized proto of the ensemble.
+END
+ }
+ summary: "Serializes the tree ensemble to a proto."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesTrainingPredict.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesTrainingPredict.pbtxt
new file mode 100644
index 0000000000..7203d3cb58
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesTrainingPredict.pbtxt
@@ -0,0 +1,69 @@
+op {
+ graph_op_name: "BoostedTreesTrainingPredict"
+ visibility: HIDDEN
+ in_arg {
+ name: "cached_tree_ids"
+ description: <<END
+Rank 1 Tensor containing cached tree ids which is the starting
+tree of prediction.
+END
+ }
+ in_arg {
+ name: "cached_node_ids"
+ description: <<END
+Rank 1 Tensor containing cached node id which is the starting
+node of prediction.
+END
+ }
+ in_arg {
+ name: "bucketized_features"
+ description: <<END
+A list of rank 1 Tensors containing bucket id for each
+feature.
+END
+ }
+ out_arg {
+ name: "partial_logits"
+ description: <<END
+Rank 2 Tensor containing logits update (with respect to cached
+values stored) for each example.
+END
+ }
+ out_arg {
+ name: "tree_ids"
+ description: <<END
+Rank 1 Tensor containing new tree ids for each example.
+END
+ }
+ out_arg {
+ name: "node_ids"
+ description: <<END
+Rank 1 Tensor containing new node ids in the new tree_ids.
+END
+ }
+ attr {
+ name: "num_bucketized_features"
+ description: <<END
+Inferred.
+END
+ }
+ attr {
+ name: "logits_dimension"
+ description: <<END
+scalar, dimension of the logits, to be used for partial logits
+shape.
+END
+ }
+ attr {
+ name: "max_depth"
+ description: <<END
+scalar, max depth of trees. To be used for parallelization costs.
+END
+ }
+ summary: "Runs multiple additive regression ensemble predictors on input instances and"
+ description: <<END
+computes the update to cached logits. It is designed to be used during training.
+It traverses the trees starting from cached tree id and cached node id and
+calculates the updates to be pushed to the cache.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesUpdateEnsemble.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesUpdateEnsemble.pbtxt
new file mode 100644
index 0000000000..00f8953875
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesUpdateEnsemble.pbtxt
@@ -0,0 +1,82 @@
+op {
+ graph_op_name: "BoostedTreesUpdateEnsemble"
+ visibility: HIDDEN
+ in_arg {
+ name: "tree_ensemble_handle"
+ description: <<END
+Handle to the ensemble variable.
+END
+ }
+ in_arg {
+ name: "feature_ids"
+ description: <<END
+Rank 1 tensor with ids for each feature. This is the real id of
+the feature that will be used in the split.
+END
+ }
+ in_arg {
+ name: "node_ids"
+ description: <<END
+List of rank 1 tensors representing the nodes for which this feature
+has a split.
+END
+ }
+ in_arg {
+ name: "gains"
+ description: <<END
+List of rank 1 tensors representing the gains for each of the feature's
+split.
+END
+ }
+ in_arg {
+ name: "thresholds"
+ description: <<END
+List of rank 1 tensors representing the thesholds for each of the
+feature's split.
+END
+ }
+ in_arg {
+ name: "left_node_contribs"
+ description: <<END
+List of rank 2 tensors with left leaf contribs for each of
+the feature's splits. Will be added to the previous node values to constitute
+the values of the left nodes.
+END
+ }
+ in_arg {
+ name: "right_node_contribs"
+ description: <<END
+List of rank 2 tensors with right leaf contribs for each
+of the feature's splits. Will be added to the previous node values to constitute
+the values of the right nodes.
+END
+ }
+ attr {
+ name: "max_depth"
+ description: <<END
+Max depth of the tree to build.
+END
+ }
+ attr {
+ name: "learning_rate"
+ description: <<END
+shrinkage const for each new tree.
+END
+ }
+ attr {
+ name: "pruning_mode"
+ description: <<END
+0-No pruning, 1-Pre-pruning, 2-Post-pruning.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+Number of features that have best splits returned. INFERRED.
+END
+ }
+ summary: "Updates the tree ensemble by either adding a layer to the last tree being grown"
+ description: <<END
+or by starting a new tree.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesEnsembleInitialized.pbtxt b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesEnsembleInitialized.pbtxt
new file mode 100644
index 0000000000..d54b7ef32a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesEnsembleInitialized.pbtxt
@@ -0,0 +1,17 @@
+op {
+ graph_op_name: "IsBoostedTreesEnsembleInitialized"
+ visibility: HIDDEN
+ in_arg {
+ name: "tree_ensemble_handle"
+ description: <<END
+Handle to the tree ensemble resouce.
+END
+ }
+ out_arg {
+ name: "is_initialized"
+ description: <<END
+output boolean on whether it is initialized or not.
+END
+ }
+ summary: "Checks whether a tree ensemble has been initialized."
+}
diff --git a/tensorflow/core/common_runtime/buf_rendezvous.cc b/tensorflow/core/common_runtime/buf_rendezvous.cc
new file mode 100644
index 0000000000..b57eb2943a
--- /dev/null
+++ b/tensorflow/core/common_runtime/buf_rendezvous.cc
@@ -0,0 +1,166 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/buf_rendezvous.h"
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+
+namespace tensorflow {
+
+BufRendezvous::~BufRendezvous() {
+ mutex_lock l(mu_);
+ if (!hook_table_.empty()) {
+ PurgeTable(errors::Internal("Delete called on non-empty BufRendezvous"),
+ &hook_table_);
+ }
+}
+
+void BufRendezvous::StartAbort(const Status& s) {
+ CHECK(!s.ok());
+ HookTable dummy_table;
+ {
+ mutex_lock l(mu_);
+ status_.Update(s);
+ hook_table_.swap(dummy_table);
+ }
+ PurgeTable(s, &dummy_table);
+}
+
+void BufRendezvous::PurgeTable(const Status& s, HookTable* table) {
+ for (auto& it : *table) {
+ Hook* h = it.second;
+ if (h->cons_cb != nullptr) {
+ h->cons_cb(s, nullptr);
+ }
+ if (h->prod_cb != nullptr) {
+ h->prod_cb(s);
+ }
+ delete h;
+ }
+ table->clear();
+}
+
+string BufRendezvous::Hook::DebugString() const {
+ return strings::StrCat("[dev:", (prod_dev ? prod_dev->name() : "none"),
+ ", ctx:", reinterpret_cast<uint64>(prod_ctx),
+ ", val:", reinterpret_cast<uint64>(prod_value),
+ ", pcb:", reinterpret_cast<uint64>(&prod_cb),
+ ", ccb:", reinterpret_cast<uint64>(&cons_cb), "]");
+}
+
+void BufRendezvous::ProvideBuf(const string& key, Device* dev,
+ DeviceContext* dev_ctx, const Tensor* v,
+ const AllocatorAttributes& attr,
+ const ProducerCallback& done) {
+ Hook* h = nullptr;
+ Status providebuf_status;
+ do {
+ mutex_lock l(mu_);
+ if (!status_.ok()) {
+ providebuf_status = status_;
+ break;
+ } else {
+ auto it = hook_table_.find(key);
+ if (it == hook_table_.end()) {
+ h = new Hook;
+ it = hook_table_.insert(std::make_pair(key, h)).first;
+ } else {
+ if (it->second->prod_cb != nullptr) {
+ providebuf_status = errors::Internal(
+ "BufRendezvous::ProvideBuf already called for key ", key);
+ break;
+ }
+ h = it->second;
+ }
+ // Populate Hook with all of the prod values.
+ h->prod_dev = dev;
+ h->prod_ctx = dev_ctx;
+ h->prod_value = v;
+ h->prod_attr = attr;
+ h->prod_cb = done;
+ // If consumer is waiting, kick off right away, removing Hook from table.
+ if (h->cons_cb != nullptr) {
+ hook_table_.erase(it);
+ } else {
+ h = nullptr;
+ }
+ }
+ } while (false);
+ if (h) {
+ h->cons_cb(Status::OK(), h);
+ }
+ if (!providebuf_status.ok()) {
+ done(providebuf_status);
+ }
+}
+
+void BufRendezvous::ConsumeBuf(const string& key,
+ const ConsumerCallback& done) {
+ Hook* existing_hook = nullptr;
+ Status consumebuf_status;
+ do {
+ mutex_lock l(mu_);
+ if (!status_.ok()) {
+ consumebuf_status = status_;
+ break;
+ }
+ auto it = hook_table_.find(key);
+ if (it != hook_table_.end()) {
+ // Prepare to consume immediately.
+ if (it->second->cons_cb) {
+ consumebuf_status =
+ errors::Internal("Second consumer arrived for key ", key);
+ break;
+ }
+ existing_hook = it->second;
+ hook_table_.erase(it);
+ existing_hook->cons_cb = done;
+ } else {
+ // Hang consumer callback on the Hook.
+ Hook* h = new Hook;
+ hook_table_[key] = h;
+ h->cons_cb = done;
+ return;
+ }
+ } while (false);
+ if (existing_hook) {
+ existing_hook->cons_cb(Status::OK(), existing_hook);
+ return;
+ }
+ if (!consumebuf_status.ok()) {
+ done(consumebuf_status, nullptr);
+ return;
+ }
+}
+
+/*static*/
+void BufRendezvous::DoneWithHook(Hook* h) {
+ h->prod_cb(Status::OK());
+ delete h;
+}
+
+void BufRendezvous::LogContents() {
+ mutex_lock l(mu_);
+ LOG(INFO) << strings::StrCat("BufRendezvous ",
+ strings::Hex(reinterpret_cast<uint64>(this)),
+ " step_id=", step_id_, " current contents:");
+ for (auto it : hook_table_) {
+ LOG(INFO) << it.first << ":" << it.second->DebugString();
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/buf_rendezvous.h b/tensorflow/core/common_runtime/buf_rendezvous.h
new file mode 100644
index 0000000000..e94e88b323
--- /dev/null
+++ b/tensorflow/core/common_runtime/buf_rendezvous.h
@@ -0,0 +1,103 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
+#define TENSORFLOW_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
+
+#include <functional>
+#include <string>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+class Device;
+class DeviceContext;
+class Tensor;
+
+// EXPERIMENTAL: RDMA oriented producer/consumer rendezvous on a local
+// Tensor value for which DMAHelper::CanUseDMA() is true, i.e. dense
+// numeric types. Similar to Rendezvous but never owns a Ref on the
+// tensor, instead it uses an explicit callback to the producer when
+// the consumer side is finished with the value. This allows the
+// producer to perform in-place updates on the source buffer or to take
+// other actions that depend on knowing the consumer has passed a certain
+// execution point.
+class BufRendezvous {
+ public:
+ explicit BufRendezvous(uint64 step_id) : step_id_(step_id) {}
+
+ ~BufRendezvous();
+
+ // Inform all all waiting parties that this BufRendezvous is defunct
+ // because of an error Status interrupting the Step.
+ void StartAbort(const Status& s);
+
+ struct Hook;
+ // Provided by the consumer to be called when access to the buffer
+ // is available. If the Status arg is not OK, then hook will not
+ // be populated. Ownership of Hook passes to consumer with the
+ // callback.
+ typedef std::function<void(const Status&, Hook*)> ConsumerCallback;
+ // Provided by the producer to be called when the consumer has finished
+ // reading the buffer and will no longer access it.
+ typedef std::function<void(const Status&)> ProducerCallback;
+
+ struct Hook {
+ Device* prod_dev;
+ DeviceContext* prod_ctx;
+ const Tensor* prod_value;
+ AllocatorAttributes prod_attr;
+ ProducerCallback prod_cb;
+ ConsumerCallback cons_cb;
+ Hook()
+ : prod_dev(nullptr),
+ prod_ctx(nullptr),
+ prod_value(nullptr),
+ prod_cb(nullptr),
+ cons_cb(nullptr) {}
+ string DebugString() const;
+ };
+
+ // Called to advertise availability of a Tensor value corresponding
+ // to key. That value must stay valid until done is called.
+ void ProvideBuf(const string& key, Device* dev, DeviceContext* dev_ctx,
+ const Tensor* v, const AllocatorAttributes& attr,
+ const ProducerCallback& done);
+
+ // Called to request access to a Tensor value corresponding to key.
+ // Consumer is provide with a Hook as soon as availble.
+ void ConsumeBuf(const string& key, const ConsumerCallback& done);
+
+ // Consumer must call this function when it's done reading the Hook provided
+ // by the ConsumerCallback. This function will invoke the producer callback
+ // and then delete h.
+ static void DoneWithHook(Hook* h);
+
+ // Write the current contents of the table to the INFO log.
+ void LogContents();
+
+ protected:
+ const uint64 step_id_;
+ mutex mu_;
+ Status status_ GUARDED_BY(mu_);
+ typedef gtl::FlatMap<string, Hook*> HookTable;
+ HookTable hook_table_ GUARDED_BY(mu_);
+
+ void PurgeTable(const Status& s, HookTable* table);
+};
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
diff --git a/tensorflow/core/common_runtime/buf_rendezvous_test.cc b/tensorflow/core/common_runtime/buf_rendezvous_test.cc
new file mode 100644
index 0000000000..0e798235bf
--- /dev/null
+++ b/tensorflow/core/common_runtime/buf_rendezvous_test.cc
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/buf_rendezvous.h"
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+#define NUM_DEVS 3
+
+class BufRendezvousTest : public ::testing::Test {
+ protected:
+ BufRendezvousTest() {
+ br_.reset(new BufRendezvous(123));
+ fake_dev_ptr_ = reinterpret_cast<Device*>(512LLU);
+ fake_dev_ctx_ = reinterpret_cast<DeviceContext*>(1024LLU);
+ a_ = Tensor(DT_FLOAT, TensorShape({24}));
+ b_ = Tensor(DT_FLOAT, TensorShape({24}));
+ }
+
+ Device* fake_dev_ptr_ = nullptr;
+ DeviceContext* fake_dev_ctx_ = nullptr;
+ Tensor a_;
+ Tensor b_;
+ AllocatorAttributes aa_;
+ std::unique_ptr<BufRendezvous> br_;
+};
+
+TEST_F(BufRendezvousTest, CorrectUseProducerFirst) {
+ Status prod_status;
+ Status cons_status;
+ bool prod_callback_called = false;
+ bool cons_callback_called = false;
+ Notification note;
+ br_->ProvideBuf(
+ "key0", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+ [&note, &prod_status, &prod_callback_called](const Status& s) {
+ prod_status = s;
+ prod_callback_called = true;
+ note.Notify();
+ });
+ EXPECT_FALSE(prod_callback_called);
+ br_->ConsumeBuf("key0", [this, &cons_status, &cons_callback_called](
+ const Status& s, BufRendezvous::Hook* h) {
+ cons_status = s;
+ cons_callback_called = true;
+ ASSERT_TRUE(h != nullptr);
+ EXPECT_EQ(h->prod_dev, fake_dev_ptr_);
+ EXPECT_EQ(h->prod_ctx, fake_dev_ctx_);
+ EXPECT_EQ(h->prod_value, &a_);
+ br_->DoneWithHook(h);
+ });
+ EXPECT_TRUE(cons_callback_called);
+ note.WaitForNotification();
+ EXPECT_TRUE(prod_callback_called);
+ TF_EXPECT_OK(cons_status);
+ TF_EXPECT_OK(prod_status);
+}
+
+TEST_F(BufRendezvousTest, CorrectUseConsumerFirst) {
+ Status prod_status;
+ Status cons_status;
+ bool prod_callback_called = false;
+ bool cons_callback_called = false;
+ Notification note;
+ br_->ConsumeBuf("key0", [this, &cons_status, &cons_callback_called](
+ const Status& s, BufRendezvous::Hook* h) {
+ cons_status = s;
+ cons_callback_called = true;
+ ASSERT_TRUE(h != nullptr);
+ EXPECT_EQ(h->prod_dev, fake_dev_ptr_);
+ EXPECT_EQ(h->prod_ctx, fake_dev_ctx_);
+ EXPECT_EQ(h->prod_value, &a_);
+ br_->DoneWithHook(h);
+ });
+ EXPECT_FALSE(cons_callback_called);
+ br_->ProvideBuf(
+ "key0", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+ [&note, &prod_status, &prod_callback_called](const Status& s) {
+ prod_status = s;
+ prod_callback_called = true;
+ note.Notify();
+ });
+ EXPECT_TRUE(cons_callback_called);
+ note.WaitForNotification();
+ EXPECT_TRUE(prod_callback_called);
+ TF_EXPECT_OK(cons_status);
+ TF_EXPECT_OK(prod_status);
+}
+
+TEST_F(BufRendezvousTest, ErrorDuplicatePut) {
+ bool prod_callback_called = false;
+ br_->ProvideBuf("key0", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+ [this, &prod_callback_called](const Status& s) {
+ prod_callback_called = true;
+ });
+ Status bad_status;
+ Notification note;
+ br_->ProvideBuf("key0", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+ [&bad_status, &note](const Status& s) {
+ bad_status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ EXPECT_FALSE(bad_status.ok());
+ EXPECT_EQ("BufRendezvous::ProvideBuf already called for key key0",
+ bad_status.error_message());
+ EXPECT_FALSE(prod_callback_called);
+ br_.reset();
+}
+
+TEST_F(BufRendezvousTest, ErrorDeleteNonEmpty) {
+ Status cons_status;
+ br_->ConsumeBuf(
+ "key0", [this, &cons_status](const Status& s, BufRendezvous::Hook* h) {
+ cons_status = s;
+ EXPECT_EQ(h, nullptr);
+ });
+ EXPECT_TRUE(cons_status.ok());
+ br_.reset();
+ EXPECT_FALSE(cons_status.ok());
+ EXPECT_EQ("Delete called on non-empty BufRendezvous",
+ cons_status.error_message());
+}
+
+TEST_F(BufRendezvousTest, AbortNonEmpty) {
+ Status cons_status;
+ Status prod_status;
+ Notification prod_note;
+ Notification cons_note;
+ br_->ConsumeBuf("key0", [this, &cons_note, &cons_status](
+ const Status& s, BufRendezvous::Hook* h) {
+ cons_status = s;
+ cons_note.Notify();
+ });
+ br_->ProvideBuf("key1", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+ [this, &prod_note, &prod_status](const Status& s) {
+ prod_status = s;
+ prod_note.Notify();
+ });
+ br_->StartAbort(errors::Internal("Falling sky detected"));
+ prod_note.WaitForNotification();
+ cons_note.WaitForNotification();
+ EXPECT_FALSE(prod_status.ok());
+ EXPECT_EQ(prod_status.error_message(), "Falling sky detected");
+ EXPECT_FALSE(cons_status.ok());
+ EXPECT_EQ(cons_status.error_message(), "Falling sky detected");
+}
+
+TEST_F(BufRendezvousTest, AbortEmpty) {
+ br_->StartAbort(errors::Internal("Falling sky detected"));
+}
+
+TEST_F(BufRendezvousTest, UseAfterAbort) {
+ br_->StartAbort(errors::Internal("Falling sky detected"));
+ Status cons_status;
+ Status prod_status;
+ Notification prod_note;
+ Notification cons_note;
+ br_->ConsumeBuf("key0", [this, &cons_note, &cons_status](
+ const Status& s, BufRendezvous::Hook* h) {
+ cons_status = s;
+ cons_note.Notify();
+ });
+ br_->ProvideBuf("key1", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
+ [this, &prod_note, &prod_status](const Status& s) {
+ prod_status = s;
+ prod_note.Notify();
+ });
+ prod_note.WaitForNotification();
+ cons_note.WaitForNotification();
+ EXPECT_FALSE(prod_status.ok());
+ EXPECT_EQ(prod_status.error_message(), "Falling sky detected");
+ EXPECT_FALSE(cons_status.ok());
+ EXPECT_EQ(cons_status.error_message(), "Falling sky detected");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.cc b/tensorflow/core/common_runtime/collective_executor_mgr.cc
new file mode 100644
index 0000000000..a5c4946e58
--- /dev/null
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.cc
@@ -0,0 +1,114 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+
+#include "tensorflow/core/common_runtime/build_graph_options.h"
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+
+namespace tensorflow {
+namespace {
+// TODO(tucker): Temporary class just until a real CollectiveExecutor
+// implementation is submitted in a later CL.
+class DummyCollectiveExecutor : public CollectiveExecutor {
+ public:
+ explicit DummyCollectiveExecutor(CollectiveExecutorMgr* ce_mgr)
+ : CollectiveExecutor(ce_mgr) {}
+
+ ~DummyCollectiveExecutor() override {}
+
+ void RecvFromPeer(const string& peer_device, const string& peer_task,
+ bool peer_is_local, const string& key, Device* to_device,
+ DeviceContext* to_device_ctx,
+ const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+ const DeviceLocality& client_locality,
+ const StatusCallback& done) override {
+ done(errors::Internal("Unimplemented"));
+ }
+
+ void PostToPeer(const string& peer_device, const string& peer_task,
+ const string& key, Device* from_device,
+ DeviceContext* from_device_ctx,
+ const AllocatorAttributes& from_alloc_attr,
+ const Tensor* from_tensor,
+ const DeviceLocality& client_locality,
+ const StatusCallback& done) override {
+ done(errors::Internal("Unimplemented"));
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(DummyCollectiveExecutor);
+};
+} // namespace
+
+CollectiveExecutorMgr::CollectiveExecutorMgr(
+ const ConfigProto& config, const DeviceMgr* dev_mgr,
+ DeviceResolverInterface* dev_resolver,
+ ParamResolverInterface* param_resolver)
+ : dev_mgr_(dev_mgr),
+ dev_resolver_(dev_resolver),
+ param_resolver_(param_resolver) {}
+
+CollectiveExecutorMgr::~CollectiveExecutorMgr() {
+ for (auto iter : executor_table_) {
+ iter.second->Unref();
+ }
+}
+
+CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) {
+ CollectiveExecutor* ce = nullptr;
+ {
+ mutex_lock l(exec_mu_);
+ auto it = executor_table_.find(step_id);
+ if (it != executor_table_.end()) {
+ ce = it->second;
+ } else {
+ ce = new DummyCollectiveExecutor(this);
+ executor_table_[step_id] = ce;
+ }
+ ce->Ref();
+ }
+ return ce;
+}
+
+void CollectiveExecutorMgr::Cleanup(int64 step_id) {
+ CollectiveExecutor* ce = nullptr;
+ {
+ mutex_lock l(exec_mu_);
+ auto it = executor_table_.find(step_id);
+ if (it != executor_table_.end()) {
+ ce = it->second;
+ executor_table_.erase(it);
+ }
+ }
+ if (ce) ce->Unref();
+}
+
+void CollectiveExecutorMgr::GetStepSequenceAsync(
+ const GetStepSequenceRequest* request, GetStepSequenceResponse* response,
+ const StatusCallback& done) {
+ done(errors::Internal(
+ "CollectiveExecutorMgr does not implement GetStepSequence."));
+}
+
+void CollectiveExecutorMgr::RefreshStepIdSequenceAsync(
+ int64 graph_key, const StatusCallback& done) {
+ done(errors::Internal(
+ "CollectiveExecutorMgr does not implement RefreshStepIdSequence."));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.h b/tensorflow/core/common_runtime/collective_executor_mgr.h
new file mode 100644
index 0000000000..4b42e2b4d1
--- /dev/null
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.h
@@ -0,0 +1,70 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
+#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
+
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+namespace tensorflow {
+class ConfigProto;
+class DeviceMgr;
+
+class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
+ public:
+ CollectiveExecutorMgr(const ConfigProto& config, const DeviceMgr* dev_mgr,
+ DeviceResolverInterface* dev_resolver,
+ ParamResolverInterface* param_resolver);
+
+ virtual ~CollectiveExecutorMgr();
+
+ CollectiveExecutor* FindOrCreate(int64 step_id) override;
+
+ void Cleanup(int64 step_id) override;
+
+ ParamResolverInterface* GetParamResolver() const override {
+ return param_resolver_.get();
+ }
+
+ DeviceResolverInterface* GetDeviceResolver() const override {
+ return dev_resolver_.get();
+ }
+
+ void GetStepSequenceAsync(const GetStepSequenceRequest* request,
+ GetStepSequenceResponse* response,
+ const StatusCallback& done) override;
+
+ void RefreshStepIdSequenceAsync(int64 graph_key,
+ const StatusCallback& done) override;
+
+ int64 NextStepId(int64 graph_key) override {
+ return CollectiveExecutor::kInvalidId;
+ }
+
+ void RetireStepId(int64 graph_key, int64 step_id) override {}
+
+ protected:
+ const DeviceMgr* dev_mgr_;
+ std::unique_ptr<DeviceResolverInterface> dev_resolver_;
+ std::unique_ptr<ParamResolverInterface> param_resolver_;
+ CollectiveRemoteAccess* remote_access_;
+ string task_name_;
+ mutex exec_mu_;
+ // Map from step_id to CollectiveExecutor
+ gtl::FlatMap<int64, CollectiveExecutor*> executor_table_ GUARDED_BY(exec_mu_);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
new file mode 100644
index 0000000000..34c9163d6a
--- /dev/null
+++ b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
@@ -0,0 +1,98 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+
+#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+#define NUM_DEVS 3
+
+class CollectiveExecutorMgrTest : public ::testing::Test {
+ protected:
+ CollectiveExecutorMgrTest() {
+ ConfigProto cp;
+ SessionOptions options;
+ auto* device_count = options.config.mutable_device_count();
+ string task_name = "/job:localhost/replica:0/task:0";
+ device_count->insert({"CPU", NUM_DEVS});
+ TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
+ device_mgr_.reset(new DeviceMgr(devices_));
+ DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get());
+ cme_.reset(new CollectiveExecutorMgr(
+ cp, device_mgr_.get(), drl,
+ new CollectiveParamResolverLocal(device_mgr_.get(), drl, task_name)));
+ }
+
+ std::unique_ptr<CollectiveExecutorMgr> cme_;
+ std::vector<Device*> devices_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+};
+
+TEST_F(CollectiveExecutorMgrTest, FindOrCreate) {
+ CollectiveExecutor::Handle* h =
+ new CollectiveExecutor::Handle(cme_->FindOrCreate(1), true);
+ EXPECT_TRUE(h->get());
+ CollectiveExecutor::Handle* h2 =
+ new CollectiveExecutor::Handle(cme_->FindOrCreate(1), true);
+ EXPECT_EQ(h->get(), h2->get());
+ CollectiveExecutor* ce = h->get();
+ delete h;
+ delete h2;
+ CollectiveExecutor::Handle h3(cme_->FindOrCreate(1), true);
+ EXPECT_EQ(ce, h3.get());
+ cme_->Cleanup(1);
+}
+
+TEST_F(CollectiveExecutorMgrTest, StepSequenceRelated) {
+ EXPECT_EQ(CollectiveExecutor::kInvalidId, cme_->NextStepId(123));
+ Notification ss_note;
+ Status ss_status;
+ cme_->RefreshStepIdSequenceAsync(
+ 123, [this, &ss_status, &ss_note](const Status& s) {
+ ss_status = s;
+ ss_note.Notify();
+ });
+ ss_note.WaitForNotification();
+ EXPECT_FALSE(ss_status.ok());
+ EXPECT_EQ(ss_status.error_message(),
+ "CollectiveExecutorMgr does not implement RefreshStepIdSequence.");
+ Notification gs_note;
+ Status gs_status;
+ GetStepSequenceRequest* req = nullptr;
+ GetStepSequenceResponse* resp = nullptr;
+ cme_->GetStepSequenceAsync(req, resp,
+ [this, &gs_status, &gs_note](const Status& s) {
+ gs_status = s;
+ gs_note.Notify();
+ });
+ gs_note.WaitForNotification();
+ EXPECT_FALSE(gs_status.ok());
+ EXPECT_EQ(gs_status.error_message(),
+ "CollectiveExecutorMgr does not implement GetStepSequence.");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
new file mode 100644
index 0000000000..b34950b2f4
--- /dev/null
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -0,0 +1,666 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+
+#include "tensorflow/core/common_runtime/device_mgr.h"
+
+namespace tensorflow {
+
+CollectiveParamResolverLocal::CollectiveParamResolverLocal(
+ const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
+ const string& task_name)
+ : dev_mgr_(dev_mgr), dev_resolver_(dev_resolver), task_name_(task_name) {}
+
+void CollectiveParamResolverLocal::CompleteGroupAsync(
+ const CompleteGroupRequest* request, CompleteGroupResponse* response,
+ CancellationManager* cancel_mgr, const StatusCallback& done) {
+ done(
+ errors::Internal("CompleteGroup is not implemented by "
+ "CollectiveParamResolverLocal which is "
+ "intended only for non-distributed deployment."));
+}
+
+void CollectiveParamResolverLocal::CompleteGroupLocal(
+ const string& device, CollectiveParams* cp, const GroupRecCallback& done) {
+ VLOG(1) << "CompleteGroupLocal " << cp << ": " << cp->ToString();
+ std::vector<StatusCallback> to_be_called;
+ GroupRec* gr = nullptr;
+ {
+ mutex_lock l(group_mu_);
+ auto it = group_table_.find(cp->group.group_key);
+ if (it == group_table_.end()) {
+ gr = new GroupRec;
+ gr->group.group_key = cp->group.group_key;
+ gr->group.group_size = cp->group.group_size;
+ gr->group.device_type = cp->group.device_type;
+ group_table_[gr->group.group_key].reset(gr);
+ VLOG(2) << "New group_key=" << gr->group.group_key
+ << " group_size=" << gr->group.group_size;
+ } else {
+ gr = it->second.get();
+ }
+ }
+ Status status;
+ {
+ mutex_lock gr_lock(gr->mu);
+ if (!gr->device_set.empty()) {
+ // Check for consistency with existing GroupRec.
+ if (cp->group.device_type != gr->group.device_type) {
+ status = errors::Internal(
+ "Collective Op ", cp->name, " is assigned to device ", device,
+ " with type ", cp->group.device_type.type_string(),
+ " and group_key ", cp->group.group_key, " but that group has type ",
+ gr->group.device_type.type_string());
+ } else if (cp->group.group_size != gr->group.group_size) {
+ status = errors::Internal(
+ "Collective Op ", cp->name, " has group_size ",
+ cp->group.group_size, " and group_key", cp->group.group_key,
+ " but that group has size ", gr->group.group_size);
+ }
+ }
+ if (status.ok()) {
+ // Insert device if not already present.
+ auto it = gr->device_set.find(device);
+ if (it == gr->device_set.end()) {
+ if (gr->device_set.size() == gr->group.group_size) {
+ // The group is already full.
+ status = errors::Internal(
+ "Collective Op ", cp->name, " is assigned to device ", device,
+ " and group_key ", cp->group.group_key,
+ " but that group doesn't contain that device.");
+ } else {
+ // This is a new device that has not yet joined the group.
+ gr->device_set.insert(device);
+ gr->device_list.push_back(device);
+ DeviceNameUtils::ParsedName parsed_device;
+ DeviceNameUtils::ParseFullName(device, &parsed_device);
+ string task_name = strings::StrCat("/job:", parsed_device.job,
+ "/replica:", parsed_device.replica,
+ "/task:", parsed_device.task);
+ gr->task_set.insert(task_name);
+ gr->task_list.push_back(task_name);
+ gr->group.num_tasks = static_cast<int32>(gr->task_set.size());
+ VLOG(1) << "group_key=" << gr->group.group_key
+ << " group_size=" << gr->group.group_size
+ << " dev_set=" << gr->device_set.size();
+ }
+ }
+ }
+
+ if (status.ok()) {
+ // If the group is not yet complete, queue to wait for it.
+ VLOG(2) << "group_size " << gr->group.group_size << " set size "
+ << gr->device_set.size() << " gr " << gr;
+
+ if (gr->device_set.size() < gr->group.group_size) {
+ gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr));
+ return;
+ }
+ CHECK_EQ(gr->device_set.size(), gr->group.group_size);
+ if (!gr->waiting.empty()) {
+ std::swap(to_be_called, gr->waiting);
+ }
+ }
+ }
+ done(status, gr);
+ for (int i = 0; i < to_be_called.size(); ++i) {
+ to_be_called[i](Status::OK());
+ }
+}
+
+namespace {
+
+struct DevRec {
+ string task;
+ string device;
+ int original_rank;
+ int local_rank;
+ int global_rank;
+ const DeviceLocality* locality;
+};
+typedef std::unordered_map<string, DevRec> TaskDeviceMap;
+typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap;
+
+// Create a populated GlobalDeviceMap from CollInstanceParams and localities.
+GlobalDeviceMap BuildDevRecs(const CollInstanceParams& ip,
+ const std::vector<DeviceLocality>& localities) {
+ GlobalDeviceMap gdm;
+ CHECK_EQ(ip.device_names.size(), ip.task_names.size());
+ CHECK_EQ(ip.device_names.size(), localities.size());
+ for (int i = 0; i < ip.device_names.size(); ++i) {
+ TaskDeviceMap& tdm = gdm[ip.task_names[i]];
+ DevRec* dr = &tdm[ip.device_names[i]];
+ dr->task = ip.task_names[i];
+ dr->device = ip.device_names[i];
+ dr->original_rank = i;
+ dr->local_rank = 0; // Will be populated later by OrderTaskDeviceMap.
+ dr->global_rank = 0; // Will be populated later by EstablishGlobalRank.
+ dr->locality = &localities[i];
+ }
+ return gdm;
+}
+
+void OrderTaskDeviceMap(TaskDeviceMap* tdm) {
+ CHECK_GT(tdm->size(), 0); // Should never be called with 0 devices
+ int least_rank = -1;
+ string next_device;
+ std::set<string> selected;
+ // Starting device is one with the least initial rank.
+ for (const auto& it : *tdm) {
+ if (least_rank < 0 || it.second.original_rank < least_rank) {
+ least_rank = it.second.original_rank;
+ next_device = it.second.device;
+ }
+ }
+ CHECK_GE(least_rank, 0);
+ DeviceNameUtils::ParsedName parsed_name;
+ CHECK(DeviceNameUtils::ParseFullName(next_device, &parsed_name));
+ // NOTE: InterconnectLink has only a device_id, nothing more, so for
+ // the time being if there's more than one device at a task we
+ // assume they're all GPUs.
+
+ int next_rank = 0;
+ while (true) {
+ selected.insert(next_device);
+ DevRec* dr = &(*tdm)[next_device];
+ dr->local_rank = next_rank;
+ ++next_rank;
+ if (selected.size() == tdm->size()) {
+ break;
+ }
+ // For the present time we assume Locality links only cover GPUs.
+ // For multiple CPUs, just take them in order.
+ const InterconnectLink* best_link = nullptr;
+ if (parsed_name.type == "GPU") {
+ for (const InterconnectLink& il : dr->locality->links().link()) {
+ parsed_name.id = il.device_id();
+ string endpoint_device =
+ DeviceNameUtils::ParsedNameToString(parsed_name);
+ if (selected.find(endpoint_device) != selected.end()) {
+ continue;
+ }
+ if (best_link == nullptr || il.strength() > best_link->strength()) {
+ best_link = &il;
+ }
+ }
+ }
+ if (best_link != nullptr) {
+ // Follow the best edge
+ parsed_name.id = best_link->device_id();
+ next_device = DeviceNameUtils::ParsedNameToString(parsed_name);
+ } else {
+ // No good edges, alas. Pick the lowest initial rank among remaining
+ // devices.
+ least_rank = -1;
+ for (const auto& it : *tdm) {
+ if (selected.find(it.second.device) != selected.end()) {
+ continue;
+ }
+ if (least_rank < 0 || it.second.original_rank < least_rank) {
+ least_rank = it.second.original_rank;
+ next_device = it.second.device;
+ }
+ }
+ CHECK_GE(least_rank, 0);
+ }
+ }
+}
+
+// The first time a shared CollectiveParams is established for a
+// shared set of instances we compute a good rank order for all the
+// devices in the group, that is appropriate for a ring algorithm.
+// This order need not be the same across different instance groups
+// sharing the same device group where there is more than one good
+// order.
+GlobalDeviceMap EstablishGlobalRank(
+ CollectiveParams* cp, const std::vector<DeviceLocality>& localities) {
+ VLOG(1) << "EstablishGlobalRank";
+ GlobalDeviceMap gdm = BuildDevRecs(cp->instance, localities);
+ for (auto& iter : gdm) {
+ TaskDeviceMap& tdm = iter.second;
+ OrderTaskDeviceMap(&tdm);
+ }
+ // Connect the global rank order by the order in which tasks first appear.
+ std::set<string> ordered_tasks;
+ int next_rank = 0;
+ for (int i = 0; i < cp->instance.task_names.size(); ++i) {
+ const string& task_name = cp->instance.task_names[i];
+ if (ordered_tasks.find(task_name) != ordered_tasks.end()) {
+ continue;
+ }
+ ordered_tasks.insert(task_name);
+ TaskDeviceMap* tdm = &gdm[task_name];
+ for (auto& it : *tdm) {
+ it.second.global_rank = it.second.local_rank + next_rank;
+ }
+ next_rank += tdm->size();
+ }
+ return gdm;
+}
+
+// Sort cp->instance.device_names lexicographically, but do by first
+// computing a reordering permutation so we can keep cp->instance.task_names
+// in corresponding order.
+void SortDevicesAndTasks(CollectiveParams* cp) {
+ VLOG(1) << "SortDevicesAndTasks " << cp << " instance " << &cp->instance;
+ CHECK(cp);
+ CHECK_EQ(cp->group.group_size, cp->instance.device_names.size());
+ CHECK_EQ(cp->group.group_size, cp->instance.task_names.size());
+ std::vector<int> perm(cp->group.group_size);
+ // TODO(tucker): substitute std::iota when the windows build supports it.
+ // std::iota(perm.begin(), perm.end(), 0);
+ for (int i = 0; i < perm.size(); ++i) {
+ perm[i] = i;
+ }
+ std::sort(perm.begin(), perm.end(), [cp](const int& a, const int& b) {
+ return cp->instance.device_names[a] < cp->instance.device_names[b];
+ });
+ std::vector<string> new_devs;
+ std::vector<string> new_tasks;
+ new_devs.reserve(cp->group.group_size);
+ new_tasks.reserve(cp->group.group_size);
+ for (int pi : perm) {
+ new_devs.push_back(cp->instance.device_names[pi]);
+ new_tasks.push_back(cp->instance.task_names[pi]);
+ }
+ cp->instance.device_names = std::move(new_devs);
+ cp->instance.task_names = std::move(new_tasks);
+ VLOG(1) << "Modified device_names on " << cp;
+}
+
+// Establish the requested number of subdivision permutations based on the
+// ring order implicit in the device order.
+void GenerateSubdivPerms(const string& device, int source_rank,
+ CollectiveParams* cp) {
+ CHECK_GT(cp->instance.impl_details.subdiv_offsets.size(), 0);
+ cp->instance.impl_details.subdiv_permutations.resize(
+ cp->instance.impl_details.subdiv_offsets.size());
+ // Each subdiv permutation is a ring formed by rotating each
+ // single-task subsequence of devices by an offset. This makes most
+ // sense when each task has the same number of devices but we can't
+ // depend on that being the case so we'll compute something that
+ // works in any case.
+
+ // Start by counting the devices in each task.
+ // Precondition: device_names must be sorted so that all devices in
+ // the same task are adjacent.
+ VLOG(2) << "Sorted task names: "
+ << str_util::Join(cp->instance.task_names, ", ");
+ std::vector<int> dev_per_task;
+ const string* prior_task_name = &cp->instance.task_names[0];
+ int dev_count = 1;
+ for (int di = 1; di < cp->group.group_size; ++di) {
+ if (cp->instance.task_names[di] != *prior_task_name) {
+ dev_per_task.push_back(dev_count);
+ dev_count = 1;
+ prior_task_name = &cp->instance.task_names[di];
+ } else {
+ ++dev_count;
+ }
+ }
+ dev_per_task.push_back(dev_count);
+ CHECK_EQ(cp->group.num_tasks, dev_per_task.size());
+
+ // Generate a ring permutation for each requested offset.
+ CHECK_GT(cp->instance.impl_details.subdiv_offsets.size(), 0);
+ VLOG(2) << "Setting up perms for cp " << cp << " subdiv_permutations "
+ << &cp->instance.impl_details.subdiv_permutations;
+ cp->instance.impl_details.subdiv_permutations.resize(
+ cp->instance.impl_details.subdiv_offsets.size());
+ cp->subdiv_rank.resize(cp->instance.impl_details.subdiv_offsets.size(), -1);
+ for (int sdi = 0; sdi < cp->instance.impl_details.subdiv_offsets.size();
+ ++sdi) {
+ std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi];
+ CHECK_EQ(perm.size(), 0);
+ int offset = cp->instance.impl_details.subdiv_offsets[sdi];
+ int prior_dev_count = 0;
+ for (int ti = 0; ti < cp->group.num_tasks; ++ti) {
+ for (int di = 0; di < dev_per_task[ti]; ++di) {
+ int offset_di = (di + offset) % dev_per_task[ti];
+ int permuted_di = prior_dev_count + offset_di;
+ perm.push_back(permuted_di);
+ if (cp->instance.device_names[prior_dev_count + di] == device) {
+ CHECK_EQ(prior_dev_count + di, cp->default_rank);
+ cp->subdiv_rank[sdi] = permuted_di;
+ }
+ }
+ prior_dev_count += dev_per_task[ti];
+ }
+ CHECK_EQ(cp->group.group_size, perm.size());
+ }
+
+ if (cp->instance.type == BROADCAST_COLLECTIVE) {
+ CHECK_GE(source_rank, 0);
+ cp->subdiv_source_rank.resize(
+ cp->instance.impl_details.subdiv_offsets.size(), -1);
+ for (int sdi = 0; sdi < cp->subdiv_source_rank.size(); ++sdi) {
+ for (int j = 0; j < cp->group.group_size; ++j) {
+ if (cp->instance.impl_details.subdiv_permutations[sdi][j] ==
+ source_rank) {
+ cp->subdiv_source_rank[sdi] = j;
+ break;
+ }
+ }
+ CHECK_GE(cp->subdiv_source_rank[sdi], 0);
+ }
+ }
+
+ if (VLOG_IS_ON(1)) {
+ // Log the computed ring order for each subdiv.
+ string buf;
+ for (int sdi = 0;
+ sdi < cp->instance.impl_details.subdiv_permutations.size(); ++sdi) {
+ buf = strings::StrCat("Subdiv ", sdi, " device order:\n");
+ for (int di = 0;
+ di < cp->instance.impl_details.subdiv_permutations[sdi].size();
+ ++di) {
+ int idx = cp->instance.impl_details.subdiv_permutations[sdi][di];
+ strings::StrAppend(&buf, cp->instance.device_names[idx], "\n");
+ }
+ strings::StrAppend(&buf, " subdiv_offsets: ");
+ for (auto o : cp->instance.impl_details.subdiv_offsets)
+ strings::StrAppend(&buf, o, " ");
+ strings::StrAppend(&buf, " SubdivRank: ");
+ for (auto d : cp->subdiv_rank) strings::StrAppend(&buf, d, " ");
+ VLOG(1) << buf;
+ }
+ }
+}
+
+} // namespace
+
+void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
+ CollectiveParams* cp) {
+ cp->task.is_local.resize(cp->group.group_size, false);
+ for (int i = 0; i < cp->group.group_size; ++i) {
+ cp->task.is_local[i] = (cp->instance.task_names[i] == task_name);
+ }
+}
+
+void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
+ CollectiveParams* cp) {
+ CHECK_EQ(cp->group.group_size, cp->instance.device_names.size()) << cp;
+ for (int i = 0; i < cp->group.group_size; ++i) {
+ if (cp->instance.device_names[i] == device) {
+ cp->default_rank = i;
+ break;
+ }
+ }
+}
+
+Status CollectiveParamResolverLocal::InitInstanceSharedParams(
+ GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir) {
+ VLOG(1) << "InitInstanceSharedParams " << ir;
+ ir->shared.instance = cp->instance;
+ {
+ mutex_lock gl(gr->mu);
+ ir->shared.group = gr->group;
+ ir->shared.instance.device_names.assign(gr->device_list.begin(),
+ gr->device_list.end());
+ ir->shared.instance.task_names.assign(gr->task_list.begin(),
+ gr->task_list.end());
+ VLOG(2) << "Initialized names for instance: "
+ << ir->shared.instance.ToString();
+ }
+ ir->shared.default_rank = -1;
+
+ // Sort devce_names lexicographcally, keeping task_names in
+ // corresponding order.
+ SortDevicesAndTasks(&ir->shared);
+
+ // Get Locality data for all devices.
+
+ // Set is_local and task_names in *shared prior to invoking
+ // GetDeviceLocalitiesAsync. In a distributed context this function can be
+ // called by a derived class, some of the devices may be non-local and
+ // GetDeviceLocalitiesAsync will use those fields to launch RPCs.
+ CompleteTaskIsLocal(task_name_, &ir->shared);
+ std::vector<DeviceLocality> localities;
+ Notification note;
+ Status status;
+ dev_resolver_->GetDeviceLocalitiesAsync(ir->shared.instance, &localities,
+ [&note, &status](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ if (status.ok()) {
+ CompleteDefaultRanking(gr, cp, ir, localities);
+ }
+ return status;
+}
+
+void CollectiveParamResolverLocal::CompleteDefaultRanking(
+ GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
+ const std::vector<DeviceLocality>& localities) {
+ // Establish an instance-specific default rank order for devices
+ // based on localities. This rank order should be a good ring
+ // order, if possible.
+ GlobalDeviceMap gdm = EstablishGlobalRank(&ir->shared, localities);
+ // Reflect the new global ranking on shared
+ size_t num_devices = ir->shared.group.group_size;
+ std::vector<string> new_device_names(num_devices, "");
+ std::vector<string> new_task_names(num_devices, "");
+ for (const auto& git : gdm) {
+ const TaskDeviceMap& tdm = git.second;
+ for (const auto& tit : tdm) {
+ const DevRec& dr = tit.second;
+ new_device_names[dr.global_rank] =
+ ir->shared.instance.device_names[dr.original_rank];
+ new_task_names[dr.global_rank] =
+ ir->shared.instance.task_names[dr.original_rank];
+ }
+ }
+
+ ir->shared.instance.device_names = new_device_names;
+ ir->shared.instance.task_names = new_task_names;
+ if (VLOG_IS_ON(2)) {
+ string buf;
+ for (const auto& d : cp->instance.device_names)
+ strings::StrAppend(&buf, "\n", d);
+ VLOG(2) << "Optimized device order for " << ir->shared.name << ": " << buf;
+ }
+}
+
+void CollectiveParamResolverLocal::CallbackWithStatus(
+ const InstanceRecCallback& done, InstanceRec* irec) {
+ Status s;
+ {
+ mutex_lock l(irec->out_mu);
+ s = irec->status;
+ }
+ done(s, irec);
+}
+
+void CollectiveParamResolverLocal::FindInstanceRec(
+ GroupRec* gr, CollectiveParams* cp, const InstanceRecCallback& done) {
+ InstanceRec* irec = nullptr;
+ bool exit_outside_locks = false;
+ {
+ mutex_lock l(instance_mu_);
+ auto it = instance_table_.find(cp->instance.instance_key);
+ if (it != instance_table_.end()) {
+ irec = it->second.get();
+ {
+ mutex_lock l(irec->in_mu);
+ if (irec->is_init) {
+ exit_outside_locks = true;
+ } else {
+ irec->init_waiters.push_back([this, gr, cp, done](InstanceRec* irec) {
+ CallbackWithStatus(done, irec);
+ });
+ return;
+ }
+ }
+ } else {
+ // Create new InstanceRec.
+ irec = new InstanceRec;
+ instance_table_[cp->instance.instance_key].reset(irec);
+ }
+ }
+ if (exit_outside_locks) {
+ CallbackWithStatus(done, irec);
+ return;
+ }
+ // Initialize the new InstanceRec while holding out_mu.
+ {
+ mutex_lock il(irec->out_mu);
+ irec->known.resize(cp->group.group_size, false);
+ irec->status = InitInstanceSharedParams(gr, cp, irec);
+ }
+ // Prepare to invoke any waiters that accumlated during initialization.
+ std::vector<IRConsumer> init_waiters;
+ {
+ mutex_lock tl(instance_mu_);
+ {
+ mutex_lock l(irec->in_mu);
+ irec->is_init = true;
+ if (!irec->init_waiters.empty()) {
+ std::swap(init_waiters, irec->init_waiters);
+ }
+ }
+ }
+ CallbackWithStatus(done, irec);
+ for (auto& f : init_waiters) {
+ f(irec);
+ }
+}
+
+void CollectiveParamResolverLocal::CompleteParamsAsync(
+ const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
+ const StatusCallback& done) {
+ VLOG(1) << "CompleteParams " << device << " for " << cp << ": "
+ << cp->ToString();
+ CompleteGroupLocal(
+ device, cp, [this, device, cp, done](const Status& s, GroupRec* gr) {
+ if (s.ok()) {
+ CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
+ } else {
+ done(s);
+ }
+ });
+}
+
+void CollectiveParamResolverLocal::CompleteInstanceAsync(
+ const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
+ CancellationManager* cancel_mgr, const StatusCallback& done) {
+ done(
+ errors::Internal("CompleteInstance is not implemented by "
+ "CollectiveParamResolverLocal which is "
+ "intended only for non-distributed deployment."));
+}
+
+void CollectiveParamResolverLocal::CompleteInstanceLocal(
+ const string& device, GroupRec* gr, CollectiveParams* cp, bool is_source,
+ const StatusCallback& done) {
+ VLOG(1) << "CompleteInstanceLocal " << device
+ << " instance_key: " << cp->instance.instance_key << " gr " << gr;
+
+ // Populate the group portion of *cp from *gr. Most of it should already
+ // match.
+ DCHECK_EQ(cp->group.group_key, gr->group.group_key);
+ DCHECK_EQ(cp->group.group_size, gr->group.group_size);
+ DCHECK_EQ(cp->group.device_type, gr->group.device_type);
+ cp->group = gr->group;
+
+ // Get the shared InstanceRec for this instance.
+ FindInstanceRec(gr, cp,
+ [this, device, gr, cp, is_source, done](const Status& s,
+ InstanceRec* ir) {
+ if (s.ok()) {
+ CompleteInstanceFromInitializedIRec(device, gr, cp, ir,
+ is_source, done);
+ } else {
+ done(s);
+ }
+ });
+}
+
+void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
+ const string& device, GroupRec* gr, CollectiveParams* cp, InstanceRec* ir,
+ bool is_source, const StatusCallback& done) {
+ // Populate the fields common across instance.
+ {
+ mutex_lock l(ir->out_mu);
+ // custom operator= does a deep copy.
+ cp->instance = ir->shared.instance;
+ }
+ // Populate the fields common across task, also default_rank.
+ SetDefaultRank(device, cp);
+ CompleteTaskIsLocal(task_name_, cp);
+ // If broadcast, may need to wait for source discovery.
+ if (cp->instance.type == BROADCAST_COLLECTIVE) {
+ CompleteInstanceSource(ir, cp, is_source,
+ [this, ir, device, cp, done](InstanceRec* irec) {
+ CHECK_EQ(ir, irec);
+ Status s;
+ int source_rank;
+ {
+ mutex_lock l(irec->out_mu);
+ s = irec->status;
+ source_rank = ir->source_rank;
+ }
+ if (s.ok()) {
+ GenerateSubdivPerms(device, source_rank, cp);
+ }
+ done(s);
+ });
+ return;
+ } else {
+ GenerateSubdivPerms(device, 0, cp);
+ }
+ done(Status::OK());
+}
+
+void CollectiveParamResolverLocal::CompleteInstanceSource(InstanceRec* ir,
+ CollectiveParams* cp,
+ bool is_source,
+ const IRConsumer& f) {
+ std::vector<IRConsumer> ready_waiters;
+ {
+ mutex_lock l(ir->out_mu);
+ CHECK_EQ(cp->group.group_size, ir->known.size());
+ CHECK_GE(cp->default_rank, 0);
+ if (!ir->known[cp->default_rank]) {
+ ir->known[cp->default_rank] = true;
+ ++ir->known_count;
+ if (is_source) {
+ if (ir->source_rank >= 0) {
+ ir->status = errors::Internal("Instance ", cp->instance.instance_key,
+ " already has source ", ir->source_rank,
+ ", recevied second claim from ",
+ cp->default_rank);
+ } else {
+ ir->source_rank = cp->default_rank;
+ }
+ }
+ }
+ if (ir->known_count < ir->shared.group.group_size) {
+ ir->known_waiters.push_back(f);
+ return;
+ }
+ CHECK_EQ(ir->known_count, ir->shared.group.group_size);
+ CHECK_GE(ir->source_rank, 0);
+ if (!ir->known_waiters.empty()) {
+ ready_waiters = std::move(ir->known_waiters);
+ }
+ }
+ f(ir);
+ for (auto& f : ready_waiters) {
+ f(ir);
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
new file mode 100644
index 0000000000..ff3415b0a9
--- /dev/null
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -0,0 +1,209 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+namespace tensorflow {
+class CompleteGroupRequest;
+class CompleteGroupResponse;
+class CompleteInstanceRequest;
+class CompleteInstanceResponse;
+class DeviceMgr;
+
+// Implements ParamResolverInterface for a single-task context.
+// It also implements the functionality necessary to serve as the
+// group leader for param resolution in a multi-task context.
+class CollectiveParamResolverLocal : public ParamResolverInterface {
+ public:
+ CollectiveParamResolverLocal(const DeviceMgr* dev_mgr,
+ DeviceResolverInterface* dev_resolver,
+ const string& task_name);
+
+ ~CollectiveParamResolverLocal() override {}
+
+ void CompleteParamsAsync(const string& device, CollectiveParams* cp,
+ CancellationManager* cancel_mgr,
+ const StatusCallback& done) override;
+
+ void CompleteGroupAsync(const CompleteGroupRequest* request,
+ CompleteGroupResponse* response,
+ CancellationManager* cancel_mgr,
+ const StatusCallback& done) override;
+
+ void CompleteInstanceAsync(const CompleteInstanceRequest* request,
+ CompleteInstanceResponse* response,
+ CancellationManager* cancel_mgr,
+ const StatusCallback& done) override;
+
+ protected:
+ // Used to complete/verify CollGroup.
+ struct GroupRec {
+ CollGroupParams group;
+ mutex mu;
+ Status status GUARDED_BY(mu);
+ std::set<string> device_set GUARDED_BY(mu);
+ std::vector<string> device_list GUARDED_BY(mu);
+ std::set<string> task_set GUARDED_BY(mu);
+ std::vector<string> task_list GUARDED_BY(mu);
+ std::vector<StatusCallback> waiting GUARDED_BY(mu);
+ };
+
+ // Finds the GroupRec that corresponds to cp->group_key.
+ // Also populates cp->group from that group_rec.
+ // Will wait until GroupRec is fully populated or an error arises before
+ // calling done. Callback GroupRec* arg is only valid if status is ok.
+ // Ownership of GroupRec stays with this object and does not pass to the
+ // callback.
+ typedef std::function<void(const Status& s, GroupRec* gr)> GroupRecCallback;
+ void CompleteGroupLocal(const string& device, CollectiveParams* cp,
+ const GroupRecCallback& done)
+ LOCKS_EXCLUDED(group_mu_);
+
+ // Used to complete/verify CollInstance.
+ struct InstanceRec;
+ typedef std::function<void(InstanceRec*)> IRConsumer;
+ struct InstanceRec {
+ // This structure has two mutexes so that a possibly long
+ // initialization can be done without holding the instance_mu_
+ // table lock the whole time (which can cause an excessive number
+ // of threads to block on it), and because the compiler may not
+ // permit mutex locks to be taken in more than one order.
+ //
+ // out_mu guards access to most of the fields.
+ // in_mu guards access to a queue of comsumer callbacks wanting to
+ // read the fields guarded by out_mu.
+ //
+ // The in_mu should be locked only while holding instance_mu_; the
+ // out_mu should be locked only while not holding
+ // instance_mu_.
+ //
+ // When is_init is false (the initial value) any potential user
+ // other than the creator should queue a callback on init_waiters.
+ // As soon as the shared member of this structure is fully
+ // initialized is_init will be set true and those callbacks will
+ // be invoked.
+ //
+ // Once inserted in the table this structure will never be replaced
+ // so users can capture the pointer while holding instance_mu_,
+ // drop that lock, then take a lock on out_mu before
+ // reading/modifying its values.
+ mutex in_mu;
+ bool is_init GUARDED_BY(in_mu);
+ std::vector<IRConsumer> init_waiters GUARDED_BY(in_mu);
+
+ // Values to be shared by all instances, constant after initialization.
+ mutex out_mu;
+ CollectiveParams shared GUARDED_BY(out_mu);
+ // If an error occurs during initialization this structure stays in
+ // the table with a non-OK status. Purging the table and restarting
+ // needs to be done at a higher level.
+ Status status GUARDED_BY(out_mu);
+
+ // These fields are used to count the instances that have called
+ // in and become known while resolving broadcast source identity.
+ int source_rank GUARDED_BY(out_mu);
+ int known_count GUARDED_BY(out_mu);
+ std::vector<bool> known GUARDED_BY(out_mu);
+ std::vector<IRConsumer> known_waiters GUARDED_BY(out_mu);
+
+ InstanceRec() : is_init(false), source_rank(-1), known_count(0) {}
+ };
+
+ // Find the InstanceRec with the same instance_key as cp. If it doesn't
+ // already exist, create and initialize from gr and cp.
+ //
+ // Precondition: *gr must be a complete GroupRec, i.e. the value set
+ // by CompleteGroupLocal. *cp must be populated with all the fields
+ // required by InitInstanceSharedParams. Ownership of InstanceRec stays
+ // with this object and does not pass to the callback.
+ typedef std::function<void(const Status& s, InstanceRec* ir)>
+ InstanceRecCallback;
+ void FindInstanceRec(GroupRec* gr, CollectiveParams* cp,
+ const InstanceRecCallback& done)
+ LOCKS_EXCLUDED(instance_mu_, gr->mu, group_mu_);
+
+ // Populate *ir with device membership from gr, then initialize to be specific
+ // to cp->instance_key, i.e. order the devices and tasks.
+ //
+ // Preconditions:
+ // cp is populated with all DeviceLocalities
+ Status InitInstanceSharedParams(GroupRec* gr, const CollectiveParams* cp,
+ InstanceRec* ir)
+ EXCLUSIVE_LOCKS_REQUIRED(ir->out_mu) LOCKS_EXCLUDED(gr->mu);
+
+ // Establishes the final order of ir->shared.instance.device_names and
+ // ir->shared.instance.task_names by considering localities of all devices.
+ void CompleteDefaultRanking(GroupRec* gr, const CollectiveParams* cp,
+ InstanceRec* ir,
+ const std::vector<DeviceLocality>& localities)
+ EXCLUSIVE_LOCKS_REQUIRED(ir->out_mu);
+
+ // Finish populating *cp.
+ // Precondition: *gr has been fully populated by CompleteGroupLocal.
+ void CompleteInstanceLocal(const string& device, GroupRec* gr,
+ CollectiveParams* cp, bool is_source,
+ const StatusCallback& done)
+ LOCKS_EXCLUDED(instance_mu_, gr->mu, group_mu_);
+
+ // Finish populating *cp from fully initialized *ir.
+ // Precondition: *gr and *ir are fully populated.
+ void CompleteInstanceFromInitializedIRec(const string& device, GroupRec* gr,
+ CollectiveParams* cp,
+ InstanceRec* ir, bool is_source,
+ const StatusCallback& done)
+ LOCKS_EXCLUDED(ir->out_mu);
+
+ // Complete source data for a broadcast instance.
+ // Precondition: *cp has complete group data and default_rank.
+ void CompleteInstanceSource(InstanceRec* ir, CollectiveParams* cp,
+ bool is_source, const IRConsumer& f)
+ LOCKS_EXCLUDED(ir->out_mu);
+
+ // If cp.device_names contains only devices local to this process
+ // populates *localities, else returns an error.
+ Status GetLocalDeviceLocalities(const CollectiveParams& cp,
+ std::vector<DeviceLocality>* localities);
+
+ // Sets CollTaskParams.is_local and CollectiveParams.default_rank.
+ // Precondition: cp->device_names is fully populated and in final order.
+ void CompleteTaskIsLocal(const string& task_name, CollectiveParams* cp);
+
+ // Sets cp->instance_default_rank according to location of device in
+ // current ordering of cp->instance.device_names.
+ void SetDefaultRank(const string& device, CollectiveParams* cp);
+
+ // Helper to grab status under lock, invoke callback out of lock.
+ void CallbackWithStatus(const InstanceRecCallback& done, InstanceRec* irec)
+ LOCKS_EXCLUDED(irec->out_mu);
+
+ const DeviceMgr* dev_mgr_;
+ DeviceResolverInterface* dev_resolver_;
+ string task_name_;
+ mutex group_mu_;
+ gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_
+ GUARDED_BY(group_mu_);
+ mutex instance_mu_;
+ gtl::FlatMap<int32, std::unique_ptr<InstanceRec>> instance_table_
+ GUARDED_BY(instance_mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
new file mode 100644
index 0000000000..4e3c7125f2
--- /dev/null
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
@@ -0,0 +1,151 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+
+#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+#define NUM_DEVS 3
+
+class CollectiveParamResolverLocalTest : public ::testing::Test {
+ protected:
+ CollectiveParamResolverLocalTest() {
+ ConfigProto cp;
+ SessionOptions options;
+ string task_name = "/job:localhost/replica:0/task:0";
+ auto* device_count = options.config.mutable_device_count();
+ device_count->insert({"CPU", NUM_DEVS});
+ TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
+ device_mgr_.reset(new DeviceMgr(devices_));
+ drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
+ prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
+ task_name));
+ }
+
+ std::vector<Device*> devices_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+ std::unique_ptr<DeviceResolverLocal> drl_;
+ std::unique_ptr<CollectiveParamResolverLocal> prl_;
+};
+
+TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
+ CollectiveParams cps[NUM_DEVS];
+ Status statuses[NUM_DEVS];
+ Notification note[NUM_DEVS];
+ for (int i = 0; i < NUM_DEVS; ++i) {
+ CollectiveParams* cp = &cps[i];
+ cp->group.group_key = 1;
+ cp->group.group_size = 3;
+ cp->group.device_type = DeviceType("CPU");
+ cp->group.num_tasks = 1;
+ cp->instance.instance_key = 7;
+ cp->instance.type = REDUCTION_COLLECTIVE;
+ cp->instance.data_type = DataType(DT_FLOAT);
+ cp->instance.shape = TensorShape({5});
+ cp->instance.device_names.push_back(
+ strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i));
+ cp->instance.impl_details.subdiv_offsets.push_back(0);
+ cp->is_source = false;
+ Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
+ prl_->CompleteParamsAsync(cp->instance.device_names[0], cp,
+ nullptr /*CancellationManager*/,
+ [this, &statuses, &note, i](const Status& s) {
+ statuses[i] = s;
+ note[i].Notify();
+ });
+ });
+ }
+ for (int i = 0; i < NUM_DEVS; ++i) {
+ note[i].WaitForNotification();
+ }
+ for (int i = 0; i < NUM_DEVS; ++i) {
+ TF_ASSERT_OK(statuses[i]);
+ ASSERT_EQ(cps[i].instance.device_names.size(), 3);
+ for (int j = 0; j < NUM_DEVS; ++j) {
+ EXPECT_EQ(
+ strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j),
+ cps[i].instance.device_names[j]);
+ EXPECT_TRUE(cps[i].task.is_local[j]);
+ }
+ EXPECT_EQ(cps[i].subdiv_rank[0], i);
+ EXPECT_EQ(cps[i].subdiv_source_rank.size(), 0);
+ EXPECT_FALSE(cps[i].is_source);
+ EXPECT_EQ(cps[i].default_rank, i);
+ }
+}
+
+TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
+ CollectiveParams cps[NUM_DEVS];
+ Status statuses[NUM_DEVS];
+ Notification note[NUM_DEVS];
+ for (int i = 0; i < NUM_DEVS; ++i) {
+ CollectiveParams* cp = &cps[i];
+ cp->group.group_key = 1;
+ cp->group.group_size = 3;
+ cp->group.device_type = DeviceType("CPU");
+ cp->group.num_tasks = 1;
+ cp->instance.instance_key = 3;
+ cp->instance.type = BROADCAST_COLLECTIVE;
+ cp->instance.data_type = DataType(DT_FLOAT);
+ cp->instance.shape = TensorShape({5});
+ cp->instance.device_names.push_back(
+ strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i));
+ cp->instance.impl_details.subdiv_offsets.push_back(0);
+ cp->is_source = (i == 1);
+ Env::Default()->SchedClosure([this, i, cp, &note, &statuses]() {
+ prl_->CompleteParamsAsync(cp->instance.device_names[0], cp,
+ nullptr /*CancellationManager*/,
+ [this, &statuses, &note, i](const Status& s) {
+ statuses[i] = s;
+ note[i].Notify();
+ });
+ });
+ }
+ for (int i = 0; i < NUM_DEVS; ++i) {
+ note[i].WaitForNotification();
+ }
+ for (int i = 0; i < NUM_DEVS; ++i) {
+ TF_ASSERT_OK(statuses[i]);
+ ASSERT_EQ(cps[i].instance.device_names.size(), 3);
+ for (int j = 0; j < NUM_DEVS; ++j) {
+ EXPECT_EQ(
+ strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j),
+ cps[i].instance.device_names[j]);
+ EXPECT_TRUE(cps[i].task.is_local[j]);
+ }
+ ASSERT_GT(cps[i].subdiv_rank.size(), 0);
+ EXPECT_EQ(cps[i].subdiv_rank[0], i);
+ ASSERT_GT(cps[i].subdiv_source_rank.size(), 0);
+ EXPECT_EQ(cps[i].subdiv_source_rank[0], 1);
+ EXPECT_EQ(cps[i].is_source, (i == 1));
+ EXPECT_EQ(cps[i].default_rank, i);
+ }
+}
+
+// TEST_F(CollectiveParamResolverLocalTest,
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_rma_local.cc b/tensorflow/core/common_runtime/collective_rma_local.cc
new file mode 100644
index 0000000000..ad9b32ce35
--- /dev/null
+++ b/tensorflow/core/common_runtime/collective_rma_local.cc
@@ -0,0 +1,108 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+
+#include "tensorflow/core/common_runtime/copy_tensor.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+
+namespace tensorflow {
+
+void CollectiveRemoteAccessLocal::StartAbort(const Status& s) {
+ buf_rendezvous_.StartAbort(s);
+}
+
+void CollectiveRemoteAccessLocal::RecvFromPeer(
+ const string& peer_device, const string& peer_task, bool peer_is_local,
+ const string& key, Device* to_device, DeviceContext* to_device_ctx,
+ const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+ const DeviceLocality& client_locality, const StatusCallback& done) {
+ VLOG(1) << "RecvFromPeer " << this << " from " << peer_device << " key "
+ << key;
+ if (!peer_is_local) {
+ done(
+ errors::Internal("CollectiveRemoteAccessLocal::RecvFromPeer "
+ "called with peer_is_local=false"));
+ return;
+ }
+ buf_rendezvous_.ConsumeBuf(
+ key, [this, to_tensor, to_device_ctx, to_device, to_alloc_attr, done](
+ const Status& s, BufRendezvous::Hook* hook) {
+ if (!s.ok()) {
+ done(s);
+ delete hook;
+ } else {
+ int64 recv_bytes = to_tensor->TotalBytes();
+ CHECK_EQ(recv_bytes, hook->prod_value->TotalBytes());
+ MemCpyAsync(hook->prod_ctx, // src DeviceContext
+ to_device_ctx, // dst DeviceContext
+ hook->prod_dev, // src Device
+ to_device, // dst Device
+ hook->prod_attr, // src AllocatorAttributes
+ to_alloc_attr, // dst AllocatorAttributes
+ hook->prod_value, // src Tensor*
+ to_tensor, // dst Tensor*
+ [hook, done](const Status& s) {
+ done(s);
+ hook->prod_cb(s);
+ delete hook;
+ });
+ }
+ });
+}
+
+void CollectiveRemoteAccessLocal::PostToPeer(
+ const string& peer_device, const string& peer_task, const string& key,
+ Device* from_device, DeviceContext* from_device_ctx,
+ const AllocatorAttributes& from_alloc_attr, const Tensor* from_tensor,
+ const DeviceLocality& client_locality, const StatusCallback& done) {
+ VLOG(1) << "PostToPeer " << this << " key " << key
+ << " step_id_=" << step_id_;
+ buf_rendezvous_.ProvideBuf(key, from_device, from_device_ctx, from_tensor,
+ from_alloc_attr, done);
+}
+
+/*static*/
+void CollectiveRemoteAccessLocal::MemCpyAsync(
+ DeviceContext* src_dev_ctx, DeviceContext* dst_dev_ctx, Device* src_dev,
+ Device* dst_dev, const AllocatorAttributes& src_attr,
+ const AllocatorAttributes& dst_attr, const Tensor* src, Tensor* dst,
+ const StatusCallback& done) {
+ // We want a real copy to happen, i.e. the bytes inside of src should be
+ // transferred to the buffer backing dst. If src and dst are on different
+ // devices then CopyTensor::ViaDMA will do just that. But if they're both
+ // the same CPU, then it will actually just reset dst to point to src.
+ // Since this routine is used for copying between devices and within a
+ // device, we need to detect and bypass the wrong-semantics case.
+ const DeviceType src_device_type(
+ src_attr.on_host() ? DEVICE_CPU : src_dev->attributes().device_type());
+ const DeviceType dst_device_type(
+ dst_attr.on_host() ? DEVICE_CPU : dst_dev->attributes().device_type());
+ const bool non_cpu_src = src_device_type != DeviceType(DEVICE_CPU);
+ const bool non_cpu_dst = dst_device_type != DeviceType(DEVICE_CPU);
+ if (non_cpu_src) CHECK(src_dev_ctx);
+ if (non_cpu_dst) CHECK(dst_dev_ctx);
+ if (non_cpu_src || non_cpu_dst) {
+ CopyTensor::ViaDMA("", // edge name (non-existent)
+ src_dev_ctx, dst_dev_ctx, src_dev, dst_dev, src_attr,
+ dst_attr, src, dst, done);
+ } else {
+ int64 bytes = src->TotalBytes();
+ DCHECK_EQ(dst->TotalBytes(), bytes);
+ memcpy(DMAHelper::base(dst), DMAHelper::base(src), bytes);
+ done(Status::OK());
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_rma_local.h b/tensorflow/core/common_runtime/collective_rma_local.h
new file mode 100644
index 0000000000..d25dd5f04a
--- /dev/null
+++ b/tensorflow/core/common_runtime/collective_rma_local.h
@@ -0,0 +1,88 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_ACCESS_H_
+#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_ACCESS_H_
+#include "tensorflow/core/common_runtime/buf_rendezvous.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/rendezvous.h"
+
+namespace tensorflow {
+
+// Basic implementation of PerStepCollectiveRemoteAccess.
+class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
+ public:
+ CollectiveRemoteAccessLocal(const DeviceMgr* dev_mgr,
+ DeviceResolverInterface* dev_resolver,
+ int64 step_id)
+ : dev_mgr_(dev_mgr),
+ dev_resolver_(dev_resolver),
+ buf_rendezvous_(step_id),
+ step_id_(step_id) {}
+
+ virtual ~CollectiveRemoteAccessLocal() {}
+
+ void StartAbort(const Status& s);
+
+ void RecvFromPeer(const string& peer_device, const string& peer_task,
+ bool peer_is_local, const string& key, Device* to_device,
+ DeviceContext* to_device_ctx,
+ const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+ const DeviceLocality& client_locality,
+ const StatusCallback& done) override;
+
+ void PostToPeer(const string& peer_device, const string& peer_task,
+ const string& key, Device* from_device,
+ DeviceContext* from_device_ctx,
+ const AllocatorAttributes& from_alloc_attr,
+ const Tensor* from_tensor,
+ const DeviceLocality& client_locality,
+ const StatusCallback& done) override;
+
+ void GetDeviceLocalitiesAsync(const CollInstanceParams& ci_params,
+ std::vector<DeviceLocality>* localities,
+ const StatusCallback& done) override {
+ dev_resolver_->GetDeviceLocalitiesAsync(ci_params, localities, done);
+ }
+
+ void GetLocalityAsync(const string& device, const string& task,
+ DeviceLocality* locality,
+ const StatusCallback& done) override {
+ dev_resolver_->GetLocalityAsync(device, task, locality, done);
+ }
+
+ void ClearTask(const string& task) override {
+ dev_resolver_->ClearTask(task);
+ }
+
+ // Copy utility that always copies bytes from src to dst even if
+ // they are on the same device, unlike CopyTensor::ViaDMA which will
+ // just change the dst buffer pointer in that case.
+ static void MemCpyAsync(DeviceContext* src_dev_ctx,
+ DeviceContext* dst_dev_ctx, Device* src_dev,
+ Device* dst_dev, const AllocatorAttributes& src_attr,
+ const AllocatorAttributes& dst_attr,
+ const Tensor* src, Tensor* dst,
+ const StatusCallback& done);
+
+ protected:
+ const DeviceMgr* dev_mgr_; // not owned
+ DeviceResolverInterface* dev_resolver_; // not owned
+ BufRendezvous buf_rendezvous_;
+ int64 step_id_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_ACCESS_H_
diff --git a/tensorflow/core/common_runtime/collective_rma_local_test.cc b/tensorflow/core/common_runtime/collective_rma_local_test.cc
new file mode 100644
index 0000000000..dcd4272d96
--- /dev/null
+++ b/tensorflow/core/common_runtime/collective_rma_local_test.cc
@@ -0,0 +1,148 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+
+#include "tensorflow/core/common_runtime/buf_rendezvous.h"
+#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+#define NUM_DEVS 3
+static const int kStepId = 123;
+
+class CollectiveRemoteAccessLocalTest : public ::testing::Test {
+ protected:
+ const string kTaskName = "/job:localhost/replica:0/task:0";
+
+ CollectiveRemoteAccessLocalTest() {
+ ConfigProto cp;
+ SessionOptions options;
+ auto* device_count = options.config.mutable_device_count();
+ device_count->insert({"CPU", NUM_DEVS});
+ TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices_));
+ device_mgr_.reset(new DeviceMgr(devices_));
+ drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
+ prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
+ kTaskName));
+ rma_.reset(new CollectiveRemoteAccessLocal(device_mgr_.get(), drl_.get(),
+ kStepId));
+ }
+
+ std::vector<Device*> devices_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+ std::unique_ptr<DeviceResolverLocal> drl_;
+ std::unique_ptr<CollectiveParamResolverLocal> prl_;
+ std::unique_ptr<CollectiveRemoteAccessLocal> rma_;
+};
+
+TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU0) {
+ Device* cpu0 = nullptr;
+ AllocatorAttributes attr;
+ DeviceLocality dev_locality;
+ TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:0", &cpu0));
+ Tensor sink_tensor(DT_FLOAT, TensorShape({8}));
+ Notification recv_note;
+ Status recv_status;
+ rma_->RecvFromPeer(kTaskName + "/device:CPU:0", kTaskName, true /*is_local*/,
+ "key_0", cpu0 /*to_device*/, nullptr /*to_device_ctx*/,
+ attr /*to_alloc_attr*/, &sink_tensor, dev_locality,
+ [this, &recv_note, &recv_status](const Status& s) {
+ recv_status = s;
+ recv_note.Notify();
+ });
+ Tensor source_tensor(DT_FLOAT, TensorShape({8}));
+ for (int i = 0; i < 8; ++i) {
+ source_tensor.flat<float>()(i) = i / 2;
+ }
+ // Tensors have distinct storage.
+ EXPECT_NE(DMAHelper::base(&source_tensor), DMAHelper::base(&sink_tensor));
+ Notification send_note;
+ Status send_status;
+ rma_->PostToPeer(kTaskName + "/device:CPU:0", kTaskName, "key_0",
+ cpu0 /*from_device*/, nullptr /*from_device_ctx*/,
+ attr /*to_alloc_attr*/, &source_tensor, dev_locality,
+ [this, &send_note, &send_status](const Status& s) {
+ send_status = s;
+ send_note.Notify();
+ });
+ recv_note.WaitForNotification();
+ send_note.WaitForNotification();
+ TF_EXPECT_OK(recv_status);
+ TF_EXPECT_OK(send_status);
+ // Sink tensor gets the source tensor values.
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_EQ(sink_tensor.flat<float>()(i), i / 2);
+ }
+ // And still has distinct storage.
+ EXPECT_NE(DMAHelper::base(&source_tensor), DMAHelper::base(&sink_tensor));
+}
+
+TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) {
+ Device* cpu2 = nullptr;
+ AllocatorAttributes attr;
+ DeviceLocality dev_locality;
+ TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:2", &cpu2));
+ Tensor sink_tensor(DT_FLOAT, TensorShape({8}));
+ Notification recv_note;
+ Status recv_status;
+ rma_->RecvFromPeer(kTaskName + "/device:CPU:1", kTaskName, true /*is_local*/,
+ "key_0", cpu2 /*to_device*/, nullptr /*to_device_ctx*/,
+ attr /*to_alloc_attr*/, &sink_tensor, dev_locality,
+ [this, &recv_note, &recv_status](const Status& s) {
+ recv_status = s;
+ recv_note.Notify();
+ });
+ Tensor source_tensor(DT_FLOAT, TensorShape({8}));
+ for (int i = 0; i < 8; ++i) {
+ source_tensor.flat<float>()(i) = i / 2;
+ }
+ // Tensors have distinct storage.
+ EXPECT_NE(DMAHelper::base(&source_tensor), DMAHelper::base(&sink_tensor));
+ Device* cpu1 = nullptr;
+ TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:1", &cpu1));
+ Notification send_note;
+ Status send_status;
+ rma_->PostToPeer(kTaskName + "/device:CPU:2", kTaskName, "key_0",
+ cpu1 /*from_device*/, nullptr /*from_device_ctx*/,
+ attr /*to_alloc_attr*/, &source_tensor, dev_locality,
+ [this, &send_note, &send_status](const Status& s) {
+ send_status = s;
+ send_note.Notify();
+ });
+ recv_note.WaitForNotification();
+ send_note.WaitForNotification();
+ TF_EXPECT_OK(recv_status);
+ TF_EXPECT_OK(send_status);
+ // Sink tensor gets the source tensor values.
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_EQ(sink_tensor.flat<float>()(i), i / 2);
+ }
+ // And still has distinct storage.
+ EXPECT_NE(DMAHelper::base(&source_tensor), DMAHelper::base(&sink_tensor));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_resolver_local.cc b/tensorflow/core/common_runtime/device_resolver_local.cc
new file mode 100644
index 0000000000..17ef4a2284
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_resolver_local.cc
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+
+#include "tensorflow/core/common_runtime/device_mgr.h"
+
+namespace tensorflow {
+
+void DeviceResolverLocal::GetDeviceLocalitiesAsync(
+ const CollInstanceParams& ci_params,
+ std::vector<DeviceLocality>* localities, const StatusCallback& done) {
+ localities->clear();
+ for (const string& device_name : ci_params.device_names) {
+ Device* dev;
+ Status s = dev_mgr_->LookupDevice(device_name, &dev);
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+ localities->push_back(dev->attributes().locality());
+ }
+ done(Status::OK());
+}
+
+void DeviceResolverLocal::GetLocalityAsync(const string& device,
+ const string& task,
+ DeviceLocality* locality,
+ const StatusCallback& done) {
+ Device* dev;
+ Status s = dev_mgr_->LookupDevice(device, &dev);
+ if (s.ok()) {
+ *locality = dev->attributes().locality();
+ }
+ done(s);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_resolver_local.h b/tensorflow/core/common_runtime/device_resolver_local.h
new file mode 100644
index 0000000000..098eccdf84
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_resolver_local.h
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
+#define TENSORFLOW_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+
+namespace tensorflow {
+class DeviceMgr;
+
+// Implements DeviceResolverInterface in a single-task context.
+class DeviceResolverLocal : public DeviceResolverInterface {
+ public:
+ DeviceResolverLocal(const DeviceMgr* dev_mgr) : dev_mgr_(dev_mgr) {}
+
+ virtual ~DeviceResolverLocal() {}
+
+ void GetDeviceLocalitiesAsync(const CollInstanceParams& ci_params,
+ std::vector<DeviceLocality>* localities,
+ const StatusCallback& done) override;
+
+ void GetLocalityAsync(const string& device, const string& task,
+ DeviceLocality* locality,
+ const StatusCallback& done) override;
+
+ void ClearTask(const string& task) override {}
+
+ protected:
+ const DeviceMgr* dev_mgr_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
diff --git a/tensorflow/core/common_runtime/device_resolver_local_test.cc b/tensorflow/core/common_runtime/device_resolver_local_test.cc
new file mode 100644
index 0000000000..f5a6471ff7
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_resolver_local_test.cc
@@ -0,0 +1,87 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+#define NUM_DEVS 3
+
+class DeviceResolverLocalTest : public ::testing::Test {
+ protected:
+ DeviceResolverLocalTest() {
+ ConfigProto cp;
+ SessionOptions options;
+ string task_name = "/job:localhost/replica:0/task:0";
+ auto* device_count = options.config.mutable_device_count();
+ device_count->insert({"CPU", NUM_DEVS});
+ TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
+ device_mgr_.reset(new DeviceMgr(devices_));
+ drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
+ }
+
+ std::vector<Device*> devices_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+ std::unique_ptr<DeviceResolverLocal> drl_;
+};
+
+TEST_F(DeviceResolverLocalTest, GetDeviceLocalitiesKnown) {
+ CollectiveParams cp;
+ std::vector<DeviceLocality> localities;
+ cp.instance.device_names.push_back(
+ "/job:localhost/replica:0/task:0/device:CPU:1");
+ cp.instance.device_names.push_back(
+ "/job:localhost/replica:0/task:0/device:CPU:2");
+ Notification note;
+ Status status;
+ drl_->GetDeviceLocalitiesAsync(cp.instance, &localities,
+ [this, &note, &status](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ TF_EXPECT_OK(status);
+ EXPECT_EQ(2, localities.size());
+}
+
+TEST_F(DeviceResolverLocalTest, GetDeviceLocalitiesUnknown) {
+ CollectiveParams cp;
+ std::vector<DeviceLocality> localities;
+ // In some builds there may be 1 GPU, but there should never be 9.
+ cp.instance.device_names.push_back(
+ "/job:localhost/replica:0/task:0/device:GPU:9");
+ Notification note;
+ Status status;
+ drl_->GetDeviceLocalitiesAsync(cp.instance, &localities,
+ [this, &note, &status](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ EXPECT_FALSE(status.ok());
+ EXPECT_EQ(0, localities.size());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index a619cac9a4..941a0e61c7 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -135,17 +135,34 @@ tf_cc_test(
],
)
-# -----------------------------------------------------------------------------
-# Google-internal targets.
+cc_library(
+ name = "execute",
+ srcs = ["execute.cc"],
+ hdrs = ["execute.h"],
+ deps = [
+ ":context",
+ ":copy_to_device_node",
+ ":kernel_and_device",
+ ":tensor_handle",
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
+cc_library(
+ name = "execute_node",
+ hdrs = ["execute_node.h"],
+ deps = [
+ ":context",
+ ":eager_executor",
+ ":execute",
+ ":kernel_and_device",
+ ":tensor_handle",
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
)
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
new file mode 100644
index 0000000000..4f16e42568
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -0,0 +1,134 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/eager/execute.h"
+
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
+#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+Status EagerExecute(EagerContext* ctx, Device* device,
+ const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
+ KernelAndDevice* kernel, NodeExecStats* maybe_stats,
+ TensorHandle** retvals, int num_retvals) {
+ if (!ctx->SoftPlacement() && device == nullptr) {
+ device = ctx->HostCPU();
+ }
+
+ if (device == nullptr) {
+ // TODO(apassos) debug how the assignment below might return a different
+ // device from the one requested above.
+ device = kernel->device();
+ }
+
+ std::vector<Tensor> outputs(1);
+ const MemoryTypeVector* output_memory_types = nullptr;
+ output_memory_types = &kernel->kernel()->output_memory_types();
+ std::vector<Tensor> inputs(op_inputs.size());
+ for (int i = 0; i < op_inputs.size(); ++i) {
+ const Tensor* input_tensor = nullptr;
+ TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
+ inputs[i] = *input_tensor;
+ }
+ // WARNING: kernel->Run utilizes the FunctionLibraryRuntime
+ // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def.
+ // But knowledge of the implementation
+ // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by
+ // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
+ // This is quite subtle. Re-work things to make this better? (Would it make
+ // sense for FunctionLibraryRuntime to ensure thread-safe access to
+ // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats
+ // for ops which are a part of functions.
+ // TODO(agarwal): change Run to take vector of handles ?
+ TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats));
+ if (maybe_stats != nullptr) {
+ maybe_stats->set_op_end_rel_micros(Env::Default()->NowMicros() -
+ maybe_stats->all_start_micros());
+ mutex_lock ml(*ctx->MetadataMu());
+ if (ctx->ShouldStoreMetadata()) {
+ auto* step_stats = ctx->RunMetadataProto()->mutable_step_stats();
+ // Lazily initialize the RunMetadata with information about all devices if
+ // this is the first call.
+ while (step_stats->dev_stats_size() < ctx->devices()->size()) {
+ step_stats->add_dev_stats();
+ }
+ // Find the current device's index.
+ int device_idx = 0;
+ for (int i = 0; i < ctx->devices()->size(); ++i) {
+ if (ctx->devices()->at(i) == device) {
+ device_idx = i;
+ break;
+ }
+ }
+ // Populate the device stats for this device.
+ auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
+ dev_stats->set_device(device->name());
+ *dev_stats->add_node_stats() = *maybe_stats;
+ }
+ }
+ DCHECK_EQ(num_retvals, outputs.size());
+ Device* op_device = device;
+ for (int i = 0; i < num_retvals; ++i) {
+ Device* d = op_device;
+ if (d != nullptr && output_memory_types != nullptr &&
+ (*output_memory_types)[i] == HOST_MEMORY) {
+ d = nullptr;
+ }
+ if (retvals[i] == nullptr) {
+ retvals[i] = new TensorHandle(outputs[i], d, op_device);
+ } else {
+ retvals[i]->SetTensorAndDevice(outputs[i], d, op_device);
+ }
+ }
+ return Status::OK();
+}
+
+Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
+ const char* device_name, TensorHandle** result) {
+ TF_RETURN_IF_ERROR(ctx->GetStatus());
+ Device* dstd = ctx->HostCPU();
+ if (device_name != nullptr && strlen(device_name) > 0) {
+ TF_RETURN_IF_ERROR(ctx->device_mgr()->LookupDevice(device_name, &dstd));
+ }
+ if (ctx->Async()) {
+ // Note that `h` may not be currently ready. However execution order will
+ // make sure that `h` is ready before the copy is actually done.
+ CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx);
+ TensorHandle* output = node->dst();
+ // Note that calling Add makes `node` accessible by the EagerExecutor
+ // thread. So further accesses need to be thread-safe.
+ ctx->ExecutorAdd(node);
+ *result = output;
+ return Status::OK();
+ } else {
+ TF_RETURN_IF_ERROR(h->CopyToDevice(ctx, dstd, result));
+ return Status::OK();
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/execute.h b/tensorflow/core/common_runtime/eager/execute.h
new file mode 100644
index 0000000000..0f6ad031e1
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/execute.h
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_H_
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+namespace tensorflow {
+
+// Low-level utility to execute the kernel specified by kernel on device device,
+// with the inputs op_inputs, in the context ctx.
+Status EagerExecute(EagerContext* ctx, Device* device,
+ const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
+ KernelAndDevice* kernel, NodeExecStats* maybe_stats,
+ TensorHandle** retvals, int num_retvals);
+
+// Low-level utility to copy a tensor handle from one device to another.
+Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
+ const char* device_name, TensorHandle** result);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_H_
diff --git a/tensorflow/core/common_runtime/eager/execute_node.h b/tensorflow/core/common_runtime/eager/execute_node.h
new file mode 100644
index 0000000000..93018dd969
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/execute_node.h
@@ -0,0 +1,88 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/execute.h"
+#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+namespace tensorflow {
+
+class ExecuteNode : public EagerNode {
+ public:
+ ExecuteNode(uint64 id, EagerContext* ctx, Device* op_device,
+ const tensorflow::gtl::InlinedVector<TensorHandle*, 4>& inputs,
+ KernelAndDevice* kernel, NodeExecStats* maybe_stats,
+ const DataTypeVector& output_dtypes,
+ const tensorflow::gtl::InlinedVector<TensorHandle*, 2>& retvals)
+ : EagerNode(id),
+ ctx_(ctx),
+ op_device_(op_device),
+ inputs_(inputs),
+ kernel_(kernel),
+ maybe_stats_(maybe_stats),
+ retvals_(retvals) {
+ for (auto handle : inputs_) {
+ handle->Ref();
+ }
+ for (auto handle : retvals_) {
+ handle->Ref();
+ }
+ }
+
+ ~ExecuteNode() override {
+ for (auto handle : inputs_) {
+ handle->Unref();
+ }
+ for (auto handle : retvals_) {
+ handle->Unref();
+ }
+ }
+
+ tensorflow::Status Run() override {
+ const Status status =
+ EagerExecute(ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(),
+ retvals_.begin(), retvals_.size());
+ if (status.ok()) {
+ return status;
+ } else {
+ return Status(status.code(),
+ strings::StrCat("Got error, \"", status.error_message(),
+ "\" while executing kernel ",
+ kernel_->kernel()->def().DebugString()));
+ }
+ }
+
+ private:
+ tensorflow::EagerContext* ctx_;
+ tensorflow::Device* op_device_;
+ tensorflow::gtl::InlinedVector<TensorHandle*, 4> inputs_;
+ tensorflow::KernelAndDevice* kernel_;
+ std::unique_ptr<NodeExecStats> maybe_stats_;
+ tensorflow::gtl::InlinedVector<TensorHandle*, 2> retvals_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_
diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD
index f6fe9edb02..5fab740e92 100644
--- a/tensorflow/core/debug/BUILD
+++ b/tensorflow/core/debug/BUILD
@@ -339,18 +339,3 @@ cc_library(
# ],
# visibility = ["//visibility:public"],
# )
-
-# -----------------------------------------------------------------------------
-# Google-internal targets. These must be at the end for syncrepo.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 434626bd2d..b07cb8cdcb 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -8,18 +8,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
-filegroup(
name = "c_srcs",
data = glob([
"**/*.cc",
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index 9dae1b9859..9c655bfa31 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -6,18 +6,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
-filegroup(
name = "c_srcs",
data = glob([
"**/*.cc",
diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc
new file mode 100644
index 0000000000..a26f2c2f31
--- /dev/null
+++ b/tensorflow/core/framework/collective.cc
@@ -0,0 +1,120 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/collective.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+
+string CollGroupParams::ToString() const {
+ return strings::StrCat("CollGroupParams {group_key=", group_key,
+ " group_size=", group_size,
+ " device_type=", device_type.type_string(),
+ " num_tasks=", num_tasks, "}");
+}
+
+CollInstanceParams& CollInstanceParams::operator=(
+ const CollInstanceParams& other) {
+ if (this != &other) {
+ instance_key = other.instance_key;
+ type = other.type;
+ data_type = other.data_type;
+ shape = other.shape;
+ device_names.clear();
+ device_names.assign(other.device_names.begin(), other.device_names.end());
+ task_names.assign(other.task_names.begin(), other.task_names.end());
+ impl_details.subdiv_offsets.assign(
+ other.impl_details.subdiv_offsets.begin(),
+ other.impl_details.subdiv_offsets.end());
+ impl_details.subdiv_permutations.clear();
+ for (auto p : other.impl_details.subdiv_permutations) {
+ impl_details.subdiv_permutations.push_back(
+ std::vector<int>(p.begin(), p.end()));
+ }
+ impl_details.subdiv_source_rank.assign(
+ other.impl_details.subdiv_source_rank.begin(),
+ other.impl_details.subdiv_source_rank.end());
+ }
+ return *this;
+}
+
+string CollInstanceParams::ToString() const {
+ string v = strings::StrCat("CollInstanceParams { instance_key=", instance_key,
+ " type=", type, " data_type=", data_type,
+ " shape=", shape.DebugString(), " devices {");
+ for (const auto& d : device_names) {
+ strings::StrAppend(&v, d, ",");
+ }
+ strings::StrAppend(&v, "} task_names={");
+ for (const auto& n : task_names) {
+ strings::StrAppend(&v, n, ", ");
+ }
+ strings::StrAppend(&v, "}, subdiv_offsets={");
+ for (const auto& d : impl_details.subdiv_offsets) {
+ strings::StrAppend(&v, d, ",");
+ }
+ strings::StrAppend(&v, "}, subdiv_perms={");
+ for (const auto& p : impl_details.subdiv_permutations) {
+ strings::StrAppend(&v, "{");
+ for (const auto& i : p) {
+ strings::StrAppend(&v, i, ",");
+ }
+ strings::StrAppend(&v, "}"); // one subdiv
+ }
+ strings::StrAppend(&v, "}"); // all subdivs
+ return v;
+}
+
+string CollTaskParams::ToString() const {
+ string v = strings::StrCat("CollTaskParams {is_local={");
+ for (const auto& b : is_local) {
+ strings::StrAppend(&v, static_cast<int>(b), ",");
+ }
+ strings::StrAppend(&v, "}}");
+ return v;
+}
+
+string CollectiveParams::ToString() const {
+ string v = strings::StrCat("CollectiveParams ", name, " {", group.ToString());
+ strings::StrAppend(&v, " ", instance.ToString());
+ strings::StrAppend(&v, " ", task.ToString());
+ strings::StrAppend(&v, " default_rank=", default_rank,
+ " is_source=", is_source, " subdiv_rank={");
+ for (const auto& r : subdiv_rank) {
+ strings::StrAppend(&v, r, ",");
+ }
+ if (!subdiv_source_rank.empty()) {
+ strings::StrAppend(&v, " subdiv_rank={");
+ for (const auto& r : subdiv_source_rank) {
+ strings::StrAppend(&v, r, ",");
+ }
+ strings::StrAppend(&v, "}");
+ }
+ strings::StrAppend(&v, "}}");
+ return v;
+}
+
+/*static*/ OpKernelContext::Params* CollectiveExecutor::CtxParams(
+ OpKernelContext* ctx) {
+ return ctx->params_;
+}
+
+/*static*/
+int64 CollectiveExecutor::kInvalidId = -1;
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
new file mode 100644
index 0000000000..362d345133
--- /dev/null
+++ b/tensorflow/core/framework/collective.h
@@ -0,0 +1,308 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_FRAMEWORK_COLLECTIVE_EXECUTOR_H_
+#define TENSORFLOW_FRAMEWORK_COLLECTIVE_EXECUTOR_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+class BufRendezvous;
+class CancellationManager;
+class CompleteGroupRequest;
+class CompleteGroupResponse;
+class CompleteInstanceRequest;
+class CompleteInstanceResponse;
+class DeviceLocality;
+class GetStepSequenceRequest;
+class GetStepSequenceResponse;
+class Op;
+class Tensor;
+
+// Types of supported collective operations.
+enum CollectiveType {
+ REDUCTION_COLLECTIVE = 0,
+ BROADCAST_COLLECTIVE,
+ UNDEFINED_COLLECTIVE,
+};
+
+// Data common to all members of a device group.
+// All members share the same device set but its order is
+// particular to an instance so it is stored there.
+struct CollGroupParams {
+ int32 group_key;
+ int32 group_size;
+ DeviceType device_type;
+ int32 num_tasks; // number of distinct tasks in group
+ string ToString() const;
+ CollGroupParams() : device_type(DEVICE_CPU) {}
+};
+
+// The best implementation of a collective op depends on many factors
+// including the number of devices involved, the topology of
+// interconnects between them and the sizes of inputs. This structure
+// is used in generating and representing data movement choreography
+// for each specific algorithm, hence it does not have a single, fixed
+// interpretation. On first execution the runtime will update this
+// structure with decisions that will guide all subsequent executions.
+struct CollImplDetails {
+ std::vector<std::vector<int>> subdiv_permutations;
+ std::vector<int> subdiv_offsets;
+ // broadcast only: rank of source in each subdiv
+ std::vector<int> subdiv_source_rank;
+};
+
+// Data common to all members of a collective instance.
+struct CollInstanceParams {
+ int32 instance_key; // Identifies all participating graph nodes.
+ CollectiveType type;
+ DataType data_type;
+ TensorShape shape;
+ // Fully qualified name of device for each member, in default rank order.
+ std::vector<string> device_names;
+ // Task name prefix of corresponding device name.
+ std::vector<string> task_names;
+ CollImplDetails impl_details;
+ string ToString() const;
+ CollInstanceParams& operator=(const struct CollInstanceParams& other);
+};
+
+// Data common to all instance members in the same task.
+struct CollTaskParams {
+ // True for devices that are local to the process, i.e. no RPC needed.
+ std::vector<bool> is_local;
+ string ToString() const;
+};
+
+// Unique to a single CollectiveOp node.
+struct CollectiveParams {
+ CollGroupParams group;
+ CollInstanceParams instance;
+ CollTaskParams task;
+
+ string name; // node name used only for log or error messages
+ int default_rank; // index of this op within device_names
+ bool is_source; // broadcast only
+ // Rank of this device in each subdivision permutation.
+ std::vector<int> subdiv_rank;
+ std::vector<int> subdiv_source_rank;
+ const Tensor* in_tensor; // kernel input
+ Tensor* out_tensor; // kernel output
+ std::unique_ptr<OpKernel> merge_op; // reduction only
+ std::unique_ptr<OpKernel> final_op; // reduction only
+ OpKernelContext* op_context;
+ string ToString() const;
+};
+
+class CollectiveExecutor;
+
+// Interface that provides resolution of device localities.
+class DeviceResolverInterface {
+ public:
+ virtual ~DeviceResolverInterface() {}
+
+ // Collects DeviceLocality protobufs from all of the devices identified
+ // in 'col_params'.
+ virtual void GetDeviceLocalitiesAsync(const CollInstanceParams& inst_params,
+ std::vector<DeviceLocality>* localities,
+ const StatusCallback& done) = 0;
+
+ // Populate *locality with the DeviceLocality of the specified
+ // device.
+ virtual void GetLocalityAsync(const string& device, const string& task,
+ DeviceLocality* locality,
+ const StatusCallback& done) = 0;
+
+ // Clear the cache of device data belonging
+ // to the specified task.
+ virtual void ClearTask(const string& task) = 0;
+};
+
+// Interface that provides resolution of shared CollectiveParams fields.
+class ParamResolverInterface {
+ public:
+ virtual ~ParamResolverInterface() {}
+
+ // Called by each collective op at first execution in order to fill out
+ // the CollectiveParams structure with data gathered from the full
+ // (maybe distributed) collection of peer nodes.
+ virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp,
+ CancellationManager* cancel_mgr,
+ const StatusCallback& done) = 0;
+
+ // Used within a distributed implementation to discover/verify
+ // data shared across a device group.
+ virtual void CompleteGroupAsync(const CompleteGroupRequest* request,
+ CompleteGroupResponse* response,
+ CancellationManager* cancel_mgr,
+ const StatusCallback& done) = 0;
+
+ // Used within a distributed implementation to discover/verify data
+ // shared across an instance group.
+ virtual void CompleteInstanceAsync(const CompleteInstanceRequest* request,
+ CompleteInstanceResponse* response,
+ CancellationManager* cancel_mgr,
+ const StatusCallback& done) = 0;
+};
+
+// Graphs which utilize Collective Ops in a common instance must
+// execute with identical step_ids even if they are disjoint graphs
+// run by otherwise independent tasks. This interface supplies
+// coordinated step_ids to use in such cases.
+class StepSequenceInterface {
+ public:
+ virtual ~StepSequenceInterface() {}
+
+ // Used with a distributed implementation to coordinate step_id
+ // sequences across tasks.
+ virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
+ GetStepSequenceResponse* response,
+ const StatusCallback& done) = 0;
+
+ // Refresh the local per-graph_key step_id sequence from collective
+ // group leader, if applicable.
+ virtual void RefreshStepIdSequenceAsync(int64 graph_key,
+ const StatusCallback& done) = 0;
+
+ // Returns the the step_id that should be used for initiating a new execution
+ // on the specified graph. May return the same step_id multiple times if
+ // RetireStepId or RefreshStepIdReservation is not called.
+ virtual int64 NextStepId(int64 graph_key) = 0;
+
+ // Reports that execution of the given step has completed successfully.
+ // Should be called immediately after a step completes with OK status,
+ // prior to calling NextStepId(). If the step fails, don't call.
+ virtual void RetireStepId(int64 graph_key, int64 step_id) = 0;
+};
+
+// Interface that provides access to per-step CollectiveExecutor
+// instances and various distributed resolution capabilities.
+class CollectiveExecutorMgrInterface : public StepSequenceInterface {
+ public:
+ virtual ~CollectiveExecutorMgrInterface() {}
+
+ // Returns the step-specific CollectiveExecutor, creating if one does not
+ // already exist. The caller assumes ownership of one Ref on the object.
+ virtual CollectiveExecutor* FindOrCreate(int64 step_id) = 0;
+
+ // If there is a CollectiveExecutor for step_id, remove it from the
+ // table.
+ virtual void Cleanup(int64 step_id) = 0;
+
+ virtual ParamResolverInterface* GetParamResolver() const = 0;
+
+ virtual DeviceResolverInterface* GetDeviceResolver() const = 0;
+};
+
+// Interface that a Collective Op implementation uses to exchange data
+// with peers. Note that data exchange is currently limited to types
+// for which DMAHelper::CanUseDMA() returns true, i.e. dense numeric
+// types.
+class PeerAccessInterface {
+ public:
+ virtual ~PeerAccessInterface() {}
+
+ virtual void RecvFromPeer(const string& peer_device, const string& peer_task,
+ bool peer_is_local, const string& key,
+ Device* to_device, DeviceContext* to_device_ctx,
+ const AllocatorAttributes& to_alloc_attr,
+ Tensor* to_tensor,
+ const DeviceLocality& client_locality,
+ const StatusCallback& done) = 0;
+
+ virtual void PostToPeer(const string& peer_device, const string& peer_task,
+ const string& key, Device* from_device,
+ DeviceContext* from_device_ctx,
+ const AllocatorAttributes& from_alloc_attr,
+ const Tensor* from_tensor,
+ const DeviceLocality& client_locality,
+ const StatusCallback& done) = 0;
+};
+
+class PerStepCollectiveRemoteAccess;
+
+// A step-specific object that can execute a collective operation completely
+// described by a CollectiveParams object.
+class CollectiveExecutor : public PeerAccessInterface, public core::RefCounted {
+ public:
+ virtual void StartAbort(const Status& s) {}
+
+ virtual void ExecuteAsync(OpKernelContext* ctx,
+ const CollectiveParams& col_params,
+ const string& exec_key, StatusCallback done) {
+ done(errors::Internal(
+ "A collective Op has been called in a context in which "
+ "a CollectiveExecutor has not been provided."));
+ }
+
+ virtual void CompleteParamsAsync(const string& device, CollectiveParams* cp,
+ CancellationManager* cancel_mgr,
+ StatusCallback done) {
+ cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, done);
+ }
+
+ virtual PerStepCollectiveRemoteAccess* remote_access() { return nullptr; }
+
+ // Used to designate an invalid group or instance key.
+ static int64 kInvalidId;
+
+ // Lexically scoped handle for Ref.
+ class Handle {
+ public:
+ explicit Handle(CollectiveExecutor* ce, bool inherit_ref) : ce_(ce) {
+ if (!inherit_ref) ce->Ref();
+ }
+ ~Handle() { ce_->Unref(); }
+ CollectiveExecutor* get() const { return ce_; }
+
+ private:
+ CollectiveExecutor* ce_;
+ };
+
+ protected:
+ explicit CollectiveExecutor(CollectiveExecutorMgrInterface* cem)
+ : cem_(cem) {}
+
+ // For use only by derived classes
+ static OpKernelContext::Params* CtxParams(OpKernelContext* ctx);
+ CollectiveExecutorMgrInterface* cem_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor);
+};
+
+// Interface of a helper object that provices a CollectiveExecutor with
+// all of the remote access it needs.
+class CollectiveRemoteAccess : public PeerAccessInterface,
+ public DeviceResolverInterface {
+ public:
+ virtual ~CollectiveRemoteAccess() {}
+};
+
+// A per-step version of CollectiveRemoteAccess that cleans up outstanding
+// communications in case step execution is abandoned.
+class PerStepCollectiveRemoteAccess : public CollectiveRemoteAccess {
+ public:
+ virtual ~PerStepCollectiveRemoteAccess() {}
+ virtual void StartAbort(const Status& s) = 0;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_COLLECTIVE_EXECUTOR_H_
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 7230e0f09c..789746b403 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -272,7 +272,7 @@ Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, int output_index);
// Shape function for binary operators that broadcast their inputs.
// Tested by ops/math_ops_test.cc.
inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
- return BroadcastBinaryOpOutputShapeFn(c, 0);
+ return BroadcastBinaryOpOutputShapeFn(c, 0);
}
// Shape function for random operations.
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 5ccd45efc9..2d97160830 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -1101,6 +1101,7 @@ class OpKernelContext {
void NotifyUseOfPersistentTensor(const Tensor& tensor);
Status status_;
+ friend class CollectiveExecutor; // for access to params_
Params* params_; // not owned
mutable mutex mu_; // mutable so const accessors can acquire the lock
gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index 9a458431e7..c84ea3b034 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -319,14 +319,13 @@ class IsResourceInitialized : public OpKernel {
// specified type. The type will be a part of the generated op name.
// TODO(apassos): figure out how to get non-cpu-allocated tensors to work
// through constant folding so this doesn't have to be marked as stateful.
-#define REGISTER_RESOURCE_HANDLE_OP(Type) \
- REGISTER_OP(#Type "HandleOp") \
- .Attr("container: string = ''") \
- .Attr("shared_name: string = ''") \
- .Output("resource: resource") \
- .SetIsStateful() \
- .SetShapeFn(tensorflow::shape_inference::ScalarShape) \
- .Doc("Creates a handle to a " #Type)
+#define REGISTER_RESOURCE_HANDLE_OP(Type) \
+ REGISTER_OP(#Type "HandleOp") \
+ .Attr("container: string = ''") \
+ .Attr("shared_name: string = ''") \
+ .Output("resource: resource") \
+ .SetIsStateful() \
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
// Utility op kernel to produce a handle to a resource of type T.
template <typename T>
diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD
index 2ca9b720ee..9dcc6765f5 100644
--- a/tensorflow/core/grappler/BUILD
+++ b/tensorflow/core/grappler/BUILD
@@ -3,18 +3,6 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
cc_library(
name = "op_types",
srcs = ["op_types.cc"],
diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD
index b653f902e8..9ecf5a6cf7 100644
--- a/tensorflow/core/grappler/clusters/BUILD
+++ b/tensorflow/core/grappler/clusters/BUILD
@@ -8,18 +8,6 @@ load(
"tf_cuda_tests_tags",
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
config_setting(
name = "xsmm",
licenses = ["notice"],
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index 5336df1f51..33949319d5 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -7,18 +7,6 @@ load(
)
filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
-filegroup(
name = "graph_properties_testdata",
srcs = glob([
"graph_properties_testdata/*.pbtxt",
@@ -55,6 +43,7 @@ cc_library(
":utils",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 817247e379..a5fd79447d 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace grappler {
@@ -251,8 +252,7 @@ typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
}
bool IsQueue(const Node& node) {
- StringPiece type(node.type_string());
- return type.ends_with("QueueV2");
+ return str_util::EndsWith(node.type_string(), "QueueV2");
}
// Returns true if the node is an Enter op AND its input is a Queue.
diff --git a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc
index ea4320687a..833205ac6f 100644
--- a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <limits>
#include "tensorflow/core/framework/cost_graph.pb.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/costs/robust_stats.h"
#include "tensorflow/core/grappler/grappler_item.h"
@@ -52,6 +53,8 @@ Status MeasuringCostEstimator::Initialize(const GrapplerItem& item) {
Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph,
CostGraphDef* cost_graph,
Costs* costs) const {
+ const bool running_simulation = (cluster_->type() == "virtual");
+
std::vector<double> times(measurement_steps_);
BlockingCounter barrier(measurement_steps_);
@@ -80,9 +83,23 @@ Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph,
}
const Costs::MicroSeconds finish = Env::Default()->NowMicros();
- const double time = (finish - start).count() * 1e3;
- times[step] = time;
-
+ if (running_simulation) {
+ // When running simulation, return the estimated runtime, not the time it
+ // takes to run the simulation.
+ double time = 0.0;
+ for (const DeviceStepStats& stepstats :
+ metadata.step_stats().dev_stats()) {
+ for (const NodeExecStats& node_stats : stepstats.node_stats()) {
+ const double completion_time =
+ node_stats.all_end_rel_micros() + node_stats.all_start_micros();
+ time = std::max(time, completion_time * 1e3);
+ }
+ }
+ times[step] = time;
+ } else {
+ const double time = (finish - start).count() * 1e3;
+ times[step] = time;
+ }
if (cost_graph && (step + 1 == measurement_steps_)) {
metadata.mutable_cost_graph()->Swap(cost_graph);
}
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index fdbc61f3f1..0f6307cfdf 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -50,6 +50,12 @@ constexpr char kPreventGradient[] = "PreventGradient";
constexpr char kGather[] = "Gather";
constexpr char kGatherV2[] = "GatherV2";
constexpr char kSlice[] = "Slice";
+constexpr char kMaxPool[] = "MaxPool";
+constexpr char kMaxPoolGrad[] = "MaxPoolGrad";
+constexpr char kAvgPool[] = "AvgPool";
+constexpr char kAvgPoolGrad[] = "AvgPoolGrad";
+constexpr char kFusedBatchNorm[] = "FusedBatchNorm";
+constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad";
static const Costs::Duration kMinComputeTime(1);
@@ -71,14 +77,39 @@ Padding GetPadding(const OpInfo& op_features) {
return Padding::SAME; // Default padding.
}
+bool IsTraining(const OpInfo& op_info) {
+ if (op_info.attr().find("is_training") != op_info.attr().end() &&
+ op_info.attr().at("is_training").b()) {
+ return true;
+ }
+ return false;
+}
+
+// TODO(dyoon): support non-4D tensors in the c ost functions of convolution
+// related ops (Conv, Pool, BatchNorm, and their backprops) and the related
+// helper functions.
std::vector<int64> GetStrides(const OpInfo& op_features) {
if (op_features.attr().find("strides") != op_features.attr().end()) {
const auto strides = op_features.attr().at("strides").list().i();
+ CHECK(strides.size() == 4) << "Attr strides is not a length-4 vector: "
+ << op_features.DebugString();
return {strides[0], strides[1], strides[2], strides[3]};
}
return {1, 1, 1, 1};
}
+std::vector<int64> GetKernelSize(const OpInfo& op_info) {
+ if (op_info.attr().find("ksize") != op_info.attr().end()) {
+ const auto ksize = op_info.attr().at("ksize").list().i();
+ CHECK(ksize.size() == 4)
+ << "Attr ksize is not a length-4 vector: " << op_info.DebugString();
+ return {ksize[0], ksize[1], ksize[2], ksize[3]};
+ }
+ // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
+ // {1, 1, 1, 1} in that case.
+ return {1, 1, 1, 1};
+}
+
int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride,
const Padding& padding) {
// Logic for calculating output shape is from GetWindowedOutputSizeVerbose()
@@ -171,9 +202,12 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
+ // TODO(76227186): re-enable with output size check & test
+ /*
{kGather, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
{kGatherV2, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
{kSlice, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
+ */
{kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
@@ -190,7 +224,15 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kRank, wrap(&OpLevelCostEstimator::PredictMetadata)},
{kShape, wrap(&OpLevelCostEstimator::PredictMetadata)},
- {kSize, wrap(&OpLevelCostEstimator::PredictMetadata)}};
+ {kSize, wrap(&OpLevelCostEstimator::PredictMetadata)},
+ {kMaxPool, wrap(&OpLevelCostEstimator::PredictMaxPool)},
+ {kMaxPoolGrad, wrap(&OpLevelCostEstimator::PredictMaxPoolGrad)},
+ {kAvgPool, wrap(&OpLevelCostEstimator::PredictAvgPool)},
+ {kAvgPoolGrad, wrap(&OpLevelCostEstimator::PredictAvgPoolGrad)},
+ {kFusedBatchNorm, wrap(&OpLevelCostEstimator::PredictFusedBatchNorm)},
+ {kFusedBatchNormGrad,
+ wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad)},
+ };
#define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
@@ -255,6 +297,7 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{"QuantizedAdd", EIGEN_COST(scalar_sum_op<float>)},
{"QuantizedMul", EIGEN_COST(scalar_product_op<float>)},
{"RealDiv", EIGEN_COST(scalar_quotient_op<float>)},
+ {"ReluGrad", EIGEN_COST(scalar_max_op<float>)},
{"SquareDifference", 1},
{"Sub", EIGEN_COST(scalar_difference_op<float>)},
{"TruncateDiv", EIGEN_COST(scalar_quotient_op<float>)},
@@ -1041,5 +1084,269 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice(
return costs;
}
+/* static */
+OpLevelCostEstimator::ConvolutionDimensions
+OpLevelCostEstimator::OpDimensionsFromInputs(
+ const TensorShapeProto& original_image_shape, const OpInfo& op_info,
+ bool* found_unknown_shapes) {
+ VLOG(2) << "op features: " << op_info.DebugString();
+ VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
+ auto image_shape =
+ MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes);
+ VLOG(2) << "Image shape: " << image_shape.DebugString();
+
+ int x_index, y_index, channel_index;
+ const string& data_format = GetDataFormat(op_info);
+ if (data_format == "NCHW") {
+ x_index = 2;
+ y_index = 3;
+ channel_index = 1;
+ } else {
+ x_index = 1;
+ y_index = 2;
+ channel_index = 3;
+ }
+ int64 batch = image_shape.dim(0).size();
+ int64 ix = image_shape.dim(x_index).size();
+ int64 iy = image_shape.dim(y_index).size();
+ int64 iz = image_shape.dim(channel_index).size();
+
+ // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
+ // {1, 1, 1, 1} in that case.
+ std::vector<int64> ksize = GetKernelSize(op_info);
+ int64 kx = ksize[x_index];
+ int64 ky = ksize[y_index];
+
+ std::vector<int64> strides = GetStrides(op_info);
+ int64 sx = strides[x_index];
+ int64 sy = strides[y_index];
+ const auto padding = GetPadding(op_info);
+
+ int64 ox = GetOutputSize(ix, kx, sx, padding);
+ int64 oy = GetOutputSize(iy, ky, sy, padding);
+ int64 oz = iz;
+
+ OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
+ batch, ix, iy, iz, kx, ky, oz, ox, oy, sx, sy, padding};
+ return conv_dims;
+}
+
+Costs OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context) const {
+ bool found_unknown_shapes = false;
+ const auto& op_info = op_context.op_info;
+ // x: op_info.inputs(0)
+ ConvolutionDimensions dims = OpDimensionsFromInputs(
+ op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
+ // kx * ky - 1 comparisons per output (kx * xy > 1)
+ // or 1 copy per output (kx * k1 = 1).
+ int per_output_ops = dims.kx * dims.ky == 1 ? 1 : dims.kx * dims.ky - 1;
+ int64 ops = dims.batch * dims.ox * dims.oy * dims.oz * per_output_ops;
+
+ double total_input_size = 0;
+ if (dims.ky >= dims.sy) {
+ total_input_size =
+ CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
+ } else { // dims.ky < dims.sy
+ // Vertical stride is larger than vertical kernel; assuming row-major
+ // format, skip unnecessary rows (or read every kx rows per sy rows, as the
+ // others are not used for output).
+ const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
+ total_input_size =
+ data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
+ }
+ const double total_output_size =
+ CalculateOutputSize(op_info, &found_unknown_shapes);
+
+ Costs costs = PredictOpCountBasedCost(
+ ops, total_input_size + total_output_size, op_info);
+ costs.inaccurate = found_unknown_shapes;
+ costs.max_memory = total_output_size;
+ return costs;
+}
+
+Costs OpLevelCostEstimator::PredictMaxPoolGrad(
+ const OpContext& op_context) const {
+ bool found_unknown_shapes = false;
+ const auto& op_info = op_context.op_info;
+ // x: op_info.inputs(0)
+ // y: op_info.inputs(1)
+ // y_grad: op_info.inputs(2)
+ ConvolutionDimensions dims = OpDimensionsFromInputs(
+ op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
+
+ int64 ops = 0;
+ if (dims.kx == 1 && dims.ky == 1) {
+ // 1x1 window. No need to know which input was max.
+ ops = dims.batch * dims.ix * dims.iy * dims.iz;
+ } else if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
+ // Non-overlapping window: re-run maxpool, then assign zero or y_grad.
+ ops = dims.batch * dims.iz *
+ (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy);
+ } else {
+ // Overlapping window: initialize with zeros, re-run maxpool, then
+ // accumulate y_gad to proper x_grad locations.
+ ops = dims.batch * dims.iz *
+ (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy * 2);
+ }
+
+ // Just read x and y_grad; no need to read y as we assume MaxPoolGrad re-run
+ // MaxPool internally.
+ double total_input_size =
+ CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
+ total_input_size +=
+ CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
+ // Write x_grad; size equal to x.
+ const double total_output_size =
+ CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
+
+ Costs costs = PredictOpCountBasedCost(
+ ops, total_input_size + total_output_size, op_info);
+ costs.inaccurate = found_unknown_shapes;
+ costs.max_memory = total_output_size;
+ return costs;
+}
+
+Costs OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context) const {
+ bool found_unknown_shapes = false;
+ const auto& op_info = op_context.op_info;
+ // x: op_info.inputs(0)
+ ConvolutionDimensions dims = OpDimensionsFromInputs(
+ op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
+
+ // kx * ky - 1 additions and 1 multiplication per output.
+ int64 ops = dims.batch * dims.ox * dims.oy * dims.oz * dims.kx * dims.ky;
+
+ double total_input_size = 0;
+ if (dims.ky >= dims.sy) {
+ total_input_size =
+ CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
+ } else { // dims.ky < dims.sy
+ // vertical stride is larger than vertical kernel; assuming row-major
+ // format, skip unnecessary rows (or read every kx rows per sy rows, as the
+ // others are not used for output).
+ const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
+ total_input_size =
+ data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
+ }
+ const double total_output_size =
+ CalculateOutputSize(op_info, &found_unknown_shapes);
+
+ Costs costs = PredictOpCountBasedCost(
+ ops, total_input_size + total_output_size, op_info);
+ costs.inaccurate = found_unknown_shapes;
+ costs.max_memory = total_output_size;
+ return costs;
+}
+
+Costs OpLevelCostEstimator::PredictAvgPoolGrad(
+ const OpContext& op_context) const {
+ bool found_unknown_shapes = false;
+ const auto& op_info = op_context.op_info;
+ // x: op_info.inputs(0)
+ // y_grad: op_info.inputs(1)
+ ConvolutionDimensions dims = OpDimensionsFromInputs(
+ op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
+
+ int64 ops = 0;
+ if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
+ // Non-overlapping window.
+ ops = dims.batch * dims.iz * (dims.ix * dims.iy + dims.ox * dims.oy);
+ } else {
+ // Overlapping window.
+ ops = dims.batch * dims.iz *
+ (dims.ix * dims.iy + dims.ox * dims.oy * (dims.kx * dims.ky + 1));
+ }
+
+ const double total_input_size =
+ CalculateInputSize(op_info, &found_unknown_shapes);
+ const double total_output_size =
+ CalculateOutputSize(op_info, &found_unknown_shapes);
+
+ Costs costs = PredictOpCountBasedCost(
+ ops, total_input_size + total_output_size, op_info);
+ costs.inaccurate = found_unknown_shapes;
+ costs.max_memory = total_output_size;
+ return costs;
+}
+
+Costs OpLevelCostEstimator::PredictFusedBatchNorm(
+ const OpContext& op_context) const {
+ bool found_unknown_shapes = false;
+ const auto& op_info = op_context.op_info;
+ // x: op_info.inputs(0)
+ // scale: op_info.inputs(1)
+ // offset: op_info.inputs(2)
+ // mean: op_info.inputs(3) --> only for inference
+ // variance: op_info.inputs(4) --> only for inference
+ ConvolutionDimensions dims = OpDimensionsFromInputs(
+ op_info.inputs(0).shape(), op_info, &found_unknown_shapes);
+ const bool is_training = IsTraining(op_info);
+
+ int64 ops = 0;
+ const auto rsqrt_cost = Eigen::internal::functor_traits<
+ Eigen::internal::scalar_rsqrt_op<float>>::Cost;
+ if (is_training) {
+ ops = dims.iz * (dims.batch * dims.ix * dims.iy * 4 + 6 + rsqrt_cost);
+ } else {
+ ops = dims.batch * dims.ix * dims.iy * dims.iz * 2;
+ }
+
+ const double size_nhwc =
+ CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
+ const double size_c =
+ CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
+ double total_input_size = 0.0;
+ double total_internal_read_size = 0.0;
+ double total_output_size = 0.0;
+ if (is_training) {
+ total_input_size = size_nhwc + size_c * 2;
+ total_output_size = size_nhwc + size_c * 4;
+ total_internal_read_size = size_nhwc;
+ } else {
+ total_input_size = size_nhwc + size_c * 4;
+ total_output_size = size_nhwc;
+ }
+
+ Costs costs = PredictOpCountBasedCost(
+ ops, total_input_size + total_output_size + total_internal_read_size,
+ op_info);
+ costs.inaccurate = found_unknown_shapes;
+ costs.max_memory = total_output_size;
+ return costs;
+}
+
+Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
+ const OpContext& op_context) const {
+ bool found_unknown_shapes = false;
+ const auto& op_info = op_context.op_info;
+ // y_backprop: op_info.inputs(0)
+ // x: op_info.inputs(1)
+ // scale: op_info.inputs(2)
+ // mean: op_info.inputs(3)
+ // variance or inverse of variance: op_info.inputs(4)
+ ConvolutionDimensions dims = OpDimensionsFromInputs(
+ op_info.inputs(1).shape(), op_info, &found_unknown_shapes);
+
+ int64 ops = 0;
+ const auto rsqrt_cost = Eigen::internal::functor_traits<
+ Eigen::internal::scalar_rsqrt_op<float>>::Cost;
+ ops = dims.iz * (dims.batch * dims.ix * dims.iy * 11 + 5 + rsqrt_cost);
+
+ const double size_nhwc =
+ CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
+ const double size_c =
+ CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
+ double total_input_size = size_nhwc * 2 + size_c * 2;
+ double total_internal_read_size = size_nhwc;
+ double total_output_size = size_nhwc * 1 + size_c * 2;
+
+ Costs costs = PredictOpCountBasedCost(
+ ops, total_input_size + total_output_size + total_internal_read_size,
+ op_info);
+ costs.inaccurate = found_unknown_shapes;
+ costs.max_memory = total_output_size;
+ return costs;
+}
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 1b3babb206..fcbecbb6dc 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -145,6 +145,12 @@ class OpLevelCostEstimator {
Costs PredictBatchMatMul(const OpContext& op_context) const;
Costs PredictMetadata(const OpContext& op_context) const;
Costs PredictGatherOrSlice(const OpContext& op_context) const;
+ Costs PredictMaxPool(const OpContext& op_context) const;
+ Costs PredictMaxPoolGrad(const OpContext& op_context) const;
+ Costs PredictAvgPool(const OpContext& op_context) const;
+ Costs PredictAvgPoolGrad(const OpContext& op_context) const;
+ Costs PredictFusedBatchNorm(const OpContext& op_context) const;
+ Costs PredictFusedBatchNormGrad(const OpContext& op_context) const;
// Utility function for safe division. Returns 0
// if rhs is 0 or negative.
@@ -156,9 +162,15 @@ class OpLevelCostEstimator {
}
}
+ // For convolution and its grad ops.
static ConvolutionDimensions ConvolutionDimensionsFromInputs(
const TensorShapeProto& original_image_shape,
- const TensorShapeProto& original_filter_shape, const OpInfo& op_features,
+ const TensorShapeProto& original_filter_shape, const OpInfo& op_info,
+ bool* found_unknown_shapes);
+
+ // For Pooling, FusedBatchNorm, and their grad ops.
+ static ConvolutionDimensions OpDimensionsFromInputs(
+ const TensorShapeProto& original_image_shape, const OpInfo& op_info,
bool* found_unknown_shapes);
protected:
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index f2a9615dfb..56915ed821 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
@@ -169,6 +171,130 @@ OpContext DescribeBiasAdd(int size1, int size2) {
return op_context;
}
+int GetOutputSize(const int x, const int k, const int s,
+ const string& padding) {
+ if (padding == "SAME") {
+ return (x + s - 1) / s;
+ } else {
+ return (x - k + s) / s;
+ }
+}
+
+std::vector<int> GetPoolingOutputSize(const std::vector<int>& input,
+ const std::vector<int>& ksize,
+ const std::vector<int>& strides,
+ const string& data_format,
+ const string& padding) {
+ // h, w, and c indices: default with NHWC.
+ int h_index = 1;
+ int w_index = 2;
+ int c_index = 3;
+ if (data_format == "NCHW") {
+ h_index = 2;
+ w_index = 3;
+ c_index = 1;
+ }
+ // Extract parameters.
+ int n = input[0];
+ int h = input[h_index];
+ int w = input[w_index];
+ int c = input[c_index];
+ int sx = strides[h_index];
+ int sy = strides[w_index];
+ int kx = ksize[h_index];
+ int ky = ksize[w_index];
+
+ // Output activation size: default with VALID padding.
+ int ho = GetOutputSize(h, kx, sx, padding);
+ int wo = GetOutputSize(w, ky, sy, padding);
+
+ std::vector<int> output;
+ if (data_format == "NHWC") {
+ output = {n, ho, wo, c};
+ } else {
+ output = {n, c, ho, wo};
+ }
+ return output;
+}
+
+OpContext DescribePoolingOp(const string& op_name, const std::vector<int>& x,
+ const std::vector<int>& ksize,
+ const std::vector<int>& strides,
+ const string& data_format, const string& padding) {
+ OpContext op_context;
+ auto& op_info = op_context.op_info;
+ SetCpuDevice(&op_info);
+ op_info.set_op(op_name);
+
+ const std::vector<int> y =
+ GetPoolingOutputSize(x, ksize, strides, data_format, padding);
+ if (op_name == "AvgPool" || op_name == "MaxPool") {
+ // input: x, output: y.
+ DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
+ DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_outputs());
+ } else if (op_name == "AvgPoolGrad") {
+ // input: x, y_grad, output: x_grad.
+ DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
+ DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
+ DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
+ } else if (op_name == "MaxPoolGrad") {
+ // input: x, y, y_grad, output: x_grad.
+ DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
+ DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
+ DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
+ DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
+ }
+ auto* attr = op_info.mutable_attr();
+ SetAttrValue(data_format, &(*attr)["data_format"]);
+ SetAttrValue(padding, &(*attr)["padding"]);
+ SetAttrValue(strides, &(*attr)["strides"]);
+ SetAttrValue(ksize, &(*attr)["ksize"]);
+ return op_context;
+}
+
+OpContext DescribeFusedBatchNorm(const bool is_training, const bool is_grad,
+ const std::vector<int>& x,
+ const string& data_format) {
+ // First, get MaxPool op info with unit stride and unit window.
+ OpContext op_context = DescribePoolingOp("MaxPool", x, {1, 1, 1, 1},
+ {1, 1, 1, 1}, data_format, "SAME");
+ auto& op_info = op_context.op_info;
+ // Override op name.
+ if (is_grad) {
+ op_info.set_op("FusedBatchNormGrad");
+ } else {
+ op_info.set_op("FusedBatchNorm");
+ }
+
+ // Add additional input output tensors.
+ if (is_grad) {
+ DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
+ }
+ int num_1d_inputs = is_grad ? 3 : 4;
+ for (int i = 0; i < num_1d_inputs; i++) {
+ auto* tensor = op_info.add_inputs();
+ auto* shape = tensor->mutable_shape();
+ shape->add_dim()->set_size(x[3]);
+ tensor->set_dtype(DT_FLOAT);
+ }
+ for (int i = 0; i < 4; i++) {
+ auto* tensor = op_info.add_outputs();
+ auto* shape = tensor->mutable_shape();
+ shape->add_dim()->set_size(x[3]);
+ tensor->set_dtype(DT_FLOAT);
+ }
+
+ // Delete unnecessary attr.
+ auto* attr = op_context.op_info.mutable_attr();
+ attr->erase("ksize");
+ attr->erase("strides");
+ attr->erase("padding");
+
+ // Additional attrs for FusedBatchNorm.
+ SetAttrValue(is_training, &(*attr)["is_training"]);
+
+ return op_context;
+}
} // namespace
class OpLevelCostEstimatorTest : public ::testing::Test {
@@ -192,43 +318,90 @@ class OpLevelCostEstimatorTest : public ::testing::Test {
estimator_.compute_memory_overlap_ = value;
}
+ void ValidateOpDimensionsFromImputs(const int n, const int h, const int w,
+ const int c, const int kx, const int ky,
+ const int sx, const int sy,
+ const string& data_format,
+ const string& padding) {
+ OpContext op_context;
+ int ho;
+ int wo;
+ if (data_format == "NHWC") {
+ op_context = DescribePoolingOp("MaxPool", {n, h, w, c}, {1, kx, ky, 1},
+ {1, sx, sy, 1}, "NHWC", padding);
+ ho = op_context.op_info.outputs(0).shape().dim(1).size();
+ wo = op_context.op_info.outputs(0).shape().dim(2).size();
+ } else {
+ op_context = DescribePoolingOp("MaxPool", {n, c, h, w}, {1, 1, kx, ky},
+ {1, 1, sx, sy}, "NCHW", padding);
+ ho = op_context.op_info.outputs(0).shape().dim(2).size();
+ wo = op_context.op_info.outputs(0).shape().dim(3).size();
+ }
+
+ bool found_unknown_shapes;
+ auto dims = OpLevelCostEstimator::OpDimensionsFromInputs(
+ op_context.op_info.inputs(0).shape(), op_context.op_info,
+ &found_unknown_shapes);
+ Padding padding_enum;
+ if (padding == "VALID") {
+ padding_enum = Padding::VALID;
+ } else {
+ padding_enum = Padding::SAME;
+ }
+ EXPECT_EQ(n, dims.batch);
+ EXPECT_EQ(h, dims.ix);
+ EXPECT_EQ(w, dims.iy);
+ EXPECT_EQ(c, dims.iz);
+ EXPECT_EQ(kx, dims.kx);
+ EXPECT_EQ(ky, dims.ky);
+ EXPECT_EQ(sx, dims.sx);
+ EXPECT_EQ(sy, dims.sy);
+ EXPECT_EQ(ho, dims.ox);
+ EXPECT_EQ(wo, dims.oy);
+ EXPECT_EQ(c, dims.oz);
+ EXPECT_EQ(padding_enum, dims.padding);
+ }
+
OpLevelCostEstimator estimator_;
};
+// TODO(76227186): re-enable with output size check & test
+/*
TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
- OpContext op_context;
- SetCpuDevice(&op_context.op_info);
- op_context.op_info.set_op("Gather");
-
- // Huge first input shouldn't affect Gather execution and memory costs.
- DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
- DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
- DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
-
- auto cost = estimator_.PredictCosts(op_context);
- EXPECT_EQ(Costs::Duration(130), cost.memory_time);
- EXPECT_EQ(Costs::Duration(16), cost.compute_time);
- EXPECT_EQ(Costs::Duration(146), cost.execution_time);
- EXPECT_FALSE(cost.inaccurate);
+OpContext op_context;
+SetCpuDevice(&op_context.op_info);
+op_context.op_info.set_op("Gather");
+
+// Huge first input shouldn't affect Gather execution and memory costs.
+DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
+DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
+DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
+
+auto cost = estimator_.PredictCosts(op_context);
+EXPECT_EQ(Costs::Duration(130), cost.memory_time);
+EXPECT_EQ(Costs::Duration(16), cost.compute_time);
+EXPECT_EQ(Costs::Duration(146), cost.execution_time);
+EXPECT_FALSE(cost.inaccurate);
}
TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) {
- OpContext op_context;
- SetCpuDevice(&op_context.op_info);
- op_context.op_info.set_op("Slice");
-
- // Huge first input shouldn't affect Slice execution and memory costs.
- DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
- DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
- DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
- DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
-
- auto cost = estimator_.PredictCosts(op_context);
- EXPECT_EQ(Costs::Duration(81), cost.memory_time);
- EXPECT_EQ(Costs::Duration(10), cost.compute_time);
- EXPECT_EQ(Costs::Duration(91), cost.execution_time);
- EXPECT_FALSE(cost.inaccurate);
+OpContext op_context;
+SetCpuDevice(&op_context.op_info);
+op_context.op_info.set_op("Slice");
+
+// Huge first input shouldn't affect Slice execution and memory costs.
+DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
+DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
+DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
+DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
+
+auto cost = estimator_.PredictCosts(op_context);
+EXPECT_EQ(Costs::Duration(81), cost.memory_time);
+EXPECT_EQ(Costs::Duration(10), cost.compute_time);
+EXPECT_EQ(Costs::Duration(91), cost.execution_time);
+EXPECT_FALSE(cost.inaccurate);
}
+*/
TEST_F(OpLevelCostEstimatorTest, BiasAddExecutionTime) {
auto cost = PredictCosts(DescribeBiasAdd(1000, 10));
@@ -440,5 +613,226 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
}
}
+TEST_F(OpLevelCostEstimatorTest, OpDimensionsFromInputs) {
+ std::vector<string> paddings = {"VALID", "SAME"};
+ std::vector<string> formats = {"NHWC", "NCHW"};
+ for (const auto& p : paddings) {
+ for (const auto& f : formats) {
+ // n, h, w, c, kx, ky, sx, sy, data_format, padding.
+ ValidateOpDimensionsFromImputs(10, 20, 20, 100, 3, 3, 2, 2, f, p);
+ ValidateOpDimensionsFromImputs(10, 20, 20, 100, 1, 1, 3, 3, f, p);
+ ValidateOpDimensionsFromImputs(10, 200, 200, 100, 5, 5, 3, 3, f, p);
+ ValidateOpDimensionsFromImputs(10, 14, 14, 3840, 3, 3, 2, 2, f, p);
+ }
+ }
+}
+
+TEST_F(OpLevelCostEstimatorTest, PredictMaxPool) {
+ auto predict_max_pool = [this](const int n, const int in, const int c,
+ const int k, const int s,
+ const string& padding) -> Costs {
+ OpContext op_context = DescribePoolingOp(
+ "MaxPool", {n, in, in, c}, {1, k, k, 1}, {1, s, s, 1}, "NHWC", padding);
+ return estimator_.PredictCosts(op_context);
+ };
+
+ {
+ // Typical 3xz3 window with 2x2 stride.
+ auto costs = predict_max_pool(10, 20, 384, 3, 2, "SAME");
+ EXPECT_EQ(Costs::Duration(1075200), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(307200), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(768000), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+ {
+ // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
+ auto costs = predict_max_pool(10, 20, 384, 1, 2, "SAME");
+ EXPECT_EQ(Costs::Duration(499200), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(38400), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(460800), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+ {
+ // 2x2 window with 3x3 stride.
+ auto costs = predict_max_pool(10, 20, 384, 2, 3, "VALID");
+ EXPECT_EQ(Costs::Duration(561792), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(56448), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(505344), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+}
+
+TEST_F(OpLevelCostEstimatorTest, PredictMaxPoolGrad) {
+ auto predict_max_pool_grad = [this](const int n, const int in, const int c,
+ const int k, const int s,
+ const string& padding) -> Costs {
+ OpContext op_context =
+ DescribePoolingOp("MaxPoolGrad", {n, in, in, c}, {1, k, k, 1},
+ {1, s, s, 1}, "NHWC", padding);
+ return estimator_.PredictCosts(op_context);
+ };
+
+ {
+ // Typical 3xz3 window with 2x2 stride.
+ auto costs = predict_max_pool_grad(10, 20, 384, 3, 2, "SAME");
+ EXPECT_EQ(Costs::Duration(1996800), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(614400), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+ {
+ // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
+ auto costs = predict_max_pool_grad(10, 20, 384, 1, 2, "SAME");
+ EXPECT_EQ(Costs::Duration(1536000), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(153600), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+ {
+ // 2x2 window with 3x3 stride.
+ auto costs = predict_max_pool_grad(10, 20, 384, 2, 3, "VALID");
+ EXPECT_EQ(Costs::Duration(1514112), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(210048), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(1304064), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+}
+
+TEST_F(OpLevelCostEstimatorTest, PredictAvgPool) {
+ auto predict_avg_pool = [this](const int n, const int in, const int c,
+ const int k, const int s,
+ const string& padding) -> Costs {
+ OpContext op_context = DescribePoolingOp(
+ "AvgPool", {n, in, in, c}, {1, k, k, 1}, {1, s, s, 1}, "NHWC", padding);
+ return estimator_.PredictCosts(op_context);
+ };
+
+ {
+ // Typical 3xz3 window with 2x2 stride.
+ auto costs = predict_avg_pool(10, 20, 384, 3, 2, "SAME");
+ EXPECT_EQ(Costs::Duration(1113600), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(345600), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(768000), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+ {
+ // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
+ auto costs = predict_avg_pool(10, 20, 384, 1, 2, "SAME");
+ EXPECT_EQ(Costs::Duration(499200), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(38400), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(460800), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+ {
+ // 2x2 window with 3x3 stride.
+ auto costs = predict_avg_pool(10, 20, 384, 2, 3, "VALID");
+ EXPECT_EQ(Costs::Duration(580608), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(75264), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(505344), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+}
+
+TEST_F(OpLevelCostEstimatorTest, PredictAvgPoolGrad) {
+ auto predict_avg_pool_grad = [this](const int n, const int in, const int c,
+ const int k, const int s,
+ const string& padding) -> Costs {
+ OpContext op_context =
+ DescribePoolingOp("AvgPoolGrad", {n, in, in, c}, {1, k, k, 1},
+ {1, s, s, 1}, "NHWC", padding);
+ return estimator_.PredictCosts(op_context);
+ };
+
+ {
+ // Typical 3xz3 window with 2x2 stride.
+ auto costs = predict_avg_pool_grad(10, 20, 384, 3, 2, "SAME");
+ EXPECT_EQ(Costs::Duration(1920000), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(537600), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+ {
+ // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
+ auto costs = predict_avg_pool_grad(10, 20, 384, 1, 2, "SAME");
+ EXPECT_EQ(Costs::Duration(1574400), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(192000), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+ {
+ // 2x2 window with 3x3 stride.
+ auto costs = predict_avg_pool_grad(10, 20, 384, 2, 3, "VALID");
+ EXPECT_EQ(Costs::Duration(1476480), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(172416), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(1304064), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+}
+
+TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNorm) {
+ auto predict_fused_bn = [this](const int n, const int in, const int c,
+ const bool is_training) -> Costs {
+ OpContext op_context = DescribeFusedBatchNorm(
+ is_training, /*is_grad=*/false, {n, in, in, c}, "NHWC");
+ return estimator_.PredictCosts(op_context);
+ };
+
+ {
+ auto costs = predict_fused_bn(10, 20, 96, /*is_training=*/true);
+ EXPECT_EQ(Costs::Duration(614737), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(153706), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(461031), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+
+ {
+ auto costs = predict_fused_bn(10, 20, 32, /*is_training=*/true);
+ EXPECT_EQ(Costs::Duration(204913), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(51236), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(153677), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+
+ {
+ auto costs = predict_fused_bn(10, 20, 96, /*is_training=*/false);
+ EXPECT_EQ(Costs::Duration(384154), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(76800), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(307354), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+
+ {
+ auto costs = predict_fused_bn(10, 20, 32, /*is_training=*/false);
+ EXPECT_EQ(Costs::Duration(128052), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(25600), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(102452), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+}
+
+TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNormGrad) {
+ auto predict_fused_bn_grad = [this](const int n, const int in,
+ const int c) -> Costs {
+ OpContext op_context = DescribeFusedBatchNorm(
+ /*is_training=*/false, /*is_grad=*/true, {n, in, in, c}, "NHWC");
+ return estimator_.PredictCosts(op_context);
+ };
+
+ {
+ auto costs = predict_fused_bn_grad(10, 20, 96);
+ EXPECT_EQ(Costs::Duration(1037050), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(422496), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(614554), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+
+ {
+ auto costs = predict_fused_bn_grad(128, 7, 384);
+ EXPECT_EQ(Costs::Duration(6503809), costs.execution_time);
+ EXPECT_EQ(Costs::Duration(2649677), costs.compute_time);
+ EXPECT_EQ(Costs::Duration(3854132), costs.memory_time);
+ EXPECT_FALSE(costs.inaccurate);
+ }
+}
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index 076945d5c6..f318e3911c 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -212,8 +212,8 @@ DeviceProperties GetDeviceInfo(const string& device_str) {
CudaGpuId cuda_gpu_id;
Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
if (!s.ok()) {
- LOG(ERROR) << s;
- return unknown;
+ // We are probably running simulation without linking cuda libraries.
+ cuda_gpu_id = CudaGpuId(parsed.id);
}
return GetLocalGPUInfo(cuda_gpu_id);
} else if (parsed.type == "CPU") {
diff --git a/tensorflow/core/grappler/inputs/BUILD b/tensorflow/core/grappler/inputs/BUILD
index b683216590..ffa204028c 100644
--- a/tensorflow/core/grappler/inputs/BUILD
+++ b/tensorflow/core/grappler/inputs/BUILD
@@ -2,18 +2,6 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
cc_library(
name = "utils",
srcs = [
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 1a6751befc..c31ac9b59c 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -309,6 +309,8 @@ bool IsSplitV(const NodeDef& node) { return node.op() == "SplitV"; }
bool IsSqrtGrad(const NodeDef& node) { return node.op() == "SqrtGrad"; }
+bool IsSquare(const NodeDef& node) { return node.op() == "Square"; }
+
bool IsSquaredDifference(const NodeDef& node) {
return node.op() == "SquaredDifference";
}
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 1ec1cd46e3..39affcbc24 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -121,6 +121,7 @@ bool IsSoftsignGrad(const NodeDef& node);
bool IsSplit(const NodeDef& node);
bool IsSplitV(const NodeDef& node);
bool IsSqrtGrad(const NodeDef& node);
+bool IsSquare(const NodeDef& node);
bool IsSquaredDifference(const NodeDef& node);
bool IsSqueeze(const NodeDef& node);
bool IsStackOp(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 601984fcfd..2c365c467c 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -12,18 +12,6 @@ load(
"tf_protos_grappler",
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
cc_library(
name = "static_schedule",
srcs = ["static_schedule.cc"],
@@ -373,6 +361,7 @@ tf_kernel_library(
srcs = [
"gpu_swapping_kernels.cc",
],
+ visibility = ["//tensorflow:__subpackages__"],
deps = [
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
@@ -385,6 +374,7 @@ cc_library(
srcs = [
"gpu_swapping_ops.cc",
],
+ visibility = ["//tensorflow:__subpackages__"],
deps = [
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
@@ -418,10 +408,7 @@ cc_library(
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/utils:traversal",
- ] + if_cuda([
- ":gpu_swapping_kernels",
- ":gpu_swapping_ops",
- ]),
+ ],
)
tf_cuda_only_cc_test(
@@ -429,6 +416,8 @@ tf_cuda_only_cc_test(
srcs = ["memory_optimizer_test.cc"],
tags = ["no_cuda_on_cpu_tap"], # Do not re-enable again without actually testing.
deps = [
+ ":gpu_swapping_kernels",
+ ":gpu_swapping_ops",
":memory_optimizer",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:ops",
@@ -630,7 +619,10 @@ cc_library(
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/optimizers:graph_optimizer",
],
@@ -638,7 +630,6 @@ cc_library(
tf_cuda_cc_test(
name = "debug_stripper_test",
- size = "small",
srcs = ["debug_stripper_test.cc"],
deps = [
":debug_stripper",
@@ -646,6 +637,7 @@ tf_cuda_cc_test(
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/utils:grappler_test",
],
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 23e21855c8..5dd0b6f4b0 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1089,7 +1089,8 @@ namespace {
bool FeedsInPlaceOp(const SimpleGraphView& graph_view, const NodeDef& node) {
const std::unordered_set<string> op_types_to_traverse = {
- node.op(), "Identity", "IdentityN", "Reshape"};
+ node.op(), "Identity", "IdentityN", "Reshape",
+ "ExpandDims", "Enter", "Switch", "Merge"};
int node_idx = graph_view.index(node.name());
std::set<int> node_fanout;
graph_view.DepthFirstSearch(op_types_to_traverse, node_idx, &node_fanout);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 792f675043..ad3edc144a 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -158,7 +158,7 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
ArithmeticOptimizer optimizer;
GraphDef output;
- auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
EXPECT_EQ(1, tensors_expected.size());
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -176,7 +176,7 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
EXPECT_EQ("c1", new_div.input(0));
EXPECT_EQ("c1", new_div.input(1));
- auto tensors = EvaluateNodes(output, item.fetch);
+ auto tensors = EvaluateNodes(output, item.fetch, {});
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index bdec73e69e..c3f8a1ce22 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -109,33 +109,18 @@ class DeviceSimple : public DeviceBase {
};
template <typename T>
-bool AllValuesAre(const TensorProto& tensor, const T& value) {
- // TensorProto represents the content of the tensor in either <type>_val or
- // tensor_content.
- typename checkpoint::SaveTypeTraits<T>::RepeatedField* tensor_values =
- checkpoint::MutableTensorProtoData<T>(const_cast<TensorProto*>(&tensor));
- if (!tensor_values->empty()) {
- for (const T& tensor_value : *tensor_values) {
- if (tensor_value != value) {
- return false;
- }
- }
- return true;
+bool AllValuesAre(const TensorProto& proto, const T& value) {
+ Tensor tensor;
+ if (!tensor.FromProto(proto)) {
+ return false;
}
- const auto tensor_content_size = tensor.tensor_content().size();
- if (tensor_content_size > 0) {
- CHECK_EQ(0, tensor_content_size % sizeof(T));
- std::vector<T> raw_values(tensor_content_size / sizeof(T));
- port::CopyToArray(tensor.tensor_content(),
- reinterpret_cast<char*>(raw_values.data()));
- for (int i = 0; i < tensor_content_size / sizeof(T); ++i) {
- if (raw_values[i] != value) {
- return false;
- }
+ auto values = tensor.flat<T>();
+ for (int i = 0; i < tensor.NumElements(); ++i) {
+ if (values(i) != value) {
+ return false;
}
- return true;
}
- return false;
+ return true;
}
// Add new_input as a control input to node if it does not already depend on it.
@@ -825,17 +810,23 @@ Status CreateConstantTensorAttrValue(DataType type, double value,
t->set_dtype(type);
*t->mutable_tensor_shape() = shape;
switch (type) {
- SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
- SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
- SET_TENSOR_VAL_CASE(DT_INT64, int64, int64);
- SET_TENSOR_VAL_CASE(DT_UINT64, int64, int64);
- SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
- SET_TENSOR_VAL_CASE(DT_UINT32, int32, int);
- SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
- SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
- SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
- SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
- SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
+ case DT_HALF:
+ t->add_half_val(static_cast<Eigen::half>(value).x);
+ break;
+ case DT_BFLOAT16:
+ t->add_half_val(static_cast<bfloat16>(value).value);
+ break;
+ SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
+ SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
+ SET_TENSOR_VAL_CASE(DT_INT64, int64, int64);
+ SET_TENSOR_VAL_CASE(DT_UINT64, int64, int64);
+ SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
+ SET_TENSOR_VAL_CASE(DT_UINT32, int32, int);
+ SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
+ SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
+ SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
+ SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
+ SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
default:
return errors::InvalidArgument("Unsupported type: ", type);
}
@@ -1388,8 +1379,8 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const {
}
const auto dtype = node.attr().at("dtype").type();
switch (dtype) {
- // TODO(rmlarsen): Make DT_HALF case compile.
- // IS_ONES_CASE(DT_HALF);
+ IS_ONES_CASE(DT_HALF);
+ IS_ONES_CASE(DT_BFLOAT16);
IS_ONES_CASE(DT_FLOAT);
IS_ONES_CASE(DT_DOUBLE);
IS_ONES_CASE(DT_COMPLEX64);
@@ -1423,8 +1414,8 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
}
const auto dtype = node.attr().at("dtype").type();
switch (dtype) {
- // TODO(rmlarsen): Make DT_HALF case compile.
- // IS_ZEROS_CASE(DT_HALF);
+ IS_ZEROS_CASE(DT_HALF);
+ IS_ZEROS_CASE(DT_BFLOAT16);
IS_ZEROS_CASE(DT_FLOAT);
IS_ZEROS_CASE(DT_DOUBLE);
IS_ZEROS_CASE(DT_COMPLEX64);
@@ -1511,9 +1502,8 @@ void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
}
Status ConstantFolding::ReplaceOperationWithConstant(
- double value, const TensorShapeProto& shape, NodeDef* node,
- GraphDef* graph) {
- AttrValue dtype_attr = node->attr().at("T");
+ double value, const AttrValue& dtype_attr, const TensorShapeProto& shape,
+ NodeDef* node, GraphDef* graph) {
AttrValue tensor_attr;
TF_RETURN_IF_ERROR(CreateConstantTensorAttrValue(dtype_attr.type(), value,
shape, &tensor_attr));
@@ -1544,6 +1534,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
// Remove Shuffle or Reverse op over scalar values.
if (use_shape_info &&
+ !properties->GetInputProperties(node->name()).empty() &&
(IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) {
const auto& shape =
properties->GetInputProperties(node->name())[0].shape();
@@ -1947,8 +1938,14 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
(is_mul || is_matmul || optimize_zeros_divided_by_y)) {
const PartialTensorShape shp(output_shape);
if (shp.IsFullyDefined()) {
- TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(0, output_shape, node,
- optimized_graph));
+ AttrValue dtype_attr;
+ if (node->op() == "SparseMatMul") {
+ dtype_attr.set_type(DT_FLOAT);
+ } else {
+ dtype_attr = node->attr().at("T");
+ }
+ TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
+ 0, dtype_attr, output_shape, node, optimized_graph));
continue;
}
// Even if an input shape is only partially known, we may known that it
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index b6645d335e..f8a9e90d62 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -83,7 +83,7 @@ class ConstantFolding : public GraphOptimizer {
void ReplaceOperationWithSnapshot(int input_to_forward, NodeDef* node,
GraphDef* graph);
void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
- Status ReplaceOperationWithConstant(double value,
+ Status ReplaceOperationWithConstant(double value, const AttrValue& dtype_attr,
const TensorShapeProto& shape,
NodeDef* node, GraphDef* graph);
void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 6340565bcd..e0ff9b17b1 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -28,7 +28,59 @@ namespace tensorflow {
namespace grappler {
namespace {
-class ConstantFoldingTest : public GrapplerTest {};
+class ConstantFoldingTest : public GrapplerTest {
+ protected:
+ template <DataType DTYPE>
+ void SimpleNeutralElementTest() {
+ typedef typename EnumToDataType<DTYPE>::Type T;
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DTYPE,
+ ops::Placeholder::Shape(TensorShape({2, 2})));
+ Tensor zeros_t(DTYPE, TensorShape({2, 2}));
+ Tensor ones_t(DTYPE, TensorShape({2, 2}));
+ Tensor x_t(DTYPE, TensorShape({2, 2}));
+ for (int i = 0; i < 4; ++i) {
+ zeros_t.flat<T>()(i) = T(0);
+ ones_t.flat<T>()(i) = T(1);
+ x_t.flat<T>()(i) = T(i + 1);
+ }
+ Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
+ Output ones = ops::Const(s.WithOpName("ones"), ones_t);
+ Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
+ Output mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"mul1", "mul2"};
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ LOG(INFO) << output.DebugString();
+ EXPECT_EQ(5, output.node_size());
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ const string& name = node.name();
+ if (name == "mul1") {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^x", node.input(0));
+ EXPECT_EQ("^zeros", node.input(1));
+ } else if (name == "mul2") {
+ EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("^ones", node.input(1));
+ }
+ }
+ auto tensors_expected =
+ EvaluateNodes(item.graph, {"mul1", "mul2"}, {{"x", x_t}});
+ auto tensors = EvaluateNodes(output, {"mul1", "mul2"}, {{"x", x_t}});
+ EXPECT_EQ(2, tensors_expected.size());
+ EXPECT_EQ(2, tensors.size());
+ for (int i = 0; i < 2; ++i) {
+ test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
+ }
+ }
+};
TEST_F(ConstantFoldingTest, SimpleFolding) {
// Build a simple graph with a few trivially prunable ops.
@@ -322,6 +374,11 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
}
}
+TEST_F(ConstantFoldingTest, NeutralElement_ShortFloats) {
+ SimpleNeutralElementTest<DT_HALF>();
+ SimpleNeutralElementTest<DT_BFLOAT16>();
+}
+
TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output cf_half = ops::Const(s.WithOpName("cf_half"), 0.5f, {1});
@@ -614,7 +671,8 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
GrapplerItem item;
item.fetch.push_back("e");
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
-
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
@@ -641,6 +699,9 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
}
}
EXPECT_EQ(1, found);
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
}
TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) {
@@ -714,7 +775,8 @@ TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) {
GrapplerItem item;
item.fetch.push_back("i2");
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
-
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
@@ -733,6 +795,9 @@ TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) {
EXPECT_EQ("^p2", node.input(1));
}
}
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
}
TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
@@ -1742,6 +1807,12 @@ TEST_F(ConstantFoldingTest, LargeConstant) {
EXPECT_EQ(2, found);
EXPECT_GT(1024 * 1024, output.ByteSizeLong());
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
TEST_F(ConstantFoldingTest, SwitchIdenticalInputs) {
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.cc b/tensorflow/core/grappler/optimizers/debug_stripper.cc
index 461f1aa2fb..0e058e3435 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper.cc
@@ -14,16 +14,33 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/debug_stripper.h"
+
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
- // TODO(haoliang): Let's remove assertions here.
*output = item.graph;
+ for (NodeDef& node : *output->mutable_node()) {
+ if (IsAssert(node)) {
+ // Convert this node into a no-op.
+ node.set_op("NoOp");
+ node.clear_attr();
+ // Convert all its inputs into control dependency, which will then
+ // be optimized away by dependency optimizer.
+ for (string& inp : *node.mutable_input()) {
+ if (!IsControlInput(inp)) {
+ inp = AsControlDependency(inp);
+ }
+ }
+ }
+ }
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
index d2cabc0798..aacd55f136 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/debug_stripper.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -27,16 +28,78 @@ namespace {
class DebugStripperTest : public GrapplerTest {};
-// TODO(haoliang): Add tests for different removal operations.
TEST_F(DebugStripperTest, OutputEqualToInput) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto c = ops::Const(s.WithOpName("c"), 0, {});
+ constexpr char device[] = "/device:CPU:0";
GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.graph = test::function::GDef(
+ {test::function::NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}},
+ device),
+ test::function::NDef("y", "XTimesTwo", {"x"}, {}, device),
+ test::function::NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, device)},
+ {});
DebugStripper optimizer;
GraphDef output;
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+ CompareGraphs(item.graph, output);
+}
+
+TEST_F(DebugStripperTest, StripAssertFromGraph) {
+ constexpr char device[] = "/device:CPU:0";
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {test::function::NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}},
+ device),
+ test::function::NDef("y", "Placeholder", {}, {{"dtype", DT_FLOAT}},
+ device),
+ test::function::NDef("GreaterEqual", "GreaterEqual", {"x", "y"},
+ {{"T", DT_FLOAT}}, device),
+ test::function::NDef("Assert", "Assert", {"GreaterEqual"},
+ {{"T", DT_FLOAT}}, device),
+ test::function::NDef("z", "Add", {"x", "y", "^Assert"}, {}, device)},
+ {});
+
+ DebugStripper optimizer;
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ int count = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "x") {
+ count++;
+ EXPECT_EQ("Placeholder", node.op());
+ EXPECT_EQ(device, node.device());
+ EXPECT_EQ(0, node.input_size());
+ } else if (node.name() == "y") {
+ count++;
+ EXPECT_EQ("Placeholder", node.op());
+ EXPECT_EQ(device, node.device());
+ EXPECT_EQ(0, node.input_size());
+ } else if (node.name() == "GreaterEqual") {
+ count++;
+ EXPECT_EQ("GreaterEqual", node.op());
+ EXPECT_EQ(device, node.device());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("y", node.input(1));
+ } else if (node.name() == "Assert") {
+ count++;
+ EXPECT_EQ("NoOp", node.op());
+ EXPECT_EQ(device, node.device());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("^GreaterEqual", node.input(0));
+ EXPECT_EQ(0, node.attr_size());
+ } else if (node.name() == "z") {
+ count++;
+ EXPECT_EQ("Add", node.op());
+ EXPECT_EQ(device, node.device());
+ EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("y", node.input(1));
+ EXPECT_EQ("^Assert", node.input(2));
+ }
+ }
+ EXPECT_EQ(5, count);
}
} // namespace
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
index 52a1118080..deb2fabded 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
@@ -414,8 +414,9 @@ TEST_F(FunctionOptimizerTest, SymbolicGradients) {
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- std::vector<Tensor> expected = EvaluateNodes(item.graph, {"out1", "out2"});
- std::vector<Tensor> optimized = EvaluateNodes(output, {"out1", "out2"});
+ std::vector<Tensor> expected =
+ EvaluateNodes(item.graph, {"out1", "out2"}, {});
+ std::vector<Tensor> optimized = EvaluateNodes(output, {"out1", "out2"}, {});
test::ExpectTensorEqual<float>(expected[0], optimized[0]);
test::ExpectTensorEqual<float>(expected[1], optimized[1]);
}
@@ -478,8 +479,8 @@ TEST_F(FunctionOptimizerTest, SymbolicGradientsIdentity) {
EXPECT_EQ("Identity", output.node(i).op());
}
- std::vector<Tensor> expected = EvaluateNodes(item.graph, {"out"});
- std::vector<Tensor> optimized = EvaluateNodes(output, {"out"});
+ std::vector<Tensor> expected = EvaluateNodes(item.graph, {"out"}, {});
+ std::vector<Tensor> optimized = EvaluateNodes(output, {"out"}, {});
test::ExpectTensorEqual<float>(expected[0], optimized[0]);
}
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
index 9595936e9e..a1f80802dd 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
@@ -426,7 +426,7 @@ TEST_F(MemoryOptimizerTest, AccumulationRewrites) {
EXPECT_EQ(4, count);
std::vector<string> fetch = {"a", "b", "c", "e"};
- auto tensors = EvaluateNodes(output, fetch);
+ auto tensors = EvaluateNodes(output, fetch, {});
EXPECT_EQ(4, tensors.size());
for (int i = 0; i < tensors[0].NumElements(); ++i) {
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 47ec16226b..ad655db727 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -102,6 +102,10 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
new FunctionOptimizer(cfg_.function_optimization())));
}
+ if (cfg_.debug_stripper() == RewriterConfig::ON) {
+ optimizers.push_back(
+ std::unique_ptr<GraphOptimizer>(new DebugStripper()));
+ }
if (cfg_.constant_folding() != RewriterConfig::OFF) {
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
new ConstantFolding(cfg_.constant_folding(), cpu_device_)));
@@ -138,10 +142,6 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
new AutoParallel(cfg_.auto_parallel().num_replicas())));
}
- if (cfg_.debug_stripper() == RewriterConfig::ON) {
- optimizers.push_back(
- std::unique_ptr<GraphOptimizer>(new DebugStripper()));
- }
} else {
const std::set<string> available_optimizers = {
"pruning", "function", "constfold", "layout",
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 829bfe9e31..86a6d5000d 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -33,8 +33,8 @@ namespace {
template <typename T>
bool SafeSetScalarTensorValue(double value, Tensor* tensor) {
using RealType = typename Eigen::NumTraits<T>::Real;
- if (value > std::numeric_limits<RealType>::max() ||
- value < std::numeric_limits<RealType>::min()) {
+ if (value > static_cast<double>(std::numeric_limits<RealType>::max()) ||
+ value < static_cast<double>(std::numeric_limits<RealType>::min())) {
return false;
}
tensor->flat<T>()(0) = static_cast<T>(value);
@@ -473,8 +473,8 @@ Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
"Expected scalar tensor, got num_elements = ", tensor->NumElements());
}
switch (dtype) {
- // TODO(rmlarsen): Handle DT_HALF.
- // HANDLE_CASE(DT_HALF);
+ HANDLE_CASE(DT_HALF);
+ HANDLE_CASE(DT_BFLOAT16);
HANDLE_CASE(DT_BOOL);
HANDLE_CASE(DT_FLOAT);
HANDLE_CASE(DT_DOUBLE);
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index 939031c44b..baf24c2505 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -2,18 +2,6 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
cc_library(
name = "scc",
srcs = ["scc.cc"],
diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc
index ee126f4955..910b0acaef 100644
--- a/tensorflow/core/grappler/utils/grappler_test.cc
+++ b/tensorflow/core/grappler/utils/grappler_test.cc
@@ -41,11 +41,17 @@ GrapplerTest::GrapplerTest() {
std::vector<Tensor> GrapplerTest::EvaluateNodes(
const GraphDef& graph, const std::vector<string>& node_names) const {
+ return EvaluateNodes(graph, node_names, {});
+}
+
+std::vector<Tensor> GrapplerTest::EvaluateNodes(
+ const GraphDef& graph, const std::vector<string>& node_names,
+ const std::vector<std::pair<string, Tensor>>& inputs) const {
std::unique_ptr<tensorflow::Session> session(NewSession(options_));
TF_CHECK_OK(session->Create(graph));
RunOptions run_options;
std::vector<Tensor> output_tensors;
- TF_CHECK_OK(session->Run(run_options, {}, node_names, node_names,
+ TF_CHECK_OK(session->Run(run_options, inputs, node_names, node_names,
&output_tensors, nullptr));
TF_CHECK_OK(session->Close());
return output_tensors;
diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h
index e0c67381a4..3bc7bea454 100644
--- a/tensorflow/core/grappler/utils/grappler_test.h
+++ b/tensorflow/core/grappler/utils/grappler_test.h
@@ -37,6 +37,10 @@ class GrapplerTest : public ::testing::Test {
std::vector<Tensor> EvaluateNodes(
const GraphDef& graph, const std::vector<string>& node_names) const;
+ std::vector<Tensor> EvaluateNodes(
+ const GraphDef& graph, const std::vector<string>& node_names,
+ const std::vector<std::pair<string, Tensor>>& inputs) const;
+
std::vector<Tensor> EvaluateFetchNodes(const GrapplerItem& item) const;
NodeDef* AddNode(const string& name, const string& op,
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index b469c01881..d2a2cdd13d 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -6096,6 +6096,13 @@ cc_library(
],
)
+tf_kernel_library(
+ name = "boosted_trees_ops",
+ deps = [
+ "//tensorflow/core/kernels/boosted_trees:boosted_trees_ops",
+ ],
+)
+
cc_library(
name = "captured_function",
hdrs = ["captured_function.h"],
@@ -6147,18 +6154,6 @@ tf_kernel_library(
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
# Library to link with when compiling the cwise_op kernels directly,
# e.g. for selective registration.
# should not be linked by projects that also link the cwise_op library.
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index 4397410a5c..de05c647d6 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -8,18 +8,6 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "**/google_*",
- ],
- ),
-)
-
cc_library(
name = "periodic_function_dynamic",
srcs = ["periodic_function.cc"],
diff --git a/tensorflow/core/kernels/boosted_trees/BUILD b/tensorflow/core/kernels/boosted_trees/BUILD
new file mode 100644
index 0000000000..62327dfe1d
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/BUILD
@@ -0,0 +1,89 @@
+# Description:
+# OpKernels for boosted trees ops.
+
+package(
+ default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_proto_library",
+)
+
+tf_proto_library(
+ name = "boosted_trees_proto",
+ srcs = ["boosted_trees.proto"],
+ cc_api_version = 2,
+ visibility = ["//visibility:public"],
+)
+
+tf_kernel_library(
+ name = "prediction_ops",
+ srcs = ["prediction_ops.cc"],
+ deps = [
+ ":resource_ops",
+ ":resources",
+ "//tensorflow/core:boosted_trees_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "resources",
+ srcs = ["resources.cc"],
+ hdrs = ["resources.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
+ ],
+)
+
+tf_kernel_library(
+ name = "resource_ops",
+ srcs = ["resource_ops.cc"],
+ deps = [
+ ":resources",
+ "//tensorflow/core:boosted_trees_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
+ ],
+)
+
+tf_kernel_library(
+ name = "stats_ops",
+ srcs = ["stats_ops.cc"],
+ deps = [
+ "//tensorflow/core:boosted_trees_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_kernel_library(
+ name = "training_ops",
+ srcs = ["training_ops.cc"],
+ deps = [
+ ":resources",
+ "//tensorflow/core:boosted_trees_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
+ ],
+)
+
+tf_kernel_library(
+ name = "boosted_trees_ops",
+ deps = [
+ ":prediction_ops",
+ ":resource_ops",
+ ":stats_ops",
+ ":training_ops",
+ ],
+)
diff --git a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
new file mode 100644
index 0000000000..106ceedc00
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
@@ -0,0 +1,113 @@
+syntax = "proto3";
+
+package tensorflow.boosted_trees;
+option cc_enable_arenas = true;
+option java_outer_classname = "BoostedTreesProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+// Node describes a node in a tree.
+message Node {
+ oneof node {
+ Leaf leaf = 1;
+ BucketizedSplit bucketized_split = 2;
+ }
+ NodeMetadata metadata = 777;
+}
+
+// NodeMetadata encodes metadata associated with each node in a tree.
+message NodeMetadata {
+ // The gain associated with this node.
+ float gain = 1;
+
+ // The original leaf node before this node was split.
+ Leaf original_leaf = 2;
+}
+
+// Leaves can either hold dense or sparse information.
+message Leaf {
+ oneof leaf {
+ // See third_party/tensorflow/contrib/decision_trees/
+ // proto/generic_tree_model.proto
+ // for a description of how vector and sparse_vector might be used.
+ Vector vector = 1;
+ SparseVector sparse_vector = 2;
+ }
+ float scalar = 3;
+}
+
+message Vector {
+ repeated float value = 1;
+}
+
+message SparseVector {
+ repeated int32 index = 1;
+ repeated float value = 2;
+}
+
+message BucketizedSplit {
+ // Float feature column and split threshold describing
+ // the rule feature <= threshold.
+ int32 feature_id = 1;
+ int32 threshold = 2;
+
+ // Node children indexing into a contiguous
+ // vector of nodes starting from the root.
+ int32 left_id = 3;
+ int32 right_id = 4;
+}
+
+// Tree describes a list of connected nodes.
+// Node 0 must be the root and can carry any payload including a leaf
+// in the case of representing the bias.
+// Note that each node id is implicitly its index in the list of nodes.
+message Tree {
+ repeated Node nodes = 1;
+}
+
+message TreeMetadata {
+ // Number of layers grown for this tree.
+ int32 num_layers_grown = 2;
+
+ // Whether the tree is finalized in that no more layers can be grown.
+ bool is_finalized = 3;
+
+ // If tree was finalized and post pruning happened, it is possible that cache
+ // still refers to some nodes that were deleted or that the node ids changed
+ // (e.g. node id 5 became node id 2 due to pruning of the other branch).
+ // The mapping below allows us to understand where the old ids now map to and
+ // how the values should be adjusted due to post-pruning.
+ // The size of the list should be equal to the number of nodes in the tree
+ // before post-pruning happened.
+ // If the node was pruned, it will have new_node_id equal to the id of a node
+ // that this node was collapsed into. For a node that didn't get pruned, it is
+ // possible that its id still changed, so new_node_id will have the
+ // corresponding id in the pruned tree.
+ // If post-pruning didn't happen, or it did and it had no effect (e.g. no
+ // nodes got pruned), this list will be empty.
+ repeated PostPruneNodeUpdate post_pruned_nodes_meta = 4;
+
+ message PostPruneNodeUpdate {
+ int32 new_node_id = 1;
+ float logit_change = 2;
+ }
+}
+
+message GrowingMetadata {
+ // Number of trees that we have attempted to build. After pruning, these
+ // trees might have been removed.
+ int64 num_trees_attempted = 1;
+ // Number of layers that we have attempted to build. After pruning, these
+ // layers might have been removed.
+ int64 num_layers_attempted = 2;
+}
+
+// TreeEnsemble describes an ensemble of decision trees.
+message TreeEnsemble {
+ repeated Tree trees = 1;
+ repeated float tree_weights = 2;
+
+ repeated TreeMetadata tree_metadata = 3;
+ // Metadata that is used during the training.
+ GrowingMetadata growing_metadata = 4;
+}
diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
new file mode 100644
index 0000000000..b13a450546
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
@@ -0,0 +1,263 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/boosted_trees/resources.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+// The Op used during training time to get the predictions so far with the
+// current ensemble being built.
+// Expect some logits are cached from the previous step and passed through
+// to be reused.
+class BoostedTreesTrainingPredictOp : public OpKernel {
+ public:
+ explicit BoostedTreesTrainingPredictOp(OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
+ &num_bucketized_features_));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("logits_dimension", &logits_dimension_));
+ OP_REQUIRES(context, logits_dimension_ == 1,
+ errors::InvalidArgument(
+ "Currently only one dimensional outputs are supported."));
+ OP_REQUIRES_OK(context, context->GetAttr("max_depth", &max_depth_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ BoostedTreesEnsembleResource* resource;
+ // Get the resource.
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &resource));
+ // Release the reference to the resource once we're done using it.
+ core::ScopedUnref unref_me(resource);
+
+ // Get the inputs.
+ OpInputList bucketized_features_list;
+ OP_REQUIRES_OK(context, context->input_list("bucketized_features",
+ &bucketized_features_list));
+ std::vector<tensorflow::TTypes<int32>::ConstVec> batch_bucketized_features;
+ batch_bucketized_features.reserve(bucketized_features_list.size());
+ for (const Tensor& tensor : bucketized_features_list) {
+ batch_bucketized_features.emplace_back(tensor.vec<int32>());
+ }
+ const int batch_size = batch_bucketized_features[0].size();
+
+ const Tensor* cached_tree_ids_t;
+ OP_REQUIRES_OK(context,
+ context->input("cached_tree_ids", &cached_tree_ids_t));
+ const auto cached_tree_ids = cached_tree_ids_t->vec<int32>();
+
+ const Tensor* cached_node_ids_t;
+ OP_REQUIRES_OK(context,
+ context->input("cached_node_ids", &cached_node_ids_t));
+ const auto cached_node_ids = cached_node_ids_t->vec<int32>();
+
+ // Allocate outputs.
+ Tensor* output_partial_logits_t = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("partial_logits",
+ {batch_size, logits_dimension_},
+ &output_partial_logits_t));
+ auto output_partial_logits = output_partial_logits_t->matrix<float>();
+
+ Tensor* output_tree_ids_t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output("tree_ids", {batch_size},
+ &output_tree_ids_t));
+ auto output_tree_ids = output_tree_ids_t->vec<int32>();
+
+ Tensor* output_node_ids_t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output("node_ids", {batch_size},
+ &output_node_ids_t));
+ auto output_node_ids = output_node_ids_t->vec<int32>();
+
+ // Indicate that the latest tree was used.
+ const int32 latest_tree = resource->num_trees() - 1;
+
+ if (latest_tree < 0) {
+ // Ensemble was empty. Nothing changes.
+ output_node_ids = cached_node_ids;
+ output_tree_ids = cached_tree_ids;
+ // All the predictions are zeros.
+ output_partial_logits.setZero();
+ } else {
+ output_tree_ids.setConstant(latest_tree);
+ auto do_work = [&resource, &batch_bucketized_features, &cached_tree_ids,
+ &cached_node_ids, &output_partial_logits,
+ &output_node_ids, batch_size,
+ latest_tree](int32 start, int32 end) {
+ for (int32 i = start; i < end; ++i) {
+ int32 tree_id = cached_tree_ids(i);
+ int32 node_id = cached_node_ids(i);
+ float partial_tree_logit = 0.0;
+
+ // If the tree was pruned, returns the node id into which the
+ // current_node_id was pruned, as well the correction of the cached
+ // logit prediction.
+ resource->GetPostPruneCorrection(tree_id, node_id, &node_id,
+ &partial_tree_logit);
+
+ // Logic in the loop adds the cached node value again if it is a leaf.
+ // If it is not a leaf anymore we need to subtract the old node's
+ // value. The following logic handles both of these cases.
+ partial_tree_logit -= resource->node_value(tree_id, node_id);
+ float partial_all_logit = 0.0;
+ while (true) {
+ if (resource->is_leaf(tree_id, node_id)) {
+ partial_tree_logit += resource->node_value(tree_id, node_id);
+
+ // Tree is done
+ partial_all_logit +=
+ resource->GetTreeWeight(tree_id) * partial_tree_logit;
+ partial_tree_logit = 0.0;
+ // Stop if it was the latest tree.
+ if (tree_id == latest_tree) {
+ break;
+ }
+ // Move onto other trees.
+ ++tree_id;
+ node_id = 0;
+ } else {
+ node_id = resource->next_node(tree_id, node_id, i,
+ batch_bucketized_features);
+ }
+ }
+ output_node_ids(i) = node_id;
+ output_partial_logits(i, 0) = partial_all_logit;
+ }
+ };
+ // Assume we will not go over more than one full tree. 4 is a magic
+ // number.
+ const int64 cost = 4 * max_depth_;
+ thread::ThreadPool* const worker_threads =
+ context->device()->tensorflow_cpu_worker_threads()->workers;
+ Shard(worker_threads->NumThreads(), worker_threads, batch_size,
+ /*cost_per_unit=*/cost, do_work);
+ }
+ }
+
+ private:
+ int32 logits_dimension_; // the size of the output prediction vector.
+ int32 num_bucketized_features_; // Indicates the number of features.
+ int32 max_depth_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesTrainingPredict").Device(DEVICE_CPU),
+ BoostedTreesTrainingPredictOp);
+
+// The Op to get the predictions at the evaluation/inference time.
+class BoostedTreesPredictOp : public OpKernel {
+ public:
+ explicit BoostedTreesPredictOp(OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
+ &num_bucketized_features_));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("logits_dimension", &logits_dimension_));
+ OP_REQUIRES(context, logits_dimension_ == 1,
+ errors::InvalidArgument(
+ "Currently only one dimensional outputs are supported."));
+ OP_REQUIRES_OK(context, context->GetAttr("max_depth", &max_depth_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ BoostedTreesEnsembleResource* resource;
+ // Get the resource.
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &resource));
+ // Release the reference to the resource once we're done using it.
+ core::ScopedUnref unref_me(resource);
+
+ // Get the inputs.
+ OpInputList bucketized_features_list;
+ OP_REQUIRES_OK(context, context->input_list("bucketized_features",
+ &bucketized_features_list));
+ std::vector<tensorflow::TTypes<int32>::ConstVec> batch_bucketized_features;
+ batch_bucketized_features.reserve(bucketized_features_list.size());
+ for (const Tensor& tensor : bucketized_features_list) {
+ batch_bucketized_features.emplace_back(tensor.vec<int32>());
+ }
+ const int batch_size = batch_bucketized_features[0].size();
+
+ // Allocate outputs.
+ Tensor* output_logits_t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ "logits", {batch_size, logits_dimension_},
+ &output_logits_t));
+ auto output_logits = output_logits_t->matrix<float>();
+
+ const int32 latest_tree = resource->num_trees() - 1;
+
+ auto do_work = [&resource, &batch_bucketized_features, &output_logits,
+ batch_size, latest_tree](int32 start, int32 end) {
+ for (int32 i = start; i < end; ++i) {
+ float tree_logit = 0.0;
+ int32 tree_id = 0;
+ int32 node_id = 0;
+ while (true) {
+ if (resource->is_leaf(tree_id, node_id)) {
+ tree_logit += resource->GetTreeWeight(tree_id) *
+ resource->node_value(tree_id, node_id);
+
+ // Stop if it was the latest tree.
+ if (tree_id == latest_tree) {
+ break;
+ }
+ // Move onto other trees.
+ ++tree_id;
+ node_id = 0;
+ } else {
+ node_id = resource->next_node(tree_id, node_id, i,
+ batch_bucketized_features);
+ }
+ }
+ output_logits(i, 0) = tree_logit;
+ }
+ };
+ const int64 cost = (latest_tree + 1) * max_depth_;
+ thread::ThreadPool* const worker_threads =
+ context->device()->tensorflow_cpu_worker_threads()->workers;
+ Shard(worker_threads->NumThreads(), worker_threads, batch_size,
+ /*cost_per_unit=*/cost, do_work);
+ }
+
+ private:
+ int32
+ logits_dimension_; // Indicates the size of the output prediction vector.
+ int32 num_bucketized_features_; // Indicates the number of features.
+ int32 max_depth_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesPredict").Device(DEVICE_CPU),
+ BoostedTreesPredictOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/resource_ops.cc b/tensorflow/core/kernels/boosted_trees/resource_ops.cc
new file mode 100644
index 0000000000..f49242d856
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/resource_ops.cc
@@ -0,0 +1,189 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/boosted_trees/resources.h"
+
+namespace tensorflow {
+
+REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesEnsembleResource);
+
+REGISTER_KERNEL_BUILDER(
+ Name("IsBoostedTreesEnsembleInitialized").Device(DEVICE_CPU),
+ IsResourceInitialized<BoostedTreesEnsembleResource>);
+
+// Creates a tree ensemble resource.
+class BoostedTreesCreateEnsembleOp : public OpKernel {
+ public:
+ explicit BoostedTreesCreateEnsembleOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // Get the stamp token.
+ const Tensor* stamp_token_t;
+ OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
+ int64 stamp_token = stamp_token_t->scalar<int64>()();
+
+ // Get the tree ensemble proto.
+ const Tensor* tree_ensemble_serialized_t;
+ OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized",
+ &tree_ensemble_serialized_t));
+ std::unique_ptr<BoostedTreesEnsembleResource> result(
+ new BoostedTreesEnsembleResource());
+ if (!result->InitFromSerialized(
+ tree_ensemble_serialized_t->scalar<string>()(), stamp_token)) {
+ result->Unref();
+ OP_REQUIRES(
+ context, false,
+ errors::InvalidArgument("Unable to parse tree ensemble proto."));
+ }
+
+ // Only create one, if one does not exist already. Report status for all
+ // other exceptions.
+ auto status =
+ CreateResource(context, HandleFromInput(context, 0), result.release());
+ if (status.code() != tensorflow::error::ALREADY_EXISTS) {
+ OP_REQUIRES_OK(context, status);
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesCreateEnsemble").Device(DEVICE_CPU),
+ BoostedTreesCreateEnsembleOp);
+
+// Op for retrieving some model states (needed for training).
+class BoostedTreesGetEnsembleStatesOp : public OpKernel {
+ public:
+ explicit BoostedTreesGetEnsembleStatesOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // Looks up the resource.
+ BoostedTreesEnsembleResource* tree_ensemble_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &tree_ensemble_resource));
+ tf_shared_lock l(*tree_ensemble_resource->get_mutex());
+ core::ScopedUnref unref_me(tree_ensemble_resource);
+
+ // Sets the outputs.
+ const int num_trees = tree_ensemble_resource->num_trees();
+ const int num_finalized_trees =
+ (num_trees <= 0 ||
+ tree_ensemble_resource->IsTreeFinalized(num_trees - 1))
+ ? num_trees
+ : num_trees - 1;
+ const int num_attempted_layers =
+ tree_ensemble_resource->GetNumLayersAttempted();
+
+ // growing_metadata
+ Tensor* output_stamp_token_t = nullptr;
+ Tensor* output_num_trees_t = nullptr;
+ Tensor* output_num_finalized_trees_t = nullptr;
+ Tensor* output_num_attempted_layers_t = nullptr;
+
+ OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
+ &output_stamp_token_t));
+ OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape(),
+ &output_num_trees_t));
+ OP_REQUIRES_OK(context,
+ context->allocate_output(2, TensorShape(),
+ &output_num_finalized_trees_t));
+ OP_REQUIRES_OK(context,
+ context->allocate_output(3, TensorShape(),
+ &output_num_attempted_layers_t));
+
+ output_stamp_token_t->scalar<int64>()() = tree_ensemble_resource->stamp();
+ output_num_trees_t->scalar<int32>()() = num_trees;
+ output_num_finalized_trees_t->scalar<int32>()() = num_finalized_trees;
+ output_num_attempted_layers_t->scalar<int32>()() = num_attempted_layers;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesGetEnsembleStates").Device(DEVICE_CPU),
+ BoostedTreesGetEnsembleStatesOp);
+
+// Op for serializing a model.
+class BoostedTreesSerializeEnsembleOp : public OpKernel {
+ public:
+ explicit BoostedTreesSerializeEnsembleOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ BoostedTreesEnsembleResource* tree_ensemble_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &tree_ensemble_resource));
+ tf_shared_lock l(*tree_ensemble_resource->get_mutex());
+ core::ScopedUnref unref_me(tree_ensemble_resource);
+ Tensor* output_stamp_token_t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
+ &output_stamp_token_t));
+ output_stamp_token_t->scalar<int64>()() = tree_ensemble_resource->stamp();
+ Tensor* output_proto_t = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, TensorShape(), &output_proto_t));
+ output_proto_t->scalar<string>()() =
+ tree_ensemble_resource->SerializeAsString();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesSerializeEnsemble").Device(DEVICE_CPU),
+ BoostedTreesSerializeEnsembleOp);
+
+// Op for deserializing a tree ensemble variable from a checkpoint.
+class BoostedTreesDeserializeEnsembleOp : public OpKernel {
+ public:
+ explicit BoostedTreesDeserializeEnsembleOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ BoostedTreesEnsembleResource* tree_ensemble_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &tree_ensemble_resource));
+ mutex_lock l(*tree_ensemble_resource->get_mutex());
+ core::ScopedUnref unref_me(tree_ensemble_resource);
+
+ // Get the stamp token.
+ const Tensor* stamp_token_t;
+ OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
+ int64 stamp_token = stamp_token_t->scalar<int64>()();
+
+ // Get the tree ensemble proto.
+ const Tensor* tree_ensemble_serialized_t;
+ OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized",
+ &tree_ensemble_serialized_t));
+ // Deallocate all the previous objects on the resource.
+ tree_ensemble_resource->Reset();
+ OP_REQUIRES(
+ context,
+ tree_ensemble_resource->InitFromSerialized(
+ tree_ensemble_serialized_t->scalar<string>()(), stamp_token),
+ errors::InvalidArgument("Unable to parse tree ensemble proto."));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesDeserializeEnsemble").Device(DEVICE_CPU),
+ BoostedTreesDeserializeEnsembleOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc
new file mode 100644
index 0000000000..2ea12c522c
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/resources.cc
@@ -0,0 +1,301 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/boosted_trees/resources.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+int32 BoostedTreesEnsembleResource::next_node(
+ const int32 tree_id, const int32 node_id, const int32 index_in_batch,
+ const std::vector<TTypes<int32>::ConstVec>& bucketized_features) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
+ const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+ const auto& split = node.bucketized_split();
+ if (bucketized_features[split.feature_id()](index_in_batch) <=
+ split.threshold()) {
+ return split.left_id();
+ } else {
+ return split.right_id();
+ }
+}
+
+float BoostedTreesEnsembleResource::node_value(const int32 tree_id,
+ const int32 node_id) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
+ const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ if (node.node_case() == boosted_trees::Node::kLeaf) {
+ return node.leaf().scalar();
+ } else {
+ return node.metadata().original_leaf().scalar();
+ }
+}
+
+void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
+ tree_ensemble_->mutable_growing_metadata()->set_num_layers_attempted(
+ tree_ensemble_->growing_metadata().num_layers_attempted() + 1);
+
+ const int n_trees = num_trees();
+
+ if (n_trees <= 0 ||
+ // Checks if we are building the first layer of the dummy empty tree
+ ((n_trees == 1 || IsTreeFinalized(n_trees - 2)) &&
+ (tree_ensemble_->trees(n_trees - 1).nodes_size() == 1))) {
+ tree_ensemble_->mutable_growing_metadata()->set_num_trees_attempted(
+ tree_ensemble_->growing_metadata().num_trees_attempted() + 1);
+ }
+}
+
+// Add a tree to the ensemble and returns a new tree_id.
+int32 BoostedTreesEnsembleResource::AddNewTree(const float weight) {
+ const int32 new_tree_id = tree_ensemble_->trees_size();
+ auto* node = tree_ensemble_->add_trees()->add_nodes();
+ node->mutable_leaf()->set_scalar(0.0);
+ tree_ensemble_->add_tree_weights(weight);
+ tree_ensemble_->add_tree_metadata();
+
+ return new_tree_id;
+}
+
+void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
+ const int32 tree_id, const int32 node_id, const int32 feature_id,
+ const int32 threshold, const float gain, const float left_contrib,
+ const float right_contrib, int32* left_node_id, int32* right_node_id) {
+ auto* tree = tree_ensemble_->mutable_trees(tree_id);
+ auto* node = tree->mutable_nodes(node_id);
+ DCHECK_EQ(node->node_case(), boosted_trees::Node::kLeaf);
+ float prev_node_value = node->leaf().scalar();
+ *left_node_id = tree->nodes_size();
+ *right_node_id = *left_node_id + 1;
+ auto* left_node = tree->add_nodes();
+ auto* right_node = tree->add_nodes();
+ if (node_id != 0) {
+ // Save previous leaf value if it is not the first leaf in the tree.
+ node->mutable_metadata()->mutable_original_leaf()->Swap(
+ node->mutable_leaf());
+ }
+ node->mutable_metadata()->set_gain(gain);
+ auto* new_split = node->mutable_bucketized_split();
+ new_split->set_feature_id(feature_id);
+ new_split->set_threshold(threshold);
+ new_split->set_left_id(*left_node_id);
+ new_split->set_right_id(*right_node_id);
+ // TODO(npononareva): this is LAYER-BY-LAYER boosting; add WHOLE-TREE.
+ left_node->mutable_leaf()->set_scalar(prev_node_value + left_contrib);
+ right_node->mutable_leaf()->set_scalar(prev_node_value + right_contrib);
+}
+
+void BoostedTreesEnsembleResource::Reset() {
+ // Reset stamp.
+ set_stamp(-1);
+
+ // Clear tree ensemle.
+ arena_.Reset();
+ CHECK_EQ(0, arena_.SpaceAllocated());
+ tree_ensemble_ =
+ protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(&arena_);
+}
+
+void BoostedTreesEnsembleResource::PostPruneTree(const int32 current_tree) {
+ // No-op if tree is empty.
+ auto* tree = tree_ensemble_->mutable_trees(current_tree);
+ int32 num_nodes = tree->nodes_size();
+ if (num_nodes == 0) {
+ return;
+ }
+
+ std::vector<int32> nodes_to_delete;
+ // If a node was pruned, we need to save the change of the prediction from
+ // this node to its parent, as well as the parent id.
+ std::vector<std::pair<int32, float>> nodes_changes;
+ nodes_changes.reserve(num_nodes);
+ for (int32 i = 0; i < num_nodes; ++i) {
+ nodes_changes.emplace_back(i, 0.0);
+ }
+ // Prune the tree recursively starting from the root. Each node that has
+ // negative gain and only leaf children will be pruned recursively up from
+ // the bottom of the tree. This method returns the list of nodes pruned, and
+ // updates the nodes in the tree not to refer to those pruned nodes.
+ RecursivelyDoPostPrunePreparation(current_tree, 0, &nodes_to_delete,
+ &nodes_changes);
+
+ if (nodes_to_delete.empty()) {
+ // No pruning happened, and no post-processing needed.
+ return;
+ }
+
+ // Sort node ids so they are in asc order.
+ std::sort(nodes_to_delete.begin(), nodes_to_delete.end());
+
+ // We need to
+ // - update split left and right children ids with new indices
+ // - actually remove the nodes that need to be removed
+ // - save the information about pruned node so we could recover the
+ // predictions from cache. Build a map for old node index=>new node index.
+ // nodes_to_delete contains nodes who's indices should be skipped, in
+ // ascending order. Save the information about new indices into meta.
+ std::map<int32, int32> old_to_new_ids;
+ int32 new_index = 0;
+ int32 index_for_deleted = 0;
+ auto* post_prune_meta = tree_ensemble_->mutable_tree_metadata(current_tree)
+ ->mutable_post_pruned_nodes_meta();
+
+ for (int32 i = 0; i < num_nodes; ++i) {
+ if (index_for_deleted < nodes_to_delete.size() &&
+ i == nodes_to_delete[index_for_deleted]) {
+ // Node i will get removed,
+ ++index_for_deleted;
+ // Update meta info that will allow us to use cached predictions from
+ // those nodes.
+ int32 new_id;
+ float logit_change;
+ CalculateParentAndLogitUpdate(i, nodes_changes, &new_id, &logit_change);
+ auto* meta = post_prune_meta->Add();
+ meta->set_new_node_id(old_to_new_ids[new_id]);
+ meta->set_logit_change(logit_change);
+ } else {
+ old_to_new_ids[i] = new_index++;
+ auto* meta = post_prune_meta->Add();
+ // Update meta info that will allow us to use cached predictions from
+ // those nodes.
+ meta->set_new_node_id(old_to_new_ids[i]);
+ meta->set_logit_change(0.0);
+ }
+ }
+ index_for_deleted = 0;
+ int32 i = 0;
+ protobuf::RepeatedPtrField<boosted_trees::Node> new_nodes;
+ new_nodes.Reserve(old_to_new_ids.size());
+ for (auto node : *(tree->mutable_nodes())) {
+ if (index_for_deleted < nodes_to_delete.size() &&
+ i == nodes_to_delete[index_for_deleted]) {
+ ++index_for_deleted;
+ ++i;
+ continue;
+ } else {
+ if (node.node_case() == boosted_trees::Node::kBucketizedSplit) {
+ node.mutable_bucketized_split()->set_left_id(
+ old_to_new_ids[node.bucketized_split().left_id()]);
+ node.mutable_bucketized_split()->set_right_id(
+ old_to_new_ids[node.bucketized_split().right_id()]);
+ }
+ *new_nodes.Add() = std::move(node);
+ }
+ ++i;
+ }
+ // Replace all the nodes in a tree with the ones we keep.
+ *tree->mutable_nodes() = std::move(new_nodes);
+
+ // Note that if the whole tree got pruned, we will end up with one node.
+ // We can't remove that tree because it will cause problems with cache.
+}
+
+void BoostedTreesEnsembleResource::GetPostPruneCorrection(
+ const int32 tree_id, const int32 initial_node_id, int32* current_node_id,
+ float* logit_update) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ if (IsTreeFinalized(tree_id) && IsTreePostPruned(tree_id)) {
+ DCHECK_LT(
+ initial_node_id,
+ tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size());
+ const auto& meta =
+ tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta(
+ initial_node_id);
+ *current_node_id = meta.new_node_id();
+ *logit_update += meta.logit_change();
+ }
+}
+
+bool BoostedTreesEnsembleResource::IsTerminalSplitNode(
+ const int32 tree_id, const int32 node_id) const {
+ const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+ const int32 left_id = node.bucketized_split().left_id();
+ const int32 right_id = node.bucketized_split().right_id();
+ return is_leaf(tree_id, left_id) && is_leaf(tree_id, right_id);
+}
+
+// For each pruned node, finds the leaf where it finally ended up and
+// calculates the total update from that pruned node prediction.
+void BoostedTreesEnsembleResource::CalculateParentAndLogitUpdate(
+ const int32 start_node_id,
+ const std::vector<std::pair<int32, float>>& nodes_change, int32* parent_id,
+ float* change) const {
+ *change = 0.0;
+ int32 node_id = start_node_id;
+ int32 parent = nodes_change[node_id].first;
+
+ while (parent != node_id) {
+ (*change) += nodes_change[node_id].second;
+ node_id = parent;
+ parent = nodes_change[node_id].first;
+ }
+ *parent_id = parent;
+}
+
+void BoostedTreesEnsembleResource::RecursivelyDoPostPrunePreparation(
+ const int32 tree_id, const int32 node_id,
+ std::vector<int32>* nodes_to_delete,
+ std::vector<std::pair<int32, float>>* nodes_meta) {
+ auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
+ DCHECK_NE(node->node_case(), boosted_trees::Node::NODE_NOT_SET);
+ // Base case when we reach a leaf.
+ if (node->node_case() == boosted_trees::Node::kLeaf) {
+ return;
+ }
+
+ // Traverse node children first and recursively prune their sub-trees.
+ RecursivelyDoPostPrunePreparation(tree_id, node->bucketized_split().left_id(),
+ nodes_to_delete, nodes_meta);
+ RecursivelyDoPostPrunePreparation(tree_id,
+ node->bucketized_split().right_id(),
+ nodes_to_delete, nodes_meta);
+
+ // Two conditions must be satisfied to prune the node:
+ // 1- The split gain is negative.
+ // 2- After depth-first pruning, the node only has leaf children.
+ const auto& node_metadata = node->metadata();
+ if (node_metadata.gain() < 0 && IsTerminalSplitNode(tree_id, node_id)) {
+ const int32 left_id = node->bucketized_split().left_id();
+ const int32 right_id = node->bucketized_split().right_id();
+
+ // Save children that need to be deleted.
+ nodes_to_delete->push_back(left_id);
+ nodes_to_delete->push_back(right_id);
+
+ // Change node back into leaf.
+ *node->mutable_leaf() = node_metadata.original_leaf();
+ const float parent_value = node_value(tree_id, node_id);
+
+ // Save the old values of weights of children.
+ (*nodes_meta)[left_id].first = node_id;
+ (*nodes_meta)[left_id].second = parent_value - node_value(tree_id, left_id);
+
+ (*nodes_meta)[right_id].first = node_id;
+ (*nodes_meta)[right_id].second =
+ parent_value - node_value(tree_id, right_id);
+
+ // Clear gain for leaf node.
+ node->clear_metadata();
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/resources.h b/tensorflow/core/kernels/boosted_trees/resources.h
new file mode 100644
index 0000000000..c82588b950
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/resources.h
@@ -0,0 +1,221 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
+
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+// A StampedResource is a resource that has a stamp token associated with it.
+// Before reading from or applying updates to the resource, the stamp should
+// be checked to verify that the update is not stale.
+class StampedResource : public ResourceBase {
+ public:
+ StampedResource() : stamp_(-1) {}
+
+ bool is_stamp_valid(int64 stamp) const { return stamp_ == stamp; }
+
+ int64 stamp() const { return stamp_; }
+ void set_stamp(int64 stamp) { stamp_ = stamp; }
+
+ private:
+ int64 stamp_;
+};
+
+// Keep a tree ensemble in memory for efficient evaluation and mutation.
+class BoostedTreesEnsembleResource : public StampedResource {
+ public:
+ // Constructor.
+ BoostedTreesEnsembleResource()
+ : tree_ensemble_(
+ protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(
+ &arena_)) {}
+
+ string DebugString() override {
+ return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(),
+ "]");
+ }
+
+ bool InitFromSerialized(const string& serialized, const int64 stamp_token) {
+ CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
+ if (ParseProtoUnlimited(tree_ensemble_, serialized)) {
+ set_stamp(stamp_token);
+ return true;
+ }
+ return false;
+ }
+
+ string SerializeAsString() const {
+ return tree_ensemble_->SerializeAsString();
+ }
+
+ int32 num_trees() const { return tree_ensemble_->trees_size(); }
+
+ // Find the next node to which the example (specified by index_in_batch)
+ // traverses down from the current node indicated by tree_id and node_id.
+ // Args:
+ // tree_id: the index of the tree in the ensemble.
+ // node_id: the index of the node within the tree.
+ // index_in_batch: the index of the example within the batch (relevant to
+ // the index of the row to read in each bucketized_features).
+ // bucketized_features: vector of feature Vectors.
+ int32 next_node(
+ const int32 tree_id, const int32 node_id, const int32 index_in_batch,
+ const std::vector<TTypes<int32>::ConstVec>& bucketized_features) const;
+
+ float node_value(const int32 tree_id, const int32 node_id) const;
+
+ int32 GetNumLayersGrown(const int32 tree_id) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ return tree_ensemble_->tree_metadata(tree_id).num_layers_grown();
+ }
+
+ void SetNumLayersGrown(const int32 tree_id, int32 new_num_layers) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown(
+ new_num_layers);
+ }
+
+ void UpdateGrowingMetadata() const;
+
+ int32 GetNumLayersAttempted() {
+ return tree_ensemble_->growing_metadata().num_layers_attempted();
+ }
+
+ bool is_leaf(const int32 tree_id, const int32 node_id) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
+ const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ return node.node_case() == boosted_trees::Node::kLeaf;
+ }
+
+ int32 feature_id(const int32 tree_id, const int32 node_id) const {
+ const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+ return node.bucketized_split().feature_id();
+ }
+
+ int32 bucket_threshold(const int32 tree_id, const int32 node_id) const {
+ const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+ return node.bucketized_split().threshold();
+ }
+
+ int32 left_id(const int32 tree_id, const int32 node_id) const {
+ const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+ return node.bucketized_split().left_id();
+ }
+
+ int32 right_id(const int32 tree_id, const int32 node_id) const {
+ const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+ return node.bucketized_split().right_id();
+ }
+
+ // Add a tree to the ensemble and returns a new tree_id.
+ int32 AddNewTree(const float weight);
+
+ // Grows the tree by adding a split and leaves.
+ void AddBucketizedSplitNode(const int32 tree_id, const int32 node_id,
+ const int32 feature_id, const int32 threshold,
+ const float gain, const float left_contrib,
+ const float right_contrib, int32* left_node_id,
+ int32* right_node_id);
+
+ // Retrieves tree weights and returns as a vector.
+ // It involves a copy, so should be called only sparingly (like once per
+ // iteration, not per example).
+ std::vector<float> GetTreeWeights() const {
+ return {tree_ensemble_->tree_weights().begin(),
+ tree_ensemble_->tree_weights().end()};
+ }
+
+ float GetTreeWeight(const int32 tree_id) const {
+ return tree_ensemble_->tree_weights(tree_id);
+ }
+
+ float IsTreeFinalized(const int32 tree_id) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ return tree_ensemble_->tree_metadata(tree_id).is_finalized();
+ }
+
+ float IsTreePostPruned(const int32 tree_id) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ return tree_ensemble_->tree_metadata(tree_id)
+ .post_pruned_nodes_meta_size() > 0;
+ }
+
+ void SetIsFinalized(const int32 tree_id, const bool is_finalized) {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized(
+ is_finalized);
+ }
+
+ // Sets the weight of i'th tree.
+ void SetTreeWeight(const int32 tree_id, const float weight) {
+ DCHECK_GE(tree_id, 0);
+ DCHECK_LT(tree_id, num_trees());
+ tree_ensemble_->set_tree_weights(tree_id, weight);
+ }
+
+ // Resets the resource and frees the protos in arena.
+ // Caller needs to hold the mutex lock while calling this.
+ virtual void Reset();
+
+ void PostPruneTree(const int32 current_tree);
+
+ // For a given node, returns the id in a pruned tree, as well as correction
+ // to the cached prediction that should be applied. If tree was not
+ // post-pruned, current_node_id will be equal to initial_node_id and logit
+ // update will be equal to zero.
+ void GetPostPruneCorrection(const int32 tree_id, const int32 initial_node_id,
+ int32* current_node_id,
+ float* logit_update) const;
+ mutex* get_mutex() { return &mu_; }
+
+ private:
+ // Helper method to check whether a node is a terminal node in that it
+ // only has leaf nodes as children.
+ bool IsTerminalSplitNode(const int32 tree_id, const int32 node_id) const;
+
+ // For each pruned node, finds the leaf where it finally ended up and
+ // calculates the total update from that pruned node prediction.
+ void CalculateParentAndLogitUpdate(
+ const int32 start_node_id,
+ const std::vector<std::pair<int32, float>>& nodes_change,
+ int32* parent_id, float* change) const;
+
+ // Helper method to collect the information to be used to prune some nodes in
+ // the tree.
+ void RecursivelyDoPostPrunePreparation(
+ const int32 tree_id, const int32 node_id,
+ std::vector<int32>* nodes_to_delete,
+ std::vector<std::pair<int32, float>>* nodes_meta);
+
+ protected:
+ protobuf::Arena arena_;
+ mutex mu_;
+ boosted_trees::TreeEnsemble* tree_ensemble_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
new file mode 100644
index 0000000000..33fdab6a86
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
@@ -0,0 +1,296 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+
+namespace {
+const float kEps = 1e-15;
+} // namespace
+
+class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
+ public:
+ explicit BoostedTreesCalculateBestGainsPerFeatureOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("l1", &l1_));
+ OP_REQUIRES_OK(context, context->GetAttr("l2", &l2_));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("tree_complexity", &tree_complexity_));
+ OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
+ OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ // node_id_range
+ const Tensor* node_id_range_t;
+ OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
+ const auto node_id_range = node_id_range_t->vec<int32>();
+ int32 node_id_first = node_id_range(0);
+ int32 node_id_last = node_id_range(1); // inclusive.
+ // stats_summary_list
+ OpInputList stats_summary_list;
+ OP_REQUIRES_OK(context, context->input_list("stats_summary_list",
+ &stats_summary_list));
+ const int64 num_buckets = stats_summary_list[0].dim_size(1);
+ std::vector<TTypes<float, 3>::ConstTensor> stats_summary;
+ stats_summary.reserve(stats_summary_list.size());
+ for (const auto& tensor : stats_summary_list) {
+ stats_summary.emplace_back(tensor.tensor<float, 3>());
+ }
+
+ // Allocate output lists of tensors:
+ OpOutputList output_node_ids_list;
+ OP_REQUIRES_OK(
+ context, context->output_list("node_ids_list", &output_node_ids_list));
+ OpOutputList output_gains_list;
+ OP_REQUIRES_OK(context,
+ context->output_list("gains_list", &output_gains_list));
+ OpOutputList output_thresholds_list;
+ OP_REQUIRES_OK(context, context->output_list("thresholds_list",
+ &output_thresholds_list));
+ OpOutputList output_left_node_contribs_list;
+ OP_REQUIRES_OK(context,
+ context->output_list("left_node_contribs_list",
+ &output_left_node_contribs_list));
+ OpOutputList output_right_node_contribs_list;
+ OP_REQUIRES_OK(context,
+ context->output_list("right_node_contribs_list",
+ &output_right_node_contribs_list));
+
+ // Get the best split info per node for each feature.
+ for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
+ std::vector<float> cum_grad;
+ std::vector<float> cum_hess;
+ cum_grad.reserve(num_buckets);
+ cum_hess.reserve(num_buckets);
+
+ std::vector<int32> output_node_ids;
+ std::vector<float> output_gains;
+ std::vector<int32> output_thresholds;
+ std::vector<float> output_left_node_contribs;
+ std::vector<float> output_right_node_contribs;
+ for (int node_id = node_id_first; node_id <= node_id_last; ++node_id) {
+ // Calculate gains.
+ cum_grad.clear();
+ cum_hess.clear();
+ float total_grad = 0.0;
+ float total_hess = 0.0;
+ for (int bucket = 0; bucket < num_buckets; ++bucket) {
+ // TODO(nponomareva): Consider multi-dimensional gradients/hessians.
+ total_grad += stats_summary[feature_idx](node_id, bucket, 0);
+ total_hess += stats_summary[feature_idx](node_id, bucket, 1);
+ cum_grad.push_back(total_grad);
+ cum_hess.push_back(total_hess);
+ }
+ float best_gain = std::numeric_limits<float>::lowest();
+ float best_bucket = 0;
+ float best_contrib_for_left = 0.0;
+ float best_contrib_for_right = 0.0;
+ // Parent gain.
+ float parent_gain;
+ float unused;
+ CalculateWeightsAndGains(total_grad, total_hess, &unused, &parent_gain);
+
+ for (int bucket = 0; bucket < num_buckets; ++bucket) {
+ const float cum_grad_bucket = cum_grad[bucket];
+ const float cum_hess_bucket = cum_hess[bucket];
+ // Left child.
+ float contrib_for_left;
+ float gain_for_left;
+ CalculateWeightsAndGains(cum_grad_bucket, cum_hess_bucket,
+ &contrib_for_left, &gain_for_left);
+ // Right child.
+ float contrib_for_right;
+ float gain_for_right;
+ CalculateWeightsAndGains(total_grad - cum_grad_bucket,
+ total_hess - cum_hess_bucket,
+ &contrib_for_right, &gain_for_right);
+
+ if (gain_for_left + gain_for_right > best_gain) {
+ best_gain = gain_for_left + gain_for_right;
+ best_bucket = bucket;
+ best_contrib_for_left = contrib_for_left;
+ best_contrib_for_right = contrib_for_right;
+ }
+ } // for bucket
+ output_node_ids.push_back(node_id);
+ // Remove the parent gain for the parent node.
+ output_gains.push_back(best_gain - parent_gain);
+ output_thresholds.push_back(best_bucket);
+ output_left_node_contribs.push_back(best_contrib_for_left);
+ output_right_node_contribs.push_back(best_contrib_for_right);
+ } // for node_id
+ const int num_nodes = output_node_ids.size();
+ // output_node_ids
+ Tensor* output_node_ids_t;
+ OP_REQUIRES_OK(context,
+ output_node_ids_list.allocate(feature_idx, {num_nodes},
+ &output_node_ids_t));
+ auto output_node_ids_vec = output_node_ids_t->vec<int32>();
+ // output_gains
+ Tensor* output_gains_t;
+ OP_REQUIRES_OK(context, output_gains_list.allocate(
+ feature_idx, {num_nodes}, &output_gains_t));
+ auto output_gains_vec = output_gains_t->vec<float>();
+ // output_thresholds
+ Tensor* output_thresholds_t;
+ OP_REQUIRES_OK(context,
+ output_thresholds_list.allocate(feature_idx, {num_nodes},
+ &output_thresholds_t));
+ auto output_thresholds_vec = output_thresholds_t->vec<int32>();
+ // output_left_node_contribs
+ Tensor* output_left_node_contribs_t;
+ OP_REQUIRES_OK(context, output_left_node_contribs_list.allocate(
+ feature_idx, {num_nodes, 1},
+ &output_left_node_contribs_t));
+ auto output_left_node_contribs_matrix =
+ output_left_node_contribs_t->matrix<float>();
+ // output_right_node_contribs
+ Tensor* output_right_node_contribs_t;
+ OP_REQUIRES_OK(context, output_right_node_contribs_list.allocate(
+ feature_idx, {num_nodes, 1},
+ &output_right_node_contribs_t));
+ auto output_right_node_contribs_matrix =
+ output_right_node_contribs_t->matrix<float>();
+ // Sets output tensors from vectors.
+ for (int i = 0; i < num_nodes; ++i) {
+ output_node_ids_vec(i) = output_node_ids[i];
+ // Adjust the gains to penalize by tree complexity.
+ output_gains_vec(i) = output_gains[i] - tree_complexity_;
+ output_thresholds_vec(i) = output_thresholds[i];
+ // Logits are 1-dimensional for now.
+ // TODO(nponomareva): Consider multi-dimensional logits.
+ output_left_node_contribs_matrix(i, 0) = output_left_node_contribs[i];
+ output_right_node_contribs_matrix(i, 0) = output_right_node_contribs[i];
+ }
+ } // for f
+ }
+
+ private:
+ void CalculateWeightsAndGains(const float g, const float h, float* weight,
+ float* gain) {
+ //
+ // The formula for weight is -(g+l1*sgn(w))/(H+l2), for gain it is
+ // (g+l1*sgn(w))^2/(h+l2).
+ // This is because for each leaf we optimize
+ // 1/2(h+l2)*w^2+g*w+l1*abs(w)
+ float g_with_l1 = g;
+ // Apply L1 regularization.
+ // 1) Assume w>0 => w=-(g+l1)/(h+l2)=> g+l1 < 0 => g < -l1
+ // 2) Assume w<0 => w=-(g-l1)/(h+l2)=> g-l1 > 0 => g > l1
+ // For g from (-l1, l1), thus there is no solution => set to 0.
+ if (l1_ > 0) {
+ if (g > l1_) {
+ g_with_l1 -= l1_;
+ } else if (g < -l1_) {
+ g_with_l1 += l1_;
+ } else {
+ *weight = 0.0;
+ *gain = 0.0;
+ return;
+ }
+ }
+ // Apply L2 regularization.
+ if (h + l2_ <= kEps) {
+ // Avoid division by 0 or infinitesimal.
+ *weight = 0;
+ *gain = 0;
+ } else {
+ *weight = -g_with_l1 / (h + l2_);
+ *gain = -g_with_l1 * (*weight);
+ }
+ }
+
+ float l1_;
+ float l2_;
+ float tree_complexity_;
+ int max_splits_;
+ int num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesCalculateBestGainsPerFeature").Device(DEVICE_CPU),
+ BoostedTreesCalculateBestGainsPerFeatureOp);
+
+class BoostedTreesMakeStatsSummaryOp : public OpKernel {
+ public:
+ explicit BoostedTreesMakeStatsSummaryOp(OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
+ OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
+ OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ // node_ids
+ const Tensor* node_ids_t;
+ OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
+ const auto node_ids = node_ids_t->vec<int32>();
+ // gradients
+ const Tensor* gradients_t;
+ OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
+ const auto gradients = gradients_t->matrix<float>();
+ // hessians
+ const Tensor* hessians_t;
+ OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
+ const auto hessians = hessians_t->matrix<float>();
+ // bucketized_features
+ OpInputList bucketized_features_list;
+ OP_REQUIRES_OK(context, context->input_list("bucketized_features_list",
+ &bucketized_features_list));
+ std::vector<tensorflow::TTypes<int32>::ConstVec> bucketized_features;
+ bucketized_features.reserve(num_features_);
+ for (const Tensor& tensor : bucketized_features_list) {
+ bucketized_features.emplace_back(tensor.vec<int32>());
+ }
+
+ // Infer batch size.
+ const int64 batch_size = node_ids_t->dim_size(0);
+ // Allocate output stats tensor (Rank 4).
+ Tensor* output_stats_summary_t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ "stats_summary",
+ {num_features_, max_splits_, num_buckets_, 2},
+ &output_stats_summary_t));
+ auto output_stats_summary = output_stats_summary_t->tensor<float, 4>();
+ output_stats_summary.setZero();
+
+ // Partition by node, and then bucketize.
+ for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
+ const auto& features = bucketized_features[feature_idx];
+ for (int i = 0; i < batch_size; ++i) {
+ const int32 node = node_ids(i);
+ const int32 bucket = features(i);
+ output_stats_summary(feature_idx, node, bucket, 0) += gradients(i, 0);
+ output_stats_summary(feature_idx, node, bucket, 1) += hessians(i, 0);
+ }
+ }
+ }
+
+ private:
+ int max_splits_;
+ int num_buckets_;
+ int num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesMakeStatsSummary").Device(DEVICE_CPU),
+ BoostedTreesMakeStatsSummaryOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/training_ops.cc b/tensorflow/core/kernels/boosted_trees/training_ops.cc
new file mode 100644
index 0000000000..b9ded4054a
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/training_ops.cc
@@ -0,0 +1,219 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/boosted_trees/resources.h"
+
+namespace tensorflow {
+
+namespace {
+constexpr float kLayerByLayerTreeWeight = 1.0;
+
+// TODO(nponomareva, youngheek): consider using vector.
+struct SplitCandidate {
+ SplitCandidate() {}
+
+ // Index in the list of the feature ids.
+ int64 feature_idx;
+
+ // Index in the tensor of node_ids for the feature with idx feature_idx.
+ int64 candidate_idx;
+
+ float gain;
+};
+
+enum PruningMode { kNoPruning = 0, kPrePruning = 1, kPostPruning = 2 };
+
+} // namespace
+
+class BoostedTreesUpdateEnsembleOp : public OpKernel {
+ public:
+ explicit BoostedTreesUpdateEnsembleOp(OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("max_depth", &max_depth_));
+ OP_REQUIRES_OK(context, context->GetAttr("learning_rate", &learning_rate_));
+ OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
+
+ int32 pruning_index;
+ OP_REQUIRES_OK(context, context->GetAttr("pruning_mode", &pruning_index));
+ pruning_mode_ = static_cast<PruningMode>(pruning_index);
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ // Get decision tree ensemble.
+ BoostedTreesEnsembleResource* ensemble_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &ensemble_resource));
+ core::ScopedUnref unref_me(ensemble_resource);
+ mutex_lock l(*ensemble_resource->get_mutex());
+ // Increase the ensemble stamp.
+ ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
+
+ // Read node ids, gains, thresholds and node contribs.
+ OpInputList node_ids_list;
+ OpInputList gains_list;
+ OpInputList thresholds_list;
+ OpInputList left_node_contribs;
+ OpInputList right_node_contribs;
+ OP_REQUIRES_OK(context, context->input_list("node_ids", &node_ids_list));
+ OP_REQUIRES_OK(context, context->input_list("gains", &gains_list));
+ OP_REQUIRES_OK(context,
+ context->input_list("thresholds", &thresholds_list));
+ OP_REQUIRES_OK(context, context->input_list("left_node_contribs",
+ &left_node_contribs));
+ OP_REQUIRES_OK(context, context->input_list("right_node_contribs",
+ &right_node_contribs));
+
+ const Tensor* feature_ids_t;
+ OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
+
+ auto feature_ids = feature_ids_t->vec<int32>();
+
+ // Find best splits for each active node.
+ std::map<int32, SplitCandidate> best_splits;
+ FindBestSplitsPerNode(context, node_ids_list, gains_list, &best_splits);
+
+ int32 current_tree =
+ UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
+
+ // No-op if no new splits can be considered.
+ if (best_splits.empty()) {
+ LOG(WARNING) << "Not growing tree ensemble as no good splits were found.";
+ return;
+ }
+
+ const int32 new_num_layers =
+ ensemble_resource->GetNumLayersGrown(current_tree) + 1;
+ VLOG(1) << "Adding layer #" << new_num_layers - 1 << " to tree #"
+ << current_tree << " of ensemble of " << current_tree + 1
+ << " trees.";
+ bool split_happened = false;
+ // Add the splits to the tree.
+ for (auto& split_entry : best_splits) {
+ const int32 node_id = split_entry.first;
+ const SplitCandidate& candidate = split_entry.second;
+
+ const int64 feature_idx = candidate.feature_idx;
+ const int64 candidate_idx = candidate.candidate_idx;
+
+ const int32 feature_id = feature_ids(feature_idx);
+ const int32 threshold =
+ thresholds_list[feature_idx].vec<int32>()(candidate_idx);
+ const float gain = gains_list[feature_idx].vec<float>()(candidate_idx);
+
+ if (pruning_mode_ == kPrePruning) {
+ // Don't consider negative splits if we're pre-pruning the tree.
+ // Note that zero-gain splits are acceptable.
+ if (gain < 0) {
+ continue;
+ }
+ }
+ // For now assume that the weights vectors are one dimensional.
+ // TODO(nponomareva): change here for multiclass.
+ const float left_contrib =
+ learning_rate_ *
+ left_node_contribs[feature_idx].matrix<float>()(candidate_idx, 0);
+ const float right_contrib =
+ learning_rate_ *
+ right_node_contribs[feature_idx].matrix<float>()(candidate_idx, 0);
+
+ // unused.
+ int32 left_node_id;
+ int32 right_node_id;
+
+ ensemble_resource->AddBucketizedSplitNode(
+ current_tree, node_id, feature_id, threshold, gain, left_contrib,
+ right_contrib, &left_node_id, &right_node_id);
+ split_happened = true;
+ }
+ if (split_happened) {
+ // Update growable tree metadata.
+ ensemble_resource->SetNumLayersGrown(current_tree, new_num_layers);
+ // Finalize the tree if needed.
+ if (ensemble_resource->GetNumLayersGrown(current_tree) >= max_depth_) {
+ ensemble_resource->SetIsFinalized(current_tree, true);
+ if (pruning_mode_ == kPostPruning) {
+ ensemble_resource->PostPruneTree(current_tree);
+ }
+ if (ensemble_resource->num_trees() > 0) {
+ // Create a dummy new tree with an empty node.
+ ensemble_resource->AddNewTree(kLayerByLayerTreeWeight);
+ }
+ }
+ }
+ }
+
+ private:
+ int32 UpdateGlobalAttemptsAndRetrieveGrowableTree(
+ BoostedTreesEnsembleResource* const ensemble_resource) {
+ int32 num_trees = ensemble_resource->num_trees();
+ int32 current_tree = num_trees - 1;
+
+ // Increment global attempt stats.
+ ensemble_resource->UpdateGrowingMetadata();
+
+ // Note we don't set tree weight to be equal to learning rate, since we
+ // apply learning rate to leaf weights instead, when doing layer-by-layer
+ // boosting.
+ if (num_trees <= 0) {
+ // Create a new tree with a no-op leaf.
+ current_tree = ensemble_resource->AddNewTree(kLayerByLayerTreeWeight);
+ }
+ return current_tree;
+ }
+
+ // Helper method which effectively does a reduce over all split candidates
+ // and finds the best split for each node.
+ void FindBestSplitsPerNode(
+ OpKernelContext* const context, const OpInputList& node_ids_list,
+ const OpInputList& gains_list,
+ std::map<int32, SplitCandidate>* best_split_per_node) {
+ // Find best split per node going through every feature candidate.
+ for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
+ const auto& node_ids = node_ids_list[feature_idx].vec<int32>();
+ const auto& gains = gains_list[feature_idx].vec<float>();
+
+ for (size_t candidate_idx = 0; candidate_idx < node_ids.size();
+ ++candidate_idx) {
+ // Get current split candidate.
+ const auto& node_id = node_ids(candidate_idx);
+ const auto& gain = gains(candidate_idx);
+
+ auto best_split_it = best_split_per_node->find(node_id);
+ SplitCandidate candidate;
+ candidate.feature_idx = feature_idx;
+ candidate.candidate_idx = candidate_idx;
+ candidate.gain = gain;
+
+ if (best_split_it == best_split_per_node->end() ||
+ gain > best_split_it->second.gain) {
+ (*best_split_per_node)[node_id] = candidate;
+ }
+ }
+ }
+ }
+
+ private:
+ int32 num_features_;
+ float learning_rate_;
+ int32 max_depth_;
+ PruningMode pruning_mode_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsemble").Device(DEVICE_CPU),
+ BoostedTreesUpdateEnsembleOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_log.cc b/tensorflow/core/kernels/cwise_op_log.cc
index 98936e0f96..5d17c890cf 100644
--- a/tensorflow/core/kernels/cwise_op_log.cc
+++ b/tensorflow/core/kernels/cwise_op_log.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER5(UnaryOp, CPU, "Log", functor::log, float, Eigen::half, double,
- complex64, complex128);
+REGISTER6(UnaryOp, CPU, "Log", functor::log, float, Eigen::half, double,
+ bfloat16, complex64, complex128);
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Log", functor::log, float, Eigen::half, double);
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 06918075a4..a80905d145 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -27,27 +27,6 @@ limitations under the License.
#include "tensorflow/core/kernels/bounds_check.h"
namespace Eigen {
-namespace numext {
-#if GOOGLE_CUDA
-template <>
-EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::complex<float> exp(
- const std::complex<float>& x) {
- auto com = ::expf(x.real());
- auto res_real = com * ::cosf(x.imag());
- auto res_imag = com * ::sinf(x.imag());
- return std::complex<float>(res_real, res_imag);
-}
-template <>
-EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::complex<double> exp(
- const std::complex<double>& x) {
- auto com = ::exp(x.real());
- auto res_real = com * ::cos(x.imag());
- auto res_imag = com * ::sin(x.imag());
- return std::complex<double>(res_real, res_imag);
-}
-#endif
-} // namespace numext
-
namespace internal {
template <typename T>
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 01754ec21a..8c4f0218ee 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -10,18 +10,7 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"tf_kernel_library",
-)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
+ "tf_cc_test",
)
cc_library(
@@ -295,11 +284,31 @@ tf_kernel_library(
],
)
+cc_library(
+ name = "prefetch_autotuner",
+ srcs = ["prefetch_autotuner.cc"],
+ hdrs = ["prefetch_autotuner.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "prefetch_autotuner_test",
+ srcs = ["prefetch_autotuner_test.cc"],
+ deps = [
+ ":prefetch_autotuner",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
tf_kernel_library(
name = "prefetch_dataset_op",
srcs = ["prefetch_dataset_op.cc"],
deps = [
":dataset",
+ ":prefetch_autotuner",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index 834c06bb93..46f43dd1b1 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -263,6 +263,11 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
}
const int64 window_size =
window_size_func_output[0].scalar<int64>()();
+ if (window_size <= 0) {
+ return errors::InvalidArgument(
+ "Window size must be greater than zero, but got ",
+ window_size, ".");
+ }
window_sizes_[key] = window_size;
}
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.cc b/tensorflow/core/kernels/data/prefetch_autotuner.cc
new file mode 100644
index 0000000000..b3272f6bcd
--- /dev/null
+++ b/tensorflow/core/kernels/data/prefetch_autotuner.cc
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/prefetch_autotuner.h"
+
+namespace tensorflow {
+
+PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size)
+ : buffer_limit_(initial_buffer_size) {
+ if (initial_buffer_size == kAutoTune) {
+ mode_ = Mode::kUpswing;
+ buffer_limit_ = 1;
+ }
+}
+
+void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) {
+ switch (mode_) {
+ case Mode::kDisabled:
+ return;
+ case Mode::kUpswing:
+ if (current_buffer_size == buffer_limit_) {
+ mode_ = Mode::kDownswing;
+ }
+ return;
+ case Mode::kDownswing:
+ if (current_buffer_size == 0) {
+ buffer_limit_ *= 2; // Increase the buffer size.
+ mode_ = Mode::kUpswing;
+ }
+ return;
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.h b/tensorflow/core/kernels/data/prefetch_autotuner.h
new file mode 100644
index 0000000000..fa8a184072
--- /dev/null
+++ b/tensorflow/core/kernels/data/prefetch_autotuner.h
@@ -0,0 +1,71 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_AUTOTUNER_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_AUTOTUNER_H_
+
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// PrefetchAutotuner dynamically adjusts the buffer size of a prefetch iterator.
+//
+// PrefetchAutotuner attempts to find the minimum buffer size such that there is
+// always at least 1 element in the prefetch queue every time the downstream
+// iterator calls GetNext().
+//
+// One common failure mode of input pipelines is being throughput bound. No
+// amount of prefetching can address that performance mode. In order to guard
+// against this condition, PrefetchAutotuner will only increase the buffer_limit
+// if the prefetching thread is able to successfully fill the buffer at its
+// current size.
+//
+// Note: in the current implementation, we never decrease the buffer_limit().
+// This should change in the future!
+//
+// PrefetchAutotuner is NOT thread safe.
+class PrefetchAutotuner {
+ public:
+ static const int64 kAutoTune = -1;
+
+ explicit PrefetchAutotuner(int64 initial_buffer_size);
+
+ int64 buffer_limit() const { return buffer_limit_; }
+
+ void RecordConsumption(size_t current_buffer_size);
+ void RecordEmpty() { RecordConsumption(0); }
+
+ private:
+ // PrefetchAutotuner operates as a state machine.
+ enum class Mode {
+ // Disables the autotuning.
+ kDisabled,
+
+ // We have increased the size of the buffer, and will transition to
+ // kDownswing if we successfully fill the buffer.
+ kUpswing,
+
+ // We have successfully filled a buffer of this size. If we ever block the
+ // downstream iterator, we should increase the buffer size.
+ kDownswing,
+ };
+
+ int64 buffer_limit_;
+ Mode mode_ = Mode::kDisabled;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_AUTOTUNER_H_
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
new file mode 100644
index 0000000000..2f573dfb35
--- /dev/null
+++ b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
@@ -0,0 +1,82 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/prefetch_autotuner.h"
+
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(PrefetchAutotuner, Disabled) {
+ PrefetchAutotuner t(2);
+ EXPECT_EQ(2, t.buffer_limit());
+ t.RecordConsumption(0);
+ t.RecordConsumption(2);
+ t.RecordConsumption(0);
+ t.RecordConsumption(2);
+ EXPECT_EQ(2, t.buffer_limit());
+}
+
+TEST(PrefetchAutotuner, Enabled) {
+ PrefetchAutotuner t(PrefetchAutotuner::kAutoTune);
+ EXPECT_EQ(1, t.buffer_limit());
+ t.RecordConsumption(0); // Expect buffer limit to increase.
+ EXPECT_EQ(1, t.buffer_limit());
+ t.RecordConsumption(1);
+ EXPECT_EQ(1, t.buffer_limit());
+ t.RecordConsumption(0); // Expect buffer limit to increase.
+ EXPECT_EQ(2, t.buffer_limit());
+ t.RecordConsumption(2);
+ EXPECT_EQ(2, t.buffer_limit());
+ t.RecordConsumption(1);
+ EXPECT_EQ(2, t.buffer_limit());
+ t.RecordConsumption(0); // Expect buffer limit to increase.
+ EXPECT_EQ(4, t.buffer_limit());
+ t.RecordConsumption(4);
+ EXPECT_EQ(4, t.buffer_limit());
+ t.RecordConsumption(0); // Expect buffer limit to increase.
+ EXPECT_EQ(8, t.buffer_limit());
+ t.RecordConsumption(0); // Expect buffer limit to stay the same!
+ EXPECT_EQ(8, t.buffer_limit());
+ t.RecordConsumption(0); // Expect buffer limit to stay the same!
+ EXPECT_EQ(8, t.buffer_limit());
+}
+
+TEST(PrefetchAutotuner, EnabledSteady) {
+ PrefetchAutotuner t(PrefetchAutotuner::kAutoTune);
+ EXPECT_EQ(1, t.buffer_limit());
+ t.RecordConsumption(0); // Expect buffer limit to increase.
+ EXPECT_EQ(1, t.buffer_limit());
+ t.RecordConsumption(1);
+ EXPECT_EQ(1, t.buffer_limit());
+ t.RecordConsumption(0); // Expect buffer limit to increase.
+ EXPECT_EQ(2, t.buffer_limit());
+ t.RecordConsumption(2);
+ EXPECT_EQ(2, t.buffer_limit());
+ t.RecordConsumption(0); // Expect buffer limit to increase.
+ EXPECT_EQ(4, t.buffer_limit());
+
+ // Never reach zero again.
+ std::vector<size_t> consumption_values = {2, 3, 1, 4, 1, 2, 3, 1};
+ for (int i = 0; i < consumption_values.size(); ++i) {
+ t.RecordConsumption(consumption_values[i]);
+ EXPECT_EQ(4, t.buffer_limit())
+ << "Failed at index " << i << " with value: " << consumption_values[i];
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 1c548a30d2..536de81fd8 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/prefetch_autotuner.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
namespace tensorflow {
@@ -37,7 +38,8 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
int64 buffer_size;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
- OP_REQUIRES(ctx, buffer_size > 0,
+ OP_REQUIRES(ctx,
+ buffer_size > 0 || buffer_size == PrefetchAutotuner::kAutoTune,
errors::InvalidArgument("buffer_size must be > 0"));
*output = new Dataset(ctx, input, buffer_size);
@@ -85,7 +87,8 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
+ auto_tuner_(params.dataset->buffer_size_) {}
~Iterator() override {
// Signal the prefetch thread to terminate it. We will then
@@ -113,6 +116,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
// Wait until the next element in the buffer has been
// produced, or we are shutting down.
while (!cancelled_ && !prefetch_thread_finished_ && buffer_.empty()) {
+ auto_tuner_.RecordEmpty();
cond_var_.wait(l);
}
@@ -129,6 +133,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
if (s.ok()) {
*out_tensors = std::move(buffer_.front().value);
}
+ auto_tuner_.RecordConsumption(buffer_.size());
buffer_.pop_front();
*end_of_sequence = false;
@@ -242,7 +247,8 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
// 1. Wait for a slot in the buffer.
{
mutex_lock l(mu_);
- while (!cancelled_ && buffer_.size() == dataset()->buffer_size_) {
+ while (!cancelled_ &&
+ buffer_.size() == auto_tuner_.buffer_limit()) {
cond_var_.wait(l);
}
@@ -323,6 +329,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
mutex parent_mu_ ACQUIRED_BEFORE(mu_);
const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_);
condition_variable cond_var_;
+ PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_);
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_) = false;
diff --git a/tensorflow/core/kernels/data/sql/BUILD b/tensorflow/core/kernels/data/sql/BUILD
index f4698bdaf7..dc59120875 100644
--- a/tensorflow/core/kernels/data/sql/BUILD
+++ b/tensorflow/core/kernels/data/sql/BUILD
@@ -7,18 +7,6 @@ package(
licenses(["notice"]) # Apache 2.0
-filegroup(
- name = "all_files",
- srcs = glob(
- include = ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
cc_library(
name = "sql",
srcs = [
diff --git a/tensorflow/core/kernels/fuzzing/BUILD b/tensorflow/core/kernels/fuzzing/BUILD
index 9a7eca03ce..aab4b009b5 100644
--- a/tensorflow/core/kernels/fuzzing/BUILD
+++ b/tensorflow/core/kernels/fuzzing/BUILD
@@ -17,18 +17,6 @@ cc_library(
],
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
load("//tensorflow/core/kernels/fuzzing:tf_ops_fuzz_target_lib.bzl", "tf_ops_fuzz_target_lib")
tf_ops_fuzz_target_lib("identity")
diff --git a/tensorflow/core/kernels/hexagon/BUILD b/tensorflow/core/kernels/hexagon/BUILD
index 7688305019..4870d9ae20 100644
--- a/tensorflow/core/kernels/hexagon/BUILD
+++ b/tensorflow/core/kernels/hexagon/BUILD
@@ -13,18 +13,6 @@ load(
"tf_kernel_library",
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tf_cc_test(
name = "graph_transferer_test",
size = "small",
diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h
index b16c76dc7f..edb779540f 100644
--- a/tensorflow/core/kernels/initializable_lookup_table.h
+++ b/tensorflow/core/kernels/initializable_lookup_table.h
@@ -92,6 +92,8 @@ class InitializableLookupTable : public LookupInterface {
//
// Then the iterator is exhausted, valid returns false and status returns
// Status::OutOfRange.
+ //
+ // This class is Thread-unsafe.
class InitTableIterator {
public:
InitTableIterator() {}
diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc
index baf0a4abe4..9e7786f25e 100644
--- a/tensorflow/core/kernels/list_kernels.cc
+++ b/tensorflow/core/kernels/list_kernels.cc
@@ -112,6 +112,7 @@ bool TensorList::Decode(const VariantTensorData& data) {
dims.push_back(scratch);
}
}
+ element_shape = PartialTensorShape(dims);
return true;
}
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h
index 9733883001..8af48f0a67 100644
--- a/tensorflow/core/kernels/list_kernels.h
+++ b/tensorflow/core/kernels/list_kernels.h
@@ -83,7 +83,8 @@ class TensorListStack : public OpKernel {
DataTypeString(l->element_dtype)));
OP_REQUIRES(c, l->element_shape.IsFullyDefined(),
errors::InvalidArgument("Tried to stack elements from a list "
- "with non-fully-defined shape."));
+ "with non-fully-defined shape: ",
+ l->element_shape.DebugString()));
if (num_elements_ != -1) {
OP_REQUIRES(c, l->tensors.size() == num_elements_,
errors::InvalidArgument("Operation expected a list with ",
diff --git a/tensorflow/core/kernels/neon/BUILD b/tensorflow/core/kernels/neon/BUILD
index c3d24e50ef..313d40c082 100644
--- a/tensorflow/core/kernels/neon/BUILD
+++ b/tensorflow/core/kernels/neon/BUILD
@@ -12,18 +12,6 @@ load(
"tf_kernel_library",
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tf_kernel_library(
name = "neon_depthwise_conv_op",
hdrs = [
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index e134e476f6..d1675f27dd 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -503,6 +503,7 @@ class ResourceGatherOp : public OpKernel {
void Compute(OpKernelContext* c) override {
Var* v = nullptr;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
+ core::ScopedUnref su(v);
// NOTE: We hold the lock for the whole gather operation instead
// of increasing the reference count of v->tensor() to avoid a
// situation where a write to the same variable will see a
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index d0703d7576..89abfe0eb1 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -31,6 +31,13 @@ limitations under the License.
// non-GPU targets. This only breaks in clang, because it's more strict for
// template code and CudaAtomicMax is used in template context.
+// This file requires the following include because it uses CudaAtomicMax:
+// #include "tensorflow/core/util/cuda_kernel_helper.h"
+
+// Unfortunately we can't add the #include, since it breaks compilation for
+// non-GPU targets. This only breaks in clang, because it's more strict for
+// template code and CudaAtomicMax is used in template context.
+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
diff --git a/tensorflow/core/kernels/snapshot_op_gpu.cu.cc b/tensorflow/core/kernels/snapshot_op_gpu.cu.cc
index f1c0ed2eae..e4e3bd5220 100644
--- a/tensorflow/core/kernels/snapshot_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/snapshot_op_gpu.cu.cc
@@ -25,8 +25,7 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
// Definition of the GPU implementations declared in softsign_op.cc.
-#define DEFINE_GPU_KERNELS(T) \
- template struct functor::Snapshot<GPUDevice, T>;
+#define DEFINE_GPU_KERNELS(T) template struct functor::Snapshot<GPUDevice, T>;
TF_CALL_POD_TYPES(DEFINE_GPU_KERNELS);
diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc
index ebd19c3d35..9a3612bd72 100644
--- a/tensorflow/core/kernels/xent_op.cc
+++ b/tensorflow/core/kernels/xent_op.cc
@@ -52,8 +52,8 @@ class SoftmaxXentWithLogitsOp : public OpKernel {
OP_REQUIRES(context, bcast.IsValid(),
errors::InvalidArgument(
"logits and labels must be broadcastable: logits_size=",
- logits_in.shape().DebugString(), " labels_size=",
- labels_in.shape().DebugString()));
+ logits_in.shape().DebugString(),
+ " labels_size=", labels_in.shape().DebugString()));
shape_in = BCast::ToShape(bcast.output_shape());
}
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(shape_in),
diff --git a/tensorflow/core/lib/db/BUILD b/tensorflow/core/lib/db/BUILD
index 9ff87e8d66..ce09c2009a 100644
--- a/tensorflow/core/lib/db/BUILD
+++ b/tensorflow/core/lib/db/BUILD
@@ -42,9 +42,3 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(["*"]),
- visibility = ["//tensorflow:__pkg__"],
-)
diff --git a/tensorflow/core/lib/io/inputbuffer_test.cc b/tensorflow/core/lib/io/inputbuffer_test.cc
index 6be1f819c2..3608008b30 100644
--- a/tensorflow/core/lib/io/inputbuffer_test.cc
+++ b/tensorflow/core/lib/io/inputbuffer_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -287,7 +288,7 @@ TEST(InputBuffer, Seek) {
EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(1, &read)));
EXPECT_TRUE(
- StringPiece(in.Seek(-1).ToString()).contains("negative position"));
+ str_util::StrContains(in.Seek(-1).ToString(), "negative position"));
}
}
diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc
index b7e51256a2..63235761d9 100644
--- a/tensorflow/core/lib/io/recordio_test.cc
+++ b/tensorflow/core/lib/io/recordio_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/lib/io/record_writer.h"
#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
@@ -218,7 +219,7 @@ TEST_F(RecordioTest, RandomRead) {
// Tests of all the error paths in log_reader.cc follow:
static void AssertHasSubstr(StringPiece s, StringPiece expected) {
- EXPECT_TRUE(StringPiece(s).contains(expected))
+ EXPECT_TRUE(str_util::StrContains(s, expected))
<< s << " does not contain " << expected;
}
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index f97f1645a6..62ce70eb6b 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -178,46 +178,88 @@ Status SetOutputShapeForReshape(InferenceContext* c) {
c->set_output(0, out);
return Status::OK();
}
- DimensionHandle num_in_elems = c->NumElements(in);
- if (c->FullyDefined(out)) {
- DimensionHandle num_out_elems = c->NumElements(out);
- if (c->ValueKnown(num_in_elems) &&
- c->Value(num_in_elems) != c->Value(num_out_elems)) {
- return errors::InvalidArgument(
- "Cannot reshape a tensor with ", c->DebugString(num_in_elems),
- " elements to shape ", c->DebugString(out), " (",
- c->DebugString(num_out_elems), " elements)");
- }
- c->set_output(0, out);
- return Status::OK();
- }
- if (c->ValueKnown(num_in_elems)) {
+ if (c->RankKnown(out) && c->RankKnown(in)) {
// We don't know the number of output elements, but we can try to infer
// the missing dimension.
- int32 unknown_idx = -1;
bool too_many_unknown = false;
- DimensionHandle known_elems = c->MakeDim(1);
- for (int32 i = 0; i < c->Rank(out); ++i) {
- DimensionHandle dim = c->Dim(out, i);
- if (!c->ValueKnown(dim)) {
- if (unknown_idx >= 0) {
- too_many_unknown = true;
- break;
+ int32 out_unknown_idx = -1;
+
+ DimensionHandle known_out_elems = c->NumElements(out);
+ if (!c->ValueKnown(known_out_elems)) {
+ known_out_elems = c->MakeDim(1);
+ for (int32 i = 0; i < c->Rank(out); ++i) {
+ DimensionHandle dim = c->Dim(out, i);
+ if (!c->ValueKnown(dim)) {
+ if (out_unknown_idx >= 0) {
+ too_many_unknown = true;
+ break;
+ }
+ out_unknown_idx = i;
+ } else {
+ TF_RETURN_IF_ERROR(
+ c->Multiply(known_out_elems, dim, &known_out_elems));
}
- unknown_idx = i;
- } else {
- TF_RETURN_IF_ERROR(c->Multiply(known_elems, dim, &known_elems));
}
}
- if (!too_many_unknown && c->Value(known_elems) != 0) {
- DimensionHandle inferred_dim;
- TF_RETURN_IF_ERROR(c->Divide(num_in_elems, c->Value(known_elems),
- true /* evenly_divisible */, &inferred_dim));
- TF_RETURN_IF_ERROR(c->ReplaceDim(out, unknown_idx, inferred_dim, &out));
+ int32 in_unknown_idx = -1;
+ DimensionHandle known_in_elems = c->NumElements(in);
+ if (!c->ValueKnown(known_in_elems)) {
+ known_in_elems = c->MakeDim(1);
+ for (int32 i = 0; i < c->Rank(in); ++i) {
+ DimensionHandle dim = c->Dim(in, i);
+ if (!c->ValueKnown(dim)) {
+ if (in_unknown_idx >= 0) {
+ too_many_unknown = true;
+ break;
+ }
+ in_unknown_idx = i;
+ } else {
+ TF_RETURN_IF_ERROR(c->Multiply(known_in_elems, dim, &known_in_elems));
+ }
+ }
}
- }
+ if (!too_many_unknown) {
+ if (in_unknown_idx < 0 && out_unknown_idx < 0) {
+ // Just check that the dimensions match.
+ if (c->Value(known_in_elems) != c->Value(known_out_elems)) {
+ return errors::InvalidArgument(
+ "Cannot reshape a tensor with ", c->DebugString(known_in_elems),
+ " elements to shape ", c->DebugString(out), " (",
+ c->DebugString(known_out_elems), " elements)");
+ }
+ } else if (in_unknown_idx < 0 && out_unknown_idx >= 0 &&
+ c->Value(known_out_elems) > 0) {
+ // Input fully known, infer the one missing output dim
+ DimensionHandle inferred_dim;
+ TF_RETURN_IF_ERROR(c->Divide(known_in_elems, c->Value(known_out_elems),
+ true /* evenly_divisible */,
+ &inferred_dim));
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(out, out_unknown_idx, inferred_dim, &out));
+
+ } else if (in_unknown_idx >= 0 && out_unknown_idx < 0 &&
+ c->Value(known_in_elems) != 0) {
+ // Output fully known, infer the one missing input dim
+ DimensionHandle inferred_dim;
+ TF_RETURN_IF_ERROR(c->Divide(known_out_elems, c->Value(known_in_elems),
+ true /* evenly_divisible */,
+ &inferred_dim));
+ DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
+ TF_RETURN_IF_ERROR(
+ c->Merge(unknown_in_dim, inferred_dim, &unknown_in_dim));
+ } else if (in_unknown_idx >= 0 && out_unknown_idx >= 0) {
+ // Exactly one unknown dimension in both input and output. These 2 are
+ // equal iff the known elements are equal.
+ if (c->Value(known_in_elems) == c->Value(known_out_elems)) {
+ DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(out, out_unknown_idx, unknown_in_dim, &out));
+ }
+ }
+ }
+ }
c->set_output(0, out);
return Status::OK();
}
@@ -452,9 +494,9 @@ REGISTER_OP("SplitV")
const Tensor* size_splits = c->input_tensor(1);
if (rank == InferenceContext::kUnknownRank) {
// If the rank of input tensor is unknown, then return unknown shapes.
- output_shape = c->UnknownShape();
+ // Note that the shape of each output can be different.
for (int i = 0; i < num_outputs; ++i) {
- c->set_output(i, output_shape);
+ c->set_output(i, c->UnknownShape());
}
} else if (rank == 0) {
// Throw error if input is a scalar.
@@ -463,18 +505,19 @@ REGISTER_OP("SplitV")
// If split dimension is known, but the sizes are unknown, then
// only the split dimension is unknown
output_shape = input;
- TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape,
- c->Value(split_dimension),
- c->UnknownDim(), &output_shape));
for (int i = 0; i < num_outputs; ++i) {
+ TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape,
+ c->Value(split_dimension),
+ c->UnknownDim(), &output_shape));
c->set_output(i, output_shape);
}
} else if (size_splits == nullptr && !c->ValueKnown(split_dimension)) {
// If split dimension or tensor containing the split sizes is unknown,
- // then return unknown shapes of same rank as input.
- output_shape = c->UnknownShapeOfRank(rank);
+ // then return unknown shapes of same rank as input. Note that each
+ // output shape can be different since splitv doesn't always split
+ // tensors evenly.
for (int i = 0; i < num_outputs; ++i) {
- c->set_output(i, output_shape);
+ c->set_output(i, c->UnknownShapeOfRank(rank));
}
} else {
// Determine the output shape if split dimension and split sizes are
@@ -776,7 +819,7 @@ REGISTER_OP("ReverseV2")
}
if (axes_dense[canonical_axis]) {
return errors::InvalidArgument("axis ", canonical_axis,
- " specified more than once.");
+ " specified more than once.");
}
axes_dense[canonical_axis] = true;
}
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index cf5bb5ad84..b1463338fb 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -838,7 +838,7 @@ TEST(ArrayOpsTest, Reshape_ShapeFn) {
// Unknown dimensions.
// Flatten:
new_shape = test::AsTensor<int32>({-1});
- INFER_OK(op, "[?];[1]", "[?]");
+ INFER_OK(op, "[?];[1]", "[d0_0]");
INFER_OK(op, "[2,2];[1]", "[4]");
// The first dimension is inferred:
new_shape = test::AsTensor<int32>({2, -1});
@@ -851,6 +851,10 @@ TEST(ArrayOpsTest, Reshape_ShapeFn) {
new_shape = test::AsTensor<int32>({-1, -1, 2});
INFER_OK(op, "[8];[3]", "[?,?,2]");
+ // Symbolic shape propagation
+ new_shape = test::AsTensor<int32>({-1, 2, 3});
+ INFER_OK(op, "[?,2,3];[3]", "[d0_0,2,3]");
+
// Reshaping to a scalar.
new_shape = test::AsTensor<int32>({});
INFER_OK(op, "[1];[0]", "[]");
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
new file mode 100644
index 0000000000..297e94655f
--- /dev/null
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -0,0 +1,319 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_RESOURCE_HANDLE_OP(BoostedTreesEnsembleResource);
+
+REGISTER_OP("IsBoostedTreesEnsembleInitialized")
+ .Input("tree_ensemble_handle: resource")
+ .Output("is_initialized: bool")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesCalculateBestGainsPerFeature")
+ .Input("node_id_range: int32")
+ .Input("stats_summary_list: num_features * float32")
+ .Attr("l1: float")
+ .Attr("l2: float")
+ .Attr("tree_complexity: float")
+ .Attr("max_splits: int >= 1")
+ .Attr("num_features: int >= 1") // not passed but populated automatically.
+ .Output("node_ids_list: num_features * int32")
+ .Output("gains_list: num_features * float32")
+ .Output("thresholds_list: num_features * int32")
+ .Output("left_node_contribs_list: num_features * float32")
+ .Output("right_node_contribs_list: num_features * float32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ // Confirms the rank of the inputs and sets the shape of the outputs.
+ int max_splits;
+ int num_features;
+ float l1, l2, tree_complexity;
+ TF_RETURN_IF_ERROR(c->GetAttr("l1", &l1));
+ if (l1 < 0) {
+ return errors::InvalidArgument("l1 must be non-negative.");
+ }
+ TF_RETURN_IF_ERROR(c->GetAttr("l2", &l2));
+ if (l2 < 0) {
+ return errors::InvalidArgument("l2 must be non-negative.");
+ }
+ TF_RETURN_IF_ERROR(c->GetAttr("tree_complexity", &tree_complexity));
+ if (tree_complexity < 0) {
+ return errors::InvalidArgument("Tree complexity must be non-negative.");
+ }
+ TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits));
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ shape_inference::ShapeHandle node_id_range_shape;
+ shape_inference::ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape));
+ TF_RETURN_IF_ERROR(
+ c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape));
+ // Checks that all stats summary entries are of the same shape.
+ shape_inference::ShapeHandle summary_shape_base;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &summary_shape_base));
+ TF_RETURN_IF_ERROR(c->Merge(summary_shape_base,
+ c->MakeShape({max_splits, -1, 2}),
+ &unused_shape));
+ for (int i = 1; i < num_features; ++i) {
+ shape_inference::ShapeHandle summary_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1 + i), 3, &summary_shape));
+ TF_RETURN_IF_ERROR(
+ c->Merge(summary_shape_base, summary_shape, &unused_shape));
+ }
+ // Sets the output lists.
+ std::vector<shape_inference::ShapeHandle> output_shapes_vec(
+ num_features, c->MakeShape({-1}));
+ TF_RETURN_IF_ERROR(c->set_output("node_ids_list", output_shapes_vec));
+ TF_RETURN_IF_ERROR(c->set_output("gains_list", output_shapes_vec));
+ TF_RETURN_IF_ERROR(c->set_output("thresholds_list", output_shapes_vec));
+ std::vector<shape_inference::ShapeHandle> output_shapes_contribs(
+ num_features, c->MakeShape({-1, 1}));
+ TF_RETURN_IF_ERROR(
+ c->set_output("left_node_contribs_list", output_shapes_contribs));
+ TF_RETURN_IF_ERROR(
+ c->set_output("right_node_contribs_list", output_shapes_contribs));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesCreateEnsemble")
+ .Input("tree_ensemble_handle: resource")
+ .Input("stamp_token: int64")
+ .Input("tree_ensemble_serialized: string")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesDeserializeEnsemble")
+ .Input("tree_ensemble_handle: resource")
+ .Input("stamp_token: int64")
+ .Input("tree_ensemble_serialized: string")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesGetEnsembleStates")
+ .Input("tree_ensemble_handle: resource")
+ .Output("stamp_token: int64")
+ .Output("num_trees: int32")
+ .Output("num_finalized_trees: int32")
+ .Output("num_attempted_layers: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ c->set_output(0, c->Scalar());
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Scalar());
+ c->set_output(3, c->Scalar());
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesMakeStatsSummary")
+ .Input("node_ids: int32")
+ .Input("gradients: float")
+ .Input("hessians: float")
+ .Input("bucketized_features_list: num_features * int32")
+ .Attr("max_splits: int >= 1")
+ .Attr("num_buckets: int >= 1")
+ .Attr("num_features: int >= 1")
+ .Output("stats_summary: float")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ // Sets the shape of the output as a Rank 4 Tensor.
+ int max_splits;
+ int num_buckets;
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits));
+ TF_RETURN_IF_ERROR(c->GetAttr("num_buckets", &num_buckets));
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ shape_inference::ShapeHandle node_ids_shape;
+ shape_inference::ShapeHandle gradients_shape;
+ shape_inference::ShapeHandle hessians_shape;
+ shape_inference::ShapeHandle bucketized_feature_shape;
+ shape_inference::ShapeHandle unused_shape;
+ shape_inference::DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_ids_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(node_ids_shape, 0),
+ c->Dim(gradients_shape, 0), &unused_dim));
+ TF_RETURN_IF_ERROR(
+ c->Merge(gradients_shape, hessians_shape, &unused_shape));
+ for (int f = 0; f < num_features; ++f) {
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(3 + f), 1, &bucketized_feature_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(node_ids_shape, 0),
+ c->Dim(bucketized_feature_shape, 0),
+ &unused_dim));
+ }
+ c->set_output(0,
+ c->MakeShape({num_features, max_splits, num_buckets, 2}));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesPredict")
+ .Input("tree_ensemble_handle: resource")
+ .Input("bucketized_features: num_bucketized_features * int32")
+ .Attr("num_bucketized_features: int >= 1")
+ .Attr("logits_dimension: int")
+ .Attr("max_depth: int >= 1")
+ .Output("logits: float")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle feature_shape;
+ int num_bucketized_features;
+ TF_RETURN_IF_ERROR(
+ c->GetAttr("num_bucketized_features", &num_bucketized_features));
+ shape_inference::ShapeHandle unused_input;
+ for (int i = 0; i < num_bucketized_features; ++i) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 1), 1, &feature_shape));
+ // Check that the shapes of all bucketized features are the same.
+ TF_RETURN_IF_ERROR(c->Merge(c->input(1), feature_shape, &unused_input));
+ }
+
+ int logits_dimension;
+ TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
+ auto logits_shape =
+ c->MakeShape({c->Dim(feature_shape, 0), logits_dimension});
+ // Logits.
+ c->set_output(0, logits_shape);
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesSerializeEnsemble")
+ .Input("tree_ensemble_handle: resource")
+ .Output("stamp_token: int64")
+ .Output("tree_ensemble_serialized: string")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ c->set_output(0, c->Scalar());
+ c->set_output(1, c->Scalar());
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesTrainingPredict")
+ .Input("tree_ensemble_handle: resource")
+ .Input("cached_tree_ids: int32")
+ .Input("cached_node_ids: int32")
+ .Input("bucketized_features: num_bucketized_features * int32")
+ .Attr("num_bucketized_features: int >= 1")
+ .Attr("logits_dimension: int")
+ .Attr("max_depth: int >= 1")
+ .Output("partial_logits: float")
+ .Output("tree_ids: int32")
+ .Output("node_ids: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle feature_shape;
+ int num_bucketized_features;
+ TF_RETURN_IF_ERROR(
+ c->GetAttr("num_bucketized_features", &num_bucketized_features));
+
+ int max_depth;
+ TF_RETURN_IF_ERROR(c->GetAttr("max_depth", &max_depth));
+
+ shape_inference::ShapeHandle unused_input;
+ for (int i = 0; i < num_bucketized_features; ++i) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 3), 1, &feature_shape));
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->input(i + 3), feature_shape, &unused_input));
+ }
+ // all inputs/outputs except logits should have same shape.
+ TF_RETURN_IF_ERROR(c->Merge(c->input(1), feature_shape, &unused_input));
+ TF_RETURN_IF_ERROR(c->Merge(c->input(2), feature_shape, &unused_input));
+
+ int logits_dimension;
+ TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
+ auto logits_shape =
+ c->MakeShape({c->Dim(feature_shape, 0), logits_dimension});
+ // Partial logits.
+ c->set_output(0, logits_shape);
+ // Tree ids.
+ c->set_output(1, c->MakeShape({c->Dim(feature_shape, 0)}));
+ // Node ids.
+ c->set_output(2, c->MakeShape({c->Dim(feature_shape, 0)}));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesUpdateEnsemble")
+ .Input("tree_ensemble_handle: resource")
+ .Input("feature_ids: int32")
+ .Input("node_ids: num_features * int32")
+ .Input("gains: num_features * float")
+ .Input("thresholds: num_features * int32")
+ .Input("left_node_contribs: num_features * float")
+ .Input("right_node_contribs: num_features * float")
+ .Attr("max_depth: int >= 1")
+ .Attr("learning_rate: float")
+ .Attr("pruning_mode: int >=0")
+ .Attr("num_features: int >= 0") // Inferred.
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle shape_handle;
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+
+ // Feature_ids, should be one for each feature.
+ shape_inference::ShapeHandle feature_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &feature_ids_shape));
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->input(1), c->Vector(num_features), &shape_handle));
+
+ for (int i = 0; i < num_features; ++i) {
+ // Node ids.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2), 1, &shape_handle));
+ auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)});
+ auto shape_rank_2 = c->MakeShape({c->Dim(shape_handle, 0), 1});
+
+ // Gains.
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(i + num_features + 2), 1, &shape_handle));
+ // TODO(nponomareva): replace this with input("name",vector of shapes).
+ TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features + 2),
+ shape_rank_1, &shape_handle));
+ // Thresholds.
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(i + num_features * 2 + 2), 1, &shape_handle));
+ TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 2 + 2),
+ shape_rank_1, &shape_handle));
+ // Left and right node contribs.
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(i + num_features * 3 + 2), 2, &shape_handle));
+ TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 3 + 2),
+ shape_rank_2, &shape_handle));
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(i + num_features * 4 + 2), 2, &shape_handle));
+ TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 4 + 2),
+ shape_rank_2, &shape_handle));
+ }
+ return Status::OK();
+ });
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/compat/BUILD b/tensorflow/core/ops/compat/BUILD
index 6cdb1586bc..c613ab144f 100644
--- a/tensorflow/core/ops/compat/BUILD
+++ b/tensorflow/core/ops/compat/BUILD
@@ -57,18 +57,3 @@ tf_cc_binary(
"//tensorflow/core:lib",
],
)
-
-# -----------------------------------------------------------------------------
-# Google-internal targets. These must be at the end for syncrepo.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index b9d5104857..704392fa53 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -1071,7 +1071,12 @@ REGISTER_OP("SoftmaxCrossEntropyWithLogits")
}
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFn(c, 1));
- if (!c->RankKnown(c->output(1)) || c->Rank(c->output(1)) != 2) {
+ if (!c->RankKnown(c->output(1))) {
+ return errors::InvalidArgument(
+ "Shape must be broadcasted with rank 2, but is rank is unknown.");
+ }
+
+ if (c->Rank(c->output(1)) != 2) {
return errors::InvalidArgument(
"Shape must be broadcasted with rank 2, but is rank ",
c->Rank(c->output(1)));
diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index 21636641e7..3ee7be3c4e 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -14,20 +14,6 @@ load(
"if_windows",
)
-filegroup(
- name = "all_files",
- srcs = glob(
- include = [
- "**/*",
- ],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
cc_library(
name = "expiring_lru_cache",
hdrs = ["expiring_lru_cache.h"],
diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD
index 2cd607edbe..447056eb4b 100644
--- a/tensorflow/core/platform/default/build_config/BUILD
+++ b/tensorflow/core/platform/default/build_config/BUILD
@@ -129,6 +129,11 @@ cc_library(
)
cc_library(
+ name = "stacktrace",
+ srcs = [],
+)
+
+cc_library(
name = "gif",
copts = tf_copts(),
deps = [
@@ -218,15 +223,3 @@ alias(
actual = ":mobile_srcs",
visibility = ["//visibility:public"],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/core/platform/hadoop/BUILD b/tensorflow/core/platform/hadoop/BUILD
index 774a439855..7c38c399bd 100644
--- a/tensorflow/core/platform/hadoop/BUILD
+++ b/tensorflow/core/platform/hadoop/BUILD
@@ -12,18 +12,6 @@ load(
"tf_cc_test",
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
cc_library(
name = "hadoop_file_system",
srcs = ["hadoop_file_system.cc"],
diff --git a/tensorflow/core/platform/s3/BUILD b/tensorflow/core/platform/s3/BUILD
index 3a0ad2e9bd..21038cfeb1 100644
--- a/tensorflow/core/platform/s3/BUILD
+++ b/tensorflow/core/platform/s3/BUILD
@@ -13,18 +13,6 @@ load(
"tf_cc_test",
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
tf_cc_binary(
name = "s3_file_system.so",
srcs = [
diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD
index 5ce6f1046d..3d3203cdaa 100644
--- a/tensorflow/core/profiler/BUILD
+++ b/tensorflow/core/profiler/BUILD
@@ -4,21 +4,6 @@ package(
licenses(["notice"]) # Apache 2.0
-# -----------------------------------------------------------------------------
-# Google-internal targets. These must be at the end for syncrepo.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos")
diff --git a/tensorflow/core/profiler/internal/BUILD b/tensorflow/core/profiler/internal/BUILD
index 05a798bff8..8dcfde9a2a 100644
--- a/tensorflow/core/profiler/internal/BUILD
+++ b/tensorflow/core/profiler/internal/BUILD
@@ -365,17 +365,3 @@ cc_library(
"//tensorflow/core:regexp_internal",
],
)
-# -----------------------------------------------------------------------------
-# Google-internal targets. These must be at the end for syncrepo.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/core/profiler/internal/advisor/BUILD b/tensorflow/core/profiler/internal/advisor/BUILD
index 40cfd1e12e..1fedb05ae3 100644
--- a/tensorflow/core/profiler/internal/advisor/BUILD
+++ b/tensorflow/core/profiler/internal/advisor/BUILD
@@ -73,18 +73,3 @@ tf_cc_test(
"//tensorflow/core/profiler/internal:tfprof_tf_testlib",
],
)
-
-# -----------------------------------------------------------------------------
-# Google-internal targets. These must be at the end for syncrepo.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/core/util/ctc/BUILD b/tensorflow/core/util/ctc/BUILD
index 1521349e4d..317420204e 100644
--- a/tensorflow/core/util/ctc/BUILD
+++ b/tensorflow/core/util/ctc/BUILD
@@ -26,18 +26,6 @@ alias(
actual = ":mobile_srcs",
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
cc_library(
name = "ctc",
deps = [
diff --git a/tensorflow/core/util/tensor_bundle/BUILD b/tensorflow/core/util/tensor_bundle/BUILD
index 166bd0f659..648358606c 100644
--- a/tensorflow/core/util/tensor_bundle/BUILD
+++ b/tensorflow/core/util/tensor_bundle/BUILD
@@ -75,18 +75,3 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
-
-# -----------------------------------------------------------------------------
-# Google-internal targets. These must be at the end for syncrepo.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/docs_src/community/contributing.md b/tensorflow/docs_src/community/contributing.md
new file mode 100644
index 0000000000..b0960df435
--- /dev/null
+++ b/tensorflow/docs_src/community/contributing.md
@@ -0,0 +1,64 @@
+# Contributing to TensorFlow
+
+TensorFlow is an open-source project, and we welcome your participation
+and contribution. This page describes how to get involved.
+
+## Repositories
+
+The code for TensorFlow is hosted in the [TensorFlow GitHub
+organization](https://github.com/tensorflow). Multiple projects are located
+inside the organization, including:
+
+* [TensorFlow](https://github.com/tensorflow/tensorflow)
+* [Models](https://github.com/tensorflow/models)
+* [TensorBoard](https://github.com/tensorflow/tensorboard)
+* [TensorFlow.js](https://github.com/tensorflow/tfjs)
+* [TensorFlow Serving](https://github.com/tensorflow/serving)
+* [TensorFlow Documentation](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/docs_src)
+
+## Contributor checklist
+
+* Before contributing to TensorFlow source code, please review the [contribution
+guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md).
+
+* Join the
+[developers@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/developers)
+mailing list, to coordinate and discuss with others contributing to TensorFlow.
+
+* For coding style conventions, read the @{$style_guide$TensorFlow Style Guide}.
+
+* Finally, review @{$documentation$Writing TensorFlow Documentation}, which
+ explains documentation conventions.
+
+You may also wish to review our guide to @{$benchmarks$defining and running benchmarks}.
+
+## Special Interest Groups
+
+To enable focused collaboration on particular areas of TensorFlow, we host
+Special Interest Groups (SIGs). SIGs do their work in public: if you want to
+join and contribute, review the work of the group, and get in touch with the
+relevant SIG leader.
+
+* **SIG Build** focuses on issues surrounding building, packaging, and
+ distribution of TensorFlow. [Mailing list](https://groups.google.com/a/tensorflow.org/forum/#!forum/build).
+
+* **SIG TensorBoard** furthers the development and direction of TensorBoard and its plugins.
+ [Mailing list](https://groups.google.com/a/tensorflow.org/forum/#!forum/tensorboard).
+
+* **SIG Rust** collaborates on the development of TensorFlow's Rust bindings.
+ [Mailing list](https://groups.google.com/a/tensorflow.org/forum/#!forum/rust).
+
+## Projects developed by the TensorFlow community
+
+The TensorFlow community has created many great projects around TensorFlow, including:
+
+* [Machine Learning with TensorFlow (Book & Code)](http://tensorflowbook.com)
+* [@jtoy's awesome "Awesome TensorFlow" list of awesome things](https://github.com/jtoy/awesome-tensorflow)
+* [TensorFlow tutorials](https://github.com/pkmital/tensorflow_tutorials)
+* [Caffe to TensorFlow model converter](https://github.com/ethereon/caffe-tensorflow)
+* [Bitfusion's` GPU-enabled AWS EC2 TensorFlow AMI](https://github.com/bitfusionio/amis/tree/master/awsmrkt-bfboost-ubuntu14-cuda75-tensorflow) ([Launch AMI](https://aws.amazon.com/marketplace/pp/B01EYKBEQ0))
+* [Operator Vectorization Library](https://github.com/opveclib/opveclib)
+* [Swift language bindings](https://github.com/PerfectlySoft/Perfect-TensorFlow)
+* [Sublime Tensorflow - A plugin for Sublime Text](https://github.com/baptisteArnaud/Sublime-Tensorflow)
+* [GPflow - Gaussian processes in TensorFlow](https://github.com/GPflow/GPflow)
+* [CS 20SI: Tensorflow for Deep Learning Research](https://web.stanford.edu/class/cs20si/) - please note, this course was designed with TensorFlow v0.12, so some of the notes may be out of date - but it's still a great resource.
diff --git a/tensorflow/docs_src/community/groups.md b/tensorflow/docs_src/community/groups.md
new file mode 100644
index 0000000000..d92f5775fa
--- /dev/null
+++ b/tensorflow/docs_src/community/groups.md
@@ -0,0 +1,17 @@
+# User Groups
+
+TensorFlow has communities around the world.
+
+## Asia
+
+* [TensorFlow Korea (TF-KR) User Group](https://www.facebook.com/groups/TensorFlowKR/) _(Korean language)_
+* [TensorFlow User Group Tokyo](https://tfug-tokyo.connpass.com/) _(Japanese Language)_
+* [Soleil Data Dojo](https://soleildatadojo.connpass.com/) _(Japanese language)_
+* [TensorFlow User Group Utsunomiya](https://tfug-utsunomiya.connpass.com/)
+
+
+## Europe
+
+* [TensorFlow Barcelona](https://www.meetup.com/Barcelona-Machine-Learning-Meetup/)
+* [TensorFlow Madrid](https://www.meetup.com/TensorFlow-Madrid/)
+
diff --git a/tensorflow/docs_src/community/index.md b/tensorflow/docs_src/community/index.md
index ebeff8493b..c08aeb7a97 100644
--- a/tensorflow/docs_src/community/index.md
+++ b/tensorflow/docs_src/community/index.md
@@ -1,18 +1,81 @@
# Community
-This section contains the following documents:
-
- * @{$welcome$Welcome to the TensorFlow Community}, which explains how
- you can get involved, where to report issues, and where to join
- like-minded TensorFlow enthusiasts online.
- * @{$roadmap$Roadmap}, which summarizes upcoming additions to TensorFlow.
- * @{$documentation$Writing TensorFlow Documentation}, which explains
- TensorFlow's documentation conventions. If you are modifying
- TensorFlow source code or documentation, please read this guide.
- * @{$style_guide$TensorFlow Style Guide}, which identifies coding style
- conventions that TensorFlow developers and users should follow.
- * @{$community/benchmarks$Benchmarks}, Benchmarks, a guide for defining and
- running a TensorFlow benchmark.
- * @{$security$Using TensorFlow Securely}, which explains TensorFlow's security
- model, a list of recent security reports, and information on how you can
- report a security vulnerability to the TensorFlow team.
+Welcome to the TensorFlow community! This page explains where to get help, and
+different ways to be part of the community. We are committed to fostering an
+open and welcoming environment, and request that you review our [code of
+conduct](https://github.com/tensorflow/tensorflow/blob/master/CODE_OF_CONDUCT.md).
+
+## Get Help
+
+### Technical Questions
+
+To ask or answer technical questions about TensorFlow, use [Stack
+Overflow](https://stackoverflow.com/questions/tagged/tensorflow). For example,
+ask or search about a particular error message you encountered during
+installation.
+
+### Bugs and Feature Requests
+
+To report bugs or make feature requests, file an issue on GitHub. Please choose
+the appropriate repository for the project. Major repositories include:
+
+ * [TensorFlow](https://github.com/tensorflow/tensorflow/issues)
+ * [TensorBoard](https://github.com/tensorflow/tensorboard/issues)
+ * [TensorFlow models](https://github.com/tensorflow/models/issues)
+
+### Security
+
+Before using TensorFlow, please take a look at our security model, list of
+recent security announcements, and ways you can report security issues to the
+TensorFlow team at the
+[Using TensorFlow Securely](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) page on GitHub.
+
+## Stay Informed
+
+### Announcements Mailing List
+
+All major releases and important announcements are sent to
+[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
+We recommend that you join this list if you depend on TensorFlow in any way.
+
+### Development Roadmap
+
+The @{$roadmap$Roadmap} summarizes plans for upcoming additions to TensorFlow.
+
+### Social Media
+
+For news and updates from around the universe of TensorFlow projects, follow
+[@tensorflow](https://twitter.com/tensorflow) on Twitter.
+
+### YouTube
+
+Our [YouTube Channel](http://youtube.com/tensorflow/) focuses on machine learing
+and AI with TensorFlow. On it we have a number of new shows, including:
+
+- TensorFlow Meets: meet with community contributors to learn and share what they're doing
+- Ask TensorFlow: the team answers the best questions tagged #AskTensorFlow from social media
+- Coding TensorFlow: short bites with tips for success with TensorFlow
+
+
+## Community Support
+
+### Mailing Lists
+
+For general discussion about TensorFlow development and direction, please join
+the [TensorFlow discuss mailing
+list](https://groups.google.com/a/tensorflow.org/d/forum/discuss).
+
+A number of other mailing lists exist, focused on different project areas, which
+can be found at @{$lists$TensorFlow Mailing Lists}.
+
+### User Groups
+
+To meet with like-minded people local to you, check out the many
+@{$groups$TensorFlow user groups} around the world.
+
+
+## Contributing To TensorFlow
+
+We welcome contributions and collaboration on TensorFlow. For more information,
+please read [Contributing to TensorFlow](contributing.md).
+
diff --git a/tensorflow/docs_src/community/leftnav_files b/tensorflow/docs_src/community/leftnav_files
index af344506c7..0bd1f14de9 100644
--- a/tensorflow/docs_src/community/leftnav_files
+++ b/tensorflow/docs_src/community/leftnav_files
@@ -1,7 +1,8 @@
index.md
-welcome.md
roadmap.md
+contributing.md
+lists.md
+groups.md
documentation.md
style_guide.md
benchmarks.md
-security.md
diff --git a/tensorflow/docs_src/community/lists.md b/tensorflow/docs_src/community/lists.md
new file mode 100644
index 0000000000..dc9240030e
--- /dev/null
+++ b/tensorflow/docs_src/community/lists.md
@@ -0,0 +1,35 @@
+# Mailing Lists
+
+As a community, we do much of our collaboration on public mailing lists.
+Please note that if you're looking for help using TensorFlow, [Stack
+Overflow](https://stackoverflow.com/questions/tagged/tensorflow) and
+[GitHub issues](https://github.com/tensorflow/tensorflow/issues)
+are the best initial places to look. For more information,
+see [how to get help](/community/#get_help).
+
+## General TensorFlow lists
+
+* [announce](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce) - Low-volume announcements of new releases.
+* [discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) - General community discussion around TensorFlow.
+* [developers](https://groups.google.com/a/tensorflow.org/forum/#!forum/developers) - Discussion for developers contributing to TensorFlow.
+
+## Project-specific lists
+
+These projects inside the TensorFlow GitHub organization have lists dedicated to their communities:
+
+* [tensor2tensor](https://groups.google.com/forum/#!forum/tensor2tensor) - User
+ and peer support for Tensor2Tensor.
+
+## Special Interest Groups
+
+TensorFlow's [Special Interest
+Groups](/community/contributing#special_interest_groups) (SIGs) support
+community collaboration on particular project focuses. Members of these groups
+work together to build and support TensorFlow related projects.
+
+* [build](https://groups.google.com/a/tensorflow.org/forum/#!forum/build) -
+ Supporting SIG Build, for build, distribution and packaging of TensorFlow.
+* [tensorboard](https://groups.google.com/a/tensorflow.org/forum/#!forum/tensorboard) -
+ Supporting SIG TensorBoard, for plugin development and other contribution.
+* [rust](https://groups.google.com/a/tensorflow.org/forum/#!forum/rust) -
+ Supporting SIG Rust, for the Rust language bindings.
diff --git a/tensorflow/docs_src/community/welcome.md b/tensorflow/docs_src/community/welcome.md
deleted file mode 100644
index 6d0458e678..0000000000
--- a/tensorflow/docs_src/community/welcome.md
+++ /dev/null
@@ -1,71 +0,0 @@
-# Welcome to the TensorFlow Community
-
-TensorFlow is an open-source project. This page explains how to contribute,
-where to ask questions, and how to help each other.
-
-
-## Development
-
-The source code for TensorFlow is on
-[GitHub](https://github.com/tensorflow/tensorflow).
-
-Before contributing to TensorFlow source code, please review the
-[Contribution guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md).
-
-### Projects developed by the TensorFlow community
-
-The TensorFlow community has created many great projects around TensorFlow, including:
-
-* [Machine Learning with TensorFlow (Book & Code)](http://tensorflowbook.com)
-* [@jtoy's awesome "Awesome TensorFlow" list of awesome things](https://github.com/jtoy/awesome-tensorflow)
-* [TensorFlow tutorials](https://github.com/pkmital/tensorflow_tutorials)
-* [Caffe to TensorFlow model converter](https://github.com/ethereon/caffe-tensorflow)
-* [Bitfusion's` GPU-enabled AWS EC2 TensorFlow AMI](https://github.com/bitfusionio/amis/tree/master/awsmrkt-bfboost-ubuntu14-cuda75-tensorflow) ([Launch AMI](https://aws.amazon.com/marketplace/pp/B01EYKBEQ0))
-* [Rust language bindings](https://github.com/google/tensorflow-rust)
-* [Operator Vectorization Library](https://github.com/opveclib/opveclib)
-* [Swift language bindings](https://github.com/PerfectlySoft/Perfect-TensorFlow)
-* [Sublime Tensorflow - A plugin for Sublime Text](https://github.com/baptisteArnaud/Sublime-Tensorflow)
-* [Edward - A library for probabilistic modeling, inference, and criticism](http://edwardlib.org) ([Github](https://github.com/blei-lab/edward), [Forum](https://discourse.edwardlib.org))
-* [GPflow - Gaussian processes in TensorFlow](https://github.com/GPflow/GPflow)
-* [CS 20SI: Tensorflow for Deep Learning Research](https://web.stanford.edu/class/cs20si/) - Please note, this course was designed with TensorFlow v0.12, so some of the notes may be out of date - but it's still a great resource.
-
-## TensorFlow Communities Around the World
-
-Asia:
-
-* [TensorFlow Korea (TF-KR) User Group](https://www.facebook.com/groups/TensorFlowKR/) _(Korean language)_
-* [TensorFlow User Group Tokyo](https://tfug-tokyo.connpass.com/) _(Japanese Language)_
-* [Soleil Data Dojo](https://soleildatadojo.connpass.com/) _(Japanese language)_
-* [TensorFlow User Group Utsunomiya](https://tfug-utsunomiya.connpass.com/)
-
-
-Europe:
-
-* [TensorFlow Barcelona](https://www.meetup.com/Barcelona-Machine-Learning-Meetup/)
-* [TensorFlow Madrid](https://www.meetup.com/TensorFlow-Madrid/)
-
-
-
-## Support
-
-TensorFlow provides multiple communication paths. To pick the right path,
-please read the following list carefully:
-
- * For new release announcements and security updates, subscribe to
- [announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
- * To ask or answer technical questions about TensorFlow, use
- [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).
- For example, ask or search Stack Overflow about a particular error message
- you encountered during installation.
- * To join general discussions about TensorFlow development and directions,
- please join the
- [TensorFlow discuss mailing list](https://groups.google.com/a/tensorflow.org/d/forum/discuss).
- For example, use this mailing list to learn about new features in
- upcoming releases of TensorFlow.
- * To report bugs or make feature requests, use the
- [TensorFlow issues tracker](https://github.com/tensorflow/tensorflow/issues)
- on GitHub. For example, use the issue tracker to request a
- new operation in TensorFlow.
- * To report vulnerabilities, please follow our
- [vulnerability disclosure guidelines](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md).
-
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 5e39e710a0..32f249cf10 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -241,13 +241,10 @@ See also
Clamps an operand to within the range between a minimum and maximum value.
-<b> `Clamp(computation, args...)` </b>
+<b> `Clamp(min, operand, max)` </b>
| Arguments | Type | Semantics |
| ------------- | ----------------------- | -------------------------------- |
-| `computation` | `Computation` | computation of type `T_0, T_1, |
-: : : ..., T_N -> S` with N parameters :
-: : : of arbitrary type :
| `min` | `ComputationDataHandle` | array of type T |
| `operand` | `ComputationDataHandle` | array of type T |
| `max` | `ComputationDataHandle` | array of type T |
diff --git a/tensorflow/examples/adding_an_op/BUILD b/tensorflow/examples/adding_an_op/BUILD
index b3ed6589ed..cf8054be6a 100644
--- a/tensorflow/examples/adding_an_op/BUILD
+++ b/tensorflow/examples/adding_an_op/BUILD
@@ -139,15 +139,3 @@ tf_cc_binary(
"//tensorflow/core:framework",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD
index 1214647797..a088d7cf2f 100644
--- a/tensorflow/examples/android/BUILD
+++ b/tensorflow/examples/android/BUILD
@@ -101,22 +101,6 @@ filegroup(
# LINT.ThenChange(//tensorflow/examples/android/download-models.gradle)
filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "bin/**",
- "gen/**",
- "gradleBuild/**",
- "libs/**",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
-filegroup(
name = "java_files",
srcs = glob(["src/**/*.java"]),
)
diff --git a/tensorflow/examples/benchmark/BUILD b/tensorflow/examples/benchmark/BUILD
index c4bb0a5bd9..98611a9aad 100644
--- a/tensorflow/examples/benchmark/BUILD
+++ b/tensorflow/examples/benchmark/BUILD
@@ -23,9 +23,3 @@ tf_py_logged_benchmark(
name = "sample_logged_benchmark",
target = "//tensorflow/examples/benchmark:sample_benchmark",
)
-
-filegroup(
- name = "all_files",
- srcs = glob(["**/*"]),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/examples/get_started/regression/BUILD b/tensorflow/examples/get_started/regression/BUILD
index 577b970c90..bee94d7d90 100644
--- a/tensorflow/examples/get_started/regression/BUILD
+++ b/tensorflow/examples/get_started/regression/BUILD
@@ -2,18 +2,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
py_test(
name = "test",
size = "medium",
diff --git a/tensorflow/examples/how_tos/reading_data/BUILD b/tensorflow/examples/how_tos/reading_data/BUILD
index 4a43585d53..64a054d371 100644
--- a/tensorflow/examples/how_tos/reading_data/BUILD
+++ b/tensorflow/examples/how_tos/reading_data/BUILD
@@ -54,15 +54,3 @@ py_binary(
"//tensorflow/examples/tutorials/mnist:input_data",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/examples/image_retraining/BUILD b/tensorflow/examples/image_retraining/BUILD
index 9f9244a74c..ecd79a3b00 100644
--- a/tensorflow/examples/image_retraining/BUILD
+++ b/tensorflow/examples/image_retraining/BUILD
@@ -49,15 +49,3 @@ py_test(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/examples/label_image/BUILD b/tensorflow/examples/label_image/BUILD
index 2abbe9dacc..c50fd93d03 100644
--- a/tensorflow/examples/label_image/BUILD
+++ b/tensorflow/examples/label_image/BUILD
@@ -9,6 +9,8 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
+exports_files(["data/grace_hopper.jpg"])
+
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
tf_cc_binary(
@@ -60,17 +62,3 @@ py_binary(
"//tensorflow:tensorflow_py",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "bin/**",
- "gen/**",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/examples/learn/BUILD b/tensorflow/examples/learn/BUILD
index aba7f600b5..bdbcb0b163 100644
--- a/tensorflow/examples/learn/BUILD
+++ b/tensorflow/examples/learn/BUILD
@@ -152,15 +152,3 @@ sh_test(
"notap",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/examples/multibox_detector/BUILD b/tensorflow/examples/multibox_detector/BUILD
index 91a5bfa51c..4f9908cd52 100644
--- a/tensorflow/examples/multibox_detector/BUILD
+++ b/tensorflow/examples/multibox_detector/BUILD
@@ -27,17 +27,3 @@ tf_cc_binary(
"//tensorflow/core:tensorflow",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "bin/**",
- "gen/**",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/examples/saved_model/BUILD b/tensorflow/examples/saved_model/BUILD
index 1cdf5ec6e1..ebefc6576d 100644
--- a/tensorflow/examples/saved_model/BUILD
+++ b/tensorflow/examples/saved_model/BUILD
@@ -8,19 +8,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "g3doc/sitemap.md",
- ],
- ),
- visibility = ["//visibility:public"],
-)
-
py_binary(
name = "saved_model_half_plus_two",
srcs = [
diff --git a/tensorflow/examples/speech_commands/BUILD b/tensorflow/examples/speech_commands/BUILD
index 12479211c3..13bca34a86 100644
--- a/tensorflow/examples/speech_commands/BUILD
+++ b/tensorflow/examples/speech_commands/BUILD
@@ -245,15 +245,3 @@ tf_cc_binary(
"//tensorflow/core:protos_all_cc",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/examples/tutorials/estimators/BUILD b/tensorflow/examples/tutorials/estimators/BUILD
index ecbc1a431d..bab609f208 100644
--- a/tensorflow/examples/tutorials/estimators/BUILD
+++ b/tensorflow/examples/tutorials/estimators/BUILD
@@ -20,15 +20,3 @@ py_binary(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/examples/tutorials/layers/BUILD b/tensorflow/examples/tutorials/layers/BUILD
index f8a29c79c6..aad78b1840 100644
--- a/tensorflow/examples/tutorials/layers/BUILD
+++ b/tensorflow/examples/tutorials/layers/BUILD
@@ -19,15 +19,3 @@ py_binary(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/examples/tutorials/mnist/BUILD b/tensorflow/examples/tutorials/mnist/BUILD
index 6d4e67063d..aa1b2ec2db 100644
--- a/tensorflow/examples/tutorials/mnist/BUILD
+++ b/tensorflow/examples/tutorials/mnist/BUILD
@@ -132,15 +132,3 @@ py_test(
"//tensorflow:tensorflow_py",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/examples/tutorials/monitors/BUILD b/tensorflow/examples/tutorials/monitors/BUILD
index 4220e8144d..1c49e3fe53 100644
--- a/tensorflow/examples/tutorials/monitors/BUILD
+++ b/tensorflow/examples/tutorials/monitors/BUILD
@@ -23,15 +23,3 @@ py_binary(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/examples/tutorials/word2vec/BUILD b/tensorflow/examples/tutorials/word2vec/BUILD
index bfcf459269..2e19c038bd 100644
--- a/tensorflow/examples/tutorials/word2vec/BUILD
+++ b/tensorflow/examples/tutorials/word2vec/BUILD
@@ -21,14 +21,3 @@ py_binary(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/examples/wav_to_spectrogram/BUILD b/tensorflow/examples/wav_to_spectrogram/BUILD
index c99870c686..cc8835728d 100644
--- a/tensorflow/examples/wav_to_spectrogram/BUILD
+++ b/tensorflow/examples/wav_to_spectrogram/BUILD
@@ -49,17 +49,3 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "bin/**",
- "gen/**",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 838f4f2301..a33703ad6f 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -1320,6 +1320,406 @@ func PadV2(scope *Scope, input tf.Output, paddings tf.Output, constant_values tf
return op.Output(0)
}
+// Return the reduction indices for computing gradients of s0 op s1 with broadcast.
+//
+// This is typically used by gradient computations for a broadcasting operation.
+func BroadcastGradientArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output, r1 tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BroadcastGradientArgs",
+ Input: []tf.Input{
+ s0, s1,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
+// Returns the gradient of `Tile`.
+//
+// DEPRECATED at GraphDef version 3: TileGrad has been replaced with reduce_sum
+//
+// Since `Tile` takes an input and repeats the input `multiples` times
+// along each dimension, `TileGrad` takes in `multiples` and aggregates
+// each repeated tile of `input` into `output`.
+func TileGrad(scope *Scope, input tf.Output, multiples tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "TileGrad",
+ Input: []tf.Input{
+ input, multiples,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Constructs a tensor by tiling a given tensor.
+//
+// This operation creates a new tensor by replicating `input` `multiples` times.
+// The output tensor's i'th dimension has `input.dims(i) * multiples[i]` elements,
+// and the values of `input` are replicated `multiples[i]` times along the 'i'th
+// dimension. For example, tiling `[a b c d]` by `[2]` produces
+// `[a b c d a b c d]`.
+//
+// Arguments:
+// input: 1-D or higher.
+// multiples: 1-D. Length must be the same as the number of dimensions in `input`
+func Tile(scope *Scope, input tf.Output, multiples tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Tile",
+ Input: []tf.Input{
+ input, multiples,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// StridedSliceGradAttr is an optional argument to StridedSliceGrad.
+type StridedSliceGradAttr func(optionalAttr)
+
+// StridedSliceGradBeginMask sets the optional begin_mask attribute to value.
+// If not specified, defaults to 0
+func StridedSliceGradBeginMask(value int64) StridedSliceGradAttr {
+ return func(m optionalAttr) {
+ m["begin_mask"] = value
+ }
+}
+
+// StridedSliceGradEndMask sets the optional end_mask attribute to value.
+// If not specified, defaults to 0
+func StridedSliceGradEndMask(value int64) StridedSliceGradAttr {
+ return func(m optionalAttr) {
+ m["end_mask"] = value
+ }
+}
+
+// StridedSliceGradEllipsisMask sets the optional ellipsis_mask attribute to value.
+// If not specified, defaults to 0
+func StridedSliceGradEllipsisMask(value int64) StridedSliceGradAttr {
+ return func(m optionalAttr) {
+ m["ellipsis_mask"] = value
+ }
+}
+
+// StridedSliceGradNewAxisMask sets the optional new_axis_mask attribute to value.
+// If not specified, defaults to 0
+func StridedSliceGradNewAxisMask(value int64) StridedSliceGradAttr {
+ return func(m optionalAttr) {
+ m["new_axis_mask"] = value
+ }
+}
+
+// StridedSliceGradShrinkAxisMask sets the optional shrink_axis_mask attribute to value.
+// If not specified, defaults to 0
+func StridedSliceGradShrinkAxisMask(value int64) StridedSliceGradAttr {
+ return func(m optionalAttr) {
+ m["shrink_axis_mask"] = value
+ }
+}
+
+// Returns the gradient of `StridedSlice`.
+//
+// Since `StridedSlice` cuts out pieces of its `input` which is size
+// `shape`, its gradient will have the same shape (which is passed here
+// as `shape`). The gradient will be zero in any element that the slice
+// does not select.
+//
+// Arguments are the same as StridedSliceGrad with the exception that
+// `dy` is the input gradient to be propagated and `shape` is the
+// shape of `StridedSlice`'s `input`.
+func StridedSliceGrad(scope *Scope, shape tf.Output, begin tf.Output, end tf.Output, strides tf.Output, dy tf.Output, optional ...StridedSliceGradAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "StridedSliceGrad",
+ Input: []tf.Input{
+ shape, begin, end, strides, dy,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// StridedSliceAttr is an optional argument to StridedSlice.
+type StridedSliceAttr func(optionalAttr)
+
+// StridedSliceBeginMask sets the optional begin_mask attribute to value.
+//
+// value: a bitmask where a bit i being 1 means to ignore the begin
+// value and instead use the largest interval possible. At runtime
+// begin[i] will be replaced with `[0, n-1) if `stride[i] > 0` or
+// `[-1, n-1]` if `stride[i] < 0`
+// If not specified, defaults to 0
+func StridedSliceBeginMask(value int64) StridedSliceAttr {
+ return func(m optionalAttr) {
+ m["begin_mask"] = value
+ }
+}
+
+// StridedSliceEndMask sets the optional end_mask attribute to value.
+//
+// value: analogous to `begin_mask`
+// If not specified, defaults to 0
+func StridedSliceEndMask(value int64) StridedSliceAttr {
+ return func(m optionalAttr) {
+ m["end_mask"] = value
+ }
+}
+
+// StridedSliceEllipsisMask sets the optional ellipsis_mask attribute to value.
+//
+// value: a bitmask where bit `i` being 1 means the `i`th
+// position is actually an ellipsis. One bit at most can be 1.
+// If `ellipsis_mask == 0`, then an implicit ellipsis mask of `1 << (m+1)`
+// is provided. This means that `foo[3:5] == foo[3:5, ...]`. An ellipsis
+// implicitly creates as many range specifications as necessary to fully
+// specify the sliced range for every dimension. For example for a 4-dimensional
+// tensor `foo` the slice `foo[2, ..., 5:8]` implies `foo[2, :, :, 5:8]`.
+// If not specified, defaults to 0
+func StridedSliceEllipsisMask(value int64) StridedSliceAttr {
+ return func(m optionalAttr) {
+ m["ellipsis_mask"] = value
+ }
+}
+
+// StridedSliceNewAxisMask sets the optional new_axis_mask attribute to value.
+//
+// value: a bitmask where bit `i` being 1 means the `i`th
+// specification creates a new shape 1 dimension. For example
+// `foo[:4, tf.newaxis, :2]` would produce a shape `(4, 1, 2)` tensor.
+// If not specified, defaults to 0
+func StridedSliceNewAxisMask(value int64) StridedSliceAttr {
+ return func(m optionalAttr) {
+ m["new_axis_mask"] = value
+ }
+}
+
+// StridedSliceShrinkAxisMask sets the optional shrink_axis_mask attribute to value.
+//
+// value: a bitmask where bit `i` implies that the `i`th
+// specification should shrink the dimensionality. begin and end
+// must imply a slice of size 1 in the dimension. For example in
+// python one might do `foo[:, 3, :]` which would result in
+// `shrink_axis_mask` being 2.
+// If not specified, defaults to 0
+func StridedSliceShrinkAxisMask(value int64) StridedSliceAttr {
+ return func(m optionalAttr) {
+ m["shrink_axis_mask"] = value
+ }
+}
+
+// Return a strided slice from `input`.
+//
+// Note, most python users will want to use the Python `Tensor.__getitem__`
+// or `Variable.__getitem__` rather than this op directly.
+//
+// The goal of this op is to produce a new tensor with a subset of
+// the elements from the `n` dimensional `input` tensor. The subset is chosen using
+// a sequence of `m` sparse range specifications encoded into the arguments
+// of this function. Note, in some cases
+// `m` could be equal to `n`, but this need not be the case. Each
+// range specification entry can be one of the following:
+//
+// - An ellipsis (...). Ellipses are used to imply zero or more
+// dimensions of full-dimension selection and are produced using
+// `ellipsis_mask`. For example, `foo[...]` is the identity slice.
+//
+// - A new axis. This is used to insert a new shape=1 dimension and is
+// produced using `new_axis_mask`. For example, `foo[:, ...]` where
+// `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor.
+//
+//
+// - A range `begin:end:stride`. This is used to specify how much to choose from
+// a given dimension. `stride` can be any integer but 0. `begin` is an integer
+// which represents the index of the first value to select while `end` represents
+// the index of the last value to select. The number of values selected in each
+// dimension is `end - begin` if `stride > 0` and `begin - end` if `stride < 0`.
+// `begin` and `end` can be negative where `-1` is the last element, `-2` is
+// the second to last. `begin_mask` controls whether to replace the explicitly
+// given `begin` with an implicit effective value of `0` if `stride > 0` and
+// `-1` if `stride < 0`. `end_mask` is analogous but produces the number
+// required to create the largest open interval. For example, given a shape
+// `(3,)` tensor `foo[:]`, the effective `begin` and `end` are `0` and `3`. Do
+// not assume this is equivalent to `foo[0:-1]` which has an effective `begin`
+// and `end` of `0` and `2`. Another example is `foo[-2::-1]` which reverses the
+// first dimension of a tensor while dropping the last two (in the original
+// order elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`.
+//
+// - A single index. This is used to keep only elements that have a given
+// index. For example (`foo[2, :]` on a shape `(5,6)` tensor produces a
+// shape `(6,)` tensor. This is encoded in `begin` and `end` and
+// `shrink_axis_mask`.
+//
+// Each conceptual range specification is encoded in the op's argument. This
+// encoding is best understand by considering a non-trivial example. In
+// particular,
+// `foo[1, 2:4, None, ..., :-3:-1, :]` will be encoded as
+//
+// ```
+// begin = [1, 2, x, x, 0, x] # x denotes don't care (usually 0)
+// end = [2, 4, x, x, -3, x]
+// strides = [1, 1, x, x, -1, 1]
+// begin_mask = 1<<4 | 1 << 5 = 48
+// end_mask = 1<<5 = 32
+// ellipsis_mask = 1<<3 = 8
+// new_axis_mask = 1<<2 4
+// shrink_axis_mask = 1<<0
+// ```
+//
+// In this case if `foo.shape` is (5, 5, 5, 5, 5, 5) the final shape of
+// the slice becomes (2, 1, 5, 5, 2, 5).
+// Let us walk step by step through each argument specification.
+//
+// 1. The first argument in the example slice is turned into `begin = 1` and
+// `end = begin + 1 = 2`. To disambiguate from the original spec `2:4` we
+// also set the appropriate bit in `shrink_axis_mask`.
+//
+// 2. `2:4` is contributes 2, 4, 1 to begin, end, and stride. All masks have
+// zero bits contributed.
+//
+// 3. None is a synonym for `tf.newaxis`. This means insert a dimension of size 1
+// dimension in the final shape. Dummy values are contributed to begin,
+// end and stride, while the new_axis_mask bit is set.
+//
+// 4. `...` grab the full ranges from as many dimensions as needed to
+// fully specify a slice for every dimension of the input shape.
+//
+// 5. `:-3:-1` shows the use of negative indices. A negative index `i` associated
+// with a dimension that has shape `s` is converted to a positive index
+// `s + i`. So `-1` becomes `s-1` (i.e. the last element). This conversion
+// is done internally so begin, end and strides receive x, -3, and -1.
+// The appropriate begin_mask bit is set to indicate the start range is the
+// full range (ignoring the x).
+//
+// 6. `:` indicates that the entire contents of the corresponding dimension
+// is selected. This is equivalent to `::` or `0::1`. begin, end, and strides
+// receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and
+// `end_mask` are also set.
+//
+// *Requirements*:
+// `0 != strides[i] for i in [0, m)`
+// `ellipsis_mask must be a power of two (only one ellipsis)`
+//
+// Arguments:
+//
+// begin: `begin[k]` specifies the offset into the `k`th range specification.
+// The exact dimension this corresponds to will be determined by context.
+// Out-of-bounds values will be silently clamped. If the `k`th bit of
+// `begin_mask` then `begin[k]` is ignored and the full range of the
+// appropriate dimension is used instead. Negative values causes indexing
+// to start from the highest element e.g. If `foo==[1,2,3]` then `foo[-1]==3`.
+// end: `end[i]` is like `begin` with the exception that `end_mask` is
+// used to determine full ranges.
+// strides: `strides[i]` specifies the increment in the `i`th specification
+// after extracting a given element. Negative indices will reverse
+// the original order. Out or range values are
+// clamped to `[0,dim[i]) if slice[i]>0` or `[-1,dim[i]-1] if slice[i] < 0`
+func StridedSlice(scope *Scope, input tf.Output, begin tf.Output, end tf.Output, strides tf.Output, optional ...StridedSliceAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "StridedSlice",
+ Input: []tf.Input{
+ input, begin, end, strides,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Return a slice from 'input'.
+//
+// The output tensor is a tensor with dimensions described by 'size'
+// whose values are extracted from 'input' starting at the offsets in
+// 'begin'.
+//
+// *Requirements*:
+// 0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n)
+//
+// Arguments:
+//
+// begin: begin[i] specifies the offset into the 'i'th dimension of
+// 'input' to slice from.
+// size: size[i] specifies the number of elements of the 'i'th dimension
+// of 'input' to slice. If size[i] is -1, all remaining elements in dimension
+// i are included in the slice (i.e. this is equivalent to setting
+// size[i] = input.dim_size(i) - begin[i]).
+func Slice(scope *Scope, input tf.Output, begin tf.Output, size tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Slice",
+ Input: []tf.Input{
+ input, begin, size,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// SizeAttr is an optional argument to Size.
+type SizeAttr func(optionalAttr)
+
+// SizeOutType sets the optional out_type attribute to value.
+// If not specified, defaults to DT_INT32
+func SizeOutType(value tf.DataType) SizeAttr {
+ return func(m optionalAttr) {
+ m["out_type"] = value
+ }
+}
+
+// Returns the size of a tensor.
+//
+// This operation returns an integer representing the number of elements in
+// `input`.
+//
+// For example:
+//
+// ```
+// # 't' is [[[1, 1,, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]]
+// size(t) ==> 12
+// ```
+func Size(scope *Scope, input tf.Output, optional ...SizeAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Size",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Returns the complex conjugate of a complex number.
//
// Given a tensor `input` of complex numbers, this operation returns a tensor of
@@ -1796,6 +2196,116 @@ func UnsortedSegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output, num
return op.Output(0)
}
+// ResourceStridedSliceAssignAttr is an optional argument to ResourceStridedSliceAssign.
+type ResourceStridedSliceAssignAttr func(optionalAttr)
+
+// ResourceStridedSliceAssignBeginMask sets the optional begin_mask attribute to value.
+// If not specified, defaults to 0
+func ResourceStridedSliceAssignBeginMask(value int64) ResourceStridedSliceAssignAttr {
+ return func(m optionalAttr) {
+ m["begin_mask"] = value
+ }
+}
+
+// ResourceStridedSliceAssignEndMask sets the optional end_mask attribute to value.
+// If not specified, defaults to 0
+func ResourceStridedSliceAssignEndMask(value int64) ResourceStridedSliceAssignAttr {
+ return func(m optionalAttr) {
+ m["end_mask"] = value
+ }
+}
+
+// ResourceStridedSliceAssignEllipsisMask sets the optional ellipsis_mask attribute to value.
+// If not specified, defaults to 0
+func ResourceStridedSliceAssignEllipsisMask(value int64) ResourceStridedSliceAssignAttr {
+ return func(m optionalAttr) {
+ m["ellipsis_mask"] = value
+ }
+}
+
+// ResourceStridedSliceAssignNewAxisMask sets the optional new_axis_mask attribute to value.
+// If not specified, defaults to 0
+func ResourceStridedSliceAssignNewAxisMask(value int64) ResourceStridedSliceAssignAttr {
+ return func(m optionalAttr) {
+ m["new_axis_mask"] = value
+ }
+}
+
+// ResourceStridedSliceAssignShrinkAxisMask sets the optional shrink_axis_mask attribute to value.
+// If not specified, defaults to 0
+func ResourceStridedSliceAssignShrinkAxisMask(value int64) ResourceStridedSliceAssignAttr {
+ return func(m optionalAttr) {
+ m["shrink_axis_mask"] = value
+ }
+}
+
+// Assign `value` to the sliced l-value reference of `ref`.
+//
+// The values of `value` are assigned to the positions in the variable
+// `ref` that are selected by the slice parameters. The slice parameters
+// `begin, `end`, `strides`, etc. work exactly as in `StridedSlice`.
+//
+// NOTE this op currently does not support broadcasting and so `value`'s
+// shape must be exactly the shape produced by the slice of `ref`.
+//
+// Returns the created operation.
+func ResourceStridedSliceAssign(scope *Scope, ref tf.Output, begin tf.Output, end tf.Output, strides tf.Output, value tf.Output, optional ...ResourceStridedSliceAssignAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceStridedSliceAssign",
+ Input: []tf.Input{
+ ref, begin, end, strides, value,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// ArgMaxAttr is an optional argument to ArgMax.
+type ArgMaxAttr func(optionalAttr)
+
+// ArgMaxOutputType sets the optional output_type attribute to value.
+// If not specified, defaults to DT_INT64
+func ArgMaxOutputType(value tf.DataType) ArgMaxAttr {
+ return func(m optionalAttr) {
+ m["output_type"] = value
+ }
+}
+
+// Returns the index with the largest value across dimensions of a tensor.
+//
+// Note that in case of ties the identity of the return value is not guaranteed.
+//
+// Arguments:
+//
+// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`.
+// Describes which dimension of the input Tensor to reduce across. For vectors,
+// use dimension = 0.
+func ArgMax(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMaxAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ArgMax",
+ Input: []tf.Input{
+ input, dimension,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Returns which elements of x are finite.
//
// @compatibility(numpy)
@@ -7514,6 +8024,75 @@ func ResourceApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.O
return scope.AddOperation(opspec)
}
+// Return the shape of s0 op s1 with broadcast.
+//
+// Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the
+// broadcasted shape. `s0`, `s1` and `r0` are all integer vectors.
+func BroadcastArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BroadcastArgs",
+ Input: []tf.Input{
+ s0, s1,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// DataFormatDimMapAttr is an optional argument to DataFormatDimMap.
+type DataFormatDimMapAttr func(optionalAttr)
+
+// DataFormatDimMapSrcFormat sets the optional src_format attribute to value.
+//
+// value: source data format.
+// If not specified, defaults to "NHWC"
+func DataFormatDimMapSrcFormat(value string) DataFormatDimMapAttr {
+ return func(m optionalAttr) {
+ m["src_format"] = value
+ }
+}
+
+// DataFormatDimMapDstFormat sets the optional dst_format attribute to value.
+//
+// value: destination data format.
+// If not specified, defaults to "NCHW"
+func DataFormatDimMapDstFormat(value string) DataFormatDimMapAttr {
+ return func(m optionalAttr) {
+ m["dst_format"] = value
+ }
+}
+
+// Returns the dimension index in the destination data format given the one in
+//
+// the source data format.
+//
+// Arguments:
+// x: A Tensor with each element as a dimension index in source data format.
+// Must be in the range [-4, 4).
+//
+// Returns A Tensor with each element as a dimension index in destination data format.
+func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAttr) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "DataFormatDimMap",
+ Input: []tf.Input{
+ x,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ResourceApplyPowerSignAttr is an optional argument to ResourceApplyPowerSign.
type ResourceApplyPowerSignAttr func(optionalAttr)
@@ -8470,47 +9049,6 @@ func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Out
return scope.AddOperation(opspec)
}
-// SizeAttr is an optional argument to Size.
-type SizeAttr func(optionalAttr)
-
-// SizeOutType sets the optional out_type attribute to value.
-// If not specified, defaults to DT_INT32
-func SizeOutType(value tf.DataType) SizeAttr {
- return func(m optionalAttr) {
- m["out_type"] = value
- }
-}
-
-// Returns the size of a tensor.
-//
-// This operation returns an integer representing the number of elements in
-// `input`.
-//
-// For example:
-//
-// ```
-// # 't' is [[[1, 1,, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]]
-// size(t) ==> 12
-// ```
-func Size(scope *Scope, input tf.Output, optional ...SizeAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Size",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate.
type ResourceScatterNdUpdateAttr func(optionalAttr)
@@ -13934,124 +14472,6 @@ func QuantizedReluX(scope *Scope, features tf.Output, max_value tf.Output, min_f
return op.Output(0), op.Output(1), op.Output(2)
}
-// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2.
-type WholeFileReaderV2Attr func(optionalAttr)
-
-// WholeFileReaderV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this reader is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// WholeFileReaderV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this reader is named in the given bucket
-// with this shared_name. Otherwise, the node name is used instead.
-// If not specified, defaults to ""
-func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// A Reader that outputs the entire contents of a file as a value.
-//
-// To use, enqueue filenames in a Queue. The output of ReaderRead will
-// be a filename (key) and the contents of that file (value).
-//
-// Returns The handle to reference the Reader.
-func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "WholeFileReaderV2",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Transforms a tf.Example proto (as a string) into typed tensors.
-//
-// Arguments:
-// serialized: A vector containing a batch of binary serialized Example protos.
-// dense_defaults: A list of Tensors (some may be empty), whose length matches
-// the length of `dense_keys`. dense_defaults[j] provides default values
-// when the example's feature_map lacks dense_key[j]. If an empty Tensor is
-// provided for dense_defaults[j], then the Feature dense_keys[j] is required.
-// The input type is inferred from dense_defaults[j], even when it's empty.
-// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined,
-// then the shape of dense_defaults[j] must match that of dense_shapes[j].
-// If dense_shapes[j] has an undefined major dimension (variable strides dense
-// feature), dense_defaults[j] must contain a single element:
-// the padding element.
-// num_sparse: The number of sparse features to be parsed from the example. This
-// must match the lengths of `sparse_keys` and `sparse_types`.
-// sparse_keys: A list of `num_sparse` strings.
-// The keys expected in the Examples' features associated with sparse values.
-// dense_keys: The keys expected in the Examples' features associated with dense
-// values.
-// sparse_types: A list of `num_sparse` types; the data types of data in each
-// Feature given in sparse_keys.
-// Currently the ParseSingleExample op supports DT_FLOAT (FloatList),
-// DT_INT64 (Int64List), and DT_STRING (BytesList).
-// dense_shapes: The shapes of data in each Feature given in dense_keys.
-// The length of this list must match the length of `dense_keys`. The
-// number of elements in the Feature corresponding to dense_key[j] must
-// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] ==
-// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j]
-// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1,
-// ..., DN), the shape of the output Tensor dense_values[j] will be (M,
-// D1, .., DN), where M is the number of blocks of elements of length
-// D1 * .... * DN, in the input.
-func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes}
- opspec := tf.OpSpec{
- Type: "ParseSingleExample",
- Input: []tf.Input{
- serialized, tf.OutputList(dense_defaults),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil {
- scope.UpdateErr("ParseSingleExample", err)
- return
- }
- if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil {
- scope.UpdateErr("ParseSingleExample", err)
- return
- }
- if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil {
- scope.UpdateErr("ParseSingleExample", err)
- return
- }
- if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil {
- scope.UpdateErr("ParseSingleExample", err)
- return
- }
- return sparse_indices, sparse_values, sparse_shapes, dense_values
-}
-
// QuantizedConv2DAttr is an optional argument to QuantizedConv2D.
type QuantizedConv2DAttr func(optionalAttr)
@@ -15280,31 +15700,6 @@ func SparseSparseMinimum(scope *Scope, a_indices tf.Output, a_values tf.Output,
return op.Output(0), op.Output(1)
}
-// Constructs a tensor by tiling a given tensor.
-//
-// This operation creates a new tensor by replicating `input` `multiples` times.
-// The output tensor's i'th dimension has `input.dims(i) * multiples[i]` elements,
-// and the values of `input` are replicated `multiples[i]` times along the 'i'th
-// dimension. For example, tiling `[a b c d]` by `[2]` produces
-// `[a b c d a b c d]`.
-//
-// Arguments:
-// input: 1-D or higher.
-// multiples: 1-D. Length must be the same as the number of dimensions in `input`
-func Tile(scope *Scope, input tf.Output, multiples tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Tile",
- Input: []tf.Input{
- input, multiples,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// TakeManySparseFromTensorsMapAttr is an optional argument to TakeManySparseFromTensorsMap.
type TakeManySparseFromTensorsMapAttr func(optionalAttr)
@@ -16729,6 +17124,203 @@ func RandomGamma(scope *Scope, shape tf.Output, alpha tf.Output, optional ...Ran
return op.Output(0)
}
+// QuantizeAndDequantizeAttr is an optional argument to QuantizeAndDequantize.
+type QuantizeAndDequantizeAttr func(optionalAttr)
+
+// QuantizeAndDequantizeSignedInput sets the optional signed_input attribute to value.
+// If not specified, defaults to true
+func QuantizeAndDequantizeSignedInput(value bool) QuantizeAndDequantizeAttr {
+ return func(m optionalAttr) {
+ m["signed_input"] = value
+ }
+}
+
+// QuantizeAndDequantizeNumBits sets the optional num_bits attribute to value.
+// If not specified, defaults to 8
+func QuantizeAndDequantizeNumBits(value int64) QuantizeAndDequantizeAttr {
+ return func(m optionalAttr) {
+ m["num_bits"] = value
+ }
+}
+
+// QuantizeAndDequantizeRangeGiven sets the optional range_given attribute to value.
+// If not specified, defaults to false
+func QuantizeAndDequantizeRangeGiven(value bool) QuantizeAndDequantizeAttr {
+ return func(m optionalAttr) {
+ m["range_given"] = value
+ }
+}
+
+// QuantizeAndDequantizeInputMin sets the optional input_min attribute to value.
+// If not specified, defaults to 0
+func QuantizeAndDequantizeInputMin(value float32) QuantizeAndDequantizeAttr {
+ return func(m optionalAttr) {
+ m["input_min"] = value
+ }
+}
+
+// QuantizeAndDequantizeInputMax sets the optional input_max attribute to value.
+// If not specified, defaults to 0
+func QuantizeAndDequantizeInputMax(value float32) QuantizeAndDequantizeAttr {
+ return func(m optionalAttr) {
+ m["input_max"] = value
+ }
+}
+
+// Use QuantizeAndDequantizeV2 instead.
+//
+// DEPRECATED at GraphDef version 22: Replaced by QuantizeAndDequantizeV2
+func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAndDequantizeAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "QuantizeAndDequantize",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns locations of nonzero / true values in a tensor.
+//
+// This operation returns the coordinates of true elements in `condition`. The
+// coordinates are returned in a 2-D tensor where the first dimension (rows)
+// represents the number of true elements, and the second dimension (columns)
+// represents the coordinates of the true elements. Keep in mind, the shape of
+// the output tensor can vary depending on how many true values there are in
+// `condition`. Indices are output in row-major order.
+//
+// For example:
+//
+// ```
+// # 'input' tensor is [[True, False]
+// # [True, False]]
+// # 'input' has two true values, so output has two coordinates.
+// # 'input' has rank of 2, so coordinates have two indices.
+// where(input) ==> [[0, 0],
+// [1, 0]]
+//
+// # `condition` tensor is [[[True, False]
+// # [True, False]]
+// # [[False, True]
+// # [False, True]]
+// # [[False, False]
+// # [False, True]]]
+// # 'input' has 5 true values, so output has 5 coordinates.
+// # 'input' has rank of 3, so coordinates have three indices.
+// where(input) ==> [[0, 0, 0],
+// [0, 1, 0],
+// [1, 0, 1],
+// [1, 1, 1],
+// [2, 1, 1]]
+//
+// # `condition` tensor is [[[1.5, 0.0]
+// # [-0.5, 0.0]]
+// # [[0.0, 0.25]
+// # [0.0, 0.75]]
+// # [[0.0, 0.0]
+// # [0.0, 0.01]]]
+// # 'input' has 5 nonzero values, so output has 5 coordinates.
+// # 'input' has rank of 3, so coordinates have three indices.
+// where(input) ==> [[0, 0, 0],
+// [0, 1, 0],
+// [1, 0, 1],
+// [1, 1, 1],
+// [2, 1, 1]]
+//
+// # `condition` tensor is [[[1.5 + 0.0j, 0.0 + 0.0j]
+// # [0.0 + 0.5j, 0.0 + 0.0j]]
+// # [[0.0 + 0.0j, 0.25 + 1.5j]
+// # [0.0 + 0.0j, 0.75 + 0.0j]]
+// # [[0.0 + 0.0j, 0.0 + 0.0j]
+// # [0.0 + 0.0j, 0.01 + 0.0j]]]
+// # 'input' has 5 nonzero magnitude values, so output has 5 coordinates.
+// # 'input' has rank of 3, so coordinates have three indices.
+// where(input) ==> [[0, 0, 0],
+// [0, 1, 0],
+// [1, 0, 1],
+// [1, 1, 1],
+// [2, 1, 1]]
+// ```
+func Where(scope *Scope, condition tf.Output) (index tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Where",
+ Input: []tf.Input{
+ condition,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// QueueDequeueV2Attr is an optional argument to QueueDequeueV2.
+type QueueDequeueV2Attr func(optionalAttr)
+
+// QueueDequeueV2TimeoutMs sets the optional timeout_ms attribute to value.
+//
+// value: If the queue is empty, this operation will block for up to
+// timeout_ms milliseconds.
+// Note: This option is not supported yet.
+// If not specified, defaults to -1
+func QueueDequeueV2TimeoutMs(value int64) QueueDequeueV2Attr {
+ return func(m optionalAttr) {
+ m["timeout_ms"] = value
+ }
+}
+
+// Dequeues a tuple of one or more tensors from the given queue.
+//
+// This operation has k outputs, where k is the number of components
+// in the tuples stored in the given queue, and output i is the ith
+// component of the dequeued tuple.
+//
+// N.B. If the queue is empty, this operation will block until an element
+// has been dequeued (or 'timeout_ms' elapses, if specified).
+//
+// Arguments:
+// handle: The handle to a queue.
+// component_types: The type of each component in a tuple.
+//
+// Returns One or more tensors that were dequeued as a tuple.
+func QueueDequeueV2(scope *Scope, handle tf.Output, component_types []tf.DataType, optional ...QueueDequeueV2Attr) (components []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"component_types": component_types}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "QueueDequeueV2",
+ Input: []tf.Input{
+ handle,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if components, idx, err = makeOutputList(op, idx, "components"); err != nil {
+ scope.UpdateErr("QueueDequeueV2", err)
+ return
+ }
+ return components
+}
+
// RandomUniformIntAttr is an optional argument to RandomUniformInt.
type RandomUniformIntAttr func(optionalAttr)
@@ -22891,199 +23483,6 @@ func InvGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) {
return op.Output(0)
}
-// StridedSliceAttr is an optional argument to StridedSlice.
-type StridedSliceAttr func(optionalAttr)
-
-// StridedSliceBeginMask sets the optional begin_mask attribute to value.
-//
-// value: a bitmask where a bit i being 1 means to ignore the begin
-// value and instead use the largest interval possible. At runtime
-// begin[i] will be replaced with `[0, n-1) if `stride[i] > 0` or
-// `[-1, n-1]` if `stride[i] < 0`
-// If not specified, defaults to 0
-func StridedSliceBeginMask(value int64) StridedSliceAttr {
- return func(m optionalAttr) {
- m["begin_mask"] = value
- }
-}
-
-// StridedSliceEndMask sets the optional end_mask attribute to value.
-//
-// value: analogous to `begin_mask`
-// If not specified, defaults to 0
-func StridedSliceEndMask(value int64) StridedSliceAttr {
- return func(m optionalAttr) {
- m["end_mask"] = value
- }
-}
-
-// StridedSliceEllipsisMask sets the optional ellipsis_mask attribute to value.
-//
-// value: a bitmask where bit `i` being 1 means the `i`th
-// position is actually an ellipsis. One bit at most can be 1.
-// If `ellipsis_mask == 0`, then an implicit ellipsis mask of `1 << (m+1)`
-// is provided. This means that `foo[3:5] == foo[3:5, ...]`. An ellipsis
-// implicitly creates as many range specifications as necessary to fully
-// specify the sliced range for every dimension. For example for a 4-dimensional
-// tensor `foo` the slice `foo[2, ..., 5:8]` implies `foo[2, :, :, 5:8]`.
-// If not specified, defaults to 0
-func StridedSliceEllipsisMask(value int64) StridedSliceAttr {
- return func(m optionalAttr) {
- m["ellipsis_mask"] = value
- }
-}
-
-// StridedSliceNewAxisMask sets the optional new_axis_mask attribute to value.
-//
-// value: a bitmask where bit `i` being 1 means the `i`th
-// specification creates a new shape 1 dimension. For example
-// `foo[:4, tf.newaxis, :2]` would produce a shape `(4, 1, 2)` tensor.
-// If not specified, defaults to 0
-func StridedSliceNewAxisMask(value int64) StridedSliceAttr {
- return func(m optionalAttr) {
- m["new_axis_mask"] = value
- }
-}
-
-// StridedSliceShrinkAxisMask sets the optional shrink_axis_mask attribute to value.
-//
-// value: a bitmask where bit `i` implies that the `i`th
-// specification should shrink the dimensionality. begin and end
-// must imply a slice of size 1 in the dimension. For example in
-// python one might do `foo[:, 3, :]` which would result in
-// `shrink_axis_mask` being 2.
-// If not specified, defaults to 0
-func StridedSliceShrinkAxisMask(value int64) StridedSliceAttr {
- return func(m optionalAttr) {
- m["shrink_axis_mask"] = value
- }
-}
-
-// Return a strided slice from `input`.
-//
-// Note, most python users will want to use the Python `Tensor.__getitem__`
-// or `Variable.__getitem__` rather than this op directly.
-//
-// The goal of this op is to produce a new tensor with a subset of
-// the elements from the `n` dimensional `input` tensor. The subset is chosen using
-// a sequence of `m` sparse range specifications encoded into the arguments
-// of this function. Note, in some cases
-// `m` could be equal to `n`, but this need not be the case. Each
-// range specification entry can be one of the following:
-//
-// - An ellipsis (...). Ellipses are used to imply zero or more
-// dimensions of full-dimension selection and are produced using
-// `ellipsis_mask`. For example, `foo[...]` is the identity slice.
-//
-// - A new axis. This is used to insert a new shape=1 dimension and is
-// produced using `new_axis_mask`. For example, `foo[:, ...]` where
-// `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor.
-//
-//
-// - A range `begin:end:stride`. This is used to specify how much to choose from
-// a given dimension. `stride` can be any integer but 0. `begin` is an integer
-// which represents the index of the first value to select while `end` represents
-// the index of the last value to select. The number of values selected in each
-// dimension is `end - begin` if `stride > 0` and `begin - end` if `stride < 0`.
-// `begin` and `end` can be negative where `-1` is the last element, `-2` is
-// the second to last. `begin_mask` controls whether to replace the explicitly
-// given `begin` with an implicit effective value of `0` if `stride > 0` and
-// `-1` if `stride < 0`. `end_mask` is analogous but produces the number
-// required to create the largest open interval. For example, given a shape
-// `(3,)` tensor `foo[:]`, the effective `begin` and `end` are `0` and `3`. Do
-// not assume this is equivalent to `foo[0:-1]` which has an effective `begin`
-// and `end` of `0` and `2`. Another example is `foo[-2::-1]` which reverses the
-// first dimension of a tensor while dropping the last two (in the original
-// order elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`.
-//
-// - A single index. This is used to keep only elements that have a given
-// index. For example (`foo[2, :]` on a shape `(5,6)` tensor produces a
-// shape `(6,)` tensor. This is encoded in `begin` and `end` and
-// `shrink_axis_mask`.
-//
-// Each conceptual range specification is encoded in the op's argument. This
-// encoding is best understand by considering a non-trivial example. In
-// particular,
-// `foo[1, 2:4, None, ..., :-3:-1, :]` will be encoded as
-//
-// ```
-// begin = [1, 2, x, x, 0, x] # x denotes don't care (usually 0)
-// end = [2, 4, x, x, -3, x]
-// strides = [1, 1, x, x, -1, 1]
-// begin_mask = 1<<4 | 1 << 5 = 48
-// end_mask = 1<<5 = 32
-// ellipsis_mask = 1<<3 = 8
-// new_axis_mask = 1<<2 4
-// shrink_axis_mask = 1<<0
-// ```
-//
-// In this case if `foo.shape` is (5, 5, 5, 5, 5, 5) the final shape of
-// the slice becomes (2, 1, 5, 5, 2, 5).
-// Let us walk step by step through each argument specification.
-//
-// 1. The first argument in the example slice is turned into `begin = 1` and
-// `end = begin + 1 = 2`. To disambiguate from the original spec `2:4` we
-// also set the appropriate bit in `shrink_axis_mask`.
-//
-// 2. `2:4` is contributes 2, 4, 1 to begin, end, and stride. All masks have
-// zero bits contributed.
-//
-// 3. None is a synonym for `tf.newaxis`. This means insert a dimension of size 1
-// dimension in the final shape. Dummy values are contributed to begin,
-// end and stride, while the new_axis_mask bit is set.
-//
-// 4. `...` grab the full ranges from as many dimensions as needed to
-// fully specify a slice for every dimension of the input shape.
-//
-// 5. `:-3:-1` shows the use of negative indices. A negative index `i` associated
-// with a dimension that has shape `s` is converted to a positive index
-// `s + i`. So `-1` becomes `s-1` (i.e. the last element). This conversion
-// is done internally so begin, end and strides receive x, -3, and -1.
-// The appropriate begin_mask bit is set to indicate the start range is the
-// full range (ignoring the x).
-//
-// 6. `:` indicates that the entire contents of the corresponding dimension
-// is selected. This is equivalent to `::` or `0::1`. begin, end, and strides
-// receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and
-// `end_mask` are also set.
-//
-// *Requirements*:
-// `0 != strides[i] for i in [0, m)`
-// `ellipsis_mask must be a power of two (only one ellipsis)`
-//
-// Arguments:
-//
-// begin: `begin[k]` specifies the offset into the `k`th range specification.
-// The exact dimension this corresponds to will be determined by context.
-// Out-of-bounds values will be silently clamped. If the `k`th bit of
-// `begin_mask` then `begin[k]` is ignored and the full range of the
-// appropriate dimension is used instead. Negative values causes indexing
-// to start from the highest element e.g. If `foo==[1,2,3]` then `foo[-1]==3`.
-// end: `end[i]` is like `begin` with the exception that `end_mask` is
-// used to determine full ranges.
-// strides: `strides[i]` specifies the increment in the `i`th specification
-// after extracting a given element. Negative indices will reverse
-// the original order. Out or range values are
-// clamped to `[0,dim[i]) if slice[i]>0` or `[-1,dim[i]-1] if slice[i] < 0`
-func StridedSlice(scope *Scope, input tf.Output, begin tf.Output, end tf.Output, strides tf.Output, optional ...StridedSliceAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "StridedSlice",
- Input: []tf.Input{
- input, begin, end, strides,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// PriorityQueueV2Attr is an optional argument to PriorityQueueV2.
type PriorityQueueV2Attr func(optionalAttr)
@@ -23233,116 +23632,6 @@ func Unstage(scope *Scope, dtypes []tf.DataType, optional ...UnstageAttr) (value
return values
}
-// ArgMaxAttr is an optional argument to ArgMax.
-type ArgMaxAttr func(optionalAttr)
-
-// ArgMaxOutputType sets the optional output_type attribute to value.
-// If not specified, defaults to DT_INT64
-func ArgMaxOutputType(value tf.DataType) ArgMaxAttr {
- return func(m optionalAttr) {
- m["output_type"] = value
- }
-}
-
-// Returns the index with the largest value across dimensions of a tensor.
-//
-// Note that in case of ties the identity of the return value is not guaranteed.
-//
-// Arguments:
-//
-// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`.
-// Describes which dimension of the input Tensor to reduce across. For vectors,
-// use dimension = 0.
-func ArgMax(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMaxAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ArgMax",
- Input: []tf.Input{
- input, dimension,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ResourceStridedSliceAssignAttr is an optional argument to ResourceStridedSliceAssign.
-type ResourceStridedSliceAssignAttr func(optionalAttr)
-
-// ResourceStridedSliceAssignBeginMask sets the optional begin_mask attribute to value.
-// If not specified, defaults to 0
-func ResourceStridedSliceAssignBeginMask(value int64) ResourceStridedSliceAssignAttr {
- return func(m optionalAttr) {
- m["begin_mask"] = value
- }
-}
-
-// ResourceStridedSliceAssignEndMask sets the optional end_mask attribute to value.
-// If not specified, defaults to 0
-func ResourceStridedSliceAssignEndMask(value int64) ResourceStridedSliceAssignAttr {
- return func(m optionalAttr) {
- m["end_mask"] = value
- }
-}
-
-// ResourceStridedSliceAssignEllipsisMask sets the optional ellipsis_mask attribute to value.
-// If not specified, defaults to 0
-func ResourceStridedSliceAssignEllipsisMask(value int64) ResourceStridedSliceAssignAttr {
- return func(m optionalAttr) {
- m["ellipsis_mask"] = value
- }
-}
-
-// ResourceStridedSliceAssignNewAxisMask sets the optional new_axis_mask attribute to value.
-// If not specified, defaults to 0
-func ResourceStridedSliceAssignNewAxisMask(value int64) ResourceStridedSliceAssignAttr {
- return func(m optionalAttr) {
- m["new_axis_mask"] = value
- }
-}
-
-// ResourceStridedSliceAssignShrinkAxisMask sets the optional shrink_axis_mask attribute to value.
-// If not specified, defaults to 0
-func ResourceStridedSliceAssignShrinkAxisMask(value int64) ResourceStridedSliceAssignAttr {
- return func(m optionalAttr) {
- m["shrink_axis_mask"] = value
- }
-}
-
-// Assign `value` to the sliced l-value reference of `ref`.
-//
-// The values of `value` are assigned to the positions in the variable
-// `ref` that are selected by the slice parameters. The slice parameters
-// `begin, `end`, `strides`, etc. work exactly as in `StridedSlice`.
-//
-// NOTE this op currently does not support broadcasting and so `value`'s
-// shape must be exactly the shape produced by the slice of `ref`.
-//
-// Returns the created operation.
-func ResourceStridedSliceAssign(scope *Scope, ref tf.Output, begin tf.Output, end tf.Output, strides tf.Output, value tf.Output, optional ...ResourceStridedSliceAssignAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceStridedSliceAssign",
- Input: []tf.Input{
- ref, begin, end, strides, value,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
// QueueEnqueueV2Attr is an optional argument to QueueEnqueueV2.
type QueueEnqueueV2Attr func(optionalAttr)
@@ -26224,6 +26513,124 @@ func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true
return op.Output(0), op.Output(1), op.Output(2)
}
+// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2.
+type WholeFileReaderV2Attr func(optionalAttr)
+
+// WholeFileReaderV2Container sets the optional container attribute to value.
+//
+// value: If non-empty, this reader is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// WholeFileReaderV2SharedName sets the optional shared_name attribute to value.
+//
+// value: If non-empty, this reader is named in the given bucket
+// with this shared_name. Otherwise, the node name is used instead.
+// If not specified, defaults to ""
+func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// A Reader that outputs the entire contents of a file as a value.
+//
+// To use, enqueue filenames in a Queue. The output of ReaderRead will
+// be a filename (key) and the contents of that file (value).
+//
+// Returns The handle to reference the Reader.
+func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "WholeFileReaderV2",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Transforms a tf.Example proto (as a string) into typed tensors.
+//
+// Arguments:
+// serialized: A vector containing a batch of binary serialized Example protos.
+// dense_defaults: A list of Tensors (some may be empty), whose length matches
+// the length of `dense_keys`. dense_defaults[j] provides default values
+// when the example's feature_map lacks dense_key[j]. If an empty Tensor is
+// provided for dense_defaults[j], then the Feature dense_keys[j] is required.
+// The input type is inferred from dense_defaults[j], even when it's empty.
+// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined,
+// then the shape of dense_defaults[j] must match that of dense_shapes[j].
+// If dense_shapes[j] has an undefined major dimension (variable strides dense
+// feature), dense_defaults[j] must contain a single element:
+// the padding element.
+// num_sparse: The number of sparse features to be parsed from the example. This
+// must match the lengths of `sparse_keys` and `sparse_types`.
+// sparse_keys: A list of `num_sparse` strings.
+// The keys expected in the Examples' features associated with sparse values.
+// dense_keys: The keys expected in the Examples' features associated with dense
+// values.
+// sparse_types: A list of `num_sparse` types; the data types of data in each
+// Feature given in sparse_keys.
+// Currently the ParseSingleExample op supports DT_FLOAT (FloatList),
+// DT_INT64 (Int64List), and DT_STRING (BytesList).
+// dense_shapes: The shapes of data in each Feature given in dense_keys.
+// The length of this list must match the length of `dense_keys`. The
+// number of elements in the Feature corresponding to dense_key[j] must
+// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] ==
+// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j]
+// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1,
+// ..., DN), the shape of the output Tensor dense_values[j] will be (M,
+// D1, .., DN), where M is the number of blocks of elements of length
+// D1 * .... * DN, in the input.
+func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes}
+ opspec := tf.OpSpec{
+ Type: "ParseSingleExample",
+ Input: []tf.Input{
+ serialized, tf.OutputList(dense_defaults),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil {
+ scope.UpdateErr("ParseSingleExample", err)
+ return
+ }
+ if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil {
+ scope.UpdateErr("ParseSingleExample", err)
+ return
+ }
+ if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil {
+ scope.UpdateErr("ParseSingleExample", err)
+ return
+ }
+ if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil {
+ scope.UpdateErr("ParseSingleExample", err)
+ return
+ }
+ return sparse_indices, sparse_values, sparse_shapes, dense_values
+}
+
// Elementwise computes the bitwise AND of `x` and `y`.
//
// The result will have those bits set, that are set in both `x` and `y`. The
@@ -27480,410 +27887,3 @@ func UniqueV2(scope *Scope, x tf.Output, axis tf.Output, optional ...UniqueV2Att
op := scope.AddOperation(opspec)
return op.Output(0), op.Output(1)
}
-
-// Return a slice from 'input'.
-//
-// The output tensor is a tensor with dimensions described by 'size'
-// whose values are extracted from 'input' starting at the offsets in
-// 'begin'.
-//
-// *Requirements*:
-// 0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n)
-//
-// Arguments:
-//
-// begin: begin[i] specifies the offset into the 'i'th dimension of
-// 'input' to slice from.
-// size: size[i] specifies the number of elements of the 'i'th dimension
-// of 'input' to slice. If size[i] is -1, all remaining elements in dimension
-// i are included in the slice (i.e. this is equivalent to setting
-// size[i] = input.dim_size(i) - begin[i]).
-func Slice(scope *Scope, input tf.Output, begin tf.Output, size tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Slice",
- Input: []tf.Input{
- input, begin, size,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// StridedSliceGradAttr is an optional argument to StridedSliceGrad.
-type StridedSliceGradAttr func(optionalAttr)
-
-// StridedSliceGradBeginMask sets the optional begin_mask attribute to value.
-// If not specified, defaults to 0
-func StridedSliceGradBeginMask(value int64) StridedSliceGradAttr {
- return func(m optionalAttr) {
- m["begin_mask"] = value
- }
-}
-
-// StridedSliceGradEndMask sets the optional end_mask attribute to value.
-// If not specified, defaults to 0
-func StridedSliceGradEndMask(value int64) StridedSliceGradAttr {
- return func(m optionalAttr) {
- m["end_mask"] = value
- }
-}
-
-// StridedSliceGradEllipsisMask sets the optional ellipsis_mask attribute to value.
-// If not specified, defaults to 0
-func StridedSliceGradEllipsisMask(value int64) StridedSliceGradAttr {
- return func(m optionalAttr) {
- m["ellipsis_mask"] = value
- }
-}
-
-// StridedSliceGradNewAxisMask sets the optional new_axis_mask attribute to value.
-// If not specified, defaults to 0
-func StridedSliceGradNewAxisMask(value int64) StridedSliceGradAttr {
- return func(m optionalAttr) {
- m["new_axis_mask"] = value
- }
-}
-
-// StridedSliceGradShrinkAxisMask sets the optional shrink_axis_mask attribute to value.
-// If not specified, defaults to 0
-func StridedSliceGradShrinkAxisMask(value int64) StridedSliceGradAttr {
- return func(m optionalAttr) {
- m["shrink_axis_mask"] = value
- }
-}
-
-// Returns the gradient of `StridedSlice`.
-//
-// Since `StridedSlice` cuts out pieces of its `input` which is size
-// `shape`, its gradient will have the same shape (which is passed here
-// as `shape`). The gradient will be zero in any element that the slice
-// does not select.
-//
-// Arguments are the same as StridedSliceGrad with the exception that
-// `dy` is the input gradient to be propagated and `shape` is the
-// shape of `StridedSlice`'s `input`.
-func StridedSliceGrad(scope *Scope, shape tf.Output, begin tf.Output, end tf.Output, strides tf.Output, dy tf.Output, optional ...StridedSliceGradAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "StridedSliceGrad",
- Input: []tf.Input{
- shape, begin, end, strides, dy,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns the gradient of `Tile`.
-//
-// DEPRECATED at GraphDef version 3: TileGrad has been replaced with reduce_sum
-//
-// Since `Tile` takes an input and repeats the input `multiples` times
-// along each dimension, `TileGrad` takes in `multiples` and aggregates
-// each repeated tile of `input` into `output`.
-func TileGrad(scope *Scope, input tf.Output, multiples tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "TileGrad",
- Input: []tf.Input{
- input, multiples,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// QuantizeAndDequantizeAttr is an optional argument to QuantizeAndDequantize.
-type QuantizeAndDequantizeAttr func(optionalAttr)
-
-// QuantizeAndDequantizeSignedInput sets the optional signed_input attribute to value.
-// If not specified, defaults to true
-func QuantizeAndDequantizeSignedInput(value bool) QuantizeAndDequantizeAttr {
- return func(m optionalAttr) {
- m["signed_input"] = value
- }
-}
-
-// QuantizeAndDequantizeNumBits sets the optional num_bits attribute to value.
-// If not specified, defaults to 8
-func QuantizeAndDequantizeNumBits(value int64) QuantizeAndDequantizeAttr {
- return func(m optionalAttr) {
- m["num_bits"] = value
- }
-}
-
-// QuantizeAndDequantizeRangeGiven sets the optional range_given attribute to value.
-// If not specified, defaults to false
-func QuantizeAndDequantizeRangeGiven(value bool) QuantizeAndDequantizeAttr {
- return func(m optionalAttr) {
- m["range_given"] = value
- }
-}
-
-// QuantizeAndDequantizeInputMin sets the optional input_min attribute to value.
-// If not specified, defaults to 0
-func QuantizeAndDequantizeInputMin(value float32) QuantizeAndDequantizeAttr {
- return func(m optionalAttr) {
- m["input_min"] = value
- }
-}
-
-// QuantizeAndDequantizeInputMax sets the optional input_max attribute to value.
-// If not specified, defaults to 0
-func QuantizeAndDequantizeInputMax(value float32) QuantizeAndDequantizeAttr {
- return func(m optionalAttr) {
- m["input_max"] = value
- }
-}
-
-// Use QuantizeAndDequantizeV2 instead.
-//
-// DEPRECATED at GraphDef version 22: Replaced by QuantizeAndDequantizeV2
-func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAndDequantizeAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "QuantizeAndDequantize",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// QueueDequeueV2Attr is an optional argument to QueueDequeueV2.
-type QueueDequeueV2Attr func(optionalAttr)
-
-// QueueDequeueV2TimeoutMs sets the optional timeout_ms attribute to value.
-//
-// value: If the queue is empty, this operation will block for up to
-// timeout_ms milliseconds.
-// Note: This option is not supported yet.
-// If not specified, defaults to -1
-func QueueDequeueV2TimeoutMs(value int64) QueueDequeueV2Attr {
- return func(m optionalAttr) {
- m["timeout_ms"] = value
- }
-}
-
-// Dequeues a tuple of one or more tensors from the given queue.
-//
-// This operation has k outputs, where k is the number of components
-// in the tuples stored in the given queue, and output i is the ith
-// component of the dequeued tuple.
-//
-// N.B. If the queue is empty, this operation will block until an element
-// has been dequeued (or 'timeout_ms' elapses, if specified).
-//
-// Arguments:
-// handle: The handle to a queue.
-// component_types: The type of each component in a tuple.
-//
-// Returns One or more tensors that were dequeued as a tuple.
-func QueueDequeueV2(scope *Scope, handle tf.Output, component_types []tf.DataType, optional ...QueueDequeueV2Attr) (components []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"component_types": component_types}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "QueueDequeueV2",
- Input: []tf.Input{
- handle,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if components, idx, err = makeOutputList(op, idx, "components"); err != nil {
- scope.UpdateErr("QueueDequeueV2", err)
- return
- }
- return components
-}
-
-// Returns locations of nonzero / true values in a tensor.
-//
-// This operation returns the coordinates of true elements in `condition`. The
-// coordinates are returned in a 2-D tensor where the first dimension (rows)
-// represents the number of true elements, and the second dimension (columns)
-// represents the coordinates of the true elements. Keep in mind, the shape of
-// the output tensor can vary depending on how many true values there are in
-// `condition`. Indices are output in row-major order.
-//
-// For example:
-//
-// ```
-// # 'input' tensor is [[True, False]
-// # [True, False]]
-// # 'input' has two true values, so output has two coordinates.
-// # 'input' has rank of 2, so coordinates have two indices.
-// where(input) ==> [[0, 0],
-// [1, 0]]
-//
-// # `condition` tensor is [[[True, False]
-// # [True, False]]
-// # [[False, True]
-// # [False, True]]
-// # [[False, False]
-// # [False, True]]]
-// # 'input' has 5 true values, so output has 5 coordinates.
-// # 'input' has rank of 3, so coordinates have three indices.
-// where(input) ==> [[0, 0, 0],
-// [0, 1, 0],
-// [1, 0, 1],
-// [1, 1, 1],
-// [2, 1, 1]]
-//
-// # `condition` tensor is [[[1.5, 0.0]
-// # [-0.5, 0.0]]
-// # [[0.0, 0.25]
-// # [0.0, 0.75]]
-// # [[0.0, 0.0]
-// # [0.0, 0.01]]]
-// # 'input' has 5 nonzero values, so output has 5 coordinates.
-// # 'input' has rank of 3, so coordinates have three indices.
-// where(input) ==> [[0, 0, 0],
-// [0, 1, 0],
-// [1, 0, 1],
-// [1, 1, 1],
-// [2, 1, 1]]
-//
-// # `condition` tensor is [[[1.5 + 0.0j, 0.0 + 0.0j]
-// # [0.0 + 0.5j, 0.0 + 0.0j]]
-// # [[0.0 + 0.0j, 0.25 + 1.5j]
-// # [0.0 + 0.0j, 0.75 + 0.0j]]
-// # [[0.0 + 0.0j, 0.0 + 0.0j]
-// # [0.0 + 0.0j, 0.01 + 0.0j]]]
-// # 'input' has 5 nonzero magnitude values, so output has 5 coordinates.
-// # 'input' has rank of 3, so coordinates have three indices.
-// where(input) ==> [[0, 0, 0],
-// [0, 1, 0],
-// [1, 0, 1],
-// [1, 1, 1],
-// [2, 1, 1]]
-// ```
-func Where(scope *Scope, condition tf.Output) (index tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Where",
- Input: []tf.Input{
- condition,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// DataFormatDimMapAttr is an optional argument to DataFormatDimMap.
-type DataFormatDimMapAttr func(optionalAttr)
-
-// DataFormatDimMapSrcFormat sets the optional src_format attribute to value.
-//
-// value: source data format.
-// If not specified, defaults to "NHWC"
-func DataFormatDimMapSrcFormat(value string) DataFormatDimMapAttr {
- return func(m optionalAttr) {
- m["src_format"] = value
- }
-}
-
-// DataFormatDimMapDstFormat sets the optional dst_format attribute to value.
-//
-// value: destination data format.
-// If not specified, defaults to "NCHW"
-func DataFormatDimMapDstFormat(value string) DataFormatDimMapAttr {
- return func(m optionalAttr) {
- m["dst_format"] = value
- }
-}
-
-// Returns the dimension index in the destination data format given the one in
-//
-// the source data format.
-//
-// Arguments:
-// x: A Tensor with each element as a dimension index in source data format.
-// Must be in the range [-4, 4).
-//
-// Returns A Tensor with each element as a dimension index in destination data format.
-func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAttr) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "DataFormatDimMap",
- Input: []tf.Input{
- x,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Return the shape of s0 op s1 with broadcast.
-//
-// Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the
-// broadcasted shape. `s0`, `s1` and `r0` are all integer vectors.
-func BroadcastArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BroadcastArgs",
- Input: []tf.Input{
- s0, s1,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Return the reduction indices for computing gradients of s0 op s1 with broadcast.
-//
-// This is typically used by gradient computations for a broadcasting operation.
-func BroadcastGradientArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output, r1 tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BroadcastGradientArgs",
- Input: []tf.Input{
- s0, s1,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index 7296205e24..1be4c838f3 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -387,15 +387,3 @@ genrule(
cmd = "cp $< $@",
output_to_bindir = 1,
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index ae7e3e73ae..d299389a77 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -73,6 +73,7 @@ py_library(
deps = [
":array_ops",
":bitwise_ops",
+ ":boosted_trees_ops",
":check_ops",
":client",
":client_testlib",
@@ -298,6 +299,7 @@ cc_library(
srcs = ["util/util.cc"],
hdrs = ["util/util.h"],
deps = [
+ ":safe_ptr",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//util/python:python_headers",
@@ -1374,6 +1376,14 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "boosted_trees_ops_gen",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/core:boosted_trees_ops_op_lib",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "summary_ops_gen",
visibility = ["//tensorflow:__subpackages__"],
deps = ["//tensorflow/core:summary_ops_op_lib"],
@@ -1623,6 +1633,19 @@ py_library(
)
py_library(
+ name = "boosted_trees_ops",
+ srcs = ["ops/boosted_trees_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":boosted_trees_ops_gen",
+ ":framework",
+ ":ops",
+ ":training",
+ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
+ ],
+)
+
+py_library(
name = "sets",
srcs = [
"ops/sets.py",
@@ -3144,6 +3167,8 @@ tf_proto_library(
srcs = ["framework/cpp_shape_inference.proto"],
cc_api_version = 2,
protodeps = tf_additional_all_protos(),
+ # TODO(b/74620627): remove when _USE_C_SHAPES is removed
+ visibility = ["//tensorflow:internal"],
)
py_test(
@@ -4382,18 +4407,6 @@ py_test(
],
)
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
cuda_py_test(
name = "accumulate_n_benchmark",
size = "large",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 3346937904..ab1d01a835 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -98,6 +98,8 @@ from tensorflow.python.summary import summary
from tensorflow.python.user_ops import user_ops
from tensorflow.python.util import compat
+# Import boosted trees ops to make sure the ops are registered (but unused).
+from tensorflow.python.ops import gen_boosted_trees_ops as _gen_boosted_trees_ops
# Import cudnn rnn ops to make sure their ops are registered.
from tensorflow.python.ops import gen_cudnn_rnn_ops as _
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index e88fc0c01a..70a3d032f4 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -723,6 +723,7 @@ def TF_Reset(target, containers=None, config=None):
%unignore TF_TryEvaluateConstant_wrapper;
%noexception TF_TryEvaluateConstant_wrapper;
%unignore ExtendSession;
+%unignore ResourceHandleShapeAndType;
%include "tensorflow/python/client/tf_session_helper.h"
diff --git a/tensorflow/python/data/BUILD b/tensorflow/python/data/BUILD
index b5bee36dcd..3e08c1587e 100644
--- a/tensorflow/python/data/BUILD
+++ b/tensorflow/python/data/BUILD
@@ -15,15 +15,3 @@ py_library(
"//tensorflow/python/data/ops:readers",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 8b8adefa65..ed0c11e6c1 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -367,15 +367,3 @@ tf_py_test(
"no_windows",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index 3119ab0037..fa2e86eab1 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -59,15 +59,3 @@ py_library(
"//tensorflow/python/eager:context",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index c0a6283be4..8729e085a3 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -2043,6 +2043,8 @@ class PrefetchDataset(Dataset):
"""See `Dataset.prefetch()` for details."""
super(PrefetchDataset, self).__init__()
self._input_dataset = input_dataset
+ if buffer_size is None:
+ buffer_size = -1 # This is the sentinel for auto-tuning.
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD
index b1bdbdab37..0fc32d51b9 100644
--- a/tensorflow/python/data/util/BUILD
+++ b/tensorflow/python/data/util/BUILD
@@ -109,15 +109,3 @@ py_test(
"//tensorflow/python:util",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 512d292ee2..4195586313 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -1095,15 +1095,3 @@ sh_test(
":offline_analyzer",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 0e089a26eb..8c0d3feece 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -398,21 +398,6 @@ py_test(
],
)
-# -----------------------------------------------------------------------------
-# Google-internal targets.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
py_library(
name = "imperative_grad",
srcs = ["imperative_grad.py"],
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index c54a5a1445..209b012621 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -646,6 +646,13 @@ _default_vspace = imperative_grad.VSpace(
ones=_ones)
+def _handle_or_self(x):
+ """If x is ResourceVariable, return its handle, else x."""
+ if isinstance(x, resource_variable_ops.ResourceVariable):
+ x = x.handle
+ return x
+
+
@tf_export("GradientTape")
class GradientTape(object):
"""Record operations for automatic differentiation.
@@ -723,9 +730,7 @@ class GradientTape(object):
tensor: a Tensor or list of Tensors.
"""
for t in nest.flatten(tensor):
- if isinstance(t, resource_variable_ops.ResourceVariable):
- t = t.handle
- tape.watch(t)
+ tape.watch(_handle_or_self(t))
def watched_variables(self):
# Sorting variables by id, which is monotonically increasing in construction
@@ -739,14 +744,15 @@ class GradientTape(object):
Args:
target: Tensor to be differentiated.
- sources: a list of Tensors or Variables. `target` will be differentiated
- against elements in `sources`.
+ sources: a list or nested structure of Tensors or Variables. `target`
+ will be differentiated against elements in `sources`.
output_gradients: a list of gradients, one for each element of
target. Defaults to None.
Returns:
- a list of Tensors (or IndexedSlices, or None), one for each element in
- `sources`.
+ a list or nested structure of Tensors (or IndexedSlices, or None),
+ one for each element in `sources`. Returned structure is the same as
+ the structure of `sources`.
Raises:
RuntimeError: if called inside the context of the tape, or if called more
@@ -756,12 +762,15 @@ class GradientTape(object):
raise RuntimeError("GradientTape.gradient can only be called once "
"on non-persistent tapes, and "
"only when the context manager has exited.")
- sources = [x.handle if isinstance(x, resource_variable_ops.ResourceVariable)
- else x
- for x in sources]
- grad = imperative_grad.imperative_grad(
- _default_vspace, self._tape, [target], sources,
+ flat_sources = nest.flatten(sources)
+ flat_sources = [_handle_or_self(x) for x in flat_sources]
+
+ flat_grad = imperative_grad.imperative_grad(
+ _default_vspace, self._tape, [target], flat_sources,
output_gradients=output_gradients)
+
if not self._persistent:
self._tape = None
+
+ grad = nest.pack_sequence_as(sources, flat_grad)
return grad
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index f04d89a6d9..991b4dbe7a 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -371,6 +371,53 @@ class BackpropTest(test.TestCase):
@test_util.assert_no_new_tensors
@test_util.run_in_graph_and_eager_modes()
+ def testGradientTapeRepeatedSource(self):
+ with backprop.GradientTape(persistent=False) as g:
+ x = constant_op.constant(3.0)
+ g.watch(x)
+ y = 2 * x
+ grad = g.gradient(target=y, sources=[x, x])
+ self.assertEqual(self.evaluate(grad), [2.0, 2.0])
+
+ @test_util.assert_no_new_tensors
+ @test_util.run_in_graph_and_eager_modes()
+ def testPersistentGradientTapeRepeatedSource(self):
+ with backprop.GradientTape(persistent=True) as g:
+ x = constant_op.constant(3.0)
+ y = constant_op.constant(5.0)
+ g.watch(x)
+ g.watch(y)
+ z = x * x + x * y
+ grad = g.gradient(target=z, sources=[x, x])
+ self.assertEqual(self.evaluate(grad), [11.0, 11.0])
+ grad = g.gradient(target=z, sources=[y, x])
+ self.assertEqual(self.evaluate(grad), [3.0, 11.0])
+
+ @test_util.assert_no_new_tensors
+ @test_util.run_in_graph_and_eager_modes()
+ def testGradientTapeStructure(self):
+ with backprop.GradientTape(persistent=True) as g:
+ # Using different constant values because constant tensors are
+ # cached, leading to a different gradient then what one might expect.
+ x1 = constant_op.constant(3.0)
+ x2 = constant_op.constant(3.1)
+ x3 = constant_op.constant(3.2)
+ g.watch(x1)
+ g.watch(x2)
+ g.watch(x3)
+ y = x1 + 2 * x2 + 3 * x3
+ self.assertEqual(self.evaluate(g.gradient(y, x1)), [1.0])
+ self.assertEqual(self.evaluate(g.gradient(y, (x1,))), (1.0,))
+ self.assertEqual(self.evaluate(g.gradient(y, (x1, x2))), (1.0, 2.0))
+ self.assertEqual(self.evaluate(g.gradient(y, [(x1, x2), (x2, x3)])),
+ [(1.0, 2.0), (2.0, 3.0)])
+ self.assertEqual(self.evaluate(g.gradient(y, (x1, x2, [x1, x3]))),
+ (1.0, 2.0, [1.0, 3.0]))
+ self.assertEqual(self.evaluate(g.gradient(y, [x1, {'x2': x2, 'x3': x3}])),
+ [1.0, {'x2': 2.0, 'x3': 3.0}])
+
+ @test_util.assert_no_new_tensors
+ @test_util.run_in_graph_and_eager_modes()
def testGradientTape(self):
with backprop.GradientTape() as g:
x = constant_op.constant(3.0)
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 6c9a14730c..8c1bb06bc3 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -232,7 +232,7 @@ class Context(object):
pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(
opts, self._device_policy)
if self._execution_mode == ASYNC:
- pywrap_tensorflow.TFE_ContextOptionsSetAsync(True)
+ pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
self._context_handle = pywrap_tensorflow.TFE_NewContext(opts, status)
finally:
pywrap_tensorflow.TFE_DeleteContextOptions(opts)
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index 6ebf5b2481..5f19f64846 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -97,6 +97,14 @@ class TFETest(test_util.TensorFlowTestCase):
self.assertTrue(has_cpu_device)
del ctx
+ def testAsyncBasic(self):
+ ctx = context.Context(execution_mode=context.ASYNC)
+ has_cpu_device = False
+ for x in ctx.devices():
+ has_cpu_device = has_cpu_device or 'CPU' in x
+ self.assertTrue(has_cpu_device)
+ del ctx
+
def testRunMetadata(self):
context.enable_run_metadata()
t = constant_op.constant(1.0)
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 55ba509065..8a398f6447 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1372,11 +1372,15 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
}
if (!result.empty()) {
PyObject* py_result = PyList_New(result.size());
+ tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size());
for (int i = 0; i < result.size(); ++i) {
if (result[i] == nullptr) {
Py_INCREF(Py_None);
result[i] = Py_None;
+ } else if (seen_results.find(result[i]) != seen_results.end()) {
+ Py_INCREF(result[i]);
}
+ seen_results.insert(result[i]);
PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
}
return py_result;
@@ -1405,16 +1409,33 @@ bool CheckInputsOk(PyObject* seq, int start_index,
PyObject* item = PyTuple_GET_ITEM(seq, i + start_index);
if (!op_def.input_arg(i).number_attr().empty() ||
!op_def.input_arg(i).type_list_attr().empty()) {
- // This item should be a list input.
- if (!PyList_Check(item)) return false;
- for (Py_ssize_t j = 0; j < PyList_Size(item); j++) {
- PyObject* inner_item = PyList_GET_ITEM(item, j);
+ // This item should be a seq input.
+ if (!PySequence_Check(item)) {
+ VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
+ << "\", Input \"" << op_def.input_arg(i).name()
+ << "\" since we expected a sequence, but got "
+ << item->ob_type->tp_name;
+ return false;
+ }
+ for (Py_ssize_t j = 0; j < PySequence_Fast_GET_SIZE(item); j++) {
+ PyObject* inner_item = PySequence_Fast_GET_ITEM(item, j);
if (!EagerTensor_CheckExact(inner_item) &&
!CheckResourceVariable(inner_item)) {
+ VLOG(1)
+ << "Falling back to slow path for Op \"" << op_def.name()
+ << "\", Input \"" << op_def.input_arg(i).name() << "\", Index "
+ << j
+ << " since we expected an EagerTensor/ResourceVariable, but got "
+ << inner_item->ob_type->tp_name;
return false;
}
}
} else if (!EagerTensor_CheckExact(item) && !CheckResourceVariable(item)) {
+ VLOG(1)
+ << "Falling back to slow path for Op \"" << op_def.name()
+ << "\", Input \"" << op_def.input_arg(i).name()
+ << "\" since we expected an EagerTensor/ResourceVariable, but got "
+ << item->ob_type->tp_name;
return false;
}
}
@@ -1726,11 +1747,11 @@ const char* GetDeviceName(PyObject* py_device_name) {
return nullptr;
}
-bool RaiseIfNotPyList(PyObject* list, const string& attr_name) {
- if (!PyList_Check(list)) {
+bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) {
+ if (!PySequence_Check(seq)) {
PyErr_SetString(PyExc_TypeError,
- Printf("expected a list for attr %s, got %s instead",
- attr_name.data(), list->ob_type->tp_name)
+ Printf("expected a sequence for attr %s, got %s instead",
+ attr_name.data(), seq->ob_type->tp_name)
.data());
return false;
@@ -1894,6 +1915,9 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
py_attr_value, &attr_list_sizes, status);
if (TF_GetCode(status) != TF_OK) {
+ VLOG(1) << "Falling back to slow path for Op \"" << op_def->name()
+ << "\" since we are unable to set the value for attr \""
+ << attr.name() << "\" due to: " << TF_Message(status);
RaiseFallbackException(TF_Message(status));
return nullptr;
}
@@ -1940,8 +1964,8 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i);
if (!input_arg.number_attr().empty()) {
// The item is a homogeneous list.
- if (!RaiseIfNotPyList(input, input_arg.number_attr())) return nullptr;
- Py_ssize_t len = PyList_Size(input);
+ if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
+ Py_ssize_t len = PySequence_Fast_GET_SIZE(input);
TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len);
if (op_exec_info.run_callbacks) {
@@ -1953,15 +1977,15 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
if (len > 0) {
// First item adds the type attr.
- if (!AddInputToOp(op_exec_info, PyList_GET_ITEM(input, 0), &input_arg,
- flattened_attrs.get(), flattened_inputs.get(), op,
- status)) {
+ if (!AddInputToOp(op_exec_info, PySequence_Fast_GET_ITEM(input, 0),
+ &input_arg, flattened_attrs.get(),
+ flattened_inputs.get(), op, status)) {
return nullptr;
}
for (Py_ssize_t j = 1; j < len; j++) {
// Since the list is homogeneous, we don't need to re-add the attr.
- if (!AddInputToOp(op_exec_info, PyList_GET_ITEM(input, j),
+ if (!AddInputToOp(op_exec_info, PySequence_Fast_GET_ITEM(input, j),
nullptr /* input_arg */,
nullptr /* flattened_attrs */,
flattened_inputs.get(), op, status)) {
@@ -1971,16 +1995,18 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
}
} else if (!input_arg.type_list_attr().empty()) {
// The item is a heterogeneous list.
- if (!RaiseIfNotPyList(input, input_arg.type_list_attr())) return nullptr;
+ if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) {
+ return nullptr;
+ }
const string& attr_name = input_arg.type_list_attr();
- Py_ssize_t len = PyList_Size(input);
+ Py_ssize_t len = PySequence_Fast_GET_SIZE(input);
tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len);
PyObject* py_attr_value = nullptr;
if (op_exec_info.run_callbacks) {
py_attr_value = PyTuple_New(len);
}
for (Py_ssize_t j = 0; j < len; j++) {
- PyObject* py_input = PyList_GET_ITEM(input, j);
+ PyObject* py_input = PySequence_Fast_GET_ITEM(input, j);
tensorflow::Safe_PyObjectPtr py_eager_tensor;
if (!ConvertToTensor(op_exec_info, py_input, &py_eager_tensor,
status)) {
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 5afb5a7dd5..f93bc221cc 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -9,24 +9,13 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
py_library(
name = "estimator_py",
srcs = ["estimator_lib.py"],
srcs_version = "PY2AND3",
deps = [
":baseline",
+ ":boosted_trees",
":dnn",
":dnn_linear_combined",
":estimator",
@@ -252,6 +241,53 @@ py_test(
)
py_library(
+ name = "boosted_trees",
+ srcs = ["canned/boosted_trees.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":estimator",
+ ":head",
+ ":model_fn",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:boosted_trees_ops",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:lookup_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/feature_column",
+ "//tensorflow/python/ops/losses",
+ ],
+)
+
+py_test(
+ name = "boosted_trees_test",
+ size = "medium",
+ srcs = ["canned/boosted_trees_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":boosted_trees",
+ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:resources",
+ "//tensorflow/python:training",
+ "//tensorflow/python/estimator:numpy_io",
+ "//tensorflow/python/feature_column",
+ ],
+)
+
+py_library(
name = "dnn",
srcs = ["canned/dnn.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
new file mode 100644
index 0000000000..a9bbabd598
--- /dev/null
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -0,0 +1,736 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Estimator classes for BoostedTrees."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator.canned import head as head_lib
+from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import boosted_trees_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
+from tensorflow.python.summary import summary
+from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import training_util
+from tensorflow.python.util.tf_export import tf_export
+
+TreeHParams = collections.namedtuple(
+ 'TreeHParams',
+ ['n_trees', 'max_depth', 'learning_rate', 'l1', 'l2', 'tree_complexity'])
+
+_HOLD_FOR_MULTI_CLASS_SUPPORT = object()
+_HOLD_FOR_MULTI_DIM_SUPPORT = object()
+
+
+def _get_transformed_features(features, feature_columns):
+ """Gets the transformed features from features/feature_columns pair.
+
+ Args:
+ features: a dicionary of name to Tensor.
+ feature_columns: a list/set of tf.feature_column.
+
+ Returns:
+ result_features: a list of the transformed features, sorted by the name.
+ num_buckets: the maximum number of buckets across bucketized_columns.
+
+ Raises:
+ ValueError: when unsupported features/columns are tried.
+ """
+ num_buckets = 1
+ # pylint:disable=protected-access
+ for fc in feature_columns:
+ if isinstance(fc, feature_column_lib._BucketizedColumn):
+ # N boundaries creates (N+1) buckets.
+ num_buckets = max(num_buckets, len(fc.boundaries) + 1)
+ else:
+ raise ValueError('For now, only bucketized_column is supported but '
+ 'got: {}'.format(fc))
+ transformed = feature_column_lib._transform_features(features,
+ feature_columns)
+ # pylint:enable=protected-access
+ result_features = []
+ for column in sorted(transformed, key=lambda tc: tc.name):
+ source_name = column.source_column.name
+ squeezed_tensor = array_ops.squeeze(transformed[column], axis=1)
+ if len(squeezed_tensor.shape) > 1:
+ raise ValueError('For now, only supports features equivalent to rank 1 '
+ 'but column `{}` got: {}'.format(
+ source_name, features[source_name].shape))
+ result_features.append(squeezed_tensor)
+ return result_features, num_buckets
+
+
+def _keep_as_local_variable(tensor, name=None):
+ """Stores a tensor as a local Variable for faster read."""
+ return variable_scope.variable(
+ initial_value=tensor,
+ trainable=False,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ validate_shape=False,
+ name=name)
+
+
+class _CacheTrainingStatesUsingHashTable(object):
+ """Caching logits, etc. using MutableHashTable."""
+
+ def __init__(self, example_ids, logits_dimension):
+ """Creates a cache with the given configuration.
+
+ It maintains a MutableDenseHashTable for all values.
+ The API lookup() and insert() would have those specs,
+ tree_ids: shape=[batch_size], dtype=int32
+ node_ids: shape=[batch_size], dtype=int32
+ logits: shape=[batch_size, logits_dimension], dtype=float32
+ However in the MutableDenseHashTable, ids are bitcasted into float32 and
+ all values are concatenated as a single tensor (of float32).
+
+ Hence conversion happens internally before inserting to the HashTable and
+ after lookup from it.
+
+ Args:
+ example_ids: a Rank 1 tensor to be used as a key of the cache.
+ logits_dimension: a constant (int) for the dimension of logits.
+
+ Raises:
+ ValueError: if example_ids is other than int64 or string.
+ """
+ if dtypes.as_dtype(dtypes.int64).is_compatible_with(example_ids.dtype):
+ empty_key = -1 << 62
+ elif dtypes.as_dtype(dtypes.string).is_compatible_with(example_ids.dtype):
+ empty_key = ''
+ else:
+ raise ValueError('Unsupported example_id_feature dtype %s.',
+ example_ids.dtype)
+ # Cache holds latest <tree_id, node_id, logits> for each example.
+ # tree_id and node_id are both int32 but logits is a float32.
+ # To reduce the overhead, we store all of them together as float32 and
+ # bitcast the ids to int32.
+ self._table_ref = lookup_ops.mutable_dense_hash_table_v2(
+ empty_key=empty_key, value_dtype=dtypes.float32, value_shape=[3])
+ self._example_ids = example_ids
+ self._logits_dimension = logits_dimension
+
+ def lookup(self):
+ """Returns cached_tree_ids, cached_node_ids, cached_logits."""
+ cached_tree_ids, cached_node_ids, cached_logits = array_ops.split(
+ lookup_ops.lookup_table_find_v2(
+ self._table_ref, self._example_ids, default_value=[0.0, 0.0, 0.0]),
+ [1, 1, self._logits_dimension],
+ axis=1)
+ cached_tree_ids = array_ops.squeeze(
+ array_ops.bitcast(cached_tree_ids, dtypes.int32))
+ cached_node_ids = array_ops.squeeze(
+ array_ops.bitcast(cached_node_ids, dtypes.int32))
+ return (cached_tree_ids, cached_node_ids, cached_logits)
+
+ def insert(self, tree_ids, node_ids, logits):
+ """Inserts values and returns the op."""
+ insert_op = lookup_ops.lookup_table_insert_v2(
+ self._table_ref, self._example_ids,
+ array_ops.concat(
+ [
+ array_ops.expand_dims(
+ array_ops.bitcast(tree_ids, dtypes.float32), 1),
+ array_ops.expand_dims(
+ array_ops.bitcast(node_ids, dtypes.float32), 1),
+ logits,
+ ],
+ axis=1,
+ name='value_concat_for_cache_insert'))
+ return insert_op
+
+
+class _CacheTrainingStatesUsingVariables(object):
+ """Caching logits, etc. using Variables."""
+
+ def __init__(self, batch_size, logits_dimension):
+ """Creates a cache with the given configuration.
+
+ It maintains three variables, tree_ids, node_ids, logits, for caching.
+ tree_ids: shape=[batch_size], dtype=int32
+ node_ids: shape=[batch_size], dtype=int32
+ logits: shape=[batch_size, logits_dimension], dtype=float32
+
+ Note, this can be used only with in-memory data setting.
+
+ Args:
+ batch_size: `int`, the size of the cache.
+ logits_dimension: a constant (int) for the dimension of logits.
+ """
+ self._logits_dimension = logits_dimension
+ self._tree_ids = _keep_as_local_variable(
+ array_ops.zeros([batch_size], dtype=dtypes.int32),
+ name='tree_ids_cache')
+ self._node_ids = _keep_as_local_variable(
+ array_ops.zeros([batch_size], dtype=dtypes.int32),
+ name='node_ids_cache')
+ self._logits = _keep_as_local_variable(
+ array_ops.zeros([batch_size, logits_dimension], dtype=dtypes.float32),
+ name='logits_cache')
+
+ def lookup(self):
+ """Returns cached_tree_ids, cached_node_ids, cached_logits."""
+ return (self._tree_ids, self._node_ids, self._logits)
+
+ def insert(self, tree_ids, node_ids, logits):
+ """Inserts values and returns the op."""
+ return control_flow_ops.group(
+ [
+ self._tree_ids.assign(tree_ids),
+ self._node_ids.assign(node_ids),
+ self._logits.assign(logits)
+ ],
+ name='cache_insert')
+
+
+class StopAtAttemptsHook(session_run_hook.SessionRunHook):
+ """Hook that requests stop at the number of trees."""
+
+ def __init__(self, num_finalized_trees_tensor, num_attempted_layers_tensor,
+ max_trees, max_depth):
+ self._num_finalized_trees_tensor = num_finalized_trees_tensor
+ self._num_attempted_layers_tensor = num_attempted_layers_tensor
+ self._max_trees = max_trees
+ self._max_depth = max_depth
+
+ def before_run(self, run_context):
+ return session_run_hook.SessionRunArgs(
+ [self._num_finalized_trees_tensor, self._num_attempted_layers_tensor])
+
+ def after_run(self, run_context, run_values):
+ num_finalized_trees, num_attempted_layers = run_values.results
+ if (num_finalized_trees >= self._max_trees or
+ 1.0 * num_attempted_layers / self._max_depth > 2 * self._max_trees):
+ run_context.request_stop()
+
+
+class StopAtNumTreesHook(session_run_hook.SessionRunHook):
+ """Hook that requests stop at the number of trees."""
+
+ def __init__(self, num_trees_tensor, max_trees):
+ self._num_trees_tensor = num_trees_tensor
+ self._max_trees = max_trees
+
+ def before_run(self, run_context):
+ return session_run_hook.SessionRunArgs(self._num_trees_tensor)
+
+ def after_run(self, run_context, run_values):
+ num_trees = run_values.results
+ if num_trees > self._max_trees:
+ run_context.request_stop()
+
+
+def _bt_model_fn(
+ features,
+ labels,
+ mode,
+ head,
+ feature_columns,
+ tree_hparams,
+ n_batches_per_layer,
+ config,
+ closed_form_grad_and_hess_fn=None,
+ example_id_column_name=None,
+ # TODO(youngheek): replace this later using other options.
+ train_in_memory=False,
+ name='TreeEnsembleModel'):
+ """Gradient Boosted Decision Tree model_fn.
+
+ Args:
+ features: dict of `Tensor`.
+ labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
+ dtype `int32` or `int64` in the range `[0, n_classes)`.
+ mode: Defines whether this is training, evaluation or prediction.
+ See `ModeKeys`.
+ head: A `head_lib._Head` instance.
+ feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.
+ tree_hparams: TODO. collections.namedtuple for hyper parameters.
+ n_batches_per_layer: A `Tensor` of `int64`. Each layer is built after at
+ least n_batches_per_layer accumulations.
+ config: `RunConfig` object to configure the runtime settings.
+ closed_form_grad_and_hess_fn: a function that accepts logits and labels
+ and returns gradients and hessians. By default, they are created by
+ tf.gradients() from the loss.
+ example_id_column_name: Name of the feature for a unique ID per example.
+ Currently experimental -- not exposed to public API.
+ train_in_memory: `bool`, when true, it assumes the dataset is in memory,
+ i.e., input_fn should return the entire dataset as a single batch, and
+ also n_batches_per_layer should be set as 1.
+ name: Name to use for the model.
+
+ Returns:
+ An `EstimatorSpec` instance.
+
+ Raises:
+ ValueError: mode or params are invalid, or features has the wrong type.
+ """
+ is_single_machine = (config.num_worker_replicas == 1)
+ if train_in_memory:
+ assert n_batches_per_layer == 1, (
+ 'When train_in_memory is enabled, input_fn should return the entire '
+ 'dataset as a single batch, and n_batches_per_layer should be set as '
+ '1.')
+ worker_device = control_flow_ops.no_op().device
+ # maximum number of splits possible in the whole tree =2^(D-1)-1
+ # TODO(youngheek): perhaps storage could be optimized by storing stats with
+ # the dimension max_splits_per_layer, instead of max_splits (for the entire
+ # tree).
+ max_splits = (1 << tree_hparams.max_depth) - 1
+ with ops.name_scope(name) as name:
+ # Prepare.
+ global_step = training_util.get_or_create_global_step()
+ input_feature_list, num_buckets = _get_transformed_features(
+ features, feature_columns)
+ if train_in_memory and mode == model_fn.ModeKeys.TRAIN:
+ input_feature_list = [
+ _keep_as_local_variable(feature) for feature in input_feature_list
+ ]
+ num_features = len(input_feature_list)
+
+ cache = None
+ if mode == model_fn.ModeKeys.TRAIN:
+ if train_in_memory and is_single_machine: # maybe just train_in_memory?
+ batch_size = array_ops.shape(input_feature_list[0])[0]
+ cache = _CacheTrainingStatesUsingVariables(batch_size,
+ head.logits_dimension)
+ elif example_id_column_name:
+ example_ids = features[example_id_column_name]
+ cache = _CacheTrainingStatesUsingHashTable(example_ids,
+ head.logits_dimension)
+
+ # Create Ensemble resources.
+ if is_single_machine:
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+ local_tree_ensemble = tree_ensemble
+ ensemble_reload = control_flow_ops.no_op()
+ else:
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+ with ops.device(worker_device):
+ local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ name=name + '_local', is_local=True)
+ # TODO(soroush): Do partial updates if this becomes a bottleneck.
+ ensemble_reload = local_tree_ensemble.deserialize(
+ *tree_ensemble.serialize())
+
+ # Create logits.
+ if mode != model_fn.ModeKeys.TRAIN:
+ logits = boosted_trees_ops.predict(
+ tree_ensemble_handle=local_tree_ensemble.resource_handle,
+ bucketized_features=input_feature_list,
+ logits_dimension=head.logits_dimension,
+ max_depth=tree_hparams.max_depth)
+ else:
+ if cache:
+ cached_tree_ids, cached_node_ids, cached_logits = cache.lookup()
+ else:
+ # Always start from the beginning when no cache is set up.
+ batch_size = array_ops.shape(input_feature_list[0])[0]
+ cached_tree_ids, cached_node_ids, cached_logits = (
+ array_ops.zeros([batch_size], dtype=dtypes.int32),
+ array_ops.zeros([batch_size], dtype=dtypes.int32),
+ array_ops.zeros(
+ [batch_size, head.logits_dimension], dtype=dtypes.float32))
+ with ops.control_dependencies([ensemble_reload]):
+ (stamp_token, num_trees, num_finalized_trees,
+ num_attempted_layers) = local_tree_ensemble.get_states()
+ summary.scalar('ensemble/num_trees', num_trees)
+ summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
+ summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
+
+ partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
+ tree_ensemble_handle=local_tree_ensemble.resource_handle,
+ cached_tree_ids=cached_tree_ids,
+ cached_node_ids=cached_node_ids,
+ bucketized_features=input_feature_list,
+ logits_dimension=head.logits_dimension,
+ max_depth=tree_hparams.max_depth)
+ logits = cached_logits + partial_logits
+
+ # Create training graph.
+ def _train_op_fn(loss):
+ """Run one training iteration."""
+ train_op = []
+ if cache:
+ train_op.append(cache.insert(tree_ids, node_ids, logits))
+ if closed_form_grad_and_hess_fn:
+ gradients, hessians = closed_form_grad_and_hess_fn(logits, labels)
+ else:
+ gradients = gradients_impl.gradients(loss, logits, name='Gradients')[0]
+ hessians = gradients_impl.gradients(
+ gradients, logits, name='Hessians')[0]
+ stats_summary_list = [
+ array_ops.squeeze(
+ boosted_trees_ops.make_stats_summary(
+ node_ids=node_ids,
+ gradients=gradients,
+ hessians=hessians,
+ bucketized_features_list=[input_feature_list[f]],
+ max_splits=max_splits,
+ num_buckets=num_buckets),
+ axis=0) for f in range(num_features)
+ ]
+
+ def grow_tree_from_stats_summaries(stats_summary_list):
+ """Updates ensemble based on the best gains from stats summaries."""
+ (node_ids_per_feature, gains_list, thresholds_list,
+ left_node_contribs_list, right_node_contribs_list) = (
+ boosted_trees_ops.calculate_best_gains_per_feature(
+ node_id_range=array_ops.stack([
+ math_ops.reduce_min(node_ids),
+ math_ops.reduce_max(node_ids)
+ ]),
+ stats_summary_list=stats_summary_list,
+ l1=tree_hparams.l1,
+ l2=tree_hparams.l2,
+ tree_complexity=tree_hparams.tree_complexity,
+ max_splits=max_splits))
+ grow_op = boosted_trees_ops.update_ensemble(
+ # Confirm if local_tree_ensemble or tree_ensemble should be used.
+ tree_ensemble.resource_handle,
+ feature_ids=math_ops.range(0, num_features, dtype=dtypes.int32),
+ node_ids=node_ids_per_feature,
+ gains=gains_list,
+ thresholds=thresholds_list,
+ left_node_contribs=left_node_contribs_list,
+ right_node_contribs=right_node_contribs_list,
+ learning_rate=tree_hparams.learning_rate,
+ max_depth=tree_hparams.max_depth,
+ pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING)
+ return grow_op
+
+ if train_in_memory and is_single_machine:
+ train_op.append(state_ops.assign_add(global_step, 1))
+ train_op.append(grow_tree_from_stats_summaries(stats_summary_list))
+ else:
+ summary_accumulator = data_flow_ops.ConditionalAccumulator(
+ dtype=dtypes.float32,
+ # The stats consist of gradients and hessians (the last dimension).
+ shape=[num_features, max_splits, num_buckets, 2],
+ shared_name='stats_summary_accumulator')
+ apply_grad = summary_accumulator.apply_grad(
+ array_ops.stack(stats_summary_list, axis=0), stamp_token)
+
+ def grow_tree_from_accumulated_summaries_fn():
+ """Updates the tree with the best layer from accumulated summaries."""
+ # Take out the accumulated summaries from the accumulator and grow.
+ stats_summary_list = array_ops.unstack(
+ summary_accumulator.take_grad(1), axis=0)
+ grow_op = grow_tree_from_stats_summaries(stats_summary_list)
+ return grow_op
+
+ with ops.control_dependencies([apply_grad]):
+ train_op.append(state_ops.assign_add(global_step, 1))
+ if config.is_chief:
+ train_op.append(
+ control_flow_ops.cond(
+ math_ops.greater_equal(
+ summary_accumulator.num_accumulated(),
+ n_batches_per_layer),
+ grow_tree_from_accumulated_summaries_fn,
+ control_flow_ops.no_op,
+ name='wait_until_n_batches_accumulated'))
+
+ return control_flow_ops.group(train_op, name='train_op')
+
+ estimator_spec = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+ if mode == model_fn.ModeKeys.TRAIN:
+ # Add an early stop hook.
+ estimator_spec = estimator_spec._replace(
+ training_hooks=estimator_spec.training_hooks +
+ (StopAtNumTreesHook(num_trees, tree_hparams.n_trees),))
+ return estimator_spec
+
+
+def _create_classification_head(n_classes,
+ weight_column=None,
+ label_vocabulary=None):
+ """Creates a classification head. Refer to canned.head for details on args."""
+ # TODO(nponomareva): Support multi-class cases.
+ if n_classes == 2:
+ # pylint: disable=protected-access
+ return head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ weight_column=weight_column,
+ label_vocabulary=label_vocabulary,
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+ # pylint: enable=protected-access
+ else:
+ raise ValueError('For now only binary classification is supported.'
+ 'n_classes given as {}'.format(n_classes))
+
+
+def _create_classification_head_and_closed_form(n_classes, weight_column,
+ label_vocabulary):
+ """Creates a head for classifier and the closed form gradients/hessians."""
+ head = _create_classification_head(n_classes, weight_column, label_vocabulary)
+ if n_classes == 2 and weight_column is None and label_vocabulary is None:
+ # Use the closed-form gradients/hessians for 2 class.
+ def _grad_and_hess_for_logloss(logits, labels):
+ # TODO(youngheek): add weights handling.
+ predictions = math_ops.reciprocal(math_ops.exp(-logits) + 1.0)
+ normalizer = math_ops.reciprocal(
+ math_ops.cast(array_ops.size(predictions), dtypes.float32))
+ gradients = (predictions - labels) * normalizer
+ hessians = predictions * (1.0 - predictions) * normalizer
+ return gradients, hessians
+
+ closed_form = _grad_and_hess_for_logloss
+ else:
+ closed_form = None
+ return (head, closed_form)
+
+
+def _create_regression_head(label_dimension, weight_column=None):
+ if label_dimension != 1:
+ raise ValueError('For now only 1 dimension regression is supported.'
+ 'label_dimension given as {}'.format(label_dimension))
+ # pylint: disable=protected-access
+ return head_lib._regression_head_with_mean_squared_error_loss(
+ label_dimension=label_dimension,
+ weight_column=weight_column,
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+ # pylint: enable=protected-access
+
+
+@tf_export('estimator.BoostedTreesClassifier')
+class BoostedTreesClassifier(estimator.Estimator):
+ """A Classifier for Tensorflow Boosted Trees models."""
+
+ def __init__(
+ self,
+ feature_columns,
+ n_batches_per_layer,
+ model_dir=None,
+ n_classes=_HOLD_FOR_MULTI_CLASS_SUPPORT,
+ weight_column=None,
+ label_vocabulary=None,
+ n_trees=100,
+ max_depth=6,
+ learning_rate=0.1,
+ l1_regularization=0.,
+ l2_regularization=0.,
+ tree_complexity=0.,
+ config=None):
+ """Initializes a `BoostedTreesClassifier` instance.
+
+ Example:
+
+ ```python
+ bucketized_feature_1 = bucketized_column(
+ numeric_column('feature_1'), BUCKET_BOUNDARIES_1)
+ bucketized_feature_2 = bucketized_column(
+ numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
+
+ classifier = estimator.BoostedTreesClassifier(
+ feature_columns=[bucketized_feature_1, bucketized_feature_2],
+ n_trees=100,
+ ... <some other params>
+ )
+
+ def input_fn_train():
+ ...
+ return dataset
+
+ classifier.train(input_fn=input_fn_train)
+
+ def input_fn_eval():
+ ...
+ return dataset
+
+ metrics = classifier.evaluate(input_fn=input_fn_eval)
+ ```
+
+ Args:
+ feature_columns: An iterable containing all the feature columns used by
+ the model. All items in the set should be instances of classes derived
+ from `FeatureColumn`.
+ n_batches_per_layer: the number of batches to collect statistics per
+ layer.
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator
+ to continue training a previously saved model.
+ n_classes: number of label classes. Default is binary classification.
+ Multiclass support is not yet implemented.
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It is used to downweight or boost examples during training. It
+ will be multiplied by the loss of the example. If it is a string, it is
+ used as a key to fetch weight tensor from the `features`. If it is a
+ `_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
+ then weight_column.normalizer_fn is applied on it to get weight tensor.
+ label_vocabulary: A list of strings represents possible label values. If
+ given, labels must be string type and have any value in
+ `label_vocabulary`. If it is not given, that means labels are
+ already encoded as integer or float within [0, 1] for `n_classes=2` and
+ encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
+ Also there will be errors if vocabulary is not provided and labels are
+ string.
+ n_trees: number trees to be created.
+ max_depth: maximum depth of the tree to grow.
+ learning_rate: shrinkage parameter to be used when a tree added to the
+ model.
+ l1_regularization: regularization multiplier applied to the absolute
+ weights of the tree leafs.
+ l2_regularization: regularization multiplier applied to the square weights
+ of the tree leafs.
+ tree_complexity: regularization factor to penalize trees with more leaves.
+ config: `RunConfig` object to configure the runtime settings.
+
+ Raises:
+ ValueError: when wrong arguments are given or unsupported functionalities
+ are requested.
+ """
+ # TODO(nponomareva): Support multi-class cases.
+ if n_classes == _HOLD_FOR_MULTI_CLASS_SUPPORT:
+ n_classes = 2
+ head, closed_form = _create_classification_head_and_closed_form(
+ n_classes, weight_column, label_vocabulary=label_vocabulary)
+
+ # HParams for the model.
+ tree_hparams = TreeHParams(
+ n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
+ tree_complexity)
+
+ def _model_fn(features, labels, mode, config):
+ return _bt_model_fn( # pylint: disable=protected-access
+ features,
+ labels,
+ mode,
+ head,
+ feature_columns,
+ tree_hparams,
+ n_batches_per_layer,
+ config,
+ closed_form_grad_and_hess_fn=closed_form)
+
+ super(BoostedTreesClassifier, self).__init__(
+ model_fn=_model_fn, model_dir=model_dir, config=config)
+
+
+@tf_export('estimator.BoostedTreesRegressor')
+class BoostedTreesRegressor(estimator.Estimator):
+ """A Regressor for Tensorflow Boosted Trees models."""
+
+ def __init__(
+ self,
+ feature_columns,
+ n_batches_per_layer,
+ model_dir=None,
+ label_dimension=_HOLD_FOR_MULTI_DIM_SUPPORT,
+ weight_column=None,
+ n_trees=100,
+ max_depth=6,
+ learning_rate=0.1,
+ l1_regularization=0.,
+ l2_regularization=0.,
+ tree_complexity=0.,
+ config=None):
+ """Initializes a `BoostedTreesRegressor` instance.
+
+ Example:
+
+ ```python
+ bucketized_feature_1 = bucketized_column(
+ numeric_column('feature_1'), BUCKET_BOUNDARIES_1)
+ bucketized_feature_2 = bucketized_column(
+ numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
+
+ regressor = estimator.BoostedTreesRegressor(
+ feature_columns=[bucketized_feature_1, bucketized_feature_2],
+ n_trees=100,
+ ... <some other params>
+ )
+
+ def input_fn_train():
+ ...
+ return dataset
+
+ regressor.train(input_fn=input_fn_train)
+
+ def input_fn_eval():
+ ...
+ return dataset
+
+ metrics = regressor.evaluate(input_fn=input_fn_eval)
+ ```
+
+ Args:
+ feature_columns: An iterable containing all the feature columns used by
+ the model. All items in the set should be instances of classes derived
+ from `FeatureColumn`.
+ n_batches_per_layer: the number of batches to collect statistics per
+ layer.
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator
+ to continue training a previously saved model.
+ label_dimension: Number of regression targets per example.
+ Multi-dimensional support is not yet implemented.
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It is used to downweight or boost examples during training. It
+ will be multiplied by the loss of the example. If it is a string, it is
+ used as a key to fetch weight tensor from the `features`. If it is a
+ `_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
+ then weight_column.normalizer_fn is applied on it to get weight tensor.
+ n_trees: number trees to be created.
+ max_depth: maximum depth of the tree to grow.
+ learning_rate: shrinkage parameter to be used when a tree added to the
+ model.
+ l1_regularization: regularization multiplier applied to the absolute
+ weights of the tree leafs.
+ l2_regularization: regularization multiplier applied to the square weights
+ of the tree leafs.
+ tree_complexity: regularization factor to penalize trees with more leaves.
+ config: `RunConfig` object to configure the runtime settings.
+
+ Raises:
+ ValueError: when wrong arguments are given or unsupported functionalities
+ are requested.
+ """
+ # TODO(nponomareva): Extend it to multi-dimension cases.
+ if label_dimension == _HOLD_FOR_MULTI_DIM_SUPPORT:
+ label_dimension = 1
+ head = _create_regression_head(label_dimension, weight_column)
+
+ # HParams for the model.
+ tree_hparams = TreeHParams(
+ n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
+ tree_complexity)
+
+ def _model_fn(features, labels, mode, config):
+ return _bt_model_fn( # pylint: disable=protected-access
+ features, labels, mode, head, feature_columns, tree_hparams,
+ n_batches_per_layer, config)
+
+ super(BoostedTreesRegressor, self).__init__(
+ model_fn=_model_fn, model_dir=model_dir, config=config)
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
new file mode 100644
index 0000000000..9276fbaaa1
--- /dev/null
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -0,0 +1,799 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests boosted_trees estimators and model_fn."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator import run_config
+from tensorflow.python.estimator.canned import boosted_trees
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import gen_boosted_trees_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import checkpoint_utils
+
+NUM_FEATURES = 3
+
+BUCKET_BOUNDARIES = [-2., .5, 12.] # Boundaries for all the features.
+INPUT_FEATURES = np.array(
+ [
+ [12.5, 1.0, -2.001, -2.0001, -1.999], # feature_0 quantized:[3,2,0,0,1]
+ [2.0, -3.0, 0.5, 0.0, 0.4995], # feature_1 quantized:[2,0,2,1,1]
+ [3.0, 20.0, 50.0, -100.0, 102.75], # feature_2 quantized:[2,3,3,0,3]
+ ],
+ dtype=np.float32)
+CLASSIFICATION_LABELS = [[0.], [1.], [1.], [0.], [0.]]
+REGRESSION_LABELS = [[1.5], [0.3], [0.2], [2.], [5.]]
+FEATURES_DICT = {'f_%d' % i: INPUT_FEATURES[i] for i in range(NUM_FEATURES)}
+
+# EXAMPLE_ID is not exposed to Estimator yet, but supported at model_fn level.
+EXAMPLE_IDS = np.array([0, 1, 2, 3, 4], dtype=np.int64)
+EXAMPLE_ID_COLUMN = '__example_id__'
+
+
+def _make_train_input_fn(is_classification):
+ """Makes train input_fn for classification/regression."""
+
+ def _input_fn():
+ features = dict(FEATURES_DICT)
+ features[EXAMPLE_ID_COLUMN] = constant_op.constant(EXAMPLE_IDS)
+ if is_classification:
+ labels = CLASSIFICATION_LABELS
+ else:
+ labels = REGRESSION_LABELS
+ return features, labels
+
+ return _input_fn
+
+
+class BoostedTreesClassifierTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._feature_columns = {
+ feature_column.bucketized_column(
+ feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+ BUCKET_BOUNDARIES)
+ for i in range(NUM_FEATURES)
+ }
+
+ def _assert_checkpoint(self, model_dir, expected_global_step):
+ self.assertEqual(expected_global_step,
+ checkpoint_utils.load_variable(model_dir,
+ ops.GraphKeys.GLOBAL_STEP))
+
+ def testTrainAndEvaluateBinaryClassifier(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+
+ # It will stop after 5 steps because of the max depth and num trees.
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+ self._assert_checkpoint(est.model_dir, 6)
+ eval_res = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertAllClose(eval_res['accuracy'], 1.0)
+
+ def testInferBinaryClassifier(self):
+ train_input_fn = _make_train_input_fn(is_classification=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+
+ # It will stop after 5 steps because of the max depth and num trees.
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(train_input_fn, steps=num_steps)
+
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertEquals(5, len(predictions))
+ # All labels are correct.
+ self.assertAllClose([0], predictions[0]['class_ids'])
+ self.assertAllClose([1], predictions[1]['class_ids'])
+ self.assertAllClose([1], predictions[2]['class_ids'])
+ self.assertAllClose([0], predictions[3]['class_ids'])
+ self.assertAllClose([0], predictions[4]['class_ids'])
+
+
+class BoostedTreesRegressionTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._feature_columns = {
+ feature_column.bucketized_column(
+ feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+ BUCKET_BOUNDARIES)
+ for i in range(NUM_FEATURES)
+ }
+
+ def _assert_checkpoint(self, model_dir, expected_global_step):
+ self.assertEqual(expected_global_step,
+ checkpoint_utils.load_variable(model_dir,
+ ops.GraphKeys.GLOBAL_STEP))
+
+ def testTrainAndEvaluateRegressor(self):
+ input_fn = _make_train_input_fn(is_classification=False)
+
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ max_depth=5)
+
+ # It will stop after 10 steps because of the max depth and num trees.
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+ self._assert_checkpoint(est.model_dir, 11)
+ eval_res = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertAllClose(eval_res['average_loss'], 0.913176)
+
+ def testInferRegressor(self):
+ train_input_fn = _make_train_input_fn(is_classification=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+
+ # It will stop after 5 steps because of the max depth and num trees.
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(train_input_fn, steps=num_steps)
+ self._assert_checkpoint(est.model_dir, 6)
+
+ predictions = list(est.predict(input_fn=predict_input_fn))
+
+ self.assertEquals(5, len(predictions))
+ self.assertAllClose([0.703549], predictions[0]['predictions'])
+ self.assertAllClose([0.266539], predictions[1]['predictions'])
+ self.assertAllClose([0.256479], predictions[2]['predictions'])
+ self.assertAllClose([1.088732], predictions[3]['predictions'])
+ self.assertAllClose([1.901732], predictions[4]['predictions'])
+
+
+class ModelFnTests(test_util.TensorFlowTestCase):
+ """Tests bt_model_fn including unexposed internal functionalities."""
+
+ def setUp(self):
+ self._feature_columns = {
+ feature_column.bucketized_column(
+ feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+ BUCKET_BOUNDARIES) for i in range(NUM_FEATURES)
+ }
+ self._tree_hparams = boosted_trees.TreeHParams(
+ n_trees=2,
+ max_depth=2,
+ learning_rate=0.1,
+ l1=0.,
+ l2=0.01,
+ tree_complexity=0.)
+
+ def _get_expected_ensembles_for_classification(self):
+ first_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ threshold: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 0.387675
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.181818
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0625
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }
+ """
+ second_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ threshold: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 0.387675
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 3
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 0.0
+ original_leaf {
+ scalar: -0.181818
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 0
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 0.105518
+ original_leaf {
+ scalar: 0.0625
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.348397
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.181818
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.224091
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.056815
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 2
+ is_finalized: true
+ }
+ tree_metadata {
+ num_layers_grown: 0
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """
+ third_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ threshold: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 0.387675
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 3
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 0.0
+ original_leaf {
+ scalar: -0.181818
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 0
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 0.105518
+ original_leaf {
+ scalar: 0.0625
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.348397
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.181818
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.224091
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.056815
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 0
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 0.287131
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.162042
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.086986
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 2
+ is_finalized: true
+ }
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 2
+ num_layers_attempted: 3
+ }
+ """
+ return (first_round, second_round, third_round)
+
+ def _get_expected_ensembles_for_regression(self):
+ first_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 1
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.169714
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.241322
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.083951
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }
+ """
+ second_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 1
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.169714
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 1
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 2.673407
+ original_leaf {
+ scalar: 0.241322
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 0
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 0.324102
+ original_leaf {
+ scalar: 0.083951
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.563167
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.247047
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.095273
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.222102
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 2
+ is_finalized: true
+ }
+ tree_metadata {
+ num_layers_grown: 0
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """
+ third_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 1
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.169714
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 1
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 2.673407
+ original_leaf {
+ scalar: 0.241322
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 0
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 0.324102
+ original_leaf {
+ scalar: 0.083951
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.563167
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.247047
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.095273
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.222102
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 0
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 0.981026
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.005166
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.180281
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 2
+ is_finalized: true
+ }
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 2
+ num_layers_attempted: 3
+ }
+ """
+ return (first_round, second_round, third_round)
+
+ def _get_train_op_and_ensemble(self, head, config, is_classification,
+ train_in_memory):
+ """Calls bt_model_fn() and returns the train_op and ensemble_serialzed."""
+ features, labels = _make_train_input_fn(is_classification)()
+ estimator_spec = boosted_trees._bt_model_fn( # pylint:disable=protected-access
+ features=features,
+ labels=labels,
+ mode=model_fn.ModeKeys.TRAIN,
+ head=head,
+ feature_columns=self._feature_columns,
+ tree_hparams=self._tree_hparams,
+ example_id_column_name=EXAMPLE_ID_COLUMN,
+ n_batches_per_layer=1,
+ config=config,
+ train_in_memory=train_in_memory)
+ resources.initialize_resources(resources.shared_resources()).run()
+ variables.global_variables_initializer().run()
+ variables.local_variables_initializer().run()
+
+ # Gets the train_op and serialized proto of the ensemble.
+ shared_resources = resources.shared_resources()
+ self.assertEqual(1, len(shared_resources))
+ train_op = estimator_spec.train_op
+ with ops.control_dependencies([train_op]):
+ _, ensemble_serialized = (
+ gen_boosted_trees_ops.boosted_trees_serialize_ensemble(
+ shared_resources[0].handle))
+ return train_op, ensemble_serialized
+
+ def testTrainClassifierInMemory(self):
+ ops.reset_default_graph()
+ expected_first, expected_second, expected_third = (
+ self._get_expected_ensembles_for_classification())
+ with self.test_session() as sess:
+ # Train with train_in_memory mode.
+ with sess.graph.as_default():
+ train_op, ensemble_serialized = self._get_train_op_and_ensemble(
+ boosted_trees._create_classification_head(n_classes=2),
+ run_config.RunConfig(),
+ is_classification=True,
+ train_in_memory=True)
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ # Validate the trained ensemble.
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_first, ensemble_proto)
+
+ # Run one more time and validate the trained ensemble.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_second, ensemble_proto)
+
+ # Third round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_third, ensemble_proto)
+
+ def testTrainClassifierNonInMemory(self):
+ ops.reset_default_graph()
+ expected_first, expected_second, expected_third = (
+ self._get_expected_ensembles_for_classification())
+ with self.test_session() as sess:
+ # Train without train_in_memory mode.
+ with sess.graph.as_default():
+ train_op, ensemble_serialized = self._get_train_op_and_ensemble(
+ boosted_trees._create_classification_head(n_classes=2),
+ run_config.RunConfig(),
+ is_classification=True,
+ train_in_memory=False)
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ # Validate the trained ensemble.
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_first, ensemble_proto)
+
+ # Run one more time and validate the trained ensemble.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_second, ensemble_proto)
+
+ # Third round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_third, ensemble_proto)
+
+ def testTrainRegressorInMemory(self):
+ ops.reset_default_graph()
+ expected_first, expected_second, expected_third = (
+ self._get_expected_ensembles_for_regression())
+ with self.test_session() as sess:
+ # Train with train_in_memory mode.
+ with sess.graph.as_default():
+ train_op, ensemble_serialized = self._get_train_op_and_ensemble(
+ boosted_trees._create_regression_head(label_dimension=1),
+ run_config.RunConfig(),
+ is_classification=False,
+ train_in_memory=True)
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ # Validate the trained ensemble.
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_first, ensemble_proto)
+
+ # Run one more time and validate the trained ensemble.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_second, ensemble_proto)
+
+ # Third round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_third, ensemble_proto)
+
+ def testTrainRegressorNonInMemory(self):
+ ops.reset_default_graph()
+ expected_first, expected_second, expected_third = (
+ self._get_expected_ensembles_for_regression())
+ with self.test_session() as sess:
+ # Train without train_in_memory mode.
+ with sess.graph.as_default():
+ train_op, ensemble_serialized = self._get_train_op_and_ensemble(
+ boosted_trees._create_regression_head(label_dimension=1),
+ run_config.RunConfig(),
+ is_classification=False,
+ train_in_memory=False)
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ # Validate the trained ensemble.
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_first, ensemble_proto)
+
+ # Run one more time and validate the trained ensemble.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_second, ensemble_proto)
+
+ # Third round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_third, ensemble_proto)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 6a4132bca2..2fe521b063 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -41,8 +41,11 @@ from tensorflow.python.estimator.export.export import get_temp_export_dir
from tensorflow.python.estimator.export.export import get_timestamped_export_dir
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.ops import resources
+from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import builder as saved_model_builder
@@ -50,6 +53,7 @@ from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary import summary
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import device_setter
+from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import evaluation
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver
@@ -183,6 +187,9 @@ class Estimator(object):
config)
self._config = config
+ # The distribute field contains an instance of DistributionStrategy.
+ self._distribution = self._config.distribute
+
# Model directory.
model_dir = compat_internal.path_to_str(model_dir)
if (model_dir is not None) and (self._config.model_dir is not None):
@@ -682,11 +689,25 @@ class Estimator(object):
def _get_features_and_labels_from_input_fn(self, input_fn, mode):
"""Extracts the `features` and labels from return values of `input_fn`."""
result = self._call_input_fn(input_fn, mode)
+ # TODO(anjalisridhar): What about the default DistributionStrategy? Perhaps
+ # using any input is alright in that case. There is also a
+ # has_dataset_or_queue_runner function that we may want to extend and use.
+ if (self._distribution is not None and
+ not isinstance(result, dataset_ops.Dataset)):
+ raise ValueError('input_fn() must return a tf.data.Dataset when using a '
+ 'DistributionStrategy.')
input_hooks = []
if isinstance(result, dataset_ops.Dataset):
- iterator = result.make_initializable_iterator()
- input_hooks.append(_DatasetInitializerHook(iterator))
- result = iterator.get_next()
+ if self._distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN:
+ # TODO(josh11b): This is currently using a one-shot iterator, we
+ # will update this to an initializeable iterator once the
+ # necessory support for creating an initializable iterator is
+ # available.
+ result = self._distribution.distribute_dataset(result).get_next()
+ else:
+ iterator = result.make_initializable_iterator()
+ input_hooks.append(_DatasetInitializerHook(iterator))
+ result = iterator.get_next()
if isinstance(result, (list, tuple)):
if len(result) != 2:
raise ValueError(
@@ -815,6 +836,12 @@ class Estimator(object):
return model_fn_results
def _train_model(self, input_fn, hooks, saving_listeners):
+ if self._distribution:
+ return self._train_model_distributed(input_fn, hooks, saving_listeners)
+ else:
+ return self._train_model_default(input_fn, hooks, saving_listeners)
+
+ def _train_model_default(self, input_fn, hooks, saving_listeners):
worker_hooks = []
with ops.Graph().as_default() as g, g.device(self._device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
@@ -826,86 +853,209 @@ class Estimator(object):
worker_hooks.extend(input_hooks)
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
+ return self._train_with_estimator_spec(estimator_spec, worker_hooks,
+ hooks, global_step_tensor,
+ saving_listeners)
- if self._warm_start_settings:
- logging.info('Warm-starting with WarmStartSettings: %s' %
- (self._warm_start_settings,))
- # pylint: disable=protected-access
- warm_starting_util.warm_start(*self._warm_start_settings)
- # pylint: enable=protected-access
- # Check if the user created a loss summary, and add one if they didn't.
- # We assume here that the summary is called 'loss'. If it is not, we will
- # make another one with the name 'loss' to ensure it shows up in the right
- # graph in TensorBoard.
- if not any([x.op.name == 'loss'
- for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
- summary.scalar('loss', estimator_spec.loss)
- ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
- worker_hooks.extend(hooks)
- worker_hooks.extend([
- training.NanTensorHook(estimator_spec.loss),
- training.LoggingTensorHook(
- {
- 'loss': estimator_spec.loss,
- 'step': global_step_tensor
- },
- every_n_iter=self._config.log_step_count_steps)
- ])
- worker_hooks.extend(estimator_spec.training_hooks)
-
- if not (estimator_spec.scaffold.saver or
- ops.get_collection(ops.GraphKeys.SAVERS)):
- ops.add_to_collection(
- ops.GraphKeys.SAVERS,
- training.Saver(
- sharded=True,
- max_to_keep=self._config.keep_checkpoint_max,
- keep_checkpoint_every_n_hours=(
- self._config.keep_checkpoint_every_n_hours),
- defer_build=True,
- save_relative_paths=True))
-
- chief_hooks = []
- all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
- saver_hooks = [
- h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
- if (self._config.save_checkpoints_secs or
- self._config.save_checkpoints_steps):
- if not saver_hooks:
- chief_hooks = [
- training.CheckpointSaverHook(
- self._model_dir,
- save_secs=self._config.save_checkpoints_secs,
- save_steps=self._config.save_checkpoints_steps,
- scaffold=estimator_spec.scaffold)
- ]
- saver_hooks = [chief_hooks[0]]
- if saving_listeners:
- if not saver_hooks:
- raise ValueError(
- 'There should be a CheckpointSaverHook to use saving_listeners. '
- 'Please set one of the RunConfig.save_checkpoints_steps or '
- 'RunConfig.save_checkpoints_secs.')
+ def _train_model_distributed(self, input_fn, hooks, saving_listeners):
+ worker_hooks = []
+ with ops.Graph().as_default() as g:
+ with self._distribution.scope():
+ random_seed.set_random_seed(self._config.tf_random_seed)
+ features, labels, input_hooks = (
+ self._get_features_and_labels_from_input_fn(
+ input_fn, model_fn_lib.ModeKeys.TRAIN))
+ worker_hooks.extend(input_hooks)
+ global_step_tensor = self._create_and_assert_global_step(g)
+ # The default destination for the global_step_tensor fetch call is the
+ # CPU.
+ global_step_read_tensor = self._distribution.fetch(global_step_tensor)
+ # we want to add to the global collection in the main thread not the
+ # tower threads.
+ ops.add_to_collection(training_util.GLOBAL_STEP_READ_KEY,
+ global_step_read_tensor)
+ grouped_estimator_spec = self._distribution.call_for_each_tower(
+ self._call_model_fn,
+ features,
+ labels, # although this will be None it seems
+ model_fn_lib.ModeKeys.TRAIN,
+ self.config)
+
+ # TODO(anjalisridhar): Figure out how to resolve the folowing scaffold
+ # parameters: init_feed_dict, init_fn.
+ scaffold_list = self._distribution.unwrap(
+ grouped_estimator_spec.scaffold)
+ init_feed_dict = [
+ s.init_feed_dict
+ for s in scaffold_list
+ if s.init_feed_dict is not None
+ ]
+ if init_feed_dict:
+ init_feed_dict = self._distribution.group(init_feed_dict)
else:
- # It is expected to have one CheckpointSaverHook. If multiple, we pick
- # up the first one to add listener.
- saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access
- with training.MonitoredTrainingSession(
- master=self._config.master,
- is_chief=self._config.is_chief,
- checkpoint_dir=self._model_dir,
- scaffold=estimator_spec.scaffold,
- hooks=worker_hooks,
- chief_only_hooks=(
- tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
- save_checkpoint_secs=0, # Saving is handled by a hook.
- save_summaries_steps=self._config.save_summary_steps,
- config=self._session_config,
- log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
- loss = None
- while not mon_sess.should_stop():
- _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
- return loss
+ init_feed_dict = None
+
+ init_fn = [s.init_fn for s in scaffold_list if s.init_fn is not None]
+ if init_fn:
+ init_fn = self._distribution.group(init_fn)
+ else:
+ init_fn = None
+
+ init_op = [s.init_op for s in scaffold_list if s.init_op is not None]
+ if init_op:
+ init_op = self._distribution.group(init_op)
+ else:
+ init_op = None
+
+ ready_op = self._distribution.call_for_each_tower(
+ create_per_tower_ready_op, grouped_estimator_spec.scaffold)
+ if ready_op is not None:
+ ready_op = self._distribution.group(ready_op)
+ else:
+ ready_op = None
+
+ ready_for_local_init_op = self._distribution.call_for_each_tower(
+ create_per_tower_ready_for_local_init_op,
+ grouped_estimator_spec.scaffold)
+ if ready_for_local_init_op is not None:
+ ready_for_local_init_op = self._distribution.group(
+ ready_for_local_init_op)
+ else:
+ ready_for_local_init_op = None
+
+ local_init_op = [
+ s.local_init_op
+ for s in scaffold_list
+ if s.local_init_op is not None
+ ]
+ if local_init_op:
+ local_init_op = self._distribution.group(local_init_op)
+ else:
+ local_init_op = None
+
+ summary_op = [
+ s.summary_op for s in scaffold_list if s.summary_op is not None
+ ]
+ if summary_op:
+ summary_op = self._distribution.group(summary_op)
+ else:
+ summary_op = None
+
+ scaffold = monitored_session.Scaffold(
+ init_op=init_op,
+ ready_op=ready_op,
+ ready_for_local_init_op=ready_for_local_init_op,
+ local_init_op=local_init_op,
+ summary_op=summary_op,
+ init_feed_dict=init_feed_dict,
+ init_fn=init_fn)
+
+ def get_hooks_from_the_first_device(per_device_hooks):
+ hooks_list = self._distribution.unwrap(per_device_hooks)
+ assert hooks_list
+ return hooks_list[0]
+
+ training_hooks = get_hooks_from_the_first_device(
+ grouped_estimator_spec.training_hooks)
+ training_chief_hooks = get_hooks_from_the_first_device(
+ grouped_estimator_spec.training_chief_hooks)
+
+ estimator_spec = model_fn_lib.EstimatorSpec(
+ mode=grouped_estimator_spec.mode,
+ loss=self._distribution.unwrap(
+ self._distribution.reduce(distribute_lib.get_loss_reduction(),
+ grouped_estimator_spec.loss,
+ destinations='/device:CPU:0'))[0],
+ train_op=self._distribution.group(grouped_estimator_spec.train_op),
+ training_hooks=training_hooks,
+ training_chief_hooks=training_chief_hooks,
+ scaffold=scaffold)
+ return self._train_with_estimator_spec(estimator_spec, worker_hooks,
+ hooks, global_step_read_tensor,
+ saving_listeners)
+
+ def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
+ global_step_tensor, saving_listeners):
+ """Train a model with the given Estimator Spec."""
+ if self._warm_start_settings:
+ logging.info('Warm-starting with WarmStartSettings: %s' %
+ (self._warm_start_settings,))
+ # pylint: disable=protected-access
+ warm_starting_util.warm_start(*self._warm_start_settings)
+ # pylint: enable=protected-access
+ # Check if the user created a loss summary, and add one if they didn't.
+ # We assume here that the summary is called 'loss'. If it is not, we will
+ # make another one with the name 'loss' to ensure it shows up in the right
+ # graph in TensorBoard.
+ if not any([x.op.name == 'loss'
+ for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
+ summary.scalar('loss', estimator_spec.loss)
+ ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
+ worker_hooks.extend(hooks)
+ worker_hooks.extend([
+ training.NanTensorHook(estimator_spec.loss),
+ training.LoggingTensorHook(
+ {
+ 'loss': estimator_spec.loss,
+ 'step': global_step_tensor
+ },
+ every_n_iter=self._config.log_step_count_steps)
+ ])
+ worker_hooks.extend(estimator_spec.training_hooks)
+
+ if not (estimator_spec.scaffold.saver or
+ ops.get_collection(ops.GraphKeys.SAVERS)):
+ ops.add_to_collection(
+ ops.GraphKeys.SAVERS,
+ training.Saver(
+ sharded=True,
+ max_to_keep=self._config.keep_checkpoint_max,
+ keep_checkpoint_every_n_hours=(
+ self._config.keep_checkpoint_every_n_hours),
+ defer_build=True,
+ save_relative_paths=True))
+
+ chief_hooks = []
+ all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
+ saver_hooks = [
+ h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
+ if (self._config.save_checkpoints_secs or
+ self._config.save_checkpoints_steps):
+ if not saver_hooks:
+ chief_hooks = [
+ training.CheckpointSaverHook(
+ self._model_dir,
+ save_secs=self._config.save_checkpoints_secs,
+ save_steps=self._config.save_checkpoints_steps,
+ scaffold=estimator_spec.scaffold)
+ ]
+ saver_hooks = [chief_hooks[0]]
+ if saving_listeners:
+ if not saver_hooks:
+ raise ValueError(
+ 'There should be a CheckpointSaverHook to use saving_listeners. '
+ 'Please set one of the RunConfig.save_checkpoints_steps or '
+ 'RunConfig.save_checkpoints_secs.')
+ else:
+ # It is expected to have one CheckpointSaverHook. If multiple, we pick
+ # up the first one to add listener.
+ saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access
+ with training.MonitoredTrainingSession(
+ master=self._config.master,
+ is_chief=self._config.is_chief,
+ checkpoint_dir=self._model_dir,
+ scaffold=estimator_spec.scaffold,
+ hooks=worker_hooks,
+ chief_only_hooks=(
+ tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
+ save_checkpoint_secs=0, # Saving is handled by a hook.
+ save_summaries_steps=self._config.save_summary_steps,
+ config=self._session_config,
+ log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
+ loss = None
+ while not mon_sess.should_stop():
+ _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
+ return loss
def _evaluate_model(self,
input_fn,
@@ -972,6 +1122,35 @@ class Estimator(object):
return eval_results
+def create_per_tower_ready_op(scaffold):
+ """Create a Scaffold.ready_op inside a tower."""
+ if scaffold.ready_op:
+ return scaffold.ready_op
+
+ def default_ready_op():
+ return array_ops.concat([
+ variables.report_uninitialized_variables(),
+ resources.report_uninitialized_resources()
+ ], 0)
+
+ return monitored_session.Scaffold.get_or_default(
+ 'ready_op', ops.GraphKeys.READY_OP, default_ready_op)
+
+
+def create_per_tower_ready_for_local_init_op(scaffold):
+ """Create a Scaffold.ready_for_local_init_op inside a tower."""
+ if scaffold.ready_for_local_init_op:
+ return scaffold.ready_for_local_init_op
+
+ def default_ready_for_local_init_op():
+ return variables.report_uninitialized_variables(
+ variables.global_variables())
+
+ return monitored_session.Scaffold.get_or_default(
+ 'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
+ default_ready_for_local_init_op)
+
+
def _check_checkpoint_available(model_dir):
latest_path = saver.latest_checkpoint(model_dir)
if not latest_path:
diff --git a/tensorflow/python/estimator/estimator_lib.py b/tensorflow/python/estimator/estimator_lib.py
index be8930b3cb..60c59cbc18 100644
--- a/tensorflow/python/estimator/estimator_lib.py
+++ b/tensorflow/python/estimator/estimator_lib.py
@@ -21,6 +21,8 @@ from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
from tensorflow.python.estimator.canned.baseline import BaselineClassifier
from tensorflow.python.estimator.canned.baseline import BaselineRegressor
+from tensorflow.python.estimator.canned.boosted_trees import BoostedTreesClassifier
+from tensorflow.python.estimator.canned.boosted_trees import BoostedTreesRegressor
from tensorflow.python.estimator.canned.dnn import DNNClassifier
from tensorflow.python.estimator.canned.dnn import DNNRegressor
from tensorflow.python.estimator.canned.dnn_linear_combined import DNNLinearCombinedClassifier
@@ -52,6 +54,8 @@ _allowed_symbols = [
# Canned Estimators
'BaselineClassifier',
'BaselineRegressor',
+ 'BoostedTreesClassifier',
+ 'BoostedTreesRegressor',
'DNNClassifier',
'DNNRegressor',
'DNNLinearCombinedClassifier',
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index 141eaeff64..41415b89e9 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -688,7 +688,7 @@ class RunConfig(object):
Only the properties in the following list are allowed to be replaced:
- - `model_dir`.
+ - `model_dir`,
- `tf_random_seed`,
- `save_summary_steps`,
- `save_checkpoints_steps`,
@@ -697,6 +697,7 @@ class RunConfig(object):
- `keep_checkpoint_max`,
- `keep_checkpoint_every_n_hours`,
- `log_step_count_steps`,
+ - `distribute`.
In addition, either `save_checkpoints_steps` or `save_checkpoints_secs`
can be set (should not be both).
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index 238a90b67d..0ae9900a1d 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -6,18 +6,6 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "py_test")
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
py_library(
name = "feature_column_py",
srcs = ["feature_column_lib.py"],
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 14d72d8a3d..82dd2a3356 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -934,6 +934,12 @@ def _parse_kwargs_as_attrs(func_name, **kwargs):
s=("function_%s" % func_name).encode())
# pylint: enable=protected-access
+ kwargs_keys = list(kwargs.keys())
+ for key in kwargs_keys:
+ if key.startswith("experimental_"):
+ attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(kwargs[key]))
+ del kwargs[key]
+
if kwargs:
raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
return attrs
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 65ca801cbe..83d256fab6 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -1227,6 +1227,15 @@ class FunctionsFromProtos(test.TestCase):
ValueError, "FunctionDefLibrary contains cyclic gradient functions!"):
function._from_library(library)
+ def testExperimentalAttrs(self):
+
+ @function.Defun(dtypes.int32, experimental_tag="tag_value")
+ def FunctionWithAttr(i):
+ return array_ops.identity(i)
+ self.assertTrue("experimental_tag" in FunctionWithAttr.definition.attr)
+ self.assertEqual(
+ FunctionWithAttr.definition.attr["experimental_tag"].s, b"tag_value")
+
@test_util.with_c_api
class FunctionOverloadTest(test.TestCase):
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index 6593b17184..369669c2e6 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -39,6 +39,7 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
@@ -356,6 +357,39 @@ class ImportGraphDefTest(test.TestCase):
self.assertEqual(d._input_types, [dtypes.int32_ref, dtypes.int32])
self.assertEqual(d.outputs, [])
+ def testResources(self):
+ # Produce GraphDef containing a ops producing and consuming resources.
+ graph = ops.Graph()
+ with graph.as_default():
+ var = resource_variable_ops.ResourceVariable(1.0)
+ var_assign = var.assign(2.0)
+ # Use an op that requires handle shape to be set.
+ var_shape = resource_variable_ops.variable_shape(var.handle)
+ init = variables.global_variables_initializer()
+ graph_def = graph.as_graph_def()
+
+ # Import the GraphDef.
+ with ops.Graph().as_default():
+ # pylint: disable=unused-variable
+ imported_var, imported_assign, imported_shape, imported_init = (
+ importer.import_graph_def(
+ graph_def,
+ return_elements=[var.name, var_assign.name, var_shape.name,
+ init.name]))
+
+ # Make sure the handle shape is set on the imported variable.
+ new_var_shape = resource_variable_ops.variable_shape(imported_var)
+ # pylint: enable=unused-variable
+
+ # Run the imported graph.
+ # TODO(b/76173421): make this work (currently DCHECKS)
+ # with self.test_session() as sess:
+ # sess.run(imported_init)
+ # self.assertEqual(sess.run(imported_var), 1.0)
+ # self.assertEqual(sess.run(imported_assign), 2.0)
+ # self.assertEqual(list(sess.run(imported_shape)), [])
+ # self.assertEqual(list(sess.run(new_var_shape)), [])
+
def testWhileLoop(self):
# Produce GraphDef containing while loop.
graph = ops.Graph()
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 25a951a2de..6930737a0c 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -42,6 +42,7 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import tape
from tensorflow.python.framework import c_api_util
+from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -295,6 +296,7 @@ class Tensor(_TensorLike):
# Attributes used for C++ shape inference. Not inspected, only forwarded.
# If set, will be a HandleData object from cpp_shape_inference.proto.
+ # TODO(b/74620627): remove when _USE_C_SHAPES is removed
self._handle_data = None
self._id = uid()
@@ -1663,6 +1665,9 @@ class Operation(object):
self._control_inputs_val = control_input_ops
self._node_def_val = copy.deepcopy(node_def)
self._op_def_val = op_def
+ else:
+ # This will be set by self.inputs.
+ self._inputs_val = None
self._id_value = self._graph._next_id() # pylint: disable=protected-access
self._original_op = original_op
@@ -1936,6 +1941,8 @@ class Operation(object):
raise TypeError("tensor must be a Tensor: %s" % tensor)
_assert_same_graph(self, tensor)
if self._c_op:
+ # Reset cached inputs.
+ self._inputs_val = None
with errors.raise_exception_on_not_ok_status() as status:
c_api.UpdateEdge(
self._graph._c_graph, # pylint: disable=protected-access
@@ -2052,15 +2059,18 @@ class Operation(object):
def inputs(self):
"""The list of `Tensor` objects representing the data inputs of this op."""
if self._c_op:
- tf_outputs = c_api.GetOperationInputs(self._c_op)
- # pylint: disable=protected-access
- retval = [
- self.graph._get_tensor_by_tf_output(tf_output)
- for tf_output in tf_outputs
- ]
- # pylint: enable=protected-access
- return Operation._InputList(retval)
- return Operation._InputList(self._inputs_val)
+ if self._inputs_val is None:
+ tf_outputs = c_api.GetOperationInputs(self._c_op)
+ # pylint: disable=protected-access
+ retval = [
+ self.graph._get_tensor_by_tf_output(tf_output)
+ for tf_output in tf_outputs
+ ]
+ # pylint: enable=protected-access
+ self._inputs_val = Operation._InputList(retval)
+ return self._inputs_val
+ else:
+ return Operation._InputList(self._inputs_val)
@property
def _inputs(self):
@@ -2472,6 +2482,14 @@ def _set_shapes_for_outputs_c_api(op):
shape_vector = [None if d == -1 else d for d in shape_vector]
output.set_shape(tensor_shape.TensorShape(shape_vector))
+ serialized = c_api.ResourceHandleShapeAndType(op._graph._c_graph,
+ output._as_tf_output())
+ if serialized:
+ output._handle_data = (
+ cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
+ compat.as_bytes(serialized)))
+ else:
+ output._handle_data = None
# TODO(skyewm): remove this when _USE_C_API flag is removed.
def _set_shapes_for_outputs(op):
diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc
index 9850f0becc..e5e3b82199 100644
--- a/tensorflow/python/framework/python_op_gen.cc
+++ b/tensorflow/python/framework/python_op_gen.cc
@@ -448,7 +448,7 @@ string AttrValueToPython(const string& type, const AttrValue& value,
return TensorToPython(value.tensor());
} else if (type == "func") {
return StringToPython(value.func().name());
- } else if (StringPiece(type).starts_with("list(")) {
+ } else if (str_util::StartsWith(type, "list(")) {
return strings::StrCat("[", AttrListToPython(value, dtype_module), "]");
} else {
return "?";
diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc
index bc5ca195da..ca6ed42bee 100644
--- a/tensorflow/python/framework/python_op_gen_main.cc
+++ b/tensorflow/python/framework/python_op_gen_main.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/inputbuffer.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/scanner.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -95,7 +96,8 @@ string InferSourceFileName(const char* argv_zero) {
// operators defined in <op type>_ops.cc
const char* kExecPrefix = "gen_";
const char* kExecSuffix = "_py_wrappers_cc";
- if (command_str.Consume(kExecPrefix) && command_str.ends_with(kExecSuffix)) {
+ if (str_util::ConsumePrefix(&command_str, kExecPrefix) &&
+ str_util::EndsWith(command_str, kExecSuffix)) {
command_str.remove_suffix(strlen(kExecSuffix));
return strings::StrCat(command_str, ".cc");
} else {
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 43106b6e59..bf00fa6439 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -487,7 +487,13 @@ def assert_no_new_pyobjects_executing_eagerly(f):
gc.collect()
# There should be no new Python objects hanging around.
new_count = len(gc.get_objects())
- self.assertEqual(previous_count, new_count)
+ # In some cases (specifacally on MacOS), new_count is somehow
+ # smaller than previous_count.
+ # Using plain assert because not all classes using this decorator
+ # have assertLessEqual
+ assert new_count <= previous_count, (
+ "new_count(%d) is not less than or equal to previous_count(%d)" % (
+ new_count, previous_count))
gc.enable()
return decorator
@@ -968,8 +974,6 @@ class TensorFlowTestCase(googletest.TestCase):
config.graph_options.optimizer_options.opt_level = -1
config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
- config.graph_options.rewrite_options.arithmetic_optimization = (
- rewriter_config_pb2.RewriterConfig.OFF)
return config
if graph is None:
diff --git a/tensorflow/python/grappler/cluster_test.py b/tensorflow/python/grappler/cluster_test.py
index a3c4c2bbeb..26c6f22d34 100644
--- a/tensorflow/python/grappler/cluster_test.py
+++ b/tensorflow/python/grappler/cluster_test.py
@@ -87,9 +87,10 @@ class ClusterTest(test.TestCase):
def testVirtualCluster(self):
with ops.Graph().as_default() as g:
- a = random_ops.random_uniform(shape=())
- b = random_ops.random_uniform(shape=())
- c = a + b
+ with ops.device('/device:GPU:0'):
+ a = random_ops.random_uniform(shape=[1024, 1024])
+ b = random_ops.random_uniform(shape=[1024, 1024])
+ c = a + b
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(c)
mg = meta_graph.create_meta_graph_def(graph=g)
@@ -102,10 +103,13 @@ class ClusterTest(test.TestCase):
'architecture': '7'
})
named_device = device_properties_pb2.NamedDevice(
- properties=device_properties, name='/GPU:0')
- grappler_cluster = cluster.Cluster(devices=[named_device])
+ properties=device_properties, name='/device:GPU:0')
+ grappler_cluster = cluster.Cluster(
+ disable_detailed_stats=False,
+ disable_timeline=False,
+ devices=[named_device])
op_perfs, run_time, _ = grappler_cluster.MeasureCosts(grappler_item)
- self.assertGreater(run_time, 0)
+ self.assertEqual(run_time, 0.000545)
self.assertEqual(len(op_perfs), 15)
estimated_perf = grappler_cluster.EstimatePerformance(named_device)
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 16033e9b8f..2a06907f49 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -868,15 +868,3 @@ py_library(
"//third_party/py/numpy",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
index 5615241ae3..755607aafb 100644
--- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
@@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import inspect # Necessary supplement to tf_inspect to deal with variadic args.
+
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
@@ -30,6 +32,8 @@ from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.utils import generic_utils
from tensorflow.python.layers import base as tf_base_layers
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -143,6 +147,7 @@ class Layer(tf_base_layers.Layer):
super(Layer, self).__init__(
name=name, dtype=dtype, trainable=trainable,
activity_regularizer=kwargs.get('activity_regularizer'))
+ self._uses_inputs_arg = True
# Add properties that are Keras-only for now.
self.supports_masking = False
@@ -213,7 +218,71 @@ class Layer(tf_base_layers.Layer):
"""
return inputs
- def __call__(self, inputs, **kwargs):
+ def _inputs_from_call_args(self, call_args, call_kwargs):
+ """Get Layer inputs from __call__ *args and **kwargs.
+
+ Args:
+ call_args: The positional arguments passed to __call__.
+ call_kwargs: The keyword argument dict passed to __call__.
+
+ Returns:
+ A tuple of (inputs, non_input_kwargs). These may be the same objects as
+ were passed in (call_args and call_kwargs).
+ """
+ if getattr(self, '_uses_inputs_arg', True):
+ assert len(call_args) == 1 # TypeError raised earlier in __call__.
+ return call_args[0], call_kwargs
+ else:
+ call_arg_spec = tf_inspect.getargspec(self.call)
+ # There is no explicit "inputs" argument expected or provided to
+ # call(). Arguments which have default values are considered non-inputs,
+ # and arguments without are considered inputs.
+ if call_arg_spec.defaults:
+ if call_arg_spec.varargs is not None:
+ raise TypeError(
+ 'Layer.call() may not accept both *args and arguments with '
+ 'default values (unable to determine which are inputs to the '
+ 'Layer).')
+ keyword_arg_names = set(
+ call_arg_spec.args[-len(call_arg_spec.defaults):])
+ else:
+ keyword_arg_names = set()
+ # Training is never an input argument name, to allow signatures like
+ # call(x, training).
+ keyword_arg_names.add('training')
+ _, unwrapped_call = tf_decorator.unwrap(self.call)
+ bound_args = inspect.getcallargs(
+ unwrapped_call, *call_args, **call_kwargs)
+ if call_arg_spec.keywords is not None:
+ var_kwargs = bound_args.pop(call_arg_spec.keywords)
+ bound_args.update(var_kwargs)
+ keyword_arg_names = keyword_arg_names.union(var_kwargs.keys())
+ all_args = call_arg_spec.args
+ if all_args and bound_args[all_args[0]] is self:
+ # Ignore the 'self' argument of methods
+ bound_args.pop(call_arg_spec.args[0])
+ all_args = all_args[1:]
+ non_input_arg_values = {}
+ input_arg_values = []
+ remaining_args_are_keyword = False
+ for argument_name in all_args:
+ if argument_name in keyword_arg_names:
+ remaining_args_are_keyword = True
+ else:
+ if remaining_args_are_keyword:
+ raise TypeError(
+ 'Found a positional argument to call() after a non-input '
+ 'argument. All arguments after "training" must be keyword '
+ 'arguments, and are not tracked as inputs to the Layer.')
+ if remaining_args_are_keyword:
+ non_input_arg_values[argument_name] = bound_args[argument_name]
+ else:
+ input_arg_values.append(bound_args[argument_name])
+ if call_arg_spec.varargs is not None:
+ input_arg_values.extend(bound_args[call_arg_spec.varargs])
+ return input_arg_values, non_input_arg_values
+
+ def __call__(self, inputs, *args, **kwargs):
"""Wrapper around self.call(), for handling internal references.
If a Keras tensor is passed:
@@ -226,6 +295,10 @@ class Layer(tf_base_layers.Layer):
Arguments:
inputs: Can be a tensor or list/tuple of tensors.
+ *args: Additional positional arguments to be passed to `call()`. Only
+ allowed in subclassed Models with custom call() signatures. In other
+ cases, `Layer` inputs must be passed using the `inputs` argument and
+ non-inputs must be keyword arguments.
**kwargs: Additional keyword arguments to be passed to `call()`.
Returns:
@@ -234,12 +307,25 @@ class Layer(tf_base_layers.Layer):
Raises:
ValueError: in case the layer is missing shape information
for its `build` call.
+ TypeError: If positional arguments are passed and this `Layer` is not a
+ subclassed `Model`.
"""
# Actually call the layer (optionally building it).
- output = super(Layer, self).__call__(inputs, **kwargs)
+ output = super(Layer, self).__call__(inputs, *args, **kwargs)
+
+ if args and getattr(self, '_uses_inputs_arg', True):
+ raise TypeError(
+ 'This Layer takes an `inputs` argument to call(), and only the '
+ '`inputs` argument may be specified as a positional argument. Pass '
+ 'everything else as a keyword argument (those arguments will not be '
+ 'tracked as inputs to the Layer).')
+
if context.executing_eagerly():
return output
+ inputs, kwargs = self._inputs_from_call_args(
+ call_args=(inputs,) + args, call_kwargs=kwargs)
+
if hasattr(self, '_symbolic_set_inputs') and not self.inputs:
# Subclassed network: explicitly set metadata normally set by a call to
# self._set_inputs().
diff --git a/tensorflow/python/keras/_impl/keras/engine/network.py b/tensorflow/python/keras/_impl/keras/engine/network.py
index ea4be0d293..9f1c7de115 100644
--- a/tensorflow/python/keras/_impl/keras/engine/network.py
+++ b/tensorflow/python/keras/_impl/keras/engine/network.py
@@ -117,6 +117,7 @@ class Network(base_layer.Layer):
self._inbound_nodes = []
def _init_graph_network(self, inputs, outputs, name=None):
+ self._uses_inputs_arg = True
# Normalize and set self.inputs, self.outputs.
if isinstance(inputs, (list, tuple)):
self.inputs = list(inputs) # Tensor or list of tensors.
@@ -274,11 +275,15 @@ class Network(base_layer.Layer):
def _init_subclassed_network(self, name=None):
self._base_init(name=name)
self._is_graph_network = False
- if 'training' in tf_inspect.getargspec(self.call).args:
+ call_args = tf_inspect.getargspec(self.call).args
+ if 'training' in call_args:
self._expects_training_arg = True
else:
self._expects_training_arg = False
-
+ if 'inputs' in call_args:
+ self._uses_inputs_arg = True
+ else:
+ self._uses_inputs_arg = False
self.outputs = None
self.inputs = None
self.built = False
diff --git a/tensorflow/python/keras/_impl/keras/engine/sequential.py b/tensorflow/python/keras/_impl/keras/engine/sequential.py
index 66cef1f5b9..2ef99d5ab3 100644
--- a/tensorflow/python/keras/_impl/keras/engine/sequential.py
+++ b/tensorflow/python/keras/_impl/keras/engine/sequential.py
@@ -29,6 +29,7 @@ from tensorflow.python.keras._impl.keras.engine.input_layer import Input
from tensorflow.python.keras._impl.keras.engine.input_layer import InputLayer
from tensorflow.python.keras._impl.keras.engine.training import Model
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import checkpointable
from tensorflow.python.util.tf_export import tf_export
@@ -192,6 +193,36 @@ class Sequential(Model):
self.build()
else:
self._layers.append(layer)
+ # In implementing Checkpointable, Sequential does not track its Layers
+ # normally, since they may be added and removed (in pop()). Instead, it
+ # names everything on demand (gathering dependencies in
+ # _checkpoint_dependencies, and looking them up in
+ # _lookup_dependency). _handle_deferred_dependencies just checks whether an
+ # existing checkpoint load targets this Layer, it does not create a
+ # dependency on the Layer.
+ self._handle_deferred_dependencies(
+ name='layer-%d' % (len(self._layers) - 1), checkpointable=layer)
+
+ @property
+ def _checkpoint_dependencies(self):
+ """For implementing Checkpointable. Layers which should be saved."""
+ return super(Sequential, self)._checkpoint_dependencies + [
+ checkpointable.CheckpointableReference(
+ name='layer-%d' % layer_index, ref=layer)
+ for layer_index, layer in enumerate(self._layers)]
+
+ def _lookup_dependency(self, name):
+ """For implementing Checkpointable. Looks up a Layer."""
+ super_lookup = super(Sequential, self)._lookup_dependency(name=name)
+ if super_lookup is not None:
+ return super_lookup
+ if name.startswith('layer-'):
+ try:
+ return self._layers[int(name[6:])]
+ except IndexError:
+ return None
+ else:
+ return None
def pop(self):
"""Removes the last layer in the model.
diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py
index 08288d353e..971245c162 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training.py
@@ -874,6 +874,11 @@ class Model(Network):
whether to build the model's graph in inference mode (False), training
mode (True), or using the Keras learning phase (None).
"""
+ if not getattr(self, '_uses_inputs_arg', True):
+ raise NotImplementedError(
+ 'Subclassed Models without "inputs" in their call() signatures do '
+ 'not yet support shape inference. File a feature request if this '
+ 'limitation bothers you.')
if self.__class__.__name__ == 'Sequential':
# Note: we can't test whether the model is `Sequential` via `isinstance`
# since `Sequential` depends on `Model`.
diff --git a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py
index 58b144365b..4445900330 100644
--- a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py
@@ -22,7 +22,9 @@ import os
import tempfile
import numpy as np
+import six
+from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.keras._impl import keras
@@ -36,6 +38,7 @@ except ImportError:
h5py = None
+# pylint: disable=not-callable
class SimpleTestModel(keras.Model):
def __init__(self, use_bn=False, use_dp=False, num_classes=10):
@@ -104,7 +107,7 @@ class NestedTestModel1(keras.Model):
def call(self, inputs):
x = self.dense1(inputs)
x = self.bn(x)
- x = self.test_net(x) # pylint: disable=not-callable
+ x = self.test_net(x)
return self.dense2(x)
@@ -161,7 +164,7 @@ def get_nested_model_3(input_dim, num_classes):
return tensor_shape.TensorShape((input_shape[0], 5))
test_model = Inner()
- x = test_model(x) # pylint: disable=not-callable
+ x = test_model(x)
outputs = keras.layers.Dense(num_classes)(x)
return keras.Model(inputs, outputs, name='nested_model_3')
@@ -574,5 +577,128 @@ class ModelSubclassingTest(test.TestCase):
self.assertGreater(loss, 0.1)
+class CustomCallModel(keras.Model):
+
+ def __init__(self):
+ super(CustomCallModel, self).__init__()
+ self.dense1 = keras.layers.Dense(1, activation='relu')
+ self.dense2 = keras.layers.Dense(1, activation='softmax')
+
+ def call(self, first, second, fiddle_with_output='no', training=True):
+ combined = self.dense1(first) + self.dense2(second)
+ if fiddle_with_output == 'yes':
+ return 10. * combined
+ else:
+ return combined
+
+
+class CustomCallSignatureTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_no_inputs_in_signature(self):
+ model = CustomCallModel()
+ first = array_ops.ones([2, 3])
+ second = array_ops.ones([2, 5])
+ output = model(first, second)
+ self.evaluate([v.initializer for v in model.variables])
+ expected_output = self.evaluate(model.dense1(first) + model.dense2(second))
+ self.assertAllClose(expected_output, self.evaluate(output))
+ output = model(first, second, fiddle_with_output='yes')
+ self.assertAllClose(10. * expected_output, self.evaluate(output))
+ output = model(first, second=second, training=False)
+ self.assertAllClose(expected_output, self.evaluate(output))
+ if not context.executing_eagerly():
+ six.assertCountEqual(self, [first, second], model.inputs)
+ with self.assertRaises(TypeError):
+ # tf.layers.Layer expects an "inputs" argument, so all-keywords doesn't
+ # work at the moment.
+ model(first=first, second=second, fiddle_with_output='yes')
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_inputs_in_signature(self):
+
+ class HasInputsAndOtherPositional(keras.Model):
+
+ def call(self, inputs, some_other_arg, training=False):
+ return inputs
+
+ model = HasInputsAndOtherPositional()
+ with self.assertRaisesRegexp(
+ TypeError, 'everything else as a keyword argument'):
+ model(array_ops.ones([]), array_ops.ones([]))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_kwargs_in_signature(self):
+
+ class HasKwargs(keras.Model):
+
+ def call(self, x, y=3, **key_words):
+ return x
+
+ model = HasKwargs()
+ arg = array_ops.ones([])
+ model(arg, a=3)
+ if not context.executing_eagerly():
+ six.assertCountEqual(self, [arg], model.inputs)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_args_in_signature(self):
+
+ class HasArgs(keras.Model):
+
+ def call(self, x, *args, **kwargs):
+ return [x] + list(args)
+
+ model = HasArgs()
+ arg1 = array_ops.ones([])
+ arg2 = array_ops.ones([])
+ arg3 = array_ops.ones([])
+ model(arg1, arg2, arg3, a=3)
+ if not context.executing_eagerly():
+ six.assertCountEqual(self, [arg1, arg2, arg3], model.inputs)
+
+ def test_args_and_keywords_in_signature(self):
+
+ class HasArgs(keras.Model):
+
+ def call(self, x, training=True, *args, **kwargs):
+ return x
+
+ with context.graph_mode():
+ model = HasArgs()
+ arg1 = array_ops.ones([])
+ arg2 = array_ops.ones([])
+ arg3 = array_ops.ones([])
+ with self.assertRaisesRegexp(TypeError, 'args and arguments with'):
+ model(arg1, arg2, arg3, a=3)
+
+ def test_training_no_default(self):
+
+ class TrainingNoDefault(keras.Model):
+
+ def call(self, x, training):
+ return x
+
+ with context.graph_mode():
+ model = TrainingNoDefault()
+ arg = array_ops.ones([])
+ model(arg, True)
+ six.assertCountEqual(self, [arg], model.inputs)
+
+ def test_training_no_default_with_positional(self):
+
+ class TrainingNoDefaultWithPositional(keras.Model):
+
+ def call(self, x, training, positional):
+ return x
+
+ with context.graph_mode():
+ model = TrainingNoDefaultWithPositional()
+ arg1 = array_ops.ones([])
+ arg2 = array_ops.ones([])
+ arg3 = array_ops.ones([])
+ with self.assertRaisesRegexp(TypeError, 'after a non-input'):
+ model(arg1, arg2, arg3)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 228d1c2452..ea210346c1 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1569,7 +1569,7 @@ cuda_py_test(
cuda_py_test(
name = "init_ops_test",
- size = "small",
+ size = "medium",
srcs = ["init_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2945,15 +2945,3 @@ tf_py_test(
"//tensorflow/python/eager:tape",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/kernel_tests/boosted_trees/BUILD b/tensorflow/python/kernel_tests/boosted_trees/BUILD
new file mode 100644
index 0000000000..30e6289420
--- /dev/null
+++ b/tensorflow/python/kernel_tests/boosted_trees/BUILD
@@ -0,0 +1,76 @@
+# Description:
+# Kernel tests for Boosted Trees.
+
+package(
+ default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+tf_py_test(
+ name = "resource_ops_test",
+ size = "small",
+ srcs = ["resource_ops_test.py"],
+ additional_deps = [
+ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
+ "//tensorflow/python:boosted_trees_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:resources",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ ],
+)
+
+tf_py_test(
+ name = "prediction_ops_test",
+ size = "small",
+ srcs = ["prediction_ops_test.py"],
+ additional_deps = [
+ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:boosted_trees_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:resources",
+ ],
+)
+
+tf_py_test(
+ name = "stats_ops_test",
+ size = "small",
+ srcs = ["stats_ops_test.py"],
+ additional_deps = [
+ "//tensorflow/python:boosted_trees_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
+tf_py_test(
+ name = "training_ops_test",
+ size = "small",
+ srcs = ["training_ops_test.py"],
+ additional_deps = [
+ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:boosted_trees_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:resources",
+ ],
+)
diff --git a/tensorflow/python/kernel_tests/boosted_trees/__init__.py b/tensorflow/python/kernel_tests/boosted_trees/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/kernel_tests/boosted_trees/__init__.py
diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
new file mode 100644
index 0000000000..d132f15e51
--- /dev/null
+++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
@@ -0,0 +1,926 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests boosted_trees prediction kernels."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from google.protobuf import text_format
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import boosted_trees_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.platform import googletest
+
+
+class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
+ """Tests prediction ops for training."""
+
+ def testCachedPredictionOnEmptyEnsemble(self):
+ """Tests that prediction on a dummy ensemble does not fail."""
+ with self.test_session() as session:
+ # Create a dummy ensemble.
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto='')
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # No previous cached values.
+ cached_tree_ids = [0, 0]
+ cached_node_ids = [0, 0]
+
+ # We have two features: 0 and 1. Values don't matter here on a dummy
+ # ensemble.
+ feature_0_values = [67, 5]
+ feature_1_values = [9, 17]
+
+ # Grow tree ensemble.
+ predict_op = boosted_trees_ops.training_predict(
+ tree_ensemble_handle,
+ max_depth=2,
+ cached_tree_ids=cached_tree_ids,
+ cached_node_ids=cached_node_ids,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1)
+
+ logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+ # Nothing changed.
+ self.assertAllClose(cached_tree_ids, new_tree_ids)
+ self.assertAllClose(cached_node_ids, new_node_ids)
+ self.assertAllClose([[0], [0]], logits_updates)
+
+ def testNoCachedPredictionButTreeExists(self):
+ """Tests that predictions are updated once trees are added."""
+ with self.test_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge("""
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 15
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.14
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 8.79
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ is_finalized: true
+ num_layers_grown: 1
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Two examples, none were cached before.
+ cached_tree_ids = [0, 0]
+ cached_node_ids = [0, 0]
+
+ feature_0_values = [67, 5]
+
+ # Grow tree ensemble.
+ predict_op = boosted_trees_ops.training_predict(
+ tree_ensemble_handle,
+ max_depth=2,
+ cached_tree_ids=cached_tree_ids,
+ cached_node_ids=cached_node_ids,
+ bucketized_features=[feature_0_values],
+ logits_dimension=1)
+
+ logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+ # We are in the first tree.
+ self.assertAllClose([0, 0], new_tree_ids)
+ self.assertAllClose([2, 1], new_node_ids)
+ self.assertAllClose([[0.1 * 8.79], [0.1 * 1.14]], logits_updates)
+
+ def testCachedPredictionIsCurrent(self):
+ """Tests that prediction based on previous node in the tree works."""
+ with self.test_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge("""
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 15
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ original_leaf {
+ scalar: -2
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.14
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 8.79
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ is_finalized: true
+ num_layers_grown: 2
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Two examples, one was cached in node 1 first, another in node 0.
+ cached_tree_ids = [0, 0]
+ cached_node_ids = [1, 2]
+
+ # We have two features: 0 and 1. Values don't matter because trees didn't
+ # change.
+ feature_0_values = [67, 5]
+ feature_1_values = [9, 17]
+
+ # Grow tree ensemble.
+ predict_op = boosted_trees_ops.training_predict(
+ tree_ensemble_handle,
+ max_depth=4,
+ cached_tree_ids=cached_tree_ids,
+ cached_node_ids=cached_node_ids,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1)
+
+ logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+ # Nothing changed.
+ self.assertAllClose(cached_tree_ids, new_tree_ids)
+ self.assertAllClose(cached_node_ids, new_node_ids)
+ self.assertAllClose([[0], [0]], logits_updates)
+
+ def testCachedPredictionFromTheSameTree(self):
+ """Tests that prediction based on previous node in the tree works."""
+ with self.test_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge("""
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 15
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ original_leaf {
+ scalar: -2
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 7
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 1.4
+ original_leaf {
+ scalar: 7.14
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 7
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 2.7
+ original_leaf {
+ scalar: -4.375
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.14
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 8.79
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -5.875
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -2.075
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ is_finalized: true
+ num_layers_grown: 2
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Two examples, one was cached in node 1 first, another in node 0.
+ cached_tree_ids = [0, 0]
+ cached_node_ids = [1, 0]
+
+ # We have two features: 0 and 1.
+ feature_0_values = [67, 5]
+ feature_1_values = [9, 17]
+
+ # Grow tree ensemble.
+ predict_op = boosted_trees_ops.training_predict(
+ tree_ensemble_handle,
+ max_depth=4,
+ cached_tree_ids=cached_tree_ids,
+ cached_node_ids=cached_node_ids,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1)
+
+ logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+ # We are still in the same tree.
+ self.assertAllClose([0, 0], new_tree_ids)
+ # When using the full tree, the first example will end up in node 4,
+ # the second in node 5.
+ self.assertAllClose([4, 5], new_node_ids)
+ # Full predictions for each instance would be 8.79 and -5.875,
+ # so an update from the previous cached values lr*(7.14 and -2) would be
+ # 1.65 and -3.875, and then multiply them by 0.1 (lr)
+ self.assertAllClose([[0.1 * 1.65], [0.1 * -3.875]], logits_updates)
+
+ def testCachedPredictionFromPreviousTree(self):
+ """Tests the predictions work when we have cache from previous trees."""
+ with self.test_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge("""
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 28
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.14
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 8.79
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 26
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 50
+ left_id: 3
+ right_id: 4
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 6
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 34
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -7.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5.0
+ }
+ }
+ }
+ tree_metadata {
+ is_finalized: true
+ }
+ tree_metadata {
+ is_finalized: true
+ }
+ tree_metadata {
+ is_finalized: false
+ }
+ tree_weights: 0.1
+ tree_weights: 0.1
+ tree_weights: 0.1
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Two examples, one was cached in node 1 first, another in node 2.
+ cached_tree_ids = [0, 0]
+ cached_node_ids = [1, 0]
+
+ # We have two features: 0 and 1.
+ feature_0_values = [36, 32]
+ feature_1_values = [11, 27]
+
+ # Grow tree ensemble.
+ predict_op = boosted_trees_ops.training_predict(
+ tree_ensemble_handle,
+ max_depth=2,
+ cached_tree_ids=cached_tree_ids,
+ cached_node_ids=cached_node_ids,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1)
+
+ logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+ # Example 1 will get to node 3 in tree 1 and node 2 of tree 2
+ # Example 2 will get to node 2 in tree 1 and node 1 of tree 2
+
+ # We are in the last tree.
+ self.assertAllClose([2, 2], new_tree_ids)
+ # When using the full tree, the first example will end up in node 4,
+ # the second in node 5.
+ self.assertAllClose([2, 1], new_node_ids)
+ # Example 1: tree 0: 8.79, tree 1: 5.0, tree 2: 5.0 = >
+ # change = 0.1*(5.0+5.0)
+ # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = >
+ # change= 0.1(1.14+7.0-7.0)
+ self.assertAllClose([[1], [0.114]], logits_updates)
+
+ def testCachedPredictionFromTheSameTreeWithPostPrunedNodes(self):
+ """Tests that prediction based on previous node in the tree works."""
+ with self.test_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge("""
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id:0
+ threshold: 33
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: -0.2
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.01
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 5
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 0.5
+ original_leaf {
+ scalar: 0.0143
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0553
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0783
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 3
+ is_finalized: true
+ post_pruned_nodes_meta {
+ new_node_id: 0
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 1
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 2
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 1
+ logit_change: -0.07
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 1
+ logit_change: -0.083
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 3
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 4
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 1
+ logit_change: -0.22
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 1
+ logit_change: -0.57
+ }
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 3
+ }
+ """, tree_ensemble_config)
+
+ # Create existing ensemble.
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ cached_tree_ids = [0, 0, 0, 0, 0, 0]
+ # Leaves 3,4, 7 and 8 got deleted during post-pruning, leaves 5 and 6
+ # changed the ids to 3 and 4 respectively.
+ cached_node_ids = [3, 4, 5, 6, 7, 8]
+
+ # We have two features: 0 and 1.
+ feature_0_values = [12, 17, 35, 36, 23, 11]
+ feature_1_values = [12, 12, 17, 18, 123, 24]
+
+ # Grow tree ensemble.
+ predict_op = boosted_trees_ops.training_predict(
+ tree_ensemble_handle,
+ max_depth=3,
+ cached_tree_ids=cached_tree_ids,
+ cached_node_ids=cached_node_ids,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1)
+
+ logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+ # We are still in the same tree.
+ self.assertAllClose([0, 0, 0, 0, 0, 0], new_tree_ids)
+ # Examples from leaves 3,4,7,8 should be in leaf 1, examples from leaf 5
+ # and 6 in leaf 3 and 4.
+ self.assertAllClose([1, 1, 3, 4, 1, 1], new_node_ids)
+
+ cached_values = [[0.08], [0.093], [0.0553], [0.0783], [0.15 + 0.08],
+ [0.5 + 0.08]]
+ self.assertAllClose([[0.01], [0.01], [0.0553], [0.0783], [0.01], [0.01]],
+ logits_updates + cached_values)
+
+ def testCachedPredictionFromThePreviousTreeWithPostPrunedNodes(self):
+ """Tests that prediction based on previous node in the tree works."""
+ with self.test_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge("""
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id:0
+ threshold: 33
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: -0.2
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.01
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 5
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 0.5
+ original_leaf {
+ scalar: 0.0143
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0553
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0783
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.55
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 3
+ is_finalized: true
+ post_pruned_nodes_meta {
+ new_node_id: 0
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 1
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 2
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 1
+ logit_change: -0.07
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 1
+ logit_change: -0.083
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 3
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 4
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 1
+ logit_change: -0.22
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 1
+ logit_change: -0.57
+ }
+ }
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 2
+ num_layers_attempted: 4
+ }
+ """, tree_ensemble_config)
+
+ # Create existing ensemble.
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ cached_tree_ids = [0, 0, 0, 0, 0, 0]
+ # Leaves 3,4, 7 and 8 got deleted during post-pruning, leaves 5 and 6
+ # changed the ids to 3 and 4 respectively.
+ cached_node_ids = [3, 4, 5, 6, 7, 8]
+
+ # We have two features: 0 and 1.
+ feature_0_values = [12, 17, 35, 36, 23, 11]
+ feature_1_values = [12, 12, 17, 18, 123, 24]
+
+ # Grow tree ensemble.
+ predict_op = boosted_trees_ops.training_predict(
+ tree_ensemble_handle,
+ max_depth=3,
+ cached_tree_ids=cached_tree_ids,
+ cached_node_ids=cached_node_ids,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1)
+
+ logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+ # We are in the last tree.
+ self.assertAllClose([1, 1, 1, 1, 1, 1], new_tree_ids)
+ # Examples from leaves 3,4,7,8 should be in leaf 1, examples from leaf 5
+ # and 6 in leaf 3 and 4 in tree 0. For tree 1, all of the examples are in
+ # the root node.
+ self.assertAllClose([0, 0, 0, 0, 0, 0], new_node_ids)
+
+ cached_values = [[0.08], [0.093], [0.0553], [0.0783], [0.15 + 0.08],
+ [0.5 + 0.08]]
+ root = 0.55
+ self.assertAllClose([[root + 0.01], [root + 0.01], [root + 0.0553],
+ [root + 0.0783], [root + 0.01], [root + 0.01]],
+ logits_updates + cached_values)
+
+ def testCachedPredictionTheWholeTreeWasPruned(self):
+ """Tests that prediction based on previous node in the tree works."""
+ with self.test_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge("""
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.00
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: true
+ post_pruned_nodes_meta {
+ new_node_id: 0
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 0
+ logit_change: -6.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 0
+ logit_change: 5.0
+ }
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }
+ """, tree_ensemble_config)
+
+ # Create existing ensemble.
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ cached_tree_ids = [
+ 0,
+ 0,
+ ]
+ # The predictions were cached in 1 and 2, both were pruned to the root.
+ cached_node_ids = [1, 2]
+
+ # We have two features: 0 and 1.These are not going to be used anywhere.
+ feature_0_values = [12, 17]
+ feature_1_values = [12, 12]
+
+ # Grow tree ensemble.
+ predict_op = boosted_trees_ops.training_predict(
+ tree_ensemble_handle,
+ max_depth=1,
+ cached_tree_ids=cached_tree_ids,
+ cached_node_ids=cached_node_ids,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1)
+
+ logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+ # We are in the last tree.
+ self.assertAllClose([0, 0], new_tree_ids)
+ self.assertAllClose([0, 0], new_node_ids)
+
+ self.assertAllClose([[-6.0], [5.0]], logits_updates)
+
+
+class PredictionOpsTest(test_util.TensorFlowTestCase):
+ """Tests prediction ops for inference."""
+
+ def testPredictionMultipleTree(self):
+ """Tests the predictions work when we have multiple trees."""
+ with self.test_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge("""
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 28
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.14
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 8.79
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 26
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 50
+ left_id: 3
+ right_id: 4
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 6.0
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 34
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -7.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5.0
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_weights: 0.2
+ tree_weights: 1.0
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ feature_0_values = [36, 32]
+ feature_1_values = [11, 27]
+
+ # Example 1: tree 0: 1.14, tree 1: 5.0, tree 2: 5.0 = >
+ # logit = 0.1*5.0+0.2*5.0+1*5
+ # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = >
+ # logit= 0.1*1.14+0.2*7.0-1*7.0
+ expected_logits = [[6.114], [-5.486]]
+
+ # Do with parallelization, e.g. EVAL
+ predict_op = boosted_trees_ops.predict(
+ tree_ensemble_handle,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1,
+ max_depth=2)
+
+ logits = session.run(predict_op)
+ self.assertAllClose(expected_logits, logits)
+
+ # Do without parallelization, e.g. INFER - the result is the same
+ predict_op = boosted_trees_ops.predict(
+ tree_ensemble_handle,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1,
+ max_depth=2)
+
+ logits = session.run(predict_op)
+ self.assertAllClose(expected_logits, logits)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py
new file mode 100644
index 0000000000..a223241e89
--- /dev/null
+++ b/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py
@@ -0,0 +1,228 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for boosted_trees resource kernels."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from google.protobuf import text_format
+
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import boosted_trees_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.platform import googletest
+
+
+class ResourceOpsTest(test_util.TensorFlowTestCase):
+ """Tests resource_ops."""
+
+ def testCreate(self):
+ with self.test_session():
+ ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
+ resources.initialize_resources(resources.shared_resources()).run()
+ stamp_token = ensemble.get_stamp_token()
+ self.assertEqual(0, stamp_token.eval())
+ (_, num_trees, num_finalized_trees,
+ num_attempted_layers) = ensemble.get_states()
+ self.assertEqual(0, num_trees.eval())
+ self.assertEqual(0, num_finalized_trees.eval())
+ self.assertEqual(0, num_attempted_layers.eval())
+
+ def testCreateWithProto(self):
+ with self.test_session():
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge("""
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 4
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ }
+ }
+ nodes {
+ bucketized_split {
+ threshold: 21
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 1.4
+ original_leaf {
+ scalar: 7.14
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 7
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 2.7
+ original_leaf {
+ scalar: -4.375
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 6.54
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.305
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -4.525
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -4.145
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 75
+ threshold: 21
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: -1.4
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.6
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.165
+ }
+ }
+ }
+ tree_weights: 0.15
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 2
+ is_finalized: true
+ }
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 2
+ num_layers_attempted: 6
+ }
+ """, ensemble_proto)
+ ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble',
+ stamp_token=7,
+ serialized_proto=ensemble_proto.SerializeToString())
+ resources.initialize_resources(resources.shared_resources()).run()
+ (stamp_token, num_trees, num_finalized_trees,
+ num_attempted_layers) = ensemble.get_states()
+ self.assertEqual(7, stamp_token.eval())
+ self.assertEqual(2, num_trees.eval())
+ self.assertEqual(1, num_finalized_trees.eval())
+ self.assertEqual(6, num_attempted_layers.eval())
+
+ def testSerializeDeserialize(self):
+ with self.test_session():
+ # Initialize.
+ ensemble = boosted_trees_ops.TreeEnsemble('ensemble', stamp_token=5)
+ resources.initialize_resources(resources.shared_resources()).run()
+ (stamp_token, num_trees, num_finalized_trees,
+ num_attempted_layers) = ensemble.get_states()
+ self.assertEqual(5, stamp_token.eval())
+ self.assertEqual(0, num_trees.eval())
+ self.assertEqual(0, num_finalized_trees.eval())
+ self.assertEqual(0, num_attempted_layers.eval())
+
+ # Deserialize.
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge("""
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 75
+ threshold: 21
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: -1.4
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.6
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.165
+ }
+ }
+ }
+ tree_weights: 0.5
+ tree_metadata {
+ num_layers_grown: 4 # it's fake intentionally.
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 5
+ }
+ """, ensemble_proto)
+ with ops.control_dependencies([
+ ensemble.deserialize(
+ stamp_token=3,
+ serialized_proto=ensemble_proto.SerializeToString())
+ ]):
+ (stamp_token, num_trees, num_finalized_trees,
+ num_attempted_layers) = ensemble.get_states()
+ self.assertEqual(3, stamp_token.eval())
+ self.assertEqual(1, num_trees.eval())
+ # This reads from metadata, not really counting the layers.
+ self.assertEqual(5, num_attempted_layers.eval())
+ self.assertEqual(0, num_finalized_trees.eval())
+
+ # Serialize.
+ new_ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ new_stamp_token, new_serialized = ensemble.serialize()
+ self.assertEqual(3, new_stamp_token.eval())
+ new_ensemble_proto.ParseFromString(new_serialized.eval())
+ self.assertProtoEquals(ensemble_proto, new_ensemble_proto)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
new file mode 100644
index 0000000000..a54cc43517
--- /dev/null
+++ b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
@@ -0,0 +1,289 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for boosted_trees stats kernels."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import boosted_trees_ops
+from tensorflow.python.platform import googletest
+
+
+class StatsOpsTest(test_util.TensorFlowTestCase):
+ """Tests stats_ops."""
+
+ def testCalculateBestGainsWithoutRegularization(self):
+ """Testing Gain calculation without any regularization."""
+ with self.test_session() as sess:
+ max_splits = 7
+ node_id_range = [1, 2] # node 1 through 2 will be processed.
+ stats_summary_list = [
+ [
+ [[0., 0.], [.08, .09], [0., 0.], [0., 0.]], # node 0; ignored
+ [[0., 0.], [.15, .36], [.06, .07], [.1, .2]], # node 1
+ [[0., 0.], [-.33, .58], [0., 0.], [.3, .4]], # node 2
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
+ ], # feature 0
+ [
+ [[0., 0.], [0., 0.], [.08, .09], [0., 0.]], # node 0; ignored
+ [[0., 0.], [.3, .5], [-.05, .06], [.06, .07]], # node 1
+ [[.1, .1], [.2, .3], [-.4, .5], [.07, .08]], # node 2
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
+ ], # feature 1
+ ] # num_features * shape=[max_splits, num_buckets, 2]
+
+ (node_ids_list, gains_list, thresholds_list, left_node_contribs_list,
+ right_node_contribs_list
+ ) = boosted_trees_ops.calculate_best_gains_per_feature(
+ node_id_range,
+ stats_summary_list,
+ l1=0.0,
+ l2=0.0,
+ tree_complexity=0.0,
+ max_splits=max_splits)
+
+ self.assertAllEqual([[1, 2], [1, 2]], sess.run(node_ids_list))
+ self.assertAllClose([[0.004775, 0.41184], [0.02823, 0.41184]],
+ sess.run(gains_list))
+ self.assertAllEqual([[1, 1], [1, 1]], sess.run(thresholds_list))
+ # The left node contrib will be later added to the previous node value to
+ # make the left node value, and the same for right node contrib.
+ self.assertAllClose([[[-.416667], [.568966]], [[-.6], [-.75]]],
+ sess.run(left_node_contribs_list))
+ self.assertAllClose([[[-.592593], [-.75]], [[-.076923], [.568966]]],
+ sess.run(right_node_contribs_list))
+
+ def testCalculateBestGainsWithL2(self):
+ """Testing Gain calculation with L2."""
+ with self.test_session() as sess:
+ max_splits = 7
+ node_id_range = [1, 2] # node 1 through 2 will be processed.
+ stats_summary_list = [
+ [
+ [[0., 0.], [.08, .09], [0., 0.], [0., 0.]], # node 0; ignored
+ [[0., 0.], [.15, .36], [.06, .07], [.1, .2]], # node 1
+ [[0., 0.], [-.33, .58], [0., 0.], [.3, .4]], # node 2
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
+ ], # feature 0
+ [
+ [[0., 0.], [0., 0.], [.08, .09], [0., 0.]], # node 0; ignored
+ [[0., 0.], [.3, .5], [-.05, .06], [.06, .07]], # node 1
+ [[.1, .1], [.2, .3], [-.4, .5], [.07, .08]], # node 2
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
+ ], # feature 1
+ ] # num_features * shape=[max_splits, num_buckets, 2]
+
+ (node_ids_list, gains_list, thresholds_list, left_node_contribs_list,
+ right_node_contribs_list
+ ) = boosted_trees_ops.calculate_best_gains_per_feature(
+ node_id_range,
+ stats_summary_list,
+ l1=0.0,
+ l2=0.1,
+ tree_complexity=0.0,
+ max_splits=max_splits)
+
+ self.assertAllEqual([[1, 2], [1, 2]], sess.run(node_ids_list))
+ self.assertAllClose([[0., 0.33931375], [0.01879096, 0.33931375]],
+ sess.run(gains_list))
+ self.assertAllEqual([[0, 1], [1, 1]], sess.run(thresholds_list))
+ # The left node contrib will be later added to the previous node value to
+ # make the left node value, and the same for right node contrib.
+ self.assertAllClose([[[0.], [.485294]], [[-.5], [-.6]]],
+ sess.run(left_node_contribs_list))
+ self.assertAllClose([[[-.424658], [-.6]], [[-.043478], [.485294]]],
+ sess.run(right_node_contribs_list))
+
+ def testCalculateBestGainsWithL1(self):
+ """Testing Gain calculation with L1."""
+ with self.test_session() as sess:
+ max_splits = 7
+ node_id_range = [1, 2] # node 1 through 2 will be processed.
+ stats_summary_list = [
+ [
+ [[0., 0.], [.08, .09], [0., 0.], [0., 0.]], # node 0; ignored
+ [[0., 0.], [.15, .36], [.06, .07], [.1, .2]], # node 1
+ [[0., 0.], [-.33, .58], [0., 0.], [.3, .4]], # node 2
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
+ ], # feature 0
+ [
+ [[0., 0.], [0., 0.], [.08, .09], [0., 0.]], # node 0; ignored
+ [[0., 0.], [.3, .5], [-.05, .06], [.06, .07]], # node 1
+ [[.1, .1], [.2, .3], [-.4, .5], [.07, .08]], # node 2
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
+ ], # feature 1
+ ] # num_features * shape=[max_splits, num_buckets, 2]
+
+ l1 = 0.1
+ (node_ids_list, gains_list, thresholds_list, left_node_contribs_list,
+ right_node_contribs_list
+ ) = boosted_trees_ops.calculate_best_gains_per_feature(
+ node_id_range,
+ stats_summary_list,
+ l1=l1,
+ l2=0.0,
+ tree_complexity=0.0,
+ max_splits=max_splits)
+
+ self.assertAllEqual([[0, 1], [1, 1]], sess.run(thresholds_list))
+
+ self.assertAllEqual([[1, 2], [1, 2]], sess.run(node_ids_list))
+ self.assertAllClose([[[0.0], [0.3965517]], [[-0.4], [-0.5]]],
+ sess.run(left_node_contribs_list))
+
+ self.assertAllClose([[[-0.3333333], [-0.5]], [[0.0], [0.396552]]],
+ sess.run(right_node_contribs_list))
+
+ # Gain should also include an adjustment of the gradient by l1.
+ self.assertAllClose([[0.0, 0.191207], [0.01, 0.191207]],
+ sess.run(gains_list))
+
+ def testCalculateBestGainsWithTreeComplexity(self):
+ """Testing Gain calculation with L2."""
+ with self.test_session() as sess:
+ max_splits = 7
+ node_id_range = [1, 2] # node 1 through 2 will be processed.
+ stats_summary_list = [
+ [
+ [[0., 0.], [.08, .09], [0., 0.], [0., 0.]], # node 0; ignored
+ [[0., 0.], [.15, .36], [.06, .07], [.1, .2]], # node 1
+ [[0., 0.], [-.33, .58], [0., 0.], [.3, .4]], # node 2
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
+ ], # feature 0
+ [
+ [[0., 0.], [0., 0.], [.08, .09], [0., 0.]], # node 0; ignored
+ [[0., 0.], [.3, .5], [-.05, .06], [.06, .07]], # node 1
+ [[.1, .1], [.2, .3], [-.4, .5], [.07, .08]], # node 2
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
+ ], # feature 1
+ ] # num_features * shape=[max_splits, num_buckets, 2]
+
+ l2 = 0.1
+ tree_complexity = 3.
+ (node_ids_list, gains_list, thresholds_list, left_node_contribs_list,
+ right_node_contribs_list
+ ) = boosted_trees_ops.calculate_best_gains_per_feature(
+ node_id_range,
+ stats_summary_list,
+ l1=0.0,
+ l2=l2,
+ tree_complexity=tree_complexity,
+ max_splits=max_splits)
+
+ self.assertAllEqual([[1, 2], [1, 2]], sess.run(node_ids_list))
+
+ self.assertAllClose([[-3., -2.66068625], [-2.98120904, -2.66068625]],
+ sess.run(gains_list))
+
+ self.assertAllEqual([[0, 1], [1, 1]], sess.run(thresholds_list))
+ # The left node contrib will be later added to the previous node value to
+ # make the left node value, and the same for right node contrib.
+ self.assertAllClose([[[0.], [.485294]], [[-.5], [-.6]]],
+ sess.run(left_node_contribs_list))
+ self.assertAllClose([[[-.424658], [-.6]], [[-.043478], [.485294]]],
+ sess.run(right_node_contribs_list))
+
+ def testMakeStatsSummarySimple(self):
+ """Simple test for MakeStatsSummary."""
+ with self.test_session():
+ self.assertAllClose([[[[1., 5.], [2., 6.]], [[3., 7.], [4., 8.]]]],
+ boosted_trees_ops.make_stats_summary(
+ node_ids=[0, 0, 1, 1],
+ gradients=[[1.], [2.], [3.], [4.]],
+ hessians=[[5.], [6.], [7.], [8.]],
+ bucketized_features_list=[[0, 1, 0, 1]],
+ max_splits=2,
+ num_buckets=2).eval())
+
+ def testMakeStatsSummaryAccumulate(self):
+ """Tests that Summary actually accumulates."""
+ with self.test_session():
+ max_splits = 3
+ num_buckets = 4
+ node_ids = [1, 1, 2, 2, 1, 1, 2, 0]
+ gradients = [[.1], [.2], [.3], [-.4], [-.05], [.06], [.07], [.08]]
+ hessians = [[.2], [.3], [.4], [.5], [.06], [.07], [.08], [.09]]
+
+ # Tests a single feature.
+ bucketized_features = [[3, 1, 2, 0, 1, 2, 0, 1]]
+ result = boosted_trees_ops.make_stats_summary(
+ node_ids, gradients, hessians, bucketized_features, max_splits,
+ num_buckets) # shape=[max_splits, num_buckets, num_features, 2]
+ self.assertAllClose(
+ [[
+ [[0., 0.], [.08, .09], [0., 0.], [0., 0.]], # node 0
+ [[0., 0.], [.15, .36], [.06, .07], [.1, .2]], # node 1
+ [[-.33, .58], [0., 0.], [.3, .4], [0., 0.]], # node 2
+ ]],
+ result.eval())
+
+ def testMakeStatsSummaryMultipleFeatures(self):
+ """Tests that MakeStatsSummary works for multiple features."""
+ with self.test_session():
+ max_splits = 3
+ num_buckets = 4
+ node_ids = [1, 1, 2, 2, 1, 1, 2, 0]
+ gradients = [[.1], [.2], [.3], [-.4], [-.05], [.06], [.07], [.08]]
+ hessians = [[.2], [.3], [.4], [.5], [.06], [.07], [.08], [.09]]
+
+ # Tests multiple features.
+ # The output from another feature will stored be in 3rd dimension.
+ bucketized_features = [[3, 1, 2, 0, 1, 2, 0, 1], [0, 0, 0, 2, 2, 3, 3, 2]]
+ result = boosted_trees_ops.make_stats_summary(
+ node_ids, gradients, hessians, bucketized_features, max_splits,
+ num_buckets) # shape=[max_splits, num_buckets, num_features, 2]
+ self.assertAllClose(
+ [
+ [
+ [[0., 0.], [.08, .09], [0., 0.], [0., 0.]], # node 0
+ [[0., 0.], [.15, .36], [.06, .07], [.1, .2]], # node 1
+ [[-.33, .58], [0., 0.], [.3, .4], [0., 0.]], # node 2
+ ], # feature 0
+ [
+ [[0., 0.], [0., 0.], [.08, .09], [0., 0.]], # node 0
+ [[.3, .5], [0., 0.], [-.05, .06], [.06, .07]], # node 1
+ [[.3, .4], [0., 0.], [-.4, .5], [.07, .08]], # node 2
+ ], # feature 1
+ ],
+ result.eval())
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
new file mode 100644
index 0000000000..4226ff75c2
--- /dev/null
+++ b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
@@ -0,0 +1,1465 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for boosted_trees training kernels."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from google.protobuf import text_format
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import boosted_trees_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.platform import googletest
+
+
+class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
+ """Tests for growing tree ensemble from split candidates."""
+
+ def testGrowWithEmptyEnsemble(self):
+ """Test growing an empty ensemble."""
+ with self.test_session() as session:
+ # Create empty ensemble.
+ tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ feature_ids = [0, 2, 6]
+
+ # Prepare feature inputs.
+ # Note that features 1 & 3 have the same gain but different splits.
+ feature1_nodes = np.array([0], dtype=np.int32)
+ feature1_gains = np.array([7.62], dtype=np.float32)
+ feature1_thresholds = np.array([52], dtype=np.int32)
+ feature1_left_node_contribs = np.array([[-4.375]], dtype=np.float32)
+ feature1_right_node_contribs = np.array([[7.143]], dtype=np.float32)
+
+ feature2_nodes = np.array([0], dtype=np.int32)
+ feature2_gains = np.array([0.63], dtype=np.float32)
+ feature2_thresholds = np.array([23], dtype=np.int32)
+ feature2_left_node_contribs = np.array([[-0.6]], dtype=np.float32)
+ feature2_right_node_contribs = np.array([[0.24]], dtype=np.float32)
+
+ # Feature split with the highest gain.
+ feature3_nodes = np.array([0], dtype=np.int32)
+ feature3_gains = np.array([7.65], dtype=np.float32)
+ feature3_thresholds = np.array([7], dtype=np.int32)
+ feature3_left_node_contribs = np.array([[-4.89]], dtype=np.float32)
+ feature3_right_node_contribs = np.array([[5.3]], dtype=np.float32)
+
+ # Grow tree ensemble.
+ grow_op = boosted_trees_ops.update_ensemble(
+ tree_ensemble_handle,
+ learning_rate=0.1,
+ pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
+ # Tree will be finalized now, since we will reach depth 1.
+ max_depth=1,
+ feature_ids=feature_ids,
+ node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
+ gains=[feature1_gains, feature2_gains, feature3_gains],
+ thresholds=[
+ feature1_thresholds, feature2_thresholds, feature3_thresholds
+ ],
+ left_node_contribs=[
+ feature1_left_node_contribs, feature2_left_node_contribs,
+ feature3_left_node_contribs
+ ],
+ right_node_contribs=[
+ feature1_right_node_contribs, feature2_right_node_contribs,
+ feature3_right_node_contribs
+ ])
+ session.run(grow_op)
+
+ new_stamp, serialized = session.run(tree_ensemble.serialize())
+
+ tree_ensemble = boosted_trees_pb2.TreeEnsemble()
+ tree_ensemble.ParseFromString(serialized)
+
+ # Note that since the tree is finalized, we added a new dummy tree.
+ expected_result = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 6
+ threshold: 7
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.65
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.489
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.53
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: true
+ }
+ tree_metadata {
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertProtoEquals(expected_result, tree_ensemble)
+
+ def testGrowExistingEnsembleTreeNotFinalized(self):
+ """Test growing an existing ensemble with the last tree not finalized."""
+ with self.test_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge("""
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 4
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.714
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.4375
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare feature inputs.
+ # feature 1 only has a candidate for node 1, feature 2 has candidates
+ # for both nodes and feature 3 only has a candidate for node 2.
+
+ feature_ids = [0, 1, 0]
+
+ feature1_nodes = np.array([1], dtype=np.int32)
+ feature1_gains = np.array([1.4], dtype=np.float32)
+ feature1_thresholds = np.array([21], dtype=np.int32)
+ feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
+ feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)
+
+ feature2_nodes = np.array([1, 2], dtype=np.int32)
+ feature2_gains = np.array([0.63, 2.7], dtype=np.float32)
+ feature2_thresholds = np.array([23, 7], dtype=np.int32)
+ feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32)
+ feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32)
+
+ feature3_nodes = np.array([2], dtype=np.int32)
+ feature3_gains = np.array([1.7], dtype=np.float32)
+ feature3_thresholds = np.array([3], dtype=np.int32)
+ feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32)
+ feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32)
+
+ # Grow tree ensemble.
+ grow_op = boosted_trees_ops.update_ensemble(
+ tree_ensemble_handle,
+ learning_rate=0.1,
+ pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
+ # tree is going to be finalized now, since we reach depth 2.
+ max_depth=2,
+ feature_ids=feature_ids,
+ node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
+ gains=[feature1_gains, feature2_gains, feature3_gains],
+ thresholds=[
+ feature1_thresholds, feature2_thresholds, feature3_thresholds
+ ],
+ left_node_contribs=[
+ feature1_left_node_contribs, feature2_left_node_contribs,
+ feature3_left_node_contribs
+ ],
+ right_node_contribs=[
+ feature1_right_node_contribs, feature2_right_node_contribs,
+ feature3_right_node_contribs
+ ])
+ session.run(grow_op)
+
+ # Expect the split for node 1 to be chosen from feature 1 and
+ # the split for node 2 to be chosen from feature 2.
+ # The grown tree should be finalized as max tree depth is 2 and we have
+ # grown 2 layers.
+ new_stamp, serialized = session.run(tree_ensemble.serialize())
+ tree_ensemble = boosted_trees_pb2.TreeEnsemble()
+ tree_ensemble.ParseFromString(serialized)
+
+ expected_result = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 4
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ }
+ }
+ nodes {
+ bucketized_split {
+ threshold: 21
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 1.4
+ original_leaf {
+ scalar: 0.714
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 7
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 2.7
+ original_leaf {
+ scalar: -0.4375
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.114
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.879
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.5875
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.2075
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ is_finalized: true
+ num_layers_grown: 2
+ }
+ tree_metadata {
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertProtoEquals(expected_result, tree_ensemble)
+
+ def testGrowExistingEnsembleTreeFinalized(self):
+ """Test growing an existing ensemble with the last tree finalized."""
+ with self.test_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge("""
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 4
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.14
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -4.375
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ tree_weights: 0.15
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: true
+ }
+ tree_metadata {
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare feature inputs.
+
+ feature_ids = [75]
+
+ feature1_nodes = np.array([0], dtype=np.int32)
+ feature1_gains = np.array([-1.4], dtype=np.float32)
+ feature1_thresholds = np.array([21], dtype=np.int32)
+ feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
+ feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)
+
+ # Grow tree ensemble.
+ grow_op = boosted_trees_ops.update_ensemble(
+ tree_ensemble_handle,
+ pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
+ learning_rate=0.1,
+ max_depth=2,
+ feature_ids=feature_ids,
+ node_ids=[feature1_nodes],
+ gains=[feature1_gains],
+ thresholds=[feature1_thresholds],
+ left_node_contribs=[feature1_left_node_contribs],
+ right_node_contribs=[feature1_right_node_contribs])
+ session.run(grow_op)
+
+ # Expect a new tree added, with a split on feature 75
+ new_stamp, serialized = session.run(tree_ensemble.serialize())
+ tree_ensemble = boosted_trees_pb2.TreeEnsemble()
+ tree_ensemble.ParseFromString(serialized)
+
+ expected_result = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 4
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.14
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -4.375
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 75
+ threshold: 21
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: -1.4
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.6
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.165
+ }
+ }
+ }
+ tree_weights: 0.15
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: true
+ }
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 2
+ num_layers_attempted: 2
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertProtoEquals(expected_result, tree_ensemble)
+
+ def testPrePruning(self):
+ """Test growing an existing ensemble with pre-pruning."""
+ with self.test_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge("""
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 4
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.14
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -4.375
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare feature inputs.
+ # For node 1, the best split is on feature 2 (gain -0.63), but the gain
+ # is negative so node 1 will not be split.
+ # For node 2, the best split is on feature 3, gain is positive.
+
+ feature_ids = [0, 1, 0]
+
+ feature1_nodes = np.array([1], dtype=np.int32)
+ feature1_gains = np.array([-1.4], dtype=np.float32)
+ feature1_thresholds = np.array([21], dtype=np.int32)
+ feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
+ feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)
+
+ feature2_nodes = np.array([1, 2], dtype=np.int32)
+ feature2_gains = np.array([-0.63, 2.7], dtype=np.float32)
+ feature2_thresholds = np.array([23, 7], dtype=np.int32)
+ feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32)
+ feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32)
+
+ feature3_nodes = np.array([2], dtype=np.int32)
+ feature3_gains = np.array([2.8], dtype=np.float32)
+ feature3_thresholds = np.array([3], dtype=np.int32)
+ feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32)
+ feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32)
+
+ # Grow tree ensemble.
+ grow_op = boosted_trees_ops.update_ensemble(
+ tree_ensemble_handle,
+ learning_rate=0.1,
+ pruning_mode=boosted_trees_ops.PruningMode.PRE_PRUNING,
+ max_depth=3,
+ feature_ids=feature_ids,
+ node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
+ gains=[feature1_gains, feature2_gains, feature3_gains],
+ thresholds=[
+ feature1_thresholds, feature2_thresholds, feature3_thresholds
+ ],
+ left_node_contribs=[
+ feature1_left_node_contribs, feature2_left_node_contribs,
+ feature3_left_node_contribs
+ ],
+ right_node_contribs=[
+ feature1_right_node_contribs, feature2_right_node_contribs,
+ feature3_right_node_contribs
+ ])
+ session.run(grow_op)
+
+ # Expect the split for node 1 to be chosen from feature 1 and
+ # the split for node 2 to be chosen from feature 2.
+ # The grown tree should not be finalized as max tree depth is 3 and
+ # it's only grown 2 layers.
+ new_stamp, serialized = session.run(tree_ensemble.serialize())
+ tree_ensemble = boosted_trees_pb2.TreeEnsemble()
+ tree_ensemble.ParseFromString(serialized)
+
+ expected_result = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 4
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.14
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 3
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 2.8
+ original_leaf {
+ scalar: -4.375
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -4.45
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -4.182
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ is_finalized: false
+ num_layers_grown: 2
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertProtoEquals(expected_result, tree_ensemble)
+
+ def testMetadataWhenCantSplitDueToEmptySplits(self):
+ """Test that the metadata is updated even though we can't split."""
+ with self.test_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge("""
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 4
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.714
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.4375
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare feature inputs.
+ # feature 1 only has a candidate for node 1, feature 2 has candidates
+ # for both nodes and feature 3 only has a candidate for node 2.
+
+ # Grow tree ensemble.
+ grow_op = boosted_trees_ops.update_ensemble(
+ tree_ensemble_handle,
+ learning_rate=0.1,
+ pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
+ max_depth=2,
+ # No splits are available.
+ feature_ids=[],
+ node_ids=[],
+ gains=[],
+ thresholds=[],
+ left_node_contribs=[],
+ right_node_contribs=[])
+ session.run(grow_op)
+
+ # Expect no new splits created, but attempted (global) stats updated. Meta
+ # data for this tree should not be updated (we didn't succeed building a
+ # layer.
+ new_stamp, serialized = session.run(tree_ensemble.serialize())
+ tree_ensemble = boosted_trees_pb2.TreeEnsemble()
+ tree_ensemble.ParseFromString(serialized)
+
+ expected_result = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 4
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.714
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.4375
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertProtoEquals(expected_result, tree_ensemble)
+
+ def testMetadataWhenCantSplitDuePrePruning(self):
+ """Test metadata is updated correctly when no split due to prepruning."""
+ with self.test_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge("""
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 4
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.14
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -4.375
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare feature inputs.
+ feature_ids = [0, 1, 0]
+
+ # All the gains are negative.
+ feature1_nodes = np.array([1], dtype=np.int32)
+ feature1_gains = np.array([-1.4], dtype=np.float32)
+ feature1_thresholds = np.array([21], dtype=np.int32)
+ feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
+ feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)
+
+ feature2_nodes = np.array([1, 2], dtype=np.int32)
+ feature2_gains = np.array([-0.63, -2.7], dtype=np.float32)
+ feature2_thresholds = np.array([23, 7], dtype=np.int32)
+ feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32)
+ feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32)
+
+ feature3_nodes = np.array([2], dtype=np.int32)
+ feature3_gains = np.array([-2.8], dtype=np.float32)
+ feature3_thresholds = np.array([3], dtype=np.int32)
+ feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32)
+ feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32)
+
+ # Grow tree ensemble.
+ grow_op = boosted_trees_ops.update_ensemble(
+ tree_ensemble_handle,
+ learning_rate=0.1,
+ pruning_mode=boosted_trees_ops.PruningMode.PRE_PRUNING,
+ max_depth=3,
+ feature_ids=feature_ids,
+ node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
+ gains=[feature1_gains, feature2_gains, feature3_gains],
+ thresholds=[
+ feature1_thresholds, feature2_thresholds, feature3_thresholds
+ ],
+ left_node_contribs=[
+ feature1_left_node_contribs, feature2_left_node_contribs,
+ feature3_left_node_contribs
+ ],
+ right_node_contribs=[
+ feature1_right_node_contribs, feature2_right_node_contribs,
+ feature3_right_node_contribs
+ ])
+ session.run(grow_op)
+
+ # Expect that no new split was created because all the gains were negative
+ # Global metadata should be updated, tree metadata should not be updated.
+ new_stamp, serialized = session.run(tree_ensemble.serialize())
+ tree_ensemble = boosted_trees_pb2.TreeEnsemble()
+ tree_ensemble.ParseFromString(serialized)
+
+ expected_result = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 4
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.14
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -4.375
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertProtoEquals(expected_result, tree_ensemble)
+
+ def testPostPruningOfSomeNodes(self):
+ """Test growing an ensemble with post-pruning."""
+ with self.test_session() as session:
+ # Create empty ensemble.
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare inputs.
+ # Second feature has larger (but still negative gain).
+ feature_ids = [0, 1]
+
+ feature1_nodes = np.array([0], dtype=np.int32)
+ feature1_gains = np.array([-1.3], dtype=np.float32)
+ feature1_thresholds = np.array([7], dtype=np.int32)
+ feature1_left_node_contribs = np.array([[0.013]], dtype=np.float32)
+ feature1_right_node_contribs = np.array([[0.0143]], dtype=np.float32)
+
+ feature2_nodes = np.array([0], dtype=np.int32)
+ feature2_gains = np.array([-0.2], dtype=np.float32)
+ feature2_thresholds = np.array([33], dtype=np.int32)
+ feature2_left_node_contribs = np.array([[0.01]], dtype=np.float32)
+ feature2_right_node_contribs = np.array([[0.0143]], dtype=np.float32)
+
+ # Grow tree ensemble.
+ grow_op = boosted_trees_ops.update_ensemble(
+ tree_ensemble_handle,
+ learning_rate=1.0,
+ pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
+ max_depth=3,
+ feature_ids=feature_ids,
+ node_ids=[feature1_nodes, feature2_nodes],
+ gains=[feature1_gains, feature2_gains],
+ thresholds=[feature1_thresholds, feature2_thresholds],
+ left_node_contribs=[
+ feature1_left_node_contribs, feature2_left_node_contribs
+ ],
+ right_node_contribs=[
+ feature1_right_node_contribs, feature2_right_node_contribs
+ ])
+
+ session.run(grow_op)
+
+ # Expect the split from second features to be chosen despite the negative
+ # gain.
+ # No pruning happened just yet.
+ new_stamp, serialized = session.run(tree_ensemble.serialize())
+ res_ensemble = boosted_trees_pb2.TreeEnsemble()
+ res_ensemble.ParseFromString(serialized)
+
+ expected_result = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 33
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: -0.2
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.01
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0143
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertProtoEquals(expected_result, res_ensemble)
+
+ # Prepare the second layer.
+ # Note that node 1 gain is negative and node 2 gain is positive.
+ feature_ids = [3]
+ feature1_nodes = np.array([1, 2], dtype=np.int32)
+ feature1_gains = np.array([-0.2, 0.5], dtype=np.float32)
+ feature1_thresholds = np.array([7, 5], dtype=np.int32)
+ feature1_left_node_contribs = np.array(
+ [[0.07], [0.041]], dtype=np.float32)
+ feature1_right_node_contribs = np.array(
+ [[0.083], [0.064]], dtype=np.float32)
+
+ # Grow tree ensemble.
+ grow_op = boosted_trees_ops.update_ensemble(
+ tree_ensemble_handle,
+ learning_rate=1.0,
+ pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
+ max_depth=3,
+ feature_ids=feature_ids,
+ node_ids=[feature1_nodes],
+ gains=[feature1_gains],
+ thresholds=[feature1_thresholds],
+ left_node_contribs=[feature1_left_node_contribs],
+ right_node_contribs=[feature1_right_node_contribs])
+
+ session.run(grow_op)
+
+ # After adding this layer, the tree will not be finalized
+ new_stamp, serialized = session.run(tree_ensemble.serialize())
+ res_ensemble = boosted_trees_pb2.TreeEnsemble()
+ res_ensemble.ParseFromString(serialized)
+ expected_result = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id:1
+ threshold: 33
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: -0.2
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 3
+ threshold: 7
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: -0.2
+ original_leaf {
+ scalar: 0.01
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 3
+ threshold: 5
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 0.5
+ original_leaf {
+ scalar: 0.0143
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.08
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.093
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0553
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0783
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 2
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """
+ self.assertEqual(new_stamp, 2)
+
+ self.assertProtoEquals(expected_result, res_ensemble)
+ # Now split the leaf 3, again with negative gain. After this layer, the
+ # tree will be finalized, and post-pruning happens. The leafs 3,4,7,8 will
+ # be pruned out.
+
+ # Prepare the third layer.
+ feature_ids = [92]
+ feature1_nodes = np.array([3], dtype=np.int32)
+ feature1_gains = np.array([-0.45], dtype=np.float32)
+ feature1_thresholds = np.array([11], dtype=np.int32)
+ feature1_left_node_contribs = np.array([[0.15]], dtype=np.float32)
+ feature1_right_node_contribs = np.array([[0.5]], dtype=np.float32)
+
+ # Grow tree ensemble.
+ grow_op = boosted_trees_ops.update_ensemble(
+ tree_ensemble_handle,
+ learning_rate=1.0,
+ pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
+ max_depth=3,
+ feature_ids=feature_ids,
+ node_ids=[feature1_nodes],
+ gains=[feature1_gains],
+ thresholds=[feature1_thresholds],
+ left_node_contribs=[feature1_left_node_contribs],
+ right_node_contribs=[feature1_right_node_contribs])
+
+ session.run(grow_op)
+ # After adding this layer, the tree will be finalized
+ new_stamp, serialized = session.run(tree_ensemble.serialize())
+ res_ensemble = boosted_trees_pb2.TreeEnsemble()
+ res_ensemble.ParseFromString(serialized)
+ # Node that nodes 3, 4, 7 and 8 got deleted, so metadata stores has ids
+ # mapped to their parent node 1, with the respective change in logits.
+ expected_result = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id:1
+ threshold: 33
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: -0.2
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.01
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 3
+ threshold: 5
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 0.5
+ original_leaf {
+ scalar: 0.0143
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0553
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0783
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 3
+ is_finalized: true
+ post_pruned_nodes_meta {
+ new_node_id: 0
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 1
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 2
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 1
+ logit_change: -0.07
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 1
+ logit_change: -0.083
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 3
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 4
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 1
+ logit_change: -0.22
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 1
+ logit_change: -0.57
+ }
+ }
+ tree_metadata {
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 3
+ }
+ """
+ self.assertEqual(new_stamp, 3)
+ self.assertProtoEquals(expected_result, res_ensemble)
+
+ def testPostPruningOfAllNodes(self):
+ """Test growing an ensemble with post-pruning, with all nodes are pruned."""
+ with self.test_session() as session:
+ # Create empty ensemble.
+ # Create empty ensemble.
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare inputs. All have negative gains.
+ feature_ids = [0, 1]
+
+ feature1_nodes = np.array([0], dtype=np.int32)
+ feature1_gains = np.array([-1.3], dtype=np.float32)
+ feature1_thresholds = np.array([7], dtype=np.int32)
+ feature1_left_node_contribs = np.array([[0.013]], dtype=np.float32)
+ feature1_right_node_contribs = np.array([[0.0143]], dtype=np.float32)
+
+ feature2_nodes = np.array([0], dtype=np.int32)
+ feature2_gains = np.array([-0.62], dtype=np.float32)
+ feature2_thresholds = np.array([33], dtype=np.int32)
+ feature2_left_node_contribs = np.array([[0.01]], dtype=np.float32)
+ feature2_right_node_contribs = np.array([[0.0143]], dtype=np.float32)
+
+ # Grow tree ensemble.
+ grow_op = boosted_trees_ops.update_ensemble(
+ tree_ensemble_handle,
+ learning_rate=1.0,
+ pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
+ max_depth=2,
+ feature_ids=feature_ids,
+ node_ids=[feature1_nodes, feature2_nodes],
+ gains=[feature1_gains, feature2_gains],
+ thresholds=[feature1_thresholds, feature2_thresholds],
+ left_node_contribs=[
+ feature1_left_node_contribs, feature2_left_node_contribs
+ ],
+ right_node_contribs=[
+ feature1_right_node_contribs, feature2_right_node_contribs
+ ])
+
+ session.run(grow_op)
+
+ # Expect the split from feature 2 to be chosen despite the negative gain.
+ # The grown tree should not be finalized as max tree depth is 2 so no
+ # pruning occurs.
+ new_stamp, serialized = session.run(tree_ensemble.serialize())
+ res_ensemble = boosted_trees_pb2.TreeEnsemble()
+ res_ensemble.ParseFromString(serialized)
+
+ expected_result = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 33
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: -0.62
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.01
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0143
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertProtoEquals(expected_result, res_ensemble)
+
+ # Prepare inputs.
+ # All have negative gain.
+ feature_ids = [3]
+ feature1_nodes = np.array([1, 2], dtype=np.int32)
+ feature1_gains = np.array([-0.2, -0.5], dtype=np.float32)
+ feature1_thresholds = np.array([77, 79], dtype=np.int32)
+ feature1_left_node_contribs = np.array([[0.023], [0.3]], dtype=np.float32)
+ feature1_right_node_contribs = np.array(
+ [[0.012343], [24]], dtype=np.float32)
+
+ grow_op = boosted_trees_ops.update_ensemble(
+ tree_ensemble_handle,
+ learning_rate=1.0,
+ pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
+ max_depth=2,
+ feature_ids=feature_ids,
+ node_ids=[feature1_nodes],
+ gains=[feature1_gains],
+ thresholds=[feature1_thresholds],
+ left_node_contribs=[feature1_left_node_contribs],
+ right_node_contribs=[feature1_right_node_contribs])
+
+ session.run(grow_op)
+
+ # Expect the split from feature 1 to be chosen despite the negative gain.
+ # The grown tree should be finalized. Since all nodes have negative gain,
+ # the whole tree is pruned.
+ new_stamp, serialized = session.run(tree_ensemble.serialize())
+ res_ensemble = boosted_trees_pb2.TreeEnsemble()
+ res_ensemble.ParseFromString(serialized)
+
+ # Expect the ensemble to be empty as post-pruning will prune
+ # the entire finalized tree.
+ self.assertEqual(new_stamp, 2)
+ self.assertProtoEquals("""
+ trees {
+ nodes {
+ leaf {
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata{
+ num_layers_grown: 2
+ is_finalized: true
+ post_pruned_nodes_meta {
+ new_node_id: 0
+ logit_change: 0.0
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 0
+ logit_change: -0.01
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 0
+ logit_change: -0.0143
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 0
+ logit_change: -0.033
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 0
+ logit_change: -0.022343
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 0
+ logit_change: -0.3143
+ }
+ post_pruned_nodes_meta {
+ new_node_id: 0
+ logit_change: -24.0143
+ }
+ }
+ tree_metadata {
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """, res_ensemble)
+
+ def testPostPruningChangesNothing(self):
+ """Test growing an ensemble with post-pruning with all gains >0."""
+ with self.test_session() as session:
+ # Create empty ensemble.
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare inputs.
+ # Second feature has larger (but still negative gain).
+ feature_ids = [3, 4]
+
+ feature1_nodes = np.array([0], dtype=np.int32)
+ feature1_gains = np.array([7.62], dtype=np.float32)
+ feature1_thresholds = np.array([52], dtype=np.int32)
+ feature1_left_node_contribs = np.array([[-4.375]], dtype=np.float32)
+ feature1_right_node_contribs = np.array([[7.143]], dtype=np.float32)
+
+ feature2_nodes = np.array([0], dtype=np.int32)
+ feature2_gains = np.array([0.63], dtype=np.float32)
+ feature2_thresholds = np.array([23], dtype=np.int32)
+ feature2_left_node_contribs = np.array([[-0.6]], dtype=np.float32)
+ feature2_right_node_contribs = np.array([[0.24]], dtype=np.float32)
+
+ # Grow tree ensemble.
+ grow_op = boosted_trees_ops.update_ensemble(
+ tree_ensemble_handle,
+ learning_rate=1.0,
+ pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
+ max_depth=1,
+ feature_ids=feature_ids,
+ node_ids=[feature1_nodes, feature2_nodes],
+ gains=[feature1_gains, feature2_gains],
+ thresholds=[feature1_thresholds, feature2_thresholds],
+ left_node_contribs=[
+ feature1_left_node_contribs, feature2_left_node_contribs
+ ],
+ right_node_contribs=[
+ feature1_right_node_contribs, feature2_right_node_contribs
+ ])
+
+ session.run(grow_op)
+
+ # Expect the split from the first feature to be chosen.
+ # Pruning got triggered but changed nothing.
+ new_stamp, serialized = session.run(tree_ensemble.serialize())
+ res_ensemble = boosted_trees_pb2.TreeEnsemble()
+ res_ensemble.ParseFromString(serialized)
+
+ expected_result = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 3
+ threshold: 52
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -4.375
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.143
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: true
+ }
+ tree_metadata {
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertProtoEquals(expected_result, res_ensemble)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD
index e220d05692..f3cc9636f9 100644
--- a/tensorflow/python/kernel_tests/distributions/BUILD
+++ b/tensorflow/python/kernel_tests/distributions/BUILD
@@ -280,15 +280,3 @@ cuda_py_test(
"//tensorflow/python:platform_test",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index f5717a5a21..1301ef9d19 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -229,7 +229,7 @@ class FunctionalOpsTest(test.TestCase):
with self.test_session():
nums = np.array([1, 2, 3, 4, 5, 6])
with self.assertRaisesRegexp(
- TypeError, r"two structures don't have the same sequence type."):
+ TypeError, r"two structures don't have the same nested structure"):
# lambda emits tuple, but dtype is a list
functional_ops.map_fn(
lambda x: ((x + 3) * 2, -(x + 3) * 2),
@@ -316,7 +316,7 @@ class FunctionalOpsTest(test.TestCase):
initializer = np.array(1.0)
# Multiply a * 1 each time
with self.assertRaisesRegexp(
- ValueError, "two structures don't have the same number of elements"):
+ ValueError, "two structures don't have the same nested structure"):
functional_ops.scan(lambda a, x: (a, -a), elems, initializer)
def testScan_Scoped(self):
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index 3c4d038ef9..1e5c118cbc 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -649,6 +649,30 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
sess.run(outputs_2norm)/(np.sqrt(np.prod(shape))*np.sqrt(3.14)),
rtol=tol, atol=tol)
+ def testNonuniformity(self):
+ value = 0
+ abs_value = 0
+ shape = [3, 3, 10, 10]
+ count = 70
+ tol = 1e-5
+ with self.test_session(use_gpu=True): # as sess:
+ for i in range(count):
+ x = variable_scope.get_variable("{}".format(i), shape=shape,
+ initializer=
+ init_ops.convolutional_delta_orthogonal)
+ x.initializer.run()
+ y = x.eval()[1, 1, :, :]
+ determinant = np.linalg.det(y)
+ value += determinant
+ abs_value += np.abs(determinant)
+
+ # Check there is some variation in the signs of the determinants
+ self.assertLess(value, count - tol)
+ self.assertLess(-count + tol, value)
+ # Check all determinants have absolute value 1
+ # Compute the sum of the absolute values of 'count' determinants
+ self.assertAllClose(abs_value, count, rtol=tol, atol=tol)
+
class IdentityInitializerTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD
index fd1b5bab6f..9555e51099 100644
--- a/tensorflow/python/kernel_tests/linalg/BUILD
+++ b/tensorflow/python/kernel_tests/linalg/BUILD
@@ -140,15 +140,3 @@ cuda_py_test(
],
shard_count = 5,
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/kernel_tests/random/BUILD b/tensorflow/python/kernel_tests/random/BUILD
index 88a4ddf7f2..acd7566eec 100644
--- a/tensorflow/python/kernel_tests/random/BUILD
+++ b/tensorflow/python/kernel_tests/random/BUILD
@@ -121,15 +121,3 @@ cuda_py_test(
"//tensorflow/python:random_ops",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 1e5f26a77f..242cdff6f3 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -625,6 +625,8 @@ class Layer(checkpointable.CheckpointableBase):
input_list = nest.flatten(inputs)
build_graph = not context.executing_eagerly()
+ # TODO(fchollet, allenl): Make deferred mode work with subclassed Models
+ # which don't use an "inputs" argument.
in_deferred_mode = isinstance(input_list[0], _DeferredTensor)
# Ensure the Layer, if being reused, is working with inputs from
# the same graph as where it was created.
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 9106461c60..207866610b 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -957,6 +957,11 @@ def _autopacking_helper(list_or_tuple, dtype, name):
Returns:
A `tf.Tensor` with value equivalent to `list_or_tuple`.
"""
+ if context.executing_eagerly():
+ # NOTE: Fast path when all the items are tensors, this doesn't do any type
+ # checking.
+ if all(ops.is_dense_tensor_like(elem) for elem in list_or_tuple):
+ return gen_array_ops.pack(list_or_tuple, name=name)
must_pack = False
converted_elems = []
with ops.name_scope(name) as scope:
diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py
new file mode 100644
index 0000000000..174d00987f
--- /dev/null
+++ b/tensorflow/python/ops/boosted_trees_ops.py
@@ -0,0 +1,160 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops for boosted_trees."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_boosted_trees_ops
+from tensorflow.python.ops import resources
+
+# Re-exporting ops used by other modules.
+# pylint: disable=unused-import
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble
+# pylint: enable=unused-import
+
+from tensorflow.python.training import saver
+
+
+class PruningMode(object):
+ NO_PRUNING, PRE_PRUNING, POST_PRUNING = range(0, 3)
+
+
+class _TreeEnsembleSavable(saver.BaseSaverBuilder.SaveableObject):
+ """SaveableObject implementation for TreeEnsemble."""
+
+ def __init__(self, resource_handle, create_op, name):
+ """Creates a _TreeEnsembleSavable object.
+
+ Args:
+ resource_handle: handle to the decision tree ensemble variable.
+ create_op: the op to initialize the variable.
+ name: the name to save the tree ensemble variable under.
+ """
+ stamp_token, serialized = (
+ gen_boosted_trees_ops.boosted_trees_serialize_ensemble(resource_handle))
+ # slice_spec is useful for saving a slice from a variable.
+ # It's not meaningful the tree ensemble variable. So we just pass an empty
+ # value.
+ slice_spec = ''
+ specs = [
+ saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec,
+ name + '_stamp'),
+ saver.BaseSaverBuilder.SaveSpec(serialized, slice_spec,
+ name + '_serialized'),
+ ]
+ super(_TreeEnsembleSavable, self).__init__(resource_handle, specs, name)
+ self._resource_handle = resource_handle
+ self._create_op = create_op
+
+ def restore(self, restored_tensors, unused_restored_shapes):
+ """Restores the associated tree ensemble from 'restored_tensors'.
+
+ Args:
+ restored_tensors: the tensors that were loaded from a checkpoint.
+ unused_restored_shapes: the shapes this object should conform to after
+ restore. Not meaningful for trees.
+
+ Returns:
+ The operation that restores the state of the tree ensemble variable.
+ """
+ with ops.control_dependencies([self._create_op]):
+ return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble(
+ self._resource_handle,
+ stamp_token=restored_tensors[0],
+ tree_ensemble_serialized=restored_tensors[1])
+
+
+class TreeEnsemble(object):
+ """Creates TreeEnsemble resource."""
+
+ def __init__(self, name, stamp_token=0, is_local=False, serialized_proto=''):
+ with ops.name_scope(name, 'TreeEnsemble') as name:
+ self._resource_handle = (
+ gen_boosted_trees_ops.boosted_trees_ensemble_resource_handle_op(
+ container='', shared_name=name, name=name))
+ create_op = gen_boosted_trees_ops.boosted_trees_create_ensemble(
+ self.resource_handle,
+ stamp_token,
+ tree_ensemble_serialized=serialized_proto)
+ is_initialized_op = (
+ gen_boosted_trees_ops.is_boosted_trees_ensemble_initialized(
+ self._resource_handle))
+ # Adds the variable to the savable list.
+ if not is_local:
+ saveable = _TreeEnsembleSavable(self.resource_handle, create_op,
+ self.resource_handle.name)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ resources.register_resource(
+ self.resource_handle,
+ create_op,
+ is_initialized_op,
+ is_shared=not is_local)
+
+ @property
+ def resource_handle(self):
+ return self._resource_handle
+
+ def get_stamp_token(self):
+ """Returns the current stamp token of the resource."""
+ stamp_token, _, _, _ = (
+ gen_boosted_trees_ops.boosted_trees_get_ensemble_states(
+ self.resource_handle))
+ return stamp_token
+
+ def get_states(self):
+ """Returns states of the tree ensemble.
+
+ Returns:
+ stamp_token, num_trees, num_finalized_trees, num_attempted_layers.
+ """
+ stamp_token, num_trees, num_finalized_trees, num_attempted_layers = (
+ gen_boosted_trees_ops.boosted_trees_get_ensemble_states(
+ self.resource_handle))
+ # Use identity to give names.
+ return (array_ops.identity(stamp_token, name='stamp_token'),
+ array_ops.identity(num_trees, name='num_trees'),
+ array_ops.identity(num_finalized_trees, name='num_finalized_trees'),
+ array_ops.identity(
+ num_attempted_layers, name='num_attempted_layers'))
+
+ def serialize(self):
+ """Serializes the ensemble into proto and returns the serialized proto.
+
+ Returns:
+ stamp_token: int64 scalar Tensor to denote the stamp of the resource.
+ serialized_proto: string scalar Tensor of the serialized proto.
+ """
+ return gen_boosted_trees_ops.boosted_trees_serialize_ensemble(
+ self.resource_handle)
+
+ def deserialize(self, stamp_token, serialized_proto):
+ """Deserialize the input proto and resets the ensemble from it.
+
+ Args:
+ stamp_token: int64 scalar Tensor to denote the stamp of the resource.
+ serialized_proto: string scalar Tensor of the serialized proto.
+
+ Returns:
+ Operation (for dependencies).
+ """
+ return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble(
+ self.resource_handle, stamp_token, serialized_proto)
diff --git a/tensorflow/python/ops/distributions/BUILD b/tensorflow/python/ops/distributions/BUILD
index 50b956a267..9d9ede7ad7 100644
--- a/tensorflow/python/ops/distributions/BUILD
+++ b/tensorflow/python/ops/distributions/BUILD
@@ -26,15 +26,3 @@ py_library(
"@six_archive//:six",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 40ab22951b..9dfe5ffbf4 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -532,8 +532,7 @@ class Orthogonal(Initializer):
q, r = linalg_ops.qr(a, full_matrices=False)
# Make Q uniform
d = array_ops.diag_part(r)
- ph = d / math_ops.abs(d)
- q *= ph
+ q *= math_ops.sign(d)
if num_rows < num_cols:
q = array_ops.matrix_transpose(q)
return self.gain * array_ops.reshape(q, shape)
@@ -579,7 +578,11 @@ class ConvolutionDeltaOrthogonal(Initializer):
a = random_ops.random_normal([shape[-1], shape[-1]],
dtype=dtype, seed=self.seed)
# Compute the qr factorization
- q, _ = linalg_ops.qr(a, full_matrices=False)
+ q, r = linalg_ops.qr(a, full_matrices=False)
+ # Make Q uniform
+ d = array_ops.diag_part(r)
+ # ph = d / math_ops.abs(d)
+ q *= math_ops.sign(d)
q = q[:shape[-2], :]
q *= math_ops.sqrt(math_ops.cast(self.gain, dtype=dtype))
if len(shape) == 3:
diff --git a/tensorflow/python/ops/linalg/BUILD b/tensorflow/python/ops/linalg/BUILD
index ce8c1580fe..07659ef44c 100644
--- a/tensorflow/python/ops/linalg/BUILD
+++ b/tensorflow/python/ops/linalg/BUILD
@@ -34,15 +34,3 @@ py_library(
"//tensorflow/python:special_math_ops",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/ops/losses/BUILD b/tensorflow/python/ops/losses/BUILD
index 07741e0c3c..4aea0265a7 100644
--- a/tensorflow/python/ops/losses/BUILD
+++ b/tensorflow/python/ops/losses/BUILD
@@ -43,15 +43,3 @@ py_test(
"//tensorflow/python:framework_for_generated_wrappers",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index a74de39eab..0c55386241 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1836,8 +1836,10 @@ def softmax_cross_entropy_with_logits_v2(
[logits, labels]) as name:
logits = ops.convert_to_tensor(logits, name="logits")
labels = ops.convert_to_tensor(labels, name="labels")
+ convert_to_float32 = (
+ logits.dtype == dtypes.float16 or logits.dtype == dtypes.bfloat16)
precise_logits = math_ops.cast(
- logits, dtypes.float32) if (logits.dtype == dtypes.float16) else logits
+ logits, dtypes.float32) if convert_to_float32 else logits
# labels and logits must be of the same type
labels = math_ops.cast(labels, precise_logits.dtype)
input_rank = array_ops.rank(precise_logits)
@@ -1883,8 +1885,8 @@ def softmax_cross_entropy_with_logits_v2(
del shape[dim]
cost.set_shape(shape)
- if logits.dtype == dtypes.float16:
- return math_ops.cast(cost, dtypes.float16)
+ if convert_to_float32:
+ return math_ops.cast(cost, logits.dtype)
else:
return cost
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index af9dae2aa6..da86d5f6ca 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -852,6 +852,57 @@ class ComputeSampledLogitsTest(test_lib.TestCase):
self.assertAllClose(exp_sampled_softmax_loss,
got_sampled_softmax_loss.eval(), 1e-4)
+ def testSampledSoftmaxLossBf16(self):
+ # A simple test to verify the numerics for bfloat16.
+ def _SoftmaxCrossEntropyWithLogits(logits, targets):
+ # logits, targets: float arrays of the same shape.
+ assert logits.shape == targets.shape
+ stable_exp_logits = np.exp(
+ logits - np.amax(logits, axis=1, keepdims=True))
+ pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True)
+ return -np.sum(targets * np.log(pred + 1.0e-20), axis=1)
+
+ np.random.seed(0)
+ num_classes = 5
+ batch_size = 3
+ labels = [0, 1, 2]
+ sampled = [1, 0, 2, 3]
+ (weights, biases, hidden_acts, _, exp_logits,
+ exp_labels) = self._GenerateTestData(
+ num_classes=num_classes,
+ dim=10,
+ batch_size=batch_size,
+ num_true=1,
+ labels=labels,
+ sampled=sampled,
+ subtract_log_q=True)
+ exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits(
+ exp_logits, exp_labels)
+
+ with self.test_session():
+ true_exp_bf16 = np.full(
+ [batch_size, 1], fill_value=0.5, dtype=dtypes.bfloat16.as_numpy_dtype)
+ sampled_exp_bf16 = np.full(
+ [len(sampled)], fill_value=0.5, dtype=dtypes.bfloat16.as_numpy_dtype)
+ sampled_vals_bf16 = (sampled, true_exp_bf16, sampled_exp_bf16)
+
+ got_sampled_softmax_loss = math_ops.cast(
+ nn_impl.sampled_softmax_loss(
+ weights=constant_op.constant(weights, dtype=dtypes.bfloat16),
+ biases=constant_op.constant(biases, dtype=dtypes.bfloat16),
+ labels=constant_op.constant(
+ labels, shape=(batch_size, 1), dtype=dtypes.bfloat16),
+ inputs=constant_op.constant(hidden_acts, dtype=dtypes.bfloat16),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=1,
+ sampled_values=sampled_vals_bf16,
+ remove_accidental_hits=False,
+ partition_strategy="div"), dtypes.float32)
+
+ self.assertAllClose(exp_sampled_softmax_loss,
+ got_sampled_softmax_loss.eval(), 1e-1)
+
class CReluTest(test_lib.TestCase):
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index df873da98e..2f39ea2e7d 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -1087,6 +1087,11 @@ ops.register_proto_function(
proto_type=variable_pb2.VariableDef,
to_proto=_to_proto_fn,
from_proto=_from_proto_fn)
+ops.register_proto_function(
+ ops.GraphKeys.GLOBAL_STEP,
+ proto_type=variable_pb2.VariableDef,
+ to_proto=_to_proto_fn,
+ from_proto=_from_proto_fn)
def is_resource_variable(var):
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index 1b4111bca6..96fb024715 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -334,7 +334,11 @@ def py_func(func, inp, Tout, stateful=True, name=None):
result = func(*[x.numpy() for x in inp])
result = nest.flatten(result)
- return [x if x is None else ops.convert_to_tensor(x) for x in result]
+ result = [x if x is None else ops.convert_to_tensor(x) for x in result]
+ if len(result) == 1:
+ # Mimic the automatic unwrapping in graph-mode py_func
+ result, = result
+ return result
return _internal_py_func(
func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
diff --git a/tensorflow/python/profiler/BUILD b/tensorflow/python/profiler/BUILD
index c815aad0a0..0654104a34 100644
--- a/tensorflow/python/profiler/BUILD
+++ b/tensorflow/python/profiler/BUILD
@@ -156,18 +156,3 @@ py_test(
"@com_google_pprof//:pprof_proto_py",
],
)
-
-# -----------------------------------------------------------------------------
-# Google-internal targets. These must be at the end for syncrepo.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD
index 362a1c49e6..994206cd63 100644
--- a/tensorflow/python/profiler/internal/BUILD
+++ b/tensorflow/python/profiler/internal/BUILD
@@ -70,18 +70,3 @@ cuda_py_test(
"no_pip",
],
)
-
-# -----------------------------------------------------------------------------
-# Google-internal targets. These must be at the end for syncrepo.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 30e0a099d8..2609a5d222 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -235,15 +235,3 @@ py_test(
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index 1de1adcfbc..6e39ce8c80 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -258,17 +258,3 @@ py_test(
"//tensorflow/core:protos_all_py",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "bin/**",
- "gen/**",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/python/training/checkpointable.py b/tensorflow/python/training/checkpointable.py
index d0650eb127..bbbe1e8ac5 100644
--- a/tensorflow/python/training/checkpointable.py
+++ b/tensorflow/python/training/checkpointable.py
@@ -560,6 +560,7 @@ class CheckpointableBase(object):
checkpointable: The Checkpointable object to restore (inheriting from
`CheckpointableBase`).
"""
+ self._maybe_initialize_checkpointable()
deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
for checkpoint_position in sorted(
deferred_dependencies_list,
diff --git a/tensorflow/python/training/device_setter.py b/tensorflow/python/training/device_setter.py
index d31c375b4c..be80c36571 100644
--- a/tensorflow/python/training/device_setter.py
+++ b/tensorflow/python/training/device_setter.py
@@ -25,14 +25,13 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util.tf_export import tf_export
-# This is a tuple of PS ops used by tf.estimator.Esitmator which should work in
+# This is a tuple of PS ops used by tf.estimator.Estimator which should work in
# almost all of cases.
-STANDARD_PS_OPS = (
- "Variable", "VariableV2", "AutoReloadVariable", "MutableHashTable",
- "MutableHashTableV2", "MutableHashTableOfTensors",
- "MutableHashTableOfTensorsV2", "MutableDenseHashTable",
- "MutableDenseHashTableV2", "VarHandleOp"
-)
+STANDARD_PS_OPS = ("Variable", "VariableV2", "AutoReloadVariable",
+ "MutableHashTable", "MutableHashTableV2",
+ "MutableHashTableOfTensors", "MutableHashTableOfTensorsV2",
+ "MutableDenseHashTable", "MutableDenseHashTableV2",
+ "VarHandleOp", "BoostedTreesEnsembleResourceHandleOp")
class _RoundRobinStrategy(object):
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 3571a41d4f..7df4812454 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -126,16 +126,18 @@ class UpdateContext(object):
def get_tower_context():
- """Returns the current TowerContext or None.
+ """Returns the current TowerContext or None if in a cross-tower context.
Note that execution:
- 1. starts in the default (single-tower) tower context;
- 2. switches to cross-tower context when entering a
- `with DistributionStrategy.scope():` block;
+ 1. starts in the default (single-tower) tower context (this function
+ will return the default TowerContext object);
+ 2. switches to cross-tower context (in which case this will return
+ None) when entering a `with DistributionStrategy.scope():` block;
3. switches to a (non-default) tower context inside
`call_for_each_tower(fn, ...)`;
4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then
- inside `merge_fn` you are back in the cross-tower context.
+ inside `merge_fn` you are back in the cross-tower context (and again
+ this function will return None).
Note that you can also go directly from step 1 to 4 to switch to a
cross-tower context for the default `DistributionStrategy`. You may
@@ -188,6 +190,9 @@ def get_cross_tower_context():
def get_distribution_strategy():
"""Returns the current `DistributionStrategy` object.
+ Prefer to use `get_tower_context()` or `get_cross_tower_context()`
+ instead when possible.
+
Returns:
A `DistributionStrategy` object. Inside a
`with distribution_strategy.scope()` block, it returns
@@ -526,7 +531,6 @@ class DistributionStrategy(object):
# TODO(josh11b): ClusterSpec/ClusterResolver
# TODO(josh11b): Partitioned computations, state; sharding
# TODO(josh11b): Model parallelism: "towers" with multiple devices; shuffling
- # TODO(josh11b): Tower-local variables
# TODO(josh11b): List of towers with their worker and parameter devices
# (where the parameter devices may overlap in the ps case).
@@ -556,10 +560,52 @@ class DistributionStrategy(object):
# Note: should support "colocate_with" argument.
raise NotImplementedError("must be implemented in descendants")
+ def tower_local_var_scope(self, reduce_method):
+ """Inside this scope, new variables will not be mirrored.
+
+ There will still be one component variable per tower, but there is
+ no requirement that they stay in sync. Instead, when saving them
+ or calling `fetch()`, we use the value that results when calling
+ `reduce()` on all the towers' variables.
+
+ Note: tower-local implies not trainable. Instead, it is expected
+ that each tower will directly update (using `assign_add()` or
+ whatever) its local variable instance but only the aggregated
+ value (accessible using `fetch()`) will be exported from the
+ model. When it is acceptable to only aggregate on export, we
+ greatly reduce communication overhead by using tower-local
+ variables.
+
+ Note: All component variables will be initialized to the same
+ value, using the initialization expression from the first tower.
+ The values will match even if the initialization expression uses
+ random numbers.
+
+ Args:
+ reduce_method: String used as a `method_string` to `reduce()`
+ to get the value to save when checkpointing.
+
+ Returns:
+ A context manager.
+ """
+ def create_tower_local_variable(next_creator, *args, **kwargs):
+ _require_distribution_strategy_scope(self)
+ kwargs["use_resource"] = True
+ kwargs["tower_local_reduce_method"] = reduce_method
+ return next_creator(*args, **kwargs)
+
+ _require_distribution_strategy_scope(self)
+ return variable_scope.variable_creator_scope(create_tower_local_variable)
+
def colocate_vars_with(self, colocate_with_variable):
- """Controls which devices variables will be created on.
+ """Scope that controls which devices variables will be created on.
- Note this may only be used inside `self.scope()`.
+ No operations should be added to the graph inside this scope, it
+ should only be used when creating variables (some implementations
+ work by changing variable creation, others work by using a
+ tf.colocate_with() scope).
+
+ This may only be used inside `self.scope()`.
Example usage:
@@ -979,6 +1025,10 @@ class TowerContext(object):
finally:
_pop_per_thread_mode()
+ def tower_local_var_scope(self, reduce_method):
+ """Alias for distribution_strategy.tower_local_var_scope()."""
+ return self._distribution_strategy.tower_local_var_scope(reduce_method)
+
@property
def is_single_tower(self):
"""Returns whether there is a single tower or multiple."""
@@ -1025,6 +1075,8 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def creator(next_creator, *args, **kwargs):
_require_distribution_strategy_scope(self)
+ if kwargs.pop("tower_local_reduce_method", None) is not None:
+ kwargs["trainable"] = False
return next_creator(*args, **kwargs)
return _CurrentDistributionContext(
@@ -1032,13 +1084,8 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def colocate_vars_with(self, colocate_with_variable):
"""Does not require `self.scope`."""
- def create_colocated_variable(next_creator, *args, **kwargs):
- _require_distribution_strategy_scope(self)
- with ops.colocate_with(colocate_with_variable):
- return next_creator(*args, **kwargs)
-
_require_distribution_strategy_scope(self)
- return variable_scope.variable_creator_scope(create_colocated_variable)
+ return ops.colocate_with(colocate_with_variable)
def distribute_dataset(self, dataset):
# TODO(josh11b): Support for this when executing eagerly is currently only
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 2d4f09a60a..4ce6f6d002 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -350,8 +350,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
elif save_summaries_steps == USE_DEFAULT:
save_summaries_steps = None
- if save_checkpoint_steps == USE_DEFAULT and \
- save_checkpoint_secs == USE_DEFAULT:
+ if (save_checkpoint_steps == USE_DEFAULT and
+ save_checkpoint_secs == USE_DEFAULT):
save_checkpoint_steps = None
save_checkpoint_secs = 600
elif save_checkpoint_secs == USE_DEFAULT:
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index bf79714f96..75665fc284 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -35,11 +35,28 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import checkpointable
+from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import slot_creator
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
+def get_filtered_grad_fn(grad_fn):
+ # `distributed_context.join()` requires that its arguments are parallel
+ # across threads, and in particular that `grads_and_vars` has the same
+ # variables in the same order.
+
+ # When computing gradients in eager mode with multiple threads, you
+ # can get extra variables with a gradient of `None`. This happens when
+ # those variables are accessed in another thread during the gradient
+ # computation. To get a consistent set of variables, we filter out
+ # those with `None` gradients.
+ def filtered_grad_fn(x=None):
+ return [(g, v) for g, v in grad_fn(x) if g is not None]
+
+ return filtered_grad_fn
+
+
def _deduplicate_indexed_slices(values, indices):
"""Sums `values` associated with any non-unique `indices`.
@@ -335,6 +352,13 @@ class Optimizer(
# ... }
self._deferred_slot_restorations = {}
+ # TODO(isaprykin): When using a DistributionStrategy, and when an
+ # optimizer is created in each tower, it might be dangerous to
+ # rely on some Optimer methods. When such methods are called on a
+ # per-tower optimizer, an exception needs to be thrown. We do
+ # allow creation per-tower optimizers however, because the
+ # compute_gradients()->apply_gradients() sequence is safe.
+
def get_name(self):
return self._name
@@ -447,14 +471,33 @@ class Optimizer(
if var_list is not None:
tape.watch(var_list)
loss_value = loss()
+
+ # Scale loss if using a "mean" loss reduction and multiple towers.
+ # Have to be careful to call distribute_lib.get_loss_reduction()
+ # *after* loss() is evaluated, so we know what loss reduction it uses.
+ # TODO(josh11b): Test that we handle weight decay in a reasonable way.
+ if distribute_lib.get_loss_reduction() == "mean":
+ num_towers = distribute_lib.get_distribution_strategy().num_towers
+ if num_towers > 1:
+ loss_value *= (1. / num_towers)
+
if var_list is None:
var_list = tape.watched_variables()
grads = tape.gradient(loss_value, var_list, grad_loss)
return list(zip(grads, var_list))
+
+ # Non-callable/Tensor loss case
if context.executing_eagerly():
raise RuntimeError(
"`loss` passed to Optimizer.compute_gradients should "
"be a function when eager execution is enabled.")
+
+ # Scale loss if using a "mean" loss reduction and multiple towers.
+ if distribute_lib.get_loss_reduction() == "mean":
+ num_towers = distribute_lib.get_distribution_strategy().num_towers
+ if num_towers > 1:
+ loss *= (1. / num_towers)
+
if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
Optimizer.GATE_GRAPH]:
raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
@@ -510,11 +553,25 @@ class Optimizer(
Raises:
TypeError: If `grads_and_vars` is malformed.
ValueError: If none of the variables have gradients.
+ RuntimeError: If you should use `_distributed_apply()` instead.
"""
# This is a default implementation of apply_gradients() that can be shared
# by most optimizers. It relies on the subclass implementing the following
# methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
+ # Handle DistributionStrategy case.
+ if distribute_lib.get_cross_tower_context():
+ raise RuntimeError("Use `_distributed_apply()` instead of "
+ "`apply_gradients()` in a cross-tower context.")
+ # TODO(isaprykin): Get rid of `has_distribution_strategy()` check by
+ # always calling _distributed_apply(), using the default distribution
+ # as needed.
+ if distribute_lib.has_distribution_strategy():
+ grads_and_vars = get_filtered_grad_fn(lambda _: grads_and_vars)()
+ return distribute_lib.get_tower_context().merge_call(
+ self._distributed_apply, grads_and_vars, global_step, name)
+
+ # No DistributionStrategy case.
grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works.
if not grads_and_vars:
raise ValueError("No variables provided.")
@@ -582,6 +639,95 @@ class Optimizer(
return apply_updates
+ def _distributed_apply(self,
+ distribution,
+ grads_and_vars,
+ global_step=None,
+ name=None):
+ """A version of `apply_gradients` for cross-tower context.
+
+ This is a version of `apply_gradients()` for when you are using a
+ `DistributionStrategy` and are in a cross-tower context. If in a
+ tower context, use `apply_gradients()` as normal.
+
+ Args:
+ distribution: A `DistributionStrategy` object.
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ `compute_gradients()`, and then aggregated across towers.
+ global_step: Optional (mirrored) `Variable` to increment by one
+ after the variables have been updated.
+ name: Optional name for the returned operation. Default to the
+ name passed to the `Optimizer` constructor.
+
+ Returns:
+ An `Operation` that applies the specified gradients across all
+ towers. If `global_step` was not None, that operation also
+ increments `global_step`.
+ """
+ reduced_grads = distribution.batch_reduce("sum", grads_and_vars)
+ var_list = [v for _, v in grads_and_vars]
+ grads_and_vars = zip(reduced_grads, var_list)
+ # Note that this is called in a cross-tower context.
+ self._create_slots(var_list)
+
+ def update(v, g):
+ """Apply gradients to a replica variable."""
+ assert v is not None
+
+ try:
+ # Convert the grad to Tensor or IndexedSlices if necessary.
+ g = ops.convert_to_tensor_or_indexed_slices(g)
+ except TypeError:
+ raise TypeError("Gradient must be convertible to a Tensor"
+ " or IndexedSlices, or None: %s" % g)
+ if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
+ raise TypeError(
+ "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
+ p = _get_processor(v)
+
+ scope_name = "" if context.executing_eagerly() else v.op.name
+ # device_policy is set because non-mirrored tensors will be read in
+ # `update_op`. `_resource_apply_dense`, `lr_t`, `beta1_t` and `beta2_t`
+ # is an example.
+ with ops.name_scope(
+ "update_" + scope_name), context.context().device_policy(
+ context.DEVICE_PLACEMENT_SILENT):
+ return p.update_op(self, g)
+
+ with ops.name_scope(name, self._name) as name:
+ self._prepare()
+
+ update_ops = [
+ op
+ for grad, var in grads_and_vars
+ for op in distribution.unwrap(distribution.update(var, update, grad))
+ ]
+
+ def finish(self, update_ops):
+ return self._finish(update_ops, "update")
+
+ non_slot_devices = distribution.non_slot_devices(var_list)
+ # Device policy is needed because hyperparameter tensors (such as
+ # AdamOptimizer's beta1_t) need to be copied across devices in Eager.
+ with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ finish_updates = distribution.update_non_slot(
+ non_slot_devices, finish, self, update_ops)
+ if global_step is None:
+ apply_updates = distribution.group(finish_updates, name=name)
+ else:
+ with ops.control_dependencies(distribution.unwrap(finish_updates)):
+ apply_updates = distribution.group(distribution.update(
+ global_step, state_ops.assign_add, 1, name=name))
+
+ if not context.executing_eagerly():
+ if isinstance(apply_updates, ops.Tensor):
+ apply_updates = apply_updates.op
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ if apply_updates not in train_op:
+ train_op.append(apply_updates)
+
+ return apply_updates
+
def get_slot(self, var, name):
"""Return a slot named `name` created for `var` by the Optimizer.
@@ -599,9 +745,25 @@ class Optimizer(
Returns:
The `Variable` for the slot if it was created, `None` otherwise.
"""
+ # pylint: disable=protected-access
named_slots = self._slots.get(name, None)
if not named_slots:
return None
+
+ if hasattr(var, "_mirrored_container"):
+ # NOTE: If this isn't patched, then there is no `handle` in
+ # `_resource_apply_dense`.
+ mirrored_container = var._mirrored_container()
+ assert mirrored_container is not None
+ if context.executing_eagerly():
+ key = mirrored_container._unique_id
+ else:
+ key = (mirrored_container.graph, mirrored_container._shared_name)
+ # pylint: enable=protected-access
+ mirrored_slot = named_slots.get(key, None)
+ if mirrored_slot is None: return None
+ return mirrored_slot.get(device=var.device)
+
return named_slots.get(_var_key(var), None)
def get_slot_names(self):
@@ -645,6 +807,7 @@ class Optimizer(
def _create_non_slot_variable(self, initial_value, name, colocate_with):
"""Add an extra variable, not associated with a slot."""
+ # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
eager = context.executing_eagerly()
graph = None if eager else colocate_with.graph
@@ -652,7 +815,8 @@ class Optimizer(
v = self._non_slot_dict.get(key, None)
if v is None:
self._maybe_initialize_checkpointable()
- with ops.colocate_with(colocate_with):
+ distribution_strategy = distribute_lib.get_distribution_strategy()
+ with distribution_strategy.colocate_vars_with(colocate_with):
if eager:
restored_initial_value = self._preload_simple_restoration(
name=name, shape=None)
@@ -694,7 +858,13 @@ class Optimizer(
return self._get_non_slot_variable(name, graph=graph)
def _get_non_slot_variable(self, name, graph=None):
- return self._non_slot_dict.get((name, graph), None)
+ non_slot = self._non_slot_dict.get((name, graph), None)
+ if hasattr(non_slot, "_mirrored_container"):
+ # This is a mirrored non-slot. In order to enable code like `_finish`
+ # to assign to a non-slot, return the current context replica.
+ return non_slot.get()
+ else:
+ return non_slot
def _non_slot_variables(self):
"""Additional variables created by the `Optimizer`.
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index ba0d038475..cec581d997 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -1924,12 +1924,22 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
else:
meta_graph_def = meta_graph_or_file
- meta_graph.import_scoped_meta_graph(meta_graph_def,
- clear_devices=clear_devices,
- import_scope=import_scope,
- **kwargs)
+ imported_vars = meta_graph.import_scoped_meta_graph(
+ meta_graph_def,
+ clear_devices=clear_devices,
+ import_scope=import_scope,
+ **kwargs)
+
if meta_graph_def.HasField("saver_def"):
- return Saver(saver_def=meta_graph_def.saver_def, name=import_scope)
+ # Infer the scope that is prepended by `import_scoped_meta_graph`.
+ scope = import_scope
+ var_names = list(imported_vars.keys())
+ if var_names:
+ sample_key = var_names[0]
+ sample_var = imported_vars[sample_key]
+ scope = sample_var.name[:-len(sample_key)]
+
+ return Saver(saver_def=meta_graph_def.saver_def, name=scope)
else:
if variables._all_saveable_objects(): # pylint: disable=protected-access
# Return the default saver instance for all graph variables.
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 7de778f298..d1c24b3930 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -2341,6 +2341,38 @@ class MetaGraphTest(test.TestCase):
10, size=[1, 10])
})
+ def testImportIntoImplicitNamescope(self):
+ # Test that we can import a meta graph into an implicit namescope.
+ test_dir = self._get_test_dir("import_into_namescope")
+ filename = os.path.join(test_dir, "ckpt")
+ image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
+ label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
+ with session.Session() as sess:
+ weights = variables.Variable(
+ random_ops.random_uniform([784, 10]), name="weights")
+ bias = variables.Variable(array_ops.zeros([10]), name="bias")
+ logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits")
+ nn_ops.softmax(logit, name="prediction")
+ cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
+ logits=logit, name="cost")
+ adam.AdamOptimizer().minimize(cost, name="optimize")
+ saver = saver_module.Saver()
+ sess.run(variables.global_variables_initializer())
+ saver.save(sess, filename)
+
+ graph = ops_lib.Graph()
+ with session.Session(graph=graph) as sess:
+ with ops_lib.name_scope("new_model"):
+ new_saver = saver_module.import_meta_graph(
+ filename + ".meta", graph=graph)
+
+ new_saver.restore(sess, filename)
+ sess.run(["new_model/optimize"], {
+ "new_model/image:0": np.random.random([1, 784]),
+ "new_model/label:0": np.random.randint(
+ 10, size=[1, 10])
+ })
+
def testClearDevicesOnImport(self):
# Test that we import a graph without its devices and run successfully.
with ops_lib.Graph().as_default():
diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py
index 9ac52dd071..258a6f045d 100644
--- a/tensorflow/python/training/slot_creator.py
+++ b/tensorflow/python/training/slot_creator.py
@@ -40,12 +40,12 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import context
-from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.training import distribute as distribute_lib
def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
@@ -112,7 +112,8 @@ def create_slot(primary, val, name, colocate_with_primary=True):
prefix = primary.op.name
with variable_scope.variable_scope(None, prefix + "/" + name):
if colocate_with_primary:
- with ops.colocate_with(primary):
+ distribution_strategy = distribute_lib.get_distribution_strategy()
+ with distribution_strategy.colocate_vars_with(primary):
return _create_slot_var(primary, val, "", validate_shape, None, None)
else:
return _create_slot_var(primary, val, "", validate_shape, None, None)
@@ -148,7 +149,8 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name,
prefix = primary.op.name
with variable_scope.variable_scope(None, prefix + "/" + name):
if colocate_with_primary:
- with ops.colocate_with(primary):
+ distribution_strategy = distribute_lib.get_distribution_strategy()
+ with distribution_strategy.colocate_vars_with(primary):
return _create_slot_var(primary, initializer, "", validate_shape, shape,
dtype)
else:
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 23c2c48f4b..5622431bc9 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -60,15 +60,7 @@ def _is_namedtuple(instance, strict=False):
Returns:
True if `instance` is a `namedtuple`.
"""
- # Attemp to limit the test to plain namedtuple (not stuff inheriting from it).
- if not isinstance(instance, tuple):
- return False
- if strict and instance.__class__.__base__ != tuple:
- return False
- return (
- hasattr(instance, "_fields") and
- isinstance(instance._fields, _collections.Sequence) and
- all(isinstance(f, _six.string_types) for f in instance._fields))
+ return _pywrap_tensorflow.IsNamedtuple(instance, strict)
def _sequence_like(instance, args):
@@ -157,76 +149,7 @@ def flatten(nest):
def _same_namedtuples(nest1, nest2):
"""Returns True if the two namedtuples have the same name and fields."""
- if nest1._fields != nest2._fields:
- return False
- if nest1.__class__.__name__ != nest2.__class__.__name__:
- return False
- return True
-
-
-def _recursive_assert_same_structure(nest1, nest2, check_types):
- """Helper function for `assert_same_structure`.
-
- See `assert_same_structure` for further information about namedtuples.
-
- Args:
- nest1: An arbitrarily nested structure.
- nest2: An arbitrarily nested structure.
- check_types: If `True` (default) types of sequences are checked as
- well, including the keys of dictionaries. If set to `False`, for example
- a list and a tuple of objects will look the same if they have the same
- size. Note that namedtuples with identical name and fields are always
- considered to have the same shallow structure.
-
- Returns:
- True if `nest1` and `nest2` have the same structure.
-
- Raises:
- ValueError: If the two structure don't have the same nested structre.
- TypeError: If the two structure don't have the same sequence type.
- ValueError: If the two dictionaries don't have the same set of keys.
- """
- is_sequence_nest1 = is_sequence(nest1)
- if is_sequence_nest1 != is_sequence(nest2):
- raise ValueError(
- "The two structures don't have the same nested structure.\n\n"
- "First structure: %s\n\nSecond structure: %s." % (nest1, nest2))
-
- if not is_sequence_nest1:
- return # finished checking
-
- if check_types:
- type_nest1 = type(nest1)
- type_nest2 = type(nest2)
-
- # Duck-typing means that nest should be fine with two different namedtuples
- # with identical name and fields.
- if _is_namedtuple(nest1, True) and _is_namedtuple(nest2, True):
- if not _same_namedtuples(nest1, nest2):
- raise TypeError(
- "The two namedtuples don't have the same sequence type. First "
- "structure has type %s, while second structure has type %s."
- % (type_nest1, type_nest2))
- else:
- if type_nest1 != type_nest2:
- raise TypeError(
- "The two structures don't have the same sequence type. First "
- "structure has type %s, while second structure has type %s."
- % (type_nest1, type_nest2))
-
- if isinstance(nest1, dict):
- keys1 = set(_six.iterkeys(nest1))
- keys2 = set(_six.iterkeys(nest2))
- if keys1 != keys2:
- raise ValueError(
- "The two dictionaries don't have the same set of keys. First "
- "structure has keys {}, while second structure has keys {}."
- .format(keys1, keys2))
-
- nest1_as_sequence = [n for n in _yield_value(nest1)]
- nest2_as_sequence = [n for n in _yield_value(nest2)]
- for n1, n2 in zip(nest1_as_sequence, nest2_as_sequence):
- _recursive_assert_same_structure(n1, n2, check_types)
+ return _pywrap_tensorflow.SameNamedtuples(nest1, nest2)
def assert_same_structure(nest1, nest2, check_types=True):
@@ -257,14 +180,7 @@ def assert_same_structure(nest1, nest2, check_types=True):
TypeError: If the two structures differ in the type of sequence in any of
their substructures. Only possible if `check_types` is `True`.
"""
- len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1
- len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1
- if len_nest1 != len_nest2:
- raise ValueError("The two structures don't have the same number of "
- "elements.\n\nFirst structure (%i elements): %s\n\n"
- "Second structure (%i elements): %s"
- % (len_nest1, nest1, len_nest2, nest2))
- _recursive_assert_same_structure(nest1, nest2, check_types)
+ _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types)
def flatten_dict_items(dictionary):
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index 4439d6241e..2f12b25354 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -19,11 +19,14 @@ from __future__ import division
from __future__ import print_function
import collections
+import time
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -32,6 +35,9 @@ from tensorflow.python.util import nest
class NestTest(test.TestCase):
+ PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPack(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
@@ -39,8 +45,8 @@ class NestTest(test.TestCase):
self.assertEqual(
nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
("d", "e", ("f", "g"), "h")))
- point = collections.namedtuple("Point", ["x", "y"])
- structure = (point(x=4, y=2), ((point(x=1, y=0),),))
+ structure = (NestTest.PointXY(x=4, y=2),
+ ((NestTest.PointXY(x=1, y=0),),))
flat = [4, 2, 1, 0]
self.assertEqual(nest.flatten(structure), flat)
restructured_from_flat = nest.pack_sequence_as(structure, flat)
@@ -66,6 +72,7 @@ class NestTest(test.TestCase):
with self.assertRaises(ValueError):
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
+ @test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenDictOrder(self):
"""`flatten` orders dicts by key, including OrderedDicts."""
ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
@@ -87,12 +94,14 @@ class NestTest(test.TestCase):
ordered_reconstruction)
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
+ Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPack_withDicts(self):
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
- named_tuple = collections.namedtuple("A", ("b", "c"))
mess = [
"z",
- named_tuple(3, 4),
+ NestTest.Abc(3, 4),
{
"c": [
1,
@@ -111,7 +120,7 @@ class NestTest(test.TestCase):
structure_of_mess = [
14,
- named_tuple("a", True),
+ NestTest.Abc("a", True),
{
"c": [
0,
@@ -157,6 +166,7 @@ class NestTest(test.TestCase):
nest.pack_sequence_as(["hello", "world"],
["and", "goodbye", "again"])
+ @test_util.assert_no_new_pyobjects_executing_eagerly
def testIsSequence(self):
self.assertFalse(nest.is_sequence("1234"))
self.assertTrue(nest.is_sequence([1, 3, [4, 5]]))
@@ -186,6 +196,23 @@ class NestTest(test.TestCase):
ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
nest.flatten_dict_items(another_bad_dictionary)
+ # pylint does not correctly recognize these as class names and
+ # suggests to use variable style under_score naming.
+ # pylint: disable=invalid-name
+ Named0ab = collections.namedtuple("named_0", ("a", "b"))
+ Named1ab = collections.namedtuple("named_1", ("a", "b"))
+ SameNameab = collections.namedtuple("same_name", ("a", "b"))
+ SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
+ SameNamexy = collections.namedtuple("same_name", ("x", "y"))
+ SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
+ SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
+ NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
+ # pylint: enable=invalid-name
+
+ class SameNamedType1(SameNameab):
+ pass
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
def testAssertSameStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
@@ -198,23 +225,32 @@ class NestTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
- ("don't have the same number of elements\\.\n\n"
- "First structure \\(6 elements\\):.*?"
- "\n\nSecond structure \\(2 elements\\):")):
+ ("The two structures don't have the same nested structure\\.\n\n"
+ "First structure:.*?\n\n"
+ "Second structure:.*\n\n"
+ "More specifically: Substructure "
+ r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
+ 'substructure "type=str str=spam" is not')):
nest.assert_same_structure(structure1, structure_different_num_elements)
with self.assertRaisesRegexp(
ValueError,
- ("don't have the same number of elements\\.\n\n"
- "First structure \\(2 elements\\):.*?"
- "\n\nSecond structure \\(1 elements\\):")):
+ ("The two structures don't have the same nested structure\\.\n\n"
+ "First structure:.*?\n\n"
+ "Second structure:.*\n\n"
+ r'More specifically: Substructure "type=list str=\[0, 1\]" '
+ r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
+ "is not")):
nest.assert_same_structure([0, 1], np.array([0, 1]))
with self.assertRaisesRegexp(
ValueError,
- ("don't have the same number of elements\\.\n\n"
- "First structure \\(1 elements\\):.*"
- "\n\nSecond structure \\(2 elements\\):")):
+ ("The two structures don't have the same nested structure\\.\n\n"
+ "First structure:.*?\n\n"
+ "Second structure:.*\n\n"
+ r'More specifically: Substructure "type=list str=\[0, 1\]" '
+ 'is a sequence, while substructure "type=int str=0" '
+ "is not")):
nest.assert_same_structure(0, [0, 1])
self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])
@@ -225,21 +261,21 @@ class NestTest(test.TestCase):
"First structure: .*?\n\nSecond structure: ")):
nest.assert_same_structure(structure1, structure_different_nesting)
- named_type_0 = collections.namedtuple("named_0", ("a", "b"))
- named_type_1 = collections.namedtuple("named_1", ("a", "b"))
self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
- named_type_0("a", "b"))
+ NestTest.Named0ab("a", "b"))
- nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b"))
+ nest.assert_same_structure(NestTest.Named0ab(3, 4),
+ NestTest.Named0ab("a", "b"))
self.assertRaises(TypeError, nest.assert_same_structure,
- named_type_0(3, 4), named_type_1(3, 4))
+ NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))
with self.assertRaisesRegexp(
ValueError,
("don't have the same nested structure\\.\n\n"
"First structure: .*?\n\nSecond structure: ")):
- nest.assert_same_structure(named_type_0(3, 4), named_type_0([3], 4))
+ nest.assert_same_structure(NestTest.Named0ab(3, 4),
+ NestTest.Named0ab([3], 4))
with self.assertRaisesRegexp(
ValueError,
@@ -258,36 +294,33 @@ class NestTest(test.TestCase):
"don't have the same set of keys"):
nest.assert_same_structure({"a": 1}, {"b": 1})
- same_name_type_0 = collections.namedtuple("same_name", ("a", "b"))
- same_name_type_1 = collections.namedtuple("same_name", ("a", "b"))
- nest.assert_same_structure(same_name_type_0(0, 1), same_name_type_1(2, 3))
+ nest.assert_same_structure(NestTest.SameNameab(0, 1),
+ NestTest.SameNameab2(2, 3))
# This assertion is expected to pass: two namedtuples with the same
# name and field names are considered to be identical.
- same_name_type_2 = collections.namedtuple("same_name_1", ("x", "y"))
- same_name_type_3 = collections.namedtuple("same_name_1", ("x", "y"))
nest.assert_same_structure(
- same_name_type_0(same_name_type_2(0, 1), 2),
- same_name_type_1(same_name_type_3(2, 3), 4))
+ NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
+ NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))
expected_message = "The two structures don't have the same.*"
with self.assertRaisesRegexp(ValueError, expected_message):
- nest.assert_same_structure(same_name_type_0(0, same_name_type_1(1, 2)),
- same_name_type_1(same_name_type_0(0, 1), 2))
+ nest.assert_same_structure(
+ NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
+ NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))
- same_name_type_1 = collections.namedtuple("not_same_name", ("a", "b"))
self.assertRaises(TypeError, nest.assert_same_structure,
- same_name_type_0(0, 1), same_name_type_1(2, 3))
+ NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))
- same_name_type_1 = collections.namedtuple("same_name", ("x", "y"))
self.assertRaises(TypeError, nest.assert_same_structure,
- same_name_type_0(0, 1), same_name_type_1(2, 3))
+ NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))
- class SameNamedType1(collections.namedtuple("same_name", ("a", "b"))):
- pass
self.assertRaises(TypeError, nest.assert_same_structure,
- same_name_type_0(0, 1), SameNamedType1(2, 3))
+ NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))
+ EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
def testMapStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = (((7, 8), 9), 10, (11, 12))
@@ -310,9 +343,8 @@ class NestTest(test.TestCase):
self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
self.assertEqual({}, nest.map_structure(lambda x: x + 1, {}))
- empty_nt = collections.namedtuple("empty_nt", "")
- self.assertEqual(empty_nt(), nest.map_structure(lambda x: x + 1,
- empty_nt()))
+ self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1,
+ NestTest.EmptyNT()))
# This is checking actual equality of types, empty list != empty tuple
self.assertNotEqual((), nest.map_structure(lambda x: x + 1, []))
@@ -352,10 +384,12 @@ class NestTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
+ ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
def testMapStructureWithStrings(self):
- ab_tuple = collections.namedtuple("ab_tuple", "a, b")
- inp_a = ab_tuple(a="foo", b=("bar", "baz"))
- inp_b = ab_tuple(a=2, b=(1, 3))
+ inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
+ inp_b = NestTest.ABTuple(a=2, b=(1, 3))
out = nest.map_structure(lambda string, repeats: string * repeats,
inp_a,
inp_b)
@@ -363,8 +397,8 @@ class NestTest(test.TestCase):
self.assertEqual("bar", out.b[0])
self.assertEqual("bazbazbaz", out.b[1])
- nt = ab_tuple(a=("something", "something_else"),
- b="yet another thing")
+ nt = NestTest.ABTuple(a=("something", "something_else"),
+ b="yet another thing")
rev_nt = nest.map_structure(lambda x: x[::-1], nt)
# Check the output is the correct structure, and all strings are reversed.
nest.assert_same_structure(nt, rev_nt)
@@ -431,10 +465,8 @@ class NestTest(test.TestCase):
# This assertion is expected to pass: two namedtuples with the same
# name and field names are considered to be identical.
- same_name_type_0 = collections.namedtuple("same_name", ("a", "b"))
- same_name_type_1 = collections.namedtuple("same_name", ("a", "b"))
- inp_shallow = same_name_type_0(1, 2)
- inp_deep = same_name_type_1(1, [1, 2, 3])
+ inp_shallow = NestTest.SameNameab(1, 2)
+ inp_deep = NestTest.SameNameab2(1, [1, 2, 3])
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
@@ -466,7 +498,7 @@ class NestTest(test.TestCase):
[1, {"c": 2}, 3, (4, 5)])
# Namedtuples.
- ab_tuple = collections.namedtuple("ab_tuple", "a, b")
+ ab_tuple = NestTest.ABTuple
input_tree = ab_tuple(a=[0, 1], b=2)
shallow_tree = ab_tuple(a=0, b=1)
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
@@ -681,5 +713,31 @@ class NestTest(test.TestCase):
list(nest.flatten_with_joined_string_paths(inputs)), expected)
+class NestBenchmark(test.Benchmark):
+
+ def run_and_report(self, s1, s2, name):
+ burn_iter, test_iter = 100, 30000
+
+ for _ in xrange(burn_iter):
+ nest.assert_same_structure(s1, s2)
+
+ t0 = time.time()
+ for _ in xrange(test_iter):
+ nest.assert_same_structure(s1, s2)
+ t1 = time.time()
+
+ self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
+ name=name)
+
+ def benchmark_assert_structure(self):
+ s1 = (((1, 2), 3), 4, (5, 6))
+ s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
+ self.run_and_report(s1, s2, "assert_same_structure_6_elem")
+
+ s1 = (((1, 2), 3), 4, (5, 6)) * 10
+ s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10
+ self.run_and_report(s1, s2, "assert_same_structure_60_elem")
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index a41fa7df25..70aee4a3f6 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/python/lib/core/safe_ptr.h"
namespace tensorflow {
namespace swig {
@@ -27,6 +28,113 @@ PyObject* CollectionsSequenceType = nullptr;
bool WarnedThatSetIsNotSequence = false;
+bool IsString(PyObject* o) {
+ return PyBytes_Check(o) ||
+#if PY_MAJOR_VERSION < 3
+ PyString_Check(o) ||
+#endif
+ PyUnicode_Check(o);
+}
+
+// Equivalent to Python's 'o.__class__.__name__'
+// Note that '__class__' attribute is set only in new-style classes.
+// A lot of tensorflow code uses __class__ without checks, so it seems like
+// we only support new-style classes.
+StringPiece GetClassName(PyObject* o) {
+ // __class__ is equivalent to type() for new style classes.
+ // type() is equivalent to PyObject_Type()
+ // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type)
+ // PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which
+ // we don't need here.
+ PyTypeObject* type = o->ob_type;
+
+ // __name__ is the value of `tp_name` after the last '.'
+ // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name)
+ StringPiece name(type->tp_name);
+ size_t pos = name.rfind('.');
+ if (pos != StringPiece::npos) {
+ name.remove_prefix(pos + 1);
+ }
+ return name;
+}
+
+string PyObjectToString(PyObject* o) {
+ if (o == nullptr) {
+ return "<null object>";
+ }
+ PyObject* str = PyObject_Str(o);
+ if (str) {
+#if PY_MAJOR_VERSION < 3
+ string s(PyString_AS_STRING(str));
+#else
+ string s(PyUnicode_AsUTF8(str));
+#endif
+ Py_DECREF(str);
+ return tensorflow::strings::StrCat("type=", GetClassName(o), " str=", s);
+ } else {
+ return "<failed to execute str() on object>";
+ }
+}
+
+// Implements the same idea as tensorflow.util.nest._yield_value
+// During construction we check if the iterable is a dictionary.
+// If so, we construct a sequence from its sorted keys that will be used
+// for iteration.
+// If not, we construct a sequence directly from the iterable.
+// At each step, we get the next element from the sequence and use it
+// either as a key or return it directly.
+//
+// 'iterable' must not be modified while ValIterator is used.
+class ValIterator {
+ public:
+ explicit ValIterator(PyObject* iterable) : dict_(nullptr), index_(0) {
+ if (PyDict_Check(iterable)) {
+ dict_ = iterable;
+ // PyDict_Keys returns a list, which can be used with
+ // PySequence_Fast_GET_ITEM.
+ seq_ = PyDict_Keys(iterable);
+ // Iterate through dictionaries in a deterministic order by sorting the
+ // keys. Notice this means that we ignore the original order of
+ // `OrderedDict` instances. This is intentional, to avoid potential
+ // bugs caused by mixing ordered and plain dicts (e.g., flattening
+ // a dict but using a corresponding `OrderedDict` to pack it back).
+ PyList_Sort(seq_);
+ } else {
+ seq_ = PySequence_Fast(iterable, "");
+ }
+ size_ = PySequence_Fast_GET_SIZE(seq_);
+ }
+
+ ~ValIterator() { Py_DECREF(seq_); }
+
+ // Return a borrowed reference to the next element from iterable.
+ // Return nullptr when iteration is over.
+ PyObject* next() {
+ PyObject* element = nullptr;
+ if (index_ < size_) {
+ // Both PySequence_Fast_GET_ITEM and PyDict_GetItem return borrowed
+ // references.
+ element = PySequence_Fast_GET_ITEM(seq_, index_);
+ ++index_;
+ if (dict_ != nullptr) {
+ element = PyDict_GetItem(dict_, element);
+ if (element == nullptr) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Dictionary was modified during iteration over it");
+ return nullptr;
+ }
+ }
+ }
+ return element;
+ }
+
+ private:
+ PyObject* seq_;
+ PyObject* dict_;
+ Py_ssize_t size_;
+ Py_ssize_t index_;
+};
+
// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
// Returns 0 otherwise.
// Returns -1 if an error occurred.
@@ -38,7 +146,7 @@ int IsSequenceHelper(PyObject* o) {
"so consider avoiding using them.";
WarnedThatSetIsNotSequence = true;
}
- if (CollectionsSequenceType == nullptr) {
+ if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
PyErr_SetString(
PyExc_RuntimeError,
tensorflow::strings::StrCat(
@@ -49,11 +157,7 @@ int IsSequenceHelper(PyObject* o) {
}
int is_instance = PyObject_IsInstance(o, CollectionsSequenceType);
if (is_instance == -1) return -1;
- return static_cast<int>(is_instance != 0 && !PyBytes_Check(o) &&
-#if PY_MAJOR_VERSION < 3
- !PyString_Check(o) &&
-#endif
- !PyUnicode_Check(o));
+ return static_cast<int>(is_instance != 0 && !IsString(o));
}
bool FlattenHelper(PyObject* nested, PyObject* list) {
@@ -75,12 +179,16 @@ bool FlattenHelper(PyObject* nested, PyObject* list) {
// while the method is running.
PyObject* key = PyList_GET_ITEM(keys, i);
PyObject* val = PyDict_GetItem(nested, key);
- if (Py_EnterRecursiveCall(" in Flatten")) {
+ if (Py_EnterRecursiveCall(" in flatten")) {
Py_DECREF(keys);
return false;
}
- FlattenHelper(val, list);
+ const bool success = FlattenHelper(val, list);
Py_LeaveRecursiveCall();
+ if (!success) {
+ Py_DECREF(keys);
+ return false;
+ }
}
Py_DECREF(keys);
return true;
@@ -90,13 +198,159 @@ bool FlattenHelper(PyObject* nested, PyObject* list) {
PyObject* item;
PyObject* iterator = PyObject_GetIter(nested);
while ((item = PyIter_Next(iterator)) != nullptr) {
- FlattenHelper(item, list);
+ if (Py_EnterRecursiveCall(" in flatten")) {
+ Py_DECREF(iterator);
+ Py_DECREF(item);
+ return false;
+ }
+ bool success = FlattenHelper(item, list);
+ Py_LeaveRecursiveCall();
+ if (!success) {
+ Py_DECREF(iterator);
+ Py_DECREF(item);
+ return false;
+ }
Py_DECREF(item);
}
Py_DECREF(iterator);
return true;
}
+// Sets error using keys of 'dict1' and 'dict2'.
+// 'dict1' and 'dict2' are assumed to be Python dictionaries.
+void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
+ bool* is_type_error) {
+ PyObject* k1 = PyDict_Keys(dict1);
+ PyObject* k2 = PyDict_Keys(dict2);
+ *is_type_error = false;
+ *error_msg = tensorflow::strings::StrCat(
+ "The two dictionaries don't have the same set of keys. "
+ "First structure has keys ",
+ PyObjectToString(k1), ", while second structure has keys ",
+ PyObjectToString(k2));
+ Py_DECREF(k1);
+ Py_DECREF(k2);
+}
+
+// Returns true iff there were no "internal" errors. In other words,
+// errors that has nothing to do with structure checking.
+// If an "internal" error occured, the appropriate Python error will be
+// set and the caller can propage it directly to the user.
+//
+// Both `error_msg` and `is_type_error` must be non-null. `error_msg` must
+// be empty.
+// Leaves `error_msg` empty if structures matched. Else, fills `error_msg`
+// with appropriate error and sets `is_type_error` to true iff
+// the error to be raised should be TypeError.
+bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
+ string* error_msg, bool* is_type_error) {
+ DCHECK(error_msg);
+ DCHECK(is_type_error);
+ const bool is_seq1 = IsSequence(o1);
+ const bool is_seq2 = IsSequence(o2);
+ if (PyErr_Occurred()) return false;
+ if (is_seq1 != is_seq2) {
+ string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2);
+ string non_seq_str = is_seq1 ? PyObjectToString(o2) : PyObjectToString(o1);
+ *is_type_error = false;
+ *error_msg = tensorflow::strings::StrCat(
+ "Substructure \"", seq_str, "\" is a sequence, while substructure \"",
+ non_seq_str, "\" is not");
+ return true;
+ }
+
+ // Got to scalars, so finished checking. Structures are the same.
+ if (!is_seq1) return true;
+
+ if (check_types) {
+ const PyTypeObject* type1 = o1->ob_type;
+ const PyTypeObject* type2 = o2->ob_type;
+
+ // We treat two different namedtuples with identical name and fields
+ // as having the same type.
+ const PyObject* o1_tuple = IsNamedtuple(o1, true);
+ if (o1_tuple == nullptr) return false;
+ const PyObject* o2_tuple = IsNamedtuple(o2, true);
+ if (o2_tuple == nullptr) {
+ Py_DECREF(o1_tuple);
+ return false;
+ }
+ bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True;
+ Py_DECREF(o1_tuple);
+ Py_DECREF(o2_tuple);
+
+ if (both_tuples) {
+ const PyObject* same_tuples = SameNamedtuples(o1, o2);
+ if (same_tuples == nullptr) return false;
+ bool not_same_tuples = same_tuples != Py_True;
+ Py_DECREF(same_tuples);
+ if (not_same_tuples) {
+ *is_type_error = true;
+ *error_msg = tensorflow::strings::StrCat(
+ "The two namedtuples don't have the same sequence type. "
+ "First structure ",
+ PyObjectToString(o1), " has type ", type1->tp_name,
+ ", while second structure ", PyObjectToString(o2), " has type ",
+ type2->tp_name);
+ return true;
+ }
+ } else if (type1 != type2) {
+ *is_type_error = true;
+ *error_msg = tensorflow::strings::StrCat(
+ "The two namedtuples don't have the same sequence type. "
+ "First structure ",
+ PyObjectToString(o1), " has type ", type1->tp_name,
+ ", while second structure ", PyObjectToString(o2), " has type ",
+ type2->tp_name);
+ return true;
+ }
+
+ if (PyDict_Check(o1)) {
+ if (PyDict_Size(o1) != PyDict_Size(o2)) {
+ SetDifferentKeysError(o1, o2, error_msg, is_type_error);
+ return true;
+ }
+
+ PyObject* key;
+ Py_ssize_t pos = 0;
+ while (PyDict_Next(o1, &pos, &key, nullptr)) {
+ if (PyDict_GetItem(o2, key) == nullptr) {
+ SetDifferentKeysError(o1, o2, error_msg, is_type_error);
+ return true;
+ }
+ }
+ }
+ }
+
+ ValIterator iter1(o1);
+ ValIterator iter2(o2);
+
+ while (true) {
+ PyObject* v1 = iter1.next();
+ PyObject* v2 = iter2.next();
+ if (v1 != nullptr && v2 != nullptr) {
+ if (Py_EnterRecursiveCall(" in assert_same_structure")) {
+ return false;
+ }
+ bool no_internal_errors = AssertSameStructureHelper(
+ v1, v2, check_types, error_msg, is_type_error);
+ Py_LeaveRecursiveCall();
+ if (!no_internal_errors) return false;
+ if (!error_msg->empty()) return true;
+ } else if (v1 == nullptr && v2 == nullptr) {
+ // Done with all recursive calls. Structure matched.
+ return true;
+ } else {
+ *is_type_error = false;
+ *error_msg = tensorflow::strings::StrCat(
+ "The two structures don't have the same number of elements. ",
+ "First structure: ", PyObjectToString(o1),
+ ". Second structure: ", PyObjectToString(o2));
+ return true;
+ }
+ }
+}
+
} // anonymous namespace
void RegisterSequenceClass(PyObject* sequence_class) {
@@ -123,5 +377,107 @@ PyObject* Flatten(PyObject* nested) {
return nullptr;
}
}
+
+PyObject* IsNamedtuple(PyObject* o, bool strict) {
+ // Must be subclass of tuple
+ if (!PyTuple_Check(o)) {
+ Py_RETURN_FALSE;
+ }
+
+ // If strict, o.__class__.__base__ must be tuple
+ if (strict) {
+ PyObject* klass = PyObject_GetAttrString(o, "__class__");
+ if (klass == nullptr) return nullptr;
+ PyObject* base = PyObject_GetAttrString(klass, "__base__");
+ Py_DECREF(klass);
+ if (base == nullptr) return nullptr;
+
+ const PyTypeObject* base_type = reinterpret_cast<PyTypeObject*>(base);
+ // built-in object types are singletons
+ bool tuple_base = base_type == &PyTuple_Type;
+ Py_DECREF(base);
+ if (!tuple_base) {
+ Py_RETURN_FALSE;
+ }
+ }
+
+ if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Sequence type has not been set. "
+ "Please call RegisterSequenceClass before using this module")
+ .c_str());
+ return nullptr;
+ }
+
+ // o must have attribute '_fields' and every element in
+ // '_fields' must be a string.
+ int has_fields = PyObject_HasAttrString(o, "_fields");
+ if (!has_fields) {
+ Py_RETURN_FALSE;
+ }
+
+ Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
+ int is_instance = PyObject_IsInstance(fields.get(), CollectionsSequenceType);
+ if (is_instance == 0) {
+ Py_RETURN_FALSE;
+ } else if (is_instance == -1) {
+ return nullptr;
+ }
+
+ Safe_PyObjectPtr seq = make_safe(PySequence_Fast(fields.get(), ""));
+ const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get());
+ for (Py_ssize_t i = 0; i < s; ++i) {
+ // PySequence_Fast_GET_ITEM returns borrowed ref
+ PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i);
+ if (!IsString(elem)) {
+ Py_RETURN_FALSE;
+ }
+ }
+
+ Py_RETURN_TRUE;
+}
+
+PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
+ PyObject* f1 = PyObject_GetAttrString(o1, "_fields");
+ PyObject* f2 = PyObject_GetAttrString(o2, "_fields");
+ if (f1 == nullptr || f2 == nullptr) {
+ Py_XDECREF(f1);
+ Py_XDECREF(f2);
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ "Expected namedtuple-like objects (that have _fields attr)");
+ return nullptr;
+ }
+
+ if (PyObject_RichCompareBool(f1, f2, Py_NE)) {
+ Py_RETURN_FALSE;
+ }
+
+ if (GetClassName(o1).compare(GetClassName(o2)) == 0) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+
+PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types) {
+ string error_msg;
+ bool is_type_error = false;
+ AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error);
+ if (!error_msg.empty()) {
+ PyErr_SetString(
+ is_type_error ? PyExc_TypeError : PyExc_ValueError,
+ tensorflow::strings::StrCat(
+ "The two structures don't have the same nested structure.\n\n",
+ "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
+ PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
+ .c_str());
+ return nullptr;
+ }
+ Py_RETURN_NONE;
+}
+
} // namespace swig
} // namespace tensorflow
diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h
index 2af71dc753..c325baa5f8 100644
--- a/tensorflow/python/util/util.h
+++ b/tensorflow/python/util/util.h
@@ -33,6 +33,57 @@ namespace swig {
// dict.
bool IsSequence(PyObject* o);
+// Implements the same interface as tensorflow.util.nest._is_namedtuple
+// Returns Py_True iff `instance` should be considered a `namedtuple`.
+//
+// Args:
+// instance: An instance of a Python object.
+// strict: If True, `instance` is considered to be a `namedtuple` only if
+// it is a "plain" namedtuple. For instance, a class inheriting
+// from a `namedtuple` will be considered to be a `namedtuple`
+// iff `strict=False`.
+//
+// Returns:
+// True if `instance` is a `namedtuple`.
+PyObject* IsNamedtuple(PyObject* o, bool strict);
+
+// Implements the same interface as tensorflow.util.nest._same_namedtuples
+// Returns Py_True iff the two namedtuples have the same name and fields.
+// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
+// '_fields' attribute).
+PyObject* SameNamedtuples(PyObject* o1, PyObject* o2);
+
+// Asserts that two structures are nested in the same way.
+//
+// Note that namedtuples with identical name and fields are always considered
+// to have the same shallow structure (even with `check_types=True`).
+// For intance, this code will print `True`:
+//
+// ```python
+// def nt(a, b):
+// return collections.namedtuple('foo', 'a b')(a, b)
+// print(assert_same_structure(nt(0, 1), nt(2, 3)))
+// ```
+//
+// Args:
+// nest1: an arbitrarily nested structure.
+// nest2: an arbitrarily nested structure.
+// check_types: if `true`, types of sequences are checked as
+// well, including the keys of dictionaries. If set to `false`, for example
+// a list and a tuple of objects will look the same if they have the same
+// size. Note that namedtuples with identical name and fields are always
+// considered to have the same shallow structure.
+//
+// Raises:
+// ValueError: If the two structures do not have the same number of elements or
+// if the two structures are not nested in the same way.
+// TypeError: If the two structures differ in the type of sequence in any of
+// their substructures. Only possible if `check_types` is `True`.
+//
+// Returns:
+// Py_None on success, nullptr on error.
+PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types);
+
// Implements the same interface as tensorflow.util.nest.flatten
//
// Returns a flat list from a given nested structure.
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index d69084fc00..b7f201b6fe 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -34,6 +34,15 @@ limitations under the License.
%unignore tensorflow::swig::IsSequence;
%noexception tensorflow::swig::IsSequence;
+%unignore tensorflow::swig::IsNamedtuple;
+%noexception tensorflow::swig::IsNamedtuple;
+
+%unignore tensorflow::swig::SameNamedtuples;
+%noexception tensorflow::swig::SameNamedtuples;
+
+%unignore tensorflow::swig::AssertSameStructure;
+%noexception tensorflow::swig::AssertSameStructure;
+
%unignore tensorflow::swig::Flatten;
%noexception tensorflow::swig::Flatten;
diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD
index 1865240014..27cdb860fe 100644
--- a/tensorflow/stream_executor/BUILD
+++ b/tensorflow/stream_executor/BUILD
@@ -56,7 +56,10 @@ cc_library(
[
"cuda/*.cc",
],
- exclude = ["cuda/cuda_platform_id.cc"],
+ exclude = [
+ "cuda/*_test.cc",
+ "cuda/cuda_platform_id.cc",
+ ],
),
),
copts = select({
@@ -72,6 +75,7 @@ cc_library(
":stream_executor",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:ops_util",
+ "@com_google_absl//absl/strings",
"@local_config_cuda//cuda:cuda_headers",
] + if_cuda_is_configured([
"//tensorflow/core:cuda",
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index ab5e6590e0..1aea0485fd 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -18,7 +18,9 @@ limitations under the License.
#include <functional>
#include <memory>
+#include "absl/strings/str_cat.h"
#include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
@@ -27,6 +29,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/cuda/cuda_timer.h"
+#include "tensorflow/stream_executor/cuda/cudnn_version.h"
#include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/error.h"
@@ -55,15 +58,6 @@ NarrowT CheckedNarrowing(const WideT& wide) {
return narrow;
}
-// Returns the "Compatibility" version number from the CuDNN version number.
-// This is the number that tries to indicate ABI compatibility.
-//
-// For example, if cudnn_version is 5107, the compatibility version
-// number will be 5100.
-size_t cudnnCompatibilityVersion(size_t cudnn_version) {
- return (cudnn_version / 100) * 100;
-}
-
} // namespace
namespace perftools {
@@ -109,6 +103,22 @@ string ToString(cudnnStatus_t status) {
}
}
+#if CUDNN_VERSION >= 6000
+string ToString(libraryPropertyType type) {
+ switch (type) {
+ case MAJOR_VERSION:
+ return "MAJOR_VERSION";
+ case MINOR_VERSION:
+ return "MINOR_VERSION";
+ case PATCH_LEVEL:
+ return "PATCH_LEVEL";
+ default:
+ return absl::StrCat(
+ "<unknown libraryPropertyType: ", static_cast<int>(type), ">");
+ }
+}
+#endif
+
template <typename T>
cudnnDataType_t GetCudnnDataType();
@@ -360,6 +370,34 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
}
}
+#if CUDNN_VERSION >= 6000
+port::Status GetCudnnProperty(libraryPropertyType type, int* value) {
+ cudnnStatus_t status = cudnnGetProperty(type, value);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ const string error =
+ absl::StrCat("cudnnGetProperty failed for type: ", ToString(type),
+ " with status: ", ToString(status));
+ LOG(ERROR) << error;
+ return port::Status{port::error::INTERNAL, error};
+ }
+ return port::Status::OK();
+}
+#endif
+
+port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
+#if CUDNN_VERSION >= 6000
+ TF_RETURN_IF_ERROR(GetCudnnProperty(MAJOR_VERSION, &version->major_version));
+ TF_RETURN_IF_ERROR(GetCudnnProperty(MINOR_VERSION, &version->minor_version));
+ TF_RETURN_IF_ERROR(GetCudnnProperty(PATCH_LEVEL, &version->patch_level));
+#else
+ size_t loaded_version = ::cudnnGetVersion();
+ version->major_version = loaded_version / 1000;
+ version->minor_version = (loaded_version / 100) % 10;
+ version->patch_level = loaded_version % 100;
+#endif
+ return port::Status::OK();
+}
+
} // namespace
CudnnSupport::CudnnSupport(CUDAExecutor* parent)
@@ -376,24 +414,19 @@ port::Status CudnnSupport::Init() {
auto status = wrap::cudnnCreate(
parent_, reinterpret_cast<cudnnHandle_t*>(&dnn_handle_));
if (status == CUDNN_STATUS_SUCCESS) {
- // Check whether loaded version of CuDNN matches what the source
- // was built with.
- size_t loaded_version = ::cudnnGetVersion();
- size_t loaded_compat_version = cudnnCompatibilityVersion(loaded_version);
- size_t compiled_compat_version = cudnnCompatibilityVersion(CUDNN_VERSION);
- bool library_loaded_matches_source =
- (loaded_compat_version == compiled_compat_version);
- if (!library_loaded_matches_source) {
- const string error =
- port::StrCat("Loaded runtime CuDNN library: ", loaded_version,
- " (compatibility version ", loaded_compat_version,
- ") but source was compiled with ", CUDNN_VERSION,
- " (compatibility version ", compiled_compat_version,
- "). If using a binary install, upgrade your CuDNN "
- "library to match. If building from sources, "
- "make sure the library loaded at runtime matches a "
- "compatible version specified during compile "
- "configuration.");
+ CudnnVersion source_version(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
+
+ CudnnVersion loaded_version;
+ TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&loaded_version));
+ if (!IsSourceCompatibleWithCudnnLibrary(source_version, loaded_version)) {
+ const tensorflow::string error = absl::StrCat(
+ "Loaded runtime CuDNN library: ", loaded_version.ToString(),
+ " but source was compiled with: ", source_version.ToString(),
+ ". CuDNN library major and minor version needs to match or have "
+ "higher minor version in case of CuDNN 7.0 or later version. If "
+ "using a binary install, upgrade your CuDNN library. If building "
+ "from sources, make sure the library loaded at runtime is compatible "
+ "with the version specified during compile configuration.");
LOG(ERROR) << error;
return port::Status{port::error::INTERNAL, error};
}
diff --git a/tensorflow/stream_executor/cuda/cudnn_version.cc b/tensorflow/stream_executor/cuda/cudnn_version.cc
new file mode 100644
index 0000000000..5591801aae
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cudnn_version.cc
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/cuda/cudnn_version.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+bool IsSourceCompatibleWithCudnnLibrary(CudnnVersion source_version,
+ CudnnVersion loaded_version) {
+ // Major version is neither forward or backward compatible and therefore major
+ // versions needs to match between source and library.
+ //
+ // Minor version is backward-compatible beginning with CuDNN 7 and therefore
+ // minor version of library needs to be same or higher.
+ //
+ // Patch releases are always forward and backward compatible and therefore
+ // need not match.
+ if (loaded_version.major_version != source_version.major_version) {
+ return false;
+ }
+ return ((loaded_version.minor_version == source_version.minor_version) ||
+ (source_version.major_version >= 7 &&
+ loaded_version.minor_version >= source_version.minor_version));
+}
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/cuda/cudnn_version.h b/tensorflow/stream_executor/cuda/cudnn_version.h
new file mode 100644
index 0000000000..058cc87bfa
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cudnn_version.h
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDNN_VERSION_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDNN_VERSION_H_
+
+#include <string>
+
+#include "absl/strings/str_join.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+struct CudnnVersion {
+ CudnnVersion() = default;
+
+ CudnnVersion(int major, int minor, int patch)
+ : major_version(major), minor_version(minor), patch_level(patch) {}
+
+ std::string ToString() const {
+ return absl::StrJoin({major_version, minor_version, patch_level}, ".");
+ }
+
+ int major_version;
+ int minor_version;
+ int patch_level;
+};
+
+// Returns true if the given source CuDNN version is compatible with the given
+// loaded version.
+bool IsSourceCompatibleWithCudnnLibrary(CudnnVersion source_version,
+ CudnnVersion loaded_version);
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDNN_VERSION_H_
diff --git a/tensorflow/stream_executor/cuda/cudnn_version_test.cc b/tensorflow/stream_executor/cuda/cudnn_version_test.cc
new file mode 100644
index 0000000000..230adafeb1
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cudnn_version_test.cc
@@ -0,0 +1,75 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/cuda/cudnn_version.h"
+
+#include "testing/base/public/gunit.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+namespace {
+
+TEST(CuDNNVersion, ToString) {
+ CudnnVersion version(7, 0, 12);
+ EXPECT_EQ(version.ToString(), "7.0.12");
+}
+
+TEST(IsSourceCompatibleWithCudnnLibraryTest, Basic) {
+ // Returns true if both major and minor versions are matching and even if the
+ // patch versions are not matching.
+ EXPECT_TRUE(IsSourceCompatibleWithCudnnLibrary(
+ /*source_version=*/CudnnVersion(7, 0, 12),
+ /*loaded_version=*/CudnnVersion(7, 0, 14)));
+ EXPECT_TRUE(IsSourceCompatibleWithCudnnLibrary(
+ /*source_version=*/CudnnVersion(6, 1, 14),
+ /*loaded_version=*/CudnnVersion(6, 1, 00)));
+
+ // Returns false if major versions are not matching as they are neither
+ // forward or backward compatible.
+ EXPECT_FALSE(IsSourceCompatibleWithCudnnLibrary(
+ /*source_version=*/CudnnVersion(7, 0, 12),
+ /*loaded_version=*/CudnnVersion(6, 1, 14)));
+ EXPECT_FALSE(IsSourceCompatibleWithCudnnLibrary(
+ /*source_version=*/CudnnVersion(8, 1, 15),
+ /*loaded_version=*/CudnnVersion(7, 0, 14)));
+
+ // Returns true if the loaded version is equal or higher because minor version
+ // are backward compatible with CuDNN version 7.
+ EXPECT_TRUE(IsSourceCompatibleWithCudnnLibrary(
+ /*source_version=*/CudnnVersion(7, 0, 14),
+ /*loaded_version=*/CudnnVersion(7, 1, 14)));
+ EXPECT_TRUE(IsSourceCompatibleWithCudnnLibrary(
+ /*source_version=*/CudnnVersion(7, 0, 14),
+ /*loaded_version=*/CudnnVersion(7, 1, 15)));
+ EXPECT_FALSE(IsSourceCompatibleWithCudnnLibrary(
+ /*source_version=*/CudnnVersion(7, 1, 15),
+ /*loaded_version=*/CudnnVersion(7, 0, 14)));
+
+ // Returns false if minor versions are not matching for version 6. Before
+ // version 7, minor versions are also neither forward or backward compatible.
+ EXPECT_FALSE(IsSourceCompatibleWithCudnnLibrary(
+ /*source_version=*/CudnnVersion(6, 0, 14),
+ /*loaded_version=*/CudnnVersion(6, 1, 15)));
+ EXPECT_FALSE(IsSourceCompatibleWithCudnnLibrary(
+ /*source_version=*/CudnnVersion(6, 1, 14),
+ /*loaded_version=*/CudnnVersion(6, 0, 14)));
+}
+
+} // namespace
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/kernel.cc b/tensorflow/stream_executor/kernel.cc
index 81e531efb3..636199cfa2 100644
--- a/tensorflow/stream_executor/kernel.cc
+++ b/tensorflow/stream_executor/kernel.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/stream_executor/lib/demangle.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/logging.h"
@@ -96,7 +97,7 @@ static const char *kStubPrefix = "__device_stub_";
void KernelBase::set_name(port::StringPiece name) {
name_ = name.ToString();
port::StringPiece stubless_name = name;
- if (name.starts_with(kStubPrefix)) {
+ if (tensorflow::str_util::StartsWith(name, kStubPrefix)) {
stubless_name.remove_prefix(strlen(kStubPrefix));
}
demangled_name_ = port::Demangle(stubless_name.data());
diff --git a/tensorflow/stream_executor/lib/str_util.h b/tensorflow/stream_executor/lib/str_util.h
index 4dd6f3b0cc..5dd3d06aff 100644
--- a/tensorflow/stream_executor/lib/str_util.h
+++ b/tensorflow/stream_executor/lib/str_util.h
@@ -29,7 +29,7 @@ using tensorflow::str_util::Split;
// Returns a copy of the input string 'str' with the given 'suffix'
// removed. If the suffix doesn't match, returns a copy of the original string.
inline string StripSuffixString(port::StringPiece str, port::StringPiece suffix) {
- if (str.ends_with(suffix)) {
+ if (tensorflow::str_util::EndsWith(str, suffix)) {
str.remove_suffix(suffix.size());
}
return str.ToString();
diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD
index d9b0260c9f..6722536358 100644
--- a/tensorflow/tools/api/generator/BUILD
+++ b/tensorflow/tools/api/generator/BUILD
@@ -5,18 +5,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
py_binary(
name = "create_python_api",
srcs = ["create_python_api.py"],
diff --git a/tensorflow/tools/api/golden/BUILD b/tensorflow/tools/api/golden/BUILD
index 08436396a6..ebdf42df2c 100644
--- a/tensorflow/tools/api/golden/BUILD
+++ b/tensorflow/tools/api/golden/BUILD
@@ -10,15 +10,3 @@ filegroup(
name = "api_golden",
srcs = glob(["*.pbtxt"]),
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
new file mode 100644
index 0000000000..fd9be8c759
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -0,0 +1,54 @@
+path: "tensorflow.estimator.BoostedTreesClassifier"
+tf_class {
+ is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesClassifier\'>"
+ is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "config"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "model_dir"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "model_fn"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "params"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'None\'], "
+ }
+ member_method {
+ name: "evaluate"
+ argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "export_savedmodel"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
+ }
+ member_method {
+ name: "get_variable_names"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_variable_value"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "latest_checkpoint"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "predict"
+ argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+ }
+ member_method {
+ name: "train"
+ argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
new file mode 100644
index 0000000000..6b305be43f
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -0,0 +1,54 @@
+path: "tensorflow.estimator.BoostedTreesRegressor"
+tf_class {
+ is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesRegressor\'>"
+ is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "config"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "model_dir"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "model_fn"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "params"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'None\'], "
+ }
+ member_method {
+ name: "evaluate"
+ argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "export_savedmodel"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
+ }
+ member_method {
+ name: "get_variable_names"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_variable_value"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "latest_checkpoint"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "predict"
+ argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+ }
+ member_method {
+ name: "train"
+ argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt
index a7a6cc1e49..4946f2c51a 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt
@@ -9,6 +9,14 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "BoostedTreesClassifier"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "BoostedTreesRegressor"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "DNNClassifier"
mtype: "<type \'type\'>"
}
diff --git a/tensorflow/tools/api/lib/BUILD b/tensorflow/tools/api/lib/BUILD
index 2d3b838957..3f4fb91042 100644
--- a/tensorflow/tools/api/lib/BUILD
+++ b/tensorflow/tools/api/lib/BUILD
@@ -26,15 +26,3 @@ py_library(
"//tensorflow/python:util",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD
index 15bf1abb5f..0dc154b6d2 100644
--- a/tensorflow/tools/api/tests/BUILD
+++ b/tensorflow/tools/api/tests/BUILD
@@ -42,15 +42,3 @@ tf_cc_binary(
"//tensorflow/core:op_gen_lib",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/tools/benchmark/BUILD b/tensorflow/tools/benchmark/BUILD
index 6ed2594e6a..566a172ea7 100644
--- a/tensorflow/tools/benchmark/BUILD
+++ b/tensorflow/tools/benchmark/BUILD
@@ -90,12 +90,3 @@ tf_cc_binary(
visibility = ["//visibility:public"],
deps = [":benchmark_model_lib"],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = ["**/OWNERS"],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/tools/build_info/BUILD b/tensorflow/tools/build_info/BUILD
index cdc47076ce..7307417805 100644
--- a/tensorflow/tools/build_info/BUILD
+++ b/tensorflow/tools/build_info/BUILD
@@ -9,18 +9,3 @@ exports_files(
"gen_build_info.py",
],
)
-
-# -----------------------------------------------------------------------------
-# Google-internal targets. These must be at the end for syncrepo.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh
index 338066131b..c7cc16e669 100755
--- a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh
+++ b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh
@@ -33,6 +33,7 @@ yes "" | $PYTHON_BIN_PATH configure.py
which bazel
bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac,-no_mac \
--test_timeout 300,450,1200,3600 --config=opt \
+ --announce_rc \
--test_size_filters=small,medium \
--jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \
//tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/osx/cpu/run_py3_cc_core.sh b/tensorflow/tools/ci_build/osx/cpu/run_py3_cc_core.sh
index 920a261ae3..7e0e81a1eb 100755
--- a/tensorflow/tools/ci_build/osx/cpu/run_py3_cc_core.sh
+++ b/tensorflow/tools/ci_build/osx/cpu/run_py3_cc_core.sh
@@ -31,6 +31,7 @@ export PYTHON_BIN_PATH=$(which python3)
yes "" | $PYTHON_BIN_PATH configure.py
which bazel
bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac,-no_mac \
+ --announce_rc \
--test_timeout 300,450,1200,3600 \
--test_size_filters=small,medium \
--jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \
diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh
index 94276c6c5c..7dfee8f371 100644
--- a/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh
+++ b/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh
@@ -41,7 +41,7 @@ run_configure_for_gpu_build
# build_libtensorflow_tarball in ../builds/libtensorflow.sh
# cannot be used on Windows since it relies on pkg_tar rules.
# So we do something special here
-bazel build -c opt --copt=/arch:AVX \
+bazel build -c opt --copt=/arch:AVX --announce_rc \
tensorflow:libtensorflow.so \
tensorflow/tools/lib_package:clicenses_generate \
tensorflow/java:libtensorflow_jni.so \
diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD
index 316e5469e7..b9032c046e 100644
--- a/tensorflow/tools/common/BUILD
+++ b/tensorflow/tools/common/BUILD
@@ -44,14 +44,3 @@ py_test(
"//tensorflow/python:platform_test",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD
index 4f90c4d940..b7bfb29aae 100644
--- a/tensorflow/tools/compatibility/BUILD
+++ b/tensorflow/tools/compatibility/BUILD
@@ -68,18 +68,3 @@ exports_files(
"testdata/test_file_v0_11.py",
],
)
-
-# -----------------------------------------------------------------------------
-# Google-internal targets. These must be at the end for syncrepo.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/tools/dist_test/server/BUILD b/tensorflow/tools/dist_test/server/BUILD
index 865af8dd7b..003a19a9ab 100644
--- a/tensorflow/tools/dist_test/server/BUILD
+++ b/tensorflow/tools/dist_test/server/BUILD
@@ -37,15 +37,3 @@ py_test(
"//tensorflow/python:client_testlib",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/tools/docker/BUILD b/tensorflow/tools/docker/BUILD
index 7d5ae0a94d..849ba49f71 100644
--- a/tensorflow/tools/docker/BUILD
+++ b/tensorflow/tools/docker/BUILD
@@ -13,15 +13,3 @@ py_binary(
srcs_version = "PY2AND3",
deps = ["//tensorflow:tensorflow_py"],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/tools/docker/notebooks/BUILD b/tensorflow/tools/docker/notebooks/BUILD
index 89f473df4b..e9f26899c9 100644
--- a/tensorflow/tools/docker/notebooks/BUILD
+++ b/tensorflow/tools/docker/notebooks/BUILD
@@ -3,15 +3,3 @@ package(default_visibility = ["//visibility:private"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD
index 8f10bc9e0c..d370fbd246 100644
--- a/tensorflow/tools/docs/BUILD
+++ b/tensorflow/tools/docs/BUILD
@@ -142,14 +142,3 @@ py_test(
"//tensorflow/python:client_testlib",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/tools/git/BUILD b/tensorflow/tools/git/BUILD
index 942ceab85f..daa17fbd50 100644
--- a/tensorflow/tools/git/BUILD
+++ b/tensorflow/tools/git/BUILD
@@ -9,18 +9,3 @@ licenses(["notice"]) # Apache 2.0
exports_files(
["gen_git_source.py"],
)
-
-# -----------------------------------------------------------------------------
-# Google-internal targets. These must be at the end for syncrepo.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD
index 6e21aa2846..1ad1895269 100644
--- a/tensorflow/tools/graph_transforms/BUILD
+++ b/tensorflow/tools/graph_transforms/BUILD
@@ -313,14 +313,3 @@ tf_py_test(
],
main = "python/transform_graph_test.py",
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/tools/mlpbtxt/BUILD b/tensorflow/tools/mlpbtxt/BUILD
index f9f48c6500..89c683c8c4 100644
--- a/tensorflow/tools/mlpbtxt/BUILD
+++ b/tensorflow/tools/mlpbtxt/BUILD
@@ -32,15 +32,3 @@ tf_cc_binary(
"//tensorflow/core:op_gen_lib",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 95cdf0bf3c..62fec2c402 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -113,6 +113,7 @@ filegroup(
"@lmdb//:LICENSE",
"@local_config_sycl//sycl:LICENSE.text",
"@grpc//third_party/nanopb:LICENSE.txt",
+ "@grpc//third_party/address_sorting:LICENSE",
"@nasm//:LICENSE",
"@nsync//:LICENSE",
"@pcre//:LICENCE",
diff --git a/tensorflow/tools/proto_text/BUILD b/tensorflow/tools/proto_text/BUILD
index 39c4aac1e8..ef7bfdd3c9 100644
--- a/tensorflow/tools/proto_text/BUILD
+++ b/tensorflow/tools/proto_text/BUILD
@@ -96,18 +96,3 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
-
-# -----------------------------------------------------------------------------
-# Google-internal targets. These must be at the end for syncrepo.
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/tools/quantization/BUILD b/tensorflow/tools/quantization/BUILD
index e99ad06a06..17443a8617 100644
--- a/tensorflow/tools/quantization/BUILD
+++ b/tensorflow/tools/quantization/BUILD
@@ -76,15 +76,3 @@ py_binary(
"//tensorflow/python:platform",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/tools/test/BUILD b/tensorflow/tools/test/BUILD
index 159a8c1cfb..4b2026b947 100644
--- a/tensorflow/tools/test/BUILD
+++ b/tensorflow/tools/test/BUILD
@@ -92,15 +92,3 @@ tf_py_logged_benchmark(
name = "rnn_op_benchmark",
target = "//tensorflow/python/kernel_tests:rnn_test",
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/user_ops/BUILD b/tensorflow/user_ops/BUILD
index e8198efe2e..71443cc41e 100644
--- a/tensorflow/user_ops/BUILD
+++ b/tensorflow/user_ops/BUILD
@@ -50,15 +50,3 @@ tf_py_test(
additional_deps = ["//tensorflow:tensorflow_py"],
data = [":invalid_op.so"],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 6ac98de43a..ac6380dd3e 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -99,11 +99,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "eigen_archive",
urls = [
- "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/2355b229ea4c.tar.gz",
- "https://bitbucket.org/eigen/eigen/get/2355b229ea4c.tar.gz",
+ "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/6913f0cf7d06.tar.gz",
+ "https://bitbucket.org/eigen/eigen/get/6913f0cf7d06.tar.gz",
],
- sha256 = "0cadb31a35b514bf2dfd6b5d38205da94ef326ec6908fc3fd7c269948467214f",
- strip_prefix = "eigen-eigen-2355b229ea4c",
+ sha256 = "791b836cacd03e20bae5bdd25f1c4a5505a0a9975ba94a61eb4e2631fbd1d53a",
+ strip_prefix = "eigen-eigen-6913f0cf7d06",
build_file = clean_dep("//third_party:eigen.BUILD"),
patch_file = clean_dep("//third_party:eigen_fix_cuda_compilation.patch")
)
@@ -430,13 +430,14 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "grpc",
urls = [
- "https://mirror.bazel.build/github.com/grpc/grpc/archive/575bda39755b98d1f7099406bb57a6e3b2074874.tar.gz",
- "https://github.com/grpc/grpc/archive/575bda39755b98d1f7099406bb57a6e3b2074874.tar.gz",
+ "https://mirror.bazel.build/github.com/grpc/grpc/archive/bd6bdf93279a39a8cd92978fd7c9d14eccd98fc2.tar.gz",
+ "https://github.com/grpc/grpc/archive/bd6bdf93279a39a8cd92978fd7c9d14eccd98fc2.tar.gz",
],
- sha256 = "f08a5c8e265191b39cc74915b1bc1fd380d86cd0176c92b7cce30b6ac50514ad",
- strip_prefix = "grpc-575bda39755b98d1f7099406bb57a6e3b2074874",
+ sha256 = "0a05bd355e4571b01d813dddffa38e57e689ac41b264dc9b1bd6ec66463ef5d6",
+ strip_prefix = "grpc-bd6bdf93279a39a8cd92978fd7c9d14eccd98fc2",
)
+
tf_http_archive(
name = "linenoise",
sha256 = "7f51f45887a3d31b4ce4fa5965210a5e64637ceac12720cfce7954d6a2e812f7",
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
index f8fb6ecb0c..8a2b24aa4e 100644
--- a/third_party/examples/eager/spinn/spinn.py
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -266,8 +266,7 @@ class SPINN(tf.keras.Model):
trackings.append(tracking)
if rights:
- reducer_output = self.reducer(
- lefts, right_in=rights, tracking=trackings)
+ reducer_output = self.reducer(lefts, rights, trackings)
reduced = iter(reducer_output)
for transition, stack in zip(trans, stacks):
@@ -388,10 +387,10 @@ class SNLIClassifier(tf.keras.Model):
# Run the batch-normalized and dropout-processed word vectors through the
# SPINN encoder.
- premise = self.encoder(
- premise_embed, transitions=premise_transition, training=training)
- hypothesis = self.encoder(
- hypothesis_embed, transitions=hypothesis_transition, training=training)
+ premise = self.encoder(premise_embed, premise_transition,
+ training=training)
+ hypothesis = self.encoder(hypothesis_embed, hypothesis_transition,
+ training=training)
# Combine encoder outputs for premises and hypotheses into logits.
# Then apply batch normalization and dropuout on the logits.
@@ -465,11 +464,10 @@ class SNLIClassifierTrainer(tfe.Checkpointable):
"""
with tfe.GradientTape() as tape:
tape.watch(self._model.variables)
- # TODO(allenl): Allow passing Layer inputs as position arguments.
logits = self._model(premise,
- premise_transition=premise_transition,
- hypothesis=hypothesis,
- hypothesis_transition=hypothesis_transition,
+ premise_transition,
+ hypothesis,
+ hypothesis_transition,
training=True)
loss = self.loss(labels, logits)
gradients = tape.gradient(loss, self._model.variables)
@@ -533,9 +531,7 @@ def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu):
snli_data, batch_size):
if use_gpu:
label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
- logits = trainer.model(
- prem, premise_transition=prem_trans, hypothesis=hypo,
- hypothesis_transition=hypo_trans, training=False)
+ logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False)
loss_val = trainer.loss(label, logits)
batch_size = tf.shape(label)[0]
mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size)
@@ -639,11 +635,8 @@ def train_or_infer_spinn(embed,
hypo, hypo_trans = inference_sentence_pair[1]
hypo_trans = inference_sentence_pair[1][1]
inference_logits = model(
- tf.constant(prem),
- premise_transition=tf.constant(prem_trans),
- hypothesis=tf.constant(hypo),
- hypothesis_transition=tf.constant(hypo_trans),
- training=False)
+ tf.constant(prem), tf.constant(prem_trans),
+ tf.constant(hypo), tf.constant(hypo_trans), training=False)
inference_logits = inference_logits[0][1:]
max_index = tf.argmax(inference_logits)
print("\nInference logits:")
diff --git a/third_party/hadoop/BUILD b/third_party/hadoop/BUILD
index 9e98154400..c3c5e428be 100644
--- a/third_party/hadoop/BUILD
+++ b/third_party/hadoop/BUILD
@@ -4,18 +4,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE.txt"])
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
cc_library(
name = "hdfs",
hdrs = ["hdfs.h"],
diff --git a/third_party/mkl/BUILD b/third_party/mkl/BUILD
index 3262562bcc..c2adf578c7 100644
--- a/third_party/mkl/BUILD
+++ b/third_party/mkl/BUILD
@@ -24,7 +24,6 @@ load(
filegroup(
name = "LICENSE",
- visibility = ["//visibility:public"],
srcs = ["MKL_LICENSE"] + select({
"@org_tensorflow//tensorflow:linux_x86_64": [
"@mkl_linux//:LICENSE",
@@ -34,13 +33,13 @@ filegroup(
],
"@org_tensorflow//tensorflow:windows": [
"@mkl_windows//:LICENSE",
- ]
- })
+ ],
+ }),
+ visibility = ["//visibility:public"],
)
cc_library(
name = "intel_binary_blob",
-
visibility = ["//visibility:public"],
deps = select({
"@org_tensorflow//tensorflow:linux_x86_64": [
@@ -54,6 +53,6 @@ cc_library(
"@org_tensorflow//tensorflow:windows": [
"@mkl_windows//:mkl_headers",
"@mkl_windows//:mkl_libs_windows",
- ]
- })
+ ],
+ }),
)
diff --git a/third_party/mkl/mkl.BUILD b/third_party/mkl/mkl.BUILD
index 892221ec00..c3a71e4ff9 100644
--- a/third_party/mkl/mkl.BUILD
+++ b/third_party/mkl/mkl.BUILD
@@ -21,7 +21,7 @@ cc_library(
name = "mkl_libs_linux",
srcs = [
"lib/libiomp5.so",
- "lib/libmklml_intel.so"
+ "lib/libmklml_intel.so",
],
visibility = ["//visibility:public"],
)
@@ -30,7 +30,7 @@ cc_library(
name = "mkl_libs_darwin",
srcs = [
"lib/libiomp5.dylib",
- "lib/libmklml.dylib"
+ "lib/libmklml.dylib",
],
visibility = ["//visibility:public"],
)
@@ -39,7 +39,7 @@ cc_library(
name = "mkl_libs_windows",
srcs = [
"lib/libiomp5md.lib",
- "lib/mklml.lib"
+ "lib/mklml.lib",
],
visibility = ["//visibility:public"],
)
diff --git a/third_party/mpi/BUILD b/third_party/mpi/BUILD
index ff3f437e92..1d6ac2fceb 100644
--- a/third_party/mpi/BUILD
+++ b/third_party/mpi/BUILD
@@ -1,17 +1,5 @@
licenses(["restricted"])
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
load("//third_party/mpi:mpi.bzl", "mpi_hdr")
load("//third_party/mpi:mpi.bzl", "if_mpi")
diff --git a/third_party/sycl/BUILD b/third_party/sycl/BUILD
index fbdf19f205..f631b6df06 100644
--- a/third_party/sycl/BUILD
+++ b/third_party/sycl/BUILD
@@ -1,15 +1,3 @@
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/third_party/sycl/sycl/BUILD b/third_party/sycl/sycl/BUILD
index bc1d18b7b5..b045609954 100644
--- a/third_party/sycl/sycl/BUILD
+++ b/third_party/sycl/sycl/BUILD
@@ -5,15 +5,3 @@
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)