aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--configure.py18
-rw-r--r--tensorflow/BUILD96
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc84
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.cc30
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.h51
-rw-r--r--tensorflow/compiler/tf2xla/type_util.h8
-rw-r--r--tensorflow/contrib/BUILD53
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py5
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py4
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py18
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py24
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt1
-rw-r--r--tensorflow/contrib/data/BUILD38
-rw-r--r--tensorflow/contrib/data/ops/indexed_dataset_ops.cc80
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD40
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py9
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py9
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py43
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py15
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/BUILD9
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py14
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py7
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py10
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/resample_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py11
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py8
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py5
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/test_utils.py73
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD57
-rw-r--r--tensorflow/contrib/data/python/ops/contrib_op_loader.py24
-rw-r--r--tensorflow/contrib/data/python/ops/error_ops.py5
-rw-r--r--tensorflow/contrib/data/python/ops/indexed_dataset_ops.py25
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py13
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py7
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py37
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py6
-rw-r--r--tensorflow/contrib/data/python/ops/threadpool.py9
-rw-r--r--tensorflow/contrib/data/python/ops/unique.py5
-rw-r--r--tensorflow/contrib/decision_trees/proto/BUILD1
-rw-r--r--tensorflow/contrib/distribute/README.md3
-rw-r--r--tensorflow/contrib/distribute/python/BUILD28
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py2
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py3
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py121
-rw-r--r--tensorflow/contrib/distribute/python/metrics_v1_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py26
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py14
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py12
-rw-r--r--tensorflow/contrib/distribute/python/monitor.py1
-rw-r--r--tensorflow/contrib/distribute/python/optimizer_v2_test.py8
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2.py232
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py90
-rw-r--r--tensorflow/contrib/distribute/python/step_fn.py7
-rw-r--r--tensorflow/contrib/distribute/python/step_fn_test.py1
-rw-r--r--tensorflow/contrib/distribute/python/values.py63
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py23
-rw-r--r--tensorflow/contrib/factorization/BUILD9
-rw-r--r--tensorflow/contrib/lite/BUILD19
-rw-r--r--tensorflow/contrib/lite/examples/android/BUILD1
-rw-r--r--tensorflow/contrib/lite/java/aar_with_jni.bzl53
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/BUILD1
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/BUILD1
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h4
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h29
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_test.cc36
-rw-r--r--tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD1
-rw-r--r--tensorflow/contrib/makefile/Makefile3
-rw-r--r--tensorflow/contrib/opt/BUILD5
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo_test.py40
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/BUILD7
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py2
-rw-r--r--tensorflow/contrib/tpu/__init__.py3
-rw-r--r--tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc42
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc3
-rw-r--r--tensorflow/contrib/tpu/profiler/op_profile.proto4
-rw-r--r--tensorflow/contrib/tpu/proto/optimization_parameters.proto33
-rw-r--r--tensorflow/contrib/tpu/python/tpu/async_checkpoint.py12
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py11
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py53
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py13
-rw-r--r--tensorflow/contrib/training/BUILD1
-rw-r--r--tensorflow/core/BUILD18
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt21
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt58
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt25
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt35
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt2
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc49
-rw-r--r--tensorflow/core/common_runtime/direct_session.h3
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc28
-rw-r--r--tensorflow/core/framework/function.cc8
-rw-r--r--tensorflow/core/framework/function.h5
-rw-r--r--tensorflow/core/framework/node_def_util.h1
-rw-r--r--tensorflow/core/framework/op.h20
-rw-r--r--tensorflow/core/framework/op_def_builder.cc24
-rw-r--r--tensorflow/core/framework/op_def_builder.h14
-rw-r--r--tensorflow/core/framework/resource_mgr.h11
-rw-r--r--tensorflow/core/framework/run_handler.cc249
-rw-r--r--tensorflow/core/framework/run_handler.h95
-rw-r--r--tensorflow/core/framework/run_handler_util.cc57
-rw-r--r--tensorflow/core/framework/run_handler_util.h43
-rw-r--r--tensorflow/core/framework/run_handler_util_test.cc93
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD7
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h4
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc28
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc112
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/BUILD3
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc29
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc36
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h23
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc451
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.h35
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc205
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc6
-rw-r--r--tensorflow/core/grappler/utils/functions.cc55
-rw-r--r--tensorflow/core/grappler/utils/functions.h5
-rw-r--r--tensorflow/core/kernels/BUILD47
-rw-r--r--tensorflow/core/kernels/collective_ops.cc2
-rw-r--r--tensorflow/core/kernels/data/BUILD1
-rw-r--r--tensorflow/core/kernels/data/experimental/BUILD (renamed from tensorflow/contrib/data/kernels/BUILD)90
-rw-r--r--tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/assert_next_dataset_op.cc)5
-rw-r--r--tensorflow/core/kernels/data/experimental/csv_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/csv_dataset_op.cc)3
-rw-r--r--tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc)5
-rw-r--r--tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc (renamed from tensorflow/contrib/data/kernels/identity_indexed_dataset.cc)7
-rw-r--r--tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc)6
-rw-r--r--tensorflow/core/kernels/data/experimental/indexed_dataset.cc (renamed from tensorflow/contrib/data/kernels/indexed_dataset.cc)14
-rw-r--r--tensorflow/core/kernels/data/experimental/indexed_dataset.h (renamed from tensorflow/contrib/data/kernels/indexed_dataset.h)6
-rw-r--r--tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/lmdb_dataset_op.cc)3
-rw-r--r--tensorflow/core/kernels/data/experimental/prefetching_kernels.cc (renamed from tensorflow/contrib/data/kernels/prefetching_kernels.cc)23
-rw-r--r--tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/threadpool_dataset_op.cc)7
-rw-r--r--tensorflow/core/kernels/data/experimental/unique_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/unique_dataset_op.cc)7
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/multi_device_iterator_ops.cc34
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc42
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc10
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc2
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc2
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc2
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc18
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt415
-rw-r--r--tensorflow/core/ops/experimental_dataset_ops.cc (renamed from tensorflow/contrib/data/ops/dataset_ops.cc)161
-rw-r--r--tensorflow/core/ops/ops.pbtxt415
-rw-r--r--tensorflow/core/platform/default/build_config.bzl45
-rw-r--r--tensorflow/core/protobuf/config.proto5
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto2
-rw-r--r--tensorflow/core/util/mkl_util.h12
-rw-r--r--tensorflow/core/util/tensor_bundle/BUILD4
-rw-r--r--tensorflow/examples/android/BUILD1
-rw-r--r--tensorflow/go/op/wrappers.go1782
-rw-r--r--tensorflow/python/BUILD10
-rw-r--r--tensorflow/python/client/session_test.py16
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/data/kernel_tests/test_base.py80
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py4
-rw-r--r--tensorflow/python/distribute/estimator_training.py2
-rw-r--r--tensorflow/python/estimator/BUILD1
-rw-r--r--tensorflow/python/estimator/canned/dnn.py3
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined_test.py268
-rw-r--r--tensorflow/python/estimator/canned/linear.py83
-rw-r--r--tensorflow/python/estimator/canned/linear_test.py138
-rw-r--r--tensorflow/python/estimator/canned/linear_testing_utils.py184
-rw-r--r--tensorflow/python/estimator/estimator.py44
-rw-r--r--tensorflow/python/estimator/estimator_test.py94
-rw-r--r--tensorflow/python/feature_column/BUILD2
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py608
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py1858
-rw-r--r--tensorflow/python/framework/test_util.py84
-rw-r--r--tensorflow/python/keras/backend.py8
-rw-r--r--tensorflow/python/keras/engine/base_layer.py157
-rw-r--r--tensorflow/python/keras/engine/training.py3
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py341
-rw-r--r--tensorflow/python/keras/engine/training_eager_test.py14
-rw-r--r--tensorflow/python/keras/engine/training_test.py12
-rw-r--r--tensorflow/python/kernel_tests/BUILD25
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py180
-rw-r--r--tensorflow/python/layers/base.py16
-rw-r--r--tensorflow/python/layers/convolutional_test.py36
-rw-r--r--tensorflow/python/layers/core_test.py6
-rw-r--r--tensorflow/python/ops/control_flow_ops.py16
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py2
-rw-r--r--tensorflow/python/ops/while_v2.py4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt24
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt148
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt46
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt58
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt58
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt105
-rw-r--r--tensorflow/tools/api/golden/tensorflow.image.pbtxt251
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt268
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt289
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt55
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt268
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt289
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt6
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.042
-rwxr-xr-xtensorflow/tools/ci_build/builds/run_pip_tests.sh1
-rw-r--r--tensorflow/tools/lib_package/BUILD38
-rw-r--r--tensorflow/tools/pip_package/BUILD29
-rwxr-xr-xtensorflow/workspace.bzl384
-rw-r--r--third_party/gpus/crosstool/BUILD.tpl14
-rw-r--r--third_party/toolchains/BUILD4
-rwxr-xr-xthird_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD2
-rwxr-xr-xthird_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD14
-rw-r--r--tools/bazel.rc1
247 files changed, 6692 insertions, 7175 deletions
diff --git a/configure.py b/configure.py
index f71caa1994..0a3b9a7894 100644
--- a/configure.py
+++ b/configure.py
@@ -1488,11 +1488,7 @@ def main():
setup_python(environ_cp)
if is_windows():
- environ_cp['TF_NEED_AWS'] = '0'
- environ_cp['TF_NEED_GCP'] = '0'
- environ_cp['TF_NEED_HDFS'] = '0'
environ_cp['TF_NEED_JEMALLOC'] = '0'
- environ_cp['TF_NEED_KAFKA'] = '0'
environ_cp['TF_NEED_OPENCL_SYCL'] = '0'
environ_cp['TF_NEED_COMPUTECPP'] = '0'
environ_cp['TF_NEED_OPENCL'] = '0'
@@ -1508,6 +1504,7 @@ def main():
if is_macos():
environ_cp['TF_NEED_JEMALLOC'] = '0'
environ_cp['TF_NEED_TENSORRT'] = '0'
+ environ_cp['TF_ENABLE_XLA'] = '0'
# The numpy package on ppc64le uses OpenBLAS which has multi-threading
# issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at
@@ -1516,18 +1513,8 @@ def main():
if is_ppc64le():
write_action_env_to_bazelrc('OMP_NUM_THREADS', 1)
- set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
- 'with_jemalloc', True)
- set_build_var(environ_cp, 'TF_NEED_GCP', 'Google Cloud Platform',
- 'with_gcp_support', True, 'gcp')
- set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System',
- 'with_hdfs_support', True, 'hdfs')
- set_build_var(environ_cp, 'TF_NEED_AWS', 'Amazon AWS Platform',
- 'with_aws_support', True, 'aws')
- set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
- 'with_kafka_support', True, 'kafka')
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
- False, 'xla')
+ True, 'xla')
set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False)
@@ -1636,4 +1623,3 @@ def main():
if __name__ == '__main__':
main()
-
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 3610eea42a..5f73da68a2 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -225,60 +225,6 @@ config_setting(
)
config_setting(
- name = "with_gcp_support",
- define_values = {"with_gcp_support": "true"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_hdfs_support",
- define_values = {"with_hdfs_support": "true"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_aws_support",
- define_values = {"with_aws_support": "true"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_kafka_support",
- define_values = {"with_kafka_support": "true"},
- visibility = ["//visibility:public"],
-)
-
-# Crosses between platforms and file system libraries not supported on those
-# platforms due to limitations in nested select() statements.
-config_setting(
- name = "with_gcp_support_windows_override",
- define_values = {"with_gcp_support": "true"},
- values = {"cpu": "x64_windows"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_hdfs_support_windows_override",
- define_values = {"with_hdfs_support": "true"},
- values = {"cpu": "x64_windows"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_aws_support_windows_override",
- define_values = {"with_aws_support": "true"},
- values = {"cpu": "x64_windows"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_kafka_support_windows_override",
- define_values = {"with_kafka_support": "true"},
- values = {"cpu": "x64_windows"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
name = "with_cuda_support_windows_override",
define_values = {"using_cuda_nvcc": "true"},
values = {"cpu": "x64_windows"},
@@ -286,48 +232,6 @@ config_setting(
)
config_setting(
- name = "with_gcp_support_android_override",
- define_values = {"with_gcp_support": "true"},
- values = {"crosstool_top": "//external:android/crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_hdfs_support_android_override",
- define_values = {"with_hdfs_support": "true"},
- values = {"crosstool_top": "//external:android/crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_aws_support_android_override",
- define_values = {"with_aws_support": "true"},
- values = {"crosstool_top": "//external:android/crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_gcp_support_ios_override",
- define_values = {"with_gcp_support": "true"},
- values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_hdfs_support_ios_override",
- define_values = {"with_hdfs_support": "true"},
- values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_aws_support_ios_override",
- define_values = {"with_aws_support": "true"},
- values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
name = "with_xla_support",
define_values = {"with_xla_support": "true"},
visibility = ["//visibility:public"],
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 2d45507796..36c6f5d316 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -92,13 +92,51 @@ Status FunctionalizeControlFlowForFunction(
});
const FunctionBody* body = flr->GetFunctionBody(handle);
+ // Call graph optimizer. The most important optimization we need is constant
+ // folding, which will replace ops like Shape/BroadcastGradientArgs with
+ // constant shape input. Without this optimization, those ops might become
+ // dynamic input for then/else body function and XLA will complain that input
+ // is not compile time constant. We enable function inlining as well, because
+ // otherwise we won't be able to infer shape for any node depending on
+ // function call nodes.
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_before_opt_", func_name),
+ *body->graph, fld);
+ }
+ // Optimizer accepts std::unique_ptr<Graph>* as input and might change
+ // underlying pointer, thus we create a new Graph and copy from body->graph.
+ std::unique_ptr<Graph> optimized_graph(new Graph(fld));
+ CopyGraph(*body->graph, optimized_graph.get());
+ OptimizerOptions opts;
+ opts.set_opt_level(OptimizerOptions::L0);
+ opts.set_do_function_inlining(true);
+ opts.set_do_constant_folding(true);
+ GraphOptimizer optimizer(opts);
+ auto cf_consider_fn = [](const Node* n) {
+ // Skip SymbolicGradient op when doing constant folding.
+ // Enabling SymbolicGradient op in constant folding requires
+ // flr->device() to be non-null, and here we have not constructed
+ // proper Device object yet (it will be constructed in XlaCompiler).
+ return n->type_string() != FunctionLibraryDefinition::kGradientOp;
+ };
+ optimizer.Optimize(flr, flr->env(),
+ /*device=*/nullptr, &optimized_graph,
+ /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr,
+ cf_consider_fn);
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_after_opt_", func_name),
+ *optimized_graph, fld);
+ }
+
// If any node has associated functions, functionalize them first.
// Gather nodes with associated functions first, because rewriting those nodes
// might involve node deletion/addition. Avoid modifying nodes while iterating
// it.
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
nodes_to_associated_functions;
- for (auto* n : body->graph->nodes()) {
+ for (auto* n : optimized_graph->nodes()) {
auto associated_functions = GetAssociatedFunctions(*n, flr);
if (!associated_functions.empty()) {
nodes_to_associated_functions.push_back({n, associated_functions});
@@ -118,7 +156,14 @@ Status FunctionalizeControlFlowForFunction(
// but still rewrite the node.
new_name = iter->second;
} else {
- new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
+ if (associated_function.type() ==
+ AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) {
+ // For SymbolicGradient, `name` is always "SymbolicGradient",
+ // which is not very informative. Use node name instead.
+ new_name = fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_"));
+ } else {
+ new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
+ }
TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
name, new_name, associated_function.attrs(), fld, flr,
canonicalized_name_to_new_name));
@@ -129,43 +174,10 @@ Status FunctionalizeControlFlowForFunction(
// That's fine because in that case, associated_functions will only have
// one member and the loop will only run once.
TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
- body->graph, n, fld, associated_function, new_name));
+ optimized_graph.get(), n, fld, associated_function, new_name));
}
}
- // Call graph optimizer. The most important optimization we need is constant
- // folding, which will replace ops like Shape/BroadcastGradientArgs with
- // constant shape input. Without this optimization, those ops might become
- // dynamic input for then/else body function and XLA will complain that input
- // is not compile time constant. We enable function inlining as well, because
- // otherwise we won't be able to infer shape for any node depending on
- // function call nodes.
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_before_opt_", func_name),
- *body->graph, fld);
- }
- // Optimizer accepts std::unique_ptr<Graph>* as input and might change
- // underlying pointer, thus we create a new Graph and copy from body->graph.
- std::unique_ptr<Graph> optimized_graph(new Graph(fld));
- CopyGraph(*body->graph, optimized_graph.get());
- OptimizerOptions opts;
- opts.set_opt_level(OptimizerOptions::L0);
- opts.set_do_function_inlining(true);
- opts.set_do_constant_folding(true);
- GraphOptimizer optimizer(opts);
- auto cf_consider_fn = [](const Node* n) {
- // Skip SymbolicGradient op when doing constant folding.
- // Enabling SymbolicGradient op in constant folding requires
- // flr->device() to be non-null, and here we have not constructed
- // proper Device object yet (it will be constructed in XlaCompiler).
- return n->type_string() != FunctionLibraryDefinition::kGradientOp;
- };
- optimizer.Optimize(flr, flr->env(),
- /*device=*/nullptr, &optimized_graph,
- /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr,
- cf_consider_fn);
-
// Functionalize the function body.
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index d6f42bac86..01dd3ba10f 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -336,9 +336,9 @@ bool HasAssociatedFunction(const NodeDef& node_def,
}
if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
- // Skip gradient op. Gradient op has "f" attr, which is set to the function
- // we are getting gradient for. That function is not associated with the op.
- return false;
+ // Gradient op has "f" attr, which is set to the function we are getting
+ // gradient for. We need to functionalize the gradient function.
+ return true;
}
for (const auto& iter : node_def.attr()) {
@@ -357,17 +357,18 @@ std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
if (flr->GetFunctionLibraryDefinition()->Contains(op)) {
// This is a function call node.
AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
- results.emplace_back(AssociatedFunctionInfo(op, attrs));
+ results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
} else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
- // Skip gradient op. Gradient op has "f" attr, which is set to the function
- // we are getting gradient for. That function is not associated with the op.
+ // This is a SymbolicGradient op.
+ AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
+ results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
} else {
// Collect all function attrs for the node.
for (auto& iter : node.attrs()) {
if (iter.second.has_func()) {
VLOG(2) << "Found function attr for node " << node.name() << ": "
<< iter.first << " = " << iter.second.func().name();
- results.emplace_back(AssociatedFunctionInfo(
+ results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
iter.second.func().name(), iter.second.func().attr(), iter.first));
}
}
@@ -410,6 +411,21 @@ Status RewriteAssociatedFunction(
graph->RemoveNode(node);
break;
}
+ case AssociatedFunctionInfo::kSymbolicGradient: {
+ NameAttrList func;
+ TF_RETURN_IF_ERROR(GetNodeAttr(
+ node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
+ GradientDef gradient_def;
+ gradient_def.set_function_name(func.name());
+ gradient_def.set_gradient_func(rewritten_function_name);
+ string original_grad_func = fld->FindGradient(func.name());
+ if (original_grad_func.empty()) {
+ TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
+ } else if (original_grad_func != rewritten_function_name) {
+ TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
+ }
+ break;
+ }
case AssociatedFunctionInfo::kFunctionAttr: {
// Change function attr to rewritten functions.
NameAttrList func;
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index 6065d0bb9a..53eab8b63e 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -65,21 +65,33 @@ uint32 GetXLARandomSeed();
class AssociatedFunctionInfo {
public:
enum AssociatedFunctionType {
- kFunctionCallNode = 0,
- kFunctionAttr = 1,
+ kFunctionAttr = 0,
+ kFunctionCallNode = 1,
+ kSymbolicGradient = 2,
};
- // The node is a function call.
- AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs)
- : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {}
-
// The function is an attr of the node.
- AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs,
- const string& attr_name)
- : type_(kFunctionAttr),
- func_name_(func_name),
- attrs_(attrs),
- attr_name_(attr_name) {}
+ static AssociatedFunctionInfo FunctionAttr(const string& func_name,
+ const AttrValueMap& attrs,
+ const string& attr_name) {
+ return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name);
+ }
+
+ // The node is a function call.
+ static AssociatedFunctionInfo FunctionCall(const string& func_name,
+ const AttrValueMap& attrs) {
+ // attr_name will not be used in this case.
+ return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs,
+ /*attr_name=*/"");
+ }
+
+ // The node is a SymbolicGradient op.
+ static AssociatedFunctionInfo SymbolicGradient(const string& func_name,
+ const AttrValueMap& attrs) {
+ // attr_name will not be used in this case.
+ return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs,
+ /*attr_name=*/"");
+ }
AssociatedFunctionType type() const { return type_; }
@@ -90,6 +102,13 @@ class AssociatedFunctionInfo {
const AttrValueMap& attrs() const { return attrs_; }
private:
+ AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name,
+ const AttrValueMap& attrs, const string& attr_name)
+ : type_(type),
+ func_name_(func_name),
+ attrs_(attrs),
+ attr_name_(attr_name) {}
+
// Available for all instances.
AssociatedFunctionType type_;
string func_name_;
@@ -105,14 +124,18 @@ bool HasAssociatedFunction(const NodeDef& node_def,
// Gets functions associated with the node. Current cases:
// 1. For function call node, its function name;
-// 2. For nodes like XlaWhile/XlaIf, all their function attributes.
+// 2. For SymbolicGradient op, returned func_name will be "SymbolicGradient",
+// and returned attrs will be this node's attributes;
+// 3. For nodes like XlaWhile/XlaIf, all their function attributes.
std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
const Node& node, FunctionLibraryRuntime* flr);
// Changes associated functions for the node. Current cases:
// 1. For function call node, creates a new node with the new function name and
// remove the old node;
-// 2. For nodes like XlaWhile/XlaIf, modify their function attributes.
+// 2. For SymbolicGradient op, add or replace GradientDef in
+// FunctionLibraryDefinition;
+// 3. For nodes like XlaWhile/XlaIf, modify their function attributes.
Status RewriteAssociatedFunction(
Graph* graph, Node* node, FunctionLibraryDefinition* fld,
const AssociatedFunctionInfo& associated_function,
diff --git a/tensorflow/compiler/tf2xla/type_util.h b/tensorflow/compiler/tf2xla/type_util.h
index bda667eb1f..6354216eee 100644
--- a/tensorflow/compiler/tf2xla/type_util.h
+++ b/tensorflow/compiler/tf2xla/type_util.h
@@ -25,6 +25,14 @@ namespace tensorflow {
// Converts a Tensorflow DataType to an XLA PrimitiveType.
Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type);
+// N.B.: there is intentionally no function to convert an XLA PrimitiveType to
+// a TensorFlow DataType. The mapping from TF types to XLA types is not
+// one-to-one: for example, both DT_INT8 and DT_QINT8 map to xla::S8. So the
+// inverse would not be a well-defined function. If you find that you want the
+// inverse mapping, then most likely you should be preserving the original
+// TensorFlow type, rather than trying to convert an XLA type into a TensorFlow
+// type.
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index ae5ca32bcf..98dff965a9 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -112,26 +112,14 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python/estimator:estimator_py",
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + select({
- "//tensorflow:with_kafka_support_windows_override": [],
- "//tensorflow:with_kafka_support": [
- "//tensorflow/contrib/kafka",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support": [
- "//tensorflow/contrib/kinesis",
- ],
- "//conditions:default": [],
- }) + if_not_windows_cuda([
- "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols
- ]) + if_not_windows([
- ]) + select({
"//tensorflow:linux_s390x": [],
"//tensorflow:windows": [],
"//conditions:default": [
"//tensorflow/contrib/bigtable",
"//tensorflow/contrib/cloud:cloud_py",
+ "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols
+ "//tensorflow/contrib/kafka",
+ "//tensorflow/contrib/kinesis",
"//tensorflow/contrib/tensorrt:init_py",
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
],
@@ -144,7 +132,6 @@ cc_library(
deps = [
"//tensorflow/contrib/boosted_trees:boosted_trees_kernels",
"//tensorflow/contrib/coder:all_kernels",
- "//tensorflow/contrib/data/kernels:dataset_kernels",
"//tensorflow/contrib/factorization/kernels:all_kernels",
"//tensorflow/contrib/hadoop:dataset_kernels",
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
@@ -159,20 +146,14 @@ cc_library(
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_cuda([
"//tensorflow/contrib/nccl:nccl_kernels",
]) + select({
- "//tensorflow:with_kafka_support_windows_override": [],
- "//tensorflow:with_kafka_support": [
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
"//tensorflow/contrib/kafka:dataset_kernels",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support": [
"//tensorflow/contrib/kinesis:dataset_kernels",
+ "//tensorflow/contrib/tensorrt:trt_engine_op_kernel",
],
- "//conditions:default": [],
- }) + if_not_windows([
- "//tensorflow/contrib/tensorrt:trt_engine_op_kernel",
- ]),
+ }),
)
cc_library(
@@ -181,8 +162,6 @@ cc_library(
deps = [
"//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib",
"//tensorflow/contrib/coder:all_ops",
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
"//tensorflow/contrib/factorization:all_ops",
"//tensorflow/contrib/framework:all_ops",
"//tensorflow/contrib/hadoop:dataset_ops_op_lib",
@@ -198,18 +177,12 @@ cc_library(
"//tensorflow/contrib/text:all_ops",
"//tensorflow/contrib/tpu:all_ops",
] + select({
- "//tensorflow:with_kafka_support_windows_override": [],
- "//tensorflow:with_kafka_support": [
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
"//tensorflow/contrib/kafka:dataset_ops_op_lib",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support": [
"//tensorflow/contrib/kinesis:dataset_ops_op_lib",
+ "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib",
],
- "//conditions:default": [],
- }) + if_not_windows([
- "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib",
- ]),
+ }),
)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
index 6b6fe9663a..839eedd3a8 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
@@ -188,9 +188,8 @@ class CoreDNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase):
# Train for a few steps.
est.train(input_fn=_train_input_fn, steps=1000)
- # 10 steps for dnn + 3 for 1 tree of depth 3 + 1 after the tree finished
- # + 1 for resource variables.
- self._assert_checkpoint(est.model_dir, global_step=15)
+ # 10 steps for dnn, 3 for 1 tree of depth 3 + 1 after the tree finished
+ self._assert_checkpoint(est.model_dir, global_step=14)
res = est.evaluate(input_fn=_eval_input_fn, steps=1)
self.assertLess(0.5, res["auc"])
est.predict(input_fn=_eval_input_fn)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index d7b14e00ba..c155128c0e 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -238,8 +238,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
output_leaf_index=False)
classifier.fit(input_fn=_train_input_fn, steps=15)
- # When no override of global steps, 6 steps were used.
- self._assert_checkpoint(classifier.model_dir, global_step=6)
+ # When no override of global steps, 5 steps were used.
+ self._assert_checkpoint(classifier.model_dir, global_step=5)
def testOverridesGlobalSteps(self):
learner_config = learner_pb2.LearnerConfig()
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index c7eb2493a8..8531e97f90 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -402,13 +402,13 @@ class GradientBoostedDecisionTreeModel(object):
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
self._num_quantiles = num_quantiles
- self._max_tree_depth = variables.Variable(
+ self._max_tree_depth = variables.VariableV1(
initial_value=self._learner_config.constraints.max_tree_depth)
- self._attempted_trees = variables.Variable(
+ self._attempted_trees = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
trainable=False,
name="attempted_trees")
- self._finalized_trees = variables.Variable(
+ self._finalized_trees = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
trainable=False,
name="finalized_trees")
@@ -770,28 +770,28 @@ class GradientBoostedDecisionTreeModel(object):
fc_name_idx += 1
# Create ensemble stats variables.
- num_layer_examples = variables.Variable(
+ num_layer_examples = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="num_layer_examples",
trainable=False)
- num_layer_steps = variables.Variable(
+ num_layer_steps = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="num_layer_steps",
trainable=False)
- num_layers = variables.Variable(
+ num_layers = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="num_layers",
trainable=False)
- active_tree = variables.Variable(
+ active_tree = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="active_tree",
trainable=False)
- active_layer = variables.Variable(
+ active_layer = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="active_layer",
trainable=False)
# Variable that becomes false once bias centering is done.
- continue_centering = variables.Variable(
+ continue_centering = variables.VariableV1(
initial_value=self._center_bias,
name="continue_centering",
trainable=False)
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
index 9d9941f696..6d20a2e7f4 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
@@ -239,7 +239,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -503,7 +503,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -607,7 +607,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -711,7 +711,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -783,7 +783,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -847,7 +847,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1090,7 +1090,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
weights = array_ops.ones([batch_size, 1], dtypes.float32)
partition_ids = array_ops.zeros([batch_size], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1194,7 +1194,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
weights = array_ops.ones([batch_size, 1], dtypes.float32)
partition_ids = array_ops.zeros([batch_size], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1299,7 +1299,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
weights = array_ops.ones([batch_size, 1], dtypes.float32)
partition_ids = array_ops.zeros([batch_size], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1405,7 +1405,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1524,7 +1524,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1656,7 +1656,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index c0763f4c0e..2975b167ec 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -132,7 +132,6 @@ tensorflow/contrib/cudnn_rnn/python
tensorflow/contrib/cudnn_rnn/python/layers
tensorflow/contrib/cudnn_rnn/python/ops
tensorflow/contrib/data
-tensorflow/contrib/data/kernels
tensorflow/contrib/data/python
tensorflow/contrib/data/python/kernel_tests
tensorflow/contrib/data/python/kernel_tests/serialization
diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD
index 9f710613dd..38f1c65a4d 100644
--- a/tensorflow/contrib/data/BUILD
+++ b/tensorflow/contrib/data/BUILD
@@ -4,17 +4,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load(
- "//tensorflow:tensorflow.bzl",
- "tf_custom_op_library",
- "tf_gen_op_libs",
- "if_not_windows",
-)
-load(
- "//tensorflow/core:platform/default/build_config_root.bzl",
- "if_static",
-)
-
py_library(
name = "data",
srcs = ["__init__.py"],
@@ -25,30 +14,3 @@ py_library(
"//tensorflow/python:util",
],
)
-
-cc_library(
- name = "lib_proto_parsing_for_dataset_ops",
- deps = if_not_windows(["//tensorflow/core:lib_proto_parsing"]),
-)
-
-tf_custom_op_library(
- name = "_dataset_ops.so",
- srcs = [
- "ops/dataset_ops.cc",
- "ops/indexed_dataset_ops.cc",
- ],
- deps = [
- "//tensorflow/contrib/data/kernels:dataset_kernels",
- "//tensorflow/contrib/data/kernels:indexed_dataset",
- ] + if_static(
- extra_deps = [":lib_proto_parsing_for_dataset_ops"],
- otherwise = [],
- ),
-)
-
-tf_gen_op_libs(
- op_lib_names = [
- "dataset_ops",
- "indexed_dataset_ops",
- ],
-)
diff --git a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc b/tensorflow/contrib/data/ops/indexed_dataset_ops.cc
deleted file mode 100644
index cd9b7c68a0..0000000000
--- a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc
+++ /dev/null
@@ -1,80 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-
-namespace tensorflow {
-
-REGISTER_OP("IdentityIndexedDataset")
- .Input("size: uint64")
- .Output("handle: variant")
- .SetIsStateful()
- .SetShapeFn(
- shape_inference::ScalarShape); // TODO(saeta): check input shapes.
-
-///////////////////////////////////////////////////////////////////////////////
-// IndexedDataset Internals
-///////////////////////////////////////////////////////////////////////////////
-
-// Creates the handle.
-REGISTER_OP("MaterializedIndexDatasetHandle")
- .Output("handle: resource")
- .Attr("container: string")
- .Attr("shared_name: string")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
-
-// Actually materialize the materialize handle.
-REGISTER_OP("IndexedDatasetMaterialize")
- .Input("dataset: variant")
- .Input("materialized: resource")
- .SetShapeFn(shape_inference::NoOutputs);
-
-namespace {
-
-Status GetShapeFn(shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
- std::vector<PartialTensorShape> output_shapes;
- TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
- if (output_shapes.size() != c->num_outputs()) {
- return errors::InvalidArgument(
- "`output_shapes` must be the same length as `output_types` (",
- output_shapes.size(), " vs. ", c->num_outputs());
- }
- for (size_t i = 0; i < output_shapes.size(); ++i) {
- shape_inference::ShapeHandle output_shape_handle;
- TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
- output_shapes[i], &output_shape_handle));
- c->set_output(static_cast<int>(i), output_shape_handle);
- }
- return Status::OK();
-}
-
-} // namespace
-
-REGISTER_OP("IndexedDatasetGet")
- .Input("materialized: resource")
- .Input("index: uint64")
- .Output("components: output_types")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(GetShapeFn)
- .Doc(R"doc(
-Gets the element at `index` from `materialized` IndexedDataset.
-)doc");
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index ce52c990ce..33784afa3f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -31,6 +31,7 @@ py_test(
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -54,6 +55,7 @@ py_test(
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -77,6 +79,7 @@ py_test(
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
@@ -97,6 +100,7 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
],
@@ -112,6 +116,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:random_seed",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -130,6 +135,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -139,12 +145,12 @@ py_test(
name = "indexed_dataset_ops_test",
srcs = ["indexed_dataset_ops_test.py"],
deps = [
- "//tensorflow/contrib/data/python/ops:contrib_op_loader",
- "//tensorflow/contrib/data/python/ops:gen_dataset_ops",
"//tensorflow/contrib/data/python/ops:indexed_dataset_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -170,6 +176,7 @@ py_test(
"//tensorflow/python:script_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@six_archive//:six",
],
@@ -189,6 +196,7 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
"//tensorflow/python:variables",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/estimator:estimator_py",
],
@@ -215,6 +223,7 @@ py_test(
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//third_party/py/numpy",
],
)
@@ -240,6 +249,7 @@ py_test(
"//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -259,6 +269,7 @@ py_test(
"//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -283,6 +294,7 @@ py_test(
"//tensorflow/python:functional_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
],
)
@@ -301,6 +313,7 @@ py_test(
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
@@ -316,6 +329,7 @@ cuda_py_test(
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
@@ -341,6 +355,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -366,6 +381,7 @@ py_library(
"//tensorflow/python:lib",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:readers",
],
@@ -412,6 +428,7 @@ py_test(
"//tensorflow/python:random_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -434,6 +451,7 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
@@ -454,6 +472,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -471,6 +490,7 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -490,6 +510,7 @@ py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python/data/kernel_tests:test_base",
"@org_sqlite//:python",
],
)
@@ -534,6 +555,7 @@ py_library(
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/kernel_tests:test_base",
],
)
@@ -550,6 +572,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:script_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -568,6 +591,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -588,6 +612,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -605,17 +630,8 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:lib",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
],
)
-
-py_library(
- name = "test_utils",
- srcs = ["test_utils.py"],
- deps = [
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/util:nest",
- ],
-)
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index e2508de9e9..fed7de5f2b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -40,12 +41,8 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class BatchDatasetTest(test.TestCase, parameterized.TestCase):
+class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
def testDenseToSparseBatchDataset(self):
components = np.random.randint(12, size=(100,)).astype(np.int32)
@@ -723,7 +720,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
-class RestructuredDatasetTest(test.TestCase):
+class RestructuredDatasetTest(test_base.DatasetTestBase):
def test_assert_element_shape(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 48971f2ccc..ae401f786c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -22,6 +22,7 @@ import random
import numpy as np
from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -35,7 +36,7 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
-class GroupByReducerTest(test.TestCase):
+class GroupByReducerTest(test_base.DatasetTestBase):
def checkResults(self, dataset, shapes, values):
self.assertEqual(shapes, dataset.output_shapes)
@@ -198,7 +199,7 @@ class GroupByReducerTest(test.TestCase):
self.assertEqual(y, 45)
-class GroupByWindowTest(test.TestCase):
+class GroupByWindowTest(test_base.DatasetTestBase):
def testSimple(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
@@ -345,7 +346,7 @@ class GroupByWindowTest(test.TestCase):
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
# Currently, they use a constant batch size, though should be made to use a
# different batch size per key.
-class BucketTest(test.TestCase):
+class BucketTest(test_base.DatasetTestBase):
def _dynamicPad(self, bucket, window, window_size):
# TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
@@ -570,7 +571,7 @@ def _get_record_shape(sparse):
return tensor_shape.TensorShape([None])
-class BucketBySequenceLength(test.TestCase):
+class BucketBySequenceLength(test_base.DatasetTestBase):
def testBucket(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
index f8e74e4583..5b3c512b64 100644
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
@@ -30,6 +30,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -43,37 +44,7 @@ from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
-class CsvDatasetOpTest(test.TestCase):
-
- def _get_next(self, dataset):
- # Returns a no argument function whose result is fed to self.evaluate to
- # yield the next element
- it = dataset.make_one_shot_iterator()
- if context.executing_eagerly():
- return it.get_next
- else:
- get_next = it.get_next()
- return lambda: get_next
-
- def _assert_datasets_equal(self, ds1, ds2):
- assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, '
- '%s') % (ds1.output_shapes,
- ds2.output_shapes)
- assert ds1.output_types == ds2.output_types
- assert ds1.output_classes == ds2.output_classes
- next1 = self._get_next(ds1)
- next2 = self._get_next(ds2)
- # Run through datasets and check that outputs match, or errors match.
- while True:
- try:
- op1 = self.evaluate(next1())
- except (errors.OutOfRangeError, ValueError) as e:
- # If op1 throws an exception, check that op2 throws same exception.
- with self.assertRaises(type(e)):
- self.evaluate(next2())
- break
- op2 = self.evaluate(next2())
- self.assertAllEqual(op1, op2)
+class CsvDatasetOpTest(test_base.DatasetTestBase):
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
filenames = []
@@ -108,7 +79,7 @@ class CsvDatasetOpTest(test.TestCase):
"""Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv)."""
dataset_actual, dataset_expected = self._make_test_datasets(
inputs, **kwargs)
- self._assert_datasets_equal(dataset_actual, dataset_expected)
+ self.assertDatasetsEqual(dataset_actual, dataset_expected)
def _verify_output_or_err(self,
dataset,
@@ -116,7 +87,7 @@ class CsvDatasetOpTest(test.TestCase):
expected_err_re=None):
if expected_err_re is None:
# Verify that output is expected, without errors
- nxt = self._get_next(dataset)
+ nxt = self.getNext(dataset)
expected_output = [[
v.encode('utf-8') if isinstance(v, str) else v for v in op
] for op in expected_output]
@@ -128,7 +99,7 @@ class CsvDatasetOpTest(test.TestCase):
else:
# Verify that OpError is produced as expected
with self.assertRaisesOpError(expected_err_re):
- nxt = self._get_next(dataset)
+ nxt = self.getNext(dataset)
while True:
try:
self.evaluate(nxt())
@@ -354,7 +325,7 @@ class CsvDatasetOpTest(test.TestCase):
inputs = [['1,,3,4', '5,6,,8']]
ds_actual, ds_expected = self._make_test_datasets(
inputs, record_defaults=record_defaults)
- self._assert_datasets_equal(
+ self.assertDatasetsEqual(
ds_actual.repeat(5).prefetch(1),
ds_expected.repeat(5).prefetch(1))
@@ -377,7 +348,7 @@ class CsvDatasetOpTest(test.TestCase):
ds = readers.make_csv_dataset(
file_path, batch_size=1, shuffle=False, num_epochs=1)
- nxt = self._get_next(ds)
+ nxt = self.getNext(ds)
result = list(self.evaluate(nxt()).values())
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
index a2ab3de52e..722e87e555 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -25,7 +26,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test.TestCase):
+class DatasetConstructorTest(test_base.DatasetTestBase):
def testRestructureDataset(self):
components = (array_ops.placeholder(dtypes.int32),
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
index eb110324d1..bc10c21472 100644
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
@@ -20,13 +20,14 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import random_seed
from tensorflow.python.platform import test
-class DirectedInterleaveDatasetTest(test.TestCase):
+class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
def testBasic(self):
selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
index f3968cdc15..cc22ea1df7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import get_single_element
from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class GetSingleElementTest(test.TestCase, parameterized.TestCase):
+class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("Zero", 0, 1),
diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
index 9c508d686d..d4d3d4adb2 100644
--- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
@@ -19,29 +19,30 @@ from __future__ import print_function
import unittest
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import indexed_dataset_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.platform import test
-class IndexedDatasetOpsTest(test.TestCase):
+class IndexedDatasetOpsTest(test_base.DatasetTestBase):
def testLowLevelIndexedDatasetOps(self):
- identity = gen_dataset_ops.identity_indexed_dataset(
+ identity = ged_ops.experimental_identity_indexed_dataset(
ops.convert_to_tensor(16, dtype=dtypes.uint64))
- handle = gen_dataset_ops.materialized_index_dataset_handle(
+ handle = ged_ops.experimental_materialized_index_dataset_handle(
container="",
shared_name="",
output_types=[dtypes.uint64],
output_shapes=[[]])
- materialize = gen_dataset_ops.indexed_dataset_materialize(identity, handle)
+ materialize = ged_ops.experimental_indexed_dataset_materialize(
+ identity, handle)
index = array_ops.placeholder(dtypes.uint64)
- get_op = gen_dataset_ops.indexed_dataset_get(
+ get_op = ged_ops.experimental_indexed_dataset_get(
handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
with self.cached_session() as sess:
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
index b9e74dfddb..28bd670ab5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
@@ -25,6 +25,7 @@ import time
from six.moves import zip_longest
from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -36,7 +37,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class ParallelInterleaveDatasetTest(test.TestCase):
+class ParallelInterleaveDatasetTest(test_base.DatasetTestBase):
def setUp(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index 7e2326bd17..58a1d7c93b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import iterator_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn
@@ -33,7 +34,7 @@ from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
-class CheckpointInputPipelineHookTest(test.TestCase):
+class CheckpointInputPipelineHookTest(test_base.DatasetTestBase):
@staticmethod
def _model_fn(features, labels, mode, config):
diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
index 1cc5ddc9a2..d2a72272db 100644
--- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
@@ -22,6 +22,7 @@ import os
import shutil
from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -31,7 +32,7 @@ from tensorflow.python.util import compat
prefix_path = "tensorflow/core/lib"
-class LMDBDatasetTest(test.TestCase):
+class LMDBDatasetTest(test_base.DatasetTestBase):
def setUp(self):
super(LMDBDatasetTest, self).setUp()
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index e8519381d6..385c4ef6ea 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -41,7 +42,7 @@ from tensorflow.python.util import compat
_NUMPY_RANDOM_SEED = 42
-class MapDatasetTest(test.TestCase):
+class MapDatasetTest(test_base.DatasetTestBase):
def testMapIgnoreError(self):
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index 25aea0393f..751e6d5b30 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -21,6 +21,7 @@ import time
from tensorflow.contrib.data.python.ops import map_defun
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -33,7 +34,8 @@ from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class MapDefunTest(test.TestCase):
+
+class MapDefunTest(test_base.DatasetTestBase):
def testMapDefunSimple(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
index 1ae92bdeff..d7b5edcd9a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
@@ -15,6 +15,7 @@ py_test(
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -31,6 +32,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -57,7 +59,6 @@ py_test(
srcs = ["map_vectorization_test.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/kernel_tests:test_utils",
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:check_ops",
"//tensorflow/python:client_testlib",
@@ -67,6 +68,7 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -85,6 +87,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -102,6 +105,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -121,6 +125,7 @@ py_test(
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -137,6 +142,7 @@ py_test(
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -151,6 +157,7 @@ py_test(
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
index d10da80442..fe1b5280ba 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
@@ -18,12 +18,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
-class AssertNextDatasetTest(test.TestCase):
+class AssertNextDatasetTest(test_base.DatasetTestBase):
def testAssertNext(self):
dataset = dataset_ops.Dataset.from_tensors(0).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
index 9518c2e1ad..b43efb5c7c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -31,7 +32,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-class HoistRandomUniformTest(test.TestCase, parameterized.TestCase):
+class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def map_functions():
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
index e75edf6086..e9e3fc81e5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -28,7 +29,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
+class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def map_functions():
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
index dd547db086..f7907eb890 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-class MapParallelizationTest(test.TestCase, parameterized.TestCase):
+class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def map_functions():
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
index 5b493f44c9..a5ea85f454 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
@@ -22,9 +22,9 @@ import time
from absl.testing import parameterized
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import test_utils
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -36,7 +36,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
+class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
def _get_test_datasets(self,
base_dataset,
@@ -85,7 +85,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
[3, 4]]).repeat(5)
unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
num_parallel_calls)
- self._assert_datasets_equal(unoptimized, optimized)
+ self.assertDatasetsEqual(unoptimized, optimized)
def testOptimizationBadMapFn(self):
# Test map functions that give an error
@@ -112,7 +112,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
# TODO(rachelim): when this optimization works, turn on expect_optimized
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_equal(optimized, unoptimized)
+ self.assertDatasetsEqual(optimized, unoptimized)
def testOptimizationIgnoreStateful(self):
@@ -124,7 +124,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
[3, 4]]).repeat(5)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_raise_same_error(
+ self.assertDatasetsRaiseSameError(
unoptimized, optimized, errors.InvalidArgumentError,
[("OneShotIterator", "OneShotIterator_1", 1),
("IteratorGetNext", "IteratorGetNext_1", 1)])
@@ -138,7 +138,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_equal(unoptimized, optimized)
+ self.assertDatasetsEqual(unoptimized, optimized)
def testOptimizationIgnoreRaggedMap(self):
# Don't optimize when the output of the map fn shapes are unknown.
@@ -148,7 +148,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_raise_same_error(
+ self.assertDatasetsRaiseSameError(
unoptimized, optimized, errors.InvalidArgumentError,
[("OneShotIterator", "OneShotIterator_1", 1),
("IteratorGetNext", "IteratorGetNext_1", 1)])
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
index 3b62a7e468..33c250ab2a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
@@ -23,12 +23,13 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class ModelDatasetTest(test.TestCase):
+class ModelDatasetTest(test_base.DatasetTestBase):
def testModelMap(self):
k = 1024 * 1024
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
index 507feda3ad..b9e60cfa4e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -26,7 +27,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class NoopEliminationTest(test.TestCase):
+class NoopEliminationTest(test_base.DatasetTestBase):
def testNoopElimination(self):
a = constant_op.constant(1, dtype=dtypes.int64)
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
index a3fb824ce9..04f499f8c5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -28,7 +29,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-class OptimizeDatasetTest(test.TestCase):
+class OptimizeDatasetTest(test_base.DatasetTestBase):
def testOptimizationDefault(self):
dataset = dataset_ops.Dataset.range(10).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
index c4623bca73..66ccaceea5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import parsing_ops as contrib_parsing_ops
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -72,7 +73,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
i += 1
-class ParseExampleTest(test.TestCase):
+class ParseExampleTest(test_base.DatasetTestBase):
def _test(self,
input_tensor,
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 33a64ea767..7a6a7a709a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -22,6 +22,7 @@ import threading
from tensorflow.contrib.data.python.ops import prefetching_ops
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import compat
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -35,7 +36,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-class PrefetchingKernelsOpsTest(test.TestCase):
+class PrefetchingKernelsOpsTest(test_base.DatasetTestBase):
def setUp(self):
self._event = threading.Event()
@@ -244,7 +245,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
sess.run(destroy_op)
-class PrefetchToDeviceTest(test.TestCase):
+class PrefetchToDeviceTest(test_base.DatasetTestBase):
def testPrefetchToDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
@@ -445,7 +446,7 @@ class PrefetchToDeviceTest(test.TestCase):
sess.run(next_element)
-class CopyToDeviceTest(test.TestCase):
+class CopyToDeviceTest(test_base.DatasetTestBase):
def testCopyToDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index db8fe6aa1b..2e901587f4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from tensorflow.contrib.data.python.ops import counter
from tensorflow.contrib.data.python.ops import enumerate_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -27,7 +28,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
-class RangeDatasetTest(test.TestCase):
+class RangeDatasetTest(test_base.DatasetTestBase):
def testEnumerateDataset(self):
components = (["a", "b"], [1, 2], [37.0, 38])
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index ed75b27a44..66ed547b6d 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import nest
from tensorflow.python.framework import constant_op
@@ -242,7 +243,7 @@ class ReadBatchFeaturesTest(
self.assertEqual(32, shape[0])
-class MakeCsvDatasetTest(test.TestCase):
+class MakeCsvDatasetTest(test_base.DatasetTestBase):
def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
return readers.make_csv_dataset(
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
index 08b9f03816..f443b5501b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
@@ -25,6 +25,7 @@ import zlib
from tensorflow.contrib.data.python.ops import readers
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.framework import constant_op
@@ -32,11 +33,10 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class FixedLengthRecordDatasetTestBase(test.TestCase):
+class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing FixedLengthRecordDataset."""
def setUp(self):
@@ -63,7 +63,7 @@ class FixedLengthRecordDatasetTestBase(test.TestCase):
return filenames
-class ReadBatchFeaturesTestBase(test.TestCase):
+class ReadBatchFeaturesTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing `make_batched_feature_dataset`."""
def setUp(self):
@@ -273,7 +273,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
self.assertAllEqual(expected_batch[i], actual_batch[i])
-class TextLineDatasetTestBase(test.TestCase):
+class TextLineDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing TextLineDataset."""
def _lineText(self, f, l):
@@ -313,7 +313,7 @@ class TextLineDatasetTestBase(test.TestCase):
return filenames
-class TFRecordDatasetTestBase(test.TestCase):
+class TFRecordDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing TFRecordDataset."""
def setUp(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
index 16b1441baa..32474bd411 100644
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
@@ -24,6 +24,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.data.python.ops import resampling
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -57,7 +58,7 @@ def _time_resampling(
return end_time - start_time
-class ResampleTest(test.TestCase, parameterized.TestCase):
+class ResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("InitialDistributionKnown", True),
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
index dde678bd54..bdf80eae4e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
@@ -22,6 +22,7 @@ import itertools
import numpy as np
from tensorflow.contrib.data.python.ops import scan_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -33,7 +34,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ScanDatasetTest(test.TestCase):
+class ScanDatasetTest(test_base.DatasetTestBase):
def _counting_dataset(self, start, scan_fn):
return dataset_ops.Dataset.from_tensors(0).repeat().apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
index 14cd3e9c4a..a10f85263a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -90,6 +91,16 @@ class StatsDatasetSerializationTest(
lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
None, num_outputs)
+ def _build_dataset_stats_aggregator(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ return dataset_ops.Dataset.range(10).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+
+ def test_set_stats_aggregator_not_support_checkpointing(self):
+ with self.assertRaisesRegexp(errors.UnimplementedError,
+ "does not support checkpointing"):
+ self.run_core_tests(self._build_dataset_stats_aggregator, None, 10)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
index 440e48db30..c97002a255 100644
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
@@ -20,13 +20,14 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import shuffle_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
-class ShuffleAndRepeatTest(test.TestCase):
+class ShuffleAndRepeatTest(test_base.DatasetTestBase):
def _build_ds(self, seed, count=5, num_elements=20):
return dataset_ops.Dataset.range(num_elements).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
index 90d18dca2a..c5a7862322 100644
--- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
@@ -21,6 +21,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.ops import sliding
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class SlideDatasetTest(test.TestCase, parameterized.TestCase):
+class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("1", 20, 14, 7, 1),
@@ -197,11 +198,6 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
sliding.sliding_window_batch(
window_size=1, stride=1, window_shift=1, window_stride=1))
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testSlideSparse(self):
def _sparse(i):
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
index 1f5c725a92..319a2ea263 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
@@ -24,12 +24,13 @@ import os
import sqlite3
from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class SqlDatasetTestBase(test.TestCase):
+class SqlDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing SqlDataset."""
def _createSqlDataset(self, output_types, num_repeats=1):
@@ -92,5 +93,3 @@ class SqlDatasetTestBase(test.TestCase):
9007199254740992.0)])
conn.commit()
conn.close()
-
-
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
index b1b4c23510..80f2625927 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
@@ -19,10 +19,10 @@ from __future__ import print_function
from tensorflow.core.framework import summary_pb2
-from tensorflow.python.platform import test
+from tensorflow.python.data.kernel_tests import test_base
-class StatsDatasetTestBase(test.TestCase):
+class StatsDatasetTestBase(test_base.DatasetTestBase):
"""Base class for testing statistics gathered in `StatsAggregator`."""
def _assertSummaryContains(self, summary_str, tag):
diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
deleted file mode 100644
index 4c3353fe40..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Test utilities for tf.data functionality."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import re
-
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-
-
-class DatasetTestBase(test.TestCase):
- """Base class for dataset tests."""
-
- def _assert_datasets_equal(self, dataset1, dataset2):
- # TODO(rachelim): support sparse tensor outputs
- next1 = dataset1.make_one_shot_iterator().get_next()
- next2 = dataset2.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- while True:
- try:
- op1 = sess.run(next1)
- except errors.OutOfRangeError:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next2)
- break
- op2 = sess.run(next2)
-
- op1 = nest.flatten(op1)
- op2 = nest.flatten(op2)
- assert len(op1) == len(op2)
- for i in range(len(op1)):
- self.assertAllEqual(op1[i], op2[i])
-
- def _assert_datasets_raise_same_error(self,
- dataset1,
- dataset2,
- exception_class,
- replacements=None):
- # We are defining next1 and next2 in the same line so that we get identical
- # file:line_number in the error messages
- # pylint: disable=line-too-long
- next1, next2 = dataset1.make_one_shot_iterator().get_next(), dataset2.make_one_shot_iterator().get_next()
- # pylint: enable=line-too-long
- with self.cached_session() as sess:
- try:
- sess.run(next1)
- raise ValueError(
- "Expected dataset to raise an error of type %s, but it did not." %
- repr(exception_class))
- except exception_class as e:
- expected_message = e.message
- for old, new, count in replacements:
- expected_message = expected_message.replace(old, new, count)
- # Check that the first segment of the error messages are the same.
- with self.assertRaisesRegexp(exception_class,
- re.escape(expected_message)):
- sess.run(next2)
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
index 8d335e87d5..08de3a9143 100644
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
@@ -24,6 +24,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import threadpool
from tensorflow.contrib.data.python.ops import unique
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -31,7 +32,8 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
-class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
+class OverrideThreadpoolDatasetTest(test_base.DatasetTestBase,
+ parameterized.TestCase):
@parameterized.named_parameters(
("1", 1, None),
diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
index f994c8563f..8856ce5afb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import unique
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -25,7 +26,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class UniqueDatasetTest(test.TestCase):
+class UniqueDatasetTest(test_base.DatasetTestBase):
def _testSimpleHelper(self, dtype, test_cases):
"""Test the `unique()` transformation on a list of test cases.
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index 8b7b3ac0f7..79134c7bc6 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -31,7 +32,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class WindowDatasetTest(test.TestCase, parameterized.TestCase):
+class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _structuredDataset(self, structure, shape, dtype):
if structure is None:
diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
index 867ee2ba37..fca546a570 100644
--- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import os
from tensorflow.contrib.data.python.ops import writers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.framework import dtypes
@@ -30,7 +31,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class TFRecordWriterTest(test.TestCase):
+class TFRecordWriterTest(test_base.DatasetTestBase):
def setUp(self):
super(TFRecordWriterTest, self).setUp()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index a14781cd93..5cd1ed542b 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -78,7 +78,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":batching",
- ":gen_dataset_ops",
":interleave_ops",
":optimization",
":parsing_ops",
@@ -86,6 +85,7 @@ py_library(
"//tensorflow/python:constant_op",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
"//tensorflow/python:platform",
@@ -148,8 +148,7 @@ py_library(
srcs = ["error_ops.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
@@ -179,12 +178,11 @@ py_library(
srcs = ["interleave_ops.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
":random_ops",
"//tensorflow/contrib/stateless",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
@@ -199,9 +197,8 @@ py_library(
srcs = ["optimization.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
@@ -304,8 +301,7 @@ py_library(
srcs = ["threadpool.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
@@ -321,9 +317,8 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
@@ -342,47 +337,11 @@ py_library(
],
)
-tf_gen_op_wrapper_py(
- name = "gen_dataset_ops",
- out = "gen_dataset_ops.py",
- deps = [
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
- ],
-)
-
-tf_kernel_library(
- name = "dataset_ops_kernels",
- deps = [
- "//tensorflow/contrib/data/kernels:dataset_kernels",
- "//tensorflow/core:framework",
- ],
- alwayslink = 1,
-)
-
-tf_custom_op_py_library(
- name = "contrib_op_loader",
- srcs = ["contrib_op_loader.py"],
- dso = ["//tensorflow/contrib/data:_dataset_ops.so"],
- kernels = [
- ":dataset_ops_kernels",
- "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":gen_dataset_ops",
- "//tensorflow/contrib/util:util_py",
- "//tensorflow/python:platform",
- ],
-)
-
py_library(
name = "indexed_dataset_ops",
srcs = ["indexed_dataset_ops.py"],
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
@@ -394,7 +353,7 @@ py_library(
name = "prefetching_ops",
srcs = ["prefetching_ops.py"],
deps = [
- ":contrib_op_loader",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
diff --git a/tensorflow/contrib/data/python/ops/contrib_op_loader.py b/tensorflow/contrib/data/python/ops/contrib_op_loader.py
deleted file mode 100644
index 8f495a9dc9..0000000000
--- a/tensorflow/contrib/data/python/ops/contrib_op_loader.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Python helper for loading contrib ops and kernels."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.util import loader
-from tensorflow.python.platform import resource_loader
-
-_dataset_ops = loader.load_op_library(
- resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index 615dbcabd4..f962e623ee 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -17,9 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
def ignore_errors():
@@ -60,7 +59,7 @@ class _IgnoreErrorsDataset(dataset_ops.UnaryDataset):
self._input_dataset = input_dataset
def _as_variant_tensor(self):
- return gen_dataset_ops.ignore_errors_dataset(
+ return gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
diff --git a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
index cc76ab0850..9c06474a2f 100644
--- a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
+++ b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
@@ -19,14 +19,13 @@ from __future__ import print_function
import abc
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
class MaterializedIndexedDataset(object):
@@ -57,7 +56,7 @@ class MaterializedIndexedDataset(object):
A tensor containing the values corresponding to `index`.
"""
# TODO(saeta): nest.pack_sequence_as(...)
- return gen_dataset_ops.indexed_dataset_get(
+ return ged_ops.experimental_indexed_dataset_get(
self._materialized_resource,
index,
output_types=nest.flatten(
@@ -90,16 +89,18 @@ class IndexedDataset(dataset_ops.Dataset):
container = ""
if shared_name is None:
shared_name = ""
- materialized_resource = gen_dataset_ops.materialized_index_dataset_handle(
- container=container,
- shared_name=shared_name,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_types(self.output_shapes, self.output_classes)))
+ materialized_resource = (
+ ged_ops.experimental_materialized_index_dataset_handle(
+ container=container,
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_types(self.output_shapes,
+ self.output_classes))))
with ops.colocate_with(materialized_resource):
- materializer = gen_dataset_ops.indexed_dataset_materialize(
+ materializer = ged_ops.experimental_indexed_dataset_materialize(
self._as_variant_tensor(), materialized_resource)
return MaterializedIndexedDataset(materialized_resource, materializer,
self.output_classes, self.output_types,
@@ -170,7 +171,7 @@ class IdentityIndexedDataset(IndexedDataset):
return tensor_shape.scalar()
def _as_variant_tensor(self):
- return gen_dataset_ops.identity_indexed_dataset(self._size)
+ return ged_ops.experimental_identity_indexed_dataset(self._size)
def _inputs(self):
return []
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index bfa3fdf543..1ee9db1aa8 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -18,8 +18,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import stateless
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import random_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
@@ -28,6 +26,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import deprecation
@@ -167,10 +166,12 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
def _as_variant_tensor(self):
# pylint: disable=protected-access
- return gen_dataset_ops.directed_interleave_dataset(
- self._selector_input._as_variant_tensor(),
- [data_input._as_variant_tensor() for data_input in self._data_inputs],
- **dataset_ops.flat_structure(self))
+ return (
+ gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
+ self._selector_input._as_variant_tensor(), [
+ data_input._as_variant_tensor()
+ for data_input in self._data_inputs
+ ], **dataset_ops.flat_structure(self)))
# pylint: enable=protected-access
def _inputs(self):
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index 3eb172acd5..30348ede36 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -17,12 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
# A constant that can be used to enable auto-tuning.
AUTOTUNE = -1
@@ -54,7 +53,7 @@ def model():
Returns:
A `Dataset` transformation function, which can be passed to
- @{tf.data.Dataset.apply}.
+ `tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
@@ -97,7 +96,7 @@ class _AssertNextDataset(dataset_ops.UnaryDataset):
transformations, dtype=dtypes.string, name="transformations")
def _as_variant_tensor(self):
- return contrib_gen_dataset_ops.assert_next_dataset(
+ return gen_experimental_dataset_ops.experimental_assert_next_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._transformations,
**dataset_ops.flat_structure(self))
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 58395879e6..46f82e453a 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -19,8 +19,6 @@ from __future__ import print_function
import warnings
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
@@ -32,7 +30,8 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.ops import resource_variable_ops
@@ -64,7 +63,7 @@ def function_buffering_resource(string_arg,
"""
if shared_name is None:
shared_name = ""
- return gen_dataset_ops.function_buffering_resource(
+ return ged_ops.experimental_function_buffering_resource(
string_arg=string_arg,
target_device=target_device,
shared_name=shared_name,
@@ -78,14 +77,14 @@ def function_buffering_resource(string_arg,
def function_buffering_resource_get_next(function_buffer_resource,
output_types,
name=None):
- return gen_dataset_ops.function_buffering_resource_get_next(
+ return ged_ops.experimental_function_buffering_resource_get_next(
function_buffer_resource=function_buffer_resource,
output_types=output_types,
name=name)
def function_buffering_resource_reset(function_buffer_resource, name=None):
- return gen_dataset_ops.function_buffering_resource_reset(
+ return ged_ops.experimental_function_buffering_resource_reset(
function_buffer_resource=function_buffer_resource, name=name)
@@ -136,7 +135,7 @@ class _PrefetchToDeviceIterator(object):
ret = remote_iterator.get_next()
return nest.flatten(sparse.serialize_sparse_tensors(ret))
- iterator_device = gen_dataset_ops.iterator_get_device(
+ iterator_device = ged_ops.experimental_iterator_get_device(
self._input_iterator._iterator_resource)
with ops.device(device):
@@ -162,10 +161,11 @@ class _PrefetchToDeviceIterator(object):
if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
- flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
+ flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
self._buffering_resource,
- output_types=nest.flatten(sparse.as_dense_types(
- self.output_types, self.output_classes)), name=name)
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ name=name)
ret = sparse.deserialize_sparse_tensors(
nest.pack_sequence_as(self.output_types, flat_ret),
@@ -219,7 +219,7 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
buffer_size):
with ops.device("/device:CPU:0"):
super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset)
- input_iterator_handle = core_gen_dataset_ops.iterator_to_string_handle(
+ input_iterator_handle = gen_dataset_ops.iterator_to_string_handle(
self._resource)
self._device = device
@@ -238,7 +238,8 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
self._buffering_resource = function_buffering_resource(
f=_prefetch_fn,
output_types=self._flat_output_types,
- target_device=gen_dataset_ops.iterator_get_device(self._resource),
+ target_device=ged_ops.experimental_iterator_get_device(
+ self._resource),
string_arg=input_iterator_handle,
buffer_size=buffer_size,
shared_name=iterator_ops._generate_shared_name(
@@ -252,7 +253,7 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
# TODO(b/77291417): Fix
with context.execution_mode(context.SYNC):
with ops.device(self._device):
- ret = gen_dataset_ops.function_buffering_resource_get_next(
+ ret = ged_ops.experimental_function_buffering_resource_get_next(
function_buffer_resource=self._buffering_resource,
output_types=self._flat_output_types)
return sparse.deserialize_sparse_tensors(
@@ -409,12 +410,12 @@ class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
"""
# pylint: disable=protected-access
ds_variant = self._input_dataset._as_variant_tensor()
- resource = core_gen_dataset_ops.anonymous_iterator(
+ resource = gen_dataset_ops.anonymous_iterator(
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
with ops.control_dependencies(
- [core_gen_dataset_ops.make_iterator(ds_variant, resource)]):
- return core_gen_dataset_ops.iterator_to_string_handle(resource)
+ [gen_dataset_ops.make_iterator(ds_variant, resource)]):
+ return gen_dataset_ops.iterator_to_string_handle(resource)
@function.Defun()
def _remote_init_func():
@@ -463,7 +464,7 @@ class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
Returns:
Tensor constant 0
"""
- iterator_resource = core_gen_dataset_ops.iterator_from_string_handle_v2(
+ iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
string_handle,
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
@@ -504,7 +505,7 @@ class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
def _as_variant_tensor(self):
with ops.device(self._target_device):
- return core_gen_dataset_ops.generator_dataset(
+ return gen_dataset_ops.generator_dataset(
self._init_captured_args,
self._next_captured_args,
self._finalize_captured_args,
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index d9d06e2703..360971e200 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -23,7 +23,6 @@ import csv
import numpy as np
from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.contrib.data.python.ops import parsing_ops
@@ -38,6 +37,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.platform import gfile
from tensorflow.python.util import deprecation
@@ -629,7 +629,7 @@ class CsvDataset(dataset_ops.DatasetSource):
def _as_variant_tensor(self):
# Constructs graph node for the dataset op.
- return contrib_gen_dataset_ops.csv_dataset(
+ return gen_experimental_dataset_ops.experimental_csv_dataset(
filenames=self._filenames,
record_defaults=self._record_defaults,
buffer_size=self._buffer_size,
@@ -1013,7 +1013,7 @@ class LMDBDataset(dataset_ops.DatasetSource):
filenames, dtype=dtypes.string, name="filenames")
def _as_variant_tensor(self):
- return contrib_gen_dataset_ops.lmdb_dataset(
+ return gen_experimental_dataset_ops.experimental_lmdb_dataset(
self._filenames,
output_types=nest.flatten(self.output_types),
output_shapes=nest.flatten(self.output_shapes))
diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py
index 9d165ad52a..f73c3fd9cb 100644
--- a/tensorflow/contrib/data/python/ops/threadpool.py
+++ b/tensorflow/contrib/data/python/ops/threadpool.py
@@ -19,10 +19,9 @@ from __future__ import print_function
import threading
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.ops import resource_variable_ops
_uid_counter = 0
@@ -47,7 +46,7 @@ class PrivateThreadPool(object):
"""Creates a `PrivateThreadPool` with the given number of threads."""
if context.executing_eagerly():
shared_name = _generate_shared_name("privatethreadpool")
- self._resource = gen_dataset_ops.thread_pool_handle(
+ self._resource = ged_ops.experimental_thread_pool_handle(
num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name,
@@ -55,7 +54,7 @@ class PrivateThreadPool(object):
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._resource, handle_device=context.context().device_name)
else:
- self._resource = gen_dataset_ops.thread_pool_handle(
+ self._resource = ged_ops.experimental_thread_pool_handle(
num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name)
@@ -70,7 +69,7 @@ class _ThreadPoolDataset(dataset_ops.UnaryDataset):
self._thread_pool = thread_pool
def _as_variant_tensor(self):
- return gen_dataset_ops.thread_pool_dataset(
+ return ged_ops.experimental_thread_pool_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._thread_pool._resource, # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py
index bad67a580d..ed363a7090 100644
--- a/tensorflow/contrib/data/python/ops/unique.py
+++ b/tensorflow/contrib/data/python/ops/unique.py
@@ -17,10 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_experimental_dataset_ops
def unique():
@@ -61,7 +60,7 @@ class _UniqueDataset(dataset_ops.UnaryDataset):
"`tf.int32`, `tf.int64`, or `tf.string` component.")
def _as_variant_tensor(self):
- return gen_dataset_ops.unique_dataset(
+ return gen_experimental_dataset_ops.experimental_unique_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD
index 3b50a48336..06940a90d5 100644
--- a/tensorflow/contrib/decision_trees/proto/BUILD
+++ b/tensorflow/contrib/decision_trees/proto/BUILD
@@ -17,7 +17,6 @@ tf_proto_library(
name = "generic_tree_model",
srcs = ["generic_tree_model.proto"],
cc_api_version = 2,
- java_api_version = 2,
visibility = ["//visibility:public"],
)
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md
index 91a27f97b7..2e025765e4 100644
--- a/tensorflow/contrib/distribute/README.md
+++ b/tensorflow/contrib/distribute/README.md
@@ -231,7 +231,8 @@ The same `input_fn` will be used for all workers if you use
important to shuffle your dataset in your `input_fn`.
`MirroredStrategy` will insert a `tf.dataset.Dataset.shard` call in you
-`input_fn`. As a result, each worker gets a fraction of your input data.
+`input_fn` if `auto_shard_dataset` is set to `True`. As a result, each worker
+gets a fraction of your input data.
### Performance Tips
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index e329b964c4..422983dbef 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -22,6 +22,7 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":input_ops",
+ ":prefetching_ops_v2",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:device_util",
@@ -29,7 +30,6 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
- "//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"@six_archive//:six",
@@ -648,6 +648,32 @@ cuda_py_test(
)
py_library(
+ name = "prefetching_ops_v2",
+ srcs = ["prefetching_ops_v2.py"],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:prefetching_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+cuda_py_test(
+ name = "prefetching_ops_v2_test",
+ srcs = ["prefetching_ops_v2_test.py"],
+ additional_deps = [
+ ":prefetching_ops_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+)
+
+py_library(
name = "input_ops",
srcs = ["input_ops.py"],
visibility = ["//tensorflow:internal"],
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index c900b41e14..9809204f8f 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -216,7 +216,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
"""Configures the object.
Args:
- session_config: a @{tf.ConfigProto}
+ session_config: a `tf.ConfigProto`
cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
cluster configurations.
task_type: the current task type, such as "worker".
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 244d1fcec8..82ca041cc2 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -59,6 +59,7 @@ from tensorflow.python.training import adagrad
from tensorflow.python.training import adam
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import rmsprop
from tensorflow.python.util import tf_inspect
@@ -354,6 +355,8 @@ gradient_descent_optimizer_v1_fn = NamedObject(
"GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2))
adagrad_optimizer_v1_fn = NamedObject(
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
+rmsprop_optimizer_v1_fn = NamedObject(
+ "RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001))
optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn,
adagrad_optimizer_v1_fn]
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index a0b8bde132..3aab2c521f 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -173,13 +173,42 @@ def batch_wrapper(dataset, batch_size, distribution):
return dataset.batch(batch_size)
-def all_combinations():
+def get_model():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+ return model
+
+
+def get_dataset(distribution):
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = batch_wrapper(dataset, 10, distribution)
+ return dataset
+
+
+strategies = [combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus,
+ combinations.tpu_strategy_one_step]
+
+
+def strategy_combinations():
return combinations.combine(
- distribution=[combinations.default_strategy,
- combinations.one_device_strategy,
- combinations.mirrored_strategy_with_gpu_and_cpu,
- combinations.mirrored_strategy_with_two_gpus,
- combinations.tpu_strategy_one_step],
+ distribution=strategies,
+ mode=['graph'])
+
+
+def strategy_and_optimizer_combinations():
+ return combinations.combine(
+ distribution=strategies,
+ optimizer=[combinations.adagrad_optimizer_v1_fn,
+ combinations.adam_optimizer_v1_fn,
+ combinations.gradient_descent_optimizer_v1_fn,
+ combinations.rmsprop_optimizer_v1_fn],
mode=['graph'])
@@ -360,9 +389,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
def test_calling_model_with_numpy_arrays(self):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
@@ -392,23 +419,17 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
# with batch_size
model.predict(inputs, batch_size=8)
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
def test_calling_model_on_same_dataset(self, distribution):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
metrics = ['mae', keras.metrics.CategoricalAccuracy()]
model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = batch_wrapper(dataset, 10, distribution)
+ dataset = get_dataset(distribution)
# Call fit with validation data
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
@@ -461,23 +482,17 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1)
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
def test_fit_eval_and_predict_methods_on_dataset(self, distribution):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
metrics = ['mae', keras.metrics.CategoricalAccuracy()]
model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = batch_wrapper(dataset, 10, distribution)
+ dataset = get_dataset(distribution)
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
model.evaluate(dataset, steps=2, verbose=1)
@@ -486,11 +501,23 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
validation_data=dataset, validation_steps=2)
+ @combinations.generate(strategy_and_optimizer_combinations())
+ def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer):
+ with self.cached_session():
+ model = get_model()
+
+ loss = 'mse'
+ model.compile(optimizer(), loss, distribute=distribution)
+
+ dataset = get_dataset(distribution)
+
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+ model.evaluate(dataset, steps=2, verbose=1)
+ model.predict(dataset, steps=2)
+
def test_unsupported_features(self):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
@@ -500,11 +527,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ dataset = get_dataset(strategy)
# Test with validation split
with self.assertRaisesRegexp(
@@ -541,9 +564,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
def test_calling_with_unsupported_predefined_callbacks(self):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
@@ -552,11 +573,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
'/device:GPU:0'])
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ dataset = get_dataset(strategy)
def schedule(_):
return 0.001
@@ -580,9 +597,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
def test_dataset_input_shape_validation(self):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
loss = 'mse'
@@ -616,17 +631,13 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
mode=['graph']))
def test_dataset_input_shape_fully_defined(self, distribution):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
loss = 'mse'
model.compile(optimizer, loss, distribute=distribution)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = get_dataset(distribution)
# Input shapes are not fully known. Batch dimension is unknown as we are
# not using the drop_remainder argument.
dataset = dataset.repeat(100).batch(10)
@@ -698,7 +709,7 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase):
class NormalizationLayerWithDistributionStrategyTest(
test.TestCase, parameterized.TestCase):
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
def test_batchnorm_correctness(self, distribution):
with self.cached_session():
model = keras.models.Sequential()
@@ -726,7 +737,7 @@ class NormalizationLayerWithDistributionStrategyTest(
class CorrectnessWithDistributionStrategyTest(test.TestCase,
parameterized.TestCase):
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
def test_metric_correctness(self, distribution):
with self.cached_session():
keras.backend.set_image_data_format('channels_last')
@@ -756,7 +767,7 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase,
history = model.fit(x=train_dataset, epochs=1, steps_per_epoch=10)
self.assertEqual(history.history['binary_accuracy'], [1.0])
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
def test_correctness(self, distribution):
with self.cached_session():
keras.backend.set_image_data_format('channels_last')
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
index f7773aff4f..8163494c8e 100644
--- a/tensorflow/contrib/distribute/python/metrics_v1_test.py
+++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py
@@ -86,11 +86,10 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
with ops.Graph().as_default(), distribution.scope():
iterator = distribution.distribute_dataset(
- dataset_fn).make_initializable_iterator()
+ dataset_fn).make_one_shot_iterator()
value, update = distribution.call_for_each_tower(
metric_fn, iterator.get_next())
update = distribution.group(update)
- self.evaluate(iterator.initializer)
self.evaluate(variables.local_variables_initializer())
# TODO(josh11b): Once we switch to using a global batch size for input,
# replace "distribution.num_towers" with "1".
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index d082d5c419..ba147e7824 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -41,14 +41,6 @@ from tensorflow.python.ops.losses import losses_impl
class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
- def _get_iterator(self, ds):
- if context.executing_eagerly():
- iterator = ds.make_one_shot_iterator()
- else:
- iterator = ds.make_initializable_iterator()
- self.evaluate(iterator.initializer)
- return iterator
-
@combinations.generate(
combinations.times(
combinations.distributions_and_v1_optimizers(),
@@ -70,7 +62,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, *inputs, run_concurrently=layer.built))
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.run_steps_on_dataset(
@@ -106,7 +99,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.group(
@@ -165,7 +159,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, *inputs, run_concurrently=layer.built))
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.run_steps_on_dataset(
@@ -249,7 +244,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
return control_flow_ops.group(fetches)
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.run_steps_on_dataset(
@@ -342,7 +338,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, x, y, run_concurrently=False))
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.run_steps_on_dataset(
@@ -435,7 +432,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
output=loss)
return distribution.group(train_op)
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
initial_loss = lambda: constant_op.constant(1e7)
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 945f450387..4d7516063c 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -347,6 +347,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
set, the `configure` method will try to find the best one.
prefetch_on_device: optional boolean to specify whether to prefetch input
data to devices.
+ auto_shard_dataset: whether to auto-shard the dataset when there are
+ multiple workers.
"""
def __init__(self,
@@ -354,11 +356,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
num_gpus=None,
num_gpus_per_worker=None,
cross_tower_ops=None,
- prefetch_on_device=None):
+ prefetch_on_device=None,
+ auto_shard_dataset=False):
super(MirroredStrategy, self).__init__()
self._cross_tower_ops = cross_tower_ops
self._prefetch_on_device = prefetch_on_device
+ self._auto_shard_dataset = auto_shard_dataset
# Rememeber num GPUs which might be needed by `configure` method.
if num_gpus is not None and num_gpus_per_worker is not None:
raise ValueError(
@@ -477,13 +481,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
if self._cluster_spec:
return values.MultiWorkerDataset(
partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
- self._prefetch_on_device)
+ self._prefetch_on_device, self._auto_shard_dataset)
else:
return values.PerDeviceDataset(
- self._call_dataset_fn(dataset_fn),
- self._devices,
- self._prefetch_on_device,
- source_device=device_util.resolve("/device:CPU:0"))
+ self._call_dataset_fn(dataset_fn), self._devices,
+ self._prefetch_on_device)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
def _run_steps_on_dataset(self, fn, iterator, iterations,
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 04c712ce1d..f51e543624 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -300,15 +300,9 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
- ds = dist.distribute_dataset(
- lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
- if context.executing_eagerly():
- iterator = ds.make_one_shot_iterator()
- else:
- iterator = ds.make_initializable_iterator()
- self.evaluate([iterator.initializer])
-
- features = iterator.get_next()
+ features = dist.distribute_dataset(
+ lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
+ ).make_one_shot_iterator().get_next()
with dist.scope():
result = dist.call_for_each_tower(
diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py
index 17b7ab74f6..7644acedc9 100644
--- a/tensorflow/contrib/distribute/python/monitor.py
+++ b/tensorflow/contrib/distribute/python/monitor.py
@@ -51,7 +51,6 @@ class Monitor(object):
else:
if session is None:
raise ValueError("Should provide a `session` in Graph mode.")
- session.run(step_callable._iterator.initializer) # pylint: disable=protected-access
self._run_step = session.make_callable(step_callable())
session.run(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
index 3064433129..6e9ba37a19 100644
--- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py
+++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
@@ -42,11 +42,8 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
- ds = distribution.distribute_dataset(dataset_fn)
- if context.executing_eagerly():
- iterator = ds.make_one_shot_iterator()
- else:
- iterator = ds.make_initializable_iterator()
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return control_flow_ops.group(distribution.unwrap(
@@ -55,7 +52,6 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
if not context.executing_eagerly():
with self.cached_session() as sess:
- sess.run(iterator.initializer)
run_step = sess.make_callable(run_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
new file mode 100644
index 0000000000..8d949943b7
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
@@ -0,0 +1,232 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Extension of prefetching_ops to support more than one device."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+
+from tensorflow.contrib.data.python.ops import prefetching_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.util import nest as data_nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+from tensorflow.python.util import nest
+
+
+# pylint: disable=protected-access
+class _PrefetchToDeviceIterator(object):
+ """A replacement for `tf.data.Iterator` that prefetches to another device.
+
+ Args:
+ input_dataset: The input dataset.
+ one_shot: If true, we make a one shot iterator that's already initialized.
+ devices: Devices on which to prefetch.
+ buffer_size: Size of the prefetching buffer.
+ shared_name: (Optional.) If non-empty, the returned iterator will be shared
+ under the given name across multiple sessions that share the same devices
+ (e.g. when using a remote server). Only used if one_shot is False.
+
+ Returns:
+ An Iterator type object.
+ """
+
+ def __init__(self,
+ input_dataset,
+ one_shot,
+ devices,
+ buffer_size,
+ shared_name=None):
+ self._input_dataset = input_dataset
+ self._get_next_call_count = 0
+ self._one_shot = one_shot
+ if shared_name is None:
+ shared_name = ""
+ self._devices = devices
+
+ if self._one_shot:
+ self._input_iterator = input_dataset.make_one_shot_iterator()
+ else:
+ self._input_iterator = iterator_ops.Iterator.from_structure(
+ self._input_dataset.output_types, self._input_dataset.output_shapes,
+ shared_name, self._input_dataset.output_classes)
+ input_iterator_handle = self._input_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _prefetch_fn(handle):
+ """Prefetches one element from `input_iterator`."""
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ handle, self._input_iterator.output_types,
+ self._input_iterator.output_shapes,
+ self._input_iterator.output_classes)
+ ret = remote_iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ target_device = ged_ops.experimental_iterator_get_device(
+ self._input_iterator._iterator_resource)
+ self._buffering_resources = []
+ for device in nest.flatten(self._devices):
+ with ops.device(device):
+ buffer_resource_handle = prefetching_ops.function_buffering_resource(
+ f=_prefetch_fn,
+ output_types=data_nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes)),
+ target_device=target_device,
+ string_arg=input_iterator_handle,
+ buffer_size=buffer_size,
+ shared_name=shared_name)
+ self._buffering_resources.append(buffer_resource_handle)
+
+ if not self._one_shot:
+ reset_ops = []
+ for buffer_resource in self._buffering_resources:
+ reset_ops.append(
+ ged_ops.experimental_function_buffering_resource_reset(
+ buffer_resource))
+ with ops.control_dependencies(reset_ops):
+ self._initializer = self._input_iterator.make_initializer(
+ self._input_dataset)
+
+ def get_next(self, name=None):
+ """See `tf.data.Iterator.get_next`."""
+ self._get_next_call_count += 1
+ if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
+ warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
+
+ flat_result = []
+ # TODO(priyag): This will fail if the input size (typically number of
+ # batches) is not divisible by number of devices.
+ # How do we handle that more gracefully / let the user know?
+ for buffer_resource in self._buffering_resources:
+ flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
+ buffer_resource,
+ output_types=data_nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ name=name)
+
+ ret = sparse.deserialize_sparse_tensors(
+ data_nest.pack_sequence_as(self.output_types, flat_ret),
+ self.output_types, self.output_shapes, self.output_classes)
+
+ for tensor, shape in zip(
+ data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
+ if isinstance(tensor, ops.Tensor):
+ tensor.set_shape(shape)
+ flat_result.append(ret)
+
+ return nest.pack_sequence_as(self._devices, flat_result)
+
+ @property
+ def initializer(self):
+ if self._one_shot:
+ raise NotImplementedError("Can't initialize a one_shot_iterator")
+ return self._initializer
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+# pylint: enable=protected-access
+
+
+class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` whose iterator prefetches elements to other device(s)."""
+
+ def __init__(self, input_dataset, devices, buffer_size):
+ super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._devices = devices
+ self._buffer_size = buffer_size if buffer_size is not None else 1
+
+ def make_one_shot_iterator(self):
+ return _PrefetchToDeviceIterator(
+ self._input_dataset,
+ one_shot=True,
+ devices=self._devices,
+ buffer_size=self._buffer_size)
+
+ def make_initializable_iterator(self, shared_name=None):
+ if context.executing_eagerly():
+ raise RuntimeError(
+ "make_initializable_iterator is not supported when eager "
+ "execution is enabled.")
+
+ return _PrefetchToDeviceIterator(
+ self._input_dataset,
+ one_shot=False,
+ devices=self._devices,
+ buffer_size=self._buffer_size,
+ shared_name=shared_name)
+
+ def _as_variant_tensor(self):
+ # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
+ # transformation methods is called.
+ # TODO(mrry): Investigate support for chaining further transformations after
+ # the prefetch, including GPU support.
+ raise NotImplementedError("`prefetch_to_devices()` must be the last "
+ "transformation in a dataset pipeline.")
+
+ # TODO(priyag): Fix the output types, shapes and classes to match the result
+ # of get_next (which has the additional nesting layer of devices now).
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+def prefetch_to_devices(devices, buffer_size=None):
+ """A transformation that prefetches dataset values to the given `devices`.
+
+ NOTE: Although the transformation creates a `tf.data.Dataset`, the
+ transformation must be the final `Dataset` in the input pipeline.
+
+ Args:
+ devices: A nested structure of devices on which to prefetch the data. It can
+ be a single device name, or a tuple or list of device names.
+ buffer_size: (Optional.) The number of elements to buffer on each device.
+ Defaults to an automatically chosen value.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _PrefetchToDeviceDataset(dataset, devices, buffer_size)
+
+ return _apply_fn
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
new file mode 100644
index 0000000000..16799104e8
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
@@ -0,0 +1,90 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for prefetching_ops_v2."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distribute.python import prefetching_ops_v2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class PrefetchingOpsV2Test(test.TestCase):
+
+ def testPrefetchToOneDevice(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices("/gpu:0"))
+
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToTwoDevicesInAList(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
+
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ output = []
+ # TODO(rohanj): Modify test to go till the end of the dataset when we
+ # switch to MultiDeviceIterator.
+ with self.cached_session() as sess:
+ for _ in range(4):
+ result = sess.run(next_element)
+ self.assertEqual(2, len(result))
+ output.extend(result)
+ self.assertEquals(set(range(8)), set(output))
+
+ def testPrefetchToTwoDevicesWithReinit(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
+
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ # TODO(rohanj): Modify test to go till the end of the dataset when we
+ # switch to MultiDeviceIterator.
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for _ in range(4):
+ sess.run(next_element)
+ sess.run(iterator.initializer)
+ for _ in range(4):
+ sess.run(next_element)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py
index 23bf36184f..1b5a4f64e5 100644
--- a/tensorflow/contrib/distribute/python/step_fn.py
+++ b/tensorflow/contrib/distribute/python/step_fn.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import backprop
-from tensorflow.python.eager import context
from tensorflow.python.training import optimizer as optimizer_lib
@@ -51,11 +50,7 @@ class StandardInputStep(Step):
def __init__(self, dataset_fn, distribution):
super(StandardInputStep, self).__init__(distribution)
self._distributed_input = distribution.distribute_dataset(dataset_fn)
- if context.executing_eagerly():
- self._iterator = self._distributed_input.make_one_shot_iterator()
- else:
- # TODO(priyag): Expose initializer via some initializer property.
- self._iterator = self._distributed_input.make_initializable_iterator()
+ self._iterator = self._distributed_input.make_one_shot_iterator()
class StandardSingleLossStep(StandardInputStep):
diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py
index 1ff9b9ceec..f1ada49fa3 100644
--- a/tensorflow/contrib/distribute/python/step_fn_test.py
+++ b/tensorflow/contrib/distribute/python/step_fn_test.py
@@ -50,7 +50,6 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
run_step = single_loss_step
else:
with self.cached_session() as sess:
- sess.run(single_loss_step._iterator.initializer)
run_step = sess.make_callable(single_loss_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index a0cd029f51..4955ded4d5 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -26,7 +26,7 @@ import weakref
import six
from tensorflow.contrib.distribute.python import input_ops
-from tensorflow.python.data.ops import multi_device_iterator_ops
+from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
@@ -683,7 +683,7 @@ class PerDeviceDataIterator(object):
def get_next(self, name=None):
"""Scatter the input across devices."""
if self._prefetch_on_device:
- data_list = self._iterator.get_next()
+ data_list = self._iterator.get_next(name=name)
index = dict(zip(self._devices, data_list))
else:
batch = self._iterator.get_next(name=name)
@@ -703,26 +703,21 @@ class PerDeviceDataIterator(object):
class PerDeviceDataset(object):
"""Like `tf.data.Dataset` split devices, producing `PerDevice` data."""
- def __init__(
- self,
- dataset,
- devices,
- prefetch_on_device=None,
- source_device="/cpu:0",
- ):
+ def __init__(self, dataset, devices, prefetch_on_device=None):
self._devices = devices
- self._source_device = source_device if source_device is not None else "/cpu:0"
# Default to using prefetching in graph mode, unless specified.
- # TODO(rohanj): Enable prefetching in eager mode.
+ # TODO(priyag): Enable prefetching in eager mode.
self._prefetch_on_device = prefetch_on_device
if self._prefetch_on_device is None:
self._prefetch_on_device = not context.executing_eagerly()
assert not (self._prefetch_on_device and context.executing_eagerly()), (
"Prefetching is only supported in graph mode currently")
- self._dataset = dataset
- if not self._prefetch_on_device:
+ if self._prefetch_on_device:
+ self._dataset = dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices(self._devices))
+ else:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
@@ -730,33 +725,15 @@ class PerDeviceDataset(object):
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerDeviceDataset."""
- # Graph mode prefetching with one shot iterator is disabled.
- if not context.executing_eagerly():
- raise ValueError("Cannot create a one shot iterator. Please use "
- "`make_initializable_iterator()` instead.")
- # Eager mode prefetching would error out in constructor. Only remaining
- # cases are non-prefetching eager / graph mode. We delegate to
- # PerDeviceDataIterator to handle them.
dataset_iterator = self._dataset.make_one_shot_iterator()
- return PerDeviceDataIterator(
- dataset_iterator, self._devices, prefetch_on_device=False)
+ return PerDeviceDataIterator(dataset_iterator, self._devices,
+ self._prefetch_on_device)
def make_initializable_iterator(self):
"""Get an initializable iterator for the distributed PerDeviceDataset."""
- # Eager mode generates already initialized iterators. Hence we cannot create
- # an initializable iterator.
- if context.executing_eagerly():
- raise ValueError("Cannot create initializable iterator in Eager mode. "
- "Please use `make_one_shot_iterator` instead.")
- if self._prefetch_on_device:
- dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
- self._dataset, self._devices, source_device=self._source_device)
- else:
- dataset_iterator = self._dataset.make_initializable_iterator()
- return PerDeviceDataIterator(
- dataset_iterator,
- self._devices,
- prefetch_on_device=self._prefetch_on_device)
+ dataset_iterator = self._dataset.make_initializable_iterator()
+ return PerDeviceDataIterator(dataset_iterator, self._devices,
+ self._prefetch_on_device)
class MultiWorkerDataIterator(object):
@@ -816,7 +793,8 @@ class MultiWorkerDataset(object):
eager mode.
"""
- def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None):
+ def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None,
+ auto_shard=False):
"""Initialize the MultiWorkerDataset object.
Args:
@@ -824,6 +802,7 @@ class MultiWorkerDataset(object):
worker_device_map: a dict mapping from each worker to a list of devices
that belong to this worker.
prefetch_on_device: whether to prefetch to devices.
+ auto_shard: whether to auto-shard the dataset.
"""
self._worker_device_map = worker_device_map
self._datasets = {}
@@ -833,13 +812,11 @@ class MultiWorkerDataset(object):
six.iteritems(worker_device_map)):
with ops.device(worker):
worker_input = dataset_fn()
- worker_input = input_ops.auto_shard_dataset(
- worker_input, len(worker_device_map), i)
+ if auto_shard:
+ worker_input = input_ops.auto_shard_dataset(
+ worker_input, len(worker_device_map), i)
self._datasets[worker] = PerDeviceDataset(
- worker_input,
- worker_devices,
- source_device=worker,
- prefetch_on_device=prefetch_on_device)
+ worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
def make_one_shot_iterator(self):
iterators = {}
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 002d61f46e..ae3e134333 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -349,11 +349,7 @@ class PerDeviceDatasetTest(test.TestCase):
def _test_iterator_no_prefetch(self, devices, dataset, expected_values):
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=False)
- if context.executing_eagerly():
- iterator = per_device_dataset.make_one_shot_iterator()
- else:
- iterator = per_device_dataset.make_initializable_iterator()
- self.evaluate([iterator.initializer])
+ iterator = per_device_dataset.make_one_shot_iterator()
for expected_value in expected_values:
next_element = iterator.get_next()
@@ -370,14 +366,21 @@ class PerDeviceDatasetTest(test.TestCase):
if not context.executing_eagerly():
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=True)
- iterator = per_device_dataset.make_initializable_iterator()
- self.evaluate([iterator.initializer])
+ iterator = per_device_dataset.make_one_shot_iterator()
+ # With prefetching, we cannot guarantee which input ends up on which
+ # device, so we verify that the complete set seen on all devices is
+ # correct, and equal numbers are distributed to each device.
+ combined_actual = []
+ combined_expected = []
for expected_value in expected_values:
next_element = iterator.get_next()
- computed_value = self.evaluate(
- [values.select_device(d, next_element) for d in devices])
- self.assertEqual(expected_value, computed_value)
+ combined_actual.extend(
+ self.evaluate(
+ [values.select_device(d, next_element) for d in devices]))
+ combined_expected.extend(expected_value)
+
+ self.assertEqual(set(combined_expected), set(combined_actual))
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD
index e344d7a23b..510f292508 100644
--- a/tensorflow/contrib/factorization/BUILD
+++ b/tensorflow/contrib/factorization/BUILD
@@ -154,6 +154,8 @@ tf_py_test(
],
tags = [
"no_pip", # b/38283730
+ "noasan", # b/116875897
+ "nomsan",
"notsan", # Flaky: b/30756419
],
)
@@ -177,7 +179,11 @@ tf_py_test(
"//tensorflow/python:random_seed",
"//tensorflow/python:variables",
],
- tags = ["notsan"], # b/62863147
+ tags = [
+ "noasan", # b/116875897
+ "nomsan",
+ "notsan", # b/62863147
+ ],
)
py_library(
@@ -276,6 +282,7 @@ tf_py_test(
"manual",
"noasan", # times out b/63678675
"nomsan",
+ "notsan", # b/116875897
],
)
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index f320b53d94..f3ebe3b245 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -26,6 +26,14 @@ config_setting(
},
)
+# Enables inclusion of TensorFlow kernels via the TF Lite Flex delegate.
+# WARNING: This build flag is experimental and subject to change.
+config_setting(
+ name = "with_tflite_flex",
+ define_values = {"with_tflite_flex": "true"},
+ visibility = ["//visibility:public"],
+)
+
cc_library(
name = "schema_fbs_version",
hdrs = ["version.h"],
@@ -157,6 +165,10 @@ cc_library(
"stderr_reporter.h",
],
copts = tflite_copts(),
+ defines = select({
+ ":with_tflite_flex": ["TFLITE_FLEX"],
+ "//conditions:default": [],
+ }),
linkopts = [
] + select({
"//tensorflow:android": [
@@ -180,7 +192,12 @@ cc_library(
"//tensorflow/contrib/lite/nnapi:nnapi_lib",
"//tensorflow/contrib/lite/profiling:profiler",
"//tensorflow/contrib/lite/schema:schema_fbs",
- ],
+ ] + select({
+ ":with_tflite_flex": [
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
+ ],
+ "//conditions:default": [],
+ }),
)
cc_library(
diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD
index 4d2437e7d3..d180cb4785 100644
--- a/tensorflow/contrib/lite/examples/android/BUILD
+++ b/tensorflow/contrib/lite/examples/android/BUILD
@@ -28,6 +28,7 @@ android_binary(
srcs = glob([
"app/src/main/java/**/*.java",
]),
+ aapt_version = "aapt",
# Package assets from assets dir as well as all model targets.
# Remove undesired models (and corresponding Activities in source)
# to reduce APK size.
diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl
index db837cf29e..9d2aead266 100644
--- a/tensorflow/contrib/lite/java/aar_with_jni.bzl
+++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl
@@ -3,12 +3,12 @@
load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
def aar_with_jni(name, android_library):
- # Generate dummy AndroidManifest.xml for dummy apk usage
- # (dummy apk is generated by <name>_dummy_app_for_so target below)
- native.genrule(
- name = name + "_binary_manifest_generator",
- outs = [name + "_generated_AndroidManifest.xml"],
- cmd = """
+ # Generate dummy AndroidManifest.xml for dummy apk usage
+ # (dummy apk is generated by <name>_dummy_app_for_so target below)
+ native.genrule(
+ name = name + "_binary_manifest_generator",
+ outs = [name + "_generated_AndroidManifest.xml"],
+ cmd = """
cat > $(OUTS) <<EOF
<manifest
xmlns:android="http://schemas.android.com/apk/res/android"
@@ -17,27 +17,28 @@ cat > $(OUTS) <<EOF
</manifest>
EOF
""",
- )
+ )
- # Generate dummy apk including .so files and later we extract out
- # .so files and throw away the apk.
- android_binary(
- name = name + "_dummy_app_for_so",
- manifest = name + "_generated_AndroidManifest.xml",
- custom_package = "dummy.package.for.so",
- deps = [android_library],
- # In some platforms we don't have an Android SDK/NDK and this target
- # can't be built. We need to prevent the build system from trying to
- # use the target in that case.
- tags = ["manual"],
- )
+ # Generate dummy apk including .so files and later we extract out
+ # .so files and throw away the apk.
+ android_binary(
+ name = name + "_dummy_app_for_so",
+ aapt_version = "aapt",
+ manifest = name + "_generated_AndroidManifest.xml",
+ custom_package = "dummy.package.for.so",
+ deps = [android_library],
+ # In some platforms we don't have an Android SDK/NDK and this target
+ # can't be built. We need to prevent the build system from trying to
+ # use the target in that case.
+ tags = ["manual"],
+ )
- native.genrule(
- name = name,
- srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"],
- outs = [name + ".aar"],
- tags = ["manual"],
- cmd = """
+ native.genrule(
+ name = name,
+ srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"],
+ outs = [name + ".aar"],
+ tags = ["manual"],
+ cmd = """
cp $(location {}.aar) $(location :{}.aar)
chmod +w $(location :{}.aar)
origdir=$$PWD
@@ -46,4 +47,4 @@ unzip $$origdir/$(location :{}_dummy_app_for_so_unsigned.apk) "lib/*"
cp -r lib jni
zip -r $$origdir/$(location :{}.aar) jni/*/*.so
""".format(android_library, name, name, name, name),
- )
+ )
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
index 220d6c2159..5ad738389e 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
@@ -7,6 +7,7 @@ licenses(["notice"]) # Apache 2.0
android_binary(
name = "TfLiteCameraDemo",
srcs = glob(["java/**/*.java"]),
+ aapt_version = "aapt",
assets = [
"//tensorflow/contrib/lite/java/demo/app/src/main/assets:labels_mobilenet_quant_v1_224.txt",
"@tflite_mobilenet//:mobilenet_quant_v1_224.tflite",
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
index b2e3a9bd7d..058240aada 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
@@ -8,6 +8,7 @@ android_binary(
srcs = [
"OvicBenchmarkerActivity.java",
],
+ aapt_version = "aapt",
assets = [
"//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata",
"//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt",
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
index 765c3a03ef..689cea03e7 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -37,10 +37,6 @@ inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) {
: nullptr;
}
-inline Dims<4> GetTensorDims(std::vector<int32_t> data) {
- return GetTensorDims(data.data(), data.size());
-}
-
inline RuntimeShape GetTensorShape(std::vector<int32_t> data) {
return RuntimeShape(data.size(), data.data());
}
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
index 5e688ce452..9f5b33d217 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
@@ -86,35 +86,6 @@ inline const bool* GetTensorData(const TfLiteTensor* tensor) {
return tensor != nullptr ? tensor->data.b : nullptr;
}
-// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
-// even if the original tensors were not 4D. We should consider rewriting them
-// to take a more generic 'shape' object.
-inline Dims<4> GetTensorDims(const int data[], const int size) {
- Dims<4> d;
- for (int i = 0; i < 4; ++i) {
- int src = size - i - 1;
- if (src >= 0) {
- d.sizes[i] = data[src];
- } else {
- d.sizes[i] = 1;
- }
- }
- d.strides[0] = 1;
- for (int i = 1; i < 4; i++) {
- d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
- }
- return d;
-}
-
-inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
- if (tensor == nullptr) {
- return Dims<4>();
- }
-
- auto* dims = tensor->dims;
- return GetTensorDims(dims->data, dims->size);
-}
-
inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
if (tensor == nullptr) {
return RuntimeShape();
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
index bf2068d320..2ed73ba82d 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
@@ -21,28 +21,32 @@ namespace {
using ::testing::ElementsAre;
-TEST(TensorTest, GetTensorDims4D) {
- Dims<4> d = GetTensorDims({2, 3, 4, 5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 2));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60));
+TEST(TensorTest, GetTensorShape4D) {
+ RuntimeShape d = GetTensorShape({2, 3, 4, 5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(2, 3, 4, 5));
}
-TEST(TensorTest, GetTensorDims3D) {
- Dims<4> d = GetTensorDims({3, 4, 5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 1));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60));
+TEST(TensorTest, GetTensorShape3D) {
+ RuntimeShape d = GetTensorShape({3, 4, 5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(3, 4, 5));
}
-TEST(TensorTest, GetTensorDims2D) {
- Dims<4> d = GetTensorDims({4, 5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 4, 1, 1));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 20));
+TEST(TensorTest, GetTensorShape2D) {
+ RuntimeShape d = GetTensorShape({4, 5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(4, 5));
}
-TEST(TensorTest, GetTensorDims1D) {
- Dims<4> d = GetTensorDims({5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 1, 1, 1));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 5, 5));
+TEST(TensorTest, GetTensorShape1D) {
+ RuntimeShape d = GetTensorShape({5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(5));
}
} // namespace
diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
index f18a2ca07a..2e5033dab1 100644
--- a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
+++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
@@ -20,6 +20,7 @@ filegroup(
android_binary(
name = "SmartReplyDemo",
srcs = glob(["java/**/*.java"]),
+ aapt_version = "aapt",
assets = [":assets"],
assets_dir = "",
custom_package = "com.example.android.smartreply",
diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile
index d962a5e12d..36125c198e 100644
--- a/tensorflow/contrib/makefile/Makefile
+++ b/tensorflow/contrib/makefile/Makefile
@@ -133,7 +133,8 @@ $(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*benchmark*.cc) \
$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*benchmark*.cc) \
$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*benchmark*.cc) \
$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*/*benchmark*.cc) \
-tensorflow/contrib/makefile/downloads/absl/absl/synchronization/internal/mutex_nonprod.cc
+tensorflow/contrib/makefile/downloads/absl/absl/synchronization/internal/mutex_nonprod.cc \
+tensorflow/contrib/makefile/downloads/absl/absl/hash/internal/print_hash_of.cc
ABSL_CC_SRCS := $(filter-out $(ABSL_CC_EXCLUDE_SRCS), $(ABSL_CC_ALL_SRCS))
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index f4ac70eb1a..6a67c6295d 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -377,6 +377,11 @@ py_test(
size = "large",
srcs = ["python/training/shampoo_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "noasan", # b/116875897
+ "nomsan",
+ "notsan",
+ ],
deps = [
":opt_py",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py
index 05bcf2cfa3..a2fd8fbd87 100644
--- a/tensorflow/contrib/opt/python/training/shampoo_test.py
+++ b/tensorflow/contrib/opt/python/training/shampoo_test.py
@@ -54,9 +54,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size)
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -105,9 +105,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size[0], size[1])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -164,9 +164,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size[0], size[1], size[2])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -254,9 +254,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size)
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -310,9 +310,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size[0], size[1])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -383,9 +383,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(sample_size_2, size[1])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = ops.IndexedSlices(
constant_op.constant(grad_np, dtype=dtypes.float32),
@@ -463,9 +463,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np = np.random.rand(sample_size, size[1], size[2])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = ops.IndexedSlices(
constant_op.constant(grad_np, dtype=dtypes.float32),
@@ -533,9 +533,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
gbar_weight = 0.1
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -628,9 +628,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g3 = np.zeros_like(mat_g3_a)
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = array_ops.placeholder(dtypes.float32, shape=size)
@@ -705,9 +705,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g3 = np.zeros_like(mat_g3_a)
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = array_ops.placeholder(dtypes.float32, shape=size)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index c230919168..cb1f707028 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -159,7 +159,12 @@ py_test(
],
shard_count = 4,
srcs_version = "PY2AND3",
- tags = ["no_pip_gpu"], # b/63391119
+ tags = [
+ "no_pip_gpu", # b/63391119
+ "noasan", # b/116875897
+ "nomsan",
+ "notsan",
+ ],
deps = [
":estimators",
":feature_keys",
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index 647455ae42..04d17bc123 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -104,7 +104,7 @@ class EvaluationMetricsTests(test.TestCase):
"ticker":
array_ops.reshape(
math_ops.cast(
- variables.Variable(
+ variables.VariableV1(
name="ticker",
initial_value=0,
dtype=dtypes.int64,
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index 766466968a..6ce6b779a2 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -55,7 +55,9 @@
@@TPUDistributionStrategy
@@keras_to_tpu_model
+
@@AsyncCheckpointSaverHook
+@@TPUInMemoryEvalHook
"""
from __future__ import absolute_import
@@ -65,6 +67,7 @@ from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.contrib.tpu.python import profiler
from tensorflow.contrib.tpu.python.ops.tpu_ops import *
+from tensorflow.contrib.tpu.python.tpu.async_checkpoint import *
from tensorflow.contrib.tpu.python.tpu.bfloat16 import *
from tensorflow.contrib.tpu.python.tpu.device_assignment import *
from tensorflow.contrib.tpu.python.tpu.keras_support import tpu_model as keras_to_tpu_model
diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
index 1bd1a31e11..bc1a0c5284 100644
--- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
+++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
@@ -104,18 +104,9 @@ Status RegisterPerTableLoadOpsForAlgorithmBody(
}
}
{
- auto* table_id_attr = op_def->add_attr();
- table_id_attr->set_name("table_id");
- table_id_attr->set_type("int");
- table_id_attr->set_has_minimum(true);
- table_id_attr->set_minimum(-1);
- table_id_attr->mutable_default_value()->set_i(-1);
- }
- {
auto* table_name_attr = op_def->add_attr();
table_name_attr->set_name("table_name");
table_name_attr->set_type("string");
- table_name_attr->mutable_default_value()->set_s("");
}
{
auto* num_shards_attr = op_def->add_attr();
@@ -147,11 +138,9 @@ parameters that are loaded from a checkpoint before a training loop is
executed.
%s
table_name: Name of this table; must match a name in the
- EmbeddingLayerConfiguration proto (overrides table_id).
+ EmbeddingLayerConfiguration proto.
num_shards: Number of shards into which the embedding tables are divided.
shard_id: Identifier of shard for this operation.
-table_id: Index of this table in the EmbeddingLayerConfiguration proto
- (deprecated).
)doc",
parameter_descriptions.c_str()));
op_def->set_is_commutative(false);
@@ -160,14 +149,10 @@ table_id: Index of this table in the EmbeddingLayerConfiguration proto
auto shape_inference_function =
[state_variable_specs,
is_debug_op](shape_inference::InferenceContext* c) -> Status {
- int table_id;
- TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
- // Exactly one must be non-default.
- if ((table_id >= 0) == (!table_name.empty())) {
- return errors::InvalidArgument(
- "exactly one of table_id or table_name must be non-default");
+ if (table_name.empty()) {
+ return errors::InvalidArgument("table_name attribute must be set");
}
int num_shards;
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
@@ -241,18 +226,9 @@ Status RegisterPerTableRetrieveOpsForAlgorithmBody(
}
}
{
- auto* table_id_attr = op_def->add_attr();
- table_id_attr->set_name("table_id");
- table_id_attr->set_type("int");
- table_id_attr->set_has_minimum(true);
- table_id_attr->set_minimum(-1);
- table_id_attr->mutable_default_value()->set_i(-1);
- }
- {
auto* table_name_attr = op_def->add_attr();
table_name_attr->set_name("table_name");
table_name_attr->set_type("string");
- table_name_attr->mutable_default_value()->set_s("");
}
{
auto* num_shards_attr = op_def->add_attr();
@@ -283,11 +259,9 @@ the correct embedding table configuration. For example, this op is
used to retrieve updated parameters before saving a checkpoint.
%s
table_name: Name of this table; must match a name in the
- EmbeddingLayerConfiguration proto (overrides table_id).
+ EmbeddingLayerConfiguration proto.
num_shards: Number of shards into which the embedding tables are divided.
shard_id: Identifier of shard for this operation.
-table_id: Index of this table in the EmbeddingLayerConfiguration proto
- (deprecated).
)doc",
parameter_descriptions.c_str()));
op_def->set_is_commutative(false);
@@ -296,14 +270,10 @@ table_id: Index of this table in the EmbeddingLayerConfiguration proto
auto shape_inference_function =
[state_variable_specs,
is_debug_op](shape_inference::InferenceContext* c) -> Status {
- int table_id;
- TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
- // Exactly one must be non-default.
- if ((table_id >= 0) == (!table_name.empty())) {
- return errors::InvalidArgument(
- "exactly one of table_id or table_name must be non-default");
+ if (table_name.empty()) {
+ return errors::InvalidArgument("table_name must be non-empty");
}
int num_shards;
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index b498599962..8e6e9aa0cd 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -156,8 +156,7 @@ bool NewSession(const string& service_addr,
channel_args));
NewProfileSessionResponse new_session_response;
TF_QCHECK_OK(FromGrpcStatus(
- stub->NewSession(&context, new_session_request, &new_session_response)))
- << new_session_response.error_message();
+ stub->NewSession(&context, new_session_request, &new_session_response)));
std::cout << "Profile session succeed for host(s):"
<< str_util::Join(hostnames, ",") << std::endl;
diff --git a/tensorflow/contrib/tpu/profiler/op_profile.proto b/tensorflow/contrib/tpu/profiler/op_profile.proto
index b25d06dda8..292108f949 100644
--- a/tensorflow/contrib/tpu/profiler/op_profile.proto
+++ b/tensorflow/contrib/tpu/profiler/op_profile.proto
@@ -66,8 +66,8 @@ message Metrics {
// - it does not reveal the peak core FLOPS of the hardware
double flops = 2;
- // The VMEM bandwidth used to load operands from HBM, as a fraction of
- // thereotical VMEM bandwidth on the specific hardware.
+ // The memory bandwidth used to load operands, as a fraction of
+ // thereotical memory bandwidth on the specific hardware.
double memory_bandwidth = 3;
double raw_time = 11; // Elapsed core-time in picoseconds.
diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
index fc1320501b..a43f45554f 100644
--- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto
+++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
@@ -22,13 +22,22 @@ message LearningRate {
}
}
+// Each optimizer's parameter proto has a link to its documentation and CPU
+// implementation (if available) for user reference.
+
+// https://www.tensorflow.org/api_docs/python/tf/train/AdagradOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L151
message AdagradParameters {
float initial_accumulator = 1;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L423
message StochasticGradientDescentParameters {
}
+// https://www.tensorflow.org/api_docs/python/tf/train/FtrlOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L192
message FtrlParameters {
float l1 = 1;
float l2 = 2;
@@ -41,21 +50,38 @@ message FtrlParameters {
// learning rate feature instead, setting the learning rate to:
// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
// Here, t is the current timestep.
+//
+// https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
// https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L54
+//
+// Note that the code by default implements the lazy version of Adam
+// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/LazyAdamOptimizer)
+// unless the use_non_lazy_adam parameter is set, in which case it implements
+// the normal version of Adam that updates all parameters in the embedding
+// table, even for entries that are not used in the current minibatch
+// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If
+// use_non_lazy_adam is enabled, use_gradient_accumulation is also required in
+// order to get correct results; a warning will be printed otherwise (which may
+// change to an error in the future).
message AdamParameters {
float beta1 = 3;
float beta2 = 4;
float epsilon = 5;
float initial_m = 6;
float initial_v = 7;
+ bool use_non_lazy_adam = 8;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L271
message MomentumParameters {
float momentum = 1;
bool use_nesterov = 2;
float initial_accum = 3;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L356
message RmsPropParameters {
float rho = 1;
float momentum = 2;
@@ -64,6 +90,8 @@ message RmsPropParameters {
float initial_mom = 5;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L372
message CenteredRmsPropParameters {
float rho = 1;
float momentum = 2;
@@ -73,6 +101,7 @@ message CenteredRmsPropParameters {
float initial_mg = 6;
}
+// Variant of algorithm in http://proceedings.mlr.press/v44/shamir15.pdf
message MdlAdagradLightParameters {
float l2 = 1;
float lr_power = 2;
@@ -91,6 +120,8 @@ message MdlAdagradLightParameters {
float initial_benefit = 15;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L68
message AdadeltaParameters {
float rho = 1;
float epsilon = 2;
@@ -98,6 +129,8 @@ message AdadeltaParameters {
float initial_update = 4;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L164
message ProximalAdagradParameters {
float l1 = 1;
float l2 = 2;
diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
index e06a720e82..20b7ba0997 100644
--- a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
+++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
-
"""Hook for asynchronous checkpointing.
This hook dispatches checkpoint writing operations in a separate thread to
@@ -28,18 +27,16 @@ import threading
import time
from tensorflow.core.util.event_pb2 import SessionLog
-
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
-from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.training.session_run_hook import SessionRunArgs
from tensorflow.python.training.summary_io import SummaryWriterCache
-class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
+class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
"""Saves checkpoints every N steps or seconds."""
def __init__(self,
@@ -67,7 +64,7 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
ValueError: One of `save_steps` or `save_secs` should be set.
ValueError: At most one of `saver` or `scaffold` should be set.
"""
- logging.info("Create CheckpointSaverHook.")
+ logging.info("Create AsyncCheckpointSaverHook.")
if saver is not None and scaffold is not None:
raise ValueError("You cannot provide both saver and scaffold.")
self._saver = saver
@@ -144,6 +141,10 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
def _save(self, session, step, asynchronous=True):
"""Saves the latest checkpoint, returns should_stop."""
+ # Skip saving on step 0
+ if step == 0:
+ return
+
def _save_fn():
"""Run the saver process."""
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
@@ -162,7 +163,6 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
end_time - start_time)
logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
- logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
for l in self._listeners:
l.before_save(session, step)
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 956d0142a3..696656e840 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -959,7 +959,16 @@ class TPUFunction(object):
# Compute our outfeed depending on the execution mode
if is_training:
- self._cloned_model._make_train_function()
+ if not isinstance(self._cloned_optimizer, keras_optimizers.TFOptimizer):
+ # For Keras optimizer, we try to place the variable weights on the TPU
+ # device. Keras creates optimizer variables (e.g. momentum values for
+ # the Momentum optimizer) when _make_train_function is invoked.
+ with keras_tpu_variables.replicated_variable_for_optimizer(
+ self._tpu_assignment.num_towers):
+ self._cloned_model._make_train_function()
+ else:
+ self._cloned_model._make_train_function()
+
self._outfeed_spec = [
tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
for tensor in self._cloned_model.train_function.outputs
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
index 170977d8ab..598da7418e 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
@@ -25,10 +25,15 @@ from __future__ import print_function
import contextlib
+import numpy as np
+
from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
+from tensorflow.python.keras import backend
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
@@ -285,3 +290,51 @@ def replicated_scope(num_replicas):
return variable_scope.variable_scope(
"", custom_getter=_replicated_variable_getter)
+
+
+@contextlib.contextmanager
+def replicated_variable_for_optimizer(num_replicas):
+ """Context manager for optimizer weights. Overrides K.variable."""
+ if num_replicas == 1:
+ yield
+ return
+
+ try:
+ old_v = backend.variable
+
+ def opt_variable(value, dtype=None, name=None, constraint=None):
+ """Instantiates a variable and returns it."""
+ if dtype is None:
+ dtype = backend.floatx()
+
+ variables = []
+ for i in range(num_replicas):
+ # Keras holds the variables in optimizer class instance , so the name
+ # does not matter here. ResourceVariable constructor will find a unique
+ # name (including name=None) for each replica.
+ with ops.device("device:TPU:{}".format(i)):
+ v = resource_variable_ops.ResourceVariable(
+ value,
+ dtype=dtypes_module.as_dtype(dtype),
+ name=name,
+ constraint=constraint)
+ variables.append(v)
+ name = "replicate_{}_{}".format("variable" if name is None else name,
+ ops.uid())
+ v = ReplicatedVariable(name, variables)
+
+ # pylint: disable=protected-access
+
+ if isinstance(value, np.ndarray):
+ v._keras_shape = value.shape
+ elif hasattr(value, "shape"):
+ v._keras_shape = backend.int_shape(value)
+ v._uses_learning_phase = False
+ backend.track_variable(v)
+ return v
+
+ backend.variable = opt_variable
+ yield
+
+ finally:
+ backend.variable = old_v
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 23c54511ca..545cee637f 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -231,7 +231,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
`metric_fn` runs on CPU to generate metrics and `tensors` represents the
`Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`.
To be precise, TPU evaluation expects a slightly different signature from the
- @{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
+ `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
`tensors` usually specify the model logits, which are transferred back from
@@ -254,7 +254,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
sending tensors from TPU to CPU. To reduce the overhead, try reducing the
size of the tensors. The `tensors` are concatenated along their major (batch)
dimension, and so must be >= rank 1. The `host_call` is useful for writing
- summaries with @{tf.contrib.summary.create_file_writer}.
+ summaries with `tf.contrib.summary.create_file_writer`.
"""
def __new__(cls,
@@ -404,12 +404,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
self._feed_error = None
self._finished = False
+ self._should_initialize_tpu = True
def begin(self):
logging.info('TPU job name %s', self._master_job)
self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
- self._init_ops = [tpu.initialize_system(job=self._master_job)]
- self._finalize_ops = [tpu.shutdown_system(job=self._master_job)]
+ if self._should_initialize_tpu:
+ self._init_ops = [tpu.initialize_system(job=self._master_job)]
+ self._finalize_ops = [tpu.shutdown_system(job=self._master_job)]
+ else:
+ self._init_ops = []
+ self._finalize_ops = []
summary_writer_init_ops = contrib_summary.summary_writer_initializer_op()
self._init_ops.extend(summary_writer_init_ops)
diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD
index ddf8365d61..b565ebd073 100644
--- a/tensorflow/contrib/training/BUILD
+++ b/tensorflow/contrib/training/BUILD
@@ -313,6 +313,5 @@ tf_proto_library(
name = "protos_all",
srcs = glob(["**/*.proto"]),
cc_api_version = 2,
- java_api_version = 2,
visibility = ["//visibility:public"],
)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index ca247dc56b..7da4b9fbd0 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1039,6 +1039,7 @@ tf_gen_op_libs(
"dataset_ops",
"decode_proto_ops",
"encode_proto_ops",
+ "experimental_dataset_ops",
"function_ops",
"functional_ops",
"image_ops",
@@ -1169,6 +1170,7 @@ cc_library(
":dataset_ops_op_lib",
":decode_proto_ops_op_lib",
":encode_proto_ops_op_lib",
+ ":experimental_dataset_ops_op_lib",
":function_ops_op_lib",
":functional_ops_op_lib",
":image_ops_op_lib",
@@ -2484,6 +2486,8 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
"framework/op_segment.h",
"framework/rendezvous.h", # only needed for tests
"framework/resource_var.h",
+ "framework/run_handler.h",
+ "framework/run_handler_util.h",
"framework/tensor_reference.h",
"framework/tracking_allocator.h", # only needed for tests
"framework/unique_tensor_references.h",
@@ -2970,6 +2974,7 @@ tf_cuda_library(
":core_cpu_internal",
":device_tracer",
":framework",
+ ":framework_internal",
":graph",
":lib",
":lib_internal",
@@ -4117,6 +4122,19 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "framework_run_handler_util_test",
+ size = "small",
+ srcs = ["framework/run_handler_util_test.cc"],
+ linkstatic = tf_kernel_tests_linkstatic(),
+ deps = [
+ ":framework_internal",
+ ":lib",
+ ":test",
+ ":test_main",
+ ],
+)
+
tf_cuda_cc_test(
name = "common_runtime_direct_session_test",
size = "small",
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt
new file mode 100644
index 0000000000..fa8fc96bb2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalAssertNextDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt
new file mode 100644
index 0000000000..5fd88e7a0c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalCSVDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt
new file mode 100644
index 0000000000..ac1f9719fe
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt
@@ -0,0 +1,21 @@
+op {
+ graph_op_name: "ExperimentalDirectedInterleaveDataset"
+ in_arg {
+ name: "selector_input_dataset"
+ description: <<END
+A dataset of scalar `DT_INT64` elements that determines which of the
+`N` data inputs should produce the next output element.
+END
+ }
+ in_arg {
+ name: "data_input_datasets"
+ description: <<END
+`N` datasets with the same type that will be interleaved according to
+the values of `selector_input_dataset`.
+END
+ }
+ summary: <<END
+A substitute for `InterleaveDataset` on a fixed list of `N` datasets.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt
new file mode 100644
index 0000000000..66511eff60
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt
@@ -0,0 +1,58 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResource"
+ in_arg {
+ name: "string_arg"
+ description: <<END
+String argument to the function call.
+END
+ }
+ in_arg {
+ name: "target_device"
+ description: <<END
+Target device to execute the function on.
+END
+ }
+ out_arg {
+ name: "resource"
+ description: <<END
+Handle to the resource created.
+END
+ }
+ attr {
+ name: "shared_name"
+ description: <<END
+If non-empty, this resource will be shared under the given name across
+multiple sessions.
+END
+ }
+ attr {
+ name: "container"
+ description: <<END
+If non-empty, this resource is placed in the given container.
+Otherwise, a default container is used.
+END
+ }
+ attr {
+ name: "f"
+ description: <<END
+Function to be executed.
+END
+ }
+ attr {
+ name: "buffer_size"
+ description: <<END
+Size of the buffer.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ summary: <<END
+Creates a resource that fills up a buffer by making function calls.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt
new file mode 100644
index 0000000000..bf4b66b22b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt
@@ -0,0 +1,25 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResourceGetNext"
+ in_arg {
+ name: "function_buffer_resource"
+ description: <<END
+The FunctionBufferingResource handle.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A list of return values.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ summary: <<END
+Gets the next element from a FunctionBufferingResource.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt
new file mode 100644
index 0000000000..729718ddb3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResourceReset"
+ in_arg {
+ name: "function_buffer_resource"
+ description: <<END
+The FunctionBufferingResource handle.
+END
+ }
+ summary: <<END
+Resets the FunctionBufferingResource.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt
new file mode 100644
index 0000000000..fe266c111f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIdentityIndexedDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt
new file mode 100644
index 0000000000..d42546516d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalIgnoreErrorsDataset"
+ summary: <<END
+Creates a dataset that contains the elements of `input_dataset` ignoring errors.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt
new file mode 100644
index 0000000000..e285f87e10
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIndexedDatasetGet"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt
new file mode 100644
index 0000000000..60c32473b5
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIndexedDatasetMaterialize"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt
new file mode 100644
index 0000000000..b72b229e9a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalIteratorGetDevice"
+ summary: <<END
+Returns the name of the device on which `resource` has been placed.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt
new file mode 100644
index 0000000000..b38b23a51d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalLMDBDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt
new file mode 100644
index 0000000000..9676b9d284
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalMaterializedIndexDatasetHandle"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt
new file mode 100644
index 0000000000..d73b5bfda3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ExperimentalThreadPoolDataset"
+ in_arg {
+ name: "thread_pool"
+ description: <<END
+A resource produced by the ThreadPoolHandle op.
+END
+ }
+ summary: <<END
+Creates a dataset that uses a custom thread pool to compute `input_dataset`.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt
new file mode 100644
index 0000000000..48bf93406c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt
@@ -0,0 +1,35 @@
+op {
+ graph_op_name: "ExperimentalThreadPoolHandle"
+ out_arg {
+ name: "handle"
+ description: <<END
+A resource that can be consumed by one or more ExperimentalThreadPoolDataset
+ops.
+END
+ }
+ attr {
+ name: "num_threads"
+ description: <<END
+The number of threads in the thread pool.
+END
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ description: <<END
+The maximum degree of parallelism to use within operations that execute on this
+threadpool.
+END
+ }
+ attr {
+ name: "display_name"
+ description: <<END
+A human-readable name for the threads that may be visible in some
+visualizations.
+threadpool.
+END
+ }
+ summary: <<END
+Creates a dataset that uses a custom thread pool to compute `input_dataset`.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt
new file mode 100644
index 0000000000..68ed797a0c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalUniqueDataset"
+ summary: <<END
+Creates a dataset that contains the unique elements of `input_dataset`.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt b/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
index 40d7d371ca..7142a0e3f2 100644
--- a/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
@@ -9,7 +9,7 @@ The lower regularized incomplete Gamma function is defined as:
where
-\\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\)
+\\(gamma(a, x) = \\int_{0}^{x} t^{a-1} exp(-t) dt\\)
is the lower incomplete Gamma function.
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 841181f8c3..458e133b68 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/run_handler.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
@@ -244,6 +245,21 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool,
#endif // __ANDROID__
}
+static RunHandlerPool* GetOrCreateRunHandlerPool(
+ const SessionOptions& options) {
+ static RunHandlerPool* pool =
+ new RunHandlerPool(NumInterOpThreadsFromSessionOptions(options));
+ return pool;
+}
+
+bool DirectSession::ShouldUseRunHandlerPool() const {
+ if (options_.config.session_inter_op_thread_pool_size() > 0 ||
+ options_.config.use_per_session_threads()) {
+ return false;
+ }
+ return true;
+}
+
DirectSession::DirectSession(const SessionOptions& options,
const DeviceMgr* device_mgr,
DirectSessionFactory* const factory)
@@ -582,16 +598,37 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
}
}
- Executor::Args::Runner default_runner = [this,
- pool](Executor::Args::Closure c) {
- SchedClosure(pool, std::move(c));
- };
+ std::unique_ptr<RunHandler> handler;
+ if (ShouldUseRunHandlerPool() &&
+ run_options.experimental().use_run_handler_pool()) {
+ // Non-null only when a global inter-op pool is used.
+ VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
+ handler = GetOrCreateRunHandlerPool(options_)->Get();
+ }
+ auto* handler_ptr = handler.get();
+
+ Executor::Args::Runner default_runner = nullptr;
+
+ if (pool == nullptr) {
+ default_runner = [](Executor::Args::Closure c) { c(); };
+ } else if (handler_ptr != nullptr) {
+ default_runner = [handler_ptr](Executor::Args::Closure c) {
+ handler_ptr->ScheduleInterOpClosure(std::move(c));
+ };
+ } else {
+ default_runner = [this, pool](Executor::Args::Closure c) {
+ SchedClosure(pool, std::move(c));
+ };
+ }
+
for (const auto& item : executors_and_keys->items) {
- // TODO(zhengxq): support partial run.
- // TODO(zhengxq): if the device picks its own threadpool, we need to assign
+ // TODO(azaks): support partial run.
+ // TODO(azaks): if the device picks its own threadpool, we need to assign
// less threads to the main compute pool by default.
thread::ThreadPool* device_thread_pool =
item.device->tensorflow_device_thread_pool();
+ // TODO(crk): Investigate usage of RunHandlerPool when using device specific
+ // thread pool(s).
if (!device_thread_pool) {
args.runner = default_runner;
} else {
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 4a6a921ea7..3a168bbe3f 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -247,6 +247,9 @@ class DirectSession : public Session {
ExecutorsAndKeys* executors_and_keys,
RunMetadata* run_metadata);
+ // Returns whether inter-op execution uses a global pool.
+ bool ShouldUseRunHandlerPool() const;
+
::tensorflow::Status ExtendLocked(const GraphDef& graph)
EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_);
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 65e816c202..e3e431f800 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -625,6 +625,34 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts_Callable) {
EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
}
+TEST_F(DirectSessionMinusAXTest, UseRunHandlerPool) {
+ Initialize({3, 2, -1, 0});
+ auto session = CreateSession();
+ ASSERT_TRUE(session != nullptr);
+ TF_ASSERT_OK(session->Create(def_));
+ std::vector<std::pair<string, Tensor>> inputs;
+
+ // Request two targets: one fetch output and one non-fetched output.
+ std::vector<string> output_names = {y_ + ":0"};
+ std::vector<string> target_nodes = {y_neg_};
+ std::vector<Tensor> outputs;
+
+ // Prepares RunOptions and RunMetadata
+ RunOptions run_options;
+ run_options.mutable_experimental()->set_use_run_handler_pool(true);
+
+ Status s = session->Run(run_options, inputs, output_names, target_nodes,
+ &outputs, nullptr);
+ TF_ASSERT_OK(s);
+
+ ASSERT_EQ(1, outputs.size());
+ // The first output should be initialized and have the correct
+ // output.
+ auto mat = outputs[0].matrix<float>();
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ EXPECT_FLOAT_EQ(5.0, mat(0, 0));
+}
+
TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
GraphDef def;
Graph g(OpRegistry::Global());
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index a17959a448..20f957190b 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1101,6 +1101,14 @@ Status FunctionLibraryDefinition::ReplaceFunction(const string& func,
return Status::OK();
}
+Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
+ mutex_lock l(mu_);
+ bool added;
+ TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
+ TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
+ return Status::OK();
+}
+
Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
const auto& i = function_defs_.find(func);
if (i == function_defs_.end()) {
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index e01eb7503d..4d6d68e214 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -331,6 +331,11 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
// a non-OK status if "func" was not found in the library, OK otherwise.
Status ReplaceFunction(const string& func, const FunctionDef& fdef);
+ // Replaces the gradient corresponding to `grad.function_name()`. Returns
+ // a non-OK status if "grad.function_name()" was not found in the library, OK
+ // otherwise.
+ Status ReplaceGradient(const GradientDef& grad);
+
// Adds the functions and gradients in 'other' to this function library.
// Duplicate functions and gradients are ignored.
// This operation is atomic.
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 187bfa2c88..0ff67554eb 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
#include <string>
-#include <unordered_map>
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h
index 25f8de8dcc..81ed5f95f0 100644
--- a/tensorflow/core/framework/op.h
+++ b/tensorflow/core/framework/op.h
@@ -209,16 +209,16 @@ template <>
class OpDefBuilderWrapper<true> {
public:
OpDefBuilderWrapper(const char name[]) : builder_(name) {}
- OpDefBuilderWrapper<true>& Attr(StringPiece spec) {
- builder_.Attr(spec);
+ OpDefBuilderWrapper<true>& Attr(string spec) {
+ builder_.Attr(std::move(spec));
return *this;
}
- OpDefBuilderWrapper<true>& Input(StringPiece spec) {
- builder_.Input(spec);
+ OpDefBuilderWrapper<true>& Input(string spec) {
+ builder_.Input(std::move(spec));
return *this;
}
- OpDefBuilderWrapper<true>& Output(StringPiece spec) {
- builder_.Output(spec);
+ OpDefBuilderWrapper<true>& Output(string spec) {
+ builder_.Output(std::move(spec));
return *this;
}
OpDefBuilderWrapper<true>& SetIsCommutative() {
@@ -237,12 +237,12 @@ class OpDefBuilderWrapper<true> {
builder_.SetAllowsUninitializedInput();
return *this;
}
- OpDefBuilderWrapper<true>& Deprecated(int version, StringPiece explanation) {
- builder_.Deprecated(version, explanation);
+ OpDefBuilderWrapper<true>& Deprecated(int version, string explanation) {
+ builder_.Deprecated(version, std::move(explanation));
return *this;
}
- OpDefBuilderWrapper<true>& Doc(StringPiece text) {
- builder_.Doc(text);
+ OpDefBuilderWrapper<true>& Doc(string text) {
+ builder_.Doc(std::move(text));
return *this;
}
OpDefBuilderWrapper<true>& SetShapeFn(
diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc
index 34a7a43d38..8a9bb63182 100644
--- a/tensorflow/core/framework/op_def_builder.cc
+++ b/tensorflow/core/framework/op_def_builder.cc
@@ -526,32 +526,32 @@ void FinalizeDoc(const string& text, OpDef* op_def,
} // namespace
-OpDefBuilder::OpDefBuilder(StringPiece op_name) {
- op_def()->set_name(string(op_name)); // NOLINT
+OpDefBuilder::OpDefBuilder(string op_name) {
+ op_def()->set_name(std::move(op_name));
}
-OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) {
- attrs_.emplace_back(spec.data(), spec.size());
+OpDefBuilder& OpDefBuilder::Attr(string spec) {
+ attrs_.push_back(std::move(spec));
return *this;
}
-OpDefBuilder& OpDefBuilder::Input(StringPiece spec) {
- inputs_.emplace_back(spec.data(), spec.size());
+OpDefBuilder& OpDefBuilder::Input(string spec) {
+ inputs_.push_back(std::move(spec));
return *this;
}
-OpDefBuilder& OpDefBuilder::Output(StringPiece spec) {
- outputs_.emplace_back(spec.data(), spec.size());
+OpDefBuilder& OpDefBuilder::Output(string spec) {
+ outputs_.push_back(std::move(spec));
return *this;
}
#ifndef TF_LEAN_BINARY
-OpDefBuilder& OpDefBuilder::Doc(StringPiece text) {
+OpDefBuilder& OpDefBuilder::Doc(string text) {
if (!doc_.empty()) {
errors_.push_back(
strings::StrCat("Extra call to Doc() for Op ", op_def()->name()));
} else {
- doc_.assign(text.data(), text.size());
+ doc_ = std::move(text);
}
return *this;
}
@@ -577,14 +577,14 @@ OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
return *this;
}
-OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) {
+OpDefBuilder& OpDefBuilder::Deprecated(int version, string explanation) {
if (op_def()->has_deprecation()) {
errors_.push_back(
strings::StrCat("Deprecated called twice for Op ", op_def()->name()));
} else {
OpDeprecation* deprecation = op_def()->mutable_deprecation();
deprecation->set_version(version);
- deprecation->set_explanation(string(explanation));
+ deprecation->set_explanation(std::move(explanation));
}
return *this;
}
diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h
index 0b39d6e848..8077b20598 100644
--- a/tensorflow/core/framework/op_def_builder.h
+++ b/tensorflow/core/framework/op_def_builder.h
@@ -51,7 +51,7 @@ struct OpRegistrationData {
class OpDefBuilder {
public:
// Constructs an OpDef with just the name field set.
- explicit OpDefBuilder(StringPiece op_name);
+ explicit OpDefBuilder(string op_name);
// Adds an attr to this OpDefBuilder (and returns *this). The spec has
// format "<name>:<type>" or "<name>:<type>=<default>"
@@ -84,7 +84,7 @@ class OpDefBuilder {
// * Ability to restrict the type of the tensor like the existing
// restrictions for type attrs.
// Perhaps by linking the type of the tensor to a type attr?
- OpDefBuilder& Attr(StringPiece spec);
+ OpDefBuilder& Attr(string spec);
// Adds an input or output to this OpDefBuilder (and returns *this).
// The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
@@ -101,8 +101,8 @@ class OpDefBuilder {
// in the spec?
// TODO(josh11b): SparseInput() and SparseOutput() matching the Python
// handling?
- OpDefBuilder& Input(StringPiece spec);
- OpDefBuilder& Output(StringPiece spec);
+ OpDefBuilder& Input(string spec);
+ OpDefBuilder& Output(string spec);
// Turns on the indicated boolean flag in this OpDefBuilder (and
// returns *this).
@@ -112,7 +112,7 @@ class OpDefBuilder {
OpDefBuilder& SetAllowsUninitializedInput();
// Deprecate the op at a certain GraphDef version.
- OpDefBuilder& Deprecated(int version, StringPiece explanation);
+ OpDefBuilder& Deprecated(int version, string explanation);
// Adds docs to this OpDefBuilder (and returns *this).
// Docs have the format:
@@ -128,9 +128,9 @@ class OpDefBuilder {
// to suppress the automatically-generated type documentation in
// generated output.
#ifndef TF_LEAN_BINARY
- OpDefBuilder& Doc(StringPiece text);
+ OpDefBuilder& Doc(string text);
#else
- OpDefBuilder& Doc(StringPiece text) { return *this; }
+ OpDefBuilder& Doc(string text) { return *this; }
#endif
// Sets the shape function to be used for shape inference.
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index abb6635984..4a531648d9 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -248,10 +248,16 @@ Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
ResourceHandle* handle);
// Create a resource pointed by a given resource handle.
+//
+// If successful, the caller transfers the ownership of one ref on `resource` to
+// `ctx->resource_mgr()`.
template <typename T>
Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
// Looks up a resource pointed by a given resource handle.
+//
+// If the lookup is successful, the caller takes the ownership of one ref on
+// `*value`, and must call its `Unref()` method when it has finished using it.
template <typename T>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
@@ -262,6 +268,11 @@ Status LookupResources(
std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values);
// Looks up or creates a resource.
+//
+// If successful, the caller takes the ownership of one ref on `*value`, and
+// must call its `Unref()` method when it has finished using it. If the
+// `creator` is invoked, its reference on the created resource is transferred
+// to `ctx->resource_mgr()`.
template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
T** value, std::function<Status(T**)> creator);
diff --git a/tensorflow/core/framework/run_handler.cc b/tensorflow/core/framework/run_handler.cc
new file mode 100644
index 0000000000..0c4007eafc
--- /dev/null
+++ b/tensorflow/core/framework/run_handler.cc
@@ -0,0 +1,249 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/run_handler.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/run_handler_util.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+
+// Contains the concrete implementation of the RunHandler.
+// Externally visible RunHandler class simply forwards the work to this one.
+class RunHandler::Impl {
+ public:
+ explicit Impl(RunHandlerPool::Impl* pool_impl) : pool_impl_(pool_impl) {
+ Reset();
+ }
+
+ ~Impl() {}
+
+ void set_inter_op_scheduling_range(std::uint_fast32_t start,
+ std::uint_fast32_t limit) {
+ inter_op_scheduling_range_.store(EncodePartition(start, limit),
+ std::memory_order_release);
+ }
+
+ std::uint_fast32_t inter_op_scheduling_range() const {
+ return inter_op_scheduling_range_.load(std::memory_order_acquire);
+ }
+
+ // Stores now time (in microseconds) since unix epoch when the handler is
+ // requested via RunHandlerPool::Get().
+ uint64 start_time_us() const { return start_time_us_; }
+
+ void ScheduleInterOpClosure(std::function<void()> fn);
+
+ void Reset();
+
+ RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
+
+ private:
+ // Encoding/decoding logic for storing [start, limit) into a single
+ // uint_fast32_t int. We assume that pool_num_threads < (1 << 16).
+ const int kMaxPartitionBits = 16;
+ const int kMaxThreads = 1 << kMaxPartitionBits;
+
+ std::uint_fast32_t EncodePartition(std::uint_fast32_t start,
+ std::uint_fast32_t limit) {
+ return (start << kMaxPartitionBits) | limit;
+ }
+
+ void DecodePartition(std::uint_fast32_t val, std::uint_fast32_t* start,
+ std::uint_fast32_t* limit) {
+ *limit = val & (kMaxThreads - 1);
+ val >>= kMaxPartitionBits;
+ *start = val;
+ }
+
+ std::atomic_uint_fast32_t inter_op_scheduling_range_;
+ RunHandlerPool::Impl* pool_impl_; // NOT OWNED.
+ uint64 start_time_us_;
+};
+
+// Contains shared state across all run handlers present in the pool. Also
+// responsible for pool management decisions.
+// This class is thread safe.
+class RunHandlerPool::Impl {
+ public:
+ explicit Impl(int num_inter_op_threads)
+ : max_handlers_(128),
+ inter_op_thread_pool_(new thread::ThreadPool(
+ Env::Default(), ThreadOptions(), "inter_op", num_inter_op_threads)),
+ iterations_(0) {
+ VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
+ for (int i = 0; i < max_handlers_; ++i) {
+ handlers_.emplace_back(new RunHandler::Impl(this));
+ free_handlers_.push_back(handlers_.back().get());
+ }
+ }
+
+ ~Impl() {
+ // Sanity check that all handlers have been returned back to the pool before
+ // destruction.
+ DCHECK_EQ(handlers_.size(), max_handlers_);
+ DCHECK_EQ(free_handlers_.size(), handlers_.size());
+ DCHECK_EQ(sorted_active_handlers_.size(), 0);
+ }
+
+ thread::ThreadPool* inter_op_thread_pool() const {
+ return inter_op_thread_pool_.get();
+ }
+
+ std::unique_ptr<RunHandler> Get() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ while (free_handlers_.empty()) {
+ one_handler_free_.wait(l);
+ }
+ // Remove the last entry from free_handlers_ and add to the end of
+ // sorted_active_handlers_.
+ auto* handler_impl = free_handlers_.back();
+ handler_impl->Reset();
+ // Sortedness isn't violated if we simply add at the end of the list, since
+ // handlers are expected to be obtained in increasing order of time.
+ sorted_active_handlers_.push_back(handler_impl);
+ DCHECK_LE(sorted_active_handlers_.size(), max_handlers_);
+ free_handlers_.pop_back();
+
+ RecomputePoolStatsLocked();
+ return WrapUnique<RunHandler>(new RunHandler(handler_impl));
+ }
+
+ void ReleaseHandler(RunHandler::Impl* handler) LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ DCHECK_GT(sorted_active_handlers_.size(), 0);
+
+ uint64 now = tensorflow::Env::Default()->NowMicros();
+ double elapsed = (now - handler->start_time_us()) / 1000.0;
+ time_hist_.Add(elapsed);
+
+ // Erase from and update sorted_active_handlers_. Add it to the end of
+ // free_handlers_.
+ auto iter = std::find(sorted_active_handlers_.begin(),
+ sorted_active_handlers_.end(), handler);
+ DCHECK(iter != sorted_active_handlers_.end())
+ << "Unexpected handler: " << handler
+ << " is being requested for release";
+
+ // Remove this handler from this list and add it to the list of free
+ // handlers.
+ sorted_active_handlers_.erase(iter);
+ free_handlers_.push_back(handler);
+ DCHECK_LE(free_handlers_.size(), max_handlers_);
+
+ RecomputePoolStatsLocked();
+ }
+ one_handler_free_.notify_one();
+ }
+
+ private:
+ void RecomputePoolStatsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Maximum number of handlers pre-created during pool construction time. The
+ // number has been chosen expecting each handler might at least want 1
+ // inter-op thread for execution (during compute intensive workloads like
+ // inference).
+ const int max_handlers_;
+
+ // Thread safe part.
+ const std::unique_ptr<thread::ThreadPool> inter_op_thread_pool_;
+
+ // Thread compatible part used only by lock under RunHandlerPool.
+ // Handlers are sorted by start time.
+ std::vector<RunHandler::Impl*> sorted_active_handlers_ GUARDED_BY(mu_);
+ std::vector<RunHandler::Impl*> free_handlers_ GUARDED_BY(mu_);
+ std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ GUARDED_BY(mu_);
+ // Histogram of elapsed runtime of every handler (in ms).
+ histogram::Histogram time_hist_ GUARDED_BY(mu_);
+ std::vector<std::uint_fast32_t> inter_op_start_ GUARDED_BY(mu_);
+ std::vector<std::uint_fast32_t> inter_op_limit_ GUARDED_BY(mu_);
+ int64 iterations_ GUARDED_BY(mu_);
+ condition_variable one_handler_free_;
+ mutex mu_;
+};
+
+void RunHandlerPool::Impl::RecomputePoolStatsLocked() {
+ int num_active_requests = sorted_active_handlers_.size();
+ if (num_active_requests == 0) return;
+
+ int num_threads = inter_op_thread_pool_->NumThreads();
+
+ inter_op_start_.resize(num_active_requests);
+ inter_op_limit_.resize(num_active_requests);
+
+ const int kMinThreadsPerRequest = 3;
+ ComputeInterOpSchedulingRanges(num_active_requests, num_threads,
+ kMinThreadsPerRequest, &inter_op_start_,
+ &inter_op_limit_);
+
+ for (int i = 0; i < num_active_requests; ++i) {
+ sorted_active_handlers_[i]->set_inter_op_scheduling_range(
+ inter_op_start_[i], inter_op_limit_[i]);
+ }
+
+ if (iterations_++ % 5000 == 0 && VLOG_IS_ON(1)) {
+ VLOG(1) << "Printing time histogram: " << time_hist_.ToString();
+ VLOG(1) << "Active session runs: " << num_active_requests;
+ uint64 now = tensorflow::Env::Default()->NowMicros();
+ string ranges_str = "";
+ string times_str = "";
+ for (int i = 0; i < num_active_requests; ++i) {
+ if (i > 0) {
+ times_str += " ";
+ ranges_str += " ";
+ }
+
+ times_str += strings::StrCat(
+ (now - sorted_active_handlers_[i]->start_time_us()) / 1000.0, " ms.");
+ ranges_str += strings::StrCat("[", inter_op_start_[i], ", ",
+ inter_op_limit_[i], ")");
+ }
+ VLOG(1) << "Elapsed times are: " << times_str;
+ VLOG(1) << "Ranges are: " << ranges_str;
+ }
+}
+
+void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) {
+ std::uint_fast32_t start = 0, limit = 0;
+ DecodePartition(inter_op_scheduling_range(), &start, &limit);
+ pool_impl_->inter_op_thread_pool()->Schedule(std::move(fn));
+}
+
+void RunHandler::Impl::Reset() {
+ set_inter_op_scheduling_range(
+ 0, pool_impl_->inter_op_thread_pool()->NumThreads());
+ start_time_us_ = tensorflow::Env::Default()->NowMicros();
+}
+
+RunHandlerPool::RunHandlerPool(int num_inter_op_threads)
+ : impl_(new Impl(num_inter_op_threads)) {}
+
+RunHandlerPool::~RunHandlerPool() {}
+
+std::unique_ptr<RunHandler> RunHandlerPool::Get() { return impl_->Get(); }
+
+RunHandler::RunHandler(Impl* impl) : impl_(impl) {}
+
+void RunHandler::ScheduleInterOpClosure(std::function<void()> fn) {
+ impl_->ScheduleInterOpClosure(std::move(fn));
+}
+
+RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); }
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/run_handler.h b/tensorflow/core/framework/run_handler.h
new file mode 100644
index 0000000000..72fa6301b4
--- /dev/null
+++ b/tensorflow/core/framework/run_handler.h
@@ -0,0 +1,95 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
+#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
+
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/histogram/histogram.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+
+namespace tensorflow {
+
+class RunHandler;
+
+// RunHandlerPool is a fixed size pool of pre-allocated RunHandlers
+// that can be used for tracking inter-op work for a given Session::Run().
+// RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes
+// 'active' when its unique_ptr is returned by Get() and is being used by a
+// client. It becomes 'inactive' once more when its unique_ptr gets destroyed.
+//
+// Expected usage:
+//
+// * Create a single RunHandlerPool (say run_handler_pool_).
+//
+// * When a Session::Run() is invoked, obtain a handler by:
+// auto handler = run_handler_pool_->Get();
+//
+// * Use handler for scheduling all inter-op work by:
+// handler->ScheduleInterOpClosure(closure);
+//
+// This class is thread safe.
+class RunHandlerPool {
+ public:
+ explicit RunHandlerPool(int num_inter_op_threads);
+ ~RunHandlerPool();
+
+ // Returns an inactive RunHandler from the pool.
+ //
+ // RunHandlers in RunHandlerPool are initially 'inactive'.
+ // A RunHandler becomes 'active' when its unique_ptr its returned by Get()
+ // and is being used by a client. It becomes 'inactive' once more when the
+ // unique_ptr is destroyed.
+ //
+ // Will block unless there is an inactive handler.
+ std::unique_ptr<RunHandler> Get();
+
+ private:
+ class Impl;
+ friend class RunHandler;
+
+ std::unique_ptr<Impl> impl_;
+};
+
+// RunHandler can be used to schedule inter-op closures to run on a global pool
+// shared across all Session::Run(s).
+//
+// It can only be created via RunHandlerPool::Get().
+//
+// This class can be used instead of directly scheduling closures on a global
+// pool since it maintains a global view across all sessions and optimizes pool
+// scheduling to improve (median and tail) latency.
+//
+// This class is thread safe.
+class RunHandler {
+ public:
+ void ScheduleInterOpClosure(std::function<void()> fn);
+
+ ~RunHandler();
+
+ private:
+ class Impl;
+ friend class RunHandlerPool::Impl;
+
+ explicit RunHandler(Impl* impl);
+
+ Impl* impl_; // NOT OWNED.
+};
+
+} // end namespace tensorflow.
+
+#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
diff --git a/tensorflow/core/framework/run_handler_util.cc b/tensorflow/core/framework/run_handler_util.cc
new file mode 100644
index 0000000000..3087998c69
--- /dev/null
+++ b/tensorflow/core/framework/run_handler_util.cc
@@ -0,0 +1,57 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/run_handler_util.h"
+
+#include <algorithm>
+#include <cmath>
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads,
+ int min_threads_per_request,
+ std::vector<std::uint_fast32_t>* start_vec,
+ std::vector<std::uint_fast32_t>* end_vec) {
+ // Each request is expected to have weight W[i] = num_active_requests - i.
+ // Therefore, total_weight = sum of all request weights.
+ float total_weight = 0.5f * num_active_requests * (num_active_requests + 1);
+ float demand_factor = static_cast<float>(num_threads) / total_weight;
+ float last_cumulative_weight = 0.0;
+ min_threads_per_request = std::max(1, min_threads_per_request);
+ for (int i = 0; i != num_active_requests; i++) {
+ float cumulative_weight =
+ static_cast<float>(i + 1) *
+ (num_active_requests - static_cast<float>(i) * 0.5f);
+ float weight = cumulative_weight - last_cumulative_weight;
+ // Quantize thread_demand by rounding up, and also satisfying
+ // `min_threads_per_request` constraint.
+ // Note: We subtract a small epsilon (0.00001) to prevent ceil(..) from
+ // rounding weights like 4.0 to 5.
+ int demand =
+ std::max(min_threads_per_request,
+ static_cast<int>(ceil(weight * demand_factor - 0.00001f)));
+ // For the quantized range [start, end); compute the floor of real start,
+ // and expand downwards from there with length `demand` and adjust for
+ // boundary conditions.
+ int start = last_cumulative_weight * demand_factor;
+ int end = std::min(num_threads, start + demand);
+ start = std::max(0, std::min(start, end - demand));
+ start_vec->at(i) = start;
+ end_vec->at(i) = end;
+ last_cumulative_weight = cumulative_weight;
+ }
+}
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/run_handler_util.h b/tensorflow/core/framework/run_handler_util.h
new file mode 100644
index 0000000000..c0c36aeccb
--- /dev/null
+++ b/tensorflow/core/framework/run_handler_util.h
@@ -0,0 +1,43 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
+
+#include <cstdint>
+#include <vector>
+
+namespace tensorflow {
+
+// Assign thread ranges to requests.
+// Requests are numbered 0...num_active_requests-1, and
+// threads are numbered 0...num_threads-1.
+// On return, the range start_vec->at(i)...end_vec->at(i)-1
+// indicates the subrange of the threads available to request i.
+// The ranges given to different requests may overlap.
+// Lower numbered requests will tend to be assigned more threads.
+// Thus, a client might associate older requests with lower
+// array indices so they receive access to more threads.
+// However, the routine ensures that each request is given access
+// to at least min(min_threads_per_request, num_threads) threads.
+// Every thread will be assigned to at least one request range,
+// assuming there is at least one request.
+void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads,
+ int min_threads_per_request,
+ std::vector<std::uint_fast32_t>* start_vec,
+ std::vector<std::uint_fast32_t>* end_vec);
+
+} // end namespace tensorflow
+#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
diff --git a/tensorflow/core/framework/run_handler_util_test.cc b/tensorflow/core/framework/run_handler_util_test.cc
new file mode 100644
index 0000000000..a1928c132b
--- /dev/null
+++ b/tensorflow/core/framework/run_handler_util_test.cc
@@ -0,0 +1,93 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/run_handler_util.h"
+
+#include <vector>
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+namespace tensorflow {
+namespace {
+
+void VerifyFunction(int num_active_requests, int num_threads,
+ int min_threads_per_request, bool print_stats = false) {
+ if (print_stats) {
+ LOG(INFO) << "Test case# num_active_requests: " << num_active_requests
+ << " num_threads: " << num_threads
+ << " min_threads: " << min_threads_per_request;
+ }
+ std::vector<std::uint_fast32_t> start(num_active_requests);
+ std::vector<std::uint_fast32_t> end(num_active_requests);
+
+ ComputeInterOpSchedulingRanges(num_active_requests, num_threads,
+ min_threads_per_request, &start, &end);
+ string range_str = "";
+ for (int i = 0; i < num_active_requests; ++i) {
+ if (i > 0) range_str += " ";
+ range_str += strings::StrCat("[", start[i], ", ", end[i], ")");
+
+ ASSERT_GE(start[i], 0) << range_str;
+ ASSERT_LE(end[i], num_threads) << range_str;
+ if (i > 0) {
+ // Due to linearly decreasing demand, #threads(i - 1) >= #threads(i)
+ ASSERT_GE(end[i - 1] - start[i - 1], end[i] - start[i]) << range_str;
+ // No missing threads.
+ ASSERT_GE(end[i - 1], start[i]) << range_str;
+ }
+ // Each interval is at least of size 'min_threads_per_request'.
+ ASSERT_GE((end[i] - start[i]), min_threads_per_request) << range_str;
+ // Verify that assigned (quantized) threads is not overly estimated
+ // from real demand, when the demand is high (>=
+ // min_threads_per_request).
+ float entry_weight = num_active_requests - i;
+ float total_weight = 0.5f * num_active_requests * (num_active_requests + 1);
+ float thread_demand = (entry_weight * num_threads) / total_weight;
+ if (thread_demand > min_threads_per_request) {
+ // We expect some over-estimation of threads due to quantization,
+ // but we hope it's not more than 1 extra thread.
+ ASSERT_NEAR(end[i] - start[i], thread_demand, 1.0)
+ << "Ranges: " << range_str << " thread_demand: " << thread_demand
+ << " i: " << i;
+ }
+ }
+ ASSERT_EQ(end[num_active_requests - 1], num_threads);
+ ASSERT_EQ(start[0], 0);
+ if (print_stats) {
+ LOG(INFO) << "Assigned ranges: " << range_str;
+ }
+}
+
+TEST(RunHandlerUtilTest, TestComputeInterOpSchedulingRanges) {
+ const int kMinThreadsPerRequestBound = 12;
+ const int kMaxActiveRequests = 128;
+ const int kMaxThreads = 128;
+
+ for (int min_threads_per_request = 1;
+ min_threads_per_request <= kMinThreadsPerRequestBound;
+ ++min_threads_per_request) {
+ for (int num_active_requests = 1; num_active_requests <= kMaxActiveRequests;
+ ++num_active_requests) {
+ for (int num_threads = min_threads_per_request;
+ num_threads <= kMaxThreads; ++num_threads) {
+ VerifyFunction(num_active_requests, num_threads,
+ min_threads_per_request);
+ }
+ }
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 81c1bddf67..5a3abbb545 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -124,10 +124,10 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
- "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
] + tf_protos_all(),
)
@@ -523,6 +523,7 @@ cc_library(
":function_utils",
":graph_utils",
"@com_google_absl//absl/strings",
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -538,6 +539,7 @@ tf_cc_test(
srcs = ["vectorization_utils_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_utils",
":function_utils",
":vectorization_utils",
"//tensorflow/core:framework",
@@ -547,7 +549,10 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ # For ops we need registered
+ "//tensorflow/core/kernels/data:dataset_ops",
"//tensorflow/core/kernels:cast_op",
+ "//tensorflow/core/kernels:logging_ops",
"//tensorflow/tools/graph_transforms:transform_utils",
] + tf_protos_all(),
)
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 5dd7819100..3af34f6904 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -116,8 +116,8 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
// is unique across the graph.
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);
-// Sets the node name using the `prefix` name as a prefix while guaranteeing the
-// name is unique across the graph.
+// Sets the function name using the `prefix` name as a prefix while guaranteeing
+// the name is unique across the function library.
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function);
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index 32ab912619..9328a7ca99 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -86,21 +86,19 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
// efficient vectorization with VectorizeMapDefun.
FunctionDef* vectorized_func =
CreateMapDefunWrapper(map_node, orig_func, library);
- NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0);
- DCHECK_EQ(map_defun_node->op(), "MapDefun");
-
- // Create a copy of the original function so that we can mutate it, and
- // attach that to the map defun node.
- FunctionDef* map_defun_fn = library->add_function();
- *map_defun_fn = orig_func;
- graph_utils::SetUniqueGraphFunctionName(orig_func.signature().name(), library,
- map_defun_fn);
- (*map_defun_node->mutable_attr())["f"].mutable_func()->set_name(
- map_defun_fn->signature().name());
-
- vectorization_utils::VectorizeMapDefun(vectorized_func, map_defun_fn,
- map_defun_node);
- return vectorized_func;
+ const NodeDef& map_defun_node = vectorized_func->node_def(0);
+ DCHECK_EQ(map_defun_node.op(), "MapDefun");
+
+ // TODO(b/116285210): Unreferenced functions should get cleaned up later
+ FunctionDef* result;
+ Status s = vectorization_utils::VectorizeMapDefun(
+ *vectorized_func, map_defun_node, library, &result);
+
+ if (!s.ok()) {
+ LOG(ERROR) << "VectorizeMapDefun failed: " << s;
+ return vectorized_func;
+ }
+ return result;
}
bool IsOutputShapesFullyDefined(const NodeDef& node) {
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
index ed1bd6bc97..f4faf41549 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
@@ -30,72 +30,51 @@ namespace {
using test::function::GDef;
using test::function::NDef;
-void MakeTensorShapeProtoHelper(const gtl::ArraySlice<int> dims,
- TensorShapeProto* t) {
- for (size_t i = 0; i < dims.size(); ++i) {
- auto* d = t->add_dim();
- d->set_size(dims[i]);
- }
-}
-
-AttrValue MakeShapeListAttr(
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& shapes) {
- AttrValue shapes_attr;
- for (size_t i = 0; i < shapes.size(); ++i) {
- MakeTensorShapeProtoHelper(shapes[i],
- shapes_attr.mutable_list()->add_shape());
- }
-
- return shapes_attr;
-}
-
-NodeDef MakeMapNodeHelper(
- StringPiece name, StringPiece input_node_name, StringPiece function_name,
- StringPiece map_op_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
+NodeDef MakeMapNodeHelper(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name, StringPiece map_op_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
return test::function::NDef(
name, map_op_name, {string(input_node_name)},
{{"f", FunctionDefHelper::FunctionRef(string(function_name))},
{"Targuments", {}},
- {"output_shapes", MakeShapeListAttr(output_shapes)},
+ {"output_shapes", output_shapes},
{"output_types", output_types}});
}
-NodeDef MakeMapNode(
- StringPiece name, StringPiece input_node_name, StringPiece function_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
return MakeMapNodeHelper(name, input_node_name, function_name, "MapDataset",
output_shapes, output_types);
}
-NodeDef MakeBatchNode(
- StringPiece name, StringPiece input_node_name,
- StringPiece input_batch_size_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
- return NDef(name, "BatchDataset",
- {string(input_node_name), string(input_batch_size_name)},
- {{"output_types", output_types},
- {"output_shapes", MakeShapeListAttr(output_shapes)}});
+NodeDef MakeBatchNode(StringPiece name, StringPiece input_node_name,
+ StringPiece input_batch_size_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
+ return NDef(
+ name, "BatchDataset",
+ {string(input_node_name), string(input_batch_size_name)},
+ {{"output_types", output_types}, {"output_shapes", output_shapes}});
}
-NodeDef MakeBatchV2Node(
- StringPiece name, StringPiece input_node_name,
- StringPiece input_batch_size_name, StringPiece input_drop_remainder_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
- return NDef(name, "BatchDatasetV2",
- {string(input_node_name), string(input_batch_size_name),
- string(input_drop_remainder_name)},
- {{"output_types", output_types},
- {"output_shapes", MakeShapeListAttr(output_shapes)}});
+NodeDef MakeBatchV2Node(StringPiece name, StringPiece input_node_name,
+ StringPiece input_batch_size_name,
+ StringPiece input_drop_remainder_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
+ return NDef(
+ name, "BatchDatasetV2",
+ {string(input_node_name), string(input_batch_size_name),
+ string(input_drop_remainder_name)},
+ {{"output_types", output_types}, {"output_shapes", output_shapes}});
}
-NodeDef MakeRangeNode(StringPiece name, const gtl::ArraySlice<string>& inputs) {
+NodeDef MakeRangeNode(StringPiece name, gtl::ArraySlice<string> inputs) {
return NDef(name, "RangeDataset", inputs,
- {{"output_shapes", MakeShapeListAttr({{}})},
+ {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})},
{"output_types", gtl::ArraySlice<DataType>({DT_INT64})}});
}
@@ -184,7 +163,7 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
item.graph = GDef(
{NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
NDef("input", "InputDataset", {},
- {{"output_shapes", MakeShapeListAttr({{}})}}),
+ {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})}}),
MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}),
MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
// FunctionLib
@@ -196,6 +175,37 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
}
+TEST(MapVectorizationTest, VectorizeWithFullyDefinedFunction) {
+ GrapplerItem item;
+ item.graph = GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ MakeRangeNode("range", {"start", "stop", "step"}),
+ MakeMapNode("map", "range", "Func", {{}}, {DT_INT32}),
+ MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
+ // FunctionLib
+ {FunctionDefHelper::Create(
+ "Func", {"x: int64", "y: int64"}, {"res: int64", "res2: int64"}, {},
+ {{{"o"}, "Mul", {"x", "x"}, {{"T", DT_INT64}}}},
+ {{"res", "o:z"}, {"res2", "o:z"}})});
+ MapVectorization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
+ 1);
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("BatchDataset", output).size(),
+ 1);
+ const NodeDef& map_node =
+ output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
+ const NodeDef& batch_node =
+ output.node(graph_utils::FindGraphNodeWithOp("BatchDataset", output));
+ EXPECT_EQ(map_node.input(0), batch_node.name());
+ EXPECT_EQ(batch_node.input(0), "range");
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
index 1462cb234d..37aa24b947 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -9,13 +9,14 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
VECTORIZER_DEPS = [
":vectorizer_registry",
- "//tensorflow/core/grappler/optimizers/data:function_utils",
+ "//tensorflow/core/grappler/optimizers/data:graph_utils",
] + tf_protos_all()
cc_library(
name = "vectorizer",
hdrs = ["vectorizer.h"],
deps = [
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
] + tf_protos_all(),
)
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
index c1739737a0..3af6bab409 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
namespace tensorflow {
@@ -23,26 +23,21 @@ namespace vectorization_utils {
class CastVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
- if (inputs.size() != 1) {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* input_ports,
+ std::vector<Port>* output_ports) override {
+ Status s;
+ if (node.num_inputs() != 1) {
return errors::Internal("Cast op should only have one input.");
}
- // Add new Cast node
- NodeDef* new_cast_node = outer_scope->add_node_def();
- *new_cast_node = node;
- new_cast_node->clear_name();
- function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("vectorized/", node.name()), outer_scope,
- new_cast_node);
- new_cast_node->set_input(0, inputs[0]);
-
- // Add the output mapping to conversion map
- (*conversion_map)[strings::StrCat(node.name(), ":y:0")] =
- strings::StrCat(new_cast_node->name(), ":y:0");
+ // Add new Cast node with the same op and attrs as the original node
+ auto new_cast_node = outer_scope->AddNode(node.def(), &s);
+ TF_RETURN_IF_ERROR(s);
+ // Add input and output mappings
+ input_ports->push_back({new_cast_node, 0});
+ output_ports->push_back({new_cast_node, 0});
return Status::OK();
}
};
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
index 776d3179c5..74ce520ce1 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
namespace tensorflow {
@@ -23,31 +23,29 @@ namespace vectorization_utils {
class UnpackVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
- if (inputs.size() != 1) {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* input_ports,
+ std::vector<Port>* output_ports) override {
+ Status s;
+ if (node.num_inputs() != 1) {
return errors::Internal("Unpack op should only have one input.");
}
- // Add new Unpack node
- NodeDef* new_unpack_node = outer_scope->add_node_def();
- *new_unpack_node = node;
- new_unpack_node->clear_name();
- function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("vectorized/", node.name()), outer_scope,
- new_unpack_node);
+ // Add new Unpack node with the same op and attrs as the original node
+ auto new_unpack_node = outer_scope->AddNode(node.def(), &s);
+ TF_RETURN_IF_ERROR(s);
// Increment "axis" attr by 1:
- (*new_unpack_node->mutable_attr())["axis"].set_i(
- node.attr().at("axis").i() + 1);
- new_unpack_node->set_input(0, inputs[0]);
+ int new_axis = node.def().attr().at("axis").i() + 1;
+ new_unpack_node->AddAttr("axis", new_axis);
- // Add the output mappings to conversion map
- int num = new_unpack_node->attr().at("num").i();
+ // Add the input mappings
+ input_ports->push_back({new_unpack_node, 0});
+
+ // Add the output mappings
+ int num = node.def().attr().at("num").i();
for (int i = 0; i < num; ++i) {
- (*conversion_map)[strings::StrCat(node.name(), ":output:", i)] =
- strings::StrCat(new_unpack_node->name(), ":output:", i);
+ output_ports->push_back({new_unpack_node, i});
}
return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
index d341dbba7d..56eb88c95e 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
@@ -17,30 +17,33 @@ limitations under the License.
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
+// Describes a tensor with its operation Node and output position
+typedef std::pair<Node*, int> Port;
+
// Interface for vectorization of TensorFlow operations. See `CastVectorizer`
// for an example.
class Vectorizer {
public:
virtual ~Vectorizer() {}
- // Vectorizes an operation, `node`, by adding operation(s) to `outer_scope`
+ // Vectorizes an operation, `node`, by adding Node(s) to `outer_scope`
// that produce the same vector output(s) as executing `node`'s op
- // on elements of the vector inputs, and adding mappings to `conversion_map`
- // from old output tensor names to new (vectorized) output tensor names.
- // The new node(s) collectively have the same number of inputs and outputs as
- // the node being converted, and use the tensor names in `inputs` as their
- // inputs.
- virtual Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) = 0;
+ // on elements of the vector inputs. The new Node(s) collectively have the
+ // same number of input and output ports as the node being converted.
+ // Adds mappings for the new nodes' input and output ports to `inputs` and
+ // `outputs` respectively, where the i'th Port in inputs/outputs
+ // corresponds to the i'th input/output port of the node to be converted.
+ virtual Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* input_ports,
+ std::vector<Port>* output_ports) = 0;
};
} // namespace vectorization_utils
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
index 86e303564b..663ceba027 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
@@ -24,9 +24,9 @@ namespace vectorization_utils {
class TestVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* inputs,
+ std::vector<Port>* outputs) override {
return Status::OK();
}
};
@@ -39,10 +39,12 @@ TEST(TestVectorizer, TestTestVectorizer) {
auto vectorizer = VectorizerRegistry::Global()->Get("test_op");
EXPECT_NE(vectorizer, nullptr);
- FunctionDef function;
- NodeDef node;
- std::map<string, string> conversion_map;
- EXPECT_TRUE(vectorizer->Vectorize(node, {}, &function, &conversion_map).ok());
+ Graph g(OpRegistry::Global());
+ NodeDef node_def;
+ Status s;
+ Node* node = g.AddNode(node_def, &s);
+ std::vector<Port> inputs, outputs;
+ EXPECT_TRUE(vectorizer->Vectorize(*node, &g, &inputs, &outputs).ok());
}
} // namespace vectorization_utils
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index cb56b65985..cea667f668 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -14,13 +14,17 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+#include <memory>
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
@@ -36,255 +40,346 @@ namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
-using function_utils::FunctionDefTensorDesc;
-
namespace {
-void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node,
- const string& output_retval, const DataType t) {
- // Set to unknown shape
- TensorShapeProto tensor_shape_proto;
- PartialTensorShape().AsProto(&tensor_shape_proto);
+// Describes a tensor with its operation Node and output position
+typedef std::pair<Node*, int> TensorDesc;
- function_utils::AddFunctionOutputWithUniqueName(
- "vectorized_out", output_retval, map_defun_fn, t);
+const char* const kRetValOp = "_Retval";
- *(*map_defun_node->mutable_attr())["output_shapes"]
- .mutable_list()
- ->add_shape() = tensor_shape_proto;
- (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t);
+void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
+ Graph* graph) {
+ // NOTE: We need two for loops here because we can't mutate the set of output
+ // edges as we iterate over them.
+ std::vector<const Edge*> edges_to_replace;
+ for (auto edge : old_src.first->out_edges()) {
+ if (edge->src_output() == old_src.second) {
+ edges_to_replace.push_back(edge);
+ }
+ }
+ for (auto edge : edges_to_replace) {
+ graph->AddEdge(new_src.first, new_src.second, edge->dst(),
+ edge->dst_input());
+ graph->RemoveEdge(edge);
+ }
}
-void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node, int output_position) {
- DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size())
- << "Trying to remove output that doesn't exist. Output number: "
- << output_position;
+Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
+ const TensorDesc& output) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
+ DataType type = output.first->output_type(output.second);
+ int index = map_defun_fn->ret_nodes.size();
- int num_later_outputs =
- map_defun_fn->signature().output_arg_size() - output_position - 1;
+ NodeDef ret_node_def;
+ ret_node_def.set_name("map_out");
+ ret_node_def.set_op(kRetValOp);
+ AddNodeAttr("T", type, &ret_node_def);
+ AddNodeAttr("index", index, &ret_node_def);
- // Remove from map_defun_fn's ret dict and output args
- map_defun_fn->mutable_ret()->erase(
- map_defun_fn->signature().output_arg(output_position).name());
- map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange(
- output_position, 1);
+ Status s;
+ Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s);
+ TF_RETURN_IF_ERROR(s);
- // Renumber outputs that come after
- for (int i = 0; i < num_later_outputs; ++i) {
- function_utils::ReplaceReferences(
- strings::StrCat(map_defun_node->name(),
- ":output:", output_position + i + 1),
- strings::StrCat(map_defun_node->name(),
- ":output:", output_position + i),
- outer_scope);
- }
- map_defun_node->mutable_attr()
- ->at("output_shapes")
- .mutable_list()
- ->mutable_shape()
- ->DeleteSubrange(output_position, 1);
- map_defun_node->mutable_attr()
- ->at("output_types")
- .mutable_list()
- ->mutable_type()
- ->ExtractSubrange(output_position, 1, nullptr);
+ map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
+ map_defun_fn->ret_nodes.push_back(ret_node);
+ map_defun_fn->ret_types.push_back(type);
+
+ return s;
}
-int FindOutputToConvert(const FunctionDef& function,
- const std::set<string>& unconvertible,
- FunctionDefTensorDesc* f) {
- for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) {
- const string& ret_key = function.signature().output_arg(i).name();
- *f = FunctionDefTensorDesc(function.ret().at(ret_key));
+void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
+ FunctionBody* map_defun_fn, Node* map_defun_node) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
+ DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
+ << "Trying to remove output that doesn't exist. Output number: "
+ << output_position;
+
+ int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1;
- if (unconvertible.find(f->node_name) == unconvertible.end()) {
- return i;
- }
+ // Modify map_defun_fn's signature and remove the output node from its graph
+ map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]);
+ map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() +
+ output_position);
+ map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
+ output_position);
+
+ // Renumber the nodes and edges that come after
+ for (int i = 0; i < num_later_outputs; ++i) {
+ ReplaceEdgeSources({map_defun_node, output_position + i + 1},
+ {map_defun_node, output_position + i}, outer_scope);
+ // Each ret node has an "index" attr that has to be updated
+ map_defun_fn->ret_nodes[output_position + i]->AddAttr("index",
+ output_position + i);
}
- return -1;
}
// Helper class that vectorizes the body of a MapDefun node, adding new
// operations to the graph that collectively compute the same value as what
// running the MapDefun function on slices of the input would produce.
-// Each instance of the class encapsulates all the data necessary to vectorize a
-// MapDefun op in place.
+// This class transforms the input FunctionDefs into their corresponding
+// Graph objects and works on the graphs directly, then converts them back
+// to FunctionDefs when GetResult is called.
class Vectorization {
public:
- Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node)
- : outer_scope_(outer_scope),
- map_defun_fn_(map_defun_fn),
- map_defun_node_(map_defun_node) {}
+ explicit Vectorization(FunctionDefLibrary* lib)
+ : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {}
- // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in
- // the outer_scope_, until there are no convertible outputs remaining.
- // This method is idempotent.
- void Vectorize();
+ // Adds the vectorized function and new map_defun_fn to lib, and points
+ // vectorized_function to the former. Returns an error status if
+ // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere
+ // along the way.
+ Status Vectorize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDef** result);
private:
- // Vectorizes the map defun function's output at output_position
- Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc);
- // Given a descriptor of the original output tensor, gets a string
- // corresponding to the converted output tensor.
- Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc,
- string* converted);
- Status AddConversionMappingFromInput(
- const FunctionDefTensorDesc& output_desc);
+ // Converts FunctionDefs to Graphs.
+ Status Initialize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node);
+
+ // Converts Graphs back to FunctionDefs and adds them to `lib_`.
+ Status GetResult(FunctionDef** vectorized_function);
+
+ // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in
+ // `outer_scope_`, until there are no convertible outputs remaining.
+ void VectorizeHelper();
+
+ // Vectorizes map_defun_fn's output at output_position.
+ Status ConvertOutput(int output_position);
// Adds mappings from node's outputs tensors to converted output tensors,
// creating the necessary new node(s). Generally, the steps to convert an op
// are:
- // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_,
- // and modify map_defun_node_ attrs accordingly
- // 2) Create new node(s) in outer_scope_ that act on batched input tensors.
+ // 1) Create new node(s) in `outer_scope_` that act on batched input tensors.
// These operations collectively compute the same value as what running
// the original operation on slices of the input tensors would produce.
// For example, a Cast op in MapDefun translates to a Cast op in
- // outer_scope_, since the vectorized version of Cast is itself.
- // 3) Set inputs of new node(s) to the corresponding converted inputs (that
- // are now outputs of map_defun_node_)
- // 4) For each output of the old node, add the mapping of output strings to
- // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0")
- Status AddConversionMappingFromOp(const NodeDef& node,
- const FunctionDefTensorDesc& output_desc);
-
- // Maps a tensor name to the name of the corresponding vectorized tensor. For
- // example, "Cast:y:0" -> "Vectorize/Cast:y:0"
- std::map<string, string> conversion_map_;
- // Unconvertible node names
- std::set<string> unconvertible_;
-
- FunctionDef* outer_scope_;
- FunctionDef* map_defun_fn_;
- NodeDef* map_defun_node_;
+ // `outer_scope_`, since the vectorized version of Cast is itself.
+ // 2) Promote the inputs of the op inputs to outputs of the
+ // `map_defun_node_` and `map_defun_fn_`.
+ // 3) Add edges between the promoted inputs (that are now outputs of
+ // `map_defun_node`) and the inputs ports of the new node(s).
+ // 4) For each output of the old node, add the mapping of output tensors to
+ // the conversion map.
+ Status AddConversionMapping(Node* op_node);
+
+ // Maps a tensor to the corresponding vectorized tensor. For example,
+ // {"Cast" Node*, 0} -> {"Vectorize/Cast" Node*, 0}
+ std::map<TensorDesc, TensorDesc> conversion_map_;
+
+ // Unconvertible ret nodes
+ std::set<Node*> unconvertible_;
+
+ FunctionDefLibrary* lib_; // Not owned
+ FunctionLibraryDefinition lib_def_;
+ // Note that FunctionBody has a pointer to a Graph object that corresponds
+ // to the function's subgraph, with additional kArgOp and kRetValOp nodes
+ // that denote that function arguments and return values. These nodes have the
+ // attrs "T" for the type, and "index" for the argument / retval index
+ // respectively. FunctionBody also keeps track of arg/ret_nodes and
+ // arg/ret_types, that should be ordered according to argument/output indices.
+ std::unique_ptr<Graph> outer_scope_;
+ std::unique_ptr<FunctionBody> map_defun_fn_;
+ Node* map_defun_node_ = nullptr; // Owned by `outer_scope`
+ Status status_;
};
-Status Vectorization::AddConversionMappingFromOp(
- const NodeDef& node, const FunctionDefTensorDesc& output_desc) {
- for (const string& input_name : node.input()) {
- if (IsControlInput(input_name)) {
+Status Vectorization::AddConversionMapping(Node* op_node) {
+ for (auto edge : op_node->in_edges()) {
+ if (edge->IsControlEdge()) {
return errors::InvalidArgument(
"Vectorizing outputs with control inputs is currently not "
"supported.");
}
}
- // TODO(rachelim): Have some mechanism for registering converters and some
- // uniform, simpler way to represent them.
-
- DataTypeVector types;
- const OpDef* op_def = nullptr;
- TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
- TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types));
-
- std::vector<string> promoted_inputs;
- promoted_inputs.reserve(node.input_size());
- for (int i = 0; i < node.input_size(); ++i) {
- promoted_inputs.push_back(strings::StrCat(
- map_defun_node_->name(),
- ":output:", map_defun_fn_->signature().output_arg_size() + i));
- }
-
- auto vectorizer = VectorizerRegistry::Global()->Get(node.op());
+ auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string());
if (vectorizer == nullptr) {
return errors::Unimplemented("No vectorizer registered for op: ",
- node.op());
+ op_node->type_string());
+ }
+ std::vector<Port> input_ports, output_ports;
+ input_ports.reserve(op_node->num_inputs());
+ output_ports.reserve(op_node->num_outputs());
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
+ &input_ports, &output_ports));
+
+ std::vector<const Edge*> input_edges;
+ TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
+
+ if (op_node->num_outputs() != output_ports.size() ||
+ op_node->num_inputs() != input_ports.size() ||
+ input_edges.size() != input_ports.size()) {
+ return errors::Internal("Vectorizer inputs/outputs don't match.");
}
- TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_,
- &conversion_map_));
+ // Promote the inputs of the op to MapDefun outputs and connect the edges
+ // accordingly.
+ for (size_t i = 0; i < op_node->num_inputs(); ++i) {
+ auto edge = input_edges[i];
+ TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
+ {edge->src(), edge->src_output()}));
+ outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1,
+ input_ports[i].first, input_ports[i].second);
+ }
- // If we get here, the conversion was successful, so we promote the inputs
- // of the ops to MapDefun outputs.
- for (int i = 0; i < types.size(); ++i) {
- AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]);
+ // Add output mappings.
+ for (size_t i = 0; i < op_node->num_outputs(); ++i) {
+ conversion_map_.insert({{op_node, i}, std::move(output_ports[i])});
}
return Status::OK();
}
-Status Vectorization::AddConversionMappingFromInput(
- const FunctionDefTensorDesc& output_desc) {
- int input_index = function_utils::FindFunctionInputWithName(
- output_desc.node_name, *map_defun_fn_);
- if (input_index == -1) {
- return errors::Internal("Cannot convert non-existent input.");
+Status Vectorization::ConvertOutput(int output_position) {
+ // ret_edge->src() is the actual op that generated the retval, and
+ // ret_edge->dst() is the retval node whose op is "_Retval"
+ const Edge* ret_edge;
+ TF_RETURN_IF_ERROR(
+ map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge));
+
+ TensorDesc output({ret_edge->src(), ret_edge->src_output()});
+ TensorDesc converted_output;
+ if (auto found = gtl::FindOrNull(conversion_map_, output)) {
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ converted_output = *found;
+ } else {
+ TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
+ converted_output = conversion_map_.at(output);
}
- conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index);
+ ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
+ outer_scope_.get());
+ RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(),
+ map_defun_node_);
+
return Status::OK();
}
-Status Vectorization::ConvertOutputHelper(
- const FunctionDefTensorDesc& output_desc, string* converted) {
- // It's possible the output already has a mapping, if it comes from a node
- // that has already been converted.
- if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) {
- *converted = *found;
- return Status::OK();
+Status Vectorization::Vectorize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node,
+ FunctionDef** result) {
+ TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node));
+ VectorizeHelper();
+ return GetResult(result);
+}
+
+void Vectorization::VectorizeHelper() {
+ while (true) {
+ int output_position = graph_utils::GetFirstElementIndexWithPredicate(
+ [this](Node* n) {
+ return this->unconvertible_.find(n) == this->unconvertible_.end();
+ },
+ map_defun_fn_->ret_nodes);
+
+ // No outputs left to convert
+ if (output_position == -1) break;
+
+ Status s = ConvertOutput(output_position);
+ if (!s.ok()) {
+ Node* output_node = map_defun_fn_->ret_nodes.at(output_position);
+ VLOG(2) << "Could not convert the output at node: "
+ << output_node->DebugString() << "\nError: " << s;
+ unconvertible_.insert(output_node);
+ }
}
- int index = function_utils::FindFunctionNodeWithName(output_desc.node_name,
- *map_defun_fn_);
- if (index == -1) { // The output comes from an input
- TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc));
+ // If we've converted all the outputs of the MapDefun function, we no longer
+ // need the MapDefun node and can delete it.
+ if (map_defun_fn_->ret_nodes.empty()) {
+ outer_scope_->RemoveNode(map_defun_node_);
} else {
- TF_RETURN_IF_ERROR(AddConversionMappingFromOp(
- map_defun_fn_->node_def(index), output_desc));
+ // Update MapDefun node attrs accordingly
+ DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size());
+ map_defun_node_->AddAttr(
+ "output_shapes",
+ std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size()));
+ map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
}
- *converted = conversion_map_.at(output_desc.full_str);
- return Status::OK();
}
+Status Vectorization::Initialize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node) {
+ // Convert outer_scope and map_defun_fn to FunctionBodys so we can
+ // work on Graphs directly.
+ const FunctionDef* map_defun_fn =
+ lib_def_.Find(map_defun_node.attr().at("f").func().name());
+
+ if (map_defun_fn == nullptr) {
+ return errors::NotFound("Could not find function with name ",
+ map_defun_node.attr().at("f").func().name(),
+ " in function library.");
+ }
-Status Vectorization::ConvertOutput(int output_position,
- const FunctionDefTensorDesc& output_desc) {
- string converted_output_name;
- TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name));
+ auto get_func_sig = [this](const string& op, const OpDef** sig) {
+ return this->lib_def_.LookUpOpDef(op, sig);
+ };
+
+ FunctionBody* outer_fn;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(outer_scope, {}, &lib_def_,
+ get_func_sig, &outer_fn));
+ // We don't need outer_fn, just the graph
+ outer_scope_.reset(outer_fn->graph);
+ outer_fn->graph = nullptr;
+ delete outer_fn;
+
+ FunctionBody* tmp;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_,
+ get_func_sig, &tmp));
+ map_defun_fn_.reset(tmp);
+
+ // Find the MapDefun node in outer_scope_
+ int node_id = graph_utils::GetFirstElementIndexWithPredicate(
+ [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); },
+ outer_scope_->nodes());
+ if (node_id == -1) {
+ return errors::NotFound("Could not find node with name ",
+ map_defun_node.name(), " in outer_scope.");
+ }
+ map_defun_node_ = outer_scope_->FindNodeId(node_id);
+
+ // Add mappings from map_defun_fn_ arg nodes to map_defun_node_ input nodes to
+ // the conversion map
+ for (auto arg_node : map_defun_fn_->arg_nodes) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(map_defun_node_->input_node(
+ arg_node->attrs().Find("index")->i(), &input_node));
- // Remove the old output and make everything that referenced it point
- // to the new string
- function_utils::ReplaceReferences(
- strings::StrCat(map_defun_node_->name(), ":output:", output_position),
- converted_output_name, outer_scope_);
- RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_,
- output_position);
+ conversion_map_.insert({{arg_node, 0}, {input_node, 0}});
+ }
return Status::OK();
}
-void Vectorization::Vectorize() {
- while (true) {
- FunctionDefTensorDesc desc;
- int output_position =
- FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc);
- if (output_position == -1) break;
+Status Vectorization::GetResult(FunctionDef** vectorized_function) {
+ TF_RETURN_IF_ERROR(status_);
- if (!ConvertOutput(output_position, desc).ok()) {
- unconvertible_.insert(desc.node_name);
- }
- }
+ if (!map_defun_fn_->ret_nodes.empty()) {
+ FunctionDef* map_defun_fn = lib_->add_function();
+ graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn));
- // If we've converted all the outputs of the MapDefun function, we no longer
- // need the MapDefun node and can delete it.
- if (map_defun_fn_->signature().output_arg_size() == 0) {
- outer_scope_->mutable_node_def()->DeleteSubrange(
- function_utils::FindFunctionNodeWithName(map_defun_node_->name(),
- *outer_scope_),
- 1);
+ AttrValue func_attr;
+ func_attr.mutable_func()->set_name(map_defun_fn->signature().name());
+ map_defun_node_->AddAttr("f", func_attr);
}
- if (!unconvertible_.empty()) {
- VLOG(2) << "The following nodes could not be converted: ["
- << absl::StrJoin(unconvertible_, ", ") << "].";
- }
+ *vectorized_function = lib_->add_function();
+ graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_,
+ *vectorized_function);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *outer_scope_, (*vectorized_function)->signature().name(),
+ *vectorized_function));
+ return Status::OK();
}
+
} // namespace
-void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node) {
- Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize();
+Status VectorizeMapDefun(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDefLibrary* lib,
+ FunctionDef** result) {
+ *result = nullptr;
+ return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result);
}
} // end namespace vectorization_utils
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
index bb405faa77..bd7d390900 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
@@ -24,22 +24,28 @@ namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
-// Given a function, `map_defun_fn`, that is mapped across some input vector
-// elements via a MapDefun operation, `VectorizeMapDefun` attempts to
-// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the
-// `outer_scope`; that is, replacing `map_defun_fn` operations with new
-// `outer_scope` operations that produce the same vector output(s) as executing
-// the `map_defun_fn` operations on elements of vector input(s) would. If all
-// `map_defun_fn` operations are successfully lifted, `map_defun_node` is
-// eliminated from `outer_scope` altogether. However, if some operations cannot
-// be lifted, and this vectorization only succeeds partially, `map_defun_node`
-// remains to be used for operations that were not lifted.
+// Given a MapDefun node (`map_defun_node`) in a FunctionDef (`outer_scope`)
+// that maps a function in lib across some input vector elements,
+// `VectorizeMapDefun` attempts to create a vectorized version of `outer_scope`
+// by "lifting" operations from the MapDefun function to the new function
+// (`result`); that is, replacing operations in the MapDefun function with
+// operations that produce the same vector output(s) as executing the original
+// operations on elements of vector input(s) would. If all operations in the
+// MapDefun function are successfully lifted, `result` has no MapDefun node
+// altogether. However, if some operations cannot be lifted, and this
+// vectorization only succeeds partially, a MapDefun node remains in `result` to
+// be used for operations that were not lifted, and the modified MapDefun
+// function is added to `lib`. The newly vectorized function `result` is also
+// added to `lib`.
+//
+// Returns Status::OK() if the vectorization is completely or partially
+// successful. Otherwise, returns an error, and sets `result` to nullptr.
//
// Example:
// If the input to the `VectorizeMapDefun` function is a MapDefun
// whose `map_defun_fn` performs the Cast operation, the vectorization will
// eliminate the MapDefun. This is because the Cast operation supports
-// any tensor shape and can thus be lifted to the `outer_scope`.
+// any tensor shape and can thus be lifted to `result`.
//
// Before:
//
@@ -68,7 +74,7 @@ namespace vectorization_utils {
//
// After:
//
-// outer_scope +------+
+// result +------+
// +---------------+ Arg0 +---------+
// | +---+--+ |
// | | |
@@ -80,8 +86,9 @@ namespace vectorization_utils {
// +---------------+ Ret0 +---------+
// +------+
//
-void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node);
+Status VectorizeMapDefun(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDefLibrary* lib,
+ FunctionDef** result);
} // end namespace vectorization_utils
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
index e129fa9237..1ff62217dd 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
@@ -60,6 +61,11 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
return node;
}
+string GetRetval(const FunctionDef& function_def, int index) {
+ return function_def.ret().at(
+ function_def.signature().output_arg(index).name());
+}
+
// TODO(rachelim): Use FunctionDefHelper::Create instead
FunctionDef CreateFunction(
StringPiece name, const std::vector<std::pair<string, DataType>>& inputs,
@@ -85,7 +91,6 @@ FunctionDef CreateFunction(
return func;
}
-TEST(FunctionDefInputDescTest, ConstructedCorrectly) {}
// Before:
//
@@ -133,10 +138,15 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- EXPECT_EQ(outer.ret().at("mapdefun"), "ret0");
- EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1");
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ EXPECT_EQ(GetRetval(*vectorized, 0), "ret0");
+ EXPECT_EQ(GetRetval(*vectorized, 1), "ret1");
}
// Before:
@@ -149,12 +159,12 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
// | +-----------+ Arg0 +---+ Arg1 +----+ |
// | | +---+--+ +---+--+ | |
// | | | | | |
-// | | +------+ | +---v--+ | |
-// | | |Const | | | Op0 | | |
-// | | +---v--+ | +---+--+ | |
+// | | +------+ | | | |
+// | | |Const | | | | |
+// | | +---v--+ | | | |
// | | | | | | |
// | | | +---v--+ +---v--+ | |
-// | | +---| XOp1 | | XOp2 | | |
+// | | +---| XOp1 | | Cast | | |
// | | +---+--+ +---+--+ | |
// | | | | | |
// | | MapDefun +---v--+ +---v--+ | |
@@ -165,23 +175,50 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
// +---------------+ Ret0 +---+ Ret1 +--------+
// +------+ +------+
//
-// where XOp1 and XOp2 are not convertible.
+// where XOp1 is not convertible.
//
// After:
//
-// No change because the ops are not convertible.
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ | |
+// | +-----------+ Arg0 +-+ | |
+// | | +---+--+ | | |
+// | | | | | |
+// | | +------+ | | | |
+// | | |Const | | | | |
+// | | +---v--+ | | | |
+// | | | | | | |
+// | | | +---v--+ | +---v--+ |
+// | | +---| XOp1 | | | Cast | |
+// | | +---+--+ | +---+--+ |
+// | | | | | |
+// | | MapDefun +---v--+ | | |
+// | +-----------+ Ret0 +-+ | |
+// | +---+--+ | |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
//
TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
FunctionDef inner =
CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
{{"ret0", DT_INT32}, {"ret1", DT_INT32}},
- {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}});
+ {{"ret0", "MatMul:product:0"}, {"ret1", "Cast:y:0"}});
+ // TODO(rachelim): If we ever write a converter for MatMul, we have to
+ // change this test.
NodeDef* x_op1 =
- function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner);
+ function_utils::AddNode("MatMul", "MatMul", {"arg0", "arg0"}, {}, &inner);
CHECK_NOTNULL(x_op1);
+ graph_transforms::SetNodeAttr("T", DT_INT32, x_op1);
- NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner);
- CHECK_NOTNULL(x_op2);
+ NodeDef* cast_node =
+ AddCastNode("Cast", {"arg1"}, DT_INT32, DT_INT32, false, &inner);
+ CHECK_NOTNULL(cast_node);
FunctionDef outer = CreateFunction(
"outer_function", {{"x", DT_INT32}, {"y", DT_INT32}},
@@ -193,12 +230,22 @@ TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- FunctionDef outer_copy(outer);
- FunctionDef inner_copy(inner);
- VectorizeMapDefun(&outer, &inner, map_defun);
- // They should be unchanged
- EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
- EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+
+ auto map_defun_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized));
+ // The Cast node should be converted just fine.
+ EXPECT_EQ(GetRetval(*vectorized, 1), "Cast:y:0");
+
+ // The inner function should only have one retval.
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib);
+ const FunctionDef* map_defun_fn =
+ lib_def.Find(map_defun_node.attr().at("f").func().name());
+ EXPECT_EQ(map_defun_fn->signature().output_arg_size(), 1);
}
// Before:
@@ -257,14 +304,19 @@ TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -330,16 +382,21 @@ TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -411,21 +468,26 @@ TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
{{1}, {1}, {1}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& unpack_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& unpack_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Unpack", *vectorized));
EXPECT_EQ(unpack_node.input(0), "x");
EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(unpack_node.name(), ":output:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(unpack_node.name(), ":output:1"));
- EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ EXPECT_EQ(GetRetval(*vectorized, 2),
strings::StrCat(unpack_node.name(), ":output:2"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -486,7 +548,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
{"ret1", "MyUnstack:output:1"},
{"ret2", "MyUnstack:output:2"}});
NodeDef* cast_op =
- AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT32, false, &inner);
CHECK_NOTNULL(cast_op);
NodeDef* unstack_op =
AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner);
@@ -505,25 +567,30 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
{{1}, {1}, {1}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- const NodeDef& unpack_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ const NodeDef& unpack_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Unpack", *vectorized));
EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0"));
EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(unpack_node.name(), ":output:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(unpack_node.name(), ":output:1"));
- EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ EXPECT_EQ(GetRetval(*vectorized, 2),
strings::StrCat(unpack_node.name(), ":output:2"));
- EXPECT_EQ(outer.node_def_size(), 2);
+ EXPECT_EQ(vectorized->node_def_size(), 2);
}
// Before:
@@ -561,9 +628,11 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
FunctionDef inner =
CreateFunction("inner_function", {{"arg0", DT_INT32}},
{{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
- // The attrs aren't relevant
- NodeDef* print_op =
- function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner);
+ NodeDef* print_op = function_utils::AddNode(
+ "Print", "Print", {"arg0", "arg0"}, {/*attrs*/}, &inner);
+ graph_transforms::SetNodeAttr("T", DT_INT32, print_op);
+ graph_transforms::SetNodeAttr("U", gtl::ArraySlice<DataType>({DT_INT32}),
+ print_op);
CHECK_NOTNULL(print_op);
NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64,
false, &inner);
@@ -578,11 +647,27 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- FunctionDef outer_copy(outer);
- FunctionDef inner_copy(inner);
- VectorizeMapDefun(&outer, &inner, map_defun);
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
// They should be unchanged
- EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+ // We check this somewhat manually as the names of nodes may have changed
+ EXPECT_EQ(vectorized->node_def_size(), 1);
+ const NodeDef& map_defun_node = vectorized->node_def(0);
+ EXPECT_EQ(map_defun_node.op(), "MapDefun");
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib);
+ const FunctionDef* map_defun_fn =
+ lib_def.Find(map_defun_node.attr().at("f").func().name());
+
+ const NodeDef& print_node = map_defun_fn->node_def(
+ function_utils::FindFunctionNodeWithOp("Print", *map_defun_fn));
+ const NodeDef& cast_node = map_defun_fn->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *map_defun_fn));
+ string control_input = strings::StrCat("^", print_node.name());
+ EXPECT_TRUE(cast_node.input(0) == control_input ||
+ cast_node.input(1) == control_input);
}
// TODO(rachelim): More test cases when we get around to implementing them:
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index e18a5f21d2..406c1b60ce 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -115,6 +115,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
Status MetaOptimizer::InitializeOptimizers(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
+ if (cfg_.disable_meta_optimizer()) {
+ return Status::OK();
+ }
if (!cfg_.disable_model_pruning()) {
optimizers->push_back(MakeUnique<ModelPruner>());
}
@@ -489,6 +492,9 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
}
bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
+ if (cfg.disable_meta_optimizer()) {
+ return false;
+ }
return !cfg.disable_model_pruning() ||
cfg.layout_optimizer() != RewriterConfig::OFF ||
cfg.function_optimization() != RewriterConfig::OFF ||
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index a428aea7f5..6861fb423c 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -41,7 +41,8 @@ Status RegisterFunctionBodyOutputs(const OpRegistrationData& registration,
tensorflow::NameRangeMap outputs_range_map;
TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
node, registration.op_def, nullptr, &outputs_range_map));
- connectivity->RegisterFunctionBodyOutputs(node.name(), outputs_range_map);
+ connectivity->RegisterFunctionBodyOutputs(node.name(),
+ std::move(outputs_range_map));
return Status::OK();
}
@@ -75,20 +76,22 @@ Status ResolveFunctionBodyNodeAttrPlaceholders(
} // namespace
void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
- const InputArgExpansion& input_arg_expansion) {
- const auto& input_name = input_arg_expansion.input_name;
+ InputArgExpansion input_arg_expansion) {
+ string input_name = input_arg_expansion.input_name;
const auto& placeholders = input_arg_expansion.placeholders;
- input_arg_expansions_.emplace(input_name, input_arg_expansion);
+
for (int i = 0; i < placeholders.size(); ++i) {
const string& placeholder = input_arg_expansion.placeholders[i];
- input_arg_placeholders_.emplace(
- placeholder, InputArgPlaceholder{input_name, /*position=*/i});
+ input_arg_placeholders_.insert(
+ {placeholder, InputArgPlaceholder{input_name, /*position=*/i}});
}
+ input_arg_expansions_.insert(
+ {std::move(input_name), std::move(input_arg_expansion)});
}
void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs(
- const string& node_name, const tensorflow::NameRangeMap& outputs) {
- function_body_outputs_[node_name] = outputs;
+ const string& node_name, tensorflow::NameRangeMap&& outputs) {
+ function_body_outputs_[node_name] = std::move(outputs);
}
Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
@@ -174,11 +177,12 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
const auto& output_range = output->second;
if (position == -1) {
+ graph_def_inputs->reserve(graph_def_inputs->size() +
+ output_range.second - output_range.first);
// If position is not defined expand node output range
for (int i = output_range.first; i < output_range.second; ++i) {
- i == 0 ? graph_def_inputs->push_back(node_name)
- : graph_def_inputs->push_back(
- strings::StrCat(node_name, ":", i));
+ graph_def_inputs->push_back(
+ i == 0 ? node_name : strings::StrCat(node_name, ":", i));
}
} else {
if (position > (output_range.second - output_range.first)) {
@@ -187,9 +191,8 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
" position: ", position, " (out of range)");
}
int pos = output_range.first + position;
- pos == 0 ? graph_def_inputs->push_back(node_name)
- : graph_def_inputs->push_back(
- strings::StrCat(node_name, ":", pos));
+ graph_def_inputs->push_back(
+ pos == 0 ? node_name : strings::StrCat(node_name, ":", pos));
}
return Status::OK();
@@ -211,8 +214,8 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs(
}
function_body_node->clear_input();
- for (const string& expanded_input : expanded_inputs)
- function_body_node->add_input(expanded_input);
+ for (string& expanded_input : expanded_inputs)
+ function_body_node->add_input(std::move(expanded_input));
return Status::OK();
}
@@ -323,7 +326,7 @@ GrapplerFunctionItem::GrapplerFunctionItem(
// Fill the feed nodes with input placeholders.
for (const InputArgExpansion& input_arg : input_arg_expansions_) {
for (const string& placeholder : input_arg.placeholders) {
- feed.emplace_back(placeholder, Tensor());
+ feed.push_back({placeholder, Tensor()});
input_arg_placeholders_.insert(placeholder);
}
}
@@ -460,7 +463,7 @@ Status InstantiationBodyParameters(
auto it = func_instantiation_attr.find(placeholder);
if (it != func_instantiation_attr.end()) {
- body_parameters->emplace(placeholder, it->second);
+ body_parameters->insert({placeholder, it->second});
} else {
return errors::InvalidArgument("Can't resolve placeholder: ",
placeholder);
@@ -498,10 +501,6 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
// GraphDef input format (name[:position])
GrapplerFunctionConnectivity connectivity;
- std::vector<InputArgExpansion> inputs;
- std::vector<OutputArgExpansion> outputs;
- std::vector<string> keep_nodes;
-
// Function body shares the library with the graph that instantiated it.
GraphDef function_body;
*function_body.mutable_library() = flib.ToProto();
@@ -518,6 +517,9 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
}
}
+ std::vector<InputArgExpansion> inputs;
+ inputs.reserve(signature.input_arg_size());
+
// For each input argument create a placeholder in function body.
for (const OpDef::ArgDef& input : signature.input_arg()) {
if (!input.type_list_attr().empty() || !input.number_attr().empty()) {
@@ -542,9 +544,10 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
/*is_ref*/ input.is_ref(),
/*placeholders=*/{input.name()}};
connectivity.RegisterInputArgExpansion(input_expansion);
- inputs.push_back(input_expansion);
+ inputs.push_back(std::move(input_expansion));
}
+ std::vector<string> keep_nodes;
// Add all function nodes to the function body
for (const NodeDef& func_def_node : func.node_def()) {
NodeDef* new_node = function_body.add_node();
@@ -572,6 +575,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
TF_RETURN_IF_ERROR(connectivity.ExpandNodeInputs(&node));
}
+ std::vector<OutputArgExpansion> outputs;
+ outputs.reserve(signature.output_arg_size());
// Add function outputs
for (const OpDef::ArgDef& out : signature.output_arg()) {
std::vector<string> output_tensors;
@@ -589,8 +594,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
OutputArgExpansion output{/*output_name=*/out.name(),
/*data_type=*/output_data_type,
/*is_ref=*/out.is_ref(),
- /*output_tensors=*/output_tensors};
- outputs.push_back(output);
+ /*output_tensors=*/std::move(output_tensors)};
+ outputs.push_back(std::move(output));
}
bool is_stateful = signature.is_stateful();
diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h
index 733caf325f..ef944ced09 100644
--- a/tensorflow/core/grappler/utils/functions.h
+++ b/tensorflow/core/grappler/utils/functions.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include <unordered_map>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
@@ -70,9 +71,9 @@ struct OutputArgExpansion {
// and fold it back when doing backward conversion.
class GrapplerFunctionConnectivity {
public:
- void RegisterInputArgExpansion(const InputArgExpansion& input_arg_expansion);
+ void RegisterInputArgExpansion(InputArgExpansion input_arg_expansion);
void RegisterFunctionBodyOutputs(const string& node_name,
- const tensorflow::NameRangeMap& outputs);
+ tensorflow::NameRangeMap&& outputs);
// Expand input encoded in FunctionDef format (name[:output][:position]) into
// multiple inputs in GraphDef format (name[:position]).
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 0b8e9ec527..9439ab332c 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1197,8 +1197,10 @@ tf_cc_test(
tf_cc_test(
name = "example_parsing_ops_test",
- size = "large",
+ size = "medium",
srcs = ["example_parsing_ops_test.cc"],
+ shard_count = 4,
+ tags = ["optonly"],
deps = [
":example_parsing_ops",
":ops_testutil",
@@ -4049,11 +4051,6 @@ cc_library(
)
SPARSE_DEPS = [
- ":bounds_check",
- ":cwise_op",
- ":fill_functor",
- ":scatter_functor",
- "//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:sparse_ops_op_lib",
@@ -4086,7 +4083,9 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_cross_op",
prefix = "sparse_cross_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4098,13 +4097,19 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_dense_binary_op_shared",
prefix = "sparse_dense_binary_op_shared",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":cwise_op",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_sparse_binary_op_shared",
prefix = "sparse_sparse_binary_op_shared",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":cwise_op",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4136,7 +4141,9 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_softmax",
prefix = "sparse_softmax",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4148,25 +4155,37 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_tensor_dense_add_op",
prefix = "sparse_tensor_dense_add_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":scatter_functor",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_tensor_dense_matmul_op",
prefix = "sparse_tensor_dense_matmul_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":bounds_check",
+ ":fill_functor",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_to_dense_op",
prefix = "sparse_to_dense_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_xent_op",
prefix = "sparse_xent_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":bounds_check",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index fa959b5a0e..82e2913b64 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -132,7 +132,6 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
- col_params_.instance.shape = c->input(0).shape();
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
@@ -144,6 +143,7 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
c->forward_input_or_allocate_output(
{0}, 0, c->input(0).shape(), &output),
done);
+ col_params_.instance.shape = c->input(0).shape();
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
auto actual_done = [c, col_exec, done](const Status& s) {
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 87efdff789..6333853cdf 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -765,6 +765,7 @@ tf_kernel_library(
":window_dataset_op",
":writer_ops",
":zip_dataset_op",
+ "//tensorflow/core/kernels/data/experimental:dataset_kernels",
],
)
diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index ec6cb37193..43406db3ed 100644
--- a/tensorflow/contrib/data/kernels/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -1,22 +1,26 @@
# Description:
-# Contains kernels for datasets and iterators.
+# Contains experimental kernels for datasets and iterators.
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_kernel_library",
+)
+
cc_library(
name = "indexed_dataset_headers",
hdrs = ["indexed_dataset.h"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:framework",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
)
-cc_library(
+tf_kernel_library(
name = "indexed_dataset",
srcs = [
"identity_indexed_dataset.cc",
@@ -24,103 +28,102 @@ cc_library(
],
deps = [
":indexed_dataset_headers",
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "prefetching_kernels",
srcs = ["prefetching_kernels.cc"],
deps = [
- "//tensorflow/core:core_cpu_headers_lib",
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "directed_interleave_dataset_op",
srcs = ["directed_interleave_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "csv_dataset_op",
srcs = ["csv_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "ignore_errors_dataset_op",
srcs = ["ignore_errors_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "lmdb_dataset_op",
srcs = ["lmdb_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
"@lmdb",
- "@protobuf_archive//:protobuf_headers",
],
)
-cc_library(
+tf_kernel_library(
name = "threadpool_dataset_op",
srcs = ["threadpool_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "unique_dataset_op",
srcs = ["unique_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "assert_next_dataset_op",
srcs = ["assert_next_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "dataset_kernels",
deps = [
":assert_next_dataset_op",
@@ -132,8 +135,5 @@ cc_library(
":prefetching_kernels",
":threadpool_dataset_op",
":unique_dataset_op",
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
)
diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
index c19a609780..3511cca0f5 100644
--- a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
@@ -147,8 +147,9 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
-REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU),
- AssertNextDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalAssertNextDataset").Device(DEVICE_CPU),
+ AssertNextDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
index 21ec50fb6b..7451ca4cb1 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
@@ -852,7 +852,8 @@ class CSVDatasetOp : public DatasetOpKernel {
}; // class CSVDatasetOp
// Register the kernel implementation for CSVDataset.
-REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalCSVDataset").Device(DEVICE_CPU),
+ CSVDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
index a5321620bf..c47a9099c4 100644
--- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
@@ -272,8 +272,9 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
- DirectedInterleaveDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
+ DirectedInterleaveDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc
index c3cb45dbf7..2141f118ca 100644
--- a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc
+++ b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/data/kernels/indexed_dataset.h"
+#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -147,8 +147,9 @@ class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("IdentityIndexedDataset").Device(DEVICE_CPU),
- IdentityIndexedDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIdentityIndexedDataset").Device(DEVICE_CPU),
+ IdentityIndexedDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
index beec344534..b34377c642 100644
--- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
namespace data {
@@ -133,8 +132,9 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("IgnoreErrorsDataset").Device(DEVICE_CPU),
- IgnoreErrorsDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIgnoreErrorsDataset").Device(DEVICE_CPU),
+ IgnoreErrorsDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc
index ced8ab0d60..75ea462f40 100644
--- a/tensorflow/contrib/data/kernels/indexed_dataset.cc
+++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/data/kernels/indexed_dataset.h"
+#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -361,12 +361,14 @@ class IndexedDatasetGet : public OpKernel {
};
REGISTER_KERNEL_BUILDER(
- Name("MaterializedIndexDatasetHandle").Device(DEVICE_CPU),
+ Name("ExperimentalMaterializedIndexDatasetHandle").Device(DEVICE_CPU),
MaterializedHandleOp);
-REGISTER_KERNEL_BUILDER(Name("IndexedDatasetMaterialize").Device(DEVICE_CPU),
- MaterializeDatasetOp);
-REGISTER_KERNEL_BUILDER(Name("IndexedDatasetGet").Device(DEVICE_CPU),
- IndexedDatasetGet);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIndexedDatasetMaterialize").Device(DEVICE_CPU),
+ MaterializeDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIndexedDatasetGet").Device(DEVICE_CPU),
+ IndexedDatasetGet);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.h b/tensorflow/core/kernels/data/experimental/indexed_dataset.h
index 7aa2d3fdbc..27a8360cbc 100644
--- a/tensorflow/contrib/data/kernels/indexed_dataset.h
+++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
-#define TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -116,4 +116,4 @@ Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
} // namespace data
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
+#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
diff --git a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
index d233c1f8ec..8a88d32f0c 100644
--- a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
@@ -210,7 +210,8 @@ class LMDBDatasetOp : public DatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalLMDBDataset").Device(DEVICE_CPU),
+ LMDBDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
index 96f1dd0059..2c6179d9f5 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
@@ -338,20 +338,20 @@ class FunctionBufferResourceHandleOp : public OpKernel {
DataTypeVector output_types_;
};
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
.Device(DEVICE_CPU)
.HostMemory("resource")
.HostMemory("string_arg")
.HostMemory("target_device"),
FunctionBufferResourceHandleOp);
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
.Device(DEVICE_GPU)
.HostMemory("resource")
.HostMemory("string_arg")
.HostMemory("target_device"),
FunctionBufferResourceHandleOp);
#if TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
.Device(DEVICE_SYCL)
.HostMemory("resource")
.HostMemory("string_arg")
@@ -403,16 +403,16 @@ class FunctionBufferingResourceGetNextOp : public AsyncOpKernel {
}
};
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
.Device(DEVICE_CPU)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceGetNextOp);
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
.Device(DEVICE_GPU)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceGetNextOp);
#if TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
.Device(DEVICE_SYCL)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceGetNextOp);
@@ -440,16 +440,16 @@ class FunctionBufferingResourceResetOp : public OpKernel {
}
};
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
.Device(DEVICE_CPU)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceResetOp);
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
.Device(DEVICE_GPU)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceResetOp);
#if TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
.Device(DEVICE_SYCL)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceResetOp);
@@ -473,8 +473,9 @@ class IteratorGetDeviceOp : public OpKernel {
}
};
-REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU),
- IteratorGetDeviceOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIteratorGetDevice").Device(DEVICE_CPU),
+ IteratorGetDeviceOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
index 30fa97a636..c80493d3a1 100644
--- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
@@ -209,10 +209,11 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("ThreadPoolHandle").Device(DEVICE_CPU),
+REGISTER_KERNEL_BUILDER(Name("ExperimentalThreadPoolHandle").Device(DEVICE_CPU),
ThreadPoolHandleOp);
-REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU),
- ThreadPoolDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalThreadPoolDataset").Device(DEVICE_CPU),
+ ThreadPoolDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
index 57fc5697a4..cd612e0eb2 100644
--- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
@@ -199,8 +199,9 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
HANDLE_TYPE(DT_INT64);
HANDLE_TYPE(DT_STRING);
default:
- LOG(FATAL) << "UniqueDataset unhandled data type: "
- << DataTypeString(lhs.dtype());
+ DCHECK(false) << "UniqueDataset unhandled data type: "
+ << DataTypeString(lhs.dtype());
+ return false;
}
}
};
@@ -215,7 +216,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("UniqueDataset").Device(DEVICE_CPU),
+REGISTER_KERNEL_BUILDER(Name("ExperimentalUniqueDataset").Device(DEVICE_CPU),
UniqueDatasetOp);
} // namespace
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 2bbf4af664..b4c7f9e510 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -37,6 +37,8 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
index 5f143967d9..d909b9e9d3 100644
--- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
+++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
@@ -134,19 +134,17 @@ class MultiDeviceIterator : public ResourceBase {
void Reset() LOCKS_EXCLUDED(mu_) {
{
mutex_lock l(mu_);
- if (background_thread_finished_) {
- return;
- }
-
- cancelled_ = true;
- // Wake up the background thread.
- for (int i = 0; i < size_; ++i) {
- buffer_[i].cond_var.notify_all();
- }
+ if (!background_thread_finished_) {
+ cancelled_ = true;
+ // Wake up the background thread.
+ for (int i = 0; i < size_; ++i) {
+ buffer_[i].cond_var.notify_all();
+ }
- // Make sure background thread has finished first.
- while (!background_thread_finished_) {
- shutdown_cond_var_.wait(l);
+ // Make sure background thread has finished first.
+ while (!background_thread_finished_) {
+ shutdown_cond_var_.wait(l);
+ }
}
}
RunPendingCallbacks();
@@ -182,7 +180,7 @@ class MultiDeviceIterator : public ResourceBase {
buffer_[shard_num].cond_var.notify_all();
}
} else {
- if (background_thread_finished_) {
+ if (end_of_iterator_) {
produced_output = true;
elem.end_of_sequence = true;
} else {
@@ -219,8 +217,12 @@ class MultiDeviceIterator : public ResourceBase {
while (!buffer_[i].callbacks.empty()) {
if (buffer_[i].data.empty()) {
HostBufferElement elem;
- elem.status =
- errors::Cancelled("Cancelled and buffer not filled.");
+ if (end_of_iterator_) {
+ elem.end_of_sequence = true;
+ } else {
+ elem.status =
+ errors::Cancelled("Cancelled and buffer not filled.");
+ }
cancellation_elements.push_back(std::move(elem));
} else {
cancellation_elements.push_back(
@@ -293,6 +295,7 @@ class MultiDeviceIterator : public ResourceBase {
{
mutex_lock l(mu_);
background_thread_finished_ = true;
+ end_of_iterator_ = true;
shutdown_cond_var_.notify_all();
}
RunPendingCallbacks();
@@ -312,6 +315,7 @@ class MultiDeviceIterator : public ResourceBase {
std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_);
bool background_thread_finished_ GUARDED_BY(mu_) = false;
bool background_thread_started_ GUARDED_BY(mu_) = false;
+ bool end_of_iterator_ GUARDED_BY(mu_) = false;
bool cancelled_ GUARDED_BY(mu_) = false;
condition_variable shutdown_cond_var_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 2e6e0465f7..2bb38bf0b9 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -1084,6 +1084,9 @@ REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
// The above design choices were made with automated optimizations in mind,
// isolating the degree of parallelism as the single tunable knob of this
// implementation.
+//
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
public:
explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx)
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index ee20249bfe..da067a4e6f 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -27,6 +27,8 @@ namespace tensorflow {
namespace data {
namespace {
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class ParallelMapIterator : public DatasetBaseIterator {
public:
explicit ParallelMapIterator(
@@ -104,18 +106,17 @@ class ParallelMapIterator : public DatasetBaseIterator {
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"),
invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
- std::shared_ptr<InvocationResult> result = invocation_results_[i];
- TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
+ const auto& result = *(invocation_results_[i]);
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("invocation_results[", i, "].size")),
- result->return_values.size()));
- for (size_t j = 0; j < result->return_values.size(); j++) {
- TF_RETURN_IF_ERROR(
- writer->WriteTensor(full_name(strings::StrCat(
- "invocation_results[", i, "][", j, "]")),
- result->return_values[j]));
+ result.return_values.size()));
+ for (size_t j = 0; j < result.return_values.size(); j++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("invocation_results[", i, "][", j, "]")),
+ result.return_values[j]));
}
- if (result->end_of_input) {
+ if (result.end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(
strings::StrCat("invocation_results[", i, "].end_of_input")),
@@ -133,9 +134,9 @@ class ParallelMapIterator : public DatasetBaseIterator {
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name("invocation_results.size"), &invocation_results_size));
for (size_t i = 0; i < invocation_results_size; i++) {
- std::shared_ptr<InvocationResult> result(new InvocationResult());
- invocation_results_.push_back(result);
- TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
+ invocation_results_.push_back(std::make_shared<InvocationResult>());
+ auto& result = *invocation_results_.back();
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
size_t num_return_values;
{
int64 size;
@@ -151,17 +152,16 @@ class ParallelMapIterator : public DatasetBaseIterator {
": ", size, " is not a valid value of type size_t."));
}
}
- result->return_values.reserve(num_return_values);
+ result.return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
- result->return_values.emplace_back();
- TF_RETURN_IF_ERROR(
- reader->ReadTensor(full_name(strings::StrCat(
- "invocation_results[", i, "][", j, "]")),
- &result->return_values.back()));
+ result.return_values.emplace_back();
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("invocation_results[", i, "][", j, "]")),
+ &result.return_values.back()));
}
- result->end_of_input = reader->Contains(full_name(
+ result.end_of_input = reader->Contains(full_name(
strings::StrCat("invocation_results[", i, "].end_of_input")));
- result->notification.Notify();
+ result.notification.Notify();
}
return Status::OK();
}
@@ -257,7 +257,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
return;
}
while (!busy()) {
- invocation_results_.emplace_back(new InvocationResult());
+ invocation_results_.push_back(std::make_shared<InvocationResult>());
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
index 7e528a71be..c8abfb9eb5 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
@@ -118,16 +118,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- return Status::OK();
+ return errors::Unimplemented(dataset()->DebugString(),
+ " does not support checkpointing");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
- return Status::OK();
+ return errors::Unimplemented(dataset()->DebugString(),
+ " does not support checkpointing");
}
private:
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index 52157ed5fb..f406ad2ab5 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -853,7 +853,7 @@ class MklConvCustomBackpropFilterOp
// MKL DNN allocates large buffers when a conv gradient filter primtive is
// created. So we don't cache conv backward primitives when the env
- // variable TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is set to true.
+ // variable TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is set to true.
bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled();
conv_bwd_filter = MklConvBwdFilterPrimitiveFactory<T>::Get(
convBwdFilterDims, do_not_cache);
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index c38c9cc27c..a501ce2c93 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -713,7 +713,7 @@ class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> {
TFPaddingToMklDnnPadding(this->padding_));
// We don't cache those primitves if the env variable
- // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true and if primitve descriptor
+ // TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true and if primitve descriptor
// includes potentialy large buffers. MKL DNN allocates buffers
// in the following cases
// 1. Legacy CPU without AVX512/AVX2, or
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 184e0cb003..b332edad0a 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -901,7 +901,7 @@ class MklConvOp : public OpKernel {
// In some cases, primitve descriptor includes potentialy large buffers,
// we don't cache those primitves if the env variable
- // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true. MKL DNN allocates buffers
+ // TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true. MKL DNN allocates buffers
// in the following cases
// 1. Legacy CPU without AVX512/AVX2, or
// 2. 1x1 convolution with stride != 1
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 427044ca67..23d76986bf 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -172,17 +172,21 @@ REGISTER_KERNEL_BUILDER(
.Device(DEVICE_GPU) \
.HostMemory("resource") \
.TypeConstraint<type>("dtype"), \
- ResourceHandleOp<Var>) \
- REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp") \
- .Device(DEVICE_GPU) \
- .HostMemory("resources") \
- .TypeConstraint<type>("dtypes"), \
- ResourceHandlesOp<Var>)
-
+ ResourceHandleOp<Var>)
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_int64(REGISTER_GPU_KERNELS);
TF_CALL_variant(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
+
+REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
+ .Device(DEVICE_GPU)
+ .HostMemory("resources")
+ .TypeConstraint("dtypes",
+ {DT_INT64, DT_COMPLEX64,
+ DT_COMPLEX128, DT_HALF, DT_FLOAT,
+ DT_DOUBLE, DT_BOOL, DT_VARIANT}),
+ ResourceHandlesOp<Var>);
+
#endif // GOOGLE_CUDA
template <typename T>
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 32ce31cf23..43c14d83b5 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -21532,6 +21532,421 @@ op {
}
}
op {
+ name: "ExperimentalAssertNextDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "transformations"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalCSVDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "compression_type"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "buffer_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "header"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "field_delim"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "use_quote_delim"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "na_value"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "select_cols"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "record_defaults"
+ type_list_attr: "output_types"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalDirectedInterleaveDataset"
+ input_arg {
+ name: "selector_input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "data_input_datasets"
+ type: DT_VARIANT
+ number_attr: "N"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalFunctionBufferingResource"
+ input_arg {
+ name: "string_arg"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "target_device"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "buffer_size"
+ type: "int"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceGetNext"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceReset"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIdentityIndexedDataset"
+ input_arg {
+ name: "size"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalIndexedDatasetGet"
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "index"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIndexedDatasetMaterialize"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIteratorGetDevice"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "device"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalLMDBDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalMaterializedIndexDatasetHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "thread_pool"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "num_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ type: "int"
+ default_value {
+ i: 1
+ }
+ }
+ attr {
+ name: "display_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalUniqueDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Expm1"
input_arg {
name: "x"
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc
index d1a771f005..f6bd5dce26 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/experimental_dataset_ops.cc
@@ -17,24 +17,16 @@ limitations under the License.
namespace tensorflow {
-REGISTER_OP("DirectedInterleaveDataset")
+REGISTER_OP("ExperimentalDirectedInterleaveDataset")
.Input("selector_input_dataset: variant")
.Input("data_input_datasets: N * variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("N: int >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-A substitute for `InterleaveDataset` on a fixed list of `N` datasets.
-
-selector_input_dataset: A dataset of scalar `DT_INT64` elements that determines
- which of the `N` data inputs should produce the next output element.
-data_input_datasets: `N` datasets with the same type that will be interleaved
- according to the values of `selector_input_dataset`.
-)doc");
+ .SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("CSVDataset")
+REGISTER_OP("ExperimentalCSVDataset")
.Input("filenames: string")
.Input("compression_type: string")
.Input("buffer_size: int64")
@@ -76,35 +68,26 @@ REGISTER_OP("CSVDataset")
return shape_inference::ScalarShape(c);
});
-REGISTER_OP("IgnoreErrorsDataset")
+REGISTER_OP("ExperimentalIgnoreErrorsDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that contains the elements of `input_dataset` ignoring errors.
-)doc");
+ .SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("UniqueDataset")
+REGISTER_OP("ExperimentalUniqueDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that contains the unique elements of `input_dataset`.
-)doc");
+ .SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("IteratorGetDevice")
+REGISTER_OP("ExperimentalIteratorGetDevice")
.Input("resource: resource")
.Output("device: string")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Returns the name of the device on which `resource` has been placed.
-)doc");
+ .SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("FunctionBufferingResource")
+REGISTER_OP("ExperimentalFunctionBufferingResource")
.Input("string_arg: string")
.Input("target_device: string")
.Output("resource: resource")
@@ -113,77 +96,36 @@ REGISTER_OP("FunctionBufferingResource")
.Attr("f: func")
.Attr("buffer_size: int")
.Attr("output_types: list(type)")
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Creates a resource that fills up a buffer by making function calls.
-
-string_arg: String argument to the function call.
-target_device: Target device to execute the function on.
-resource: Handle to the resource created.
-f: Function to be executed.
-buffer_size: Size of the buffer.
-container: If non-empty, this resource is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this resource will be shared under the given name
- across multiple sessions.
-output_types: The type list for the return values.
-)doc");
-
-REGISTER_OP("FunctionBufferingResourceGetNext")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("ExperimentalFunctionBufferingResourceGetNext")
.Input("function_buffer_resource: resource")
.Attr("output_types: list(type)")
.Output("output: output_types")
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Gets the next element from a FunctionBufferingResource.
+ .SetShapeFn(shape_inference::UnknownShape);
-function_buffer_resource: The FunctionBufferingResource handle.
-output: A list of return values.
-output_types: The type list for the return values.
-)doc");
-
-REGISTER_OP("FunctionBufferingResourceReset")
+REGISTER_OP("ExperimentalFunctionBufferingResourceReset")
.Input("function_buffer_resource: resource")
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Resets the FunctionBufferingResource.
-
-function_buffer_resource: The FunctionBufferingResource handle.
-)doc");
+ .SetShapeFn(shape_inference::UnknownShape);
-REGISTER_OP("ThreadPoolDataset")
+REGISTER_OP("ExperimentalThreadPoolDataset")
.Input("input_dataset: variant")
.Input("thread_pool: resource")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that uses a custom thread pool to compute `input_dataset`.
-
-handle: A resource produced by the ThreadPoolHandle op.
-)doc");
+ .SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("ThreadPoolHandle")
+REGISTER_OP("ExperimentalThreadPoolHandle")
.Output("handle: resource")
.SetShapeFn(shape_inference::ScalarShape)
.Attr("num_threads: int")
.Attr("max_intra_op_parallelism: int = 1")
.Attr("display_name: string")
.Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Doc(R"doc(
-Creates a custom thread pool with the given number of threads.
-
-handle: A resource that can be consumed by one or more ThreadPoolDataset ops.
-num_threads: The number of threads in the thread pool.
-max_intra_op_parallelism: The maximum degree of parallelism to use within
- operations that execute on this threadpool.
-display_name: A human-readable name for the threads that may be visible in
- some visualizations.
-)doc");
-
-REGISTER_OP("AssertNextDataset")
+ .Attr("shared_name: string = ''");
+
+REGISTER_OP("ExperimentalAssertNextDataset")
.Input("input_dataset: variant")
.Input("transformations: string")
.Output("handle: variant")
@@ -196,7 +138,7 @@ REGISTER_OP("AssertNextDataset")
return shape_inference::ScalarShape(c);
});
-REGISTER_OP("LMDBDataset")
+REGISTER_OP("ExperimentalLMDBDataset")
.Input("filenames: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
@@ -205,4 +147,61 @@ REGISTER_OP("LMDBDataset")
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("ExperimentalIdentityIndexedDataset")
+ .Input("size: uint64")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(
+ shape_inference::ScalarShape); // TODO(saeta): check input shapes.
+
+///////////////////////////////////////////////////////////////////////////////
+// IndexedDataset Internals
+///////////////////////////////////////////////////////////////////////////////
+
+// Creates the handle.
+REGISTER_OP("ExperimentalMaterializedIndexDatasetHandle")
+ .Output("handle: resource")
+ .Attr("container: string")
+ .Attr("shared_name: string")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+// Actually materialize the materialize handle.
+REGISTER_OP("ExperimentalIndexedDatasetMaterialize")
+ .Input("dataset: variant")
+ .Input("materialized: resource")
+ .SetShapeFn(shape_inference::NoOutputs);
+
+namespace {
+
+Status GetShapeFn(shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as `output_types` (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+REGISTER_OP("ExperimentalIndexedDatasetGet")
+ .Input("materialized: resource")
+ .Input("index: uint64")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(GetShapeFn);
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 02a7f8d717..abee803889 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -10039,6 +10039,421 @@ op {
}
}
op {
+ name: "ExperimentalAssertNextDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "transformations"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalCSVDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "compression_type"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "buffer_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "header"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "field_delim"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "use_quote_delim"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "na_value"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "select_cols"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "record_defaults"
+ type_list_attr: "output_types"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalDirectedInterleaveDataset"
+ input_arg {
+ name: "selector_input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "data_input_datasets"
+ type: DT_VARIANT
+ number_attr: "N"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalFunctionBufferingResource"
+ input_arg {
+ name: "string_arg"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "target_device"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "buffer_size"
+ type: "int"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceGetNext"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceReset"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIdentityIndexedDataset"
+ input_arg {
+ name: "size"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalIndexedDatasetGet"
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "index"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIndexedDatasetMaterialize"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIteratorGetDevice"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "device"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalLMDBDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalMaterializedIndexDatasetHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "thread_pool"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "num_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ type: "int"
+ default_value {
+ i: 1
+ }
+ }
+ attr {
+ name: "display_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalUniqueDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Expm1"
input_arg {
name: "x"
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index bb841aeab7..3b14757945 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -641,54 +641,41 @@ def tf_additional_lib_deps():
def tf_additional_core_deps():
return select({
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
+ "//tensorflow:android": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:ios": [],
+ "//tensorflow:linux_s390x": [],
+ "//conditions:default": [
"//tensorflow/core/platform/cloud:gcs_file_system",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_hdfs_support_windows_override": [],
- "//tensorflow:with_hdfs_support_android_override": [],
- "//tensorflow:with_hdfs_support_ios_override": [],
- "//tensorflow:with_hdfs_support": [
- "//tensorflow/core/platform/hadoop:hadoop_file_system",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support_android_override": [],
- "//tensorflow:with_aws_support_ios_override": [],
- "//tensorflow:with_aws_support": [
"//tensorflow/core/platform/s3:s3_file_system",
+ "//tensorflow/core/platform/hadoop:hadoop_file_system",
],
- "//conditions:default": [],
})
# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_op_deps():
return select({
- "//tensorflow:with_gcp_support_windows_override": [],
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
+ "//tensorflow:android": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:ios": [],
+ "//tensorflow:linux_s390x": [],
+ "//conditions:default": [
"//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib",
"//tensorflow/contrib/cloud:gcs_config_ops_op_lib",
],
- "//conditions:default": [],
})
# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_kernel_deps():
return select({
- "//tensorflow:with_gcp_support_windows_override": [],
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
+ "//tensorflow:android": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:ios": [],
+ "//tensorflow:linux_s390x": [],
+ "//conditions:default": [
"//tensorflow/contrib/cloud/kernels:bigquery_reader_ops",
"//tensorflow/contrib/cloud/kernels:gcs_config_ops",
],
- "//conditions:default": [],
})
def tf_lib_proto_parsing_deps():
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index 85cd02350a..104ab039cb 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -453,6 +453,11 @@ message RunOptions {
// same group_key value (in a distributed computation where tasks
// run disjoint graphs).
int64 collective_graph_key = 1;
+ // If true, then operations (using the inter-op pool) across all
+ // session::run() calls will be centrally scheduled, optimizing for (median
+ // and tail) latency.
+ // Consider using this option for CPU-bound workloads like inference.
+ bool use_run_handler_pool = 2;
};
Experimental experimental = 8;
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index 482178a540..8e0448d536 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -77,6 +77,8 @@ message RewriterConfig {
Toggle scoped_allocator_optimization = 15;
// Force small ops onto the CPU (default is ON).
Toggle pin_to_host_optimization = 18;
+ // Disable the entire meta optimizer (off by default).
+ bool disable_meta_optimizer = 19;
// Controls how many times we run the optimizers in meta optimizer (default
// is once).
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index cf7ffd8149..04aaea4f89 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -2039,8 +2039,8 @@ class MklPrimitiveFactory {
/// Fuction to check whether primitive memory optimization is enabled
static inline bool IsPrimitiveMemOptEnabled() {
bool is_primitive_mem_opt_enabled = true;
- TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE", true,
- &is_primitive_mem_opt_enabled));
+ TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true,
+ &is_primitive_mem_opt_enabled));
return is_primitive_mem_opt_enabled;
}
@@ -2095,9 +2095,8 @@ static inline memory::format get_desired_format(int channel,
fmt_desired = is_2d ? memory::format::nChw16c : memory::format::nCdhw16c;
} else if (port::TestCPUFeature(port::CPUFeature::AVX2) &&
(channel % 8) == 0) {
- fmt_desired = is_2d
- ? memory::format::nChw8c
- : memory::format::ncdhw; //not support avx2 for 3d yet.
+ fmt_desired = is_2d ? memory::format::nChw8c
+ : memory::format::ncdhw; // no avx2 support for 3d yet.
} else {
fmt_desired = is_2d ? memory::format::nchw : memory::format::ncdhw;
}
@@ -2209,7 +2208,8 @@ inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
// utility function to determine if it is conv 1x1 and stride != 1
// for purpose of temporarily disabling primitive reuse
-inline bool IsConv1x1StrideNot1(memory::dims filter_dims, memory::dims strides) {
+inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
+ memory::dims strides) {
if (filter_dims.size() != 4 || strides.size() != 2) return false;
return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
diff --git a/tensorflow/core/util/tensor_bundle/BUILD b/tensorflow/core/util/tensor_bundle/BUILD
index 4d4db86df2..f40ec9b752 100644
--- a/tensorflow/core/util/tensor_bundle/BUILD
+++ b/tensorflow/core/util/tensor_bundle/BUILD
@@ -65,6 +65,10 @@ tf_cc_test(
name = "tensor_bundle_test",
srcs = ["tensor_bundle_test.cc"],
data = glob(["testdata/**"]),
+ tags = [
+ "nomsan",
+ "notsan",
+ ],
deps = [
":tensor_bundle",
"//tensorflow/core:framework",
diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD
index f327b645f5..f5f0d7c3c8 100644
--- a/tensorflow/examples/android/BUILD
+++ b/tensorflow/examples/android/BUILD
@@ -68,6 +68,7 @@ android_binary(
srcs = glob([
"src/**/*.java",
]),
+ aapt_version = "aapt",
# Package assets from assets dir as well as all model targets. Remove undesired models
# (and corresponding Activities in source) to reduce APK size.
assets = [
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 2f297d5161..b4d4db3e4d 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -3742,27 +3742,6 @@ func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf
return op.Output(0)
}
-// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble.
-//
-// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest
-// layer.
-func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesGetEnsembleStates",
- Input: []tf.Input{
- tree_ensemble_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
-}
-
// Creates a tree ensemble model and returns a handle to it.
//
// Arguments:
@@ -4059,6 +4038,364 @@ func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true
return op.Output(0), op.Output(1), op.Output(2)
}
+// LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler.
+type LogUniformCandidateSamplerAttr func(optionalAttr)
+
+// LogUniformCandidateSamplerSeed sets the optional seed attribute to value.
+//
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func LogUniformCandidateSamplerSeed(value int64) LogUniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// LogUniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
+//
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func LogUniformCandidateSamplerSeed2(value int64) LogUniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Generates labels for candidate sampling with a log-uniform distribution.
+//
+// See explanations of candidate sampling and the data formats at
+// go/candidate-sampling.
+//
+// For each batch, this op picks a single set of sampled candidate labels.
+//
+// The advantages of sampling candidates per-batch are simplicity and the
+// possibility of efficient dense matrix multiplication. The disadvantage is that
+// the sampled candidates must be chosen independently of the context and of the
+// true labels.
+//
+// Arguments:
+// true_classes: A batch_size * num_true matrix, in which each row contains the
+// IDs of the num_true target_classes in the corresponding original label.
+// num_true: Number of true labels per context.
+// num_sampled: Number of candidates to randomly sample.
+// unique: If unique is true, we sample with rejection, so that all sampled
+// candidates in a batch are unique. This requires some approximation to
+// estimate the post-rejection sampling probabilities.
+// range_max: The sampler will sample integers from the interval [0, range_max).
+//
+// Returns A vector of length num_sampled, in which each element is
+// the ID of a sampled candidate.A batch_size * num_true matrix, representing
+// the number of times each candidate is expected to occur in a batch
+// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
+// candidate representing the number of times the candidate is expected
+// to occur in a batch of sampled candidates. If unique=true, then this is a
+// probability.
+func LogUniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LogUniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "LogUniformCandidateSampler",
+ Input: []tf.Input{
+ true_classes,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// UniformCandidateSamplerAttr is an optional argument to UniformCandidateSampler.
+type UniformCandidateSamplerAttr func(optionalAttr)
+
+// UniformCandidateSamplerSeed sets the optional seed attribute to value.
+//
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func UniformCandidateSamplerSeed(value int64) UniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// UniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
+//
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func UniformCandidateSamplerSeed2(value int64) UniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Generates labels for candidate sampling with a uniform distribution.
+//
+// See explanations of candidate sampling and the data formats at
+// go/candidate-sampling.
+//
+// For each batch, this op picks a single set of sampled candidate labels.
+//
+// The advantages of sampling candidates per-batch are simplicity and the
+// possibility of efficient dense matrix multiplication. The disadvantage is that
+// the sampled candidates must be chosen independently of the context and of the
+// true labels.
+//
+// Arguments:
+// true_classes: A batch_size * num_true matrix, in which each row contains the
+// IDs of the num_true target_classes in the corresponding original label.
+// num_true: Number of true labels per context.
+// num_sampled: Number of candidates to randomly sample.
+// unique: If unique is true, we sample with rejection, so that all sampled
+// candidates in a batch are unique. This requires some approximation to
+// estimate the post-rejection sampling probabilities.
+// range_max: The sampler will sample integers from the interval [0, range_max).
+//
+// Returns A vector of length num_sampled, in which each element is
+// the ID of a sampled candidate.A batch_size * num_true matrix, representing
+// the number of times each candidate is expected to occur in a batch
+// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
+// candidate representing the number of times the candidate is expected
+// to occur in a batch of sampled candidates. If unique=true, then this is a
+// probability.
+func UniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...UniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "UniformCandidateSampler",
+ Input: []tf.Input{
+ true_classes,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping.
+type GenerateVocabRemappingAttr func(optionalAttr)
+
+// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value.
+//
+// value: Number of entries in the old vocab file to consider. If -1,
+// use the entire old vocabulary.
+// If not specified, defaults to -1
+//
+// REQUIRES: value >= -1
+func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr {
+ return func(m optionalAttr) {
+ m["old_vocab_size"] = value
+ }
+}
+
+// Given a path to new and old vocabulary files, returns a remapping Tensor of
+//
+// length `num_new_vocab`, where `remapping[i]` contains the row number in the old
+// vocabulary that corresponds to row `i` in the new vocabulary (starting at line
+// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i`
+// in the new vocabulary is not in the old vocabulary. The old vocabulary is
+// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the
+// default value of -1.
+//
+// `num_vocab_offset` enables
+// use in the partitioned variable case, and should generally be set through
+// examining partitioning info. The format of the files should be a text file,
+// with each line containing a single entity within the vocabulary.
+//
+// For example, with `new_vocab_file` a text file containing each of the following
+// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3],
+// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be
+// `[0, -1, 2]`.
+//
+// The op also returns a count of how many entries in the new vocabulary
+// were present in the old vocabulary, which is used to calculate the number of
+// values to initialize in a weight matrix remapping
+//
+// This functionality can be used to remap both row vocabularies (typically,
+// features) and column vocabularies (typically, classes) from TensorFlow
+// checkpoints. Note that the partitioning logic relies on contiguous vocabularies
+// corresponding to div-partitioned variables. Moreover, the underlying remapping
+// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should
+// use the corresponding index_table_from_file() as the FeatureColumn framework
+// does (as opposed to tf.feature_to_id(), which uses a CuckooTable).
+//
+// Arguments:
+// new_vocab_file: Path to the new vocab file.
+// old_vocab_file: Path to the old vocab file.
+// new_vocab_offset: How many entries into the new vocab file to start reading.
+// num_new_vocab: Number of entries in the new vocab file to remap.
+//
+// Returns A Tensor of length num_new_vocab where the element at index i
+// is equal to the old ID that maps to the new ID i. This element is -1 for any
+// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab.
+func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "GenerateVocabRemapping",
+ Input: []tf.Input{
+ new_vocab_file, old_vocab_file,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
+// Broadcasts a tensor value to one or more other devices.
+func CollectiveBcastSend(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape}
+ opspec := tf.OpSpec{
+ Type: "CollectiveBcastSend",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Mutually reduces multiple tensors of identical type and shape.
+func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets}
+ opspec := tf.OpSpec{
+ Type: "CollectiveReduce",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// AbortAttr is an optional argument to Abort.
+type AbortAttr func(optionalAttr)
+
+// AbortErrorMsg sets the optional error_msg attribute to value.
+//
+// value: A string which is the message associated with the exception.
+// If not specified, defaults to ""
+func AbortErrorMsg(value string) AbortAttr {
+ return func(m optionalAttr) {
+ m["error_msg"] = value
+ }
+}
+
+// AbortExitWithoutError sets the optional exit_without_error attribute to value.
+// If not specified, defaults to false
+func AbortExitWithoutError(value bool) AbortAttr {
+ return func(m optionalAttr) {
+ m["exit_without_error"] = value
+ }
+}
+
+// Raise a exception to abort the process when called.
+//
+// If exit_without_error is true, the process will exit normally,
+// otherwise it will exit with a SIGABORT signal.
+//
+// Returns nothing but an exception.
+//
+// Returns the created operation.
+func Abort(scope *Scope, optional ...AbortAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Abort",
+
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Forwards the input to the output.
+//
+// This operator represents the loop termination condition used by the
+// "pivot" switches of a loop.
+//
+// Arguments:
+// input: A boolean scalar, representing the branch predicate of the Switch op.
+//
+// Returns The same tensor as `input`.
+func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "LoopCond",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns a tensor of zeros with the same shape and type as x.
+//
+// Arguments:
+// x: a tensor of type T.
+//
+// Returns a tensor of the same shape and type as x but filled with zeros.
+func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ZerosLike",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns a copy of the input tensor.
+func Snapshot(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Snapshot",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ResourceStridedSliceAssignAttr is an optional argument to ResourceStridedSliceAssign.
type ResourceStridedSliceAssignAttr func(optionalAttr)
@@ -10182,23 +10519,6 @@ func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...Ass
return scope.AddOperation(opspec)
}
-// Broadcasts a tensor value to one or more other devices.
-func CollectiveBcastSend(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape}
- opspec := tf.OpSpec{
- Type: "CollectiveBcastSend",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Split a `SparseTensor` into `num_split` tensors along one dimension.
//
// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices
@@ -10776,23 +11096,6 @@ func ResourceScatterNdAdd(scope *Scope, ref tf.Output, indices tf.Output, update
return scope.AddOperation(opspec)
}
-// Mutually reduces multiple tensors of identical type and shape.
-func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets}
- opspec := tf.OpSpec{
- Type: "CollectiveReduce",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Updates the tree ensemble by either adding a layer to the last tree being grown
//
// or by starting a new tree.
@@ -11671,6 +11974,49 @@ func ResourceApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.
return scope.AddOperation(opspec)
}
+// Exits the current frame to its parent frame.
+//
+// Exit makes its input `data` available to the parent frame.
+//
+// Arguments:
+// data: The tensor to be made available to the parent frame.
+//
+// Returns The same tensor as `data`.
+func Exit(scope *Scope, data tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Exit",
+ Input: []tf.Input{
+ data,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Produce a string tensor that encodes the state of a Reader.
+//
+// Not all Readers support being serialized, so this can produce an
+// Unimplemented error.
+//
+// Arguments:
+// reader_handle: Handle to a Reader.
+func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ReaderSerializeStateV2",
+ Input: []tf.Input{
+ reader_handle,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter.
type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr)
@@ -11804,68 +12150,6 @@ func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (o
return op.Output(0)
}
-// StringSplitV2Attr is an optional argument to StringSplitV2.
-type StringSplitV2Attr func(optionalAttr)
-
-// StringSplitV2Maxsplit sets the optional maxsplit attribute to value.
-//
-// value: An `int`. If `maxsplit > 0`, limit of the split of the result.
-// If not specified, defaults to -1
-func StringSplitV2Maxsplit(value int64) StringSplitV2Attr {
- return func(m optionalAttr) {
- m["maxsplit"] = value
- }
-}
-
-// Split elements of `source` based on `sep` into a `SparseTensor`.
-//
-// Let N be the size of source (typically N will be the batch size). Split each
-// element of `source` based on `sep` and return a `SparseTensor`
-// containing the split tokens. Empty tokens are ignored.
-//
-// For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c',
-// then the output will be
-// ```
-// st.indices = [0, 0;
-// 0, 1;
-// 1, 0;
-// 1, 1;
-// 1, 2]
-// st.shape = [2, 3]
-// st.values = ['hello', 'world', 'a', 'b', 'c']
-// ```
-//
-// If `sep` is given, consecutive delimiters are not grouped together and are
-// deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
-// sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
-// string, consecutive whitespace are regarded as a single separator, and the
-// result will contain no empty strings at the startor end if the string has
-// leading or trailing whitespace.
-//
-// Note that the above mentioned behavior matches python's str.split.
-//
-// Arguments:
-// input: `1-D` string `Tensor`, the strings to split.
-// sep: `0-D` string `Tensor`, the delimiter character.
-func StringSplitV2(scope *Scope, input tf.Output, sep tf.Output, optional ...StringSplitV2Attr) (indices tf.Output, values tf.Output, shape tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "StringSplitV2",
- Input: []tf.Input{
- input, sep,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// MaxPoolAttr is an optional argument to MaxPool.
type MaxPoolAttr func(optionalAttr)
@@ -12435,21 +12719,6 @@ func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...
return op.Output(0)
}
-// Computes softsign: `features / (abs(features) + 1)`.
-func Softsign(scope *Scope, features tf.Output) (activations tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Softsign",
- Input: []tf.Input{
- features,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Creates a TensorList which, when stacked, has the value of `tensor`.
//
// Each tensor in the result list corresponds to one row of the input tensor.
@@ -12470,81 +12739,6 @@ func TensorListFromTensor(scope *Scope, tensor tf.Output, element_shape tf.Outpu
return op.Output(0)
}
-// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping.
-type GenerateVocabRemappingAttr func(optionalAttr)
-
-// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value.
-//
-// value: Number of entries in the old vocab file to consider. If -1,
-// use the entire old vocabulary.
-// If not specified, defaults to -1
-//
-// REQUIRES: value >= -1
-func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr {
- return func(m optionalAttr) {
- m["old_vocab_size"] = value
- }
-}
-
-// Given a path to new and old vocabulary files, returns a remapping Tensor of
-//
-// length `num_new_vocab`, where `remapping[i]` contains the row number in the old
-// vocabulary that corresponds to row `i` in the new vocabulary (starting at line
-// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i`
-// in the new vocabulary is not in the old vocabulary. The old vocabulary is
-// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the
-// default value of -1.
-//
-// `num_vocab_offset` enables
-// use in the partitioned variable case, and should generally be set through
-// examining partitioning info. The format of the files should be a text file,
-// with each line containing a single entity within the vocabulary.
-//
-// For example, with `new_vocab_file` a text file containing each of the following
-// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3],
-// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be
-// `[0, -1, 2]`.
-//
-// The op also returns a count of how many entries in the new vocabulary
-// were present in the old vocabulary, which is used to calculate the number of
-// values to initialize in a weight matrix remapping
-//
-// This functionality can be used to remap both row vocabularies (typically,
-// features) and column vocabularies (typically, classes) from TensorFlow
-// checkpoints. Note that the partitioning logic relies on contiguous vocabularies
-// corresponding to div-partitioned variables. Moreover, the underlying remapping
-// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should
-// use the corresponding index_table_from_file() as the FeatureColumn framework
-// does (as opposed to tf.feature_to_id(), which uses a CuckooTable).
-//
-// Arguments:
-// new_vocab_file: Path to the new vocab file.
-// old_vocab_file: Path to the old vocab file.
-// new_vocab_offset: How many entries into the new vocab file to start reading.
-// num_new_vocab: Number of entries in the new vocab file to remap.
-//
-// Returns A Tensor of length num_new_vocab where the element at index i
-// is equal to the old ID that maps to the new ID i. This element is -1 for any
-// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab.
-func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "GenerateVocabRemapping",
- Input: []tf.Input{
- new_vocab_file, old_vocab_file,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
// Assigns sparse updates to the variable referenced by `resource`.
//
// This operation computes
@@ -13547,6 +13741,27 @@ func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAtt
return op.Output(0)
}
+// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics.
+//
+// Arguments:
+// tree_ensemble_handle: Handle to the tree ensemble.
+//
+// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest
+// layer.
+func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesGetEnsembleStates",
+ Input: []tf.Input{
+ tree_ensemble_handle,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
+}
+
// ResourceApplyPowerSignAttr is an optional argument to ResourceApplyPowerSign.
type ResourceApplyPowerSignAttr func(optionalAttr)
@@ -16327,79 +16542,6 @@ func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...Ra
return op.Output(0)
}
-// LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler.
-type LogUniformCandidateSamplerAttr func(optionalAttr)
-
-// LogUniformCandidateSamplerSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func LogUniformCandidateSamplerSeed(value int64) LogUniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// LogUniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func LogUniformCandidateSamplerSeed2(value int64) LogUniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Generates labels for candidate sampling with a log-uniform distribution.
-//
-// See explanations of candidate sampling and the data formats at
-// go/candidate-sampling.
-//
-// For each batch, this op picks a single set of sampled candidate labels.
-//
-// The advantages of sampling candidates per-batch are simplicity and the
-// possibility of efficient dense matrix multiplication. The disadvantage is that
-// the sampled candidates must be chosen independently of the context and of the
-// true labels.
-//
-// Arguments:
-// true_classes: A batch_size * num_true matrix, in which each row contains the
-// IDs of the num_true target_classes in the corresponding original label.
-// num_true: Number of true labels per context.
-// num_sampled: Number of candidates to randomly sample.
-// unique: If unique is true, we sample with rejection, so that all sampled
-// candidates in a batch are unique. This requires some approximation to
-// estimate the post-rejection sampling probabilities.
-// range_max: The sampler will sample integers from the interval [0, range_max).
-//
-// Returns A vector of length num_sampled, in which each element is
-// the ID of a sampled candidate.A batch_size * num_true matrix, representing
-// the number of times each candidate is expected to occur in a batch
-// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
-// candidate representing the number of times the candidate is expected
-// to occur in a batch of sampled candidates. If unique=true, then this is a
-// probability.
-func LogUniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LogUniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "LogUniformCandidateSampler",
- Input: []tf.Input{
- true_classes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// Returns the max of x and y (i.e. x > y ? x : y) element-wise.
//
// *NOTE*: `Maximum` supports broadcasting. More about broadcasting
@@ -19444,31 +19586,6 @@ func SparseDenseCwiseAdd(scope *Scope, sp_indices tf.Output, sp_values tf.Output
return op.Output(0)
}
-// Read an element from the TensorArray into output `value`.
-//
-// Arguments:
-// handle: The handle to a TensorArray.
-//
-// flow_in: A float scalar that enforces proper chaining of operations.
-// dtype: The type of the elem that is returned.
-//
-// Returns The tensor that is read from the TensorArray.
-func TensorArrayReadV3(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype}
- opspec := tf.OpSpec{
- Type: "TensorArrayReadV3",
- Input: []tf.Input{
- handle, index, flow_in,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// QuantizeV2Attr is an optional argument to QuantizeV2.
type QuantizeV2Attr func(optionalAttr)
@@ -20866,6 +20983,201 @@ func Sum(scope *Scope, input tf.Output, axis tf.Output, optional ...SumAttr) (ou
return op.Output(0)
}
+// EnterAttr is an optional argument to Enter.
+type EnterAttr func(optionalAttr)
+
+// EnterIsConstant sets the optional is_constant attribute to value.
+//
+// value: If true, the output is constant within the child frame.
+// If not specified, defaults to false
+func EnterIsConstant(value bool) EnterAttr {
+ return func(m optionalAttr) {
+ m["is_constant"] = value
+ }
+}
+
+// EnterParallelIterations sets the optional parallel_iterations attribute to value.
+//
+// value: The number of iterations allowed to run in parallel.
+// If not specified, defaults to 10
+func EnterParallelIterations(value int64) EnterAttr {
+ return func(m optionalAttr) {
+ m["parallel_iterations"] = value
+ }
+}
+
+// Creates or finds a child frame, and makes `data` available to the child frame.
+//
+// This op is used together with `Exit` to create loops in the graph.
+// The unique `frame_name` is used by the `Executor` to identify frames. If
+// `is_constant` is true, `output` is a constant in the child frame; otherwise
+// it may be changed in the child frame. At most `parallel_iterations` iterations
+// are run in parallel in the child frame.
+//
+// Arguments:
+// data: The tensor to be made available to the child frame.
+// frame_name: The name of the child frame.
+//
+// Returns The same tensor as `data`.
+func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"frame_name": frame_name}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Enter",
+ Input: []tf.Input{
+ data,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Add all input tensors element wise.
+//
+// Arguments:
+// inputs: Must all be the same size and shape.
+func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "AddN",
+ Input: []tf.Input{
+ tf.OutputList(inputs),
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// TryRpcAttr is an optional argument to TryRpc.
+type TryRpcAttr func(optionalAttr)
+
+// TryRpcProtocol sets the optional protocol attribute to value.
+//
+// value: RPC protocol to use. Empty string means use the default protocol.
+// Options include 'grpc'.
+// If not specified, defaults to ""
+func TryRpcProtocol(value string) TryRpcAttr {
+ return func(m optionalAttr) {
+ m["protocol"] = value
+ }
+}
+
+// TryRpcFailFast sets the optional fail_fast attribute to value.
+//
+// value: `boolean`. If `true` (default), then failures to connect
+// (i.e., the server does not immediately respond) cause an RPC failure.
+// If not specified, defaults to true
+func TryRpcFailFast(value bool) TryRpcAttr {
+ return func(m optionalAttr) {
+ m["fail_fast"] = value
+ }
+}
+
+// TryRpcTimeoutInMs sets the optional timeout_in_ms attribute to value.
+//
+// value: `int`. If `0` (default), then the kernel will run the RPC
+// request and only time out if the RPC deadline passes or the session times out.
+// If this value is greater than `0`, then the op will raise an exception if
+// the RPC takes longer than `timeout_in_ms`.
+// If not specified, defaults to 0
+func TryRpcTimeoutInMs(value int64) TryRpcAttr {
+ return func(m optionalAttr) {
+ m["timeout_in_ms"] = value
+ }
+}
+
+// Perform batches of RPC requests.
+//
+// This op asynchronously performs either a single RPC request, or a batch
+// of requests. RPC requests are defined by three main parameters:
+//
+// - `address` (the host+port or BNS address of the request)
+// - `method` (the method name for the request)
+// - `request` (the serialized proto string, or vector of strings,
+// of the RPC request argument).
+//
+// For example, if you have an RPC service running on port localhost:2345,
+// and its interface is configured with the following proto declaration:
+//
+// ```
+// service MyService {
+// rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
+// }
+// };
+// ```
+//
+// then call this op with arguments:
+//
+// ```
+// address = "localhost:2345"
+// method = "MyService/MyMethod"
+// ```
+//
+// The `request` tensor is a string tensor representing serialized `MyRequestProto`
+// strings; and the output string tensor `response` will have the same shape
+// and contain (upon successful completion) corresponding serialized
+// `MyResponseProto` strings.
+//
+// For example, to send a single, empty, `MyRequestProto`, call
+// this op with `request = ""`. To send 5 **parallel** empty requests,
+// call this op with `request = ["", "", "", "", ""]`.
+//
+// More generally, one can create a batch of `MyRequestProto` serialized protos
+// from regular batched tensors using the `encode_proto` op, and convert
+// the response `MyResponseProto` serialized protos to batched tensors
+// using the `decode_proto` op.
+//
+// **NOTE** Working with serialized proto strings is faster than instantiating
+// actual proto objects in memory, so no performance degradation is expected
+// compared to writing custom kernels for this workflow.
+//
+// Unlike the standard `Rpc` op, if the connection fails or the remote worker
+// returns an error status, this op does **not** reraise the exception.
+// Instead, the `status_code` and `status_message` entry for the corresponding RPC
+// call is set with the error returned from the RPC call. The `response` tensor
+// will contain valid response values for those minibatch entries whose RPCs did
+// not fail; the rest of the entries will have empty strings.
+//
+// Arguments:
+// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
+// If this tensor has more than 1 element, then multiple parallel rpc requests
+// are sent. This argument broadcasts with `method` and `request`.
+// method: `0-D` or `1-D`. The method address on the RPC server.
+// If this tensor has more than 1 element, then multiple parallel rpc requests
+// are sent. This argument broadcasts with `address` and `request`.
+// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument.
+// If this tensor has more than 1 element, then multiple parallel rpc requests
+// are sent. This argument broadcasts with `address` and `method`.
+//
+// Returns Same shape as `request`. Serialized proto strings: the rpc responses.Same shape as `request`. Values correspond to tensorflow Status enum codes.Same shape as `request`. Values correspond to Status messages
+// returned from the RPC calls.
+func TryRpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...TryRpcAttr) (response tf.Output, status_code tf.Output, status_message tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "TryRpc",
+ Input: []tf.Input{
+ address, method, request,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
// Delete the tensor specified by its handle in the session.
//
// Arguments:
@@ -21612,29 +21924,6 @@ func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Outp
return op.Output(0), op.Output(1), op.Output(2)
}
-// Forwards the input to the output.
-//
-// This operator represents the loop termination condition used by the
-// "pivot" switches of a loop.
-//
-// Arguments:
-// input: A boolean scalar, representing the branch predicate of the Switch op.
-//
-// Returns The same tensor as `input`.
-func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "LoopCond",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes the sum along segments of a tensor.
//
// Read
@@ -24163,6 +24452,31 @@ func TensorSummary(scope *Scope, tensor tf.Output, optional ...TensorSummaryAttr
return op.Output(0)
}
+// Read an element from the TensorArray into output `value`.
+//
+// Arguments:
+// handle: The handle to a TensorArray.
+//
+// flow_in: A float scalar that enforces proper chaining of operations.
+// dtype: The type of the elem that is returned.
+//
+// Returns The tensor that is read from the TensorArray.
+func TensorArrayReadV3(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtype": dtype}
+ opspec := tf.OpSpec{
+ Type: "TensorArrayReadV3",
+ Input: []tf.Input{
+ handle, index, flow_in,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the gradient for the tanh of `x` wrt its input.
//
// Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy`
@@ -27849,178 +28163,6 @@ func FakeParam(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Outpu
return op.Output(0)
}
-// EncodeProtoAttr is an optional argument to EncodeProto.
-type EncodeProtoAttr func(optionalAttr)
-
-// EncodeProtoDescriptorSource sets the optional descriptor_source attribute to value.
-// If not specified, defaults to "local://"
-func EncodeProtoDescriptorSource(value string) EncodeProtoAttr {
- return func(m optionalAttr) {
- m["descriptor_source"] = value
- }
-}
-
-// The op serializes protobuf messages provided in the input tensors.
-//
-// The types of the tensors in `values` must match the schema for the
-// fields specified in `field_names`. All the tensors in `values` must
-// have a common shape prefix, *batch_shape*.
-//
-// The `sizes` tensor specifies repeat counts for each field. The repeat
-// count (last dimension) of a each tensor in `values` must be greater
-// than or equal to corresponding repeat count in `sizes`.
-//
-// A `message_type` name must be provided to give context for the field
-// names. The actual message descriptor can be looked up either in the
-// linked-in descriptor pool or a filename provided by the caller using
-// the `descriptor_source` attribute.
-//
-// The `descriptor_source` attribute selects a source of protocol
-// descriptors to consult when looking up `message_type`. This may be a
-// filename containing a serialized `FileDescriptorSet` message,
-// or the special value `local://`, in which case only descriptors linked
-// into the code will be searched; the filename can be on any filesystem
-// accessible to TensorFlow.
-//
-// You can build a `descriptor_source` file using the `--descriptor_set_out`
-// and `--include_imports` options to the protocol compiler `protoc`.
-//
-// The `local://` database only covers descriptors linked into the
-// code via C++ libraries, not Python imports. You can link in a proto descriptor
-// by creating a cc_library target with alwayslink=1.
-//
-// There are a few special cases in the value mapping:
-//
-// Submessage and group fields must be pre-serialized as TensorFlow strings.
-//
-// TensorFlow lacks support for unsigned int64s, so they must be
-// represented as `tf.int64` with the same twos-complement bit pattern
-// (the obvious way).
-//
-// Unsigned int32 values can be represented exactly with `tf.int64`, or
-// with sign wrapping if the input is of type `tf.int32`.
-//
-// Arguments:
-// sizes: Tensor of int32 with shape `[batch_shape, len(field_names)]`.
-// values: List of tensors containing values for the corresponding field.
-// field_names: List of strings containing proto field names.
-// message_type: Name of the proto message type to decode.
-//
-// Returns Tensor of serialized protos with shape `batch_shape`.
-func EncodeProto(scope *Scope, sizes tf.Output, values []tf.Output, field_names []string, message_type string, optional ...EncodeProtoAttr) (bytes tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"field_names": field_names, "message_type": message_type}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "EncodeProto",
- Input: []tf.Input{
- sizes, tf.OutputList(values),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Creates a TensorArray for storing the gradients of values in the given handle.
-//
-// If the given TensorArray gradient already exists, returns a reference to it.
-//
-// Locks the size of the original TensorArray by disabling its dynamic size flag.
-//
-// **A note about the input flow_in:**
-//
-// The handle flow_in forces the execution of the gradient lookup to occur
-// only after certain other operations have occurred. For example, when
-// the forward TensorArray is dynamically sized, writes to this TensorArray
-// may resize the object. The gradient TensorArray is statically sized based
-// on the size of the forward TensorArray when this operation executes.
-// Furthermore, the size of the forward TensorArray is frozen by this call.
-// As a result, the flow is used to ensure that the call to generate the gradient
-// TensorArray only happens after all writes are executed.
-//
-// In the case of dynamically sized TensorArrays, gradient computation should
-// only be performed on read operations that have themselves been chained via
-// flow to occur only after all writes have executed. That way the final size
-// of the forward TensorArray is known when this operation is called.
-//
-// **A note about the source attribute:**
-//
-// TensorArray gradient calls use an accumulator TensorArray object. If
-// multiple gradients are calculated and run in the same session, the multiple
-// gradient nodes may accidentally flow through the same accumulator TensorArray.
-// This double counts and generally breaks the TensorArray gradient flow.
-//
-// The solution is to identify which gradient call this particular
-// TensorArray gradient is being called in. This is performed by identifying
-// a unique string (e.g. "gradients", "gradients_1", ...) from the input
-// gradient Tensor's name. This string is used as a suffix when creating
-// the TensorArray gradient object here (the attribute `source`).
-//
-// The attribute `source` is added as a suffix to the forward TensorArray's
-// name when performing the creation / lookup, so that each separate gradient
-// calculation gets its own TensorArray accumulator.
-//
-// Arguments:
-// handle: The handle to the forward TensorArray.
-// flow_in: A float scalar that enforces proper chaining of operations.
-// source: The gradient source string, used to decide which gradient TensorArray
-// to return.
-func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"source": source}
- opspec := tf.OpSpec{
- Type: "TensorArrayGradV3",
- Input: []tf.Input{
- handle, flow_in,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
-// Creates a dataset that splits a SparseTensor into elements row-wise.
-func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output) (handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SparseTensorSliceDataset",
- Input: []tf.Input{
- indices, values, dense_shape,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns x / y element-wise for real types.
-//
-// If `x` and `y` are reals, this will return the floating-point division.
-//
-// *NOTE*: `Div` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "RealDiv",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Adds v into specified rows of x.
//
// Computes y = x; y[i, :] += v; return y.
@@ -28316,6 +28458,255 @@ func StackPushV2(scope *Scope, handle tf.Output, elem tf.Output, optional ...Sta
return op.Output(0)
}
+// StringSplitV2Attr is an optional argument to StringSplitV2.
+type StringSplitV2Attr func(optionalAttr)
+
+// StringSplitV2Maxsplit sets the optional maxsplit attribute to value.
+//
+// value: An `int`. If `maxsplit > 0`, limit of the split of the result.
+// If not specified, defaults to -1
+func StringSplitV2Maxsplit(value int64) StringSplitV2Attr {
+ return func(m optionalAttr) {
+ m["maxsplit"] = value
+ }
+}
+
+// Split elements of `source` based on `sep` into a `SparseTensor`.
+//
+// Let N be the size of source (typically N will be the batch size). Split each
+// element of `source` based on `sep` and return a `SparseTensor`
+// containing the split tokens. Empty tokens are ignored.
+//
+// For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c',
+// then the output will be
+// ```
+// st.indices = [0, 0;
+// 0, 1;
+// 1, 0;
+// 1, 1;
+// 1, 2]
+// st.shape = [2, 3]
+// st.values = ['hello', 'world', 'a', 'b', 'c']
+// ```
+//
+// If `sep` is given, consecutive delimiters are not grouped together and are
+// deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
+// sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
+// string, consecutive whitespace are regarded as a single separator, and the
+// result will contain no empty strings at the startor end if the string has
+// leading or trailing whitespace.
+//
+// Note that the above mentioned behavior matches python's str.split.
+//
+// Arguments:
+// input: `1-D` string `Tensor`, the strings to split.
+// sep: `0-D` string `Tensor`, the delimiter character.
+func StringSplitV2(scope *Scope, input tf.Output, sep tf.Output, optional ...StringSplitV2Attr) (indices tf.Output, values tf.Output, shape tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "StringSplitV2",
+ Input: []tf.Input{
+ input, sep,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// Computes softsign: `features / (abs(features) + 1)`.
+func Softsign(scope *Scope, features tf.Output) (activations tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Softsign",
+ Input: []tf.Input{
+ features,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// EncodeProtoAttr is an optional argument to EncodeProto.
+type EncodeProtoAttr func(optionalAttr)
+
+// EncodeProtoDescriptorSource sets the optional descriptor_source attribute to value.
+// If not specified, defaults to "local://"
+func EncodeProtoDescriptorSource(value string) EncodeProtoAttr {
+ return func(m optionalAttr) {
+ m["descriptor_source"] = value
+ }
+}
+
+// The op serializes protobuf messages provided in the input tensors.
+//
+// The types of the tensors in `values` must match the schema for the
+// fields specified in `field_names`. All the tensors in `values` must
+// have a common shape prefix, *batch_shape*.
+//
+// The `sizes` tensor specifies repeat counts for each field. The repeat
+// count (last dimension) of a each tensor in `values` must be greater
+// than or equal to corresponding repeat count in `sizes`.
+//
+// A `message_type` name must be provided to give context for the field
+// names. The actual message descriptor can be looked up either in the
+// linked-in descriptor pool or a filename provided by the caller using
+// the `descriptor_source` attribute.
+//
+// The `descriptor_source` attribute selects a source of protocol
+// descriptors to consult when looking up `message_type`. This may be a
+// filename containing a serialized `FileDescriptorSet` message,
+// or the special value `local://`, in which case only descriptors linked
+// into the code will be searched; the filename can be on any filesystem
+// accessible to TensorFlow.
+//
+// You can build a `descriptor_source` file using the `--descriptor_set_out`
+// and `--include_imports` options to the protocol compiler `protoc`.
+//
+// The `local://` database only covers descriptors linked into the
+// code via C++ libraries, not Python imports. You can link in a proto descriptor
+// by creating a cc_library target with alwayslink=1.
+//
+// There are a few special cases in the value mapping:
+//
+// Submessage and group fields must be pre-serialized as TensorFlow strings.
+//
+// TensorFlow lacks support for unsigned int64s, so they must be
+// represented as `tf.int64` with the same twos-complement bit pattern
+// (the obvious way).
+//
+// Unsigned int32 values can be represented exactly with `tf.int64`, or
+// with sign wrapping if the input is of type `tf.int32`.
+//
+// Arguments:
+// sizes: Tensor of int32 with shape `[batch_shape, len(field_names)]`.
+// values: List of tensors containing values for the corresponding field.
+// field_names: List of strings containing proto field names.
+// message_type: Name of the proto message type to decode.
+//
+// Returns Tensor of serialized protos with shape `batch_shape`.
+func EncodeProto(scope *Scope, sizes tf.Output, values []tf.Output, field_names []string, message_type string, optional ...EncodeProtoAttr) (bytes tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"field_names": field_names, "message_type": message_type}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "EncodeProto",
+ Input: []tf.Input{
+ sizes, tf.OutputList(values),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Creates a TensorArray for storing the gradients of values in the given handle.
+//
+// If the given TensorArray gradient already exists, returns a reference to it.
+//
+// Locks the size of the original TensorArray by disabling its dynamic size flag.
+//
+// **A note about the input flow_in:**
+//
+// The handle flow_in forces the execution of the gradient lookup to occur
+// only after certain other operations have occurred. For example, when
+// the forward TensorArray is dynamically sized, writes to this TensorArray
+// may resize the object. The gradient TensorArray is statically sized based
+// on the size of the forward TensorArray when this operation executes.
+// Furthermore, the size of the forward TensorArray is frozen by this call.
+// As a result, the flow is used to ensure that the call to generate the gradient
+// TensorArray only happens after all writes are executed.
+//
+// In the case of dynamically sized TensorArrays, gradient computation should
+// only be performed on read operations that have themselves been chained via
+// flow to occur only after all writes have executed. That way the final size
+// of the forward TensorArray is known when this operation is called.
+//
+// **A note about the source attribute:**
+//
+// TensorArray gradient calls use an accumulator TensorArray object. If
+// multiple gradients are calculated and run in the same session, the multiple
+// gradient nodes may accidentally flow through the same accumulator TensorArray.
+// This double counts and generally breaks the TensorArray gradient flow.
+//
+// The solution is to identify which gradient call this particular
+// TensorArray gradient is being called in. This is performed by identifying
+// a unique string (e.g. "gradients", "gradients_1", ...) from the input
+// gradient Tensor's name. This string is used as a suffix when creating
+// the TensorArray gradient object here (the attribute `source`).
+//
+// The attribute `source` is added as a suffix to the forward TensorArray's
+// name when performing the creation / lookup, so that each separate gradient
+// calculation gets its own TensorArray accumulator.
+//
+// Arguments:
+// handle: The handle to the forward TensorArray.
+// flow_in: A float scalar that enforces proper chaining of operations.
+// source: The gradient source string, used to decide which gradient TensorArray
+// to return.
+func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"source": source}
+ opspec := tf.OpSpec{
+ Type: "TensorArrayGradV3",
+ Input: []tf.Input{
+ handle, flow_in,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
+// Creates a dataset that splits a SparseTensor into elements row-wise.
+func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseTensorSliceDataset",
+ Input: []tf.Input{
+ indices, values, dense_shape,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns x / y element-wise for real types.
+//
+// If `x` and `y` are reals, this will return the floating-point division.
+//
+// *NOTE*: `Div` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "RealDiv",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Creates a dataset that concatenates `input_dataset` with `another_dataset`.
func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
if scope.Err() != nil {
@@ -32600,79 +32991,6 @@ func CudnnRNNParamsToCanonical(scope *Scope, num_layers tf.Output, num_units tf.
return weights, biases
}
-// UniformCandidateSamplerAttr is an optional argument to UniformCandidateSampler.
-type UniformCandidateSamplerAttr func(optionalAttr)
-
-// UniformCandidateSamplerSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func UniformCandidateSamplerSeed(value int64) UniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// UniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func UniformCandidateSamplerSeed2(value int64) UniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Generates labels for candidate sampling with a uniform distribution.
-//
-// See explanations of candidate sampling and the data formats at
-// go/candidate-sampling.
-//
-// For each batch, this op picks a single set of sampled candidate labels.
-//
-// The advantages of sampling candidates per-batch are simplicity and the
-// possibility of efficient dense matrix multiplication. The disadvantage is that
-// the sampled candidates must be chosen independently of the context and of the
-// true labels.
-//
-// Arguments:
-// true_classes: A batch_size * num_true matrix, in which each row contains the
-// IDs of the num_true target_classes in the corresponding original label.
-// num_true: Number of true labels per context.
-// num_sampled: Number of candidates to randomly sample.
-// unique: If unique is true, we sample with rejection, so that all sampled
-// candidates in a batch are unique. This requires some approximation to
-// estimate the post-rejection sampling probabilities.
-// range_max: The sampler will sample integers from the interval [0, range_max).
-//
-// Returns A vector of length num_sampled, in which each element is
-// the ID of a sampled candidate.A batch_size * num_true matrix, representing
-// the number of times each candidate is expected to occur in a batch
-// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
-// candidate representing the number of times the candidate is expected
-// to occur in a batch of sampled candidates. If unique=true, then this is a
-// probability.
-func UniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...UniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "UniformCandidateSampler",
- Input: []tf.Input{
- true_classes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// CTCLossAttr is an optional argument to CTCLoss.
type CTCLossAttr func(optionalAttr)
@@ -32823,321 +33141,3 @@ func Switch(scope *Scope, data tf.Output, pred tf.Output) (output_false tf.Outpu
op := scope.AddOperation(opspec)
return op.Output(0), op.Output(1)
}
-
-// Add all input tensors element wise.
-//
-// Arguments:
-// inputs: Must all be the same size and shape.
-func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "AddN",
- Input: []tf.Input{
- tf.OutputList(inputs),
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// TryRpcAttr is an optional argument to TryRpc.
-type TryRpcAttr func(optionalAttr)
-
-// TryRpcProtocol sets the optional protocol attribute to value.
-//
-// value: RPC protocol to use. Empty string means use the default protocol.
-// Options include 'grpc'.
-// If not specified, defaults to ""
-func TryRpcProtocol(value string) TryRpcAttr {
- return func(m optionalAttr) {
- m["protocol"] = value
- }
-}
-
-// TryRpcFailFast sets the optional fail_fast attribute to value.
-//
-// value: `boolean`. If `true` (default), then failures to connect
-// (i.e., the server does not immediately respond) cause an RPC failure.
-// If not specified, defaults to true
-func TryRpcFailFast(value bool) TryRpcAttr {
- return func(m optionalAttr) {
- m["fail_fast"] = value
- }
-}
-
-// TryRpcTimeoutInMs sets the optional timeout_in_ms attribute to value.
-//
-// value: `int`. If `0` (default), then the kernel will run the RPC
-// request and only time out if the RPC deadline passes or the session times out.
-// If this value is greater than `0`, then the op will raise an exception if
-// the RPC takes longer than `timeout_in_ms`.
-// If not specified, defaults to 0
-func TryRpcTimeoutInMs(value int64) TryRpcAttr {
- return func(m optionalAttr) {
- m["timeout_in_ms"] = value
- }
-}
-
-// Perform batches of RPC requests.
-//
-// This op asynchronously performs either a single RPC request, or a batch
-// of requests. RPC requests are defined by three main parameters:
-//
-// - `address` (the host+port or BNS address of the request)
-// - `method` (the method name for the request)
-// - `request` (the serialized proto string, or vector of strings,
-// of the RPC request argument).
-//
-// For example, if you have an RPC service running on port localhost:2345,
-// and its interface is configured with the following proto declaration:
-//
-// ```
-// service MyService {
-// rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
-// }
-// };
-// ```
-//
-// then call this op with arguments:
-//
-// ```
-// address = "localhost:2345"
-// method = "MyService/MyMethod"
-// ```
-//
-// The `request` tensor is a string tensor representing serialized `MyRequestProto`
-// strings; and the output string tensor `response` will have the same shape
-// and contain (upon successful completion) corresponding serialized
-// `MyResponseProto` strings.
-//
-// For example, to send a single, empty, `MyRequestProto`, call
-// this op with `request = ""`. To send 5 **parallel** empty requests,
-// call this op with `request = ["", "", "", "", ""]`.
-//
-// More generally, one can create a batch of `MyRequestProto` serialized protos
-// from regular batched tensors using the `encode_proto` op, and convert
-// the response `MyResponseProto` serialized protos to batched tensors
-// using the `decode_proto` op.
-//
-// **NOTE** Working with serialized proto strings is faster than instantiating
-// actual proto objects in memory, so no performance degradation is expected
-// compared to writing custom kernels for this workflow.
-//
-// Unlike the standard `Rpc` op, if the connection fails or the remote worker
-// returns an error status, this op does **not** reraise the exception.
-// Instead, the `status_code` and `status_message` entry for the corresponding RPC
-// call is set with the error returned from the RPC call. The `response` tensor
-// will contain valid response values for those minibatch entries whose RPCs did
-// not fail; the rest of the entries will have empty strings.
-//
-// Arguments:
-// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
-// If this tensor has more than 1 element, then multiple parallel rpc requests
-// are sent. This argument broadcasts with `method` and `request`.
-// method: `0-D` or `1-D`. The method address on the RPC server.
-// If this tensor has more than 1 element, then multiple parallel rpc requests
-// are sent. This argument broadcasts with `address` and `request`.
-// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument.
-// If this tensor has more than 1 element, then multiple parallel rpc requests
-// are sent. This argument broadcasts with `address` and `method`.
-//
-// Returns Same shape as `request`. Serialized proto strings: the rpc responses.Same shape as `request`. Values correspond to tensorflow Status enum codes.Same shape as `request`. Values correspond to Status messages
-// returned from the RPC calls.
-func TryRpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...TryRpcAttr) (response tf.Output, status_code tf.Output, status_message tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "TryRpc",
- Input: []tf.Input{
- address, method, request,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
-// EnterAttr is an optional argument to Enter.
-type EnterAttr func(optionalAttr)
-
-// EnterIsConstant sets the optional is_constant attribute to value.
-//
-// value: If true, the output is constant within the child frame.
-// If not specified, defaults to false
-func EnterIsConstant(value bool) EnterAttr {
- return func(m optionalAttr) {
- m["is_constant"] = value
- }
-}
-
-// EnterParallelIterations sets the optional parallel_iterations attribute to value.
-//
-// value: The number of iterations allowed to run in parallel.
-// If not specified, defaults to 10
-func EnterParallelIterations(value int64) EnterAttr {
- return func(m optionalAttr) {
- m["parallel_iterations"] = value
- }
-}
-
-// Creates or finds a child frame, and makes `data` available to the child frame.
-//
-// This op is used together with `Exit` to create loops in the graph.
-// The unique `frame_name` is used by the `Executor` to identify frames. If
-// `is_constant` is true, `output` is a constant in the child frame; otherwise
-// it may be changed in the child frame. At most `parallel_iterations` iterations
-// are run in parallel in the child frame.
-//
-// Arguments:
-// data: The tensor to be made available to the child frame.
-// frame_name: The name of the child frame.
-//
-// Returns The same tensor as `data`.
-func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"frame_name": frame_name}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Enter",
- Input: []tf.Input{
- data,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Produce a string tensor that encodes the state of a Reader.
-//
-// Not all Readers support being serialized, so this can produce an
-// Unimplemented error.
-//
-// Arguments:
-// reader_handle: Handle to a Reader.
-func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ReaderSerializeStateV2",
- Input: []tf.Input{
- reader_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Exits the current frame to its parent frame.
-//
-// Exit makes its input `data` available to the parent frame.
-//
-// Arguments:
-// data: The tensor to be made available to the parent frame.
-//
-// Returns The same tensor as `data`.
-func Exit(scope *Scope, data tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Exit",
- Input: []tf.Input{
- data,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns a copy of the input tensor.
-func Snapshot(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Snapshot",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns a tensor of zeros with the same shape and type as x.
-//
-// Arguments:
-// x: a tensor of type T.
-//
-// Returns a tensor of the same shape and type as x but filled with zeros.
-func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ZerosLike",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// AbortAttr is an optional argument to Abort.
-type AbortAttr func(optionalAttr)
-
-// AbortErrorMsg sets the optional error_msg attribute to value.
-//
-// value: A string which is the message associated with the exception.
-// If not specified, defaults to ""
-func AbortErrorMsg(value string) AbortAttr {
- return func(m optionalAttr) {
- m["error_msg"] = value
- }
-}
-
-// AbortExitWithoutError sets the optional exit_without_error attribute to value.
-// If not specified, defaults to false
-func AbortExitWithoutError(value bool) AbortAttr {
- return func(m optionalAttr) {
- m["exit_without_error"] = value
- }
-}
-
-// Raise a exception to abort the process when called.
-//
-// If exit_without_error is true, the process will exit normally,
-// otherwise it will exit with a SIGABORT signal.
-//
-// Returns nothing but an exception.
-//
-// Returns the created operation.
-func Abort(scope *Scope, optional ...AbortAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Abort",
-
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 410b3a553a..9275ad767e 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1639,6 +1639,15 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "experimental_dataset_ops_gen",
+ visibility = [
+ "//learning/brain/python/ops:__pkg__",
+ "//tensorflow:__subpackages__",
+ "//tensorflow/python/kernel_tests:__pkg__",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "image_ops_gen",
visibility = ["//learning/brain/python/ops:__pkg__"],
)
@@ -2008,6 +2017,7 @@ py_library(
":array_ops",
":cond_v2_impl",
":constant_op",
+ ":control_flow_ops",
":control_flow_util",
":framework_ops",
":function_def_to_graph",
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 5c0c405306..347833ce8f 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -120,11 +120,17 @@ class SessionTest(test_util.TensorFlowTestCase):
inp = constant_op.constant(10.0, name='W1')
self.assertAllEqual(inp.eval(), 10.0)
- devices = sess.list_devices()
- self.assertEqual(2, len(devices))
- for device in devices:
- self.assertEqual('CPU', framework_device_lib.DeviceSpec.from_string(
- device.name).device_type)
+ num_cpu_devices = 0
+ num_gpu_devices = 0
+ for device in sess.list_devices():
+ device_type = framework_device_lib.DeviceSpec.from_string(
+ device.name).device_type
+ if device_type == 'CPU':
+ num_cpu_devices += 1
+ elif device_type == 'GPU':
+ num_gpu_devices += 1
+ self.assertEqual(2, num_cpu_devices)
+ self.assertEqual(0, num_gpu_devices)
def testPerSessionThreads(self):
with session.Session(
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 88cad5d6d9..b74fce3a4c 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 27)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 28)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 5f9818566f..cadfe7f9e0 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -471,6 +471,9 @@ py_library(
srcs = ["test_base.py"],
deps = [
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/util:nest",
],
)
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py
index b4f64115b7..b730e10949 100644
--- a/tensorflow/python/data/kernel_tests/test_base.py
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -17,6 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import re
+
+from tensorflow.python.data.util import nest
+from tensorflow.python.eager import context
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.platform import test
@@ -24,6 +30,80 @@ class DatasetTestBase(test.TestCase):
"""Base class for dataset tests."""
def assertSparseValuesEqual(self, a, b):
+ """Asserts that two SparseTensors/SparseTensorValues are equal."""
self.assertAllEqual(a.indices, b.indices)
self.assertAllEqual(a.values, b.values)
self.assertAllEqual(a.dense_shape, b.dense_shape)
+
+ def getNext(self, dataset):
+ """Returns a callable that returns the next element of the dataset.
+
+ Example use:
+ ```python
+ # In both graph and eager modes
+ dataset = ...
+ nxt = self.getNext(dataset)
+ result = self.evaluate(nxt())
+ ```
+
+ Args:
+ dataset: A dataset whose next element is returned
+
+ Returns:
+ A callable that returns the next element of `dataset`
+ """
+ it = dataset.make_one_shot_iterator()
+ if context.executing_eagerly():
+ return it.get_next
+ else:
+ nxt = it.get_next()
+ return lambda: nxt
+
+ def assertDatasetsEqual(self, dataset1, dataset2):
+ """Checks that datasets are equal. Supports both graph and eager mode."""
+ self.assertEqual(dataset1.output_types, dataset2.output_types)
+ self.assertEqual(dataset1.output_classes, dataset2.output_classes)
+
+ next1 = self.getNext(dataset1)
+ next2 = self.getNext(dataset2)
+ while True:
+ try:
+ op1 = self.evaluate(next1())
+ except errors.OutOfRangeError:
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(next2())
+ break
+ op2 = self.evaluate(next2())
+
+ op1 = nest.flatten(op1)
+ op2 = nest.flatten(op2)
+ assert len(op1) == len(op2)
+ for i in range(len(op1)):
+ if isinstance(
+ op1[i],
+ (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
+ self.assertSparseValuesEqual(op1[i], op2[i])
+ else:
+ self.assertAllEqual(op1[i], op2[i])
+
+ def assertDatasetsRaiseSameError(self,
+ dataset1,
+ dataset2,
+ exception_class,
+ replacements=None):
+ """Checks that datasets raise the same error on the first get_next call."""
+ next1 = self.getNext(dataset1)
+ next2 = self.getNext(dataset2)
+ try:
+ self.evaluate(next1())
+ raise ValueError(
+ 'Expected dataset to raise an error of type %s, but it did not.' %
+ repr(exception_class))
+ except exception_class as e:
+ expected_message = e.message
+ for old, new, count in replacements:
+ expected_message = expected_message.replace(old, new, count)
+ # Check that the first segment of the error messages are the same.
+ with self.assertRaisesRegexp(exception_class,
+ re.escape(expected_message)):
+ self.evaluate(next2())
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index bd3562f1ff..b9b77d4a5b 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -126,7 +126,7 @@ class _WorkerContext(object):
replicated training.
task_id: an integer indicating id of the corresponding task. It can be
None if it is local training or in-graph replicated training.
- session_config: an optional @{tf.ConfigProto} object.
+ session_config: an optional `tf.ConfigProto` object.
rpc_layer: optional string specifying the RPC protocol for communication
with worker masters. If None or empty, hosts in the `cluster_spec` will
be used directly.
@@ -685,7 +685,7 @@ def run_distribute_coordinator(worker_fn,
in a cluster. If not set or empty, fall back to local training.
task_type: the current task type, optional if this is a client.
task_id: the current task id, optional if this is a client.
- session_config: an optional @{tf.ConfigProto} object which will be passed
+ session_config: an optional `tf.ConfigProto` object which will be passed
to `strategy`'s `configure` method and used to create a session.
rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
diff --git a/tensorflow/python/distribute/estimator_training.py b/tensorflow/python/distribute/estimator_training.py
index 8daa34c885..0289689134 100644
--- a/tensorflow/python/distribute/estimator_training.py
+++ b/tensorflow/python/distribute/estimator_training.py
@@ -62,7 +62,7 @@ def _get_global_id(cluster_spec, task_type, task_id, chief_task_type):
# Sort task names in cluster by "chief"/"master", "evaluator", "worker"
# and "ps". More details can be found at the documentation of
- # @{tf.estimator.RunConfig.global_id_in_cluster}.
+ # `tf.estimator.RunConfig.global_id_in_cluster`.
task_type_ordered_list = []
if chief_task_type in cluster_spec.jobs:
task_type_ordered_list = [chief_task_type]
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index ba1b7ec2b5..1c4c5951df 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -344,6 +344,7 @@ py_test(
":pandas_io",
":prediction_keys",
"//tensorflow:tensorflow_py_no_contrib",
+ "@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index 97971f9561..a6c2aaa7d9 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -131,9 +131,7 @@ class _DNNModel(training.Model):
name=None,
**kwargs):
super(_DNNModel, self).__init__(name=name, **kwargs)
- self._is_v2 = False
if feature_column_v2.is_feature_column_v2(feature_columns):
- self._is_v2 = True
self._input_layer = feature_column_v2.FeatureLayer(
feature_columns=feature_columns,
name='input_layer',
@@ -190,7 +188,6 @@ class _DNNModel(training.Model):
_scope=logits_scope)
self._add_layer(self._logits_layer, logits_scope.name)
self._logits_scope_name = logits_scope.name
- self._logits_layer._use_resource_variables = False # pylint: disable=protected-access
self._input_layer_partitioner = input_layer_partitioner
def call(self, features, mode):
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
index d16318659b..ae968e717a 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import shutil
import tempfile
+from absl.testing import parameterized
import numpy as np
import six
@@ -35,6 +36,7 @@ from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.estimator.inputs import pandas_io
from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import nn
@@ -119,7 +121,16 @@ class LinearOnlyRegressorPartitionerTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorPartitionerV2Test(
+ linear_testing_utils.BaseLinearRegressorPartitionerTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearOnlyRegressorEvaluationTest(
@@ -128,7 +139,16 @@ class LinearOnlyRegressorEvaluationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorEvaluationV2Test(
+ linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearOnlyRegressorPredictTest(
@@ -137,7 +157,16 @@ class LinearOnlyRegressorPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorPredictV2Test(
+ linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearOnlyRegressorIntegrationTest(
@@ -146,7 +175,16 @@ class LinearOnlyRegressorIntegrationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorIntegrationV2Test(
+ linear_testing_utils.BaseLinearRegressorIntegrationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearOnlyRegressorTrainingTest(
@@ -155,7 +193,16 @@ class LinearOnlyRegressorTrainingTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorTrainingV2Test(
+ linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
def _linear_classifier_fn(feature_columns,
@@ -185,7 +232,18 @@ class LinearOnlyClassifierTrainingTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearOnlyClassifierTrainingV2Test(
+ linear_testing_utils.BaseLinearClassifierTrainingTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearOnlyClassifierClassesEvaluationTest(
@@ -194,7 +252,18 @@ class LinearOnlyClassifierClassesEvaluationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearOnlyClassifierClassesEvaluationV2Test(
+ linear_testing_utils.BaseLinearClassifierEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearOnlyClassifierPredictTest(
@@ -203,7 +272,18 @@ class LinearOnlyClassifierPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearOnlyClassifierPredictV2Test(
+ linear_testing_utils.BaseLinearClassifierPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearOnlyClassifierIntegrationTest(
@@ -212,9 +292,21 @@ class LinearOnlyClassifierIntegrationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearOnlyClassifierIntegrationV2Test(
+ linear_testing_utils.BaseLinearClassifierIntegrationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
def setUp(self):
@@ -225,13 +317,15 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._model_dir)
- def _test_complete_flow(
- self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
- label_dimension, batch_size):
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, label_dimension, batch_size,
+ fc_impl):
linear_feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
dnn_feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
feature_columns = linear_feature_columns + dnn_feature_columns
est = dnn_linear_combined.DNNLinearCombinedRegressor(
linear_feature_columns=linear_feature_columns,
@@ -257,14 +351,14 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
self.assertAllEqual((batch_size, label_dimension), predictions.shape)
# EXPORT
- feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ feature_spec = fc_impl.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
- def test_numpy_input_fn(self):
+ def test_numpy_input_fn(self, fc_impl):
"""Tests complete flow with numpy_input_fn."""
label_dimension = 2
batch_size = 10
@@ -293,9 +387,10 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_pandas_input_fn(self):
+ def test_pandas_input_fn(self, fc_impl):
"""Tests complete flow with pandas_input_fn."""
if not HAS_PANDAS:
return
@@ -326,9 +421,10 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_input_fn_from_parse_example(self):
+ def test_input_fn_from_parse_example(self, fc_impl):
"""Tests complete flow with input_fn constructed from parse_example."""
label_dimension = 2
batch_size = 10
@@ -376,7 +472,8 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
predict_input_fn=_predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
# A function to mimic dnn-classifier init reuse same tests.
@@ -407,7 +504,16 @@ class DNNOnlyClassifierEvaluateTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNOnlyClassifierEvaluateV2Test(
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
class DNNOnlyClassifierPredictTest(
@@ -416,7 +522,16 @@ class DNNOnlyClassifierPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNOnlyClassifierPredictV2Test(
+ dnn_testing_utils.BaseDNNClassifierPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
class DNNOnlyClassifierTrainTest(
@@ -425,7 +540,16 @@ class DNNOnlyClassifierTrainTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNOnlyClassifierTrainV2Test(dnn_testing_utils.BaseDNNClassifierTrainTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
# A function to mimic dnn-regressor init reuse same tests.
@@ -454,7 +578,16 @@ class DNNOnlyRegressorEvaluateTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNOnlyRegressorEvaluateV2Test(
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNOnlyRegressorPredictTest(
@@ -463,7 +596,16 @@ class DNNOnlyRegressorPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNOnlyRegressorPredictV2Test(
+ dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNOnlyRegressorTrainTest(
@@ -472,9 +614,19 @@ class DNNOnlyRegressorTrainTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+class DNNOnlyRegressorTrainV2Test(dnn_testing_utils.BaseDNNRegressorTrainTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
+
+
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
def setUp(self):
@@ -488,13 +640,14 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
def _as_label(self, data_in_float):
return np.rint(data_in_float).astype(np.int64)
- def _test_complete_flow(
- self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
- n_classes, batch_size):
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, n_classes, batch_size, fc_impl):
linear_feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
dnn_feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
feature_columns = linear_feature_columns + dnn_feature_columns
est = dnn_linear_combined.DNNLinearCombinedClassifier(
linear_feature_columns=linear_feature_columns,
@@ -520,14 +673,14 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
# EXPORT
- feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ feature_spec = fc_impl.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
- def test_numpy_input_fn(self):
+ def test_numpy_input_fn(self, fc_impl):
"""Tests complete flow with numpy_input_fn."""
n_classes = 3
input_dimension = 2
@@ -559,9 +712,10 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_pandas_input_fn(self):
+ def test_pandas_input_fn(self, fc_impl):
"""Tests complete flow with pandas_input_fn."""
if not HAS_PANDAS:
return
@@ -593,9 +747,10 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_input_fn_from_parse_example(self):
+ def test_input_fn_from_parse_example(self, fc_impl):
"""Tests complete flow with input_fn constructed from parse_example."""
input_dimension = 2
n_classes = 3
@@ -647,9 +802,11 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
predict_input_fn=_predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNLinearCombinedTests(test.TestCase):
def setUp(self):
@@ -681,9 +838,9 @@ class DNNLinearCombinedTests(test.TestCase):
return optimizer_mock
- def test_train_op_calls_both_dnn_and_linear(self):
+ def test_train_op_calls_both_dnn_and_linear(self, fc_impl):
opt = gradient_descent.GradientDescentOptimizer(1.)
- x_column = feature_column.numeric_column('x')
+ x_column = fc_impl.numeric_column('x')
input_fn = numpy_io.numpy_input_fn(
x={'x': np.array([[0.], [1.]])},
y=np.array([[0.], [1.]]),
@@ -708,7 +865,7 @@ class DNNLinearCombinedTests(test.TestCase):
checkpoint_utils.load_variable(
self._model_dir, 'dnn_called'))
- def test_dnn_and_linear_logits_are_added(self):
+ def test_dnn_and_linear_logits_are_added(self, fc_impl):
with ops.Graph().as_default():
variables_lib.Variable([[1.0]], name='linear/linear_model/x/weights')
variables_lib.Variable([2.0], name='linear/linear_model/bias_weights')
@@ -719,7 +876,7 @@ class DNNLinearCombinedTests(test.TestCase):
variables_lib.Variable(1, name='global_step', dtype=dtypes.int64)
linear_testing_utils.save_variables_to_ckpt(self._model_dir)
- x_column = feature_column.numeric_column('x')
+ x_column = fc_impl.numeric_column('x')
est = dnn_linear_combined.DNNLinearCombinedRegressor(
linear_feature_columns=[x_column],
dnn_hidden_units=[1],
@@ -737,6 +894,7 @@ class DNNLinearCombinedTests(test.TestCase):
next(est.predict(input_fn=input_fn)))
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNLinearCombinedWarmStartingTest(test.TestCase):
def setUp(self):
@@ -758,11 +916,11 @@ class DNNLinearCombinedWarmStartingTest(test.TestCase):
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._ckpt_and_vocab_dir)
- def test_classifier_basic_warm_starting(self):
+ def test_classifier_basic_warm_starting(self, fc_impl):
"""Tests correctness of DNNLinearCombinedClassifier default warm-start."""
- age = feature_column.numeric_column('age')
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ age = fc_impl.numeric_column('age')
+ city = fc_impl.embedding_column(
+ fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -798,11 +956,11 @@ class DNNLinearCombinedWarmStartingTest(test.TestCase):
dnn_lc_classifier.get_variable_value(variable_name),
warm_started_dnn_lc_classifier.get_variable_value(variable_name))
- def test_regressor_basic_warm_starting(self):
+ def test_regressor_basic_warm_starting(self, fc_impl):
"""Tests correctness of DNNLinearCombinedRegressor default warm-start."""
- age = feature_column.numeric_column('age')
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ age = fc_impl.numeric_column('age')
+ city = fc_impl.embedding_column(
+ fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -836,11 +994,11 @@ class DNNLinearCombinedWarmStartingTest(test.TestCase):
dnn_lc_regressor.get_variable_value(variable_name),
warm_started_dnn_lc_regressor.get_variable_value(variable_name))
- def test_warm_starting_selective_variables(self):
+ def test_warm_starting_selective_variables(self, fc_impl):
"""Tests selecting variables to warm-start."""
- age = feature_column.numeric_column('age')
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ age = fc_impl.numeric_column('age')
+ city = fc_impl.embedding_column(
+ fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py
index 115dd18518..8b96284bd3 100644
--- a/tensorflow/python/estimator/canned/linear.py
+++ b/tensorflow/python/estimator/canned/linear.py
@@ -25,14 +25,18 @@ import six
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import optimizers
-from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variable_ops
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import ftrl
+from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import estimator_export
@@ -46,23 +50,42 @@ def _get_default_optimizer(feature_columns):
return ftrl.FtrlOptimizer(learning_rate=learning_rate)
-def _compute_fraction_of_zero(cols_to_vars):
- """Given a linear cols_to_vars dict, compute the fraction of zero weights.
+def _get_expanded_variable_list(var_list):
+ """Given a list of variables, expands them if they are partitioned.
Args:
- cols_to_vars: A dictionary mapping FeatureColumns to lists of tf.Variables
- like one returned from feature_column_lib.linear_model.
+ var_list: A list of variables.
+
+ Returns:
+ A list of variables where each partitioned variable is expanded to its
+ components.
+ """
+ returned_list = []
+ for variable in var_list:
+ if (isinstance(variable, variable_ops.Variable) or
+ resource_variable_ops.is_resource_variable(variable)):
+ returned_list.append(variable) # Single variable case.
+ else: # Must be a PartitionedVariable, so convert into a list.
+ returned_list.extend(list(variable))
+ return returned_list
+
+
+# TODO(rohanj): Consider making this a public utility method.
+def _compute_fraction_of_zero(variables):
+ """Given a linear variables list, compute the fraction of zero weights.
+
+ Args:
+ variables: A list or list of list of variables
Returns:
The fraction of zeros (sparsity) in the linear model.
"""
all_weight_vars = []
- for var_or_var_list in cols_to_vars.values():
+ for var_or_var_list in variables:
+ var_list = nest.flatten(var_or_var_list)
# Skip empty-lists associated with columns that created no Variables.
- if var_or_var_list:
- all_weight_vars += [
- array_ops.reshape(var, [-1]) for var in var_or_var_list
- ]
+ if var_list:
+ all_weight_vars += [array_ops.reshape(var, [-1]) for var in var_list]
return nn.zero_fraction(array_ops.concat(all_weight_vars, axis=0))
@@ -92,14 +115,36 @@ def _linear_logit_fn_builder(units, feature_columns, sparse_combiner='sum'):
Returns:
A `Tensor` representing the logits.
"""
- cols_to_vars = {}
- logits = feature_column_lib.linear_model(
- features=features,
- feature_columns=feature_columns,
- units=units,
- sparse_combiner=sparse_combiner,
- cols_to_vars=cols_to_vars)
- bias = cols_to_vars.pop('bias')
+ if feature_column_v2.is_feature_column_v2(feature_columns):
+ shared_state_manager = feature_column_v2.SharedEmbeddingStateManager()
+ linear_model = feature_column_v2.LinearModel(
+ feature_columns=feature_columns,
+ units=units,
+ sparse_combiner=sparse_combiner,
+ shared_state_manager=shared_state_manager)
+ logits = linear_model(features)
+ bias = linear_model.bias_variable
+
+ # We'd like to get all the non-bias variables associated with this
+ # LinearModel. This includes the shared embedding variables as well.
+ variables = linear_model.variables
+ variables.remove(bias)
+ variables.extend(shared_state_manager.variables)
+
+ # Expand (potential) Partitioned variables
+ bias = _get_expanded_variable_list([bias])
+ variables = _get_expanded_variable_list(variables)
+ else:
+ linear_model = feature_column._LinearModel( # pylint: disable=protected-access
+ feature_columns=feature_columns,
+ units=units,
+ sparse_combiner=sparse_combiner,
+ name='linear_model')
+ logits = linear_model(features)
+ cols_to_vars = linear_model.cols_to_vars()
+ bias = cols_to_vars.pop('bias')
+ variables = cols_to_vars.values()
+
if units > 1:
summary.histogram('bias', bias)
else:
@@ -107,7 +152,7 @@ def _linear_logit_fn_builder(units, feature_columns, sparse_combiner='sum'):
# so we should provide a scalar summary.
summary.scalar('bias', bias[0][0])
summary.scalar('fraction_of_zero_weights',
- _compute_fraction_of_zero(cols_to_vars))
+ _compute_fraction_of_zero(variables))
return logits
return linear_logit_fn
diff --git a/tensorflow/python/estimator/canned/linear_test.py b/tensorflow/python/estimator/canned/linear_test.py
index 59a230417d..3e6da5de22 100644
--- a/tensorflow/python/estimator/canned/linear_test.py
+++ b/tensorflow/python/estimator/canned/linear_test.py
@@ -20,6 +20,8 @@ from __future__ import print_function
from tensorflow.python.estimator.canned import linear
from tensorflow.python.estimator.canned import linear_testing_utils
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.platform import test
@@ -40,7 +42,16 @@ class LinearRegressorPartitionerTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearRegressorPartitionerV2Test(
+ linear_testing_utils.BaseLinearRegressorPartitionerTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearRegressorEvaluationTest(
@@ -49,7 +60,16 @@ class LinearRegressorEvaluationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearRegressorEvaluationV2Test(
+ linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearRegressorPredictTest(
@@ -58,7 +78,16 @@ class LinearRegressorPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearRegressorPredictV2Test(
+ linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearRegressorIntegrationTest(
@@ -67,7 +96,16 @@ class LinearRegressorIntegrationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearRegressorIntegrationV2Test(
+ linear_testing_utils.BaseLinearRegressorIntegrationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearRegressorTrainingTest(
@@ -76,19 +114,37 @@ class LinearRegressorTrainingTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
-# Tests for Linear Classifier.
+class LinearRegressorTrainingV2Test(
+ linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase):
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
+
+# Tests for Linear Classifier.
class LinearClassifierTrainingTest(
linear_testing_utils.BaseLinearClassifierTrainingTest, test.TestCase):
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearClassifierTrainingV2Test(
+ linear_testing_utils.BaseLinearClassifierTrainingTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearClassifierEvaluationTest(
@@ -97,7 +153,18 @@ class LinearClassifierEvaluationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearClassifierEvaluationV2Test(
+ linear_testing_utils.BaseLinearClassifierEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearClassifierPredictTest(
@@ -106,7 +173,18 @@ class LinearClassifierPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearClassifierPredictV2Test(
+ linear_testing_utils.BaseLinearClassifierPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearClassifierIntegrationTest(
@@ -115,7 +193,18 @@ class LinearClassifierIntegrationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearClassifierIntegrationV2Test(
+ linear_testing_utils.BaseLinearClassifierIntegrationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
# Tests for Linear logit_fn.
@@ -124,7 +213,17 @@ class LinearLogitFnTest(linear_testing_utils.BaseLinearLogitFnTest,
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
- linear_testing_utils.BaseLinearLogitFnTest.__init__(self)
+ linear_testing_utils.BaseLinearLogitFnTest.__init__(
+ self, fc_lib=feature_column)
+
+
+class LinearLogitFnV2Test(linear_testing_utils.BaseLinearLogitFnTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearLogitFnTest.__init__(
+ self, fc_lib=feature_column_v2)
# Tests for warm-starting with Linear logit_fn.
@@ -134,7 +233,22 @@ class LinearWarmStartingTest(linear_testing_utils.BaseLinearWarmStartingTest,
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearWarmStartingTest.__init__(
- self, _linear_classifier_fn, _linear_regressor_fn)
+ self,
+ _linear_classifier_fn,
+ _linear_regressor_fn,
+ fc_lib=feature_column)
+
+
+class LinearWarmStartingV2Test(linear_testing_utils.BaseLinearWarmStartingTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearWarmStartingTest.__init__(
+ self,
+ _linear_classifier_fn,
+ _linear_regressor_fn,
+ fc_lib=feature_column_v2)
if __name__ == '__main__':
diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py
index 65cdd50061..827352a70b 100644
--- a/tensorflow/python/estimator/canned/linear_testing_utils.py
+++ b/tensorflow/python/estimator/canned/linear_testing_utils.py
@@ -37,7 +37,8 @@ from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.estimator.inputs import pandas_io
-from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@@ -152,8 +153,9 @@ class CheckPartitionerVarHook(session_run_hook.SessionRunHook):
class BaseLinearRegressorPartitionerTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -173,7 +175,7 @@ class BaseLinearRegressorPartitionerTest(object):
return [partitions, 1] if shape[0] == x_dim else [1]
regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.categorical_column_with_hash_bucket(
+ feature_columns=(self._fc_lib.categorical_column_with_hash_bucket(
'language', hash_bucket_size=x_dim),),
partitioner=_partitioner,
model_dir=self._model_dir)
@@ -209,9 +211,8 @@ class BaseLinearRegressorPartitionerTest(object):
'_get_replica_device_setter',
return_value=lambda _: '/cpu:0'):
linear_regressor = self._linear_regressor_fn(
- feature_columns=(
- feature_column_lib.categorical_column_with_hash_bucket(
- 'language', hash_bucket_size=x_dim),),
+ feature_columns=(self._fc_lib.categorical_column_with_hash_bucket(
+ 'language', hash_bucket_size=x_dim),),
config=FakeRunConfig(),
model_dir=self._model_dir)
@@ -232,8 +233,9 @@ class BaseLinearRegressorPartitionerTest(object):
# TODO(b/36813849): Add tests with dynamic shape inputs using placeholders.
class BaseLinearRegressorEvaluationTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -252,7 +254,7 @@ class BaseLinearRegressorEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir)
eval_metrics = linear_regressor.evaluate(
input_fn=lambda: ({'age': ((1,),)}, ((10.,),)), steps=1)
@@ -276,7 +278,7 @@ class BaseLinearRegressorEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir)
eval_metrics = linear_regressor.evaluate(
input_fn=lambda: ({'age': ((1,), (1,))}, ((10.,), (10.,))), steps=1)
@@ -308,7 +310,7 @@ class BaseLinearRegressorEvaluationTest(object):
return features, labels
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
weight_column='weights',
model_dir=self._model_dir)
eval_metrics = linear_regressor.evaluate(input_fn=_input_fn, steps=1)
@@ -336,8 +338,7 @@ class BaseLinearRegressorEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column(
- 'age', shape=(x_dim,)),),
+ feature_columns=(self._fc_lib.numeric_column('age', shape=(x_dim,)),),
label_dimension=label_dim,
model_dir=self._model_dir)
input_fn = numpy_io.numpy_input_fn(
@@ -374,8 +375,8 @@ class BaseLinearRegressorEvaluationTest(object):
batch_size = 2
feature_columns = [
- feature_column_lib.numeric_column('age'),
- feature_column_lib.numeric_column('height')
+ self._fc_lib.numeric_column('age'),
+ self._fc_lib.numeric_column('height')
]
input_fn = numpy_io.numpy_input_fn(
x={'age': np.array([20, 40]),
@@ -402,8 +403,9 @@ class BaseLinearRegressorEvaluationTest(object):
class BaseLinearRegressorPredictTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -422,7 +424,7 @@ class BaseLinearRegressorPredictTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('x'),),
+ feature_columns=(self._fc_lib.numeric_column('x'),),
model_dir=self._model_dir)
predict_input_fn = numpy_io.numpy_input_fn(
@@ -441,7 +443,7 @@ class BaseLinearRegressorPredictTest(object):
batch_size = 2
label_dimension = 3
x_dim = 4
- feature_columns = (feature_column_lib.numeric_column('x', shape=(x_dim,)),)
+ feature_columns = (self._fc_lib.numeric_column('x', shape=(x_dim,)),)
with ops.Graph().as_default():
variables_lib.Variable( # shape=[x_dim, label_dimension]
[[1., 2., 3.], [2., 3., 4.], [3., 4., 5.], [4., 5., 6.]],
@@ -479,8 +481,8 @@ class BaseLinearRegressorPredictTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('x0'),
- feature_column_lib.numeric_column('x1')),
+ feature_columns=(self._fc_lib.numeric_column('x0'),
+ self._fc_lib.numeric_column('x1')),
model_dir=self._model_dir)
predict_input_fn = numpy_io.numpy_input_fn(
@@ -515,9 +517,8 @@ class BaseLinearRegressorPredictTest(object):
dense_shape=[2, 2]),
})
- feature_columns = (
- feature_column_lib.categorical_column_with_vocabulary_list(
- 'language', vocabulary_list=['a', 'b', 'c']),)
+ feature_columns = (self._fc_lib.categorical_column_with_vocabulary_list(
+ 'language', vocabulary_list=['a', 'b', 'c']),)
# Check prediction for each sparse_combiner.
# With sparse_combiner = 'sum', we have
@@ -561,8 +562,9 @@ class BaseLinearRegressorPredictTest(object):
class BaseLinearRegressorIntegrationTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -575,7 +577,7 @@ class BaseLinearRegressorIntegrationTest(object):
def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
input_dimension, label_dimension, prediction_length):
feature_columns = [
- feature_column_lib.numeric_column('x', shape=(input_dimension,))
+ self._fc_lib.numeric_column('x', shape=(input_dimension,))
]
est = self._linear_regressor_fn(
feature_columns=feature_columns,
@@ -597,7 +599,7 @@ class BaseLinearRegressorIntegrationTest(object):
self.assertAllEqual((prediction_length, label_dimension), predictions.shape)
# EXPORT
- feature_spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ feature_spec = self._fc_lib.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
@@ -729,8 +731,9 @@ class BaseLinearRegressorIntegrationTest(object):
class BaseLinearRegressorTrainingTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -808,7 +811,7 @@ class BaseLinearRegressorTrainingTest(object):
label = 5.
age = 17
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir)
# Train for a few steps, and validate final checkpoint.
@@ -820,7 +823,7 @@ class BaseLinearRegressorTrainingTest(object):
def testTrainWithOneDimLabel(self):
label_dimension = 1
batch_size = 20
- feature_columns = [feature_column_lib.numeric_column('age', shape=(1,))]
+ feature_columns = [self._fc_lib.numeric_column('age', shape=(1,))]
est = self._linear_regressor_fn(
feature_columns=feature_columns,
label_dimension=label_dimension,
@@ -840,7 +843,7 @@ class BaseLinearRegressorTrainingTest(object):
def testTrainWithOneDimWeight(self):
label_dimension = 1
batch_size = 20
- feature_columns = [feature_column_lib.numeric_column('age', shape=(1,))]
+ feature_columns = [self._fc_lib.numeric_column('age', shape=(1,))]
est = self._linear_regressor_fn(
feature_columns=feature_columns,
label_dimension=label_dimension,
@@ -867,7 +870,7 @@ class BaseLinearRegressorTrainingTest(object):
# loss = (logits - label)^2 = (0 - 5.)^2 = 25.
mock_optimizer = self._mock_optimizer(expected_loss=25.)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir,
optimizer=mock_optimizer)
self.assertEqual(0, mock_optimizer.minimize.call_count)
@@ -900,7 +903,7 @@ class BaseLinearRegressorTrainingTest(object):
# loss = (logits - label)^2 = (175 - 5)^2 = 28900
mock_optimizer = self._mock_optimizer(expected_loss=28900.)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir,
optimizer=mock_optimizer)
self.assertEqual(0, mock_optimizer.minimize.call_count)
@@ -935,7 +938,7 @@ class BaseLinearRegressorTrainingTest(object):
# loss = sum(logits - label)^2 = (175 - 5)^2 + (155 - 3)^2 = 52004
mock_optimizer = self._mock_optimizer(expected_loss=52004.)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir,
optimizer=mock_optimizer)
self.assertEqual(0, mock_optimizer.minimize.call_count)
@@ -954,8 +957,9 @@ class BaseLinearRegressorTrainingTest(object):
class BaseLinearClassifierTrainingTest(object):
- def __init__(self, linear_classifier_fn):
+ def __init__(self, linear_classifier_fn, fc_lib=feature_column):
self._linear_classifier_fn = linear_classifier_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1031,7 +1035,7 @@ class BaseLinearClassifierTrainingTest(object):
label = 0
age = 17
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1051,7 +1055,7 @@ class BaseLinearClassifierTrainingTest(object):
batch_size = 20
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
data_rank_1 = np.array([0, 1])
@@ -1078,7 +1082,7 @@ class BaseLinearClassifierTrainingTest(object):
batch_size = 20
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
data_rank_1 = np.array([0, 1])
@@ -1103,7 +1107,7 @@ class BaseLinearClassifierTrainingTest(object):
batch_size = 20
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
weight_column='w',
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1129,7 +1133,7 @@ class BaseLinearClassifierTrainingTest(object):
batch_size = 20
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
weight_column='w',
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1166,7 +1170,7 @@ class BaseLinearClassifierTrainingTest(object):
expected_loss=-1 * math.log(1.0/n_classes))
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
optimizer=mock_optimizer,
model_dir=self._model_dir)
@@ -1229,7 +1233,7 @@ class BaseLinearClassifierTrainingTest(object):
mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
optimizer=mock_optimizer,
model_dir=self._model_dir)
@@ -1277,7 +1281,7 @@ class BaseLinearClassifierTrainingTest(object):
mock_optimizer = self._mock_optimizer(expected_loss=1.1132617)
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
optimizer=mock_optimizer,
model_dir=self._model_dir)
@@ -1341,7 +1345,7 @@ class BaseLinearClassifierTrainingTest(object):
mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
optimizer=mock_optimizer,
model_dir=self._model_dir)
@@ -1368,8 +1372,9 @@ class BaseLinearClassifierTrainingTest(object):
class BaseLinearClassifierEvaluationTest(object):
- def __init__(self, linear_classifier_fn):
+ def __init__(self, linear_classifier_fn, fc_lib=feature_column):
self._linear_classifier_fn = linear_classifier_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1398,7 +1403,7 @@ class BaseLinearClassifierEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
est = self._linear_classifier_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
eval_metrics = est.evaluate(
@@ -1464,7 +1469,7 @@ class BaseLinearClassifierEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
est = self._linear_classifier_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
eval_metrics = est.evaluate(
@@ -1540,7 +1545,7 @@ class BaseLinearClassifierEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
est = self._linear_classifier_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
weight_column='w',
model_dir=self._model_dir)
@@ -1605,8 +1610,9 @@ class BaseLinearClassifierEvaluationTest(object):
class BaseLinearClassifierPredictTest(object):
- def __init__(self, linear_classifier_fn):
+ def __init__(self, linear_classifier_fn, fc_lib=feature_column):
self._linear_classifier_fn = linear_classifier_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1634,7 +1640,7 @@ class BaseLinearClassifierPredictTest(object):
save_variables_to_ckpt(self._model_dir)
est = self._linear_classifier_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
label_vocabulary=label_vocabulary,
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1730,9 +1736,8 @@ class BaseLinearClassifierPredictTest(object):
dense_shape=[2, 2]),
})
- feature_columns = (
- feature_column_lib.categorical_column_with_vocabulary_list(
- 'language', vocabulary_list=['a', 'b', 'c']),)
+ feature_columns = (self._fc_lib.categorical_column_with_vocabulary_list(
+ 'language', vocabulary_list=['a', 'b', 'c']),)
# Check prediction for each sparse_combiner.
# With sparse_combiner = 'sum', we have
@@ -1776,8 +1781,9 @@ class BaseLinearClassifierPredictTest(object):
class BaseLinearClassifierIntegrationTest(object):
- def __init__(self, linear_classifier_fn):
+ def __init__(self, linear_classifier_fn, fc_lib=feature_column):
self._linear_classifier_fn = linear_classifier_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1789,7 +1795,7 @@ class BaseLinearClassifierIntegrationTest(object):
def _test_complete_flow(self, n_classes, train_input_fn, eval_input_fn,
predict_input_fn, input_dimension, prediction_length):
feature_columns = [
- feature_column_lib.numeric_column('x', shape=(input_dimension,))
+ self._fc_lib.numeric_column('x', shape=(input_dimension,))
]
est = self._linear_classifier_fn(
feature_columns=feature_columns,
@@ -1811,7 +1817,7 @@ class BaseLinearClassifierIntegrationTest(object):
self.assertAllEqual((prediction_length, 1), predictions.shape)
# EXPORT
- feature_spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ feature_spec = self._fc_lib.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
@@ -1961,9 +1967,12 @@ class BaseLinearClassifierIntegrationTest(object):
class BaseLinearLogitFnTest(object):
+ def __init__(self, fc_lib=feature_column):
+ self._fc_lib = fc_lib
+
def test_basic_logit_correctness(self):
"""linear_logit_fn simply wraps feature_column_lib.linear_model."""
- age = feature_column_lib.numeric_column('age')
+ age = self._fc_lib.numeric_column('age')
with ops.Graph().as_default():
logit_fn = linear._linear_logit_fn_builder(units=2, feature_columns=[age])
logits = logit_fn(features={'age': [[23.], [31.]]})
@@ -1983,12 +1992,14 @@ class BaseLinearLogitFnTest(object):
def test_compute_fraction_of_zero(self):
"""Tests the calculation of sparsity."""
- age = feature_column_lib.numeric_column('age')
- occupation = feature_column_lib.categorical_column_with_hash_bucket(
+ if self._fc_lib != feature_column:
+ return
+ age = feature_column.numeric_column('age')
+ occupation = feature_column.categorical_column_with_hash_bucket(
'occupation', hash_bucket_size=5)
with ops.Graph().as_default():
cols_to_vars = {}
- feature_column_lib.linear_model(
+ feature_column.linear_model(
features={
'age': [[23.], [31.]],
'occupation': [['doctor'], ['engineer']]
@@ -1997,7 +2008,42 @@ class BaseLinearLogitFnTest(object):
units=3,
cols_to_vars=cols_to_vars)
cols_to_vars.pop('bias')
- fraction_zero = linear._compute_fraction_of_zero(cols_to_vars)
+ fraction_zero = linear._compute_fraction_of_zero(cols_to_vars.values())
+ age_var = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
+ 'linear_model/age')[0]
+ with tf_session.Session() as sess:
+ sess.run([variables_lib.global_variables_initializer()])
+ # Upon initialization, all variables will be zero.
+ self.assertAllClose(1, fraction_zero.eval())
+
+ sess.run(age_var.assign([[2.0, 0.0, -1.0]]))
+ # 1 of the 3 age weights are zero, and all of the 15 (5 hash buckets
+ # x 3-dim output) are zero.
+ self.assertAllClose(16. / 18., fraction_zero.eval())
+
+ def test_compute_fraction_of_zero_v2(self):
+ """Tests the calculation of sparsity."""
+ if self._fc_lib != feature_column_v2:
+ return
+
+ age = feature_column_v2.numeric_column('age')
+ occupation = feature_column_v2.categorical_column_with_hash_bucket(
+ 'occupation', hash_bucket_size=5)
+ shared_state_manager = feature_column_v2.SharedEmbeddingStateManager()
+ with ops.Graph().as_default():
+ model = feature_column_v2.LinearModel(
+ feature_columns=[age, occupation],
+ units=3,
+ shared_state_manager=shared_state_manager)
+ features = {
+ 'age': [[23.], [31.]],
+ 'occupation': [['doctor'], ['engineer']]
+ }
+ model(features)
+ variables = model.variables
+ variables.remove(model.bias_variable)
+ variables.extend(shared_state_manager.variables)
+ fraction_zero = linear._compute_fraction_of_zero(variables)
age_var = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
'linear_model/age')[0]
with tf_session.Session() as sess:
@@ -2013,9 +2059,13 @@ class BaseLinearLogitFnTest(object):
class BaseLinearWarmStartingTest(object):
- def __init__(self, _linear_classifier_fn, _linear_regressor_fn):
+ def __init__(self,
+ _linear_classifier_fn,
+ _linear_regressor_fn,
+ fc_lib=feature_column):
self._linear_classifier_fn = _linear_classifier_fn
self._linear_regressor_fn = _linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
# Create a directory to save our old checkpoint and vocabularies to.
@@ -2039,7 +2089,7 @@ class BaseLinearWarmStartingTest(object):
def test_classifier_basic_warm_starting(self):
"""Tests correctness of LinearClassifier default warm-start."""
- age = feature_column_lib.numeric_column('age')
+ age = self._fc_lib.numeric_column('age')
# Create a LinearClassifier and train to save a checkpoint.
linear_classifier = self._linear_classifier_fn(
@@ -2066,7 +2116,7 @@ class BaseLinearWarmStartingTest(object):
def test_regressor_basic_warm_starting(self):
"""Tests correctness of LinearRegressor default warm-start."""
- age = feature_column_lib.numeric_column('age')
+ age = self._fc_lib.numeric_column('age')
# Create a LinearRegressor and train to save a checkpoint.
linear_regressor = self._linear_regressor_fn(
@@ -2091,7 +2141,7 @@ class BaseLinearWarmStartingTest(object):
def test_warm_starting_selective_variables(self):
"""Tests selecting variables to warm-start."""
- age = feature_column_lib.numeric_column('age')
+ age = self._fc_lib.numeric_column('age')
# Create a LinearClassifier and train to save a checkpoint.
linear_classifier = self._linear_classifier_fn(
@@ -2128,7 +2178,7 @@ class BaseLinearWarmStartingTest(object):
vocab_file = os.path.join(self._ckpt_and_vocab_dir, 'occupation_vocab')
with open(vocab_file, 'w') as f:
f.write('\n'.join(vocab_list))
- occupation = feature_column_lib.categorical_column_with_vocabulary_file(
+ occupation = self._fc_lib.categorical_column_with_vocabulary_file(
'occupation',
vocabulary_file=vocab_file,
vocabulary_size=len(vocab_list))
@@ -2152,7 +2202,7 @@ class BaseLinearWarmStartingTest(object):
'new_occupation_vocab')
with open(new_vocab_file, 'w') as f:
f.write('\n'.join(new_vocab_list))
- new_occupation = feature_column_lib.categorical_column_with_vocabulary_file(
+ new_occupation = self._fc_lib.categorical_column_with_vocabulary_file(
'occupation',
vocabulary_file=new_vocab_file,
vocabulary_size=len(new_vocab_list))
@@ -2205,7 +2255,7 @@ class BaseLinearWarmStartingTest(object):
def test_warm_starting_with_naming_change(self):
"""Tests warm-starting with a Tensor name remapping."""
- age_in_years = feature_column_lib.numeric_column('age_in_years')
+ age_in_years = self._fc_lib.numeric_column('age_in_years')
# Create a LinearClassifier and train to save a checkpoint.
linear_classifier = self._linear_classifier_fn(
@@ -2219,7 +2269,7 @@ class BaseLinearWarmStartingTest(object):
# learning_rate = 0.0 optimizer to check values (use SGD so we don't have
# accumulator values that change).
warm_started_linear_classifier = self._linear_classifier_fn(
- feature_columns=[feature_column_lib.numeric_column('age')],
+ feature_columns=[self._fc_lib.numeric_column('age')],
n_classes=4,
optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0),
# The 'age' variable correspond to the 'age_in_years' variable in the
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 827b405e51..34faf03bb0 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -144,7 +144,7 @@ class Estimator(object):
* `labels`: This is the second item returned from the `input_fn`
passed to `train`, `evaluate`, and `predict`. This should be a
single `tf.Tensor` or `dict` of same (for multi-head models).
- If mode is @{tf.estimator.ModeKeys.PREDICT}, `labels=None` will
+ If mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` will
be passed. If the `model_fn`'s signature does not accept
`mode`, the `model_fn` must still be able to handle
`labels=None`.
@@ -803,9 +803,9 @@ class Estimator(object):
those features and labels, and restores the given checkpoint
(or, lacking that, the most recent checkpoint) into the graph.
Only one of the modes is used for saving variables to the `SavedModel`
- (order of preference: @{tf.estimator.ModeKeys#TRAIN$TRAIN},
- @{tf.estimator.ModeKeys#EVAL$EVAL}, then
- @{tf.estimator.ModeKeys#PREDICT$PREDICT}), such that up to three
+ (order of preference: `tf.estimator.ModeKeys.TRAIN`,
+ `tf.estimator.ModeKeys.EVAL`, then
+ `tf.estimator.ModeKeys.PREDICT`), such that up to three
`tf.MetaGraphDefs` are saved with a single set of variables in a single
`SavedModel` directory.
@@ -1101,7 +1101,7 @@ class Estimator(object):
"""Creates the global step tensor in graph.
The global step tensor must be an integer type with name 'global_step' and
- be added to the collection @{tf.GraphKeys#GLOBAL_STEP$GLOBAL_STEP}.
+ be added to the collection `tf.GraphKeys.GLOBAL_STEP`.
Args:
graph: The graph in which to create the global step tensor.
@@ -1414,6 +1414,36 @@ class Estimator(object):
# It is expected to have one CheckpointSaverHook. If multiple, we pick
# up the first one to add listener.
saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access
+
+ # Add summary hooks to worker 0 if we are running with a master, to ensure
+ # that summaries are written at correct intervals even with long-running
+ # evaluations.
+ save_summary_steps = self._config.save_summary_steps
+ log_step_count_steps = self._config.log_step_count_steps
+ if (self._config.cluster_spec and self._config.cluster_spec.jobs and
+ (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
+ # Update config values to prevent the default hooks from being created on
+ # the master or other workers.
+ save_summary_steps = 0
+ log_step_count_steps = None
+
+ if (self._config.task_type == run_config.TaskType.WORKER and
+ self._config.task_id == 0):
+ if (self._config.save_summary_steps and
+ self._config.save_summary_steps > 0):
+ worker_hooks.append(
+ training.SummarySaverHook(
+ save_steps=self._config.save_summary_steps,
+ output_dir=self._config.model_dir,
+ scaffold=estimator_spec.scaffold))
+
+ if (self._config.log_step_count_steps and
+ self._config.log_step_count_steps > 0):
+ worker_hooks.append(
+ training.StepCounterHook(
+ every_n_steps=self._config.log_step_count_steps,
+ output_dir=self._config.model_dir))
+
with training.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
@@ -1423,9 +1453,9 @@ class Estimator(object):
chief_only_hooks=(
tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
save_checkpoint_secs=0, # Saving is handled by a hook.
- save_summaries_steps=self._config.save_summary_steps,
+ save_summaries_steps=save_summary_steps,
config=self._session_config,
- log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
+ log_step_count_steps=log_step_count_steps) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index bc2504ca19..246dfb1a4b 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import functools
import glob
+import json
import os
import tempfile
@@ -969,6 +970,99 @@ class EstimatorTrainTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'train_and_evaluate'):
est.train(dummy_input_fn, steps=1)
+ def test_master_distributed_hooks(self):
+ tf_config = json.dumps({
+ 'cluster': {
+ run_config.TaskType.PS: ['localhost:1234'],
+ run_config.TaskType.WORKER: ['localhost:1235'],
+ run_config.TaskType.MASTER: ['localhost:1236']
+ },
+ 'task': {
+ 'type': run_config.TaskType.MASTER,
+ 'index': 0
+ }
+ })
+ with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+ est = estimator.Estimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig())
+
+ with test.mock.patch.object(training,
+ 'MonitoredTrainingSession') as mock_sess:
+ est.train(dummy_input_fn, steps=1)
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.StepCounterHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
+ self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
+
+ def test_master_distributed_hooks_for_worker_0(self):
+ tf_config = json.dumps({
+ 'cluster': {
+ run_config.TaskType.PS: ['localhost:1234'],
+ run_config.TaskType.WORKER: ['localhost:1235'],
+ run_config.TaskType.MASTER: ['localhost:1236']
+ },
+ 'task': {
+ 'type': run_config.TaskType.WORKER,
+ 'index': 0
+ }
+ })
+ with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+ est = estimator.Estimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig())
+
+ with test.mock.patch.object(training,
+ 'MonitoredTrainingSession') as mock_sess:
+ est.train(dummy_input_fn, steps=1)
+ self.assertTrue(
+ any(
+ isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertTrue(
+ any(
+ isinstance(hook, basic_session_run_hooks.StepCounterHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
+ self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
+
+ def test_master_distributed_hooks_for_worker_nonzero(self):
+ tf_config = json.dumps({
+ 'cluster': {
+ run_config.TaskType.PS: ['localhost:1234'],
+ run_config.TaskType.WORKER: ['localhost:1235', 'localhost:1237'],
+ run_config.TaskType.MASTER: ['localhost:1236']
+ },
+ 'task': {
+ 'type': run_config.TaskType.WORKER,
+ 'index': 1
+ }
+ })
+ with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+ est = estimator.Estimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig())
+
+ with test.mock.patch.object(training,
+ 'MonitoredTrainingSession') as mock_sess:
+ est.train(dummy_input_fn, steps=1)
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.StepCounterHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
+ self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
+
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
_, _ = features, labels
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index 5800b693b4..ac53a84eef 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -156,7 +156,7 @@ py_test(
"//tensorflow/python:variables",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
- "//tensorflow/python/estimator:estimator_py",
+ "//tensorflow/python/estimator:numpy_io",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index 538641c251..b79373c475 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -136,14 +136,11 @@ import six
from tensorflow.python.eager import context
-from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine.base_layer import Layer
-from tensorflow.python.layers import base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@@ -153,7 +150,6 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
@@ -245,28 +241,19 @@ class StateManager(object):
raise NotImplementedError('StateManager.get_resource')
-class _InputLayerStateManager(StateManager):
- """Manages the state of InputLayer."""
+class _StateManagerImpl(StateManager):
+ """Manages the state of FeatureLayer and LinearModel."""
- def __init__(self, layer, feature_columns, trainable):
- """Creates an _InputLayerStateManager object.
+ def __init__(self, layer, trainable):
+ """Creates an _StateManagerImpl object.
Args:
layer: The input layer this state manager is associated with.
- feature_columns: List of feature columns for the input layer
trainable: Whether by default, variables created are trainable or not.
"""
self._trainable = trainable
self._layer = layer
- self._cols_to_vars_map = {}
- self._cols_to_names_map = {}
- for column in sorted(feature_columns, key=lambda x: x.name):
- self._cols_to_vars_map[column] = {}
- base_name = column.name
- if isinstance(column, SharedEmbeddingColumn):
- base_name = column.shared_collection_name
- with variable_scope.variable_scope(base_name) as vs:
- self._cols_to_names_map[column] = _strip_leading_slashes(vs.name)
+ self._cols_to_vars_map = collections.defaultdict(lambda: {})
def create_variable(self,
feature_column,
@@ -277,19 +264,20 @@ class _InputLayerStateManager(StateManager):
initializer=None):
if name in self._cols_to_vars_map[feature_column]:
raise ValueError('Variable already exists.')
- with variable_scope.variable_scope(self._cols_to_names_map[feature_column]):
- var = self._layer.add_variable(
- name=name,
- shape=shape,
- dtype=dtype,
- initializer=initializer,
- trainable=self._trainable and trainable,
- # TODO(rohanj): Get rid of this hack once we have a mechanism for
- # specifying a default partitioner for an entire layer. In that case,
- # the default getter for Layers should work.
- getter=variable_scope.get_variable)
- self._cols_to_vars_map[feature_column][name] = var
- return var
+
+ var = self._layer.add_variable(
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ trainable=self._trainable and trainable,
+ use_resource=True,
+ # TODO(rohanj): Get rid of this hack once we have a mechanism for
+ # specifying a default partitioner for an entire layer. In that case,
+ # the default getter for Layers should work.
+ getter=variable_scope.get_variable)
+ self._cols_to_vars_map[feature_column][name] = var
+ return var
def get_variable(self, feature_column, name):
if name in self._cols_to_vars_map[feature_column]:
@@ -313,12 +301,15 @@ class FeatureLayer(Layer):
keywords_embedded = embedding_column(
categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
columns = [price, keywords_embedded, ...]
- features = tf.parse_example(..., features=make_parse_example_spec(columns))
feature_layer = FeatureLayer(columns)
+
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
dense_tensor = feature_layer(features)
for units in [128, 64, 32]:
dense_tensor = tf.layers.dense(dense_tensor, units, tf.nn.relu)
- prediction = tf.layers.dense(dense_tensor, 1)."""
+ prediction = tf.layers.dense(dense_tensor, 1).
+ ```
+ """
def __init__(self,
feature_columns,
@@ -375,8 +366,7 @@ class FeatureLayer(Layer):
super(FeatureLayer, self).__init__(name=name, trainable=trainable, **kwargs)
self._feature_columns = _normalize_feature_columns(feature_columns)
- self._state_manager = _InputLayerStateManager(self, self._feature_columns,
- self.trainable)
+ self._state_manager = _StateManagerImpl(self, self.trainable)
self._shared_state_manager = shared_state_manager
for column in sorted(self._feature_columns, key=lambda x: x.name):
if not isinstance(column, DenseColumn):
@@ -394,8 +384,9 @@ class FeatureLayer(Layer):
if isinstance(column, SharedEmbeddingColumn):
column.create_state(self._shared_state_manager)
else:
- with variable_scope.variable_scope(None, default_name=self.name):
- column.create_state(self._state_manager)
+ with variable_scope._pure_variable_scope(self.name): # pylint: disable=protected-access
+ with variable_scope._pure_variable_scope(column.name): # pylint: disable=protected-access
+ column.create_state(self._state_manager)
super(FeatureLayer, self).build(None)
def call(self, features, cols_to_output_tensors=None):
@@ -424,19 +415,20 @@ class FeatureLayer(Layer):
output_tensors = []
ordered_columns = []
for column in sorted(self._feature_columns, key=lambda x: x.name):
- ordered_columns.append(column)
- if isinstance(column, SharedEmbeddingColumn):
- tensor = column.get_dense_tensor(transformation_cache,
- self._shared_state_manager)
- else:
- tensor = column.get_dense_tensor(transformation_cache,
- self._state_manager)
- num_elements = column.variable_shape.num_elements()
- batch_size = array_ops.shape(tensor)[0]
- tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
- output_tensors.append(tensor)
- if cols_to_output_tensors is not None:
- cols_to_output_tensors[column] = tensor
+ with ops.name_scope(column.name):
+ ordered_columns.append(column)
+ if isinstance(column, SharedEmbeddingColumn):
+ tensor = column.get_dense_tensor(transformation_cache,
+ self._shared_state_manager)
+ else:
+ tensor = column.get_dense_tensor(transformation_cache,
+ self._state_manager)
+ num_elements = column.variable_shape.num_elements()
+ batch_size = array_ops.shape(tensor)[0]
+ tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
+ output_tensors.append(tensor)
+ if cols_to_output_tensors is not None:
+ cols_to_output_tensors[column] = tensor
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)
@@ -448,20 +440,18 @@ class FeatureLayer(Layer):
return (input_shape[0], total_elements)
-def linear_model(features,
- feature_columns,
- units=1,
- sparse_combiner='sum',
- weight_collections=None,
- trainable=True,
- cols_to_vars=None):
- """Returns a linear prediction `Tensor` based on given `feature_columns`.
+def _strip_leading_slashes(name):
+ return name.rsplit('/', 1)[-1]
- This function generates a weighted sum based on output dimension `units`.
+
+class LinearModel(Layer):
+ """Produces a linear prediction `Tensor` based on given `feature_columns`.
+
+ This layer generates a weighted sum based on output dimension `units`.
Weighted sum refers to logits in classification problems. It refers to the
prediction itself for linear regression problems.
- Note on supported columns: `linear_model` treats categorical columns as
+ Note on supported columns: `LinearModel` treats categorical columns as
`indicator_column`s. To be specific, assume the input as `SparseTensor` looks
like:
@@ -486,308 +476,195 @@ def linear_model(features,
keywords = categorical_column_with_hash_bucket("keywords", 10K)
keywords_price = crossed_column('keywords', price_buckets, ...)
columns = [price_buckets, keywords, keywords_price ...]
+ linear_model = LinearModel(columns)
+
features = tf.parse_example(..., features=make_parse_example_spec(columns))
- prediction = linear_model(features, columns)
+ prediction = linear_model(features)
```
-
- Args:
- features: A mapping from key to tensors. `_FeatureColumn`s look up via these
- keys. For example `numeric_column('price')` will look at 'price' key in
- this dict. Values are `Tensor` or `SparseTensor` depending on
- corresponding `_FeatureColumn`.
- feature_columns: An iterable containing the FeatureColumns to use as inputs
- to your model. All items should be instances of classes derived from
- `_FeatureColumn`s.
- units: An integer, dimensionality of the output space. Default value is 1.
- sparse_combiner: A string specifying how to reduce if a categorical column
- is multivalent. Except `numeric_column`, almost all columns passed to
- `linear_model` are considered as categorical columns. It combines each
- categorical column independently. Currently "mean", "sqrtn" and "sum" are
- supported, with "sum" the default for linear model. "sqrtn" often achieves
- good accuracy, in particular with bag-of-words columns.
- * "sum": do not normalize features in the column
- * "mean": do l1 normalization on features in the column
- * "sqrtn": do l2 normalization on features in the column
- For example, for two features represented as the categorical columns:
-
- ```python
- # Feature 1
-
- shape = [2, 2]
- {
- [0, 0]: "a"
- [0, 1]: "b"
- [1, 0]: "c"
- }
-
- # Feature 2
-
- shape = [2, 3]
- {
- [0, 0]: "d"
- [1, 0]: "e"
- [1, 1]: "f"
- [1, 2]: "g"
- }
- ```
- with `sparse_combiner` as "mean", the linear model outputs conceptly are:
- ```
- y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
- y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
- ```
- where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
- assigned to the presence of `x` in the input features.
- weight_collections: A list of collection names to which the Variable will be
- added. Note that, variables will also be added to collections
- `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
- trainable: If `True` also add the variable to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- cols_to_vars: If not `None`, must be a dictionary that will be filled with a
- mapping from `_FeatureColumn` to associated list of `Variable`s. For
- example, after the call, we might have cols_to_vars = {
- _NumericColumn(
- key='numeric_feature1', shape=(1,):
- [<tf.Variable 'linear_model/price2/weights:0' shape=(1, 1)>],
- 'bias': [<tf.Variable 'linear_model/bias_weights:0' shape=(1,)>],
- _NumericColumn(
- key='numeric_feature2', shape=(2,)):
- [<tf.Variable 'linear_model/price1/weights:0' shape=(2, 1)>]}
- If a column creates no variables, its value will be an empty list. Note
- that cols_to_vars will also contain a string key 'bias' that maps to a
- list of Variables.
-
- Returns:
- A `Tensor` which represents predictions/logits of a linear model. Its shape
- is (batch_size, units) and its dtype is `float32`.
-
- Raises:
- ValueError: if an item in `feature_columns` is neither a `_DenseColumn`
- nor `_CategoricalColumn`.
- """
- with variable_scope.variable_scope(None, 'linear_model') as vs:
- model_name = _strip_leading_slashes(vs.name)
- linear_model_layer = _LinearModel(
- feature_columns=feature_columns,
- units=units,
- sparse_combiner=sparse_combiner,
- weight_collections=weight_collections,
- trainable=trainable,
- name=model_name)
- retval = linear_model_layer(features) # pylint: disable=not-callable
- if cols_to_vars is not None:
- cols_to_vars.update(linear_model_layer.cols_to_vars())
- return retval
-
-
-def _add_to_collections(var, weight_collections):
- """Adds a var to the list of weight_collections provided.
-
- Handles the case for partitioned and non-partitioned variables.
-
- Args:
- var: A variable or Partitioned Variable.
- weight_collections: List of collections to add variable to.
- """
- for weight_collection in weight_collections:
- # The layer self.add_variable call already adds it to GLOBAL_VARIABLES.
- if weight_collection == ops.GraphKeys.GLOBAL_VARIABLES:
- continue
- # TODO(rohanj): Explore adding a _get_variable_list method on `Variable`
- # so that we don't have to do this check.
- if isinstance(var, variables.PartitionedVariable):
- for constituent_var in list(var):
- ops.add_to_collection(weight_collection, constituent_var)
- else:
- ops.add_to_collection(weight_collection, var)
-
-
-class _FCLinearWrapper(base.Layer):
- """Wraps a _FeatureColumn in a layer for use in a linear model.
-
- See `linear_model` above.
"""
def __init__(self,
- feature_column,
+ feature_columns,
units=1,
sparse_combiner='sum',
- weight_collections=None,
trainable=True,
name=None,
+ shared_state_manager=None,
**kwargs):
- super(_FCLinearWrapper, self).__init__(
- trainable=trainable, name=name, **kwargs)
- self._feature_column = feature_column
- self._units = units
- self._sparse_combiner = sparse_combiner
- self._weight_collections = weight_collections
+ """Constructs a LinearModel.
- def build(self, _):
- if isinstance(self._feature_column, fc_old._CategoricalColumn): # pylint: disable=protected-access
- weight = self.add_variable(
- name='weights',
- shape=(self._feature_column._num_buckets, self._units), # pylint: disable=protected-access
- initializer=init_ops.zeros_initializer(),
- trainable=self.trainable)
- else:
- num_elements = self._feature_column._variable_shape.num_elements() # pylint: disable=protected-access
- weight = self.add_variable(
- name='weights',
- shape=[num_elements, self._units],
- initializer=init_ops.zeros_initializer(),
- trainable=self.trainable)
- _add_to_collections(weight, self._weight_collections)
- self._weight_var = weight
- self.built = True
-
- def call(self, builder):
- weighted_sum = fc_old._create_weighted_sum( # pylint: disable=protected-access
- column=self._feature_column,
- builder=builder,
- units=self._units,
- sparse_combiner=self._sparse_combiner,
- weight_collections=self._weight_collections,
- trainable=self.trainable,
- weight_var=self._weight_var)
- return weighted_sum
+ Args:
+ feature_columns: An iterable containing the FeatureColumns to use as
+ inputs to your model. All items should be instances of classes derived
+ from `_FeatureColumn`s.
+ units: An integer, dimensionality of the output space. Default value is 1.
+ sparse_combiner: A string specifying how to reduce if a categorical column
+ is multivalent. Except `numeric_column`, almost all columns passed to
+ `linear_model` are considered as categorical columns. It combines each
+ categorical column independently. Currently "mean", "sqrtn" and "sum"
+ are supported, with "sum" the default for linear model. "sqrtn" often
+ achieves good accuracy, in particular with bag-of-words columns.
+ * "sum": do not normalize features in the column
+ * "mean": do l1 normalization on features in the column
+ * "sqrtn": do l2 normalization on features in the column
+ For example, for two features represented as the categorical columns:
+
+ ```python
+ # Feature 1
+
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [0, 1]: "b"
+ [1, 0]: "c"
+ }
+
+ # Feature 2
+
+ shape = [2, 3]
+ {
+ [0, 0]: "d"
+ [1, 0]: "e"
+ [1, 1]: "f"
+ [1, 2]: "g"
+ }
+ ```
+
+ with `sparse_combiner` as "mean", the linear model outputs conceptly are
+ ```
+ y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
+ y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
+ ```
+ where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
+ assigned to the presence of `x` in the input features.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: Name to give to the Linear Model. All variables and ops created will
+ be scoped by this name.
+ shared_state_manager: SharedEmbeddingStateManager that manages the state
+ of SharedEmbeddingColumns. For more info, look at `FeatureLayer`.
+ **kwargs: Keyword arguments to construct a layer.
+ Raises:
+ ValueError: if an item in `feature_columns` is neither a `DenseColumn`
+ nor `CategoricalColumn`.
+ """
+ super(LinearModel, self).__init__(name=name, trainable=trainable, **kwargs)
-class _BiasLayer(base.Layer):
- """A layer for the bias term.
- """
+ self._feature_columns = _normalize_feature_columns(feature_columns)
+ self._feature_columns = sorted(self._feature_columns, key=lambda x: x.name)
+ for column in self._feature_columns:
+ if not isinstance(column, (DenseColumn, CategoricalColumn)):
+ raise ValueError(
+ 'Items of feature_columns must be either a '
+ 'DenseColumn or CategoricalColumn. Given: {}'.format(column))
- def __init__(self,
- units=1,
- trainable=True,
- weight_collections=None,
- name=None,
- **kwargs):
- super(_BiasLayer, self).__init__(trainable=trainable, name=name, **kwargs)
self._units = units
- self._weight_collections = weight_collections
-
- def build(self, _):
- self._bias_variable = self.add_variable(
- 'bias_weights',
- shape=[self._units],
- initializer=init_ops.zeros_initializer(),
- trainable=self.trainable)
- _add_to_collections(self._bias_variable, self._weight_collections)
- self.built = True
-
- def call(self, _):
- return self._bias_variable
+ self._sparse_combiner = sparse_combiner
+ self._state_manager = _StateManagerImpl(self, self.trainable)
+ self._shared_state_manager = shared_state_manager
+ self._bias_variable = None
-def _get_expanded_variable_list(var_list):
- returned_list = []
- for variable in var_list:
- if (isinstance(variable, variables.Variable) or
- resource_variable_ops.is_resource_variable(variable)):
- returned_list.append(variable) # Single variable case.
- else: # Must be a PartitionedVariable, so convert into a list.
- returned_list.extend(list(variable))
- return returned_list
+ def build(self, _):
+ # Create state for shared embedding columns.
+ for column in self._feature_columns:
+ if isinstance(column, SharedEmbeddingColumn):
+ column.create_state(self._shared_state_manager)
+ # We need variable scopes for now because we want the variable partitioning
+ # information to percolate down. We also use _pure_variable_scope's here
+ # since we want to open up a name_scope in the `call` method while creating
+ # the ops.
+ with variable_scope._pure_variable_scope(self.name): # pylint: disable=protected-access
+ for column in self._feature_columns:
+ with variable_scope._pure_variable_scope(column.name): # pylint: disable=protected-access
+ # Create the state for each feature column
+ if not isinstance(column, SharedEmbeddingColumn):
+ column.create_state(self._state_manager)
+
+ # Create a weight variable for each column.
+ if isinstance(column, CategoricalColumn):
+ first_dim = column.num_buckets
+ else:
+ first_dim = column.variable_shape.num_elements()
+ self._state_manager.create_variable(
+ column,
+ name='weights',
+ dtype=dtypes.float32,
+ shape=(first_dim, self._units),
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+
+ # Create a bias variable.
+ self._bias_variable = self.add_variable(
+ name='bias_weights',
+ dtype=dtypes.float32,
+ shape=[self._units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable,
+ use_resource=True,
+ # TODO(rohanj): Get rid of this hack once we have a mechanism for
+ # specifying a default partitioner for an entire layer. In that case,
+ # the default getter for Layers should work.
+ getter=variable_scope.get_variable)
-def _strip_leading_slashes(name):
- return name.rsplit('/', 1)[-1]
+ super(LinearModel, self).build(None)
+ def call(self, features):
+ """Returns a `Tensor` the represents the predictions of a linear model.
-class _LinearModel(training.Model):
- """Creates a linear model using feature columns.
+ Args:
+ features: A mapping from key to tensors. `_FeatureColumn`s look up via
+ these keys. For example `numeric_column('price')` will look at 'price'
+ key in this dict. Values are `Tensor` or `SparseTensor` depending on
+ corresponding `_FeatureColumn`.
- See `linear_model` for details.
- """
+ Returns:
+ A `Tensor` which represents predictions/logits of a linear model. Its
+ shape is (batch_size, units) and its dtype is `float32`.
- def __init__(self,
- feature_columns,
- units=1,
- sparse_combiner='sum',
- weight_collections=None,
- trainable=True,
- name=None,
- **kwargs):
- super(_LinearModel, self).__init__(name=name, **kwargs)
- self._feature_columns = fc_old._normalize_feature_columns( # pylint: disable=protected-access
- feature_columns)
- self._weight_collections = list(weight_collections or [])
- if ops.GraphKeys.GLOBAL_VARIABLES not in self._weight_collections:
- self._weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
- if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections:
- self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
-
- column_layers = {}
- for column in sorted(self._feature_columns, key=lambda x: x.name):
- with variable_scope.variable_scope(
- None, default_name=column._var_scope_name) as vs: # pylint: disable=protected-access
- # Having the fully expressed variable scope name ends up doubly
- # expressing the outer scope (scope with which this method was called)
- # in the name of the variable that would get created.
- column_name = _strip_leading_slashes(vs.name)
- column_layer = _FCLinearWrapper(column, units, sparse_combiner,
- self._weight_collections, trainable,
- column_name, **kwargs)
- column_layers[column_name] = column_layer
- self._column_layers = self._add_layers(column_layers)
- self._bias_layer = _BiasLayer(
- units=units,
- trainable=trainable,
- weight_collections=self._weight_collections,
- name='bias_layer',
- **kwargs)
- self._cols_to_vars = {}
-
- def cols_to_vars(self):
- """Returns a dict mapping _FeatureColumns to variables.
-
- See `linear_model` for more information.
- This is not populated till `call` is called i.e. layer is built.
+ Raises:
+ ValueError: If features are not a dictionary.
"""
- return self._cols_to_vars
-
- def call(self, features):
- with variable_scope.variable_scope(self.name):
- for column in self._feature_columns:
- if not isinstance(
- column,
- (
- fc_old._DenseColumn, # pylint: disable=protected-access
- fc_old._CategoricalColumn)): # pylint: disable=protected-access
- raise ValueError(
- 'Items of feature_columns must be either a '
- '_DenseColumn or _CategoricalColumn. Given: {}'.format(column))
+ if not isinstance(features, dict):
+ raise ValueError('We expected a dictionary here. Instead we got: ',
+ features)
+ with ops.name_scope(self.name):
+ transformation_cache = FeatureTransformationCache(features)
weighted_sums = []
- ordered_columns = []
- builder = fc_old._LazyBuilder(features) # pylint: disable=protected-access
- for layer in sorted(self._column_layers.values(), key=lambda x: x.name):
- column = layer._feature_column # pylint: disable=protected-access
- ordered_columns.append(column)
- weighted_sum = layer(builder)
- weighted_sums.append(weighted_sum)
- self._cols_to_vars[column] = ops.get_collection(
- ops.GraphKeys.GLOBAL_VARIABLES, scope=layer.scope_name)
-
- _verify_static_batch_size_equality(weighted_sums, ordered_columns)
+ for column in self._feature_columns:
+ with ops.name_scope(column.name):
+ # All the weights used in the linear model are owned by the state
+ # manager associated with this Linear Model.
+ weight_var = self._state_manager.get_variable(column, 'weights')
+
+ # The embedding weights for the SharedEmbeddingColumn are owned by
+ # the shared_state_manager and so we need to pass that in while
+ # creating the weighted sum. For all other columns, the state is owned
+ # by the Linear Model's state manager.
+ if isinstance(column, SharedEmbeddingColumn):
+ state_manager = self._shared_state_manager
+ else:
+ state_manager = self._state_manager
+ weighted_sum = _create_weighted_sum(
+ column=column,
+ transformation_cache=transformation_cache,
+ state_manager=state_manager,
+ sparse_combiner=self._sparse_combiner,
+ weight_var=weight_var)
+ weighted_sums.append(weighted_sum)
+
+ _verify_static_batch_size_equality(weighted_sums, self._feature_columns)
predictions_no_bias = math_ops.add_n(
weighted_sums, name='weighted_sum_no_bias')
predictions = nn_ops.bias_add(
- predictions_no_bias,
- self._bias_layer( # pylint: disable=not-callable
- builder,
- scope=variable_scope.get_variable_scope()), # pylint: disable=not-callable
- name='weighted_sum')
- bias = self._bias_layer.variables[0]
- self._cols_to_vars['bias'] = _get_expanded_variable_list([bias])
- return predictions
-
- def _add_layers(self, layers):
- # "Magic" required for keras.Model classes to track all the variables in
- # a list of layers.Layer objects.
- # TODO(ashankar): Figure out API so user code doesn't have to do this.
- for name, layer in layers.items():
- setattr(self, 'layer-%s' % name, layer)
- return layers
+ predictions_no_bias, self._bias_variable, name='weighted_sum')
+ return predictions
+
+ @property
+ def bias_variable(self):
+ return self._bias_variable
def _transform_features(features, feature_columns, state_manager):
@@ -2053,58 +1930,32 @@ def is_feature_column_v2(feature_columns):
return True
-def _create_weighted_sum(column,
- transformation_cache,
- state_manager,
- units,
- sparse_combiner,
- weight_collections,
- trainable,
- weight_var=None):
+def _create_weighted_sum(column, transformation_cache, state_manager,
+ sparse_combiner, weight_var):
"""Creates a weighted sum for a dense/categorical column for linear_model."""
if isinstance(column, CategoricalColumn):
return _create_categorical_column_weighted_sum(
column=column,
transformation_cache=transformation_cache,
state_manager=state_manager,
- units=units,
sparse_combiner=sparse_combiner,
- weight_collections=weight_collections,
- trainable=trainable,
weight_var=weight_var)
else:
return _create_dense_column_weighted_sum(
column=column,
transformation_cache=transformation_cache,
state_manager=state_manager,
- units=units,
- weight_collections=weight_collections,
- trainable=trainable,
weight_var=weight_var)
-def _create_dense_column_weighted_sum(column,
- transformation_cache,
- state_manager,
- units,
- weight_collections,
- trainable,
- weight_var=None):
+def _create_dense_column_weighted_sum(column, transformation_cache,
+ state_manager, weight_var):
"""Create a weighted sum of a dense column for linear_model."""
tensor = column.get_dense_tensor(transformation_cache, state_manager)
num_elements = column.variable_shape.num_elements()
batch_size = array_ops.shape(tensor)[0]
tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
- if weight_var is not None:
- weight = weight_var
- else:
- weight = variable_scope.get_variable(
- name='weights',
- shape=[num_elements, units],
- initializer=init_ops.zeros_initializer(),
- trainable=trainable,
- collections=weight_collections)
- return math_ops.matmul(tensor, weight, name='weighted_sum')
+ return math_ops.matmul(tensor, weight_var, name='weighted_sum')
class CategoricalColumn(FeatureColumn):
@@ -2145,14 +1996,8 @@ class CategoricalColumn(FeatureColumn):
pass
-def _create_categorical_column_weighted_sum(column,
- transformation_cache,
- state_manager,
- units,
- sparse_combiner,
- weight_collections,
- trainable,
- weight_var=None):
+def _create_categorical_column_weighted_sum(
+ column, transformation_cache, state_manager, sparse_combiner, weight_var):
# pylint: disable=g-doc-return-or-yield,g-doc-args
"""Create a weighted sum of a categorical column for linear_model.
@@ -2191,17 +2036,8 @@ def _create_categorical_column_weighted_sum(column,
weight_tensor = sparse_ops.sparse_reshape(
weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
- if weight_var is not None:
- weight = weight_var
- else:
- weight = variable_scope.get_variable(
- name='weights',
- shape=(column.num_buckets, units),
- initializer=init_ops.zeros_initializer(),
- trainable=trainable,
- collections=weight_collections)
return _safe_embedding_lookup_sparse(
- weight,
+ weight_var,
id_tensor,
sparse_weights=weight_tensor,
combiner=sparse_combiner,
@@ -2777,6 +2613,7 @@ class SharedEmbeddingStateManager(Layer):
dtype=dtype,
trainable=self.trainable and trainable,
initializer=initializer,
+ use_resource=True,
# TODO(rohanj): Get rid of this hack once we have a mechanism for
# specifying a default partitioner for an entire layer. In that case,
# the default getter for Layers should work.
@@ -2836,6 +2673,10 @@ class SharedEmbeddingColumn(
def create_state(self, state_manager):
"""Creates the shared embedding lookup variable."""
+ if not isinstance(state_manager, SharedEmbeddingStateManager):
+ raise ValueError('Expected state_manager to be of type '
+ 'SharedEmbeddingStateManager. Obtained type: {}'.format(
+ type(state_manager)))
embedding_shape = (self.categorical_column.num_buckets, self.dimension)
state_manager.create_variable(
name=self.shared_collection_name,
@@ -3447,6 +3288,7 @@ def _safe_embedding_lookup_sparse(embedding_weights,
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
dtype = sparse_weights.dtype if sparse_weights is not None else None
+ # TODO(rohanj): Look into removing this convert_to_tensor call.
embedding_weights = [
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
]
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index 2970431167..d3787146ed 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -31,9 +31,7 @@ from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.estimator.inputs import numpy_io
-from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.feature_column import feature_column_v2 as fc
-from tensorflow.python.feature_column.feature_column_v2 import _LinearModel
from tensorflow.python.feature_column.feature_column_v2 import _transform_features
from tensorflow.python.feature_column.feature_column_v2 import FeatureColumn
from tensorflow.python.feature_column.feature_column_v2 import FeatureLayer
@@ -48,7 +46,6 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
@@ -360,26 +357,12 @@ class NumericColumnTest(test.TestCase):
self.assertEqual(a.default_value, ((3., 2.),))
def test_linear_model(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.]], price_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(price_var.assign([[10.]]))
- self.assertAllClose([[10.], [50.]], predictions.eval())
-
- def test_keras_linear_model(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.]], price_var.eval())
@@ -564,13 +547,13 @@ class BucketizedColumnTest(test.TestCase):
def test_linear_model_one_input_value(self):
"""Tests linear_model() for input with shape=[1]."""
- price = fc_old.numeric_column('price', shape=[1])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ price = fc.numeric_column('price', shape=[1])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
features = {'price': [[-1.], [1.], [5.], [6.]]}
- predictions = fc.linear_model(features, [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ model = fc.LinearModel([bucketized_price])
+ predictions = model(features)
+ bucketized_price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
# One weight variable per bucket, all initialized to zero.
@@ -589,13 +572,13 @@ class BucketizedColumnTest(test.TestCase):
def test_linear_model_two_input_values(self):
"""Tests linear_model() for input with shape=[2]."""
- price = fc_old.numeric_column('price', shape=[2])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
features = {'price': [[-1., 1.], [5., 6.]]}
- predictions = fc.linear_model(features, [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ model = fc.LinearModel([bucketized_price])
+ predictions = model(features)
+ bucketized_price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
# One weight per bucket per input column, all initialized to zero.
@@ -616,62 +599,6 @@ class BucketizedColumnTest(test.TestCase):
sess.run(bias.assign([1.]))
self.assertAllClose([[81.], [141.]], predictions.eval())
- def test_keras_linear_model_one_input_value(self):
- """Tests _LinearModel for input with shape=[1]."""
- price = fc_old.numeric_column('price', shape=[1])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
- with ops.Graph().as_default():
- features = {'price': [[-1.], [1.], [5.], [6.]]}
- predictions = get_keras_linear_model_predictions(features,
- [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- # One weight variable per bucket, all initialized to zero.
- self.assertAllClose([[0.], [0.], [0.], [0.], [0.]],
- bucketized_price_var.eval())
- self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
- sess.run(
- bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.]]))
- # price -1. is in the 0th bucket, whose weight is 10.
- # price 1. is in the 1st bucket, whose weight is 20.
- # price 5. is in the 3rd bucket, whose weight is 40.
- # price 6. is in the 4th bucket, whose weight is 50.
- self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
- sess.run(bias.assign([1.]))
- self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
-
- def test_keras_linear_model_two_input_values(self):
- """Tests _LinearModel for input with shape=[2]."""
- price = fc_old.numeric_column('price', shape=[2])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
- with ops.Graph().as_default():
- features = {'price': [[-1., 1.], [5., 6.]]}
- predictions = get_keras_linear_model_predictions(features,
- [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- # One weight per bucket per input column, all initialized to zero.
- self.assertAllClose(
- [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
- bucketized_price_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(
- bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.],
- [60.], [70.], [80.], [90.], [100.]]))
- # 1st example:
- # price -1. is in the 0th bucket, whose weight is 10.
- # price 1. is in the 6th bucket, whose weight is 70.
- # 2nd example:
- # price 5. is in the 3rd bucket, whose weight is 40.
- # price 6. is in the 9th bucket, whose weight is 100.
- self.assertAllClose([[80.], [140.]], predictions.eval())
- sess.run(bias.assign([1.]))
- self.assertAllClose([[81.], [141.]], predictions.eval())
-
class HashedCategoricalColumnTest(test.TestCase):
@@ -852,39 +779,18 @@ class HashedCategoricalColumnTest(test.TestCase):
transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor)
def test_linear_model(self):
- wire_column = fc_old.categorical_column_with_hash_bucket('wire', 4)
- self.assertEqual(4, wire_column._num_buckets)
+ wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
+ self.assertEqual(4, wire_column.num_buckets)
with ops.Graph().as_default():
- predictions = fc.linear_model({
- wire_column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
- # 'marlo' -> 3: wire_var[3] = 4
- # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
- self.assertAllClose(((4.,), (6.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- wire_column = fc_old.categorical_column_with_hash_bucket('wire', 4)
- self.assertEqual(4, wire_column._num_buckets)
- with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((wire_column,))
+ predictions = model({
wire_column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
+ })
+ wire_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
@@ -1103,93 +1009,12 @@ class CrossedColumnTest(test.TestCase):
Uses data from test_get_sparse_tesnsors_simple.
"""
- a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,))
- b = fc_old.bucketized_column(a, boundaries=(0, 1))
- crossed = fc_old.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- 'a': constant_op.constant(((-1., .5), (.5, 1.))),
- 'c': sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=['cA', 'cB', 'cC'],
- dense_shape=(2, 2)),
- }, (crossed,))
- bias = get_linear_model_bias()
- crossed_var = get_linear_model_column_var(crossed)
- with _initialized_session() as sess:
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(
- ((0.,), (0.,), (0.,), (0.,), (0.,)), crossed_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
- # Expected ids after cross = (1, 0, 1, 3, 4, 2)
- self.assertAllClose(((3.,), (14.,)), predictions.eval())
- sess.run(bias.assign((.1,)))
- self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
-
- def test_linear_model_with_weights(self):
-
- class _TestColumnWithWeights(fc_old._CategoricalColumn):
- """Produces sparse IDs and sparse weights."""
-
- @property
- def name(self):
- return 'test_column'
-
- @property
- def _parse_example_spec(self):
- return {
- self.name: parsing_ops.VarLenFeature(dtypes.int32),
- '{}_weights'.format(self.name): parsing_ops.VarLenFeature(
- dtypes.float32),
- }
-
- @property
- def _num_buckets(self):
- return 5
-
- def _transform_feature(self, inputs):
- return (inputs.get(self.name),
- inputs.get('{}_weights'.format(self.name)))
-
- def _get_sparse_tensors(self, inputs, weight_collections=None,
- trainable=None):
- """Populates both id_tensor and weight_tensor."""
- ids_and_weights = inputs.get(self)
- return fc_old._CategoricalColumn.IdWeightPair(
- id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
-
- t = _TestColumnWithWeights()
- crossed = fc_old.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
- with ops.Graph().as_default():
- with self.assertRaisesRegexp(
- ValueError,
- 'crossed_column does not support weight_tensor.*{}'.format(t.name)):
- fc.linear_model({
- t.name: sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=[0, 1, 2],
- dense_shape=(2, 2)),
- '{}_weights'.format(t.name): sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=[1., 10., 2.],
- dense_shape=(2, 2)),
- 'c': sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=['cA', 'cB', 'cC'],
- dense_shape=(2, 2)),
- }, (crossed,))
-
- def test_keras_linear_model(self):
- """Tests _LinearModel.
-
- Uses data from test_get_sparse_tesnsors_simple.
- """
- a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,))
- b = fc_old.bucketized_column(a, boundaries=(0, 1))
- crossed = fc_old.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
+ a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc.bucketized_column(a, boundaries=(0, 1))
+ crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((crossed,))
+ predictions = model({
'a':
constant_op.constant(((-1., .5), (.5, 1.))),
'c':
@@ -1197,13 +1022,12 @@ class CrossedColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=['cA', 'cB', 'cC'],
dense_shape=(2, 2)),
- }, (crossed,))
- bias = get_linear_model_bias()
- crossed_var = get_linear_model_column_var(crossed)
+ })
+ crossed_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,), (0.,)),
- crossed_var.eval())
+ self.assertAllClose(
+ ((0.,), (0.,), (0.,), (0.,), (0.,)), crossed_var.eval())
self.assertAllClose(((0.,), (0.,)), predictions.eval())
sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
# Expected ids after cross = (1, 0, 1, 3, 4, 2)
@@ -1211,9 +1035,9 @@ class CrossedColumnTest(test.TestCase):
sess.run(bias.assign((.1,)))
self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
- def test_keras_linear_model_with_weights(self):
+ def test_linear_model_with_weights(self):
- class _TestColumnWithWeights(fc_old._CategoricalColumn):
+ class _TestColumnWithWeights(fc.CategoricalColumn):
"""Produces sparse IDs and sparse weights."""
@property
@@ -1221,38 +1045,36 @@ class CrossedColumnTest(test.TestCase):
return 'test_column'
@property
- def _parse_example_spec(self):
+ def parse_example_spec(self):
return {
- self.name:
- parsing_ops.VarLenFeature(dtypes.int32),
- '{}_weights'.format(self.name):
- parsing_ops.VarLenFeature(dtypes.float32),
- }
+ self.name: parsing_ops.VarLenFeature(dtypes.int32),
+ '{}_weights'.format(self.name): parsing_ops.VarLenFeature(
+ dtypes.float32),
+ }
@property
- def _num_buckets(self):
+ def num_buckets(self):
return 5
- def _transform_feature(self, inputs):
- return (inputs.get(self.name),
- inputs.get('{}_weights'.format(self.name)))
+ def transform_feature(self, transformation_cache, state_manager):
+ return (transformation_cache.get(self.name, state_manager),
+ transformation_cache.get('{}_weights'.format(self.name),
+ state_manager))
- def _get_sparse_tensors(self,
- inputs,
- weight_collections=None,
- trainable=None):
+ def get_sparse_tensors(self, transformation_cache, state_manager):
"""Populates both id_tensor and weight_tensor."""
- ids_and_weights = inputs.get(self)
- return fc_old._CategoricalColumn.IdWeightPair(
+ ids_and_weights = transformation_cache.get(self, state_manager)
+ return fc.CategoricalColumn.IdWeightPair(
id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
t = _TestColumnWithWeights()
- crossed = fc_old.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
+ crossed = fc.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
with ops.Graph().as_default():
with self.assertRaisesRegexp(
ValueError,
'crossed_column does not support weight_tensor.*{}'.format(t.name)):
- get_keras_linear_model_predictions({
+ model = fc.LinearModel((crossed,))
+ model({
t.name:
sparse_tensor.SparseTensor(
indices=((0, 0), (1, 0), (1, 1)),
@@ -1268,37 +1090,7 @@ class CrossedColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=['cA', 'cB', 'cC'],
dense_shape=(2, 2)),
- }, (crossed,))
-
-
-def get_linear_model_bias(name='linear_model'):
- with variable_scope.variable_scope(name, reuse=True):
- return variable_scope.get_variable('bias_weights')
-
-
-def get_linear_model_column_var(column, name='linear_model'):
- return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
- name + '/' + column.name)[0]
-
-
-def get_keras_linear_model_predictions(features,
- feature_columns,
- units=1,
- sparse_combiner='sum',
- weight_collections=None,
- trainable=True,
- cols_to_vars=None):
- keras_linear_model = _LinearModel(
- feature_columns,
- units,
- sparse_combiner,
- weight_collections,
- trainable,
- name='linear_model')
- retval = keras_linear_model(features) # pylint: disable=not-callable
- if cols_to_vars is not None:
- cols_to_vars.update(keras_linear_model.cols_to_vars())
- return retval
+ })
class LinearModelTest(test.TestCase):
@@ -1306,56 +1098,50 @@ class LinearModelTest(test.TestCase):
def test_raises_if_empty_feature_columns(self):
with self.assertRaisesRegexp(ValueError,
'feature_columns must not be empty'):
- fc.linear_model(features={}, feature_columns=[])
+ fc.LinearModel(feature_columns=[])
def test_should_be_feature_column(self):
- with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
- fc.linear_model(features={'a': [[0]]}, feature_columns='NotSupported')
+ with self.assertRaisesRegexp(ValueError, 'must be a FeatureColumn'):
+ fc.LinearModel(feature_columns='NotSupported')
def test_should_be_dense_or_categorical_column(self):
- class NotSupportedColumn(fc_old._FeatureColumn):
+ class NotSupportedColumn(fc.FeatureColumn):
@property
def name(self):
return 'NotSupportedColumn'
- def _transform_feature(self, cache):
+ def transform_feature(self, transformation_cache, state_manager):
pass
@property
- def _parse_example_spec(self):
+ def parse_example_spec(self):
pass
with self.assertRaisesRegexp(
- ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
- fc.linear_model(
- features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
+ ValueError, 'must be either a DenseColumn or CategoricalColumn'):
+ fc.LinearModel(feature_columns=[NotSupportedColumn()])
def test_does_not_support_dict_columns(self):
with self.assertRaisesRegexp(
ValueError, 'Expected feature_columns to be iterable, found dict.'):
- fc.linear_model(
- features={'a': [[0]]},
- feature_columns={'a': fc_old.numeric_column('a')})
+ fc.LinearModel(feature_columns={'a': fc.numeric_column('a')})
def test_raises_if_duplicate_name(self):
with self.assertRaisesRegexp(
ValueError, 'Duplicate feature column name found for columns'):
- fc.linear_model(
- features={'a': [[0]]},
- feature_columns=[
- fc_old.numeric_column('a'),
- fc_old.numeric_column('a')
- ])
+ fc.LinearModel(
+ feature_columns=[fc.numeric_column('a'),
+ fc.numeric_column('a')])
def test_dense_bias(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
sess.run(price_var.assign([[10.]]))
@@ -1363,16 +1149,16 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[15.], [55.]], predictions.eval())
def test_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(features, [wire_cast])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast])
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
@@ -1381,18 +1167,17 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[1005.], [10015.]], predictions.eval())
def test_dense_and_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- price = fc_old.numeric_column('price')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [wire_cast, price])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([wire_cast, price])
+ predictions = model(features)
+ price_var, wire_cast_var, bias = model.variables
with _initialized_session() as sess:
sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
@@ -1402,38 +1187,36 @@ class LinearModelTest(test.TestCase):
def test_dense_and_sparse_column(self):
"""When the column is both dense and sparse, uses sparse tensors."""
- class _DenseAndSparseColumn(fc_old._DenseColumn, fc_old._CategoricalColumn):
+ class _DenseAndSparseColumn(fc.DenseColumn, fc.CategoricalColumn):
@property
def name(self):
return 'dense_and_sparse_column'
@property
- def _parse_example_spec(self):
+ def parse_example_spec(self):
return {self.name: parsing_ops.VarLenFeature(self.dtype)}
- def _transform_feature(self, inputs):
- return inputs.get(self.name)
+ def transform_feature(self, transformation_cache, state_manager):
+ return transformation_cache.get(self.name, state_manager)
@property
- def _variable_shape(self):
+ def variable_shape(self):
raise ValueError('Should not use this method.')
- def _get_dense_tensor(self, inputs, weight_collections=None,
- trainable=None):
+ def get_dense_tensor(self, transformation_cache, state_manager):
raise ValueError('Should not use this method.')
@property
- def _num_buckets(self):
+ def num_buckets(self):
return 4
- def _get_sparse_tensors(self, inputs, weight_collections=None,
- trainable=None):
+ def get_sparse_tensors(self, transformation_cache, state_manager):
sp_tensor = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 0], [1, 1]],
values=[2, 0, 3],
dense_shape=[2, 2])
- return fc_old._CategoricalColumn.IdWeightPair(sp_tensor, None)
+ return fc.CategoricalColumn.IdWeightPair(sp_tensor, None)
dense_and_sparse_column = _DenseAndSparseColumn()
with ops.Graph().as_default():
@@ -1442,10 +1225,9 @@ class LinearModelTest(test.TestCase):
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {dense_and_sparse_column.name: sp_tensor}
- predictions = fc.linear_model(features, [dense_and_sparse_column])
- bias = get_linear_model_bias()
- dense_and_sparse_column_var = get_linear_model_column_var(
- dense_and_sparse_column)
+ model = fc.LinearModel([dense_and_sparse_column])
+ predictions = model(features)
+ dense_and_sparse_column_var, bias = model.variables
with _initialized_session() as sess:
sess.run(dense_and_sparse_column_var.assign(
[[10.], [100.], [1000.], [10000.]]))
@@ -1453,12 +1235,12 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[1005.], [10015.]], predictions.eval())
def test_dense_multi_output(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price], units=3)
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((3,)), bias.eval())
self.assertAllClose(np.zeros((1, 3)), price_var.eval())
@@ -1468,16 +1250,16 @@ class LinearModelTest(test.TestCase):
predictions.eval())
def test_sparse_multi_output(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(features, [wire_cast], units=3)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast], units=3)
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((3,)), bias.eval())
self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
@@ -1490,18 +1272,19 @@ class LinearModelTest(test.TestCase):
predictions.eval())
def test_dense_multi_dimension(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1., 2.], [5., 6.]]}
- predictions = fc.linear_model(features, [price])
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, _ = model.variables
with _initialized_session() as sess:
self.assertAllClose([[0.], [0.]], price_var.eval())
sess.run(price_var.assign([[10.], [100.]]))
self.assertAllClose([[210.], [650.]], predictions.eval())
def test_sparse_multi_rank(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = array_ops.sparse_placeholder(dtypes.string)
wire_value = sparse_tensor.SparseTensorValue(
@@ -1509,8 +1292,9 @@ class LinearModelTest(test.TestCase):
indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
dense_shape=[2, 2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(features, [wire_cast])
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast])
+ predictions = model(features)
+ wire_cast_var, _ = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
self.assertAllClose(
@@ -1522,25 +1306,24 @@ class LinearModelTest(test.TestCase):
predictions.eval(feed_dict={wire_tensor: wire_value}))
def test_sparse_combiner(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(
- features, [wire_cast], sparse_combiner='mean')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast], sparse_combiner='mean')
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[1005.], [5010.]], predictions.eval())
def test_sparse_combiner_with_negative_weights(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- wire_cast_weights = fc_old.weighted_categorical_column(wire_cast, 'weights')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast_weights = fc.weighted_categorical_column(wire_cast, 'weights')
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
@@ -1551,22 +1334,21 @@ class LinearModelTest(test.TestCase):
'wire_cast': wire_tensor,
'weights': constant_op.constant([[1., 1., -1.0]])
}
- predictions = fc.linear_model(
- features, [wire_cast_weights], sparse_combiner='sum')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast_weights], sparse_combiner='sum')
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[1005.], [-9985.]], predictions.eval())
def test_dense_multi_dimension_multi_output(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1., 2.], [5., 6.]]}
- predictions = fc.linear_model(features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price], units=3)
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((3,)), bias.eval())
self.assertAllClose(np.zeros((2, 3)), price_var.eval())
@@ -1576,21 +1358,22 @@ class LinearModelTest(test.TestCase):
predictions.eval())
def test_raises_if_shape_mismatch(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
with self.assertRaisesRegexp(
Exception,
r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
- fc.linear_model(features, [price])
+ model = fc.LinearModel([price])
+ model(features)
def test_dense_reshaping(self):
- price = fc_old.numeric_column('price', shape=[1, 2])
+ price = fc.numeric_column('price', shape=[1, 2])
with ops.Graph().as_default():
features = {'price': [[[1., 2.]], [[5., 6.]]]}
- predictions = fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.], [0.]], price_var.eval())
@@ -1599,17 +1382,16 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[210.], [650.]], predictions.eval())
def test_dense_multi_column(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1., 2.], [5., 6.]],
'price2': [[3.], [4.]]
}
- predictions = fc.linear_model(features, [price1, price2])
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
+ model = fc.LinearModel([price1, price2])
+ predictions = model(features)
+ price1_var, price2_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.], [0.]], price1_var.eval())
@@ -1620,115 +1402,55 @@ class LinearModelTest(test.TestCase):
sess.run(bias.assign([7.]))
self.assertAllClose([[3217.], [4657.]], predictions.eval())
- def test_fills_cols_to_vars(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
- cols_to_vars = {}
- fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
- self.assertAllEqual(cols_to_vars['bias'], [bias])
- self.assertAllEqual(cols_to_vars[price1], [price1_var])
- self.assertAllEqual(cols_to_vars[price2], [price2_var])
-
- def test_fills_cols_to_vars_partitioned_variables(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2', shape=3)
- with ops.Graph().as_default():
- features = {
- 'price1': [[1., 2.], [6., 7.]],
- 'price2': [[3., 4., 5.], [8., 9., 10.]]
- }
- cols_to_vars = {}
- with variable_scope.variable_scope(
- 'linear',
- partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)):
- fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
- with _initialized_session():
- self.assertEqual([0.], cols_to_vars['bias'][0].eval())
- # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables.
- self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval())
- # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and
- # a [1, 1] Variable.
- self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
-
- def test_dense_collection(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- fc.linear_model(features, [price], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- self.assertIn(bias, my_vars)
- self.assertIn(price_var, my_vars)
-
- def test_sparse_collection(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- fc.linear_model(
- features, [wire_cast], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- self.assertIn(bias, my_vars)
- self.assertIn(wire_cast_var, my_vars)
-
def test_dense_trainable_default(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default() as g:
features = {'price': [[1.], [5.]]}
- fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ model(features)
+ price_var, bias = model.variables
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertIn(bias, trainable_vars)
self.assertIn(price_var, trainable_vars)
def test_sparse_trainable_default(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default() as g:
wire_tensor = sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
features = {'wire_cast': wire_tensor}
- fc.linear_model(features, [wire_cast])
+ model = fc.LinearModel([wire_cast])
+ model(features)
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ wire_cast_var, bias = model.variables
self.assertIn(bias, trainable_vars)
self.assertIn(wire_cast_var, trainable_vars)
def test_dense_trainable_false(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default() as g:
features = {'price': [[1.], [5.]]}
- fc.linear_model(features, [price], trainable=False)
+ model = fc.LinearModel([price], trainable=False)
+ model(features)
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertEqual([], trainable_vars)
def test_sparse_trainable_false(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default() as g:
wire_tensor = sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
features = {'wire_cast': wire_tensor}
- fc.linear_model(features, [wire_cast], trainable=False)
+ model = fc.LinearModel([wire_cast], trainable=False)
+ model(features)
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertEqual([], trainable_vars)
def test_column_order(self):
- price_a = fc_old.numeric_column('price_a')
- price_b = fc_old.numeric_column('price_b')
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
+ price_a = fc.numeric_column('price_a')
+ price_b = fc.numeric_column('price_b')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
features = {
'price_a': [[1.]],
'price_b': [[3.]],
@@ -1736,15 +1458,15 @@ class LinearModelTest(test.TestCase):
sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
}
- fc.linear_model(
- features, [price_a, wire_cast, price_b],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
+ model = fc.LinearModel([price_a, wire_cast, price_b])
+ model(features)
+
+ my_vars = model.variables
self.assertIn('price_a', my_vars[0].name)
self.assertIn('price_b', my_vars[1].name)
self.assertIn('wire_cast', my_vars[2].name)
- with ops.Graph().as_default() as g:
+ with ops.Graph().as_default():
features = {
'price_a': [[1.]],
'price_b': [[3.]],
@@ -1752,17 +1474,45 @@ class LinearModelTest(test.TestCase):
sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
}
- fc.linear_model(
- features, [wire_cast, price_b, price_a],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
+ model = fc.LinearModel([wire_cast, price_b, price_a])
+ model(features)
+
+ my_vars = model.variables
self.assertIn('price_a', my_vars[0].name)
self.assertIn('price_b', my_vars[1].name)
self.assertIn('wire_cast', my_vars[2].name)
+ def test_variable_names(self):
+ price1 = fc.numeric_column('price1')
+ dense_feature = fc.numeric_column('dense_feature')
+ dense_feature_bucketized = fc.bucketized_column(
+ dense_feature, boundaries=[0.])
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc.embedding_column(
+ some_sparse_column, dimension=10)
+ all_cols = [price1, dense_feature_bucketized, some_embedding_column]
+
+ with ops.Graph().as_default():
+ model = fc.LinearModel(all_cols)
+ features = {
+ 'price1': [[3.], [4.]],
+ 'dense_feature': [[-1.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ }
+ model(features)
+ variable_names = [var.name for var in model.variables]
+ self.assertItemsEqual([
+ 'linear_model/dense_feature_bucketized/weights:0',
+ 'linear_model/price1/weights:0',
+ 'linear_model/sparse_feature_embedding/embedding_weights:0',
+ 'linear_model/sparse_feature_embedding/weights:0',
+ 'linear_model/bias_weights:0',
+ ], variable_names)
+
def test_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1.], [5.], [7.]], # batchsize = 3
@@ -1771,12 +1521,13 @@ class LinearModelTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- fc.linear_model(features, [price1, price2])
+ model = fc.LinearModel([price1, price2])
+ model(features)
def test_subset_of_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- price3 = fc_old.numeric_column('price3')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ price3 = fc.numeric_column('price3')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
@@ -1786,17 +1537,19 @@ class LinearModelTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- fc.linear_model(features, [price1, price2, price3])
+ model = fc.LinearModel([price1, price2, price3])
+ model(features)
def test_runtime_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
'price2': [[3.], [4.]] # batchsize = 2
}
- predictions = fc.linear_model(features, [price1, price2])
+ model = fc.LinearModel([price1, price2])
+ predictions = model(features)
with _initialized_session() as sess:
with self.assertRaisesRegexp(errors.OpError,
'must have the same size and shape'):
@@ -1804,14 +1557,15 @@ class LinearModelTest(test.TestCase):
predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
def test_runtime_batch_size_matches(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
}
- predictions = fc.linear_model(features, [price1, price2])
+ model = fc.LinearModel([price1, price2])
+ predictions = model(features)
with _initialized_session() as sess:
sess.run(
predictions,
@@ -1821,14 +1575,14 @@ class LinearModelTest(test.TestCase):
})
def test_with_numpy_input_fn(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
price, boundaries=[
0.,
10.,
100.,
])
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
input_fn = numpy_io.numpy_input_fn(
@@ -1839,15 +1593,14 @@ class LinearModelTest(test.TestCase):
batch_size=2,
shuffle=False)
features = input_fn()
- net = fc.linear_model(features, [price_buckets, body_style])
+ model = fc.LinearModel([price_buckets, body_style])
+ net = model(features)
# self.assertEqual(1 + 3 + 5, net.shape[1])
with _initialized_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
+ body_style_var, price_buckets_var, bias = model.variables
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
@@ -1859,14 +1612,14 @@ class LinearModelTest(test.TestCase):
coord.join(threads)
def test_with_1d_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
price, boundaries=[
0.,
10.,
100.,
])
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
# Provides 1-dim tensor and dense tensor.
@@ -1880,11 +1633,10 @@ class LinearModelTest(test.TestCase):
self.assertEqual(1, features['price'].shape.ndims)
self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
- net = fc.linear_model(features, [price_buckets, body_style])
+ model = fc.LinearModel([price_buckets, body_style])
+ net = model(features)
with _initialized_session() as sess:
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
+ body_style_var, price_buckets_var, bias = model.variables
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
@@ -1893,16 +1645,16 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
def test_with_1d_unknown_shape_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
price, boundaries=[
0.,
10.,
100.,
])
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- country = fc_old.categorical_column_with_vocabulary_list(
+ country = fc.categorical_column_with_vocabulary_list(
'country', vocabulary_list=['US', 'JP', 'CA'])
# Provides 1-dim tensor and dense tensor.
@@ -1921,10 +1673,9 @@ class LinearModelTest(test.TestCase):
dense_shape=(2,))
country_data = np.array(['US', 'CA'])
- net = fc.linear_model(features, [price_buckets, body_style, country])
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
+ model = fc.LinearModel([price_buckets, body_style, country])
+ net = model(features)
+ body_style_var, _, price_buckets_var, bias = model.variables
with _initialized_session() as sess:
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
@@ -1940,7 +1691,7 @@ class LinearModelTest(test.TestCase):
}))
def test_with_rank_0_feature(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
features = {
'price': constant_op.constant(0),
}
@@ -1948,29 +1699,31 @@ class LinearModelTest(test.TestCase):
# Static rank 0 should fail
with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
- fc.linear_model(features, [price])
+ model = fc.LinearModel([price])
+ model(features)
# Dynamic rank 0 should fail
features = {
'price': array_ops.placeholder(dtypes.float32),
}
- net = fc.linear_model(features, [price])
+ model = fc.LinearModel([price])
+ net = model(features)
self.assertEqual(1, net.shape[1])
with _initialized_session() as sess:
with self.assertRaisesOpError('Feature .* cannot have rank 0'):
sess.run(net, feed_dict={features['price']: np.array(1)})
def test_multiple_linear_models(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features1 = {'price': [[1.], [5.]]}
features2 = {'price': [[2.], [10.]]}
- predictions1 = fc.linear_model(features1, [price])
- predictions2 = fc.linear_model(features2, [price])
- bias1 = get_linear_model_bias(name='linear_model')
- bias2 = get_linear_model_bias(name='linear_model_1')
- price_var1 = get_linear_model_column_var(price, name='linear_model')
- price_var2 = get_linear_model_column_var(price, name='linear_model_1')
+ model1 = fc.LinearModel([price])
+ model2 = fc.LinearModel([price])
+ predictions1 = model1(features1)
+ predictions2 = model2(features2)
+ price_var1, bias1 = model1.variables
+ price_var2, bias2 = model2.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias1.eval())
sess.run(price_var1.assign([[10.]]))
@@ -1982,664 +1735,6 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[25.], [105.]], predictions2.eval())
-class _LinearModelTest(test.TestCase):
-
- def test_raises_if_empty_feature_columns(self):
- with self.assertRaisesRegexp(ValueError,
- 'feature_columns must not be empty'):
- get_keras_linear_model_predictions(features={}, feature_columns=[])
-
- def test_should_be_feature_column(self):
- with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
- get_keras_linear_model_predictions(
- features={'a': [[0]]}, feature_columns='NotSupported')
-
- def test_should_be_dense_or_categorical_column(self):
-
- class NotSupportedColumn(fc_old._FeatureColumn):
-
- @property
- def name(self):
- return 'NotSupportedColumn'
-
- def _transform_feature(self, cache):
- pass
-
- @property
- def _parse_example_spec(self):
- pass
-
- with self.assertRaisesRegexp(
- ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
- get_keras_linear_model_predictions(
- features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
-
- def test_does_not_support_dict_columns(self):
- with self.assertRaisesRegexp(
- ValueError, 'Expected feature_columns to be iterable, found dict.'):
- fc.linear_model(
- features={'a': [[0]]},
- feature_columns={'a': fc_old.numeric_column('a')})
-
- def test_raises_if_duplicate_name(self):
- with self.assertRaisesRegexp(
- ValueError, 'Duplicate feature column name found for columns'):
- get_keras_linear_model_predictions(
- features={'a': [[0]]},
- feature_columns=[
- fc_old.numeric_column('a'),
- fc_old.numeric_column('a')
- ])
-
- def test_dense_bias(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- sess.run(price_var.assign([[10.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[15.], [55.]], predictions.eval())
-
- def test_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(features, [wire_cast])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[1005.], [10015.]], predictions.eval())
-
- def test_dense_and_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(features,
- [wire_cast, price])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(bias.assign([5.]))
- sess.run(price_var.assign([[10.]]))
- self.assertAllClose([[1015.], [10065.]], predictions.eval())
-
- def test_dense_and_sparse_column(self):
- """When the column is both dense and sparse, uses sparse tensors."""
-
- class _DenseAndSparseColumn(fc_old._DenseColumn, fc_old._CategoricalColumn):
-
- @property
- def name(self):
- return 'dense_and_sparse_column'
-
- @property
- def _parse_example_spec(self):
- return {self.name: parsing_ops.VarLenFeature(self.dtype)}
-
- def _transform_feature(self, inputs):
- return inputs.get(self.name)
-
- @property
- def _variable_shape(self):
- raise ValueError('Should not use this method.')
-
- def _get_dense_tensor(self,
- inputs,
- weight_collections=None,
- trainable=None):
- raise ValueError('Should not use this method.')
-
- @property
- def _num_buckets(self):
- return 4
-
- def _get_sparse_tensors(self,
- inputs,
- weight_collections=None,
- trainable=None):
- sp_tensor = sparse_tensor.SparseTensor(
- indices=[[0, 0], [1, 0], [1, 1]],
- values=[2, 0, 3],
- dense_shape=[2, 2])
- return fc_old._CategoricalColumn.IdWeightPair(sp_tensor, None)
-
- dense_and_sparse_column = _DenseAndSparseColumn()
- with ops.Graph().as_default():
- sp_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'],
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {dense_and_sparse_column.name: sp_tensor}
- predictions = get_keras_linear_model_predictions(
- features, [dense_and_sparse_column])
- bias = get_linear_model_bias()
- dense_and_sparse_column_var = get_linear_model_column_var(
- dense_and_sparse_column)
- with _initialized_session() as sess:
- sess.run(
- dense_and_sparse_column_var.assign([[10.], [100.], [1000.],
- [10000.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[1005.], [10015.]], predictions.eval())
-
- def test_dense_multi_output(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(
- features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((3,)), bias.eval())
- self.assertAllClose(np.zeros((1, 3)), price_var.eval())
- sess.run(price_var.assign([[10., 100., 1000.]]))
- sess.run(bias.assign([5., 6., 7.]))
- self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]],
- predictions.eval())
-
- def test_sparse_multi_output(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(
- features, [wire_cast], units=3)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((3,)), bias.eval())
- self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
- sess.run(
- wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.],
- [1000., 1100.,
- 1200.], [10000., 11000., 12000.]]))
- sess.run(bias.assign([5., 6., 7.]))
- self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]],
- predictions.eval())
-
- def test_dense_multi_dimension(self):
- price = fc_old.numeric_column('price', shape=2)
- with ops.Graph().as_default():
- features = {'price': [[1., 2.], [5., 6.]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([[0.], [0.]], price_var.eval())
- sess.run(price_var.assign([[10.], [100.]]))
- self.assertAllClose([[210.], [650.]], predictions.eval())
-
- def test_sparse_multi_rank(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = array_ops.sparse_placeholder(dtypes.string)
- wire_value = sparse_tensor.SparseTensorValue(
- values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2]
- indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
- dense_shape=[2, 2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(features, [wire_cast])
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
- self.assertAllClose(
- np.zeros((2, 1)),
- predictions.eval(feed_dict={wire_tensor: wire_value}))
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- self.assertAllClose(
- [[1010.], [11000.]],
- predictions.eval(feed_dict={wire_tensor: wire_value}))
-
- def test_sparse_combiner(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(
- features, [wire_cast], sparse_combiner='mean')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[1005.], [5010.]], predictions.eval())
-
- def test_dense_multi_dimension_multi_output(self):
- price = fc_old.numeric_column('price', shape=2)
- with ops.Graph().as_default():
- features = {'price': [[1., 2.], [5., 6.]]}
- predictions = get_keras_linear_model_predictions(
- features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((3,)), bias.eval())
- self.assertAllClose(np.zeros((2, 3)), price_var.eval())
- sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]]))
- sess.run(bias.assign([2., 3., 4.]))
- self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]],
- predictions.eval())
-
- def test_raises_if_shape_mismatch(self):
- price = fc_old.numeric_column('price', shape=2)
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- with self.assertRaisesRegexp(
- Exception,
- r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
- get_keras_linear_model_predictions(features, [price])
-
- def test_dense_reshaping(self):
- price = fc_old.numeric_column('price', shape=[1, 2])
- with ops.Graph().as_default():
- features = {'price': [[[1., 2.]], [[5., 6.]]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.], [0.]], price_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(price_var.assign([[10.], [100.]]))
- self.assertAllClose([[210.], [650.]], predictions.eval())
-
- def test_dense_multi_column(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
- predictions = get_keras_linear_model_predictions(features,
- [price1, price2])
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.], [0.]], price1_var.eval())
- self.assertAllClose([[0.]], price2_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(price1_var.assign([[10.], [100.]]))
- sess.run(price2_var.assign([[1000.]]))
- sess.run(bias.assign([7.]))
- self.assertAllClose([[3217.], [4657.]], predictions.eval())
-
- def test_fills_cols_to_vars(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
- cols_to_vars = {}
- get_keras_linear_model_predictions(
- features, [price1, price2], cols_to_vars=cols_to_vars)
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
- self.assertAllEqual(cols_to_vars['bias'], [bias])
- self.assertAllEqual(cols_to_vars[price1], [price1_var])
- self.assertAllEqual(cols_to_vars[price2], [price2_var])
-
- def test_fills_cols_to_vars_partitioned_variables(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2', shape=3)
- with ops.Graph().as_default():
- features = {
- 'price1': [[1., 2.], [6., 7.]],
- 'price2': [[3., 4., 5.], [8., 9., 10.]]
- }
- cols_to_vars = {}
- with variable_scope.variable_scope(
- 'linear',
- partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)):
- get_keras_linear_model_predictions(
- features, [price1, price2], cols_to_vars=cols_to_vars)
- with _initialized_session():
- self.assertEqual([0.], cols_to_vars['bias'][0].eval())
- # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables.
- self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval())
- # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and
- # a [1, 1] Variable.
- self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
-
- def test_dense_collection(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- get_keras_linear_model_predictions(
- features, [price], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- self.assertIn(bias, my_vars)
- self.assertIn(price_var, my_vars)
-
- def test_sparse_collection(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- get_keras_linear_model_predictions(
- features, [wire_cast], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- self.assertIn(bias, my_vars)
- self.assertIn(wire_cast_var, my_vars)
-
- def test_dense_trainable_default(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertIn(bias, trainable_vars)
- self.assertIn(price_var, trainable_vars)
-
- def test_sparse_trainable_default(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- get_keras_linear_model_predictions(features, [wire_cast])
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- self.assertIn(bias, trainable_vars)
- self.assertIn(wire_cast_var, trainable_vars)
-
- def test_dense_trainable_false(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- get_keras_linear_model_predictions(features, [price], trainable=False)
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertEqual([], trainable_vars)
-
- def test_sparse_trainable_false(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- get_keras_linear_model_predictions(features, [wire_cast], trainable=False)
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertEqual([], trainable_vars)
-
- def test_column_order(self):
- price_a = fc_old.numeric_column('price_a')
- price_b = fc_old.numeric_column('price_b')
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- features = {
- 'price_a': [[1.]],
- 'price_b': [[3.]],
- 'wire_cast':
- sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- }
- get_keras_linear_model_predictions(
- features, [price_a, wire_cast, price_b],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- self.assertIn('price_a', my_vars[0].name)
- self.assertIn('price_b', my_vars[1].name)
- self.assertIn('wire_cast', my_vars[2].name)
-
- with ops.Graph().as_default() as g:
- features = {
- 'price_a': [[1.]],
- 'price_b': [[3.]],
- 'wire_cast':
- sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- }
- get_keras_linear_model_predictions(
- features, [wire_cast, price_b, price_a],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- self.assertIn('price_a', my_vars[0].name)
- self.assertIn('price_b', my_vars[1].name)
- self.assertIn('wire_cast', my_vars[2].name)
-
- def test_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {
- 'price1': [[1.], [5.], [7.]], # batchsize = 3
- 'price2': [[3.], [4.]] # batchsize = 2
- }
- with self.assertRaisesRegexp(
- ValueError,
- 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- get_keras_linear_model_predictions(features, [price1, price2])
-
- def test_subset_of_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- price3 = fc_old.numeric_column('price3')
- with ops.Graph().as_default():
- features = {
- 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
- 'price2': [[3.], [4.]], # batchsize = 2
- 'price3': [[3.], [4.], [5.]] # batchsize = 3
- }
- with self.assertRaisesRegexp(
- ValueError,
- 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- get_keras_linear_model_predictions(features, [price1, price2, price3])
-
- def test_runtime_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {
- 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
- 'price2': [[3.], [4.]] # batchsize = 2
- }
- predictions = get_keras_linear_model_predictions(features,
- [price1, price2])
- with _initialized_session() as sess:
- with self.assertRaisesRegexp(errors.OpError,
- 'must have the same size and shape'):
- sess.run(
- predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
-
- def test_runtime_batch_size_matches(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {
- 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
- 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
- }
- predictions = get_keras_linear_model_predictions(features,
- [price1, price2])
- with _initialized_session() as sess:
- sess.run(
- predictions,
- feed_dict={
- features['price1']: [[1.], [5.]],
- features['price2']: [[1.], [5.]],
- })
-
- def test_with_numpy_input_fn(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
- price, boundaries=[
- 0.,
- 10.,
- 100.,
- ])
- body_style = fc_old.categorical_column_with_vocabulary_list(
- 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
-
- input_fn = numpy_io.numpy_input_fn(
- x={
- 'price': np.array([-1., 2., 13., 104.]),
- 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
- },
- batch_size=2,
- shuffle=False)
- features = input_fn()
- net = get_keras_linear_model_predictions(features,
- [price_buckets, body_style])
- # self.assertEqual(1 + 3 + 5, net.shape[1])
- with _initialized_session() as sess:
- coord = coordinator.Coordinator()
- threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
-
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
-
- sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
- sess.run(bias.assign([5.]))
-
- self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
-
- coord.request_stop()
- coord.join(threads)
-
- def test_with_1d_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
- price, boundaries=[
- 0.,
- 10.,
- 100.,
- ])
- body_style = fc_old.categorical_column_with_vocabulary_list(
- 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
-
- # Provides 1-dim tensor and dense tensor.
- features = {
- 'price':
- constant_op.constant([
- -1.,
- 12.,
- ]),
- 'body-style':
- sparse_tensor.SparseTensor(
- indices=((0,), (1,)),
- values=('sedan', 'hardtop'),
- dense_shape=(2,)),
- }
- self.assertEqual(1, features['price'].shape.ndims)
- self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
-
- net = get_keras_linear_model_predictions(features,
- [price_buckets, body_style])
- with _initialized_session() as sess:
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
-
- sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
- sess.run(bias.assign([5.]))
-
- self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
-
- def test_with_1d_unknown_shape_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
- price, boundaries=[
- 0.,
- 10.,
- 100.,
- ])
- body_style = fc_old.categorical_column_with_vocabulary_list(
- 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- country = fc_old.categorical_column_with_vocabulary_list(
- 'country', vocabulary_list=['US', 'JP', 'CA'])
-
- # Provides 1-dim tensor and dense tensor.
- features = {
- 'price': array_ops.placeholder(dtypes.float32),
- 'body-style': array_ops.sparse_placeholder(dtypes.string),
- 'country': array_ops.placeholder(dtypes.string),
- }
- self.assertIsNone(features['price'].shape.ndims)
- self.assertIsNone(features['body-style'].get_shape().ndims)
-
- price_data = np.array([-1., 12.])
- body_style_data = sparse_tensor.SparseTensorValue(
- indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,))
- country_data = np.array(['US', 'CA'])
-
- net = get_keras_linear_model_predictions(
- features, [price_buckets, body_style, country])
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
- with _initialized_session() as sess:
- sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
- sess.run(bias.assign([5.]))
-
- self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
- sess.run(
- net,
- feed_dict={
- features['price']: price_data,
- features['body-style']: body_style_data,
- features['country']: country_data
- }))
-
- def test_with_rank_0_feature(self):
- price = fc_old.numeric_column('price')
- features = {
- 'price': constant_op.constant(0),
- }
- self.assertEqual(0, features['price'].shape.ndims)
-
- # Static rank 0 should fail
- with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
- get_keras_linear_model_predictions(features, [price])
-
- # Dynamic rank 0 should fail
- features = {
- 'price': array_ops.placeholder(dtypes.float32),
- }
- net = get_keras_linear_model_predictions(features, [price])
- self.assertEqual(1, net.shape[1])
- with _initialized_session() as sess:
- with self.assertRaisesOpError('Feature .* cannot have rank 0'):
- sess.run(net, feed_dict={features['price']: np.array(1)})
-
-
class FeatureLayerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
@@ -3739,47 +2834,22 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
id_weight_pair.id_tensor.eval())
def test_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_file(
- key='wire',
- vocabulary_file=self._wire_vocabulary_file_name,
- vocabulary_size=self._wire_vocabulary_size,
- num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- wire_column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
- # 'marlo' -> 2: wire_var[2] = 3
- # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
- self.assertAllClose(((3.,), (5.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_file(
+ wire_column = fc.categorical_column_with_vocabulary_file(
key='wire',
vocabulary_file=self._wire_vocabulary_file_name,
vocabulary_size=self._wire_vocabulary_size,
num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
+ self.assertEqual(4, wire_column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((wire_column,))
+ predictions = model({
wire_column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
+ })
+ wire_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
@@ -4140,45 +3210,21 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
id_weight_pair.id_tensor.eval())
def test_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_list(
- key='aaa',
- vocabulary_list=('omar', 'stringer', 'marlo'),
- num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- wire_column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
- # 'marlo' -> 2: wire_var[2] = 3
- # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
- self.assertAllClose(((3.,), (5.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_list(
+ wire_column = fc.categorical_column_with_vocabulary_list(
key='aaa',
vocabulary_list=('omar', 'stringer', 'marlo'),
num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
+ self.assertEqual(4, wire_column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((wire_column,))
+ predictions = model({
wire_column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
+ })
+ wire_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
@@ -4398,39 +3444,18 @@ class IdentityCategoricalColumnTest(test.TestCase):
}))
def test_linear_model(self):
- column = fc_old.categorical_column_with_identity(key='aaa', num_buckets=3)
- self.assertEqual(3, column.num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- weight_var.assign(((1.,), (2.,), (3.,))).eval()
- # weight_var[0] = 1
- # weight_var[2] + weight_var[1] = 3+2 = 5
- self.assertAllClose(((1.,), (5.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- column = fc_old.categorical_column_with_identity(key='aaa', num_buckets=3)
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
self.assertEqual(3, column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,))
+ predictions = model({
column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(0, 2, 1),
dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
+ })
+ weight_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
@@ -4656,27 +3681,8 @@ class IndicatorColumnTest(test.TestCase):
self.assertAllEqual([[0., 1., 1.]], indicator_tensor.eval())
def test_linear_model(self):
- animal = fc_old.indicator_column(
- fc_old.categorical_column_with_identity('animal', num_buckets=4))
- with ops.Graph().as_default():
- features = {
- 'animal':
- sparse_tensor.SparseTensor(
- indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
- }
-
- predictions = fc.linear_model(features, [animal])
- weight_var = get_linear_model_column_var(animal)
- with _initialized_session():
- # All should be zero-initialized.
- self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
- self.assertAllClose([[0.]], predictions.eval())
- weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
- self.assertAllClose([[2. + 3.]], predictions.eval())
-
- def test_keras_linear_model(self):
- animal = fc_old.indicator_column(
- fc_old.categorical_column_with_identity('animal', num_buckets=4))
+ animal = fc.indicator_column(
+ fc.categorical_column_with_identity('animal', num_buckets=4))
with ops.Graph().as_default():
features = {
'animal':
@@ -4684,8 +3690,9 @@ class IndicatorColumnTest(test.TestCase):
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
}
- predictions = get_keras_linear_model_predictions(features, [animal])
- weight_var = get_linear_model_column_var(animal)
+ model = fc.LinearModel([animal])
+ predictions = model(features)
+ weight_var, _ = model.variables
with _initialized_session():
# All should be zero-initialized.
self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
@@ -5137,17 +4144,16 @@ class EmbeddingColumnTest(test.TestCase):
return zeros_embedding_values
# Build columns.
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_initializer)
with ops.Graph().as_default():
- predictions = fc.linear_model({
- categorical_column.name: sparse_input
- }, (embedding_column,))
+ model = fc.LinearModel((embedding_column,))
+ predictions = model({categorical_column.name: sparse_input})
expected_var_names = (
'linear_model/bias_weights:0',
'linear_model/aaa_embedding/weights:0',
@@ -5189,82 +4195,6 @@ class EmbeddingColumnTest(test.TestCase):
# = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
- def test_keras_linear_model(self):
- # Inputs.
- batch_size = 4
- vocabulary_size = 3
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- # example 2, ids []
- # example 3, ids [1]
- indices=((0, 0), (1, 0), (1, 4), (3, 0)),
- values=(2, 0, 1, 1),
- dense_shape=(batch_size, 5))
-
- # Embedding variable.
- embedding_dimension = 2
- embedding_shape = (vocabulary_size, embedding_dimension)
- zeros_embedding_values = np.zeros(embedding_shape)
-
- def _initializer(shape, dtype, partition_info):
- self.assertAllEqual(embedding_shape, shape)
- self.assertEqual(dtypes.float32, dtype)
- self.assertIsNone(partition_info)
- return zeros_embedding_values
-
- # Build columns.
- categorical_column = fc_old.categorical_column_with_identity(
- key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc_old.embedding_column(
- categorical_column,
- dimension=embedding_dimension,
- initializer=_initializer)
-
- with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
- categorical_column.name: sparse_input
- }, (embedding_column,))
- expected_var_names = (
- 'linear_model/bias_weights:0',
- 'linear_model/aaa_embedding/weights:0',
- 'linear_model/aaa_embedding/embedding_weights:0',
- )
- self.assertItemsEqual(
- expected_var_names,
- [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
- trainable_vars = {
- v.name: v
- for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- }
- self.assertItemsEqual(expected_var_names, trainable_vars.keys())
- bias = trainable_vars['linear_model/bias_weights:0']
- embedding_weights = trainable_vars[
- 'linear_model/aaa_embedding/embedding_weights:0']
- linear_weights = trainable_vars['linear_model/aaa_embedding/weights:0']
- with _initialized_session():
- # Predictions with all zero weights.
- self.assertAllClose(np.zeros((1,)), bias.eval())
- self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
- self.assertAllClose(
- np.zeros((embedding_dimension, 1)), linear_weights.eval())
- self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
-
- # Predictions with all non-zero weights.
- embedding_weights.assign((
- (1., 2.), # id 0
- (3., 5.), # id 1
- (7., 11.) # id 2
- )).eval()
- linear_weights.assign(((4.,), (6.,))).eval()
- # example 0, ids [2], embedding[0] = [7, 11]
- # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
- # example 2, ids [], embedding[2] = [0, 0]
- # example 3, ids [1], embedding[3] = [3, 5]
- # sum(embeddings * linear_weights)
- # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
- self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
-
def test_feature_layer(self):
# Inputs.
vocabulary_size = 3
@@ -5765,27 +4695,31 @@ class SharedEmbeddingColumnTest(test.TestCase):
return zeros_embedding_values
# Build columns.
- categorical_column_a = fc_old.categorical_column_with_identity(
+ categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc_old.categorical_column_with_identity(
+ categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
initializer=_initializer)
with ops.Graph().as_default():
- predictions = fc.linear_model({
+ model = fc.LinearModel(
+ (embedding_column_a, embedding_column_b),
+ shared_state_manager=fc.SharedEmbeddingStateManager())
+ predictions = model({
categorical_column_a.name: input_a,
- categorical_column_b.name: input_b,
- }, (embedding_column_a, embedding_column_b))
+ categorical_column_b.name: input_b
+ })
+
# Linear weights do not follow the column name. But this is a rare use
# case, and fixing it would add too much complexity to the code.
expected_var_names = (
'linear_model/bias_weights:0',
- 'linear_model/aaa_bbb_shared_embedding/weights:0',
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0',
+ 'linear_model/aaa_shared_embedding/weights:0',
+ 'shared_embedding_state_manager/aaa_bbb_shared_embedding:0',
+ 'linear_model/bbb_shared_embedding/weights:0',
)
self.assertItemsEqual(
expected_var_names,
@@ -5797,102 +4731,11 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertItemsEqual(expected_var_names, trainable_vars.keys())
bias = trainable_vars['linear_model/bias_weights:0']
embedding_weights = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
+ 'shared_embedding_state_manager/aaa_bbb_shared_embedding:0']
linear_weights_a = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/weights:0']
+ 'linear_model/aaa_shared_embedding/weights:0']
linear_weights_b = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0']
- with _initialized_session():
- # Predictions with all zero weights.
- self.assertAllClose(np.zeros((1,)), bias.eval())
- self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
- self.assertAllClose(
- np.zeros((embedding_dimension, 1)), linear_weights_a.eval())
- self.assertAllClose(
- np.zeros((embedding_dimension, 1)), linear_weights_b.eval())
- self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
-
- # Predictions with all non-zero weights.
- embedding_weights.assign((
- (1., 2.), # id 0
- (3., 5.), # id 1
- (7., 11.) # id 2
- )).eval()
- linear_weights_a.assign(((4.,), (6.,))).eval()
- # example 0, ids [2], embedding[0] = [7, 11]
- # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
- # sum(embeddings * linear_weights)
- # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29]
- linear_weights_b.assign(((3.,), (5.,))).eval()
- # example 0, ids [0], embedding[0] = [1, 2]
- # example 1, ids [], embedding[1] = 0, 0]
- # sum(embeddings * linear_weights)
- # = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
- self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
-
- def test_keras_linear_model(self):
- # Inputs.
- batch_size = 2
- vocabulary_size = 3
- # -1 values are ignored.
- input_a = np.array([
- [2, -1, -1], # example 0, ids [2]
- [0, 1, -1]
- ]) # example 1, ids [0, 1]
- input_b = np.array([
- [0, -1, -1], # example 0, ids [0]
- [-1, -1, -1]
- ]) # example 1, ids []
-
- # Embedding variable.
- embedding_dimension = 2
- embedding_shape = (vocabulary_size, embedding_dimension)
- zeros_embedding_values = np.zeros(embedding_shape)
-
- def _initializer(shape, dtype, partition_info):
- self.assertAllEqual(embedding_shape, shape)
- self.assertEqual(dtypes.float32, dtype)
- self.assertIsNone(partition_info)
- return zeros_embedding_values
-
- # Build columns.
- categorical_column_a = fc_old.categorical_column_with_identity(
- key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc_old.categorical_column_with_identity(
- key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
- [categorical_column_a, categorical_column_b],
- dimension=embedding_dimension,
- initializer=_initializer)
-
- with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
- categorical_column_a.name: input_a,
- categorical_column_b.name: input_b,
- }, (embedding_column_a, embedding_column_b))
- # Linear weights do not follow the column name. But this is a rare use
- # case, and fixing it would add too much complexity to the code.
- expected_var_names = (
- 'linear_model/bias_weights:0',
- 'linear_model/aaa_bbb_shared_embedding/weights:0',
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0',
- )
- self.assertItemsEqual(
- expected_var_names,
- [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
- trainable_vars = {
- v.name: v
- for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- }
- self.assertItemsEqual(expected_var_names, trainable_vars.keys())
- bias = trainable_vars['linear_model/bias_weights:0']
- embedding_weights = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
- linear_weights_a = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/weights:0']
- linear_weights_b = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0']
+ 'linear_model/bbb_shared_embedding/weights:0']
with _initialized_session():
# Predictions with all zero weights.
self.assertAllClose(np.zeros((1,)), bias.eval())
@@ -6291,13 +5134,14 @@ class WeightedCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2)),
weight_tensor.eval())
- def test_keras_linear_model(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ def test_linear_model(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,))
+ predictions = model({
'ids':
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
@@ -6308,9 +5152,8 @@ class WeightedCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=(.5, 1., .1),
dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
+ })
+ weight_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
@@ -6321,15 +5164,16 @@ class WeightedCategoricalColumnTest(test.TestCase):
# = 3*1 + 2*.1 = 3+.2 = 3.2
self.assertAllClose(((.5,), (3.2,)), predictions.eval())
- def test_keras_linear_model_mismatched_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ def test_linear_model_mismatched_shape(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError,
r'Dimensions.*are not compatible'):
- get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,))
+ model({
'ids':
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
@@ -6340,122 +5184,23 @@ class WeightedCategoricalColumnTest(test.TestCase):
indices=((0, 0), (0, 1), (1, 0), (1, 1)),
values=(.5, 11., 1., .1),
dense_shape=(2, 2))
- }, (column,))
-
- def test_keras_linear_model_mismatched_dense_values(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions(
- {
- 'ids':
- sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': ((.5,), (1.,))
- }, (column,),
- sparse_combiner='mean')
- # Disabling the constant folding optimizer here since it changes the
- # error message differently on CPU and GPU.
- config = config_pb2.ConfigProto()
- config.graph_options.rewrite_options.constant_folding = (
- rewriter_config_pb2.RewriterConfig.OFF)
- with _initialized_session(config):
- with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
- predictions.eval()
+ })
- def test_keras_linear_model_mismatched_dense_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ def test_linear_model_mismatched_dense_values(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,), sparse_combiner='mean')
+ predictions = model({
'ids':
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(0, 2, 1),
dense_shape=(2, 2)),
- 'values': ((.5,), (1.,), (.1,))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- weight_var.assign(((1.,), (2.,), (3.,))).eval()
- # weight_var[0] * weights[0, 0] = 1 * .5 = .5
- # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
- # = 3*1 + 2*.1 = 3+.2 = 3.2
- self.assertAllClose(((.5,), (3.2,)), predictions.eval())
-
- def test_linear_model(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- 'ids': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(.5, 1., .1),
- dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- weight_var.assign(((1.,), (2.,), (3.,))).eval()
- # weight_var[0] * weights[0, 0] = 1 * .5 = .5
- # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
- # = 3*1 + 2*.1 = 3+.2 = 3.2
- self.assertAllClose(((.5,), (3.2,)), predictions.eval())
-
- def test_linear_model_mismatched_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- with self.assertRaisesRegexp(
- ValueError, r'Dimensions.*are not compatible'):
- fc.linear_model({
- 'ids': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (0, 1), (1, 0), (1, 1)),
- values=(.5, 11., 1., .1),
- dense_shape=(2, 2))
- }, (column,))
-
- def test_linear_model_mismatched_dense_values(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- predictions = fc.linear_model(
- {
- 'ids':
- sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': ((.5,), (1.,))
- }, (column,),
- sparse_combiner='mean')
+ 'values': ((.5,), (1.,))
+ })
# Disabling the constant folding optimizer here since it changes the
# error message differently on CPU and GPU.
config = config_pb2.ConfigProto()
@@ -6466,20 +5211,21 @@ class WeightedCategoricalColumnTest(test.TestCase):
predictions.eval()
def test_linear_model_mismatched_dense_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = fc.linear_model({
- 'ids': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
+ model = fc.LinearModel((column,))
+ predictions = model({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
'values': ((.5,), (1.,), (.1,))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
+ })
+ weight_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index cd0b03be43..6673bc5561 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -24,8 +24,8 @@ from collections import OrderedDict
import contextlib
import gc
import itertools
-import os
import math
+import os
import random
import re
import tempfile
@@ -402,11 +402,14 @@ def with_c_shapes(cls):
return cls
-def enable_cond_v2(fn):
- """Decorator for enabling CondV2 on a test.
+def enable_control_flow_v2(fn):
+ """Decorator for enabling CondV2 and WhileV2 on a test.
- Note this enables using CondV2 after running the test class's setup/teardown
- methods.
+ Note this enables using CondV2 and WhileV2 after running the test class's
+ setup/teardown methods.
+
+ In addition to this, callers must import the while_v2 module in order to set
+ the _while_v2 module in control_flow_ops.
Args:
fn: the function to be wrapped
@@ -416,21 +419,56 @@ def enable_cond_v2(fn):
"""
def wrapper(*args, **kwargs):
- prev_value = control_flow_ops.ENABLE_COND_V2
+ enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2
+ enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2
control_flow_ops.ENABLE_COND_V2 = True
+ control_flow_ops.ENABLE_WHILE_V2 = True
try:
fn(*args, **kwargs)
finally:
- control_flow_ops.ENABLE_COND_V2 = prev_value
+ control_flow_ops.ENABLE_COND_V2 = enable_cond_v2_old
+ control_flow_ops.ENABLE_WHILE_V2 = enable_while_v2_old
return wrapper
-def with_cond_v2(cls):
- """Adds methods that call original methods but with CondV2 enabled.
+def with_control_flow_v2(cls):
+ """Adds methods that call original methods with WhileV2 and CondV2 enabled.
- Note this enables CondV2 in new methods after running the test class's
- setup method.
+ Note this enables CondV2 and WhileV2 in new methods after running the test
+ class's setup method.
+
+ In addition to this, callers must import the while_v2 module in order to set
+ the _while_v2 module in control_flow_ops.
+
+ If a test function has _disable_control_flow_v2 attr set to True (using the
+ @disable_control_flow_v2 decorator), the v2 function is not generated for it.
+
+ Example:
+
+ @test_util.with_control_flow_v2
+ class ControlFlowTest(test.TestCase):
+
+ def testEnabledForV2(self):
+ ...
+
+ @test_util.disable_control_flow_v2("b/xyzabc")
+ def testDisabledForV2(self):
+ ...
+
+ Generated class:
+ class ControlFlowTest(test.TestCase):
+
+ def testEnabledForV2(self):
+ ...
+
+ def testEnabledForV2WithControlFlowV2(self):
+ // Enable V2 flags.
+ testEnabledForV2(self)
+ // Restore V2 flags.
+
+ def testDisabledForV2(self):
+ ...
Args:
cls: class to decorate
@@ -438,15 +476,33 @@ def with_cond_v2(cls):
Returns:
cls with new test methods added
"""
- if control_flow_ops.ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2:
return cls
for name, value in cls.__dict__.copy().items():
- if callable(value) and name.startswith("test"):
- setattr(cls, name + "WithCondV2", enable_cond_v2(value))
+ if (callable(value) and name.startswith("test") and
+ not getattr(value, "_disable_control_flow_v2", False)):
+ setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value))
return cls
+def disable_control_flow_v2(unused_msg):
+ """Decorator for a function in a with_control_flow_v2 enabled test class.
+
+ Blocks the function from being run with v2 control flow ops.
+
+ Args:
+ unused_msg: Reason for disabling.
+
+ Returns:
+ The wrapped function with _disable_control_flow_v2 attr set to True.
+ """
+ def wrapper(func):
+ func._disable_control_flow_v2 = True
+ return func
+ return wrapper
+
+
def assert_no_new_pyobjects_executing_eagerly(f):
"""Decorator for asserting that no new Python objects persist after a test.
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 4589c821e5..584facc859 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -1511,12 +1511,8 @@ def batch_dot(x, y, axes=None):
out = math_ops.reduce_sum(
math_ops.multiply(array_ops.transpose(x, [1, 0]), y), axes[1])
else:
- if axes is not None:
- adj_x = None if axes[0] == ndim(x) - 1 else True
- adj_y = True if axes[1] == ndim(y) - 1 else None
- else:
- adj_x = None
- adj_y = None
+ adj_x = None if axes[0] == ndim(x) - 1 else True
+ adj_y = True if axes[1] == ndim(y) - 1 else None
out = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
if diff:
if x_ndim > y_ndim:
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index e98b131ae6..a75ce30d31 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import collections as collections_lib
import enum # pylint: disable=g-bad-import-order
+import functools
import inspect # Necessary supplement to tf_inspect to deal with variadic args.
import numpy as np
@@ -160,9 +161,13 @@ class Layer(checkpointable.CheckpointableBase):
self._trainable_weights = []
self._non_trainable_weights = []
self._updates = []
- # When executing eagerly, _losses is a list of zero-argument lambdas which
- # return tensors. When using graph execution, _losses is a list of ops.
+ # A list of zero-argument lambdas which return Tensors, used for variable
+ # regularizers.
+ self._callable_losses = []
+ # A list of Tensors containing activity regularizers and losses manually
+ # added through `add_loss`. Empty when executing eagerly.
self._losses = []
+ self._in_call = False # Flag for error checking in add_loss
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
self._call_fn_args = function_utils.fn_args(self.call)
self._compute_previous_mask = ('mask' in self._call_fn_args or
@@ -359,20 +364,20 @@ class Layer(checkpointable.CheckpointableBase):
def losses(self):
"""Losses which are associated with this `Layer`.
- Note that when executing eagerly, getting this property evaluates
- regularizers. When using graph execution, variable regularization ops have
- already been created and are simply returned here.
+ Variable regularization tensors are created when this property is accessed,
+ so it is eager safe: accessing `losses` under a `tf.GradientTape` will
+ propagate gradients back to the corresponding variables.
Returns:
A list of tensors.
"""
- if context.executing_eagerly():
- # _losses may only contain variable regularization losses when executing
- # eagerly, and they have been saved as lambdas to be executed when
- # requested.
- return [regularizer() for regularizer in self._losses]
- else:
- return self._losses
+ collected_losses = []
+ collected_losses.extend(self._losses)
+ for regularizer in self._callable_losses:
+ loss_tensor = regularizer()
+ if loss_tensor is not None:
+ collected_losses.append(loss_tensor)
+ return collected_losses
@doc_controls.for_subclass_implementers
def add_loss(self, losses, inputs=None):
@@ -393,7 +398,9 @@ class Layer(checkpointable.CheckpointableBase):
from `Layer.call()`).
Arguments:
- losses: Loss tensor, or list/tuple of tensors.
+ losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses
+ may also be zero-argument callables which create a loss tensor. Only
+ callable losses are supported when executing eagerly.
inputs: If anything other than None is passed, it signals the losses
are conditional on some of the layer's inputs,
and thus they should only be run where these inputs are available.
@@ -403,29 +410,45 @@ class Layer(checkpointable.CheckpointableBase):
(e.g. weight regularization losses).
Raises:
- RuntimeError: If called in Eager mode.
+ RuntimeError: If called in Eager mode with a `Tensor` rather than a
+ callable, or if `inputs` is not None.
"""
- if context.executing_eagerly():
- # TODO(fchollet): it should be possible (and highly desirable) to support
- # `add_loss` in eager mode. This allows great convenience and flexibility
- # in defining custom losses on the fly (e.g. in VAEs).
- # Simply appending the loss value to `self._losses`
- # is the correct behavior.
- # The only caveat is that we need to force the user to only call
- # `add_loss` from inside a model or Layer's `call` method
- # (otherwise the loss computation cannot be backproped through).
- raise RuntimeError('Layer.add_loss not supported in Eager mode.')
-
+ executing_eagerly = context.executing_eagerly()
+ if executing_eagerly:
+ if inputs is not None:
+ raise RuntimeError(
+ 'Activity regularization (via the "inputs" argument to '
+ 'Layer.add_loss) is not supported when executing eagerly. Consider '
+ 'returning activity regularization losses from a Model\'s call() '
+ 'method.')
+ if getattr(self, '_in_call', False):
+ # TODO(psv): Support activity regularization and a way to reset losses.
+ raise RuntimeError(
+ 'Adding losses inside a Layer\'s call() method is not currently '
+ 'supported when executing eagerly. Please file a feature request '
+ 'if you need this limitation lifted.')
losses = generic_utils.to_list(losses)
- losses = [ops.convert_to_tensor(loss, dtype=backend.floatx())
- if not tensor_util.is_tensor(loss) else loss for loss in losses]
- self._losses += losses
- if inputs is None:
- for loss in losses:
- loss._unconditional_loss = True # pylint: disable=protected-access
- else:
- for loss in losses:
- loss._unconditional_loss = False # pylint: disable=protected-access
+
+ def _tag_unconditional(loss):
+ if callable(loss):
+ loss = loss()
+ if loss is None:
+ return None # Will be filtered out when computing the .losses property
+ if not tensor_util.is_tensor(loss):
+ loss = ops.convert_to_tensor(loss, dtype=backend.floatx())
+ loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access
+ return loss
+
+ for loss in losses:
+ if callable(loss):
+ self._callable_losses.append(
+ functools.partial(_tag_unconditional, loss))
+ else:
+ if executing_eagerly:
+ raise RuntimeError(
+ 'Layer.add_loss only supported for zero-argument lambdas when '
+ 'executing eagerly.')
+ self._losses.append(_tag_unconditional(loss))
def get_losses_for(self, inputs):
"""Retrieves losses relevant to a specific set of inputs.
@@ -599,56 +622,20 @@ class Layer(checkpointable.CheckpointableBase):
return variable
def _handle_weight_regularization(self, name, variable, regularizer):
- # `init_graph` should point to the graph in which variable initialization
- # will occur; it should be None if and only if initialization will take
- # place in the eager context.
- init_graph = None
- if not context.executing_eagerly():
- default_graph = ops.get_default_graph()
- if default_graph.building_function:
- with ops.init_scope():
- # Retrieve the variables from the graph into which variables
- # will be lifted; if initialization ops will be lifted into
- # the eager context, then there is nothing to retrieve, since variable
- # collections are not supported when eager execution is enabled.
- if not context.executing_eagerly():
- init_graph = ops.get_default_graph()
- else:
- # Initialization ops will not be lifted out of the default graph.
- init_graph = default_graph
-
- if init_graph is not None: # pylint: disable=protected-access
- # The variable was created and initialized in a graph.
- if regularizer:
- if isinstance(variable, tf_variables.PartitionedVariable):
- for v in variable:
- with ops.colocate_with(v.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(v)
- if regularization is not None:
- self.add_loss(regularization)
- else:
- with ops.colocate_with(variable.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(variable)
- if regularization is not None:
- self.add_loss(regularization)
- elif regularizer: # initialization took place in an eager context
- if isinstance(variable, tf_variables.PartitionedVariable):
- raise RuntimeError(
- 'Partitioned variable regularization is not yet '
- 'supported when executing eagerly. File a feature request'
- 'if this is important to you.')
- # Save a zero-argument lambda which runs the regularizer on the
- # variable, to be executed when `Layer.losses` is requested.
- # This makes losses responsive to variable updates when executing
- # eagerly.
- #
- # TODO(akshayka): Do the same for graphs as well, so that losses
- # collected in a while_loop can be run outside its control flow
- # context and so that losses won't be swallowed up by graph functions
- # (i.e., `.losses()` should always create regularizers).
- self._losses.append(lambda: regularizer(variable))
+ """Create lambdas which compute regularization losses."""
+
+ def _loss_for_variable(v):
+ """Creates a regularization loss `Tensor` for variable `v`."""
+ with ops.colocate_with(v):
+ with ops.name_scope(name + '/Regularizer'):
+ regularization = regularizer(v)
+ return regularization
+
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ for v in variable:
+ self.add_loss(functools.partial(_loss_for_variable, v))
+ else:
+ self.add_loss(functools.partial(_loss_for_variable, variable))
def _handle_activity_regularization(self, inputs, outputs):
# Apply activity regularization.
@@ -766,7 +753,9 @@ class Layer(checkpointable.CheckpointableBase):
self._assert_input_compatibility(inputs)
if not in_deferred_mode:
+ self._in_call = True
outputs = self.call(inputs, *args, **kwargs)
+ self._in_call = False
if outputs is None:
raise ValueError('A layer\'s `call` method should return a Tensor '
'or a list of Tensors, not None (layer: ' +
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 46bffd7068..5091cac836 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -851,7 +851,8 @@ class Model(Network):
# able to clone a Dataset on multiple workers we can remove this lambda.
result = self._distribution_strategy.distribute_dataset(lambda: x)
iterator = result.make_initializable_iterator()
- K.get_session().run(iterator.initializer)
+ with self._distribution_strategy.scope():
+ K.get_session().run(iterator.initializer)
training_utils.validate_iterator_input(x, y, sample_weight,
validation_split)
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 1b64f904d5..a6470458d2 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -112,100 +112,99 @@ def fit_loop(
dataset_targets = distributed_training_utils.flatten_perdevice_values(
current_strategy, targets)
- # Create a train function that is composed of all the parameters above.
- distributed_train_function = K.Function(
- all_inputs, all_outputs,
- updates=all_updates,
- name='distributed_train_function',
- **all_session_args)
-
- # We need to set sample_weights to None since there are sample weight
- # placeholders that are created with default values.
- sample_weights = [None for _ in range(len(model.outputs) *
- current_strategy.num_towers)]
- if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
- ins = dataset_inputs + dataset_targets + sample_weights + [1]
- else:
- ins = dataset_inputs + dataset_targets
+ # Create a train function that is composed of all the parameters above.
+ distributed_train_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_train_function',
+ **all_session_args)
+
+ # We need to set sample_weights to None since there are sample weight
+ # placeholders that are created with default values.
+ sample_weights = [None for _ in range(len(model.outputs) *
+ current_strategy.num_towers)]
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + dataset_targets + sample_weights + [1]
+ else:
+ ins = dataset_inputs + dataset_targets
- do_validation = False
- if validation_steps:
- do_validation = True
+ do_validation = False
+ if validation_steps:
+ do_validation = True
- # Copy the weights from the original model to each of the replicated models.
- orig_model_weights = model.get_weights()
- with current_strategy.scope():
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- callbacks = cbks.configure_callbacks(
- callbacks,
- model,
- do_validation=do_validation,
- val_inputs=None,
- val_targets=None,
- epochs=epochs,
- steps_per_epoch=steps_per_epoch,
- verbose=verbose)
- out_labels = model.metrics_names or []
- callbacks.on_train_begin()
-
- assert steps_per_epoch is not None
-
- for epoch in range(initial_epoch, epochs):
- # Reset stateful metrics
- for m in model.stateful_metric_functions:
- m.reset_states()
- callbacks.on_epoch_begin(epoch)
- epoch_logs = {}
- for step_index in range(steps_per_epoch):
- batch_logs = {'batch': step_index, 'size': 1}
- callbacks.on_batch_begin(step_index, batch_logs)
- try:
- outs = distributed_train_function(ins)
- except errors.OutOfRangeError:
- logging.warning('Your dataset iterator ran out of data; '
- 'interrupting training. Make sure that your dataset '
- 'can generate at least `steps_per_epoch * epochs` '
- 'batches (in this case, %d batches).' %
- steps_per_epoch * epochs)
- break
-
- if not isinstance(outs, list):
- outs = [outs]
-
- outs = _aggregate_metrics_across_towers(current_strategy.num_towers,
- out_labels,
- model.stateful_metric_names, outs)
- for l, o in zip(out_labels, outs):
- batch_logs[l] = o
- callbacks.on_batch_end(step_index, batch_logs)
+ callbacks = cbks.configure_callbacks(
+ callbacks,
+ model,
+ do_validation=do_validation,
+ val_inputs=None,
+ val_targets=None,
+ epochs=epochs,
+ steps_per_epoch=steps_per_epoch,
+ verbose=verbose)
+ out_labels = model.metrics_names or []
+ callbacks.on_train_begin()
+
+ assert steps_per_epoch is not None
+
+ for epoch in range(initial_epoch, epochs):
+ # Reset stateful metrics
+ for m in model.stateful_metric_functions:
+ m.reset_states()
+ callbacks.on_epoch_begin(epoch)
+ epoch_logs = {}
+ for step_index in range(steps_per_epoch):
+ batch_logs = {'batch': step_index, 'size': 1}
+ callbacks.on_batch_begin(step_index, batch_logs)
+ try:
+ outs = distributed_train_function(ins)
+ except errors.OutOfRangeError:
+ logging.warning('Your dataset iterator ran out of data; '
+ 'interrupting training. Make sure that your dataset '
+ 'can generate at least `steps_per_epoch * epochs` '
+ 'batches (in this case, %d batches).' %
+ steps_per_epoch * epochs)
+ break
+
+ if not isinstance(outs, list):
+ outs = [outs]
+
+ outs = _aggregate_metrics_across_towers(current_strategy.num_towers,
+ out_labels,
+ model.stateful_metric_names,
+ outs)
+ for l, o in zip(out_labels, outs):
+ batch_logs[l] = o
+ callbacks.on_batch_end(step_index, batch_logs)
+ if callbacks.model.stop_training:
+ break
+ if do_validation:
+ val_outs = test_loop(
+ model,
+ val_iterator,
+ steps=validation_steps,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(out_labels, val_outs):
+ epoch_logs['val_' + l] = o
+
+ callbacks.on_epoch_end(epoch, epoch_logs)
if callbacks.model.stop_training:
break
- if do_validation:
- val_outs = test_loop(
- model,
- val_iterator,
- steps=validation_steps,
- verbose=0)
- if not isinstance(val_outs, list):
- val_outs = [val_outs]
- # Same labels assumed.
- for l, o in zip(out_labels, val_outs):
- epoch_logs['val_' + l] = o
+ callbacks.on_train_end()
- callbacks.on_epoch_end(epoch, epoch_logs)
- if callbacks.model.stop_training:
- break
- callbacks.on_train_end()
-
- # Copy the weights back from the replicated model to the original model.
- with current_strategy.scope():
+ # Copy the weights back from the replicated model to the original model.
updated_weights = current_strategy.unwrap(
model._grouped_model)[0].get_weights()
model.set_weights(updated_weights)
- return model.history
+ return model.history
def _experimental_fit_loop(
@@ -427,66 +426,65 @@ def test_loop(model, iterator, verbose=0, steps=None):
dataset_targets = distributed_training_utils.flatten_perdevice_values(
current_strategy, targets)
- distributed_test_function = K.Function(
- all_inputs, all_outputs,
- updates=all_updates,
- name='distributed_test_function',
- **all_session_args)
-
- # We need to set sample_weights to None since there are sample weight
- # placeholders that are created with default values.
- sample_weights = [None for _ in range(len(model.outputs) *
- current_strategy.num_towers)]
- if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
- ins = dataset_inputs + dataset_targets + sample_weights + [0]
- else:
- ins = dataset_inputs + dataset_targets
+ distributed_test_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_test_function',
+ **all_session_args)
- for m in model.stateful_metric_functions:
- m.reset_states()
- stateful_metric_indices = [
- i for i, name in enumerate(model.metrics_names)
- if str(name) in model.stateful_metric_names
- ]
+ # We need to set sample_weights to None since there are sample weight
+ # placeholders that are created with default values.
+ sample_weights = [None for _ in range(len(model.outputs) *
+ current_strategy.num_towers)]
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + dataset_targets + sample_weights + [0]
+ else:
+ ins = dataset_inputs + dataset_targets
- outs = []
- if verbose == 1:
- progbar = Progbar(target=steps)
+ for m in model.stateful_metric_functions:
+ m.reset_states()
+ stateful_metric_indices = [
+ i for i, name in enumerate(model.metrics_names)
+ if str(name) in model.stateful_metric_names
+ ]
- # Copy the weights from the original model to each of the replicated models.
- orig_model_weights = model.get_weights()
- with current_strategy.scope():
+ outs = []
+ if verbose == 1:
+ progbar = Progbar(target=steps)
+
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- assert steps is not None
- for step in range(steps):
- batch_outs = distributed_test_function(ins)
- batch_outs = _aggregate_metrics_across_towers(
- current_strategy.num_towers, model.metrics_names,
- model.stateful_metric_names, batch_outs)
- if isinstance(batch_outs, list):
- if step == 0:
- outs = [0.] * len(batch_outs)
- for i, batch_out in enumerate(batch_outs):
- if i in stateful_metric_indices:
- outs[i] = batch_out
- else:
- outs[i] += batch_out
- else:
- if step == 0:
- outs.append(0.)
- outs[0] += batch_outs
- if verbose >= 1:
- progbar.update(step + 1)
- for i in range(len(outs)):
- if i not in stateful_metric_indices:
- outs[i] /= steps
+ assert steps is not None
+ for step in range(steps):
+ batch_outs = distributed_test_function(ins)
+ batch_outs = _aggregate_metrics_across_towers(
+ current_strategy.num_towers, model.metrics_names,
+ model.stateful_metric_names, batch_outs)
+ if isinstance(batch_outs, list):
+ if step == 0:
+ outs = [0.] * len(batch_outs)
+ for i, batch_out in enumerate(batch_outs):
+ if i in stateful_metric_indices:
+ outs[i] = batch_out
+ else:
+ outs[i] += batch_out
+ else:
+ if step == 0:
+ outs.append(0.)
+ outs[0] += batch_outs
+ if verbose >= 1:
+ progbar.update(step + 1)
+ for i in range(len(outs)):
+ if i not in stateful_metric_indices:
+ outs[i] /= steps
- if len(outs) == 1:
- return outs[0]
- return outs
+ if len(outs) == 1:
+ return outs[0]
+ return outs
def _experimental_test_loop(model, iterator, verbose=0, steps=None):
@@ -647,51 +645,50 @@ def predict_loop(model, iterator, verbose=0, steps=None):
dataset_inputs = distributed_training_utils.flatten_perdevice_values(
current_strategy, inputs)
- distributed_predict_function = K.Function(
- all_inputs, all_outputs,
- updates=all_updates,
- name='distributed_predict_function',
- **all_session_args)
+ distributed_predict_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_predict_function',
+ **all_session_args)
- if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
- ins = dataset_inputs + [0]
- else:
- ins = dataset_inputs
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + [0]
+ else:
+ ins = dataset_inputs
- if verbose == 1:
- progbar = Progbar(target=steps)
+ if verbose == 1:
+ progbar = Progbar(target=steps)
- # Copy the weights from the original model to each of the replicated models.
- orig_model_weights = model.get_weights()
- with current_strategy.scope():
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- if steps is not None:
- # Since we do not know how many samples we will see, we cannot pre-allocate
- # the returned Numpy arrays. Instead, we store one array per batch seen
- # and concatenate them upon returning.
- unconcatenated_outs = []
- for step in range(steps):
- batch_outs = distributed_predict_function(ins)
- if not isinstance(batch_outs, list):
- batch_outs = [batch_outs]
- if step == 0:
- for _ in batch_outs:
- unconcatenated_outs.append([])
- # TODO(anjalisridhar): Should combine the outputs from multiple towers
- # correctly here.
- for i, batch_out in enumerate(batch_outs):
- unconcatenated_outs[i].append(batch_out)
- if verbose >= 1:
- progbar.update(step + 1)
- if len(unconcatenated_outs) == 1:
- return np.concatenate(unconcatenated_outs[0], axis=0)
- return [
- np.concatenate(unconcatenated_outs[i], axis=0)
- for i in range(len(unconcatenated_outs))
- ]
+ if steps is not None:
+ # Since we do not know how many samples we will see, we cannot
+ # pre-allocate the returned Numpy arrays. Instead, we store one array per
+ # batch seen and concatenate them upon returning.
+ unconcatenated_outs = []
+ for step in range(steps):
+ batch_outs = distributed_predict_function(ins)
+ if not isinstance(batch_outs, list):
+ batch_outs = [batch_outs]
+ if step == 0:
+ for _ in batch_outs:
+ unconcatenated_outs.append([])
+ # TODO(anjalisridhar): Should combine the outputs from multiple towers
+ # correctly here.
+ for i, batch_out in enumerate(batch_outs):
+ unconcatenated_outs[i].append(batch_out)
+ if verbose >= 1:
+ progbar.update(step + 1)
+ if len(unconcatenated_outs) == 1:
+ return np.concatenate(unconcatenated_outs[0], axis=0)
+ return [
+ np.concatenate(unconcatenated_outs[i], axis=0)
+ for i in range(len(unconcatenated_outs))
+ ]
def _experimental_predict_loop(model, iterator, verbose=0, steps=None):
diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py
index db7ccb181f..1f5176c4d7 100644
--- a/tensorflow/python/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/engine/training_eager_test.py
@@ -192,6 +192,20 @@ class CorrectnessTest(test.TestCase):
history = model.fit(iterator, epochs=1, steps_per_epoch=10)
self.assertEqual(np.around(history.history['loss'][-1], decimals=4), 0.6173)
+ def test_no_loss_in_call(self):
+
+ class HasLoss(keras.layers.Layer):
+
+ def call(self, x):
+ self.add_loss(x)
+ return x
+
+ layer = HasLoss()
+ with self.assertRaises(RuntimeError):
+ layer(1.)
+
+ with ops.Graph().as_default():
+ layer(1.)
if __name__ == '__main__':
ops.enable_eager_execution()
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 30be4131a4..54ad74c08b 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -27,6 +27,7 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util as tf_test_util
@@ -2427,6 +2428,17 @@ class TestTrainingWithMetrics(test.TestCase):
scores = model.train_on_batch(x, y, sample_weight=w)
self.assertArrayNear(scores, [0.2, 0.8, 0.8], 0.1)
+ def test_losses_in_defun(self):
+ with context.eager_mode():
+ layer = keras.layers.Dense(1, kernel_regularizer='l1')
+ layer(array_ops.ones([1, 10]))
+
+ @function.defun
+ def get_losses():
+ return layer.losses
+
+ self.assertAllEqual(self.evaluate(layer.losses),
+ self.evaluate(get_losses()))
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 280c18ec00..9490746fd9 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1480,7 +1480,7 @@ cuda_py_test(
name = "control_flow_ops_py_test",
# TODO(b/70473603): change this back to "small" once the C API is
# permanently enabled
- size = "medium",
+ size = "large",
srcs = ["control_flow_ops_py_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -1512,6 +1512,7 @@ cuda_py_test(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
+ "//tensorflow/python:while_v2",
],
)
@@ -2358,7 +2359,7 @@ cuda_py_test(
cuda_py_test(
name = "transpose_op_test",
- size = "large",
+ size = "medium",
srcs = ["transpose_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2366,10 +2367,11 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
],
- shard_count = 2,
+ shard_count = 4,
tags = [
"no_gpu",
"no_oss",
+ "optonly", # times out
],
)
@@ -2488,6 +2490,7 @@ cuda_py_test(
"//tensorflow/python:nn_grad",
"//tensorflow/python:nn_ops",
],
+ shard_count = 2,
tags = [
"optonly", # flaky timeouts unless optimized
],
@@ -2508,7 +2511,7 @@ cuda_py_test(
cuda_py_test(
name = "conv_ops_test",
- size = "large",
+ size = "medium",
srcs = ["conv_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2527,6 +2530,9 @@ cuda_py_test(
"//tensorflow/python:variables",
],
shard_count = 4,
+ tags = [
+ "optonly", # times out
+ ],
)
cuda_py_test(
@@ -2586,7 +2592,7 @@ cuda_py_test(
cuda_py_test(
name = "fft_ops_test",
- size = "large",
+ size = "medium",
srcs = ["fft_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2596,7 +2602,8 @@ cuda_py_test(
"//tensorflow/python:spectral_ops",
"//tensorflow/python:spectral_ops_test_util",
],
- shard_count = 3,
+ shard_count = 4,
+ tags = ["optonly"],
)
cuda_py_test(
@@ -2661,7 +2668,7 @@ cuda_py_test(
cuda_py_test(
name = "scatter_ops_test",
- size = "large", # NOTE: This is not run by default.
+ size = "medium", # NOTE: This is not run by default.
srcs = ["scatter_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2670,11 +2677,13 @@ cuda_py_test(
"//tensorflow/python:state_ops",
"//tensorflow/python:variables",
],
+ shard_count = 2,
+ tags = ["optonly"],
)
cuda_py_test(
name = "slice_op_test",
- size = "large",
+ size = "medium",
srcs = ["slice_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 083de84775..d91a848e01 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -23,7 +23,6 @@ from __future__ import print_function
import collections
import math
import time
-import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -63,6 +62,7 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.ops import while_v2 # pylint: disable=unused-import
# pylint: disable=unused-import
import tensorflow.python.ops.tensor_array_grad
# pylint: enable=unused-import
@@ -125,7 +125,7 @@ def isum(s, maximum_iterations=None):
return r_s
-@test_util.with_cond_v2
+@test_util.with_control_flow_v2
class ControlFlowTest(test.TestCase):
def testRefIdentity(self):
@@ -332,10 +332,8 @@ class ControlFlowTest(test.TestCase):
with self.assertRaisesOpError("has inputs from different frames"):
res.eval(feed_dict={data: 1.0})
+ @test_util.disable_control_flow_v2("b/113294340")
def testCondBool(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296297")
-
values = constant_op.constant(10)
fn1 = lambda: math_ops.add(values, 1)
fn2 = lambda: math_ops.subtract(values, 1)
@@ -366,6 +364,7 @@ class ControlFlowTest(test.TestCase):
"has been marked as not fetchable"):
sess.run(t, feed_dict={x: 3})
+ @test_util.disable_control_flow_v2("Not relevant")
def testFeedable(self):
with self.cached_session() as sess:
c = constant_op.constant(2)
@@ -383,10 +382,8 @@ class ControlFlowTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "may not be fed"):
sess.run(r, feed_dict={t: 3})
+ @test_util.disable_control_flow_v2("b/113296180 (IndexedSlices)")
def testCondIndexedSlices(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296180")
-
with self.cached_session():
values = constant_op.constant(10)
indices = constant_op.constant(0)
@@ -401,10 +398,8 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(11, val)
self.assertAllEqual(0, ind)
+ @test_util.disable_control_flow_v2("b/113296161 (SparseTensors)")
def testCondSparseTensor(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296161 (SparseTensors)")
-
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
indices = constant_op.constant(
@@ -435,10 +430,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval())
+ @test_util.disable_control_flow_v2("b/113293074")
def testCondIndexedSlicesDifferentTypes(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113293074")
-
with self.cached_session():
values = constant_op.constant(10)
i_32 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int32)
@@ -510,10 +503,8 @@ class ControlFlowTest(test.TestCase):
result = r.eval()
self.assertAllEqual(12, result)
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testCond_4(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v1 = variables.Variable(7)
v2 = variables.Variable(7)
@@ -587,10 +578,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn)
self.assertAllEqual([2.0], r.eval())
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testCondWithControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/79881896")
-
with self.cached_session():
control_holder = array_ops.placeholder(dtypes.float32, shape=())
a = constant_op.constant(3)
@@ -629,10 +618,9 @@ class ControlFlowTest(test.TestCase):
merged_op = control_flow_ops.merge([assign_v, orig_v])
self.assertAllEqual([1.0], sess.run(merged_op.output))
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCondSwitchIdentity(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
# Make sure the recv identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
pred = constant_op.constant(True)
@@ -646,10 +634,9 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
sess.run(r)
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCondRecvIdentity(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
# Make sure the switch identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
with ops.device(test.gpu_device_name()):
@@ -665,10 +652,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
sess.run(r)
+ @test_util.disable_control_flow_v2("b/113346829 (gpu failure)")
def testCondGrad_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113346829 (gpu failure)")
-
graph = ops.Graph()
with graph.as_default():
x = constant_op.constant(10.0, name="x")
@@ -694,10 +679,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1}))
self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
+ @test_util.disable_control_flow_v2(
+ "b/110550782 (gradient w.r.t external variable)")
def testCondGrad_3(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/110550782 (gradient w.r.t external variable)")
-
with self.cached_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
ox = constant_op.constant(10.0)
@@ -729,10 +713,8 @@ class ControlFlowTest(test.TestCase):
result = gradients_impl.gradients(z, x)[0]
self.assertEqual(1.0, result.eval())
+ @test_util.disable_control_flow_v2("b/113327884")
def testCondGrad_Gather(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113327884")
-
with self.cached_session() as sess:
v1 = variables.Variable([1.0, 42.0])
c = array_ops.placeholder(dtypes.int32, shape=[])
@@ -756,6 +738,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(dense_gv, [0.0, 2.0])
# Microbenchmark: 256,000 iterations/s.
+ @test_util.disable_control_flow_v2("b/116630618 (Times out)")
def testWhile_1(self):
with self.cached_session():
n = constant_op.constant(0)
@@ -764,6 +747,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual(10000, r.eval())
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testWhileExternalControlDependencies(self):
with self.cached_session():
v = variables.Variable(0.0)
@@ -779,6 +763,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(result.eval(), 2)
self.assertAllEqual(v.eval(), 1.0)
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testWhileExternalControlDependenciesNoInput(self):
with self.cached_session():
v = variables.Variable(0.0)
@@ -794,6 +779,7 @@ class ControlFlowTest(test.TestCase):
result.eval()
self.assertAllEqual(v.eval(), 1.0)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileWithRefs_1(self):
with self.cached_session() as sess:
x = variables.VariableV1(0)._ref() # pylint: disable=protected-access
@@ -824,18 +810,22 @@ class ControlFlowTest(test.TestCase):
r = isum(s)
self.assertAllEqual(45, r.eval())
+ @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
def testWhileWithMaximumIterations(self):
with self.cached_session():
s = constant_op.constant([1, 2, 3, 4, 5])
r = isum(s, maximum_iterations=3)
self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithMaximumIterationsAndSingleArgument(self):
with self.cached_session():
r = control_flow_ops.while_loop(
lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
self.assertEqual(1, r.eval())
+ @test_util.disable_control_flow_v2(
+ "b/116248044 (nested), b/115920078 (gradients)")
def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
v = constant_op.constant(1.0)
@@ -861,6 +851,7 @@ class ControlFlowTest(test.TestCase):
# Should execute without issue.
self.assertEqual(3, self.evaluate(loop_execute))
+ @test_util.disable_control_flow_v2("b/116248044 (nested while_loop)")
def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
v = constant_op.constant(1.0)
@@ -904,10 +895,8 @@ class ControlFlowTest(test.TestCase):
r"context '.*' \(currently defined in '.*'\)"):
_ = gradients_impl.gradients(loop_with_maxiter, v)
+ @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
-
v = constant_op.constant(1.0)
def create_while_loop():
@@ -939,6 +928,8 @@ class ControlFlowTest(test.TestCase):
r"while loop context '' \(currently defined in 'cond/.+'\)"):
_ = gradients_impl.gradients(loop, v)
+ @test_util.disable_control_flow_v2(
+ "b/116248044 (nesting), b/115776323 (max_iters)")
def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self):
v = constant_op.constant(1.0)
@@ -1048,6 +1039,7 @@ class ControlFlowTest(test.TestCase):
result = r[3].eval()
self.assertAllEqual(42, result)
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhile_5(self):
with self.cached_session():
@@ -1072,6 +1064,7 @@ class ControlFlowTest(test.TestCase):
result = r[2].eval()
self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result)
+ @test_util.disable_control_flow_v2("b/116338794 (buffer_reuse)")
def testBufferForwarding(self):
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
@@ -1122,6 +1115,7 @@ class ControlFlowTest(test.TestCase):
self._testWhile_Gpu_1(use_gpu=False)
self._testWhile_Gpu_1(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileShape(self):
with self.cached_session():
i = constant_op.constant(0)
@@ -1139,6 +1133,7 @@ class ControlFlowTest(test.TestCase):
r = r[1] * array_ops.ones([8, 8])
self.assertAllEqual(np.ones((8, 8)), r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithNonTensorInput_Scalar(self):
with self.cached_session():
n = 0
@@ -1147,6 +1142,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual(10000, r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithNonTensorInput_Vector(self):
with self.cached_session():
n = np.array([0]) # Note, [0] would not work here; that is a list
@@ -1155,6 +1151,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual([10000], r.eval())
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileShapeInference(self):
with self.cached_session():
i = constant_op.constant(0)
@@ -1169,7 +1166,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(
c, b, [i, m],
[i.get_shape(), tensor_shape.TensorShape([None, 2])])
- self.assertTrue(r[1].get_shape()[0].value is None)
+ self.assertIsNone(r[1].get_shape()[0].value)
self.assertEqual(r[1].get_shape()[1], tensor_shape.Dimension(2))
with self.assertRaisesRegexp(
@@ -1180,6 +1177,7 @@ class ControlFlowTest(test.TestCase):
r"tf.while_loop to specify a less-specific shape."):
r = control_flow_ops.while_loop(c, b, [i, m])
+ @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
def testWhileShapeInferenceSparseTensor(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -1211,6 +1209,7 @@ class ControlFlowTest(test.TestCase):
c, b, [i, x],
[i.get_shape(), tensor_shape.TensorShape([5])])
+ @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
def testWhileShapeInferenceIndexedSlices(self):
with self.cached_session():
values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
@@ -1265,6 +1264,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n])
self.assertEqual(225, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhile_1(self):
self._testNestedWhile_1(use_gpu=False)
self._testNestedWhile_1(use_gpu=True)
@@ -1297,6 +1297,7 @@ class ControlFlowTest(test.TestCase):
outer_c, outer_b, [s0], parallel_iterations=1)
self.assertEqual(1048576.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhile_2(self):
self._testNestedWhile_2(use_gpu=False)
self._testNestedWhile_2(use_gpu=True)
@@ -1350,6 +1351,7 @@ class ControlFlowTest(test.TestCase):
lambda x: x < 10, lambda x: x + array_ops.identity(c), [x0])
self.assertEqual(10, sess.run(r, {b: True}))
+ @test_util.disable_control_flow_v2("b/79881896 (control_deps)")
def testWhileWithControl_5(self):
with self.cached_session() as sess:
b = array_ops.placeholder(dtypes.bool)
@@ -1364,9 +1366,6 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, sess.run(r, {b: True}))
def testWhileCondWithControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
-
# Ensure that no control edges by an outer control dependency context are
# added to nodes inside cond/while contexts.
with self.cached_session() as sess:
@@ -1380,10 +1379,8 @@ class ControlFlowTest(test.TestCase):
(constant_op.constant(5),))
self.assertEqual(0, sess.run(loop))
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testWhileCondWithControl_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v = variable_scope.get_variable(
"v", [], initializer=init_ops.constant_initializer(2))
@@ -1405,9 +1402,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(4, r.eval())
self.assertAllClose(65536.0, v.eval())
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testWhileCondExitControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
v = variables.Variable(1)
@@ -1432,8 +1428,6 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(99, v.eval())
def testCondWhile_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
@@ -1445,8 +1439,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testCondWhile_2(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
n = ops.convert_to_tensor(0)
@@ -1458,9 +1450,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def _testCondWhile_3(self, use_gpu):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
-
with self.test_session(use_gpu=use_gpu) as sess:
p = array_ops.placeholder(dtypes.bool)
n = constant_op.constant(0.0)
@@ -1477,18 +1466,17 @@ class ControlFlowTest(test.TestCase):
lambda: control_flow_ops.while_loop(c, b, [n]),
lambda: math_ops.multiply(n, 2.0))
r1 = gradients_impl.gradients(r, [n])
- self.assertEqual(10, sess.run(r, {p: True}))
+ self.assertEqual(10., sess.run(r, {p: True}))
self.assertEqual([1.0], sess.run(r1, {p: True}))
self.assertEqual(0.0, sess.run(r, {p: False}))
self.assertEqual([2.0], sess.run(r1, {p: False}))
+ @test_util.disable_control_flow_v2("b/116743589")
def testCondWhile_3(self):
self._testCondWhile_3(use_gpu=False)
self._testCondWhile_3(use_gpu=True)
def testWhileCond_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
i = ops.convert_to_tensor(0, name="i")
@@ -1505,8 +1493,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testWhileCond_2(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
@@ -1516,8 +1502,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testWhileCond_3(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
n = ops.convert_to_tensor(0)
@@ -1532,6 +1516,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
# NOTE: It is ok to have parallel_iterations > 1
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_1(self):
with self.cached_session():
select = variables.Variable([3.0, 4.0, 5.0])
@@ -1554,6 +1539,7 @@ class ControlFlowTest(test.TestCase):
result = select.eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_2(self):
with self.cached_session():
select1 = variables.Variable([3.0, 4.0, 5.0])
@@ -1580,6 +1566,7 @@ class ControlFlowTest(test.TestCase):
result2 = select2.eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_3(self):
with self.cached_session():
select = variables.Variable([3.0, 4.0, 5.0])
@@ -1601,7 +1588,7 @@ class ControlFlowTest(test.TestCase):
result = r[1].eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
- # b/24814703
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_4(self):
with self.cached_session():
var_a = variables.Variable(0, name="a")
@@ -1629,7 +1616,7 @@ class ControlFlowTest(test.TestCase):
lpa.eval() # Run the loop
self.assertEqual(10, var_b.eval())
- # b/24736492
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_5(self):
with self.cached_session():
# Create some variables.
@@ -1659,7 +1646,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, var_a.eval())
self.assertEqual(10, var_b.eval())
- # b/24814668
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_6(self):
with self.cached_session():
# Create some variables.
@@ -1689,6 +1676,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(55, var_b.eval())
self.assertEqual(10, var_a.eval())
+ @test_util.disable_control_flow_v2("b/116742472 (resource accumulator)")
def testWhileQueue_1(self):
with self.cached_session():
q = data_flow_ops.FIFOQueue(-1, dtypes.int32)
@@ -1707,6 +1695,7 @@ class ControlFlowTest(test.TestCase):
for i in xrange(10):
self.assertEqual([i], q.dequeue().eval())
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileStack_1(self):
with self.cached_session():
s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo")
@@ -1775,6 +1764,7 @@ class ControlFlowTest(test.TestCase):
with self.session(graph=graph) as sess:
self.assertAllClose(1024.0, sess.run(r))
+ @test_util.disable_control_flow_v2("b/116351701 (colocation)")
def testWhileGrad_ColocateGradients(self):
self._testWhileGrad_ColocateGradients(colocate=False)
self._testWhileGrad_ColocateGradients(colocate=True)
@@ -1790,6 +1780,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(1024.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileGrad_Shape(self):
with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=[None])
@@ -1861,8 +1852,6 @@ class ControlFlowTest(test.TestCase):
self._testWhileGrad_Mul(use_gpu=True, p_iters=10)
def _testNestedWhileCondWhileGrad(self, use_gpu):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.test_session(use_gpu=use_gpu):
v = constant_op.constant(1.0)
@@ -1885,10 +1874,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(512.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhileCondWhileGrad(self):
self._testNestedWhileCondWhileGrad(use_gpu=False)
self._testNestedWhileCondWhileGrad(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116823782")
def testWhileGrad_Variable(self):
with self.cached_session():
a = variables.Variable(3.0)
@@ -1902,8 +1893,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(216.0, r[0].eval())
def testWhileGradInCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/110550782 (gradient w.r.t external variable)")
with self.cached_session():
n = ops.convert_to_tensor(1.0, name="n")
@@ -1919,6 +1908,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x)
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+ @test_util.disable_control_flow_v2("b/116340060")
def testGradInWhileWrtInitialLoopVal(self):
with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
@@ -1936,6 +1926,7 @@ class ControlFlowTest(test.TestCase):
"loop invariants or wrt the input parameters to the loop body."):
control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testWhileGradInWhile(self):
with self.cached_session():
n = ops.convert_to_tensor(1.0, name="n")
@@ -1952,9 +1943,8 @@ class ControlFlowTest(test.TestCase):
[tensor_shape.unknown_shape()])
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testCondGradInNestedWhiles(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113346829 (gpu failure)")
def outer_body(i, x):
_, x = control_flow_ops.while_loop(
@@ -1972,6 +1962,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(i_val, 3)
self.assertAllClose(x_val, 1.0)
+ @test_util.disable_control_flow_v2("b/116255781 (flat_args)")
def testWhile_NestedInput(self):
with self.cached_session() as sess:
named = collections.namedtuple("named", ("a", "b"))
@@ -1999,6 +1990,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0],
sess.run(r_flattened))
+ @test_util.disable_control_flow_v2("b/116255781(flat_args)")
def testWhile_NestedBadArityFails(self):
with self.cached_session():
named = collections.namedtuple("named", ("a", "b"))
@@ -2057,6 +2049,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients([rx], x)
self.assertAllClose(1024.0, r[0].eval())
+ @test_util.disable_control_flow_v2("b/116355153 (back_prop flag)")
def testWhileGrad_NoGradient(self):
with self.cached_session():
v = constant_op.constant(2.0, name="v")
@@ -2067,6 +2060,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)
self.assertAllClose(1.0, r[0].eval())
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGrad_NoDependency(self):
with self.cached_session() as sess:
variable = variables.Variable(array_ops.ones([2, 3]))
@@ -2180,10 +2174,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(8.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_Simple(self):
self._testNestedWhileGrad_Simple(use_gpu=False)
self._testNestedWhileGrad_Simple(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_SerialInner(self):
with self.cached_session():
v = constant_op.constant(1.0)
@@ -2207,6 +2203,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(256.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_ParallelInner(self):
with self.cached_session():
v = constant_op.constant(1.0)
@@ -2230,6 +2227,8 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(512.0, r.eval())
+ @test_util.disable_control_flow_v2(
+ "Nested loops and TensorArrays not supported")
def testNestedWhileGrad_ParallelIterations(self):
# Make sure the stack pushes and pops of an inner loop are executed in
# the sequential order of the iterations of its outer loop.
@@ -2268,13 +2267,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(1024.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116272044 (cond_in_while)")
def testWhileCondGrad_Simple(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
-
self._testWhileCondGrad_Simple(use_gpu=False)
self._testWhileCondGrad_Simple(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116272044 (cond_in_while)")
def testWhileCondGrad_UnknownShape(self):
with self.cached_session() as sess:
v = array_ops.placeholder(dtypes.float32)
@@ -2292,6 +2290,7 @@ class ControlFlowTest(test.TestCase):
r = sess.run(r, feed_dict={v: 2.0})
self.assertAllClose(1024.0, r)
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileGrad_Concat(self):
with self.cached_session() as sess:
x = variable_scope.get_variable("x", initializer=[[1., 2.]])
@@ -2315,6 +2314,7 @@ class ControlFlowTest(test.TestCase):
sess.run(op)
self.assertAllClose([[0.98000002, 1.98000002]], sess.run(x))
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileWithRefsWithGradients_1(self):
with self.cached_session() as sess:
x = variables.VariableV1(0.)._ref() # pylint: disable=protected-access
@@ -2343,6 +2343,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, value_x)
self.assertEqual(73, value_x_grad)
+ @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
def testWhileGrad_IndexedSlices(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -2364,6 +2365,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r.values, values)[0]
self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
+ @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
def testWhileGrad_SparseTensor(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -2386,6 +2388,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r.values, values)[0]
self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
+ @test_util.disable_control_flow_v2("b/115920078 (gradients)")
def testCallGradInLoop(self):
with self.cached_session() as sess:
i0 = constant_op.constant(0)
@@ -2405,6 +2408,8 @@ class ControlFlowTest(test.TestCase):
c, b, [i0, constant_op.constant(0.0)])
self.assertAllClose(600.0, sess.run(output_grad)[1])
+ @test_util.disable_control_flow_v2(
+ "b/116255781 (flat_args), b/115660901 (TensorArray)")
def testWhileAndTensorArray(self):
with self.cached_session() as sess:
param = constant_op.constant(2.0)
@@ -2509,6 +2514,7 @@ class ControlFlowTest(test.TestCase):
all_ops = x.graph.get_operations()
self.assertFalse(any([name in op.name for op in all_ops]))
+ @test_util.disable_control_flow_v2("b/116255781 (flat args)")
def testWhileGradGradFail(self):
theta = variables.Variable(initial_value=1.)
@@ -2538,6 +2544,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, y)[0]
self.assertEqual(388.0, r.eval())
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGradientWithNontrainablePath1(self):
q = variables.Variable([7., 8.])
@@ -2555,6 +2562,7 @@ class ControlFlowTest(test.TestCase):
sess.run(q.initializer)
self.assertAllClose([0., 0.], sess.run(dy_dq))
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGradientWithNontrainablePath2(self):
q = variables.Variable([7., 8.])
@@ -2572,6 +2580,7 @@ class ControlFlowTest(test.TestCase):
sess.run(q.initializer)
self.assertAllClose([1., 1.], sess.run(dy_dq))
+ @test_util.disable_control_flow_v2("b/115920078 (gradients)")
def testIssue16504(self):
c = constant_op.constant(np.arange(100), dtype=dtypes.float32)
w = variables.Variable(
@@ -2595,6 +2604,7 @@ class ControlFlowTest(test.TestCase):
grad, = gradients_impl.gradients(w, c)
self.assertIsNotNone(grad)
+ @test_util.disable_control_flow_v2("b/116270461 (resource)")
def testStopGradMultiFlows(self):
with self.cached_session():
@@ -2653,10 +2663,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(4.0, i.eval(feed_dict={d: 1}))
self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCase(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
with self.cached_session():
x = constant_op.constant(1)
y = constant_op.constant(2)
@@ -2708,10 +2717,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r6.eval(), 0)
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCaseSideEffects(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
with self.cached_session() as sess:
v0 = variables.Variable(-1)
v1 = variables.Variable(-1)
@@ -2746,10 +2754,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, r0.eval())
self.assertAllEqual(sess.run([v0, v1, v2]), [0, -1, -1])
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testOneOpCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v = variables.Variable(0)
c = ops.convert_to_tensor(0)
@@ -3031,9 +3037,11 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, x)[0]
self.assertEqual(r.eval(), 524288.0)
- self.assertEqual(
- len([op for op in x.graph.get_operations() if op.type == "StackV2"]),
- 1)
+ # while_v2 does not have stacks.
+ if not control_flow_ops.ENABLE_WHILE_V2:
+ self.assertEqual(
+ len([op for op in x.graph.get_operations() if op.type == "StackV2"
+ ]), 1)
class ControlFlowContextCheckTest(test.TestCase):
@@ -3393,7 +3401,7 @@ class WhileOpBenchmark(test.Benchmark):
name="unroll_same_device", iters=iters, wall_time=duration)
-@test_util.with_cond_v2
+@test_util.with_control_flow_v2
class EagerTest(test.TestCase):
def testCond(self):
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 3ba880d7a1..e399ece232 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -131,10 +131,20 @@ class Layer(base_layer.Layer):
def add_loss(self, losses, inputs=None):
previous_losses_length = len(self._losses)
+ previous_callable_losses_length = len(self._callable_losses)
super(Layer, self).add_loss(losses, inputs=inputs)
- # TODO(fchollet): deprecate collection below.
- new_losses = self._losses[previous_losses_length:]
- _add_elements_to_collection(new_losses, ops.GraphKeys.REGULARIZATION_LOSSES)
+ if not context.executing_eagerly():
+ # TODO(fchollet): deprecate collection below.
+ new_losses = self._losses[previous_losses_length:]
+ new_callable_losses = self._callable_losses[
+ previous_callable_losses_length:]
+ for regularizer in new_callable_losses:
+ loss_tensor = regularizer()
+ if loss_tensor is not None:
+ new_losses.append(loss_tensor)
+ _add_elements_to_collection(
+ new_losses,
+ ops.GraphKeys.REGULARIZATION_LOSSES)
def _name_scope(self):
"""Determines op naming for the Layer."""
diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py
index d61d3b6dba..257fa27156 100644
--- a/tensorflow/python/layers/convolutional_test.py
+++ b/tensorflow/python/layers/convolutional_test.py
@@ -207,7 +207,8 @@ class ConvTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DBiasRegularizer(self):
height, width = 7, 9
@@ -217,7 +218,8 @@ class ConvTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DNoBias(self):
height, width = 7, 9
@@ -445,7 +447,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DPointwiseRegularizer(self):
length = 9
@@ -455,7 +458,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DBiasRegularizer(self):
length = 9
@@ -465,7 +469,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DNoBias(self):
length = 9
@@ -682,7 +687,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DPointwiseRegularizer(self):
height, width = 7, 9
@@ -692,7 +698,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DBiasRegularizer(self):
height, width = 7, 9
@@ -702,7 +709,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DNoBias(self):
height, width = 7, 9
@@ -839,7 +847,8 @@ class Conv2DTransposeTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DTransposeBiasRegularizer(self):
height, width = 7, 9
@@ -849,7 +858,8 @@ class Conv2DTransposeTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DTransposeNoBias(self):
height, width = 7, 9
@@ -1017,7 +1027,8 @@ class Conv3DTransposeTest(test.TestCase):
layer.apply(volumes)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv3DTransposeBiasRegularizer(self):
depth, height, width = 5, 7, 9
@@ -1027,7 +1038,8 @@ class Conv3DTransposeTest(test.TestCase):
layer.apply(volumes)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv3DTransposeNoBias(self):
depth, height, width = 5, 7, 9
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index 46009a30ac..d26f3f4789 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -197,7 +197,8 @@ class DenseTest(test.TestCase):
_ = dense(inputs)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(dense.losses, loss_keys)
+ self.evaluate([v.initializer for v in dense.variables])
+ self.assertAllEqual(self.evaluate(dense.losses), self.evaluate(loss_keys))
def testKernelRegularizerWithReuse(self):
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
@@ -218,7 +219,8 @@ class DenseTest(test.TestCase):
_ = dense(inputs)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(dense.losses, loss_keys)
+ self.evaluate([v.initializer for v in dense.variables])
+ self.assertAllEqual(self.evaluate(dense.losses), self.evaluate(loss_keys))
def testFunctionalDense(self):
with self.cached_session():
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 87f8bd85a5..9d7d31df22 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -60,8 +60,17 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
+# The while_v2 module.
+_while_v2 = None
ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
+# Note: Setting this to True is not sufficient to switch to the v2 while_loop.
+# Users must also import the while_v2 module to set the _while_v2 module
+# variable above. We do this to avoid a circular dependency:
+# control_flow_ops -> while_v2 -> gradients_impl -> control_flow_ops
+# A ValueError is raised in tf.while_loop if this is set to True and the
+# `_while_v2` module is not set.
+ENABLE_WHILE_V2 = os.getenv("TF_ENABLE_WHILE_V2", "0") != "0"
# We override the 'tuple' for a control flow op, so we keep python's
@@ -3211,6 +3220,13 @@ def while_loop(cond,
```
"""
+ if ENABLE_WHILE_V2 and not context.executing_eagerly():
+ if not _while_v2:
+ raise ValueError("The while_v2 module is not set. Did you forget to "
+ "import tensorflow.python.ops."
+ "while_v2?")
+ return _while_v2.while_loop(cond, body, loop_vars, name)
+
with ops.name_scope(name, "while", loop_vars):
if not loop_vars:
raise ValueError("No loop variables provided")
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 43cca1a498..c2751e529a 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -611,7 +611,7 @@ class LSTMStateTuple(_LSTMStateTuple):
# TODO(scottzhu): Stop exporting this class in TF 2.0.
@tf_export("nn.rnn_cell.BasicLSTMCell")
class BasicLSTMCell(LayerRNNCell):
- """DEPRECATED: Please use @{tf.nn.rnn_cell.LSTMCell} instead.
+ """DEPRECATED: Please use `tf.nn.rnn_cell.LSTMCell` instead.
Basic LSTM recurrent network cell.
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 875be31602..6791e1cd61 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -24,6 +24,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import sys
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import function
@@ -33,6 +34,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl as cond_v2
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gradients_impl
@@ -41,6 +43,8 @@ from tensorflow.python.util import nest
# pylint: disable=protected-access
+control_flow_ops._while_v2 = sys.modules[__name__]
+
# TODO(b/79881896): Handle external control dependencies. tf.while_loop allows
# control dependencies on external nodes with at least 1 output.
# Another idea is to create const nodes outside the loop and add control edges
diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt
deleted file mode 100644
index eb41deee13..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt
+++ /dev/null
@@ -1,24 +0,0 @@
-path: "tensorflow.ConfigProto.Experimental"
-tf_proto {
- descriptor {
- name: "Experimental"
- field {
- name: "collective_group_leader"
- number: 1
- label: LABEL_OPTIONAL
- type: TYPE_STRING
- }
- field {
- name: "client_handles_error_formatting"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "executor_type"
- number: 3
- label: LABEL_OPTIONAL
- type: TYPE_STRING
- }
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
deleted file mode 100644
index e565b903d2..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
+++ /dev/null
@@ -1,148 +0,0 @@
-path: "tensorflow.ConfigProto"
-tf_proto {
- descriptor {
- name: "ConfigProto"
- field {
- name: "device_count"
- number: 1
- label: LABEL_REPEATED
- type: TYPE_MESSAGE
- type_name: ".tensorflow.ConfigProto.DeviceCountEntry"
- }
- field {
- name: "intra_op_parallelism_threads"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_INT32
- }
- field {
- name: "inter_op_parallelism_threads"
- number: 5
- label: LABEL_OPTIONAL
- type: TYPE_INT32
- }
- field {
- name: "use_per_session_threads"
- number: 9
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "session_inter_op_thread_pool"
- number: 12
- label: LABEL_REPEATED
- type: TYPE_MESSAGE
- type_name: ".tensorflow.ThreadPoolOptionProto"
- }
- field {
- name: "placement_period"
- number: 3
- label: LABEL_OPTIONAL
- type: TYPE_INT32
- }
- field {
- name: "device_filters"
- number: 4
- label: LABEL_REPEATED
- type: TYPE_STRING
- }
- field {
- name: "gpu_options"
- number: 6
- label: LABEL_OPTIONAL
- type: TYPE_MESSAGE
- type_name: ".tensorflow.GPUOptions"
- }
- field {
- name: "allow_soft_placement"
- number: 7
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "log_device_placement"
- number: 8
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "graph_options"
- number: 10
- label: LABEL_OPTIONAL
- type: TYPE_MESSAGE
- type_name: ".tensorflow.GraphOptions"
- }
- field {
- name: "operation_timeout_in_ms"
- number: 11
- label: LABEL_OPTIONAL
- type: TYPE_INT64
- }
- field {
- name: "rpc_options"
- number: 13
- label: LABEL_OPTIONAL
- type: TYPE_MESSAGE
- type_name: ".tensorflow.RPCOptions"
- }
- field {
- name: "cluster_def"
- number: 14
- label: LABEL_OPTIONAL
- type: TYPE_MESSAGE
- type_name: ".tensorflow.ClusterDef"
- }
- field {
- name: "isolate_session_state"
- number: 15
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "experimental"
- number: 16
- label: LABEL_OPTIONAL
- type: TYPE_MESSAGE
- type_name: ".tensorflow.ConfigProto.Experimental"
- }
- nested_type {
- name: "DeviceCountEntry"
- field {
- name: "key"
- number: 1
- label: LABEL_OPTIONAL
- type: TYPE_STRING
- }
- field {
- name: "value"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_INT32
- }
- options {
- map_entry: true
- }
- }
- nested_type {
- name: "Experimental"
- field {
- name: "collective_group_leader"
- number: 1
- label: LABEL_OPTIONAL
- type: TYPE_STRING
- }
- field {
- name: "client_handles_error_formatting"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "executor_type"
- number: 3
- label: LABEL_OPTIONAL
- type: TYPE_STRING
- }
- }
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt
deleted file mode 100644
index 4f0147a523..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt
+++ /dev/null
@@ -1,46 +0,0 @@
-path: "tensorflow.data.Iterator"
-tf_class {
- is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.Iterator\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "initializer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_classes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shapes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_types"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'iterator_resource\', \'initializer\', \'output_types\', \'output_shapes\', \'output_classes\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "from_string_handle"
- argspec: "args=[\'string_handle\', \'output_types\', \'output_shapes\', \'output_classes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "from_structure"
- argspec: "args=[\'output_types\', \'output_shapes\', \'shared_name\', \'output_classes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "get_next"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "make_initializer"
- argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "string_handle"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
deleted file mode 100644
index c23b04b4ef..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ /dev/null
@@ -1,58 +0,0 @@
-path: "tensorflow.estimator.BoostedTreesClassifier"
-tf_class {
- is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesClassifier\'>"
- is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "config"
- mtype: "<type \'property\'>"
- }
- member {
- name: "model_dir"
- mtype: "<type \'property\'>"
- }
- member {
- name: "model_fn"
- mtype: "<type \'property\'>"
- }
- member {
- name: "params"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], "
- }
- member_method {
- name: "eval_dir"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "export_savedmodel"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
- }
- member_method {
- name: "get_variable_names"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_variable_value"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "latest_checkpoint"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
- }
- member_method {
- name: "train"
- argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
deleted file mode 100644
index 6878d28fff..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ /dev/null
@@ -1,58 +0,0 @@
-path: "tensorflow.estimator.BoostedTreesRegressor"
-tf_class {
- is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesRegressor\'>"
- is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "config"
- mtype: "<type \'property\'>"
- }
- member {
- name: "model_dir"
- mtype: "<type \'property\'>"
- }
- member {
- name: "model_fn"
- mtype: "<type \'property\'>"
- }
- member {
- name: "params"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], "
- }
- member_method {
- name: "eval_dir"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "export_savedmodel"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
- }
- member_method {
- name: "get_variable_names"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_variable_value"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "latest_checkpoint"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
- }
- member_method {
- name: "train"
- argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt
deleted file mode 100644
index bf1f94b6ae..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt
+++ /dev/null
@@ -1,105 +0,0 @@
-path: "tensorflow.estimator.RunConfig"
-tf_class {
- is_instance: "<class \'tensorflow.python.estimator.run_config.RunConfig\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "cluster_spec"
- mtype: "<type \'property\'>"
- }
- member {
- name: "device_fn"
- mtype: "<type \'property\'>"
- }
- member {
- name: "eval_distribute"
- mtype: "<type \'property\'>"
- }
- member {
- name: "evaluation_master"
- mtype: "<type \'property\'>"
- }
- member {
- name: "global_id_in_cluster"
- mtype: "<type \'property\'>"
- }
- member {
- name: "is_chief"
- mtype: "<type \'property\'>"
- }
- member {
- name: "keep_checkpoint_every_n_hours"
- mtype: "<type \'property\'>"
- }
- member {
- name: "keep_checkpoint_max"
- mtype: "<type \'property\'>"
- }
- member {
- name: "log_step_count_steps"
- mtype: "<type \'property\'>"
- }
- member {
- name: "master"
- mtype: "<type \'property\'>"
- }
- member {
- name: "model_dir"
- mtype: "<type \'property\'>"
- }
- member {
- name: "num_ps_replicas"
- mtype: "<type \'property\'>"
- }
- member {
- name: "num_worker_replicas"
- mtype: "<type \'property\'>"
- }
- member {
- name: "protocol"
- mtype: "<type \'property\'>"
- }
- member {
- name: "save_checkpoints_secs"
- mtype: "<type \'property\'>"
- }
- member {
- name: "save_checkpoints_steps"
- mtype: "<type \'property\'>"
- }
- member {
- name: "save_summary_steps"
- mtype: "<type \'property\'>"
- }
- member {
- name: "service"
- mtype: "<type \'property\'>"
- }
- member {
- name: "session_config"
- mtype: "<type \'property\'>"
- }
- member {
- name: "task_id"
- mtype: "<type \'property\'>"
- }
- member {
- name: "task_type"
- mtype: "<type \'property\'>"
- }
- member {
- name: "tf_random_seed"
- mtype: "<type \'property\'>"
- }
- member {
- name: "train_distribute"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "replace"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
deleted file mode 100644
index 5c46dc5ee7..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt
+++ /dev/null
@@ -1,251 +0,0 @@
-path: "tensorflow.image"
-tf_module {
- member {
- name: "ResizeMethod"
- mtype: "<type \'type\'>"
- }
- member_method {
- name: "adjust_brightness"
- argspec: "args=[\'image\', \'delta\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "adjust_contrast"
- argspec: "args=[\'images\', \'contrast_factor\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "adjust_gamma"
- argspec: "args=[\'image\', \'gamma\', \'gain\'], varargs=None, keywords=None, defaults=[\'1\', \'1\'], "
- }
- member_method {
- name: "adjust_hue"
- argspec: "args=[\'image\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "adjust_jpeg_quality"
- argspec: "args=[\'image\', \'jpeg_quality\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "adjust_saturation"
- argspec: "args=[\'image\', \'saturation_factor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "central_crop"
- argspec: "args=[\'image\', \'central_fraction\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "convert_image_dtype"
- argspec: "args=[\'image\', \'dtype\', \'saturate\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "crop_and_resize"
- argspec: "args=[\'image\', \'boxes\', \'box_ind\', \'crop_size\', \'method\', \'extrapolation_value\', \'name\'], varargs=None, keywords=None, defaults=[\'bilinear\', \'0\', \'None\'], "
- }
- member_method {
- name: "crop_to_bounding_box"
- argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "decode_and_crop_jpeg"
- argspec: "args=[\'contents\', \'crop_window\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \'True\', \'False\', \'1\', \'\', \'None\'], "
- }
- member_method {
- name: "decode_bmp"
- argspec: "args=[\'contents\', \'channels\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
- }
- member_method {
- name: "decode_gif"
- argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "decode_image"
- argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'uint8\'>\", \'None\'], "
- }
- member_method {
- name: "decode_jpeg"
- argspec: "args=[\'contents\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \'True\', \'False\', \'1\', \'\', \'None\'], "
- }
- member_method {
- name: "decode_png"
- argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'uint8\'>\", \'None\'], "
- }
- member_method {
- name: "draw_bounding_boxes"
- argspec: "args=[\'images\', \'boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "encode_jpeg"
- argspec: "args=[\'image\', \'format\', \'quality\', \'progressive\', \'optimize_size\', \'chroma_downsampling\', \'density_unit\', \'x_density\', \'y_density\', \'xmp_metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'95\', \'False\', \'False\', \'True\', \'in\', \'300\', \'300\', \'\', \'None\'], "
- }
- member_method {
- name: "encode_png"
- argspec: "args=[\'image\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
- }
- member_method {
- name: "extract_glimpse"
- argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'None\'], "
- }
- member_method {
- name: "extract_image_patches"
- argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "extract_jpeg_shape"
- argspec: "args=[\'contents\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
- }
- member_method {
- name: "flip_left_right"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "flip_up_down"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "grayscale_to_rgb"
- argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "hsv_to_rgb"
- argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "image_gradients"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "is_jpeg"
- argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "non_max_suppression"
- argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], "
- }
- member_method {
- name: "non_max_suppression_overlaps"
- argspec: "args=[\'overlaps\', \'scores\', \'max_output_size\', \'overlap_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], "
- }
- member_method {
- name: "non_max_suppression_padded"
- argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], "
- }
- member_method {
- name: "pad_to_bounding_box"
- argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "per_image_standardization"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "psnr"
- argspec: "args=[\'a\', \'b\', \'max_val\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_brightness"
- argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_contrast"
- argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_flip_left_right"
- argspec: "args=[\'image\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_flip_up_down"
- argspec: "args=[\'image\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_hue"
- argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_jpeg_quality"
- argspec: "args=[\'image\', \'min_jpeg_quality\', \'max_jpeg_quality\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_saturation"
- argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "resize_area"
- argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "resize_bicubic"
- argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "resize_bilinear"
- argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "resize_image_with_crop_or_pad"
- argspec: "args=[\'image\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "resize_image_with_pad"
- argspec: "args=[\'image\', \'target_height\', \'target_width\', \'method\'], varargs=None, keywords=None, defaults=[\'0\'], "
- }
- member_method {
- name: "resize_images"
- argspec: "args=[\'images\', \'size\', \'method\', \'align_corners\', \'preserve_aspect_ratio\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\'], "
- }
- member_method {
- name: "resize_nearest_neighbor"
- argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "rgb_to_grayscale"
- argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "rgb_to_hsv"
- argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "rgb_to_yiq"
- argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "rgb_to_yuv"
- argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "rot90"
- argspec: "args=[\'image\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
- }
- member_method {
- name: "sample_distorted_bounding_box"
- argspec: "args=[\'image_size\', \'bounding_boxes\', \'seed\', \'seed2\', \'min_object_covered\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.1\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "sobel_edges"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "ssim"
- argspec: "args=[\'img1\', \'img2\', \'max_val\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "ssim_multiscale"
- argspec: "args=[\'img1\', \'img2\', \'max_val\', \'power_factors\'], varargs=None, keywords=None, defaults=[\'(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)\'], "
- }
- member_method {
- name: "total_variation"
- argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "transpose_image"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "yiq_to_rgb"
- argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "yuv_to_rgb"
- argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
deleted file mode 100644
index e579fe6a1a..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
+++ /dev/null
@@ -1,268 +0,0 @@
-path: "tensorflow.keras.Model"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "activity_regularizer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "inbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_spec"
- mtype: "<type \'property\'>"
- }
- member {
- name: "layers"
- mtype: "<type \'property\'>"
- }
- member {
- name: "losses"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "outbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "stateful"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "uses_learning_phase"
- mtype: "<type \'property\'>"
- }
- member {
- name: "variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "weights"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_loss"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_update"
- argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "add_variable"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
- }
- member_method {
- name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
- }
- member_method {
- name: "apply"
- argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "call"
- argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "compute_mask"
- argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "count_params"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
- }
- member_method {
- name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
- }
- member_method {
- name: "fit_generator"
- argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
- }
- member_method {
- name: "from_config"
- argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "get_config"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_layer"
- argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_weights"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "load_weights"
- argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
- }
- member_method {
- name: "predict_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "predict_on_batch"
- argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "reset_states"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "save"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
- }
- member_method {
- name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
- }
- member_method {
- name: "set_weights"
- argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "summary"
- argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "test_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "to_json"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "to_yaml"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "train_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
deleted file mode 100644
index 6f05cdd093..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
+++ /dev/null
@@ -1,289 +0,0 @@
-path: "tensorflow.keras.Sequential"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.engine.sequential.Sequential\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "activity_regularizer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "inbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_spec"
- mtype: "<type \'property\'>"
- }
- member {
- name: "layers"
- mtype: "<type \'property\'>"
- }
- member {
- name: "losses"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "outbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "stateful"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "uses_learning_phase"
- mtype: "<type \'property\'>"
- }
- member {
- name: "variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "weights"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'layers\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "add"
- argspec: "args=[\'self\', \'layer\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "add_loss"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_update"
- argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "add_variable"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
- }
- member_method {
- name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
- }
- member_method {
- name: "apply"
- argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "call"
- argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "compute_mask"
- argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "count_params"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
- }
- member_method {
- name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
- }
- member_method {
- name: "fit_generator"
- argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
- }
- member_method {
- name: "from_config"
- argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "get_config"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_layer"
- argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_weights"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "load_weights"
- argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
- }
- member_method {
- name: "pop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
- }
- member_method {
- name: "predict_classes"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], "
- }
- member_method {
- name: "predict_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "predict_on_batch"
- argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict_proba"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], "
- }
- member_method {
- name: "reset_states"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "save"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
- }
- member_method {
- name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
- }
- member_method {
- name: "set_weights"
- argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "summary"
- argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "symbolic_set_inputs"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "test_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "to_json"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "to_yaml"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "train_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt
deleted file mode 100644
index 2e9de9ebb2..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt
+++ /dev/null
@@ -1,55 +0,0 @@
-path: "tensorflow.keras.activations"
-tf_module {
- member_method {
- name: "deserialize"
- argspec: "args=[\'name\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "elu"
- argspec: "args=[\'x\', \'alpha\'], varargs=None, keywords=None, defaults=[\'1.0\'], "
- }
- member_method {
- name: "get"
- argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "hard_sigmoid"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "linear"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "relu"
- argspec: "args=[\'x\', \'alpha\', \'max_value\', \'threshold\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\', \'0\'], "
- }
- member_method {
- name: "selu"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "serialize"
- argspec: "args=[\'activation\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "sigmoid"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "softmax"
- argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], "
- }
- member_method {
- name: "softplus"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "softsign"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "tanh"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
deleted file mode 100644
index 56914e1746..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
+++ /dev/null
@@ -1,268 +0,0 @@
-path: "tensorflow.keras.models.Model"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "activity_regularizer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "inbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_spec"
- mtype: "<type \'property\'>"
- }
- member {
- name: "layers"
- mtype: "<type \'property\'>"
- }
- member {
- name: "losses"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "outbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "stateful"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "uses_learning_phase"
- mtype: "<type \'property\'>"
- }
- member {
- name: "variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "weights"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_loss"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_update"
- argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "add_variable"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
- }
- member_method {
- name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
- }
- member_method {
- name: "apply"
- argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "call"
- argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "compute_mask"
- argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "count_params"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
- }
- member_method {
- name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
- }
- member_method {
- name: "fit_generator"
- argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
- }
- member_method {
- name: "from_config"
- argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "get_config"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_layer"
- argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_weights"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "load_weights"
- argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
- }
- member_method {
- name: "predict_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "predict_on_batch"
- argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "reset_states"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "save"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
- }
- member_method {
- name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
- }
- member_method {
- name: "set_weights"
- argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "summary"
- argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "test_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "to_json"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "to_yaml"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "train_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
deleted file mode 100644
index 4c1c54001d..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
+++ /dev/null
@@ -1,289 +0,0 @@
-path: "tensorflow.keras.models.Sequential"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.engine.sequential.Sequential\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "activity_regularizer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "inbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_spec"
- mtype: "<type \'property\'>"
- }
- member {
- name: "layers"
- mtype: "<type \'property\'>"
- }
- member {
- name: "losses"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "outbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "stateful"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "uses_learning_phase"
- mtype: "<type \'property\'>"
- }
- member {
- name: "variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "weights"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'layers\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "add"
- argspec: "args=[\'self\', \'layer\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "add_loss"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_update"
- argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "add_variable"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
- }
- member_method {
- name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
- }
- member_method {
- name: "apply"
- argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "call"
- argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "compute_mask"
- argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "count_params"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
- }
- member_method {
- name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
- }
- member_method {
- name: "fit_generator"
- argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
- }
- member_method {
- name: "from_config"
- argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "get_config"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_layer"
- argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_weights"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "load_weights"
- argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
- }
- member_method {
- name: "pop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
- }
- member_method {
- name: "predict_classes"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], "
- }
- member_method {
- name: "predict_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "predict_on_batch"
- argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict_proba"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], "
- }
- member_method {
- name: "reset_states"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "save"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
- }
- member_method {
- name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
- }
- member_method {
- name: "set_weights"
- argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "summary"
- argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "symbolic_set_inputs"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "test_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "to_json"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "to_yaml"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "train_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt
index 537e73aa89..47b5b56faf 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt
@@ -8,5 +8,11 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_INT64
}
+ field {
+ name: "use_run_handler_pool"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
}
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt
index cec04a2bf0..c0c2e7b9f8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt
@@ -55,6 +55,12 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_INT64
}
+ field {
+ name: "use_run_handler_pool"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
}
enum_type {
name: "TraceLevel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt
index 537e73aa89..47b5b56faf 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt
@@ -8,5 +8,11 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_INT64
}
+ field {
+ name: "use_run_handler_pool"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt
index cec04a2bf0..c0c2e7b9f8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt
@@ -55,6 +55,12 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_INT64
}
+ field {
+ name: "use_run_handler_pool"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
}
enum_type {
name: "TraceLevel"
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04 b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04
index a30858db82..dd8d705331 100644
--- a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04
+++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04
@@ -26,7 +26,7 @@ ENV NVIDIA_VISIBLE_DEVICES all
ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
ENV NVIDIA_REQUIRE_CUDA "cuda>=9.0"
ENV NCCL_VERSION 2.2.13
-ENV CUDNN_VERSION 7.2.1.38
+ENV CUDNN_VERSION 7.1.4.18
# TODO(b/110903506): /usr/loca/cuda/lib64/stubs should not be needed in
# LD_LIBRARY_PATH. The stubs/libcuda.so is not meant to used at runtime. The
diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
index 17198a6560..7d5cf3f843 100755
--- a/tensorflow/tools/ci_build/builds/run_pip_tests.sh
+++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
@@ -111,7 +111,6 @@ bazel clean
# virtualenv.
export TF_NEED_GCP=0
export TF_NEED_HDFS=0
-export TF_ENABLE_XLA=0
# Obtain the path to Python binary
if [[ ${IS_VIRTUALENV} == "1" ]]; then
diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD
index 095ac1f4cc..b9f4902639 100644
--- a/tensorflow/tools/lib_package/BUILD
+++ b/tensorflow/tools/lib_package/BUILD
@@ -137,16 +137,6 @@ genrule(
"@snappy//:COPYING",
"@zlib_archive//:zlib.h",
] + select({
- "//tensorflow:with_aws_support": [
- "@aws//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_gcp_support": [
- "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
"//tensorflow:with_jemalloc_linux_x86_64": [
"@jemalloc//:COPYING",
],
@@ -171,7 +161,14 @@ genrule(
"@grpc//third_party/nanopb:LICENSE.txt",
"@grpc//third_party/address_sorting:LICENSE",
],
- ),
+ ) + select({
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
+ "@aws//:LICENSE",
+ "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
+ ],
+ }),
outs = ["include/tensorflow/c/LICENSE"],
cmd = "$(location :concat_licenses.sh) $(SRCS) >$@",
tools = [":concat_licenses.sh"],
@@ -205,16 +202,6 @@ genrule(
"@snappy//:COPYING",
"@zlib_archive//:zlib.h",
] + select({
- "//tensorflow:with_aws_support": [
- "@aws//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_gcp_support": [
- "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
"//tensorflow:with_jemalloc_linux_x86_64": [
"@jemalloc//:COPYING",
],
@@ -232,7 +219,14 @@ genrule(
]) + if_mkl([
"//third_party/mkl:LICENSE",
"//third_party/mkl_dnn:LICENSE",
- ]),
+ ]) + select({
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
+ "@aws//:LICENSE",
+ "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
+ ],
+ }),
outs = ["include/tensorflow/jni/LICENSE"],
cmd = "$(location :concat_licenses.sh) $(SRCS) >$@",
tools = [":concat_licenses.sh"],
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index cce60ccea0..c621812535 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -66,8 +66,6 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
"//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
"//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:test_utils",
- "//tensorflow/contrib/data/python/ops:contrib_op_loader",
"//tensorflow/contrib/eager/python/examples:examples_pip",
"//tensorflow/contrib/eager/python:evaluator",
"//tensorflow/contrib/gan:gan",
@@ -169,17 +167,6 @@ filegroup(
"@zlib_archive//:zlib.h",
"@org_python_pypi_backports_weakref//:LICENSE",
] + select({
- "//tensorflow:with_aws_support": [
- "@aws//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_gcp_support": [
- "@com_github_googleapis_googleapis//:LICENSE",
- "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
"//tensorflow:with_jemalloc_linux_x86_64": [
"@jemalloc//:COPYING",
],
@@ -188,11 +175,6 @@ filegroup(
],
"//conditions:default": [],
}) + select({
- "//tensorflow:with_kafka_support": [
- "@kafka//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
"//tensorflow/core/kernels:xsmm": [
"@libxsmm_archive//:LICENSE.md",
],
@@ -215,7 +197,16 @@ filegroup(
"@ngraph_tf//:LICENSE",
"@nlohmann_json_lib//:LICENSE.MIT",
"@tbb//:LICENSE",
- ]) + tf_additional_license_deps(),
+ ]) + tf_additional_license_deps() + select({
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
+ "@aws//:LICENSE",
+ "@com_github_googleapis_googleapis//:LICENSE",
+ "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
+ "@kafka//:LICENSE",
+ ],
+ }),
)
sh_binary(
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 6966783efd..9b4b698874 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -57,39 +57,39 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
# Point //external/local_config_arm_compiler to //external/arm_compiler
arm_compiler_configure(
name = "local_config_arm_compiler",
- remote_config_repo = "../arm_compiler",
build_file = clean_dep("//third_party/toolchains/cpus/arm:BUILD"),
+ remote_config_repo = "../arm_compiler",
)
mkl_repository(
name = "mkl_linux",
+ build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
+ sha256 = "e2233534a9d15c387e22260997af4312a39e9f86f791768409be273b5453c4e6",
+ strip_prefix = "mklml_lnx_2019.0.20180710",
urls = [
"https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_lnx_2019.0.20180710.tgz",
"https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_lnx_2019.0.20180710.tgz",
],
- sha256 = "e2233534a9d15c387e22260997af4312a39e9f86f791768409be273b5453c4e6",
- strip_prefix = "mklml_lnx_2019.0.20180710",
- build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
)
mkl_repository(
name = "mkl_windows",
+ build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
+ sha256 = "3fdcff17b018a0082491adf3ba143358265336a801646e46e0191ec8d58d24a2",
+ strip_prefix = "mklml_win_2019.0.20180710",
urls = [
"https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_win_2019.0.20180710.zip",
"https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_win_2019.0.20180710.zip",
],
- sha256 = "3fdcff17b018a0082491adf3ba143358265336a801646e46e0191ec8d58d24a2",
- strip_prefix = "mklml_win_2019.0.20180710",
- build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
)
mkl_repository(
name = "mkl_darwin",
+ build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
+ sha256 = "411a30014a938eb83fb9f37b3dbe8e371b106fc1dd621fc23123cadc72737ce6",
+ strip_prefix = "mklml_mac_2019.0.20180710",
urls = [
"https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_mac_2019.0.20180710.tgz",
"https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_mac_2019.0.20180710.tgz",
],
- sha256 = "411a30014a938eb83fb9f37b3dbe8e371b106fc1dd621fc23123cadc72737ce6",
- strip_prefix = "mklml_mac_2019.0.20180710",
- build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
)
if path_prefix:
@@ -98,39 +98,40 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "mkl_dnn",
+ build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"),
+ sha256 = "363cc9239eacf8e7917753c6d8c94f767e4cd049160d0654a61ef32d5e1b3049",
+ strip_prefix = "mkl-dnn-4e333787e0d66a1dca1218e99a891d493dbc8ef1",
urls = [
"https://mirror.bazel.build/github.com/intel/mkl-dnn/archive/4e333787e0d66a1dca1218e99a891d493dbc8ef1.tar.gz",
"https://github.com/intel/mkl-dnn/archive/4e333787e0d66a1dca1218e99a891d493dbc8ef1.tar.gz",
],
- sha256 = "363cc9239eacf8e7917753c6d8c94f767e4cd049160d0654a61ef32d5e1b3049",
- strip_prefix = "mkl-dnn-4e333787e0d66a1dca1218e99a891d493dbc8ef1",
- build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"),
)
tf_http_archive(
name = "com_google_absl",
+ build_file = clean_dep("//third_party:com_google_absl.BUILD"),
+ sha256 = "7dd09690ae7ca4551de3111d4a86b75b23ec17445f273d3c42bdcdc1c7b02e4e",
+ strip_prefix = "abseil-cpp-48cd2c3f351ff188bc85684b84a91b6e6d17d896",
urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/e291c279e458761e77a69b09b129d3d1e81f1e80.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/e291c279e458761e77a69b09b129d3d1e81f1e80.tar.gz",
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/48cd2c3f351ff188bc85684b84a91b6e6d17d896.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/48cd2c3f351ff188bc85684b84a91b6e6d17d896.tar.gz",
],
- sha256 = "278a1af58b633be886fe81bf7061dca6b5fea99566850d1319fffdaa1a061792",
- strip_prefix = "abseil-cpp-e291c279e458761e77a69b09b129d3d1e81f1e80",
- build_file = clean_dep("//third_party:com_google_absl.BUILD"),
)
tf_http_archive(
name = "eigen_archive",
+ build_file = clean_dep("//third_party:eigen.BUILD"),
+ sha256 = "d956415d784fa4e42b6a2a45c32556d6aec9d0a3d8ef48baee2522ab762556a9",
+ strip_prefix = "eigen-eigen-fd6845384b86",
urls = [
"https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz",
"https://bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz",
],
- sha256 = "d956415d784fa4e42b6a2a45c32556d6aec9d0a3d8ef48baee2522ab762556a9",
- strip_prefix = "eigen-eigen-fd6845384b86",
- build_file = clean_dep("//third_party:eigen.BUILD"),
)
tf_http_archive(
name = "arm_compiler",
+ build_file = clean_dep("//:arm_compiler.BUILD"),
sha256 = "970285762565c7890c6c087d262b0a18286e7d0384f13a37786d8521773bc969",
strip_prefix = "tools-0e906ebc527eab1cdbf7adabff5b474da9562e9f/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf",
urls = [
@@ -139,216 +140,211 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
# remove the whitelist entry in third_party/repo.bzl.
# "https://github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz",
],
- build_file = clean_dep("//:arm_compiler.BUILD"),
)
tf_http_archive(
name = "libxsmm_archive",
+ build_file = clean_dep("//third_party:libxsmm.BUILD"),
+ sha256 = "cd8532021352b4a0290d209f7f9bfd7c2411e08286a893af3577a43457287bfa",
+ strip_prefix = "libxsmm-1.9",
urls = [
"https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.9.tar.gz",
"https://github.com/hfp/libxsmm/archive/1.9.tar.gz",
],
- sha256 = "cd8532021352b4a0290d209f7f9bfd7c2411e08286a893af3577a43457287bfa",
- strip_prefix = "libxsmm-1.9",
- build_file = clean_dep("//third_party:libxsmm.BUILD"),
)
tf_http_archive(
name = "ortools_archive",
+ build_file = clean_dep("//third_party:ortools.BUILD"),
+ sha256 = "d025a95f78b5fc5eaa4da5f395f23d11c23cf7dbd5069f1f627f002de87b86b9",
+ strip_prefix = "or-tools-6.7.2/src",
urls = [
"https://mirror.bazel.build/github.com/google/or-tools/archive/v6.7.2.tar.gz",
"https://github.com/google/or-tools/archive/v6.7.2.tar.gz",
],
- sha256 = "d025a95f78b5fc5eaa4da5f395f23d11c23cf7dbd5069f1f627f002de87b86b9",
- strip_prefix = "or-tools-6.7.2/src",
- build_file = clean_dep("//third_party:ortools.BUILD"),
)
tf_http_archive(
name = "com_googlesource_code_re2",
+ sha256 = "803c7811146edeef8f91064de37c6f19136ff01a2a8cdb3230e940b2fd9f07fe",
+ strip_prefix = "re2-2018-07-01",
+ system_build_file = clean_dep("//third_party/systemlibs:re2.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/google/re2/archive/2018-07-01.tar.gz",
"https://github.com/google/re2/archive/2018-07-01.tar.gz",
],
- sha256 = "803c7811146edeef8f91064de37c6f19136ff01a2a8cdb3230e940b2fd9f07fe",
- strip_prefix = "re2-2018-07-01",
- system_build_file = clean_dep("//third_party/systemlibs:re2.BUILD"),
)
tf_http_archive(
name = "com_github_googlecloudplatform_google_cloud_cpp",
- urls = [
- "https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/14760a86c4ffab9943b476305c4fe927ad95db1c.tar.gz",
- "https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/14760a86c4ffab9943b476305c4fe927ad95db1c.tar.gz",
- ],
sha256 = "fdd3b3aecce60987e5525e55bf3a21d68a8695320bd5b980775af6507eec3944",
strip_prefix = "google-cloud-cpp-14760a86c4ffab9943b476305c4fe927ad95db1c",
system_build_file = clean_dep("//third_party/systemlibs:google_cloud_cpp.BUILD"),
system_link_files = {
"//third_party/systemlibs:google_cloud_cpp.google.cloud.bigtable.BUILD": "google/cloud/bigtable/BUILD",
},
+ urls = [
+ "https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/14760a86c4ffab9943b476305c4fe927ad95db1c.tar.gz",
+ "https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/14760a86c4ffab9943b476305c4fe927ad95db1c.tar.gz",
+ ],
)
tf_http_archive(
name = "com_github_googleapis_googleapis",
+ build_file = clean_dep("//third_party:googleapis.BUILD"),
+ sha256 = "824870d87a176f26bcef663e92051f532fac756d1a06b404055dc078425f4378",
+ strip_prefix = "googleapis-f81082ea1e2f85c43649bee26e0d9871d4b41cdb",
+ system_build_file = clean_dep("//third_party/systemlibs:googleapis.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip",
"https://github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip",
],
- sha256 = "824870d87a176f26bcef663e92051f532fac756d1a06b404055dc078425f4378",
- strip_prefix = "googleapis-f81082ea1e2f85c43649bee26e0d9871d4b41cdb",
- build_file = clean_dep("//third_party:googleapis.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:googleapis.BUILD"),
)
tf_http_archive(
name = "gemmlowp",
+ sha256 = "b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658",
+ strip_prefix = "gemmlowp-38ebac7b059e84692f53e5938f97a9943c120d98",
urls = [
"https://mirror.bazel.build/github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip",
"https://github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip",
],
- sha256 = "b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658",
- strip_prefix = "gemmlowp-38ebac7b059e84692f53e5938f97a9943c120d98",
)
tf_http_archive(
name = "farmhash_archive",
+ build_file = clean_dep("//third_party:farmhash.BUILD"),
+ sha256 = "6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0",
+ strip_prefix = "farmhash-816a4ae622e964763ca0862d9dbd19324a1eaf45",
urls = [
"https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz",
"https://github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz",
],
- sha256 = "6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0",
- strip_prefix = "farmhash-816a4ae622e964763ca0862d9dbd19324a1eaf45",
- build_file = clean_dep("//third_party:farmhash.BUILD"),
)
tf_http_archive(
name = "highwayhash",
+ build_file = clean_dep("//third_party:highwayhash.BUILD"),
+ sha256 = "9c3e0e87d581feeb0c18d814d98f170ff23e62967a2bd6855847f0b2fe598a37",
+ strip_prefix = "highwayhash-fd3d9af80465e4383162e4a7c5e2f406e82dd968",
urls = [
"http://mirror.bazel.build/github.com/google/highwayhash/archive/fd3d9af80465e4383162e4a7c5e2f406e82dd968.tar.gz",
"https://github.com/google/highwayhash/archive/fd3d9af80465e4383162e4a7c5e2f406e82dd968.tar.gz",
],
- sha256 = "9c3e0e87d581feeb0c18d814d98f170ff23e62967a2bd6855847f0b2fe598a37",
- strip_prefix = "highwayhash-fd3d9af80465e4383162e4a7c5e2f406e82dd968",
- build_file = clean_dep("//third_party:highwayhash.BUILD"),
)
tf_http_archive(
name = "nasm",
+ build_file = clean_dep("//third_party:nasm.BUILD"),
+ sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011",
+ strip_prefix = "nasm-2.13.03",
+ system_build_file = clean_dep("//third_party/systemlibs:nasm.BUILD"),
urls = [
"https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
"http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.13.03.tar.bz2/sha512/d7a6b4cee8dfd603d8d4c976e5287b5cc542fa0b466ff989b743276a6e28114e64289bf02a7819eca63142a5278aa6eed57773007e5f589e15768e6456a8919d/nasm-2.13.03.tar.bz2",
"http://www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
],
- sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011",
- strip_prefix = "nasm-2.13.03",
- build_file = clean_dep("//third_party:nasm.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:nasm.BUILD"),
)
tf_http_archive(
name = "jpeg",
+ build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"),
+ sha256 = "f892fff427ab3adffc289363eac26d197ce3ccacefe5f5822377348a8166069b",
+ strip_prefix = "libjpeg-turbo-2.0.0",
+ system_build_file = clean_dep("//third_party/systemlibs:jpeg.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
"https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
],
- sha256 = "f892fff427ab3adffc289363eac26d197ce3ccacefe5f5822377348a8166069b",
- strip_prefix = "libjpeg-turbo-2.0.0",
- build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:jpeg.BUILD"),
)
tf_http_archive(
name = "png_archive",
+ build_file = clean_dep("//third_party:png.BUILD"),
+ patch_file = clean_dep("//third_party:png_fix_rpi.patch"),
+ sha256 = "e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef",
+ strip_prefix = "libpng-1.6.34",
+ system_build_file = clean_dep("//third_party/systemlibs:png.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz",
"https://github.com/glennrp/libpng/archive/v1.6.34.tar.gz",
],
- sha256 = "e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef",
- strip_prefix = "libpng-1.6.34",
- build_file = clean_dep("//third_party:png.BUILD"),
- patch_file = clean_dep("//third_party:png_fix_rpi.patch"),
- system_build_file = clean_dep("//third_party/systemlibs:png.BUILD"),
)
tf_http_archive(
name = "org_sqlite",
+ build_file = clean_dep("//third_party:sqlite.BUILD"),
+ sha256 = "ad68c1216c3a474cf360c7581a4001e952515b3649342100f2d7ca7c8e313da6",
+ strip_prefix = "sqlite-amalgamation-3240000",
+ system_build_file = clean_dep("//third_party/systemlibs:sqlite.BUILD"),
urls = [
"https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3240000.zip",
"https://www.sqlite.org/2018/sqlite-amalgamation-3240000.zip",
],
- sha256 = "ad68c1216c3a474cf360c7581a4001e952515b3649342100f2d7ca7c8e313da6",
- strip_prefix = "sqlite-amalgamation-3240000",
- build_file = clean_dep("//third_party:sqlite.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:sqlite.BUILD"),
)
tf_http_archive(
name = "gif_archive",
+ build_file = clean_dep("//third_party:gif.BUILD"),
+ sha256 = "34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1",
+ strip_prefix = "giflib-5.1.4",
+ system_build_file = clean_dep("//third_party/systemlibs:gif.BUILD"),
urls = [
"https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz",
"http://pilotfiber.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz",
],
- sha256 = "34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1",
- strip_prefix = "giflib-5.1.4",
- build_file = clean_dep("//third_party:gif.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:gif.BUILD"),
)
tf_http_archive(
name = "six_archive",
+ build_file = clean_dep("//third_party:six.BUILD"),
+ sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a",
+ strip_prefix = "six-1.10.0",
+ system_build_file = clean_dep("//third_party/systemlibs:six.BUILD"),
urls = [
"https://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz",
"https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz",
],
- sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a",
- strip_prefix = "six-1.10.0",
- build_file = clean_dep("//third_party:six.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:six.BUILD"),
)
tf_http_archive(
name = "astor_archive",
+ build_file = clean_dep("//third_party:astor.BUILD"),
+ sha256 = "ff6d2e2962d834acb125cc4dcc80c54a8c17c253f4cc9d9c43b5102a560bb75d",
+ strip_prefix = "astor-0.6.2",
+ system_build_file = clean_dep("//third_party/systemlibs:astor.BUILD"),
urls = [
"https://mirror.bazel.build/pypi.python.org/packages/d8/be/c4276b3199ec3feee2a88bc64810fbea8f26d961e0a4cd9c68387a9f35de/astor-0.6.2.tar.gz",
"https://pypi.python.org/packages/d8/be/c4276b3199ec3feee2a88bc64810fbea8f26d961e0a4cd9c68387a9f35de/astor-0.6.2.tar.gz",
],
- sha256 = "ff6d2e2962d834acb125cc4dcc80c54a8c17c253f4cc9d9c43b5102a560bb75d",
- strip_prefix = "astor-0.6.2",
- build_file = clean_dep("//third_party:astor.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:astor.BUILD"),
)
tf_http_archive(
name = "gast_archive",
+ build_file = clean_dep("//third_party:gast.BUILD"),
+ sha256 = "7068908321ecd2774f145193c4b34a11305bd104b4551b09273dfd1d6a374930",
+ strip_prefix = "gast-0.2.0",
+ system_build_file = clean_dep("//third_party/systemlibs:gast.BUILD"),
urls = [
"https://mirror.bazel.build/pypi.python.org/packages/5c/78/ff794fcae2ce8aa6323e789d1f8b3b7765f601e7702726f430e814822b96/gast-0.2.0.tar.gz",
"https://pypi.python.org/packages/5c/78/ff794fcae2ce8aa6323e789d1f8b3b7765f601e7702726f430e814822b96/gast-0.2.0.tar.gz",
],
- sha256 = "7068908321ecd2774f145193c4b34a11305bd104b4551b09273dfd1d6a374930",
- strip_prefix = "gast-0.2.0",
- build_file = clean_dep("//third_party:gast.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:gast.BUILD"),
)
tf_http_archive(
name = "termcolor_archive",
+ build_file = clean_dep("//third_party:termcolor.BUILD"),
+ sha256 = "1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b",
+ strip_prefix = "termcolor-1.1.0",
+ system_build_file = clean_dep("//third_party/systemlibs:termcolor.BUILD"),
urls = [
"https://mirror.bazel.build/pypi.python.org/packages/8a/48/a76be51647d0eb9f10e2a4511bf3ffb8cc1e6b14e9e4fab46173aa79f981/termcolor-1.1.0.tar.gz",
"https://pypi.python.org/packages/8a/48/a76be51647d0eb9f10e2a4511bf3ffb8cc1e6b14e9e4fab46173aa79f981/termcolor-1.1.0.tar.gz",
],
- sha256 = "1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b",
- strip_prefix = "termcolor-1.1.0",
- build_file = clean_dep("//third_party:termcolor.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:termcolor.BUILD"),
)
tf_http_archive(
name = "absl_py",
- urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz",
- "https://github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz",
- ],
sha256 = "95160f778a62c7a60ddeadc7bf2d83f85a23a27359814aca12cf949e896fa82c",
strip_prefix = "abseil-py-pypi-v0.2.2",
system_build_file = clean_dep("//third_party/systemlibs:absl_py.BUILD"),
@@ -356,17 +352,21 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
"//third_party/systemlibs:absl_py.absl.flags.BUILD": "absl/flags/BUILD",
"//third_party/systemlibs:absl_py.absl.testing.BUILD": "absl/testing/BUILD",
},
+ urls = [
+ "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz",
+ "https://github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz",
+ ],
)
tf_http_archive(
name = "org_python_pypi_backports_weakref",
+ build_file = clean_dep("//third_party:backports_weakref.BUILD"),
+ sha256 = "8813bf712a66b3d8b85dc289e1104ed220f1878cf981e2fe756dfaabe9a82892",
+ strip_prefix = "backports.weakref-1.0rc1/src",
urls = [
"https://mirror.bazel.build/pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz",
"https://pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz",
],
- sha256 = "8813bf712a66b3d8b85dc289e1104ed220f1878cf981e2fe756dfaabe9a82892",
- strip_prefix = "backports.weakref-1.0rc1/src",
- build_file = clean_dep("//third_party:backports_weakref.BUILD"),
)
filegroup_external(
@@ -389,9 +389,9 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "protobuf_archive",
- urls = PROTOBUF_URLS,
sha256 = PROTOBUF_SHA256,
strip_prefix = PROTOBUF_STRIP_PREFIX,
+ urls = PROTOBUF_URLS,
)
# We need to import the protobuf library under the names com_google_protobuf
@@ -399,222 +399,222 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
# Unfortunately there is no way to alias http_archives at the moment.
tf_http_archive(
name = "com_google_protobuf",
- urls = PROTOBUF_URLS,
sha256 = PROTOBUF_SHA256,
strip_prefix = PROTOBUF_STRIP_PREFIX,
+ urls = PROTOBUF_URLS,
)
tf_http_archive(
name = "com_google_protobuf_cc",
- urls = PROTOBUF_URLS,
sha256 = PROTOBUF_SHA256,
strip_prefix = PROTOBUF_STRIP_PREFIX,
+ urls = PROTOBUF_URLS,
)
tf_http_archive(
name = "nsync",
+ sha256 = "692f9b30e219f71a6371b98edd39cef3cbda35ac3abc4cd99ce19db430a5591a",
+ strip_prefix = "nsync-1.20.1",
+ system_build_file = clean_dep("//third_party/systemlibs:nsync.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/google/nsync/archive/1.20.1.tar.gz",
"https://github.com/google/nsync/archive/1.20.1.tar.gz",
],
- sha256 = "692f9b30e219f71a6371b98edd39cef3cbda35ac3abc4cd99ce19db430a5591a",
- strip_prefix = "nsync-1.20.1",
- system_build_file = clean_dep("//third_party/systemlibs:nsync.BUILD"),
)
tf_http_archive(
name = "com_google_googletest",
+ sha256 = "353ab86e35cea1cd386115279cf4b16695bbf21b897bfbf2721cf4cb5f64ade8",
+ strip_prefix = "googletest-997d343dd680e541ef96ce71ee54a91daf2577a0",
urls = [
"https://mirror.bazel.build/github.com/google/googletest/archive/997d343dd680e541ef96ce71ee54a91daf2577a0.zip",
"https://github.com/google/googletest/archive/997d343dd680e541ef96ce71ee54a91daf2577a0.zip",
],
- sha256 = "353ab86e35cea1cd386115279cf4b16695bbf21b897bfbf2721cf4cb5f64ade8",
- strip_prefix = "googletest-997d343dd680e541ef96ce71ee54a91daf2577a0",
)
tf_http_archive(
name = "com_github_gflags_gflags",
+ sha256 = "ae27cdbcd6a2f935baa78e4f21f675649271634c092b1be01469440495609d0e",
+ strip_prefix = "gflags-2.2.1",
urls = [
"https://mirror.bazel.build/github.com/gflags/gflags/archive/v2.2.1.tar.gz",
"https://github.com/gflags/gflags/archive/v2.2.1.tar.gz",
],
- sha256 = "ae27cdbcd6a2f935baa78e4f21f675649271634c092b1be01469440495609d0e",
- strip_prefix = "gflags-2.2.1",
)
tf_http_archive(
name = "pcre",
+ build_file = clean_dep("//third_party:pcre.BUILD"),
sha256 = "69acbc2fbdefb955d42a4c606dfde800c2885711d2979e356c0636efde9ec3b5",
+ strip_prefix = "pcre-8.42",
+ system_build_file = clean_dep("//third_party/systemlibs:pcre.BUILD"),
urls = [
"https://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.42.tar.gz",
"http://ftp.exim.org/pub/pcre/pcre-8.42.tar.gz",
],
- strip_prefix = "pcre-8.42",
- build_file = clean_dep("//third_party:pcre.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:pcre.BUILD"),
)
tf_http_archive(
name = "swig",
+ build_file = clean_dep("//third_party:swig.BUILD"),
sha256 = "58a475dbbd4a4d7075e5fe86d4e54c9edde39847cdb96a3053d87cb64a23a453",
+ strip_prefix = "swig-3.0.8",
+ system_build_file = clean_dep("//third_party/systemlibs:swig.BUILD"),
urls = [
"https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz",
"http://ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz",
"http://pilotfiber.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz",
],
- strip_prefix = "swig-3.0.8",
- build_file = clean_dep("//third_party:swig.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:swig.BUILD"),
)
tf_http_archive(
name = "curl",
+ build_file = clean_dep("//third_party:curl.BUILD"),
sha256 = "e9c37986337743f37fd14fe8737f246e97aec94b39d1b71e8a5973f72a9fc4f5",
+ strip_prefix = "curl-7.60.0",
+ system_build_file = clean_dep("//third_party/systemlibs:curl.BUILD"),
urls = [
"https://mirror.bazel.build/curl.haxx.se/download/curl-7.60.0.tar.gz",
"https://curl.haxx.se/download/curl-7.60.0.tar.gz",
],
- strip_prefix = "curl-7.60.0",
- build_file = clean_dep("//third_party:curl.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:curl.BUILD"),
)
tf_http_archive(
name = "grpc",
+ sha256 = "50db9cf2221354485eb7c3bd55a4c27190caef7048a2a1a15fbe60a498f98b44",
+ strip_prefix = "grpc-1.13.0",
+ system_build_file = clean_dep("//third_party/systemlibs:grpc.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/grpc/grpc/archive/v1.13.0.tar.gz",
"https://github.com/grpc/grpc/archive/v1.13.0.tar.gz",
],
- sha256 = "50db9cf2221354485eb7c3bd55a4c27190caef7048a2a1a15fbe60a498f98b44",
- strip_prefix = "grpc-1.13.0",
- system_build_file = clean_dep("//third_party/systemlibs:grpc.BUILD"),
)
tf_http_archive(
name = "linenoise",
+ build_file = clean_dep("//third_party:linenoise.BUILD"),
sha256 = "7f51f45887a3d31b4ce4fa5965210a5e64637ceac12720cfce7954d6a2e812f7",
+ strip_prefix = "linenoise-c894b9e59f02203dbe4e2be657572cf88c4230c3",
urls = [
"https://mirror.bazel.build/github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz",
"https://github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz",
],
- strip_prefix = "linenoise-c894b9e59f02203dbe4e2be657572cf88c4230c3",
- build_file = clean_dep("//third_party:linenoise.BUILD"),
)
# TODO(phawkins): currently, this rule uses an unofficial LLVM mirror.
# Switch to an official source of snapshots if/when possible.
tf_http_archive(
name = "llvm",
+ build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
+ sha256 = "a4f8bfe7e3e69069934a87e612a1d4d3b8b6af13e0f1213a42a6046e1bcd50d8",
+ strip_prefix = "llvm-d3429e96fe1e45b1dc0106463832523f37faf271",
urls = [
"https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/d3429e96fe1e45b1dc0106463832523f37faf271.tar.gz",
"https://github.com/llvm-mirror/llvm/archive/d3429e96fe1e45b1dc0106463832523f37faf271.tar.gz",
],
- sha256 = "a4f8bfe7e3e69069934a87e612a1d4d3b8b6af13e0f1213a42a6046e1bcd50d8",
- strip_prefix = "llvm-d3429e96fe1e45b1dc0106463832523f37faf271",
- build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)
tf_http_archive(
name = "lmdb",
+ build_file = clean_dep("//third_party:lmdb.BUILD"),
+ sha256 = "f3927859882eb608868c8c31586bb7eb84562a40a6bf5cc3e13b6b564641ea28",
+ strip_prefix = "lmdb-LMDB_0.9.22/libraries/liblmdb",
+ system_build_file = clean_dep("//third_party/systemlibs:lmdb.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz",
"https://github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz",
],
- sha256 = "f3927859882eb608868c8c31586bb7eb84562a40a6bf5cc3e13b6b564641ea28",
- strip_prefix = "lmdb-LMDB_0.9.22/libraries/liblmdb",
- build_file = clean_dep("//third_party:lmdb.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:lmdb.BUILD"),
)
tf_http_archive(
name = "jsoncpp_git",
+ build_file = clean_dep("//third_party:jsoncpp.BUILD"),
+ sha256 = "c49deac9e0933bcb7044f08516861a2d560988540b23de2ac1ad443b219afdb6",
+ strip_prefix = "jsoncpp-1.8.4",
+ system_build_file = clean_dep("//third_party/systemlibs:jsoncpp.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz",
"https://github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz",
],
- sha256 = "c49deac9e0933bcb7044f08516861a2d560988540b23de2ac1ad443b219afdb6",
- strip_prefix = "jsoncpp-1.8.4",
- build_file = clean_dep("//third_party:jsoncpp.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:jsoncpp.BUILD"),
)
tf_http_archive(
name = "boringssl",
+ sha256 = "1188e29000013ed6517168600fc35a010d58c5d321846d6a6dfee74e4c788b45",
+ strip_prefix = "boringssl-7f634429a04abc48e2eb041c81c5235816c96514",
+ system_build_file = clean_dep("//third_party/systemlibs:boringssl.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/google/boringssl/archive/7f634429a04abc48e2eb041c81c5235816c96514.tar.gz",
"https://github.com/google/boringssl/archive/7f634429a04abc48e2eb041c81c5235816c96514.tar.gz",
],
- sha256 = "1188e29000013ed6517168600fc35a010d58c5d321846d6a6dfee74e4c788b45",
- strip_prefix = "boringssl-7f634429a04abc48e2eb041c81c5235816c96514",
- system_build_file = clean_dep("//third_party/systemlibs:boringssl.BUILD"),
)
tf_http_archive(
name = "zlib_archive",
+ build_file = clean_dep("//third_party:zlib.BUILD"),
+ sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
+ strip_prefix = "zlib-1.2.11",
+ system_build_file = clean_dep("//third_party/systemlibs:zlib.BUILD"),
urls = [
"https://mirror.bazel.build/zlib.net/zlib-1.2.11.tar.gz",
"https://zlib.net/zlib-1.2.11.tar.gz",
],
- sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
- strip_prefix = "zlib-1.2.11",
- build_file = clean_dep("//third_party:zlib.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:zlib.BUILD"),
)
tf_http_archive(
name = "fft2d",
+ build_file = clean_dep("//third_party/fft2d:fft2d.BUILD"),
+ sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296",
urls = [
"https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
"http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
],
- sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296",
- build_file = clean_dep("//third_party/fft2d:fft2d.BUILD"),
)
tf_http_archive(
name = "snappy",
+ build_file = clean_dep("//third_party:snappy.BUILD"),
+ sha256 = "3dfa02e873ff51a11ee02b9ca391807f0c8ea0529a4924afa645fbf97163f9d4",
+ strip_prefix = "snappy-1.1.7",
+ system_build_file = clean_dep("//third_party/systemlibs:snappy.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/google/snappy/archive/1.1.7.tar.gz",
"https://github.com/google/snappy/archive/1.1.7.tar.gz",
],
- sha256 = "3dfa02e873ff51a11ee02b9ca391807f0c8ea0529a4924afa645fbf97163f9d4",
- strip_prefix = "snappy-1.1.7",
- build_file = clean_dep("//third_party:snappy.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:snappy.BUILD"),
)
tf_http_archive(
name = "nccl_archive",
+ build_file = clean_dep("//third_party:nccl/nccl_archive.BUILD"),
+ sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176",
+ strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7",
urls = [
"https://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz",
"https://github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz",
],
- sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176",
- strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7",
- build_file = clean_dep("//third_party:nccl/nccl_archive.BUILD"),
)
tf_http_archive(
name = "kafka",
+ build_file = clean_dep("//third_party:kafka/BUILD"),
+ patch_file = clean_dep("//third_party/kafka:config.patch"),
+ sha256 = "cc6ebbcd0a826eec1b8ce1f625ffe71b53ef3290f8192b6cae38412a958f4fd3",
+ strip_prefix = "librdkafka-0.11.5",
urls = [
"https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.5.tar.gz",
"https://github.com/edenhill/librdkafka/archive/v0.11.5.tar.gz",
],
- sha256 = "cc6ebbcd0a826eec1b8ce1f625ffe71b53ef3290f8192b6cae38412a958f4fd3",
- strip_prefix = "librdkafka-0.11.5",
- build_file = clean_dep("//third_party:kafka/BUILD"),
- patch_file = clean_dep("//third_party/kafka:config.patch"),
)
tf_http_archive(
name = "aws",
+ build_file = clean_dep("//third_party:aws.BUILD"),
+ sha256 = "b888d8ce5fc10254c3dd6c9020c7764dd53cf39cf011249d0b4deda895de1b7c",
+ strip_prefix = "aws-sdk-cpp-1.3.15",
urls = [
"https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz",
"https://github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz",
],
- sha256 = "b888d8ce5fc10254c3dd6c9020c7764dd53cf39cf011249d0b4deda895de1b7c",
- strip_prefix = "aws-sdk-cpp-1.3.15",
- build_file = clean_dep("//third_party:aws.BUILD"),
)
java_import_external(
@@ -644,14 +644,14 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "jemalloc",
+ build_file = clean_dep("//third_party:jemalloc.BUILD"),
+ sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8",
+ strip_prefix = "jemalloc-4.4.0",
+ system_build_file = clean_dep("//third_party/systemlibs:jemalloc.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz",
"https://github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz",
],
- sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8",
- strip_prefix = "jemalloc-4.4.0",
- build_file = clean_dep("//third_party:jemalloc.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:jemalloc.BUILD"),
)
java_import_external(
@@ -700,196 +700,196 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "com_google_pprof",
+ build_file = clean_dep("//third_party:pprof.BUILD"),
+ sha256 = "e0928ca4aa10ea1e0551e2d7ce4d1d7ea2d84b2abbdef082b0da84268791d0c4",
+ strip_prefix = "pprof-c0fb62ec88c411cc91194465e54db2632845b650",
urls = [
"https://mirror.bazel.build/github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz",
"https://github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz",
],
- sha256 = "e0928ca4aa10ea1e0551e2d7ce4d1d7ea2d84b2abbdef082b0da84268791d0c4",
- strip_prefix = "pprof-c0fb62ec88c411cc91194465e54db2632845b650",
- build_file = clean_dep("//third_party:pprof.BUILD"),
)
tf_http_archive(
name = "cub_archive",
+ build_file = clean_dep("//third_party:cub.BUILD"),
+ sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3",
+ strip_prefix = "cub-1.8.0",
urls = [
"https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.8.0.zip",
"https://github.com/NVlabs/cub/archive/1.8.0.zip",
],
- sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3",
- strip_prefix = "cub-1.8.0",
- build_file = clean_dep("//third_party:cub.BUILD"),
)
tf_http_archive(
name = "cython",
+ build_file = clean_dep("//third_party:cython.BUILD"),
+ delete = ["BUILD.bazel"],
sha256 = "bccc9aa050ea02595b2440188813b936eaf345e85fb9692790cecfe095cf91aa",
+ strip_prefix = "cython-0.28.4",
+ system_build_file = clean_dep("//third_party/systemlibs:cython.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/cython/cython/archive/0.28.4.tar.gz",
"https://github.com/cython/cython/archive/0.28.4.tar.gz",
],
- strip_prefix = "cython-0.28.4",
- build_file = clean_dep("//third_party:cython.BUILD"),
- delete = ["BUILD.bazel"],
- system_build_file = clean_dep("//third_party/systemlibs:cython.BUILD"),
)
tf_http_archive(
name = "bazel_toolchains",
+ sha256 = "3b604699685c5c65dd3f6f17425570a4b2f00ddba2f750db15acc72e55bb098b",
+ strip_prefix = "bazel-toolchains-37acf1841ab1475c98a152cb9e446460c8ae29e1",
urls = [
"https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/37acf1841ab1475c98a152cb9e446460c8ae29e1.tar.gz",
"https://github.com/bazelbuild/bazel-toolchains/archive/37acf1841ab1475c98a152cb9e446460c8ae29e1.tar.gz",
],
- strip_prefix = "bazel-toolchains-37acf1841ab1475c98a152cb9e446460c8ae29e1",
- sha256 = "3b604699685c5c65dd3f6f17425570a4b2f00ddba2f750db15acc72e55bb098b",
)
tf_http_archive(
name = "arm_neon_2_x86_sse",
+ build_file = clean_dep("//third_party:arm_neon_2_x86_sse.BUILD"),
sha256 = "c8d90aa4357f8079d427e87a6f4c493da1fa4140aee926c05902d7ec1533d9a5",
strip_prefix = "ARM_NEON_2_x86_SSE-0f77d9d182265259b135dad949230ecbf1a2633d",
urls = [
"https://mirror.bazel.build/github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz",
"https://github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz",
],
- build_file = clean_dep("//third_party:arm_neon_2_x86_sse.BUILD"),
)
tf_http_archive(
name = "double_conversion",
+ build_file = clean_dep("//third_party:double_conversion.BUILD"),
+ sha256 = "2f7fbffac0d98d201ad0586f686034371a6d152ca67508ab611adc2386ad30de",
+ strip_prefix = "double-conversion-3992066a95b823efc8ccc1baf82a1cfc73f6e9b8",
+ system_build_file = clean_dep("//third_party/systemlibs:double_conversion.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip",
"https://github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip",
],
- sha256 = "2f7fbffac0d98d201ad0586f686034371a6d152ca67508ab611adc2386ad30de",
- strip_prefix = "double-conversion-3992066a95b823efc8ccc1baf82a1cfc73f6e9b8",
- build_file = clean_dep("//third_party:double_conversion.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:double_conversion.BUILD"),
)
tf_http_archive(
name = "tflite_mobilenet",
+ build_file = clean_dep("//third_party:tflite_mobilenet.BUILD"),
sha256 = "23f814d1c076bdf03715dfb6cab3713aa4fbdf040fd5448c43196bd2e97a4c1b",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
],
- build_file = clean_dep("//third_party:tflite_mobilenet.BUILD"),
)
tf_http_archive(
name = "tflite_mobilenet_ssd",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
sha256 = "767057f2837a46d97882734b03428e8dd640b93236052b312b2f0e45613c1cf0",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip",
],
- build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
tf_http_archive(
name = "tflite_mobilenet_ssd_quant",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
sha256 = "a809cd290b4d6a2e8a9d5dad076e0bd695b8091974e0eed1052b480b2f21b6dc",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip",
],
- build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
tf_http_archive(
name = "tflite_mobilenet_ssd_quant_protobuf",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
sha256 = "09280972c5777f1aa775ef67cb4ac5d5ed21970acd8535aeca62450ef14f0d79",
+ strip_prefix = "ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
"http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
],
- strip_prefix = "ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18",
- build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
tf_http_archive(
name = "tflite_conv_actions_frozen",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
sha256 = "d947b38cba389b5e2d0bfc3ea6cc49c784e187b41a071387b3742d1acac7691e",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip",
],
- build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
tf_http_archive(
name = "tflite_smartreply",
+ build_file = clean_dep("//third_party:tflite_smartreply.BUILD"),
sha256 = "8980151b85a87a9c1a3bb1ed4748119e4a85abd3cb5744d83da4d4bd0fbeef7c",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip",
],
- build_file = clean_dep("//third_party:tflite_smartreply.BUILD"),
)
tf_http_archive(
name = "tflite_ovic_testdata",
+ build_file = clean_dep("//third_party:tflite_ovic_testdata.BUILD"),
sha256 = "a9a705d8d519220178e2e65d383fdb21da37fdb31d1e909b0a1acdac46479e9c",
+ strip_prefix = "ovic",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/data/ovic.zip",
"https://storage.googleapis.com/download.tensorflow.org/data/ovic.zip",
],
- build_file = clean_dep("//third_party:tflite_ovic_testdata.BUILD"),
- strip_prefix = "ovic",
)
tf_http_archive(
name = "build_bazel_rules_android",
sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806",
+ strip_prefix = "rules_android-0.1.1",
urls = [
"https://mirror.bazel.build/github.com/bazelbuild/rules_android/archive/v0.1.1.zip",
"https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip",
],
- strip_prefix = "rules_android-0.1.1",
)
tf_http_archive(
name = "tbb",
+ build_file = clean_dep("//third_party/ngraph:tbb.BUILD"),
+ sha256 = "724686f90bcda78f13b76f297d964008737ccd6399328143c1c0093e73ae6a13",
+ strip_prefix = "tbb-tbb_2018",
urls = [
"https://mirror.bazel.build/github.com/01org/tbb/archive/tbb_2018.zip",
"https://github.com/01org/tbb/archive/tbb_2018.zip",
],
- sha256 = "724686f90bcda78f13b76f297d964008737ccd6399328143c1c0093e73ae6a13",
- strip_prefix = "tbb-tbb_2018",
- build_file = clean_dep("//third_party/ngraph:tbb.BUILD"),
)
tf_http_archive(
name = "ngraph",
+ build_file = clean_dep("//third_party/ngraph:ngraph.BUILD"),
+ sha256 = "bf9dcc88e5c66021e3aac80491a231711211540d613bf9b6bd28db3f5bb86b62",
+ strip_prefix = "ngraph-0.8.1",
urls = [
"https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.8.1.tar.gz",
"https://github.com/NervanaSystems/ngraph/archive/v0.8.1.tar.gz",
],
- sha256 = "bf9dcc88e5c66021e3aac80491a231711211540d613bf9b6bd28db3f5bb86b62",
- strip_prefix = "ngraph-0.8.1",
- build_file = clean_dep("//third_party/ngraph:ngraph.BUILD"),
)
tf_http_archive(
name = "nlohmann_json_lib",
+ build_file = clean_dep("//third_party/ngraph:nlohmann_json.BUILD"),
+ sha256 = "9f3549824af3ca7e9707a2503959886362801fb4926b869789d6929098a79e47",
+ strip_prefix = "json-3.1.1",
urls = [
"https://mirror.bazel.build/github.com/nlohmann/json/archive/v3.1.1.tar.gz",
"https://github.com/nlohmann/json/archive/v3.1.1.tar.gz",
],
- sha256 = "9f3549824af3ca7e9707a2503959886362801fb4926b869789d6929098a79e47",
- strip_prefix = "json-3.1.1",
- build_file = clean_dep("//third_party/ngraph:nlohmann_json.BUILD"),
)
tf_http_archive(
name = "ngraph_tf",
+ build_file = clean_dep("//third_party/ngraph:ngraph_tf.BUILD"),
+ sha256 = "402f84c748c113780a60f35f39aab118435285543aee4900d712b76fbf8a21ee",
+ strip_prefix = "ngraph-tf-0.6.1",
urls = [
"https://mirror.bazel.build/github.com/NervanaSystems/ngraph-tf/archive/v0.6.1.tar.gz",
"https://github.com/NervanaSystems/ngraph-tf/archive/v0.6.1.tar.gz",
],
- sha256 = "402f84c748c113780a60f35f39aab118435285543aee4900d712b76fbf8a21ee",
- strip_prefix = "ngraph-tf-0.6.1",
- build_file = clean_dep("//third_party/ngraph:ngraph_tf.BUILD"),
)
##############################################################################
diff --git a/third_party/gpus/crosstool/BUILD.tpl b/third_party/gpus/crosstool/BUILD.tpl
index f638756d23..c8812fab33 100644
--- a/third_party/gpus/crosstool/BUILD.tpl
+++ b/third_party/gpus/crosstool/BUILD.tpl
@@ -2,6 +2,20 @@ licenses(["restricted"])
package(default_visibility = ["//visibility:public"])
+toolchain(
+ name = "toolchain-linux-x86_64",
+ exec_compatible_with = [
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//platforms:x86_64",
+ ],
+ target_compatible_with = [
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//platforms:x86_64",
+ ],
+ toolchain = ":cc-compiler-local",
+ toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
+)
+
cc_toolchain_suite(
name = "toolchain",
toolchains = {
diff --git a/third_party/toolchains/BUILD b/third_party/toolchains/BUILD
index 7256a7d96e..bcbc4dda11 100644
--- a/third_party/toolchains/BUILD
+++ b/third_party/toolchains/BUILD
@@ -26,12 +26,10 @@ platform(
constraint_values = [
"@bazel_tools//platforms:x86_64",
"@bazel_tools//platforms:linux",
- "@bazel_tools//tools/cpp:clang",
- "@bazel_toolchains//constraints:xenial",
],
remote_execution_properties = """
properties: {
name: "container-image"
- value:"docker://gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04@sha256:06b585f42eed3b2030e9566b8f88f48d7472fa0f47e59765bc115376c8801bdf"
+ value:"docker://gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04@sha256:e5099ff15650986e268a43ee99e2d2b7ffe2459b8b6935385078d1d3b2ed4d02"
}""",
)
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD b/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
index 2d3e41127d..05abcb56d8 100755
--- a/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
+++ b/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
@@ -1253,7 +1253,7 @@ genrule(
"cuda/lib/libcupti.so.9.0",
],
cmd = """
-if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0.176" "$(@D)/cuda/lib/libcudart.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcublas.so.9.0.480" "$(@D)/cuda/lib/libcublas.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcusolver.so.9.0.176" "$(@D)/cuda/lib/libcusolver.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcurand.so.9.0.176" "$(@D)/cuda/lib/libcurand.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcufft.so.9.0.176" "$(@D)/cuda/lib/libcufft.so.9.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.7.2.1" "$(@D)/cuda/lib/libcudnn.so.7" && cp "/usr/local/cuda-9.0/extras/CUPTI/lib64/libcupti.so.9.0.176" "$(@D)/cuda/lib/libcupti.so.9.0"
+if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0.176" "$(@D)/cuda/lib/libcudart.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcublas.so.9.0.480" "$(@D)/cuda/lib/libcublas.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcusolver.so.9.0.176" "$(@D)/cuda/lib/libcusolver.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcurand.so.9.0.176" "$(@D)/cuda/lib/libcurand.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcufft.so.9.0.176" "$(@D)/cuda/lib/libcufft.so.9.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.4" "$(@D)/cuda/lib/libcudnn.so.7" && cp "/usr/local/cuda-9.0/extras/CUPTI/lib64/libcupti.so.9.0.176" "$(@D)/cuda/lib/libcupti.so.9.0"
""",
)
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
index a56b4513fb..6442e7628a 100755
--- a/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
+++ b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
@@ -2,6 +2,20 @@ licenses(["restricted"])
package(default_visibility = ["//visibility:public"])
+toolchain(
+ name = "toolchain-linux-x86_64",
+ exec_compatible_with = [
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//platforms:x86_64",
+ ],
+ target_compatible_with = [
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//platforms:x86_64",
+ ],
+ toolchain = ":cc-compiler-local",
+ toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
+)
+
cc_toolchain_suite(
name = "toolchain",
toolchains = {
diff --git a/tools/bazel.rc b/tools/bazel.rc
index 3734fab715..0cd148ed87 100644
--- a/tools/bazel.rc
+++ b/tools/bazel.rc
@@ -73,6 +73,7 @@ build --define=grpc_no_ares=true
build --spawn_strategy=standalone
build --genrule_strategy=standalone
build -c opt
+build --define=with_jemalloc=false
# Other build flags.
build --define=grpc_no_ares=true