aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
commitf41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch)
treeef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/python
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/BUILD965
-rw-r--r--tensorflow/python/__init__.py42
-rwxr-xr-xtensorflow/python/client/__init__.py0
-rw-r--r--tensorflow/python/client/client_lib.py40
-rw-r--r--tensorflow/python/client/events_writer.i34
-rw-r--r--tensorflow/python/client/events_writer_test.py54
-rw-r--r--tensorflow/python/client/graph_util.py138
-rw-r--r--tensorflow/python/client/graph_util_test.py126
-rw-r--r--tensorflow/python/client/notebook.py104
-rw-r--r--tensorflow/python/client/session.py567
-rw-r--r--tensorflow/python/client/session_test.py555
-rw-r--r--tensorflow/python/client/tensorflow_server.i16
-rw-r--r--tensorflow/python/client/test_construction_fails_op.cc22
-rw-r--r--tensorflow/python/client/tf_session.i235
-rw-r--r--tensorflow/python/client/tf_session_helper.cc518
-rw-r--r--tensorflow/python/client/tf_session_helper.h56
-rwxr-xr-xtensorflow/python/framework/__init__.py0
-rw-r--r--tensorflow/python/framework/device.py220
-rw-r--r--tensorflow/python/framework/device_test.py122
-rw-r--r--tensorflow/python/framework/docs.py492
-rw-r--r--tensorflow/python/framework/errors.py410
-rw-r--r--tensorflow/python/framework/errors_test.py63
-rw-r--r--tensorflow/python/framework/framework_lib.py70
-rw-r--r--tensorflow/python/framework/gen_docs_combined.py114
-rwxr-xr-xtensorflow/python/framework/gen_docs_test.sh4
-rw-r--r--tensorflow/python/framework/importer.py303
-rw-r--r--tensorflow/python/framework/importer_test.py546
-rw-r--r--tensorflow/python/framework/op_def_registry.py23
-rw-r--r--tensorflow/python/framework/ops.py2985
-rw-r--r--tensorflow/python/framework/ops_test.py825
-rw-r--r--tensorflow/python/framework/python_op_gen.cc678
-rw-r--r--tensorflow/python/framework/python_op_gen.h17
-rw-r--r--tensorflow/python/framework/python_op_gen_main.cc30
-rw-r--r--tensorflow/python/framework/random_seed.py136
-rw-r--r--tensorflow/python/framework/registry.py64
-rw-r--r--tensorflow/python/framework/registry_test.py38
-rw-r--r--tensorflow/python/framework/tensor_shape.py743
-rw-r--r--tensorflow/python/framework/tensor_shape_test.py232
-rw-r--r--tensorflow/python/framework/tensor_util.py511
-rw-r--r--tensorflow/python/framework/tensor_util_test.py379
-rw-r--r--tensorflow/python/framework/test_kernel_label_op.cc47
-rw-r--r--tensorflow/python/framework/test_util.py437
-rw-r--r--tensorflow/python/framework/test_util_test.py128
-rw-r--r--tensorflow/python/framework/types.py418
-rw-r--r--tensorflow/python/framework/types_test.py174
-rwxr-xr-xtensorflow/python/kernel_tests/__init__.py0
-rw-r--r--tensorflow/python/kernel_tests/argmax_op_test.py61
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py45
-rw-r--r--tensorflow/python/kernel_tests/attention_ops_test.py166
-rw-r--r--tensorflow/python/kernel_tests/batch_matmul_op_test.py195
-rw-r--r--tensorflow/python/kernel_tests/bcast_ops_test.py76
-rw-r--r--tensorflow/python/kernel_tests/bias_op_test.py93
-rw-r--r--tensorflow/python/kernel_tests/candidate_sampler_ops_test.py114
-rw-r--r--tensorflow/python/kernel_tests/cast_op_test.py165
-rw-r--r--tensorflow/python/kernel_tests/cholesky_op_test.py74
-rw-r--r--tensorflow/python/kernel_tests/clip_ops_test.py222
-rw-r--r--tensorflow/python/kernel_tests/concat_op_test.py276
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py524
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py1260
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py1009
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py1187
-rw-r--r--tensorflow/python/kernel_tests/decode_csv_op_test.py148
-rw-r--r--tensorflow/python/kernel_tests/decode_raw_op_test.py44
-rw-r--r--tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py60
-rw-r--r--tensorflow/python/kernel_tests/dense_update_ops_test.py151
-rw-r--r--tensorflow/python/kernel_tests/determinant_op_test.py72
-rw-r--r--tensorflow/python/kernel_tests/diag_op_test.py80
-rw-r--r--tensorflow/python/kernel_tests/dynamic_partition_op_test.py99
-rw-r--r--tensorflow/python/kernel_tests/dynamic_stitch_op_test.py107
-rw-r--r--tensorflow/python/kernel_tests/edit_distance_op_test.py153
-rw-r--r--tensorflow/python/kernel_tests/embedding_ops_test.py422
-rw-r--r--tensorflow/python/kernel_tests/fifo_queue_test.py1043
-rw-r--r--tensorflow/python/kernel_tests/gather_op_test.py71
-rw-r--r--tensorflow/python/kernel_tests/gradient_checker.py251
-rw-r--r--tensorflow/python/kernel_tests/gradient_checker_test.py178
-rw-r--r--tensorflow/python/kernel_tests/identity_op_py_test.py47
-rw-r--r--tensorflow/python/kernel_tests/in_topk_op_test.py36
-rw-r--r--tensorflow/python/kernel_tests/init_ops_test.py252
-rw-r--r--tensorflow/python/kernel_tests/io_ops_test.py53
-rw-r--r--tensorflow/python/kernel_tests/linalg_grad_test.py49
-rw-r--r--tensorflow/python/kernel_tests/listdiff_op_test.py117
-rw-r--r--tensorflow/python/kernel_tests/logging_ops_test.py50
-rw-r--r--tensorflow/python/kernel_tests/lookup_table_op_test.py195
-rw-r--r--tensorflow/python/kernel_tests/lrn_op_test.py101
-rw-r--r--tensorflow/python/kernel_tests/matmul_op_test.py206
-rw-r--r--tensorflow/python/kernel_tests/matrix_inverse_op_test.py79
-rw-r--r--tensorflow/python/kernel_tests/numerics_test.py91
-rw-r--r--tensorflow/python/kernel_tests/pack_op_test.py47
-rw-r--r--tensorflow/python/kernel_tests/pad_op_test.py140
-rw-r--r--tensorflow/python/kernel_tests/parsing_ops_test.py414
-rw-r--r--tensorflow/python/kernel_tests/pooling_ops_test.py819
-rw-r--r--tensorflow/python/kernel_tests/random_ops_test.py242
-rw-r--r--tensorflow/python/kernel_tests/random_shuffle_queue_test.py1054
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py362
-rw-r--r--tensorflow/python/kernel_tests/reduction_ops_test.py533
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py181
-rw-r--r--tensorflow/python/kernel_tests/reshape_op_test.py106
-rw-r--r--tensorflow/python/kernel_tests/reverse_sequence_op_test.py109
-rw-r--r--tensorflow/python/kernel_tests/save_restore_ops_test.py21
-rw-r--r--tensorflow/python/kernel_tests/scatter_ops_test.py49
-rw-r--r--tensorflow/python/kernel_tests/segment_reduction_ops_test.py269
-rw-r--r--tensorflow/python/kernel_tests/shape_ops_test.py389
-rw-r--r--tensorflow/python/kernel_tests/slice_op_test.py235
-rw-r--r--tensorflow/python/kernel_tests/softmax_op_test.py65
-rw-r--r--tensorflow/python/kernel_tests/softplus_op_test.py47
-rw-r--r--tensorflow/python/kernel_tests/sparse_concat_op_test.py260
-rw-r--r--tensorflow/python/kernel_tests/sparse_matmul_op_test.py82
-rw-r--r--tensorflow/python/kernel_tests/sparse_reorder_op_test.py56
-rw-r--r--tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py111
-rw-r--r--tensorflow/python/kernel_tests/sparsemask_op_test.py32
-rw-r--r--tensorflow/python/kernel_tests/split_op_test.py132
-rw-r--r--tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py34
-rw-r--r--tensorflow/python/kernel_tests/string_to_number_op_test.py66
-rw-r--r--tensorflow/python/kernel_tests/summary_image_op_test.py63
-rw-r--r--tensorflow/python/kernel_tests/summary_ops_test.py83
-rw-r--r--tensorflow/python/kernel_tests/topk_op_test.py52
-rw-r--r--tensorflow/python/kernel_tests/transpose_op_test.py176
-rw-r--r--tensorflow/python/kernel_tests/unique_op_test.py22
-rw-r--r--tensorflow/python/kernel_tests/unpack_op_test.py56
-rw-r--r--tensorflow/python/kernel_tests/variable_ops_test.py225
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py160
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py242
-rw-r--r--tensorflow/python/kernel_tests/where_op_test.py43
-rw-r--r--tensorflow/python/kernel_tests/xent_op_test.py110
-rwxr-xr-xtensorflow/python/lib/__init__.py0
-rwxr-xr-xtensorflow/python/lib/core/__init__.py0
-rw-r--r--tensorflow/python/lib/core/pywrap_status_test.py35
-rw-r--r--tensorflow/python/lib/core/status.i116
-rw-r--r--tensorflow/python/lib/core/status_helper.i16
-rw-r--r--tensorflow/python/lib/core/strings.i94
-rwxr-xr-xtensorflow/python/lib/io/__init__.py0
-rw-r--r--tensorflow/python/lib/io/py_record_reader.cc49
-rw-r--r--tensorflow/python/lib/io/py_record_reader.h50
-rw-r--r--tensorflow/python/lib/io/py_record_reader.i39
-rw-r--r--tensorflow/python/lib/io/py_record_writer.cc44
-rw-r--r--tensorflow/python/lib/io/py_record_writer.h38
-rw-r--r--tensorflow/python/lib/io/py_record_writer.i38
-rw-r--r--tensorflow/python/lib/io/python_io.py29
-rw-r--r--tensorflow/python/lib/io/tf_record.py68
-rwxr-xr-xtensorflow/python/ops/__init__.py0
-rw-r--r--tensorflow/python/ops/array_grad.py187
-rw-r--r--tensorflow/python/ops/array_ops.py1207
-rw-r--r--tensorflow/python/ops/attention_ops.py34
-rw-r--r--tensorflow/python/ops/candidate_sampling_ops.py365
-rw-r--r--tensorflow/python/ops/clip_ops.py234
-rw-r--r--tensorflow/python/ops/common_shapes.py371
-rw-r--r--tensorflow/python/ops/constant_op.py189
-rw-r--r--tensorflow/python/ops/control_flow_grad.py100
-rw-r--r--tensorflow/python/ops/control_flow_ops.py1561
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py88
-rw-r--r--tensorflow/python/ops/data_flow_grad.py37
-rw-r--r--tensorflow/python/ops/data_flow_ops.py680
-rw-r--r--tensorflow/python/ops/embedding_ops.py197
-rw-r--r--tensorflow/python/ops/gradients.py661
-rw-r--r--tensorflow/python/ops/gradients_test.py337
-rw-r--r--tensorflow/python/ops/image_ops.py786
-rw-r--r--tensorflow/python/ops/image_ops_test.py771
-rw-r--r--tensorflow/python/ops/init_ops.py181
-rw-r--r--tensorflow/python/ops/io_ops.py541
-rw-r--r--tensorflow/python/ops/linalg_grad.py25
-rw-r--r--tensorflow/python/ops/linalg_ops.py62
-rw-r--r--tensorflow/python/ops/logging_ops.py58
-rw-r--r--tensorflow/python/ops/math_grad.py506
-rw-r--r--tensorflow/python/ops/math_ops.py1201
-rw-r--r--tensorflow/python/ops/math_ops_test.py68
-rw-r--r--tensorflow/python/ops/nn.py816
-rw-r--r--tensorflow/python/ops/nn_grad.py229
-rw-r--r--tensorflow/python/ops/nn_ops.py365
-rw-r--r--tensorflow/python/ops/nn_test.py882
-rw-r--r--tensorflow/python/ops/numerics.py50
-rw-r--r--tensorflow/python/ops/op_def_library.py640
-rw-r--r--tensorflow/python/ops/op_def_library_test.py1402
-rw-r--r--tensorflow/python/ops/parsing_ops.py390
-rw-r--r--tensorflow/python/ops/random_ops.py181
-rw-r--r--tensorflow/python/ops/sparse_grad.py12
-rw-r--r--tensorflow/python/ops/sparse_ops.py458
-rw-r--r--tensorflow/python/ops/sparse_ops_test.py212
-rw-r--r--tensorflow/python/ops/standard_ops.py41
-rw-r--r--tensorflow/python/ops/state_grad.py18
-rw-r--r--tensorflow/python/ops/state_ops.py189
-rw-r--r--tensorflow/python/ops/string_ops.py12
-rw-r--r--tensorflow/python/ops/summary_ops.py177
-rw-r--r--tensorflow/python/ops/variable_scope.py333
-rw-r--r--tensorflow/python/ops/variables.py569
-rw-r--r--tensorflow/python/platform/__init__.py6
-rw-r--r--tensorflow/python/platform/app.py13
-rw-r--r--tensorflow/python/platform/base.i176
-rw-r--r--tensorflow/python/platform/control_imports.py13
-rwxr-xr-xtensorflow/python/platform/default/__init__.py0
-rw-r--r--tensorflow/python/platform/default/_app.py11
-rw-r--r--tensorflow/python/platform/default/_flags.py92
-rw-r--r--tensorflow/python/platform/default/_gfile.py404
-rw-r--r--tensorflow/python/platform/default/_googletest.py68
-rw-r--r--tensorflow/python/platform/default/_init.py1
-rw-r--r--tensorflow/python/platform/default/_logging.py182
-rw-r--r--tensorflow/python/platform/default/_parameterized.py2
-rw-r--r--tensorflow/python/platform/default/_resource_loader.py26
-rw-r--r--tensorflow/python/platform/default/_status_bar.py5
-rw-r--r--tensorflow/python/platform/default/flags_test.py53
-rw-r--r--tensorflow/python/platform/default/gfile_test.py147
-rw-r--r--tensorflow/python/platform/default/logging_test.py13
-rw-r--r--tensorflow/python/platform/flags.py10
-rw-r--r--tensorflow/python/platform/gfile.py10
-rw-r--r--tensorflow/python/platform/googletest.py10
-rw-r--r--tensorflow/python/platform/logging.py10
-rw-r--r--tensorflow/python/platform/numpy.i3085
-rw-r--r--tensorflow/python/platform/parameterized.py10
-rw-r--r--tensorflow/python/platform/resource_loader.py10
-rw-r--r--tensorflow/python/platform/status_bar.py10
-rw-r--r--tensorflow/python/platform/test.py6
-rw-r--r--tensorflow/python/summary/README.md15
-rwxr-xr-xtensorflow/python/summary/__init__.py0
-rw-r--r--tensorflow/python/summary/event_accumulator.py433
-rw-r--r--tensorflow/python/summary/event_accumulator_test.py422
-rw-r--r--tensorflow/python/summary/event_multiplexer.py346
-rw-r--r--tensorflow/python/summary/event_multiplexer_test.py244
-rwxr-xr-xtensorflow/python/summary/impl/__init__.py0
-rw-r--r--tensorflow/python/summary/impl/directory_watcher.py115
-rw-r--r--tensorflow/python/summary/impl/directory_watcher_test.py102
-rw-r--r--tensorflow/python/summary/impl/event_file_loader.py49
-rw-r--r--tensorflow/python/summary/impl/event_file_loader_test.py59
-rw-r--r--tensorflow/python/summary/impl/reservoir.py164
-rw-r--r--tensorflow/python/summary/impl/reservoir_test.py178
-rw-r--r--tensorflow/python/tensorflow.i14
-rwxr-xr-xtensorflow/python/training/__init__.py0
-rw-r--r--tensorflow/python/training/adagrad.py58
-rw-r--r--tensorflow/python/training/adagrad_test.py144
-rw-r--r--tensorflow/python/training/adam.py142
-rw-r--r--tensorflow/python/training/adam_test.py174
-rw-r--r--tensorflow/python/training/checkpoint_state.proto18
-rw-r--r--tensorflow/python/training/coordinator.py186
-rw-r--r--tensorflow/python/training/coordinator_test.py98
-rw-r--r--tensorflow/python/training/ftrl.py283
-rw-r--r--tensorflow/python/training/ftrl_test.py234
-rw-r--r--tensorflow/python/training/gradient_descent.py44
-rw-r--r--tensorflow/python/training/gradient_descent_test.py105
-rw-r--r--tensorflow/python/training/input.py501
-rw-r--r--tensorflow/python/training/input_test.py477
-rw-r--r--tensorflow/python/training/learning_rate_decay.py65
-rw-r--r--tensorflow/python/training/learning_rate_decay_test.py60
-rw-r--r--tensorflow/python/training/momentum.py51
-rw-r--r--tensorflow/python/training/momentum_test.py258
-rw-r--r--tensorflow/python/training/moving_averages.py247
-rw-r--r--tensorflow/python/training/moving_averages_test.py130
-rw-r--r--tensorflow/python/training/optimizer.py426
-rw-r--r--tensorflow/python/training/queue_runner.py233
-rw-r--r--tensorflow/python/training/queue_runner_test.py186
-rw-r--r--tensorflow/python/training/rmsprop.py81
-rw-r--r--tensorflow/python/training/rmsprop_test.py158
-rw-r--r--tensorflow/python/training/saver.proto30
-rw-r--r--tensorflow/python/training/saver.py887
-rw-r--r--tensorflow/python/training/saver_test.py563
-rw-r--r--tensorflow/python/training/summary_io.py226
-rw-r--r--tensorflow/python/training/summary_writer_test.py151
-rw-r--r--tensorflow/python/training/training.py138
-rw-r--r--tensorflow/python/training/training_ops.py115
-rw-r--r--tensorflow/python/training/training_ops_test.py159
-rw-r--r--tensorflow/python/training/training_util.py57
-rwxr-xr-xtensorflow/python/user_ops/__init__.py0
-rw-r--r--tensorflow/python/user_ops/user_ops.py10
-rwxr-xr-xtensorflow/python/util/__init__.py0
-rw-r--r--tensorflow/python/util/port.i11
-rwxr-xr-xtensorflow/python/util/protobuf/__init__.py0
-rw-r--r--tensorflow/python/util/protobuf/compare.py384
-rw-r--r--tensorflow/python/util/protobuf/compare_test.proto49
-rw-r--r--tensorflow/python/util/protobuf/compare_test.py652
266 files changed, 62734 insertions, 0 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
new file mode 100644
index 0000000000..89eb22daba
--- /dev/null
+++ b/tensorflow/python/BUILD
@@ -0,0 +1,965 @@
+# Description:
+# Python support for TensorFlow.
+
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("/tensorflow/tensorflow", "tf_cuda_library")
+load("/tensorflow/tensorflow", "tf_gen_op_wrapper_py")
+load("/tensorflow/tensorflow", "py_tests")
+load("/tensorflow/tensorflow", "cuda_py_tests")
+load("/tensorflow/tensorflow", "tf_py_wrap_cc")
+load("/tensorflow/core/platform/default/build_config", "tf_proto_library_py")
+
+config_setting(
+ name = "macosx",
+ values = {"cpu": "darwin"},
+)
+
+numpy_macosx_include_dir = select({
+ ":macosx": ["-I/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/numpy/core/include"],
+ "//conditions:default": [],
+})
+
+py_library(
+ name = "python",
+ srcs = ["__init__.py"],
+ visibility = ["//tensorflow:__pkg__"],
+ deps = [
+ ":client",
+ ":client_testlib",
+ ":framework",
+ ":framework_test_lib",
+ ":platform",
+ ":platform_test",
+ ":summary",
+ ":training",
+ ],
+)
+
+py_library(
+ name = "platform",
+ srcs = glob(["platform/**/*.py"]),
+)
+
+py_library(
+ name = "platform_test",
+ srcs = [
+ "platform/default/_googletest.py",
+ "platform/googletest.py",
+ ],
+ deps = [":platform"],
+)
+
+py_tests(
+ name = "platform_tests",
+ srcs = glob(["platform/default/*_test.py"]),
+ additional_deps = [
+ ":platform",
+ ":platform_test",
+ ],
+ prefix = "platform",
+)
+
+cc_library(
+ name = "py_record_reader_lib",
+ srcs = [
+ "lib/io/py_record_reader.cc",
+ ],
+ hdrs = [
+ "lib/io/py_record_reader.h",
+ ],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "py_record_writer_lib",
+ srcs = [
+ "lib/io/py_record_writer.cc",
+ ],
+ hdrs = [
+ "lib/io/py_record_writer.h",
+ ],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+py_test(
+ name = "pywrap_status_test",
+ size = "small",
+ srcs = ["lib/core/pywrap_status_test.py"],
+ deps = [
+ ":framework_test_lib",
+ ":platform_test",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+cc_library(
+ name = "python_op_gen_main",
+ srcs = [
+ "framework/python_op_gen.cc",
+ "framework/python_op_gen.h",
+ "framework/python_op_gen_main.cc",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_cc",
+ ],
+)
+
+py_library(
+ name = "framework",
+ srcs = [
+ # TODO(mrry): Move this to framework.
+ "client/graph_util.py",
+ "framework/device.py",
+ "framework/errors.py",
+ "framework/framework_lib.py",
+ "framework/importer.py",
+ "framework/op_def_registry.py",
+ "framework/ops.py",
+ "framework/random_seed.py",
+ "framework/registry.py",
+ "framework/tensor_shape.py",
+ "framework/types.py",
+ "framework/tensor_util.py",
+ "ops/common_shapes.py",
+ ],
+ deps = [
+ ":platform",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+# subinclude("//third_party/py/cython:build_defs")
+
+py_library(
+ name = "extra_py_tests_deps",
+ deps = ["//tensorflow:tensorflow_py"],
+)
+
+py_library(
+ name = "framework_test_lib",
+ srcs = [
+ "framework/test_util.py",
+ ],
+ deps = [
+ ":framework",
+ ":platform_test",
+ ":pywrap_tensorflow",
+ ":session",
+ ":util",
+ ],
+)
+
+py_library(
+ name = "client_testlib",
+ srcs = [
+ "platform/test.py",
+ ],
+ deps = [
+ ":framework_test_lib",
+ ":platform_test",
+ ],
+)
+
+py_test(
+ name = "framework_errors_test",
+ srcs = ["framework/errors_test.py"],
+ main = "framework/errors_test.py",
+ deps = [
+ ":framework_test_lib",
+ ":platform_test",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+py_test(
+ name = "framework_importer_test",
+ srcs = ["framework/importer_test.py"],
+ main = "framework/importer_test.py",
+ deps = [
+ ":framework_test_lib",
+ ":ops",
+ ":platform_test",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "test_kernel_label_op",
+ out = "framework/test_kernel_label_op.py",
+ deps = [":test_kernel_label_op_kernel"],
+)
+
+cc_library(
+ name = "test_kernel_label_op_kernel",
+ srcs = ["framework/test_kernel_label_op.cc"],
+ linkstatic = 1,
+ deps = ["//tensorflow/core:framework"],
+ alwayslink = 1,
+)
+
+py_test(
+ name = "framework_ops_test",
+ srcs = ["framework/ops_test.py"],
+ main = "framework/ops_test.py",
+ deps = [
+ ":framework_test_lib",
+ ":ops",
+ ":platform_test",
+ ":session",
+ ":test_kernel_label_op",
+ ],
+)
+
+py_test(
+ name = "framework_tensor_shape_test",
+ srcs = ["framework/tensor_shape_test.py"],
+ main = "framework/tensor_shape_test.py",
+ deps = [
+ ":framework_test_lib",
+ ":platform_test",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+py_test(
+ name = "framework_tensor_util_test",
+ srcs = ["framework/tensor_util_test.py"],
+ main = "framework/tensor_util_test.py",
+ deps = [
+ ":framework_test_lib",
+ ":ops",
+ ":platform_test",
+ ],
+)
+
+py_test(
+ name = "framework_test_util_test",
+ srcs = ["framework/test_util_test.py"],
+ main = "framework/test_util_test.py",
+ deps = [
+ ":framework_test_lib",
+ ":ops",
+ ":platform_test",
+ ],
+)
+
+py_test(
+ name = "framework_types_test",
+ srcs = ["framework/types_test.py"],
+ main = "framework/types_test.py",
+ deps = [
+ ":framework_test_lib",
+ ":platform_test",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+py_test(
+ name = "op_def_library_test",
+ srcs = ["ops/op_def_library_test.py"],
+ main = "ops/op_def_library_test.py",
+ deps = [
+ ":framework_test_lib",
+ ":ops",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "array_ops",
+ hidden = [
+ "BroadcastGradientArgs",
+ "Concat",
+ "Const",
+ "EditDistance",
+ "Pack",
+ "Placeholder",
+ "RefIdentity",
+ "Split",
+ "Slice",
+ "TileGrad", # Exported through array_grad instead of array_ops.
+ "ZerosLike", # TODO(josh11b): Use this instead of the Python version.
+ "Unpack",
+ ],
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
+ name = "attention_ops",
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
+ name = "candidate_sampling_ops",
+ hidden = [
+ "AllCandidateSampler",
+ "ComputeAccidentalHits",
+ "FixedUnigramCandidateSampler",
+ "LogUniformCandidateSampler",
+ "ThreadUnsafeUnigramCandidateSampler",
+ "UniformCandidateSampler",
+ ],
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
+ name = "control_flow_ops",
+ hidden = [
+ "Switch",
+ "Merge",
+ "Exit",
+ ],
+ require_shape_functions = True,
+ deps = [
+ "//tensorflow/core:control_flow_ops_op_lib",
+ "//tensorflow/core:no_op_op_lib",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "data_flow_ops",
+ hidden = [
+ "FIFOQueue",
+ "HashTable",
+ "InitializeTable",
+ "LookupTableFind",
+ "LookupTableSize",
+ "Mutex",
+ "MutexAcquire",
+ "MutexRelease",
+ "QueueClose",
+ "QueueDequeue",
+ "QueueDequeueMany",
+ "QueueEnqueue",
+ "QueueEnqueueMany",
+ "QueueSize",
+ "RandomShuffleQueue",
+ ],
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
+ name = "image_ops",
+ hidden = [
+ "ScaleImageGrad",
+ ],
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
+ name = "io_ops",
+ hidden = [
+ "FixedLengthRecordReader",
+ "IdentityReader",
+ "ReaderClose",
+ "ReaderEnqueueWork",
+ "ReaderNumRecordsProduced",
+ "ReaderNumWorkUnitsCompleted",
+ "ReaderRead",
+ "ReaderReset",
+ "ReaderRestoreState",
+ "ReaderSerializeState",
+ "ReaderWorkQueueLength",
+ "Restore",
+ "RestoreSlice",
+ "Save",
+ "SaveSlices",
+ "ShardedFilename",
+ "ShardedFilespec",
+ "TextLineReader",
+ "TFRecordReader",
+ "WholeFileReader",
+ ],
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
+ name = "linalg_ops",
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
+ name = "logging_ops",
+ hidden = [
+ "Assert",
+ "Print",
+ ],
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
+ name = "math_ops",
+ hidden = [
+ "Abs",
+ "All",
+ "Any",
+ "BatchMatMul",
+ "Complex",
+ "Max",
+ "Mean",
+ "Min",
+ "Pow",
+ "Prod",
+ "Range",
+ "SparseMatMul",
+ "Sum",
+ "MatMul",
+ "Sigmoid",
+ "Tanh",
+ ],
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
+ name = "nn_ops",
+ hidden = [
+ "AvgPoolGrad", # "*Grad" accessible through nn_grad instead of nn_ops.
+ "BatchNormWithGlobalNormalizationGrad",
+ "SoftmaxCrossEntropyWithLogits",
+ "LRNGrad",
+ "MaxPoolGrad",
+ "MaxPoolGradWithArgmax",
+ "ReluGrad",
+ "Relu6Grad",
+ "SoftplusGrad",
+ "BiasAdd",
+ "Relu6",
+ "AvgPool",
+ "MaxPool",
+ ],
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
+ name = "parsing_ops",
+ hidden = ["ParseExample"],
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
+ name = "random_ops",
+ hidden = [
+ "RandomUniform",
+ "RandomShuffle",
+ "RandomStandardNormal",
+ "TruncatedNormal",
+ ],
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
+ name = "state_ops",
+ hidden = [
+ "Variable",
+ "TemporaryVariable",
+ "DestroyTemporaryVariable",
+ ],
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
+ name = "sparse_ops",
+ hidden = [
+ "SparseConcat",
+ "SparseSelectLastK",
+ "SparseReorder",
+ ],
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
+ name = "string_ops",
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
+ name = "summary_ops",
+ hidden = [
+ "HistogramAccumulatorSummary",
+ "HistogramSummary",
+ "ImageSummary",
+ "MergeSummary",
+ "ScalarSummary",
+ ],
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
+ name = "user_ops",
+ hidden = [
+ "Fact",
+ ],
+ require_shape_functions = False,
+)
+
+tf_gen_op_wrapper_py(
+ name = "training_ops",
+ out = "training/gen_training_ops.py",
+ require_shape_functions = True,
+)
+
+py_library(
+ name = "ops",
+ srcs = [
+ "ops/array_grad.py",
+ "ops/array_ops.py",
+ "ops/attention_ops.py",
+ "ops/candidate_sampling_ops.py",
+ "ops/clip_ops.py",
+ "ops/constant_op.py",
+ "ops/control_flow_grad.py",
+ "ops/control_flow_ops.py",
+ "ops/data_flow_grad.py",
+ "ops/data_flow_ops.py",
+ "ops/embedding_ops.py",
+ "ops/gen_array_ops.py",
+ "ops/gen_attention_ops.py",
+ "ops/gen_control_flow_ops.py",
+ "ops/gen_data_flow_ops.py",
+ "ops/gen_image_ops.py",
+ "ops/gen_io_ops.py",
+ "ops/gen_linalg_ops.py",
+ "ops/gen_logging_ops.py",
+ "ops/gen_math_ops.py",
+ "ops/gen_nn_ops.py",
+ "ops/gen_random_ops.py",
+ "ops/gen_state_ops.py",
+ "ops/gen_string_ops.py",
+ "ops/gen_summary_ops.py",
+ "ops/gradients.py",
+ "ops/image_ops.py",
+ "ops/init_ops.py",
+ "ops/io_ops.py",
+ "ops/linalg_grad.py",
+ "ops/linalg_ops.py",
+ "ops/logging_ops.py",
+ "ops/math_grad.py",
+ "ops/math_ops.py",
+ "ops/nn.py",
+ "ops/nn_grad.py",
+ "ops/nn_ops.py",
+ "ops/numerics.py",
+ "ops/op_def_library.py",
+ "ops/parsing_ops.py",
+ "ops/random_ops.py",
+ "ops/sparse_grad.py",
+ "ops/sparse_ops.py",
+ "ops/standard_ops.py",
+ "ops/state_grad.py",
+ "ops/state_ops.py",
+ "ops/string_ops.py",
+ "ops/summary_ops.py",
+ "ops/variable_scope.py",
+ "ops/variables.py",
+ "user_ops/user_ops.py",
+ ],
+ deps = [
+ ":array_ops",
+ ":candidate_sampling_ops",
+ ":control_flow_ops",
+ ":data_flow_ops",
+ ":framework",
+ ":io_ops",
+ ":linalg_ops",
+ ":logging_ops",
+ ":math_ops",
+ ":nn_ops",
+ ":parsing_ops",
+ ":random_ops",
+ ":sparse_ops",
+ ":string_ops",
+ ":summary_ops",
+ ":user_ops",
+ ],
+)
+
+py_library(
+ name = "training",
+ srcs = glob(
+ ["training/**/*.py"],
+ exclude = ["**/*test*"],
+ ),
+ deps = [
+ ":client",
+ ":framework",
+ ":lib",
+ ":ops",
+ ":protos_all_py",
+ ":pywrap_tensorflow",
+ ":training_ops",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+py_library(
+ name = "client",
+ srcs = glob(
+ ["client/**/*.py"],
+ exclude = ["**/*test*"],
+ ),
+ deps = [
+ ":framework",
+ ":ops",
+ ":session",
+ ":training_ops",
+ ],
+)
+
+py_library(
+ name = "util",
+ srcs = glob(["util/**/*.py"]),
+ deps = [
+ "//google/protobuf:protobuf_python",
+ ],
+)
+
+tf_proto_library_py(
+ name = "protos_all",
+ srcs = glob(
+ ["**/*.proto"],
+ exclude = ["util/protobuf/compare_test.proto"],
+ ),
+)
+
+tf_proto_library_py(
+ name = "compare_test_proto",
+ testonly = 1,
+ srcs = ["util/protobuf/compare_test.proto"],
+)
+
+py_test(
+ name = "protobuf_compare_test",
+ srcs = ["util/protobuf/compare_test.py"],
+ main = "util/protobuf/compare_test.py",
+ deps = [
+ ":compare_test_proto_py",
+ ":platform_test",
+ ":util",
+ ],
+)
+
+py_test(
+ name = "events_writer_test",
+ size = "small",
+ srcs = [
+ "client/events_writer_test.py",
+ ],
+ deps = [
+ ":framework_test_lib",
+ ":lib",
+ ":platform_test",
+ ],
+)
+
+tf_cuda_library(
+ name = "tf_session_helper",
+ srcs = ["client/tf_session_helper.cc"],
+ hdrs = ["client/tf_session_helper.h"],
+ copts = numpy_macosx_include_dir + ["-I/usr/include/python2.7"],
+ deps = [
+ ":construction_fails_op",
+ ":test_kernel_label_op_kernel",
+ "//tensorflow/core",
+ "//tensorflow/core:kernels",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:local",
+ "//tensorflow/core:protos_cc",
+ ],
+)
+
+tf_py_wrap_cc(
+ name = "client/pywraptensorflow_server_lib",
+ srcs = ["client/tensorflow_server.i"],
+ copts = numpy_macosx_include_dir,
+ swig_includes = [
+ "lib/core/status.i",
+ "lib/core/strings.i",
+ "platform/base.i",
+ ],
+ deps = [
+ "//tensorflow/core",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_cc",
+ ],
+)
+
+tf_py_wrap_cc(
+ name = "pywrap_tensorflow",
+ srcs = ["tensorflow.i"],
+ copts = numpy_macosx_include_dir,
+ swig_includes = [
+ "client/events_writer.i",
+ "client/tf_session.i",
+ "lib/core/status.i",
+ "lib/core/status_helper.i",
+ "lib/core/strings.i",
+ "lib/io/py_record_reader.i",
+ "lib/io/py_record_writer.i",
+ "platform/base.i",
+ "platform/numpy.i",
+ "util/port.i",
+ ],
+ deps = [
+ ":py_record_reader_lib",
+ ":py_record_writer_lib",
+ ":tf_session_helper",
+ ],
+)
+
+py_library(
+ name = "lib",
+ srcs = glob(["lib/**/*.py"]),
+ deps = [
+ ":pywrap_tensorflow",
+ ],
+)
+
+py_library(
+ name = "session",
+ srcs = ["client/session.py"],
+ deps = [
+ ":framework",
+ ":ops",
+ ":pywrap_tensorflow",
+ ],
+)
+
+# Just used by tests.
+tf_cuda_library(
+ name = "construction_fails_op",
+ testonly = 1,
+ srcs = ["client/test_construction_fails_op.cc"],
+ deps = [
+ "//tensorflow/core",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_cc",
+ ],
+ alwayslink = 1,
+)
+
+py_test(
+ name = "session_test",
+ srcs = ["client/session_test.py"],
+ deps = [
+ ":framework",
+ ":framework_test_lib",
+ ":session",
+ ],
+)
+
+py_test(
+ name = "graph_util_test",
+ srcs = ["client/graph_util_test.py"],
+ deps = [
+ ":framework",
+ ":framework_test_lib",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "kernel_tests/gradient_checker",
+ srcs = ["kernel_tests/gradient_checker.py"],
+)
+
+cpu_only_kernel_test_list = glob([
+ "kernel_tests/attention_ops_test.py",
+ "kernel_tests/barrier_ops_test.py",
+ "kernel_tests/bcast_ops_test.py",
+ "kernel_tests/candidate_sampler_ops_test.py",
+ "kernel_tests/cholesky_op_test.py",
+ "kernel_tests/clip_ops_test.py",
+ "kernel_tests/decode_csv_op_test.py",
+ "kernel_tests/decode_raw_op_test.py",
+ "kernel_tests/determinant_op_test.py",
+ "kernel_tests/diag_op_test.py",
+ "kernel_tests/edit_distance_op_test.py",
+ "kernel_tests/fifo_queue_test.py",
+ "kernel_tests/identity_op_py_test.py",
+ "kernel_tests/in_topk_op_test.py",
+ "kernel_tests/io_ops_test.py",
+ "kernel_tests/listdiff_op_test.py",
+ "kernel_tests/logging_ops_test.py",
+ "kernel_tests/lookup_table_op_test.py",
+ "kernel_tests/lrn_op_py_test.py",
+ "kernel_tests/matrix_inverse_op_test.py",
+ "kernel_tests/mutex_ops_test.py",
+ "kernel_tests/parsing_ops_test.py",
+ "kernel_tests/queue_ops_test.py",
+ "kernel_tests/random_shuffle_queue_test.py",
+ "kernel_tests/save_restore_ops_test.py",
+ "kernel_tests/segment_reduction_ops_test.py",
+ "kernel_tests/sparse_concat_op_test.py",
+ "kernel_tests/sparse_reorder_op_test.py",
+ "kernel_tests/sparse_to_dense_op_test.py",
+ "kernel_tests/sparsemask_op_test.py",
+ "kernel_tests/summary_ops_test.py",
+ "kernel_tests/topk_op_test.py",
+ "kernel_tests/unique_op_test.py",
+ "kernel_tests/variable_scope_test.py",
+ "kernel_tests/variables_test.py",
+ "kernel_tests/where_op_test.py",
+])
+
+py_tests(
+ name = "cpu_only_kernel_tests",
+ srcs = cpu_only_kernel_test_list,
+)
+
+py_tests(
+ name = "reader_ops_test",
+ srcs = ["kernel_tests/reader_ops_test.py"],
+ additional_deps = [
+ ":lib",
+ ],
+)
+
+cuda_py_tests(
+ name = "op_tests",
+ srcs = glob(
+ ["ops/*_test.py"],
+ exclude = [
+ "ops/image_ops_test.py",
+ "ops/op_def_library_test.py",
+ ],
+ ),
+)
+
+cuda_py_tests(
+ name = "kernel_tests",
+ srcs = glob(
+ ["kernel_tests/*_test.py"],
+ exclude = [
+ "**/reader_ops_test.py",
+ # Sharded below
+ "**/cwise_ops_test.py",
+ "**/conv_ops_test.py",
+ "**/linalg_grad_test.py",
+ "**/pooling_ops_test.py",
+ ] + cpu_only_kernel_test_list,
+ ),
+)
+
+cuda_py_tests(
+ name = "kernel_tests_with_sharding",
+ srcs = [
+ "kernel_tests/conv_ops_test.py",
+ "kernel_tests/cwise_ops_test.py",
+ "kernel_tests/linalg_grad_test.py",
+ "kernel_tests/pooling_ops_test.py",
+ ],
+ shard_count = 2,
+)
+
+cuda_py_tests(
+ name = "image_ops_test",
+ srcs = [
+ "ops/image_ops_test.py",
+ ],
+ data = [
+ "//tensorflow/core:image_testdata",
+ ],
+ shard_count = 5,
+)
+
+cuda_py_tests(
+ name = "training_tests",
+ srcs = glob(
+ ["training/*_test.py"],
+ exclude = ["training/input_test.py"],
+ ),
+ additional_deps = [
+ ":training",
+ ],
+)
+
+py_tests(
+ name = "training_tests",
+ srcs = glob(
+ ["training/input_test.py"],
+ ),
+ additional_deps = [
+ ":training",
+ ],
+)
+
+py_library(
+ name = "summary",
+ srcs = glob(
+ ["summary/**/*.py"],
+ exclude = ["**/*test*"],
+ ),
+ deps = [
+ ":client",
+ ":framework",
+ ":pywrap_tensorflow",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+py_tests(
+ name = "summary_tests",
+ srcs = glob(["summary/**/*_test.py"]),
+ additional_deps = [
+ ":summary",
+ ":training",
+ ],
+)
+
+py_library(
+ name = "docs",
+ srcs = [
+ "framework/docs.py",
+ ],
+ deps = [
+ ":platform",
+ ],
+)
+
+py_binary(
+ name = "gen_docs_combined",
+ srcs = [
+ "framework/gen_docs_combined.py",
+ ],
+ main = "framework/gen_docs_combined.py",
+ deps = [
+ ":docs",
+ ":platform",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+sh_test(
+ name = "gen_docs_test",
+ size = "small",
+ srcs = [
+ "framework/gen_docs_test.sh",
+ ],
+ data = [
+ ":gen_docs_combined",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
new file mode 100644
index 0000000000..5527c01173
--- /dev/null
+++ b/tensorflow/python/__init__.py
@@ -0,0 +1,42 @@
+# pylint: disable=wildcard-import,unused-import,g-bad-import-order,line-too-long
+"""Import core names of TensorFlow.
+
+Programs that want to build Brain Ops and Graphs without having to import the
+constructors and utilities individually can import this file:
+
+import tensorflow.python.platform
+import tensorflow as tf
+
+"""
+
+import tensorflow.python.platform
+from tensorflow.core.framework.graph_pb2 import *
+from tensorflow.core.framework.summary_pb2 import *
+from tensorflow.core.framework.config_pb2 import *
+from tensorflow.core.util.event_pb2 import *
+
+# Framework
+from tensorflow.python.framework.framework_lib import *
+
+# Session
+from tensorflow.python.client.client_lib import *
+
+# Ops
+from tensorflow.python.ops.standard_ops import *
+
+# Bring nn, image_ops, user_ops as a subpackages
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import image_ops as image
+from tensorflow.python.user_ops import user_ops
+
+# Import the names from python/training.py as train.Name.
+from tensorflow.python.training import training as train
+
+# Sub-package for performing i/o directly instead of via ops in a graph.
+from tensorflow.python.lib.io import python_io
+
+# Make some application and test modules available.
+from tensorflow.python.platform import app
+from tensorflow.python.platform import flags
+from tensorflow.python.platform import logging
+from tensorflow.python.platform import test
diff --git a/tensorflow/python/client/__init__.py b/tensorflow/python/client/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/client/__init__.py
diff --git a/tensorflow/python/client/client_lib.py b/tensorflow/python/client/client_lib.py
new file mode 100644
index 0000000000..9148ed17c0
--- /dev/null
+++ b/tensorflow/python/client/client_lib.py
@@ -0,0 +1,40 @@
+# pylint: disable=wildcard-import,unused-import,g-bad-import-order,line-too-long
+"""This library contains classes for launching graphs and executing operations.
+
+The [basic usage](../../get_started/index.md#basic-usage) guide has
+examples of how a graph is launched in a [`tf.Session`](#Session).
+
+## Session management
+
+@@Session
+
+@@get_default_session
+
+## Error classes
+
+@@OpError
+@@CancelledError
+@@UnknownError
+@@InvalidArgumentError
+@@DeadlineExceededError
+@@NotFoundError
+@@AlreadyExistsError
+@@PermissionDeniedError
+@@UnauthenticatedError
+@@ResourceExhaustedError
+@@FailedPreconditionError
+@@AbortedError
+@@OutOfRangeError
+@@UnimplementedError
+@@InternalError
+@@UnavailableError
+@@DataLossError
+"""
+
+from tensorflow.python.client.session import InteractiveSession
+from tensorflow.python.client.session import Session
+
+from tensorflow.python.framework import errors
+from tensorflow.python.framework.errors import OpError
+
+from tensorflow.python.framework.ops import get_default_session
diff --git a/tensorflow/python/client/events_writer.i b/tensorflow/python/client/events_writer.i
new file mode 100644
index 0000000000..cbf42e2791
--- /dev/null
+++ b/tensorflow/python/client/events_writer.i
@@ -0,0 +1,34 @@
+%include "tensorflow/python/platform/base.i"
+
+%{
+#include "tensorflow/core/util/events_writer.h"
+#include "tensorflow/core/util/event.pb.h"
+%}
+
+%nodefaultctor EventsWriter;
+
+%ignoreall
+%unignore tensorflow;
+%unignore tensorflow::EventsWriter;
+%unignore tensorflow::EventsWriter::EventsWriter;
+%unignore tensorflow::EventsWriter::~EventsWriter;
+%unignore tensorflow::EventsWriter::FileName;
+%rename("_WriteSerializedEvent") tensorflow::EventsWriter::WriteSerializedEvent;
+%unignore tensorflow::EventsWriter::Flush;
+%unignore tensorflow::EventsWriter::Close;
+%include "tensorflow/core/util/events_writer.h"
+%unignoreall
+
+%newobject tensorflow::EventsWriter::EventsWriter;
+
+
+%extend tensorflow::EventsWriter {
+%insert("python") %{
+ def WriteEvent(self, event):
+ from tensorflow.core.util.event_pb2 import Event
+ if not isinstance(event, Event):
+ raise TypeError("Expected an event_pb2.Event proto, "
+ " but got %s" % type(event))
+ return self._WriteSerializedEvent(event.SerializeToString())
+%}
+}
diff --git a/tensorflow/python/client/events_writer_test.py b/tensorflow/python/client/events_writer_test.py
new file mode 100644
index 0000000000..60bce49b1f
--- /dev/null
+++ b/tensorflow/python/client/events_writer_test.py
@@ -0,0 +1,54 @@
+"""Tests for the SWIG-wrapped events writer."""
+import os.path
+
+from tensorflow.core.framework import summary_pb2
+from tensorflow.core.util import event_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.lib.io import tf_record
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class PywrapeventsWriterTest(test_util.TensorFlowTestCase):
+
+ def testWriteEvents(self):
+ file_prefix = os.path.join(self.get_temp_dir(), "events")
+ writer = pywrap_tensorflow.EventsWriter(file_prefix)
+ filename = writer.FileName()
+ event_written = event_pb2.Event(
+ wall_time=123.45, step=67,
+ summary=summary_pb2.Summary(
+ value=[summary_pb2.Summary.Value(tag="foo", simple_value=89.0)]))
+ writer.WriteEvent(event_written)
+ writer.Flush()
+ writer.Close()
+
+ with self.assertRaises(IOError):
+ for r in tf_record.tf_record_iterator(filename + "DOES_NOT_EXIST"):
+ self.assertTrue(False)
+
+ reader = tf_record.tf_record_iterator(filename)
+ event_read = event_pb2.Event()
+
+ event_read.ParseFromString(next(reader))
+ self.assertTrue(event_read.HasField("file_version"))
+
+ event_read.ParseFromString(next(reader))
+ # Second event
+ self.assertProtoEquals("""
+ wall_time: 123.45 step: 67
+ summary { value { tag: 'foo' simple_value: 89.0 } }
+ """, event_read)
+
+ with self.assertRaises(StopIteration):
+ next(reader)
+
+ def testWriteEventInvalidType(self):
+ class _Invalid(object):
+ def __str__(self): return "Invalid"
+ with self.assertRaisesRegexp(TypeError, "Invalid"):
+ pywrap_tensorflow.EventsWriter("foo").WriteEvent(_Invalid())
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/client/graph_util.py b/tensorflow/python/client/graph_util.py
new file mode 100644
index 0000000000..4c65a445ae
--- /dev/null
+++ b/tensorflow/python/client/graph_util.py
@@ -0,0 +1,138 @@
+"""Helpers to manipulate a tensor graph in python.
+"""
+
+import tensorflow.python.platform
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import device as pydev
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.platform import logging
+
+_VARIABLE_OPS = {
+ "Assign",
+ "AssignAdd",
+ "AssignSub",
+ "Queue",
+ "RandomParameters",
+ "ScatterAdd",
+ "ScatterSub",
+ "ScatterUpdate",
+ "Variable",
+}
+
+
+def _is_variable_op(op):
+ """Returns true if 'op' refers to a Variable node."""
+ return op in _VARIABLE_OPS
+
+
+def set_cpu0(device_string):
+ """Creates a new device string based on `device_string' but using /CPU:0.
+
+ If the device is already on /CPU:0, this is a no-op.
+
+ Args:
+ device_string: A device string.
+
+ Returns:
+ A device string.
+ """
+ parsed_device = pydev.from_string(device_string)
+ parsed_device.device_type = "CPU"
+ parsed_device.device_index = 0
+ return parsed_device.to_string()
+
+
+def must_run_on_cpu(node, pin_variables_on_cpu=False):
+ """Returns True if the given node_def must run on CPU, otherwise False.
+
+ Args:
+ node: The node to be assigned to a device. Could be either an ops.Operation
+ or NodeDef.
+ pin_variables_on_cpu: If True, this function will return False if node_def
+ represents a variable-related op.
+
+ Returns:
+ True if the given node must run on CPU, otherwise False.
+ """
+
+ if isinstance(node, ops.Operation):
+ node_def = node.node_def
+ else:
+ assert isinstance(node, graph_pb2.NodeDef)
+ node_def = node
+
+ # If the op is a variable-related op, should we pin it on CPU?
+ if pin_variables_on_cpu and _is_variable_op(node_def.op):
+ return True
+
+ # Constant operations producing a string or int32 must run on CPU.
+ if node_def.op == "Const":
+ # Get the value of the 'dtype' attr
+ dtype = node_def.attr["dtype"].type
+ if dtype == types.string or dtype == types.int32:
+ return True
+
+ if node_def.op == "DynamicStitch":
+ dtype = node_def.attr["T"].type
+ if dtype == types.int32:
+ # DynamicStitch on GPU only works for int32 values.
+ return True
+
+ if node_def.op in ["Cast"]:
+ dtype = node_def.attr["SrcT"].type
+ if dtype == types.int32:
+ # Cast on GPU does not works for int32 values.
+ return True
+ return False
+
+
+################################################################################
+#
+# device functions for use in with g.device(...)
+#
+################################################################################
+
+
+def pin_variables_on_cpu(op):
+ """Returns a CPU device for Variable nodes if the device is not specified.
+
+ Args:
+ op: The ops.Operation object describing the node for which a device
+ should be chosen. The op.device field is respected.
+
+ Returns:
+ A device containing "/device:CPU:0" if the node is related to a variable.
+ """
+ device = op.device if op.device is not None else ""
+ dev = pydev.from_string(device)
+
+ # If a device type exists already, do not override.
+ if dev.device_type:
+ return device
+
+ if isinstance(op, ops.Operation):
+ node_def = op.node_def
+ else:
+ assert isinstance(op, graph_pb2.NodeDef)
+ node_def = op
+
+ if _is_variable_op(node_def.op):
+ return set_cpu0(device)
+ return device
+
+
+def pin_to_cpu(op):
+ """Returns a CPU device for the given node."""
+ device = op.device if op.device is not None else ""
+ dev = pydev.from_string(device)
+
+ if not dev.device_type:
+ return set_cpu0(device)
+ if dev.device_type == "CPU":
+ return device
+
+ logging.info("Operation %s has been assigned to a non-CPU (%s), so "
+ "it will not be pinned to the CPU.", op.name, dev.device_type)
+ return device
diff --git a/tensorflow/python/client/graph_util_test.py b/tensorflow/python/client/graph_util_test.py
new file mode 100644
index 0000000000..8066f722a8
--- /dev/null
+++ b/tensorflow/python/client/graph_util_test.py
@@ -0,0 +1,126 @@
+"""Tests for tensorflow.python.client.graph_util."""
+import tensorflow.python.platform
+
+from tensorflow.python.client import graph_util
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import data_flow_ops
+# pylint: disable=unused-import
+from tensorflow.python.ops import math_ops
+# pylint: enable=unused-import
+from tensorflow.python.ops import state_ops
+from tensorflow.python.platform import googletest
+
+
+class DeviceFunctionsTest(googletest.TestCase):
+
+ def testPinToCpu(self):
+ with ops.Graph().as_default() as g, g.device(graph_util.pin_to_cpu):
+ const_a = constant_op.constant(5.0)
+ const_b = constant_op.constant(10.0)
+ add_c = const_a + const_b
+ var_v = state_ops.variable_op([], dtype=types.float32)
+ assign_c_to_v = state_ops.assign(var_v, add_c)
+ const_string = constant_op.constant("on a cpu")
+ dynamic_stitch_int_result = data_flow_ops.dynamic_stitch(
+ [[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]])
+ dynamic_stitch_float_result = data_flow_ops.dynamic_stitch(
+ [[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]])
+ self.assertEqual(const_a.device, "/device:CPU:0")
+ self.assertEqual(const_b.device, "/device:CPU:0")
+ self.assertEqual(add_c.device, "/device:CPU:0")
+ self.assertEqual(var_v.device, "/device:CPU:0")
+ self.assertEqual(assign_c_to_v.device, "/device:CPU:0")
+ self.assertEqual(const_string.device, "/device:CPU:0")
+ self.assertEqual(dynamic_stitch_int_result.device, "/device:CPU:0")
+ self.assertEqual(dynamic_stitch_float_result.device, "/device:CPU:0")
+
+ def testPinRequiredOpsOnCPU(self):
+ with ops.Graph().as_default() as g, g.device(
+ graph_util.pin_variables_on_cpu):
+ const_a = constant_op.constant(5.0)
+ const_b = constant_op.constant(10.0)
+ add_c = const_a + const_b
+ var_v = state_ops.variable_op([], dtype=types.float32)
+ assign_c_to_v = state_ops.assign(var_v, add_c)
+ dynamic_stitch_int_result = data_flow_ops.dynamic_stitch(
+ [[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]])
+ dynamic_stitch_float_result = data_flow_ops.dynamic_stitch(
+ [[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]])
+ # Non-variable ops shuld not specify a device
+ self.assertEqual(const_a.device, None)
+ self.assertEqual(const_b.device, None)
+ self.assertEqual(add_c.device, None)
+ # Variable ops specify a device
+ self.assertEqual(var_v.device, "/device:CPU:0")
+ self.assertEqual(assign_c_to_v.device, "/device:CPU:0")
+
+ def testTwoDeviceFunctions(self):
+ with ops.Graph().as_default() as g:
+ var_0 = state_ops.variable_op([1], dtype=types.float32)
+ with g.device(graph_util.pin_variables_on_cpu):
+ var_1 = state_ops.variable_op([1], dtype=types.float32)
+ var_2 = state_ops.variable_op([1], dtype=types.float32)
+ var_3 = state_ops.variable_op([1], dtype=types.float32)
+ with g.device(graph_util.pin_variables_on_cpu):
+ var_4 = state_ops.variable_op([1], dtype=types.float32)
+ with g.device("/device:GPU:0"):
+ var_5 = state_ops.variable_op([1], dtype=types.float32)
+ var_6 = state_ops.variable_op([1], dtype=types.float32)
+
+ self.assertEqual(var_0.device, None)
+ self.assertEqual(var_1.device, "/device:CPU:0")
+ self.assertEqual(var_2.device, None)
+ self.assertEqual(var_3.device, None)
+ self.assertEqual(var_4.device, "/device:CPU:0")
+ self.assertEqual(var_5.device, "/device:GPU:0")
+ self.assertEqual(var_6.device, "/device:CPU:0")
+
+ def testExplicitDevice(self):
+ with ops.Graph().as_default() as g:
+ const_0 = constant_op.constant(5.0)
+ with g.device("/device:GPU:0"):
+ const_1 = constant_op.constant(5.0)
+ with g.device("/device:GPU:1"):
+ const_2 = constant_op.constant(5.0)
+ with g.device("/device:CPU:0"):
+ const_3 = constant_op.constant(5.0)
+ with g.device("/device:CPU:1"):
+ const_4 = constant_op.constant(5.0)
+ with g.device("/job:ps"):
+ const_5 = constant_op.constant(5.0)
+
+ self.assertEqual(const_0.device, None)
+ self.assertEqual(const_1.device, "/device:GPU:0")
+ self.assertEqual(const_2.device, "/device:GPU:1")
+ self.assertEqual(const_3.device, "/device:CPU:0")
+ self.assertEqual(const_4.device, "/device:CPU:1")
+ self.assertEqual(const_5.device, "/job:ps")
+
+ def testDefaultDevice(self):
+ with ops.Graph().as_default() as g, g.device(
+ graph_util.pin_variables_on_cpu):
+ with g.device("/job:ps"):
+ const_0 = constant_op.constant(5.0)
+ with g.device("/device:GPU:0"):
+ const_1 = constant_op.constant(5.0)
+ with g.device("/device:GPU:1"):
+ const_2 = constant_op.constant(5.0)
+ with g.device("/device:CPU:0"):
+ const_3 = constant_op.constant(5.0)
+ with g.device("/device:CPU:1"):
+ const_4 = constant_op.constant(5.0)
+ with g.device("/replica:0"):
+ const_5 = constant_op.constant(5.0)
+
+ self.assertEqual(const_0.device, "/job:ps")
+ self.assertEqual(const_1.device, "/device:GPU:0")
+ self.assertEqual(const_2.device, "/device:GPU:1")
+ self.assertEqual(const_3.device, "/device:CPU:0")
+ self.assertEqual(const_4.device, "/device:CPU:1")
+ self.assertEqual(const_5.device, "/replica:0")
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/client/notebook.py b/tensorflow/python/client/notebook.py
new file mode 100644
index 0000000000..1871fbc632
--- /dev/null
+++ b/tensorflow/python/client/notebook.py
@@ -0,0 +1,104 @@
+"""Notebook front-end to TensorFlow.
+
+When you run this binary, you'll see something like below, which indicates
+the serving URL of the notebook:
+
+ The IPython Notebook is running at: http://127.0.0.1:8888/
+
+Press "Shift+Enter" to execute a cell
+Press "Enter" on a cell to go into edit mode.
+Press "Escape" to go back into command mode and use arrow keys to navigate.
+Press "a" in command mode to insert cell above or "b" to insert cell below.
+
+Your root notebooks directory is FLAGS.notebook_dir
+"""
+
+
+import os
+import socket
+import sys
+
+# pylint: disable=g-import-not-at-top
+# Official recommended way of turning on fast protocol buffers as of 10/21/14
+os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp"
+os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2"
+
+from tensorflow.python.platform import app
+from tensorflow.python.platform import flags
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string(
+ "password", None,
+ "Password to require. If set, the server will allow public access."
+ " Only used if notebook config file does not exist.")
+
+flags.DEFINE_string("notebook_dir", "experimental/brain/notebooks",
+ "root location where to store notebooks")
+
+ORIG_ARGV = sys.argv
+# Main notebook process calls itself with argv[1]="kernel" to start kernel
+# subprocesses.
+IS_KERNEL = len(sys.argv) > 1 and sys.argv[1] == "kernel"
+
+
+def main(unused_argv):
+ sys.argv = ORIG_ARGV
+
+ if not IS_KERNEL:
+ # Drop all flags.
+ sys.argv = [sys.argv[0]]
+ # NOTE(sadovsky): For some reason, putting this import at the top level
+ # breaks inline plotting. It's probably a bug in the stone-age version of
+ # matplotlib.
+ from IPython.html.notebookapp import NotebookApp # pylint: disable=g-import-not-at-top
+ notebookapp = NotebookApp.instance()
+ notebookapp.open_browser = True
+
+ # password functionality adopted from quality/ranklab/main/tools/notebook.py
+ # add options to run with "password"
+ if FLAGS.password:
+ from IPython.lib import passwd # pylint: disable=g-import-not-at-top
+ notebookapp.ip = "0.0.0.0"
+ notebookapp.password = passwd(FLAGS.password)
+ else:
+ print ("\nNo password specified; Notebook server will only be available"
+ " on the local machine.\n")
+ notebookapp.initialize(argv=["--notebook-dir", FLAGS.notebook_dir])
+
+ if notebookapp.ip == "0.0.0.0":
+ proto = "https" if notebookapp.certfile else "http"
+ url = "%s://%s:%d%s" % (proto, socket.gethostname(), notebookapp.port,
+ notebookapp.base_project_url)
+ print "\nNotebook server will be publicly available at: %s\n" % url
+
+ notebookapp.start()
+ return
+
+ # Drop the --flagfile flag so that notebook doesn't complain about an
+ # "unrecognized alias" when parsing sys.argv.
+ sys.argv = ([sys.argv[0]] +
+ [z for z in sys.argv[1:] if not z.startswith("--flagfile")])
+ from IPython.kernel.zmq.kernelapp import IPKernelApp # pylint: disable=g-import-not-at-top
+ kernelapp = IPKernelApp.instance()
+ kernelapp.initialize()
+
+ # Enable inline plotting. Equivalent to running "%matplotlib inline".
+ ipshell = kernelapp.shell
+ ipshell.enable_matplotlib("inline")
+
+ kernelapp.start()
+
+
+if __name__ == "__main__":
+ # When the user starts the main notebook process, we don't touch sys.argv.
+ # When the main process launches kernel subprocesses, it writes all flags
+ # to a tmpfile and sets --flagfile to that tmpfile, so for kernel
+ # subprocesses here we drop all flags *except* --flagfile, then call
+ # app.run(), and then (in main) restore all flags before starting the
+ # kernel app.
+ if IS_KERNEL:
+ # Drop everything except --flagfile.
+ sys.argv = ([sys.argv[0]] +
+ [x for x in sys.argv[1:] if x.startswith("--flagfile")])
+ app.run()
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
new file mode 100644
index 0000000000..7da9b41cf4
--- /dev/null
+++ b/tensorflow/python/client/session.py
@@ -0,0 +1,567 @@
+"""A client interface for TensorFlow."""
+
+import re
+import sys
+import threading
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python import pywrap_tensorflow as tf_session
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import logging
+
+
+class SessionInterface(object):
+ """Base class for implementations of TensorFlow client sessions."""
+
+ @property
+ def graph(self):
+ """The underlying TensorFlow graph, to be used in building Operations."""
+ raise NotImplementedError('graph')
+
+ @property
+ def sess_str(self):
+ """The TensorFlow process to which this session will connect."""
+ raise NotImplementedError('sess_str')
+
+ def run(self, fetches, feed_dict=None):
+ """Runs operations in the session. See `Session.run()` for details."""
+ raise NotImplementedError('Run')
+
+
+class BaseSession(SessionInterface):
+ """A class for interacting with a TensorFlow computation.
+
+ The BaseSession enables incremental graph building with inline
+ execution of Operations and evaluation of Tensors.
+ """
+
+ def __init__(self, target='', graph=None, config=None):
+ """Constructs a new TensorFlow session.
+
+ Args:
+ target: (Optional) The TensorFlow execution engine to connect to.
+ graph: (Optional) The graph to be used. If this argument is None,
+ the default graph will be used.
+ config: (Optional) ConfigProto proto used to configure the session.
+
+ Raises:
+ RuntimeError: If an error occurs while creating the TensorFlow
+ session.
+ """
+ if graph is None:
+ self._graph = ops.get_default_graph()
+ else:
+ self._graph = graph
+
+ self._opened = False
+ self._closed = False
+
+ self._current_version = 0
+ self._extend_lock = threading.Lock()
+ self._target = target
+
+ self._session = None
+
+ try:
+ opts = tf_session.TF_NewSessionOptions(target=target, config=config)
+ status = tf_session.TF_NewStatus()
+ self._session = tf_session.TF_NewSession(opts, status)
+ if tf_session.TF_GetCode(status) != 0:
+ message = tf_session.TF_Message(status)
+ raise RuntimeError(message)
+
+ finally:
+ tf_session.TF_DeleteSessionOptions(opts)
+ tf_session.TF_DeleteStatus(status)
+
+ def close(self):
+ """Closes this session.
+
+ Calling this method frees all resources associated with the session.
+
+ Raises:
+ RuntimeError: If an error occurs while closing the session.
+ """
+ with self._extend_lock:
+ if self._opened and not self._closed:
+ self._closed = True
+ try:
+ status = tf_session.TF_NewStatus()
+ tf_session.TF_CloseSession(self._session, status)
+ if tf_session.TF_GetCode(status) != 0:
+ raise RuntimeError(tf_session.TF_Message(status))
+ finally:
+ tf_session.TF_DeleteStatus(status)
+
+ def __del__(self):
+ self.close()
+ try:
+ status = tf_session.TF_NewStatus()
+ if self._session is not None:
+ tf_session.TF_DeleteSession(self._session, status)
+ if tf_session.TF_GetCode(status) != 0:
+ raise RuntimeError(tf_session.TF_Message(status))
+ self._session = None
+ finally:
+ tf_session.TF_DeleteStatus(status)
+
+ @property
+ def graph(self):
+ """The graph that was launched in this session."""
+ return self._graph
+
+ @property
+ def graph_def(self):
+ """A serializable version of the underlying TensorFlow graph.
+
+ Returns:
+ A graph_pb2.GraphDef proto containing nodes for all of the Operations in
+ the underlying TensorFlow graph.
+ """
+ return self._graph.as_graph_def()
+
+ @property
+ def sess_str(self):
+ return self._target
+
+ def as_default(self):
+ """Returns a context manager that makes this object the default session.
+
+ Use with the `with` keyword to specify that calls to
+ [`Operation.run()`](framework.md#Operation.run) or
+ [`Tensor.run()`](framework.md#Tensor.run) should be executed in
+ this session.
+
+ ```python
+ c = tf.constant(..)
+ sess = tf.Session()
+
+ with sess.as_default():
+ assert tf.get_default_session() is sess
+ print c.eval()
+ ```
+
+ To get the current default session, use
+ [`tf.get_default_session()`](#get_default_session).
+
+
+ *N.B.* The `as_default` context manager *does not* close the
+ session when you exit the context, and you must close the session
+ explicitly.
+
+ ```python
+ c = tf.constant(...)
+ sess = tf.Session()
+ with sess.as_default():
+ print c.eval()
+ # ...
+ with sess.as_default():
+ print c.eval()
+
+ sess.close()
+ ```
+
+ Alternatively, you can use `with tf.Session():` to create a
+ session that is automatically closed on exiting the context,
+ including when an uncaught exception is raised.
+
+ *N.B.* The default graph is a property of the current thread. If you
+ create a new thread, and wish to use the default session in that
+ thread, you must explicitly add a `with sess.as_default():` in that
+ thread's function.
+
+ Returns:
+ A context manager using this session as the default session.
+
+ """
+ return ops.default_session(self)
+
+ # Eventually, this registration could be opened up to support custom
+ # Tensor expansions. Expects tuples of (Type, fetch_fn, feed_fn),
+ # where the signatures are:
+ # fetch_fn : Type -> (list of Tensors,
+ # lambda: list of fetched np.ndarray -> TypeVal)
+ # feed_fn : Type, TypeVal -> list of (Tensor, value)
+ # Conceptually, fetch_fn describes how to expand fetch into its
+ # component Tensors and how to contracting the fetched results back into
+ # a single return value. feed_fn describes how to unpack a single fed
+ # value and map it to feeds of a Tensor and its corresponding value.
+ # pylint: disable=g-long-lambda
+ _REGISTERED_EXPANSIONS = [
+ # SparseTensors are fetched as SparseTensorValues. They can be fed
+ # SparseTensorValues or normal tuples.
+ (ops.SparseTensor,
+ lambda fetch: (
+ [fetch.indices, fetch.values, fetch.shape],
+ lambda fetched_vals: ops.SparseTensorValue(*fetched_vals)),
+ lambda feed, feed_val: list(zip(
+ [feed.indices, feed.values, feed.shape], feed_val))),
+ # The default catches all types and performs no expansions.
+ (object,
+ lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]),
+ lambda feed, feed_val: [(feed, feed_val)])]
+ # pylint: enable=g-long-lambda
+
+ def run(self, fetches, feed_dict=None):
+ """Runs the operations and evaluates the tensors in `fetches`.
+
+ This method runs one "step" of TensorFlow computation, by
+ running the necessary graph fragment to execute every `Operation`
+ and evaluate every `Tensor` in `fetches`, substituting the values in
+ `feed_dict` for the corresponding input values.
+
+ The `fetches` argument may be a list of graph elements or a single
+ graph element, and these determine the return value of this
+ method. A graph element can be one of the following types:
+
+ * If the *i*th element of `fetches` is an
+ [`Operation`](framework.md#Operation), the *i*th return value
+ will be `None`.
+ * If the *i*th element of `fetches` is a
+ [`Tensor`](framework.md#Tensor), the *i*th return value will
+ be a numpy ndarray containing the value of that tensor.
+ * If the *i*th element of `fetches` is a
+ [`SparseTensor`](sparse_ops.md#SparseTensor), the *i*th
+ return value will be a
+ [`SparseTensorValue`](sparse_ops.md#SparseTensorValue)
+ containing the value of that sparse tensor.
+
+ The optional `feed_dict` argument allows the caller to override
+ the value of tensors in the graph. Each key in `feed_dict` can be
+ one of the following types:
+
+ * If the key is a [`Tensor`](framework.md#Tensor), the
+ value may be a Python scalar, string, list, or numpy ndarray
+ that can be converted to the same `dtype` as that
+ tensor. Additionally, if the key is a
+ [placeholder](io_ops.md#placeholder), the shape of the value
+ will be checked for compatibility with the placeholder.
+ * If the key is a [`SparseTensor`](sparse_ops.md#SparseTensor),
+ the value should be a
+ [`SparseTensorValue`](sparse_ops.md#SparseTensorValue).
+
+ Args:
+ fetches: A single graph element, or a list of graph elements
+ (described above).
+ feed_dict: A dictionary that maps graph elements to values
+ (described above).
+
+ Returns:
+ Either a single value if `fetches` is a single graph element, or
+ a list of values if `fetches` is a list (described above).
+
+ Raises:
+ RuntimeError: If this `Session` is in an invalid state (e.g. has been
+ closed).
+ TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
+ ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a
+ `Tensor` that doesn't exist.
+
+ """
+ def _fetch_fn(fetch):
+ for tensor_type, fetch_fn, _ in BaseSession._REGISTERED_EXPANSIONS:
+ if isinstance(fetch, tensor_type):
+ return fetch_fn(fetch)
+ raise TypeError('Fetch argument %r has invalid type %r'
+ % (fetch, type(fetch)))
+
+ def _feed_fn(feed, feed_val):
+ for tensor_type, _, feed_fn in BaseSession._REGISTERED_EXPANSIONS:
+ if isinstance(feed, tensor_type):
+ return feed_fn(feed, feed_val)
+ raise TypeError('Feed argument %r has invalid type %r'
+ % (feed, type(feed)))
+
+ # Check session.
+ if self._closed:
+ raise RuntimeError('Attempted to use a closed Session.')
+
+ # Validate and process fetches.
+ is_list_fetch = isinstance(fetches, (list, tuple))
+ if not is_list_fetch:
+ fetches = [fetches]
+
+ unique_fetch_targets = set()
+ target_list = []
+
+ fetch_info = []
+ for fetch in fetches:
+ subfetches, fetch_contraction_fn = _fetch_fn(fetch)
+ subfetch_names = []
+ for subfetch in subfetches:
+ try:
+ fetch_t = self.graph.as_graph_element(subfetch, allow_tensor=True,
+ allow_operation=True)
+ if isinstance(fetch_t, ops.Operation):
+ target_list.append(fetch_t.name)
+ else:
+ subfetch_names.append(fetch_t.name)
+ except TypeError as e:
+ raise TypeError('Fetch argument %r of %r has invalid type %r, '
+ 'must be a string or Tensor. (%s)'
+ % (subfetch, fetch, type(subfetch), e.message))
+ except ValueError as e:
+ raise ValueError('Fetch argument %r of %r cannot be interpreted as a '
+ 'Tensor. (%s)' % (subfetch, fetch, e.message))
+ except KeyError as e:
+ raise ValueError('Fetch argument %r of %r cannot be interpreted as a '
+ 'Tensor. (%s)' % (subfetch, fetch, e.message))
+ unique_fetch_targets.update(subfetch_names)
+ fetch_info.append((subfetch_names, fetch_contraction_fn))
+
+ unique_fetch_targets = list(unique_fetch_targets)
+
+ # Create request.
+ feed_dict_string = {}
+
+ # Validate and process feed_dict.
+ if feed_dict:
+ for feed, feed_val in feed_dict.iteritems():
+ for subfeed, subfeed_val in _feed_fn(feed, feed_val):
+ try:
+ subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True,
+ allow_operation=False)
+ except Exception as e:
+ e.message = ('Cannot interpret feed_dict key as Tensor: '
+ + e.message)
+ e.args = (e.message,)
+ raise e
+ np_val = np.array(subfeed_val, dtype=subfeed_t.dtype.as_numpy_dtype)
+ if subfeed_t.op.type == 'Placeholder':
+ if not subfeed_t.get_shape().is_compatible_with(np_val.shape):
+ raise ValueError(
+ 'Cannot feed value of shape %r for Tensor %r, '
+ 'which has shape %r'
+ % (np_val.shape, subfeed_t.name,
+ tuple(subfeed_t.get_shape().dims)))
+ feed_dict_string[str(subfeed_t.name)] = np_val
+
+ # Run request and get response.
+ results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)
+
+ # User may have fetched the same tensor multiple times, but we
+ # only fetch them from the runtime once. Furthermore, they may
+ # be wrapped as a tuple of tensors. Here we map the results back
+ # to what the client asked for.
+ fetched_results = dict(zip(unique_fetch_targets, results))
+ ret = []
+ for fetch_names, fetch_contraction_fn in fetch_info:
+ if fetch_names:
+ fetched_vals = [fetched_results[name] for name in fetch_names]
+ ret.append(fetch_contraction_fn(fetched_vals))
+ else:
+ ret.append(None)
+
+ if is_list_fetch:
+ return ret
+ else:
+ return ret[0]
+
+ # Captures the name of a node in an error status.
+ _NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =')
+
+ def _do_run(self, target_list, fetch_list, feed_dict):
+ """Runs a step based on the given fetches and feeds.
+
+ Args:
+ target_list: A list of strings corresponding to names of tensors
+ or operations to be run to, but not fetched.
+ fetch_list: A list of strings corresponding to names of tensors to be
+ fetched and operations to be run.
+ feed_dict: A dictionary that maps tensor names to numpy ndarrays.
+
+ Returns:
+ A list of numpy ndarrays, corresponding to the elements of
+ `fetch_list`. If the ith element of `fetch_list` contains the
+ name of an operation, the first Tensor output of that operation
+ will be returned for that element.
+ """
+ try:
+ # Ensure any changes to the graph are reflected in the runtime.
+ with self._extend_lock:
+ if self._graph.version > self._current_version:
+ graph_def = self._graph.as_graph_def(
+ from_version=self._current_version)
+
+ try:
+ status = tf_session.TF_NewStatus()
+ tf_session.TF_ExtendGraph(
+ self._session, graph_def.SerializeToString(), status)
+ if tf_session.TF_GetCode(status) != 0:
+ raise RuntimeError(tf_session.TF_Message(status))
+ self._opened = True
+ finally:
+ tf_session.TF_DeleteStatus(status)
+
+ self._current_version = self._graph.version
+
+ return tf_session.TF_Run(self._session, feed_dict, fetch_list,
+ target_list)
+
+ except tf_session.StatusNotOK as e:
+ e_type, e_value, e_traceback = sys.exc_info()
+ m = BaseSession._NODEDEF_NAME_RE.search(e.error_message)
+ if m is not None:
+ node_name = m.group(1)
+ node_def = None
+ try:
+ op = self._graph.get_operation_by_name(node_name)
+ node_def = op.node_def
+ except KeyError:
+ op = None
+ # pylint: disable=protected-access
+ raise errors._make_specific_exception(node_def, op, e.error_message,
+ e.code)
+ # pylint: enable=protected-access
+ raise e_type, e_value, e_traceback
+
+
+class Session(BaseSession):
+ """A class for running TensorFlow operations.
+
+ A `Session` object encapsulates the environment in which `Operation`
+ objects are executed, and `Tensor` objects are evaluated. For
+ example:
+
+ ```python
+ # Build a graph.
+ a = tf.constant(5.0)
+ b = tf.constant(6.0)
+ c = a * b
+
+ # Launch the graph in a session.
+ sess = tf.Session()
+
+ # Evaluate the tensor `c`.
+ print sess.run(c)
+ ```
+
+ A session may own resources, such as
+ [variables](state_ops.md#Variable), [queues](io_ops.md#QueueBase),
+ and [readers](io_ops.md#ReaderBase). It is important to release
+ these resources when they are no longer required. To do this, either
+ invoke the [`close()`](#Session.close) method on the session, or use
+ the session as a context manager. The following two examples are
+ equivalent:
+
+ ```python
+ # Using the `close()` method.
+ sess = tf.Session()
+ sess.run(...)
+ sess.close()
+
+ # Using the context manager.
+ with tf.Session() as sess:
+ sess.run(...)
+ ```
+
+ The [`ConfigProto`]
+ (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/config.proto)
+ protocol buffer exposes various configuration options for a
+ session. For example, to create a session that uses soft constraints
+ for device placement, and log the resulting placement decisions,
+ create a session as follows:
+
+ ```python
+ # Launch the graph in a session that allows soft device placement and
+ # logs the placement decisions.
+ sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
+ log_device_placement=True))
+ ```
+
+ @@__init__
+ @@run
+ @@close
+
+ @@graph
+
+ @@as_default
+
+ """
+
+ def __init__(self, target='', graph=None, config=None):
+ """Creates a new TensorFlow session.
+
+ If no `graph` argument is specified when constructing the session,
+ the default graph will be launched in the session. If you are
+ using more than one graph (created with `tf.Graph()` in the same
+ process, you will have to use different sessions for each graph,
+ but each graph can be used in multiple sessions. In this case, it
+ is often clearer to pass the graph to be launched explicitly to
+ the session constructor.
+
+ Args:
+ target: (Optional.) The execution engine to connect to.
+ Defaults to using an in-process engine. At present, no value
+ other than the empty string is supported.
+ graph: (Optional.) The `Graph` to be launched (described above).
+ config: (Optional.) A [`ConfigProto`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/config.proto)
+ protocol buffer with configuration options for the session.
+
+ """
+ super(Session, self).__init__(target, graph, config=config)
+ self._context_managers = [self.graph.as_default(), self.as_default()]
+
+ def __enter__(self):
+ for context_manager in self._context_managers:
+ context_manager.__enter__()
+ return self
+
+ def __exit__(self, exec_type, exec_value, exec_tb):
+ if exec_type is errors.OpError:
+ logging.error('Session closing due to OpError: %s', (exec_value,))
+
+ for context_manager in reversed(self._context_managers):
+ context_manager.__exit__(exec_type, exec_value, exec_tb)
+
+ self.close()
+
+
+class InteractiveSession(BaseSession):
+ """A TensorFlow `Session` for use in interactive contexts, such as a shell.
+
+ In some cases, such as interactive shells and IPython notebooks, it is
+ useful to be able to define a `Session` without using a with block: this
+ style enables statements to be executed immediately, rather than at the
+ termination of the block. In that case, it must be closed using
+ `Session.close()`. For example:
+
+ ```python
+ sess = InteractiveSession()
+ a = tf.constant(5.0)
+ b = tf.constant(6.0)
+ c = a * b
+ print c.run()
+ sess.close()
+ ```
+
+ @@__init__
+ @@close
+ """
+
+ def __init__(self, target='', graph=None):
+ """Initializes an `InteractiveSession` object similar to `Session`.
+
+ Args:
+ target: Optional. The TensorFlow execution engine to connect to.
+ graph: Optional. The `Graph` object to be used. If this argument is None,
+ the default graph will be used.
+ """
+ super(InteractiveSession, self).__init__(target, graph)
+ self._default_session = self.as_default()
+ self._default_session.__enter__()
+ self._explicit_graph = graph
+ if self._explicit_graph is not None:
+ self._default_graph = graph.as_default()
+ self._default_graph.__enter__()
+
+ def close(self):
+ """Closes an `InteractiveSession`."""
+ super(InteractiveSession, self).close()
+ if self._explicit_graph is not None:
+ self._default_graph.__exit__(None, None, None)
+ self._default_session.__exit__(None, None, None)
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
new file mode 100644
index 0000000000..4492840dcf
--- /dev/null
+++ b/tensorflow/python/client/session_test.py
@@ -0,0 +1,555 @@
+"""Tests for tensorflow.python.client.session.Session."""
+import threading
+import time
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.core.framework import config_pb2
+from tensorflow.core.lib.core import error_codes_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+
+
+# NOTE(mrry): Dummy shape registration for op used in the tests.
+ops.RegisterShape('ConstructionFails')(None)
+
+
+class SessionTest(test_util.TensorFlowTestCase):
+
+ def testUseExistingGraph(self):
+ with ops.Graph().as_default() as g, ops.device('/cpu:0'):
+ a = constant_op.constant(6.0, shape=[1, 1])
+ b = constant_op.constant(7.0, shape=[1, 1])
+ c = math_ops.matmul(a, b, name='matmul')
+ with session.Session(graph=g):
+ result = c.eval()
+ self.assertAllEqual(result, [[42.0]])
+
+ def testUseDefaultGraph(self):
+ with ops.Graph().as_default(), ops.device('/cpu:0'):
+ a = constant_op.constant(6.0, shape=[1, 1])
+ b = constant_op.constant(7.0, shape=[1, 1])
+ c = math_ops.matmul(a, b, name='matmul')
+ with session.Session():
+ result = c.eval()
+ self.assertAllEqual(result, [[42.0]])
+
+ def testCreate(self):
+ with session.Session():
+ inp = constant_op.constant(10.0, name='W1')
+ copy = array_ops.identity(inp)
+ # Test with feed.
+ # TODO(mrry): Investigate why order='F' didn't work.
+ arr = np.asarray([[0, 1, 2], [3, 4, 5]], dtype=np.float32, order='C')
+ copy_val = copy.eval({'W1:0': arr})
+ self.assertAllEqual(arr, copy_val)
+ # Test without feed.
+ copy_val = copy.eval()
+ self.assertAllEqual(np.asarray(10.0, dtype=np.float32), copy_val)
+
+ def testManyCPUs(self):
+ # TODO(keveman): Implement ListDevices and test for the number of
+ # devices returned by ListDevices.
+ with session.Session(
+ config=config_pb2.ConfigProto(device_count={'CPU': 2})):
+ inp = constant_op.constant(10.0, name='W1')
+ self.assertAllEqual(inp.eval(), 10.0)
+
+ def testErrorsReported(self):
+ with session.Session() as s:
+ constant_op.constant(10.0, name='W1')
+ with self.assertRaises(ValueError):
+ s.run('foo:0')
+
+ def testErrorPayload(self):
+ with session.Session():
+ a = array_ops.placeholder(types.float32)
+ with self.assertRaisesOpError(lambda e: e.op == a.op):
+ a.eval()
+
+ def testOpConstructionErrorPayload(self):
+ with session.Session():
+ failing_op = ops.get_default_graph().create_op(
+ 'ConstructionFails', [], [], name='f')
+
+ def exc_predicate(e):
+ return (e.op == failing_op
+ and e.error_code == error_codes_pb2.INVALID_ARGUMENT)
+ with self.assertRaisesOpError(exc_predicate):
+ failing_op.run()
+
+ def testErrorBasedOn(self):
+ with session.Session() as sess:
+ a = constant_op.constant(0.0, shape=[2, 3])
+ # NOTE(mrry): The original_op is nonsense, but used here to test that the
+ # errors are reported correctly.
+ # pylint: disable=protected-access
+ with sess.graph._original_op(a.op):
+ b = array_ops.identity(a, name='id')
+ with sess.graph._original_op(b.op):
+ c = array_ops.placeholder(types.float32)
+ # pylint: enable=protected-access
+
+ def exc_predicate(e):
+ return (e.op == c.op
+ and e.op._original_op == b.op
+ and e.op._original_op._original_op == a.op)
+ with self.assertRaisesOpError(exc_predicate):
+ c.eval()
+
+ def testFetchTensorObject(self):
+ with session.Session() as s:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+ results_with_list = s.run([c])
+ self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_list[0])
+ results_with_single = s.run(c)
+ self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_single)
+ results_with_get = c.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_get)
+ a_val, b_val = s.run([a, b]) # Test multiple fetches.
+ self.assertAllEqual([[1.0, 1.0]], a_val)
+ self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], b_val)
+
+ def testFetchScalar(self):
+ with session.Session() as s:
+ for scalar in np.int32, np.int64, np.float32, np.float64:
+ x = scalar(7)
+ y = scalar(8)
+ tf_x = constant_op.constant(x, shape=[])
+ tf_y = constant_op.constant(y)
+ tf_xy = math_ops.add(tf_x, tf_y)
+ # Single fetch
+ xy = s.run(tf_xy)
+ self.assertEqual(scalar, type(xy))
+ self.assertEqual(x + y, xy)
+ # List fetch
+ xy, = s.run([tf_xy])
+ self.assertEqual(scalar, type(xy))
+ self.assertEqual(x + y, xy)
+
+ def testFetchOperationObject(self):
+ with session.Session() as s:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ v = variables.Variable(a, name='testFetchOperationObject_v')
+ s.run(v.initializer)
+ v_val = s.run(v)
+ self.assertAllEqual([[1.0, 1.0]], v_val)
+
+ def testFetchSparseTensor(self):
+ with session.Session() as s:
+ indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
+ values = np.array([1.0, 2.0]).astype(np.float32)
+ shape = np.array([7, 9, 2]).astype(np.int64)
+ sp = ops.SparseTensor(
+ constant_op.constant(indices),
+ constant_op.constant(values),
+ constant_op.constant(shape))
+ # Single fetch, use as tuple
+ sp_out = s.run(sp)
+ indices_out, values_out, shape_out = sp_out
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # Single fetch, use as SparseTensorValue
+ sp_out = s.run(sp)
+ self.assertAllEqual(sp_out.indices, indices)
+ self.assertAllEqual(sp_out.values, values)
+ self.assertAllEqual(sp_out.shape, shape)
+ # Tuple fetch, use as tuple
+ indices_out, values_out, shape_out = s.run(sp)
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # List fetch, use as tuple
+ (indices_out, values_out, shape_out), = s.run([sp])
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # List fetch, use as SparseTensorValue
+ sp_out, = s.run([sp])
+ self.assertAllEqual(sp_out.indices, indices)
+ self.assertAllEqual(sp_out.values, values)
+ self.assertAllEqual(sp_out.shape, shape)
+
+ def testFeedSparseTensor(self):
+ with session.Session() as s:
+ indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
+ values = np.array([1.0, 2.0]).astype(np.float32)
+ shape = np.array([7, 9, 2]).astype(np.int64)
+ sp = ops.SparseTensor(
+ array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
+ array_ops.placeholder(dtype=np.float32, shape=(2,)),
+ array_ops.placeholder(dtype=np.int64, shape=(3,)),)
+ sp_indices = array_ops.identity(sp.indices)
+ sp_values = array_ops.identity(sp.values)
+ sp_shape = array_ops.identity(sp.shape)
+ sp2 = ops.SparseTensor(sp_indices, sp_values, sp_shape)
+ # Feed with tuple
+ indices_out, values_out, shape_out = s.run(
+ [sp_indices, sp_values, sp_shape], {sp: (indices, values, shape)})
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # Feed with SparseTensorValue
+ indices_out, values_out, shape_out = s.run(
+ [sp_indices, sp_values, sp_shape],
+ {sp: ops.SparseTensorValue(indices, values, shape)})
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # Feed with SparseTensorValue, fetch SparseTensorValue
+ sp2_out = s.run(sp2, {sp: ops.SparseTensorValue(indices, values, shape)})
+ self.assertAllEqual(sp2_out.indices, indices)
+ self.assertAllEqual(sp2_out.values, values)
+ self.assertAllEqual(sp2_out.shape, shape)
+
+ def testExtendWithStatelessOperations(self):
+ with session.Session() as s:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+ c_val = s.run(c)
+ self.assertAllEqual([[4.0, 4.0, 4.0]], c_val)
+ d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1])
+ e = math_ops.matmul(c, d)
+ # Extend will happen here.
+ e_val = s.run(e)
+ self.assertAllEqual([[24.0]], e_val)
+
+ def testExtendWithStatefulOperations(self):
+ with session.Session() as s:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+ v = variables.Variable(c, name='testExtendWithStatefulOperations_v')
+ v.initializer.run()
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ d = constant_op.constant(3.0, shape=[2, 3])
+ e = math_ops.matmul(a, d)
+ assign_e_to_v = state_ops.assign(v, e)
+ # Extend will happen here.
+ e_val = e.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ s.run(assign_e_to_v)
+ v_val = v.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
+
+ def testExtendWithGroupBy(self):
+ with session.Session() as s:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ p = variables.Variable(a, name='testExtendWithGroupBy_p')
+ a_val = a.eval() # Force an Extend after this op.
+ self.assertAllEqual([[1.0, 1.0]], a_val)
+
+ b = constant_op.constant(2.0, shape=[1, 2])
+ q = variables.Variable(b, name='testExtendWithGroupBy_q')
+ # Extend will happen here.
+ init = control_flow_ops.group(p.initializer, q.initializer)
+ s.run(init)
+ p_val, q_val = s.run([p, q])
+
+ self.assertAllEqual([[1.0, 1.0]], p_val)
+ self.assertAllEqual([[2.0, 2.0]], q_val)
+
+ def testTensorGetMethod(self):
+ with session.Session():
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+
+ c_val = c.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], c_val)
+
+ fed_c_val = c.eval(feed_dict={a.name: [[4.0, 4.0]]})
+ self.assertAllEqual([[16.0, 16.0, 16.0]], fed_c_val)
+
+ def testOperationRunMethod(self):
+ with session.Session():
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[1, 2], name='b')
+ v = variables.Variable(a, a.dtype)
+ assign_a_to_v = state_ops.assign(v, a)
+
+ assign_a_to_v.eval()
+
+ v_val = v.eval()
+ self.assertAllEqual([[1.0, 1.0]], v_val)
+
+ assign_b_to_v = state_ops.assign(v, b)
+
+ assign_b_to_v.eval()
+ v_val = v.eval()
+ self.assertAllEqual([[2.0, 2.0]], v_val)
+
+ assign_b_to_v.eval(feed_dict={'b:0': [[3.0, 3.0]]})
+ v_val = v.eval()
+ self.assertAllEqual([[3.0, 3.0]], v_val)
+
+ def testDefaultGraph(self):
+ with session.Session() as s:
+ self.assertEqual(ops.get_default_graph(), s.graph)
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ self.assertEqual(ops.get_default_graph(), a.graph)
+ self.assertEqual(ops.get_default_graph(), b.graph)
+ c = math_ops.matmul(a, b)
+ v = variables.Variable(c, name='testDefaultGraph_v')
+ v.initializer.run()
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ d = constant_op.constant(3.0, shape=[2, 3])
+ e = math_ops.matmul(a, d)
+ assign_e_to_v = state_ops.assign(v, e)
+ e_val = e.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ s.run(assign_e_to_v)
+ v_val = v.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
+ self.assertEqual(ops.get_default_graph(), s.graph)
+
+ def _testDefaultGraphInThread(self, constructed_event, continue_event, i):
+ with session.Session() as s:
+ self.assertEqual(ops.get_default_graph(), s.graph)
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+ v = variables.Variable(c, name='var_%d' % i)
+
+ # Block here until all threads have constructed their graph.
+ constructed_event.set()
+ continue_event.wait()
+
+ assign_c_to_v = state_ops.assign(v, c)
+ v.initializer.run()
+ assign_c_to_v.eval()
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ d = constant_op.constant(3.0, shape=[2, 3])
+ e = math_ops.matmul(a, d)
+ assign_e_to_v = state_ops.assign(v, e)
+ e_val = e.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ s.run(assign_e_to_v)
+ v_val = v.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
+ self.assertEqual(ops.get_default_graph(), s.graph)
+
+ def testDefaultGraphWithThreads(self):
+ # Fork ten threads that use their thread-local default graph.
+ threads = []
+ constructed_events = [threading.Event() for _ in range(10)]
+ continue_event = threading.Event()
+ for i, constructed_event in enumerate(constructed_events):
+ t = self.checkedThread(target=self._testDefaultGraphInThread,
+ args=(constructed_event, continue_event, i))
+ threads.append(t)
+ for t in threads:
+ t.start()
+ for constructed_event in constructed_events:
+ constructed_event.wait()
+ continue_event.set()
+ for t in threads:
+ t.join()
+
+ def testParallelRun(self):
+ with session.Session() as sess:
+ c = constant_op.constant(5.0)
+ ev = threading.Event()
+
+ def run_step():
+ ev.wait()
+ val = c.eval(session=sess)
+ self.assertEqual(val, 5.0)
+ threads = [self.checkedThread(target=run_step) for _ in range(100)]
+ for t in threads:
+ t.start()
+ ev.set()
+ for t in threads:
+ t.join()
+
+ def testRunFeedDict(self):
+ with session.Session() as s:
+ x = array_ops.zeros([2])
+
+ y = s.run(2 * x, feed_dict={x: np.ones(2).astype(np.float32)})
+ self.assertAllEqual(y, 2 * np.ones(2))
+
+ y = s.run(2 * x, feed_dict={x.name: np.ones(2).astype(np.float32)})
+ self.assertAllEqual(y, 2 * np.ones(2))
+
+ y = s.run(2 * x, feed_dict={x: [1, 1]})
+ assert (y == 2 * np.ones(2)).all()
+
+ def testGraphDef(self):
+ with session.Session() as sess:
+ self.assertProtoEquals('', sess.graph_def)
+ c = constant_op.constant(5.0, name='c')
+ self.assertEquals(len(sess.graph_def.node), 1)
+ d = constant_op.constant(6.0, name='d')
+ self.assertEquals(len(sess.graph_def.node), 2)
+ self.assertAllEqual(c.eval(), 5.0)
+ self.assertAllEqual(d.eval(), 6.0)
+ e = constant_op.constant(7.0, name='e')
+ self.assertEquals(len(sess.graph_def.node), 3)
+ self.assertAllEqual(e.eval(), 7.0)
+
+ def testUseAfterClose(self):
+ with session.Session() as sess:
+ c = constant_op.constant(5.0)
+ self.assertAllEqual(sess.run(c), 5.0)
+ with self.assertRaisesWithPredicateMatch(
+ RuntimeError, lambda e: 'Attempted to use a closed Session.' in str(e)):
+ sess.run(c)
+
+ def testUseAfterCloseConcurrent(self):
+ with session.Session() as sess:
+ c = constant_op.constant(5.0)
+ self.assertAllEqual(sess.run(c), 5.0)
+
+ def update_thread():
+ with self.assertRaisesWithPredicateMatch(
+ RuntimeError,
+ lambda e: 'Attempted to use a closed Session.' in str(e)):
+ while True:
+ sess.run(c)
+ t = threading.Thread(target=update_thread)
+ t.start()
+ time.sleep(0.1)
+ sess.close()
+ t.join()
+
+ def testNotEntered(self):
+ # pylint: disable=protected-access
+ self.assertEqual(ops._default_session_stack.get_default(), None)
+ # pylint: enable=protected-access
+ with ops.device('/cpu:0'):
+ sess = session.Session()
+ c_1 = constant_op.constant(5.0)
+ with sess.graph.as_default():
+ c_2 = constant_op.constant(5.0)
+ self.assertEqual(c_1.graph, c_2.graph)
+ self.assertEqual(sess.run(c_2), 5.0)
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, lambda e: 'No default session is registered.' in str(e)):
+ c_2.eval()
+
+ def testInteractive(self):
+ with ops.device('/cpu:0'):
+ sess = session.InteractiveSession()
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+ self.assertAllEqual([[4.0, 4.0, 4.0]], c.eval())
+ d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1])
+ e = math_ops.matmul(c, d)
+ self.assertAllEqual([[24.0]], e.eval())
+ sess.close()
+
+ def testSharedGraph(self):
+ with ops.Graph().as_default() as g, ops.device('/cpu:0'):
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+
+ with session.Session(graph=g) as sess1:
+ with session.Session(graph=g) as sess2:
+ self.assertAllEqual(sess1.run(c), sess2.run(c))
+
+ def testDuplicatedInputs(self):
+ with session.Session() as sess:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[1, 3])
+ a_val, b_val, a2_val = sess.run([a, b, a])
+ self.assertAllEqual(a_val, [[1.0, 1.0]])
+ self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]])
+ self.assertAllEqual(a2_val, [[1.0, 1.0]])
+
+ def testFeedAndFetch(self):
+ with session.Session():
+ for dtype in [types.float32,
+ types.float64,
+ types.int32,
+ types.uint8,
+ types.int16,
+ types.int8,
+ types.int64,
+ types.bool,
+ types.complex64]:
+ for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
+ np_dtype = dtype.as_numpy_dtype
+
+ feed_t = array_ops.placeholder(dtype=dtype, shape=shape)
+ out_t = array_ops.identity(feed_t)
+
+ np_array = np.random.randint(-10, 10, shape)
+
+ if dtype == types.bool:
+ np_array = np_array > 0
+ elif dtype == types.complex64:
+ np_array = np.sqrt(np_array.astype(np_dtype))
+ else:
+ np_array = np_array.astype(np_dtype)
+
+ self.assertAllEqual(np_array,
+ out_t.eval(feed_dict={feed_t: np_array}))
+
+ def testStringFetch(self):
+ with session.Session():
+ for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
+ size = 1
+ for s in shape:
+ size *= s
+ c_list = np.array([str(i) for i in xrange(size)],
+ dtype=np.object).reshape(shape) if size > 0 else []
+ c = constant_op.constant(c_list)
+ self.assertAllEqual(c.eval(), c_list)
+
+ def testStringFeed(self):
+ with session.Session():
+ for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
+ size = 1
+ for s in shape:
+ size *= s
+ c_list = np.array([str(i) for i in xrange(size)],
+ dtype=np.object).reshape(shape)
+ feed_t = array_ops.placeholder(dtype=types.string, shape=shape)
+ c = array_ops.identity(feed_t)
+ self.assertAllEqual(c.eval(feed_dict={feed_t: c_list}), c_list)
+
+ def testStringFeedWithNullCharacters(self):
+ with session.Session():
+ c_list = ['\n\x01\x00', '\n\x00\x01']
+ feed_t = array_ops.placeholder(dtype=types.string, shape=[2])
+ c = array_ops.identity(feed_t)
+ out = c.eval(feed_dict={feed_t: c_list})
+ self.assertEqual(c_list[0], out[0])
+ self.assertEqual(c_list[1], out[1])
+
+ def testInvalidTargetFails(self):
+ with self.assertRaises(RuntimeError):
+ session.Session("INVALID_TARGET")
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/client/tensorflow_server.i b/tensorflow/python/client/tensorflow_server.i
new file mode 100644
index 0000000000..65b3826961
--- /dev/null
+++ b/tensorflow/python/client/tensorflow_server.i
@@ -0,0 +1,16 @@
+%include "tensorflow/python/platform/base.i"
+%import(module="tensorflow.python.pywrap_tensorflow") "tensorflow/python/lib/core/status.i"
+
+%{
+#include "tensorflow/core/public/tensorflow_server.h"
+%}
+
+%ignoreall
+
+%unignore tensorflow;
+%unignore tensorflow::LaunchTensorFlow;
+
+%include "tensorflow/core/public/tensorflow_server.h"
+
+%unignoreall
+
diff --git a/tensorflow/python/client/test_construction_fails_op.cc b/tensorflow/python/client/test_construction_fails_op.cc
new file mode 100644
index 0000000000..47b2b5b49c
--- /dev/null
+++ b/tensorflow/python/client/test_construction_fails_op.cc
@@ -0,0 +1,22 @@
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+REGISTER_OP("ConstructionFails");
+
+class ConstructionFailsOp : public OpKernel {
+ public:
+ explicit ConstructionFailsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES(ctx, false,
+ errors::InvalidArgument("Failure during construction."));
+ }
+
+ void Compute(OpKernelContext* ctx) override {}
+};
+
+REGISTER_KERNEL_BUILDER(Name("ConstructionFails").Device(DEVICE_CPU),
+ ConstructionFailsOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
new file mode 100644
index 0000000000..30e80f779f
--- /dev/null
+++ b/tensorflow/python/client/tf_session.i
@@ -0,0 +1,235 @@
+%include "tensorflow/python/platform/base.i"
+
+%{
+
+#include "numpy/arrayobject.h"
+
+#include "tensorflow/python/client/tf_session_helper.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/status.h"
+
+%}
+
+// Implements the StatusNotOK exception.
+%import(module="tensorflow.python.pywrap_tensorflow") "tensorflow/python/lib/core/status.i"
+
+// Required to use PyArray_* functions.
+%include "tensorflow/python/platform/numpy.i"
+%init %{
+import_array();
+%}
+
+// Release the Python GIL for the duration of most methods.
+%exception {
+ Py_BEGIN_ALLOW_THREADS;
+ $action
+ Py_END_ALLOW_THREADS;
+}
+
+// Proto input arguments to C API functions are passed as a (const
+// void*, size_t) pair. In Python, typemap these to a single string
+// argument.
+%typemap(in) (const void* proto, size_t proto_len) {
+ char* c_string;
+ Py_ssize_t py_size;
+ if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
+ // Python has raised an error (likely TypeError or UnicodeEncodeError).
+ SWIG_fail;
+ }
+ $1 = static_cast<void*>(c_string);
+ $2 = static_cast<size_t>(py_size);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// BEGIN TYPEMAPS FOR tensorflow::TF_Run_wrapper()
+////////////////////////////////////////////////////////////////////////////////
+
+// The wrapper takes a vector of pairs of feed names and feed
+// values. In Python this is represented as dictionary mapping strings
+// to numpy arrays.
+%typemap(in) const tensorflow::FeedVector& inputs (
+ tensorflow::FeedVector temp,
+ tensorflow::Safe_PyObjectPtr temp_string_list(tensorflow::make_safe(nullptr)),
+ tensorflow::Safe_PyObjectPtr temp_array_list(tensorflow::make_safe(nullptr))) {
+ if (!PyDict_Check($input)) {
+ SWIG_fail;
+ }
+
+ temp_string_list = tensorflow::make_safe(PyList_New(0));
+ if (!temp_string_list) {
+ SWIG_fail;
+ }
+ temp_array_list = tensorflow::make_safe(PyList_New(0));
+ if (!temp_array_list) {
+ SWIG_fail;
+ }
+
+ PyObject* key;
+ PyObject* value;
+ Py_ssize_t pos = 0;
+ while (PyDict_Next($input, &pos, &key, &value)) {
+ const char* key_string = PyString_AsString(key);
+ if (!key_string) {
+ SWIG_fail;
+ }
+
+ // The ndarray must be stored as contiguous bytes in C (row-major) order.
+ PyObject* array_object = PyArray_FromAny(
+ value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr);
+ if (!array_object) {
+ SWIG_fail;
+ }
+ PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_object);
+
+ // Keep a reference to the key and the array, in case the incoming dict is
+ // modified, and/or to avoid leaking references on failure.
+ if (PyList_Append(temp_string_list.get(), key) == -1) {
+ SWIG_fail;
+ }
+ if (PyList_Append(temp_array_list.get(), array_object) == -1) {
+ SWIG_fail;
+ }
+
+ temp.push_back(std::make_pair(key_string, array));
+ }
+ $1 = &temp;
+}
+
+// The wrapper also takes a list of fetch and target names. In Python this is
+// represented as a list of strings.
+%typemap(in) const tensorflow::NameVector& (
+ tensorflow::NameVector temp,
+ tensorflow::Safe_PyObjectPtr temp_string_list(tensorflow::make_safe(nullptr))) {
+ if (!PyList_Check($input)) {
+ SWIG_fail;
+ }
+
+ Py_ssize_t len = PyList_Size($input);
+
+ temp_string_list = tensorflow::make_safe(PyList_New(len));
+ if (!temp_string_list) {
+ SWIG_fail;
+ }
+
+ for (Py_ssize_t i = 0; i < len; ++i) {
+ PyObject* elem = PyList_GetItem($input, i);
+ if (!elem) {
+ SWIG_fail;
+ }
+
+ // Keep a reference to the string in case the incoming list is modified.
+ PyList_SET_ITEM(temp_string_list.get(), i, elem);
+ Py_INCREF(elem);
+
+ const char* fetch_name = PyString_AsString(elem);
+ if (!fetch_name) {
+ PyErr_SetString(PyExc_TypeError,
+ "a fetch or target name was not a string");
+ SWIG_fail;
+ }
+
+ // TODO(mrry): Avoid copying the fetch name in, if this impacts performance.
+ temp.push_back(fetch_name);
+ }
+ $1 = &temp;
+}
+
+
+// The wrapper has two outputs: a tensorflow::Status, and a vector of
+// PyObjects containing the fetch results (iff the status is OK). Since
+// the interpretation of the vector depends on the status, we define
+// them as two consecutive out arguments, so that they can be accessed
+// together in a typemap.
+
+// Define temporaries for the argout outputs.
+%typemap(in, numinputs=0) tensorflow::Status* out_status (
+ tensorflow::Status temp) {
+ $1 = &temp;
+}
+%typemap(in, numinputs=0) tensorflow::PyObjectVector* out_values (
+ tensorflow::PyObjectVector temp) {
+ $1 = &temp;
+}
+
+// Raise a StatusNotOK exception if the out_status is not OK;
+// otherwise build a Python list of outputs and return it.
+%typemap(argout, fragment="StatusNotOK") (
+ tensorflow::Status* out_status, tensorflow::PyObjectVector* out_values) {
+ if (!$1->ok()) {
+ RaiseStatusNotOK(*$1, $descriptor(tensorflow::Status*));
+ SWIG_fail;
+ } else {
+ tensorflow::Safe_PyObjectVector out_values_safe;
+ for (int i = 0; i < $2->size(); ++i) {
+ out_values_safe.emplace_back(tensorflow::make_safe($2->at(i)));
+ }
+
+ $result = PyList_New($2->size());
+ if (!$result) {
+ SWIG_fail;
+ }
+
+ for (int i = 0; i < $2->size(); ++i) {
+ PyList_SET_ITEM($result, i, $2->at(i));
+ out_values_safe[i].release();
+ }
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// END TYPEMAPS FOR tensorflow::TF_Run_wrapper()
+////////////////////////////////////////////////////////////////////////////////
+
+
+
+// Include the functions from tensor_c_api.h, except TF_Run.
+%ignoreall
+%unignore TF_Code;
+%unignore TF_Status;
+%unignore TF_NewStatus;
+%unignore TF_DeleteStatus;
+%unignore TF_GetCode;
+%unignore TF_Message;
+%unignore TF_SessionOptions;
+%rename("_TF_SetTarget") TF_SetTarget;
+%rename("_TF_SetConfig") TF_SetConfig;
+%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
+%unignore TF_DeleteSessionOptions;
+%unignore TF_NewSession;
+%unignore TF_CloseSession;
+%unignore TF_DeleteSession;
+%unignore TF_ExtendGraph;
+%include "tensorflow/core/public/tensor_c_api.h"
+%ignoreall
+
+%insert("python") %{
+ def TF_NewSessionOptions(target=None, config=None):
+ opts = _TF_NewSessionOptions()
+ if target is not None:
+ _TF_SetTarget(opts, target)
+ if config is not None:
+ from tensorflow.core.framework import config_pb2
+ if not isinstance(config, config_pb2.ConfigProto):
+ raise TypeError("Expected config_pb2.ConfigProto, "
+ "but got %s" % type(config))
+ status = TF_NewStatus()
+ config_str = config.SerializeToString()
+ _TF_SetConfig(opts, config_str, len(config_str), status)
+ if TF_GetCode(status) != 0:
+ raise ValueError(TF_Message(status))
+ return opts
+%}
+
+// Include the wrapper for TF_Run from tf_session_helper.h.
+
+// The %exception block above releases the Python GIL for the length
+// of each wrapped method. We disable this behavior for TF_Run
+// because it uses the Python allocator.
+%noexception tensorflow::TF_Run_wrapper;
+%rename(TF_Run) tensorflow::TF_Run_wrapper;
+%unignore tensorflow;
+%unignore TF_Run;
+
+%include "tensorflow/python/client/tf_session_helper.h"
+
+%unignoreall
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
new file mode 100644
index 0000000000..06483da87b
--- /dev/null
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -0,0 +1,518 @@
+#include "tensorflow/python/client/tf_session_helper.h"
+
+#include <cstring>
+
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Container types for the various temporary values used internally in
+// the wrapper.
+
+// A TF_TensorVector is a vector of borrowed pointers to TF_Tensors.
+typedef gtl::InlinedVector<TF_Tensor*, 8> TF_TensorVector;
+
+// Safe containers for (an) owned TF_Tensor(s). On destruction, the
+// tensor will be deleted by TF_DeleteTensor.
+typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
+ Safe_TF_TensorPtr;
+typedef std::vector<Safe_TF_TensorPtr> Safe_TF_TensorVector;
+Safe_TF_TensorPtr make_safe(TF_Tensor* tensor) {
+ return Safe_TF_TensorPtr(tensor, TF_DeleteTensor);
+}
+
+// Safe container for an owned TF_Status. On destruction, the status
+// will be deleted by TF_DeleteStatus.
+typedef std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>
+ Safe_TF_StatusPtr;
+Safe_TF_StatusPtr make_safe(TF_Status* status) {
+ return Safe_TF_StatusPtr(status, TF_DeleteStatus);
+}
+
+Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr,
+ TF_DataType* out_tf_datatype) {
+ PyObject* key;
+ PyObject* value;
+ Py_ssize_t pos = 0;
+ if (PyDict_Next(descr->fields, &pos, &key, &value)) {
+ const char* key_string = PyString_AsString(key);
+ if (!key_string) {
+ return errors::Internal("Corrupt numpy type descriptor");
+ }
+ tensorflow::string key = key_string;
+ // The typenames here should match the field names in the custom struct
+ // types constructed in test_util.py.
+ // TODO(mrry,keveman): Investigate Numpy type registration to replace this
+ // hard-coding of names.
+ if (key == "quint8") {
+ *out_tf_datatype = TF_QUINT8;
+ } else if (key == "qint8") {
+ *out_tf_datatype = TF_QINT8;
+ } else if (key == "qint32") {
+ *out_tf_datatype = TF_QINT32;
+ } else {
+ return errors::Internal("Unsupported numpy data type");
+ }
+ return Status::OK();
+ }
+ return errors::Internal("Unsupported numpy data type");
+}
+
+Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array,
+ TF_DataType* out_tf_datatype) {
+ int pyarray_type = PyArray_TYPE(array);
+ PyArray_Descr* descr = array->descr;
+ switch (pyarray_type) {
+ case NPY_FLOAT32:
+ *out_tf_datatype = TF_FLOAT;
+ break;
+ case NPY_FLOAT64:
+ *out_tf_datatype = TF_DOUBLE;
+ break;
+ case NPY_INT32:
+ *out_tf_datatype = TF_INT32;
+ break;
+ case NPY_UINT8:
+ *out_tf_datatype = TF_UINT8;
+ break;
+ case NPY_INT16:
+ *out_tf_datatype = TF_INT16;
+ break;
+ case NPY_INT8:
+ *out_tf_datatype = TF_INT8;
+ break;
+ case NPY_INT64:
+ *out_tf_datatype = TF_INT64;
+ break;
+ case NPY_BOOL:
+ *out_tf_datatype = TF_BOOL;
+ break;
+ case NPY_COMPLEX64:
+ *out_tf_datatype = TF_COMPLEX;
+ break;
+ case NPY_OBJECT:
+ *out_tf_datatype = TF_STRING;
+ break;
+ case NPY_VOID:
+ // Quantized types are currently represented as custom struct types.
+ // PyArray_TYPE returns NPY_VOID for structs, and we should look into
+ // descr to derive the actual type.
+ return PyArrayDescr_to_TF_DataType(descr, out_tf_datatype);
+ default:
+ // TODO(mrry): Support these.
+ return errors::Internal("Unsupported feed type");
+ }
+ return Status::OK();
+}
+
+Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
+ int* out_pyarray_type) {
+ switch (tf_datatype) {
+ case TF_FLOAT:
+ *out_pyarray_type = NPY_FLOAT32;
+ break;
+ case TF_DOUBLE:
+ *out_pyarray_type = NPY_FLOAT64;
+ break;
+ case TF_INT32:
+ *out_pyarray_type = NPY_INT32;
+ break;
+ case TF_UINT8:
+ *out_pyarray_type = NPY_UINT8;
+ break;
+ case TF_INT16:
+ *out_pyarray_type = NPY_INT16;
+ break;
+ case TF_INT8:
+ *out_pyarray_type = NPY_INT8;
+ break;
+ case TF_INT64:
+ *out_pyarray_type = NPY_INT64;
+ break;
+ case TF_BOOL:
+ *out_pyarray_type = NPY_BOOL;
+ break;
+ case TF_COMPLEX:
+ *out_pyarray_type = NPY_COMPLEX64;
+ break;
+ case TF_STRING:
+ *out_pyarray_type = NPY_OBJECT;
+ break;
+ // TODO(keveman): These should be changed to NPY_VOID, and the type used for
+ // the resulting numpy array should be the custom struct types that we
+ // expect for quantized types.
+ case TF_QINT8:
+ *out_pyarray_type = NPY_INT8;
+ break;
+ case TF_QUINT8:
+ *out_pyarray_type = NPY_UINT8;
+ break;
+ case TF_QINT32:
+ *out_pyarray_type = NPY_INT32;
+ break;
+ case TF_BFLOAT16:
+ *out_pyarray_type = NPY_UINT16;
+ break;
+ default:
+ return errors::Internal("Unsupported fetch type");
+ }
+ return Status::OK();
+}
+
+// Iterate over the string array 'array', extract the ptr and len of each string
+// element and call f(ptr, len).
+template <typename F>
+Status PyStringArrayMap(PyArrayObject* array, F f) {
+ Safe_PyObjectPtr iter = tensorflow::make_safe(
+ PyArray_IterNew(reinterpret_cast<PyObject*>(array)));
+ while (PyArray_ITER_NOTDONE(iter.get())) {
+ auto item = tensorflow::make_safe(
+ PyArray_GETITEM(array, PyArray_ITER_DATA(iter.get())));
+ if (!item.get()) {
+ return errors::Internal("Unable to get element from the feed.");
+ }
+ char* ptr;
+ Py_ssize_t len;
+ int success = PyString_AsStringAndSize(item.get(), &ptr, &len);
+ if (success != 0) {
+ return errors::Internal("Unable to get element from the feed.");
+ }
+ f(ptr, len);
+ PyArray_ITER_NEXT(iter.get());
+ }
+ return Status::OK();
+}
+
+// Encode the strings in 'array' into a contiguous buffer and return the base of
+// the buffer. The caller takes ownership of the buffer.
+Status EncodePyStringArray(PyArrayObject* array, tensorflow::int64 nelems,
+ size_t* size, void** buffer) {
+ // Compute bytes needed for encoding.
+ *size = 0;
+ TF_RETURN_IF_ERROR(
+ PyStringArrayMap(array, [&size](char* ptr, Py_ssize_t len) {
+ *size += sizeof(tensorflow::uint64) +
+ tensorflow::core::VarintLength(len) + len;
+ }));
+ // Encode all strings.
+ std::unique_ptr<char[]> base_ptr(new char[*size]);
+ char* base = base_ptr.get();
+ char* data_start = base + sizeof(tensorflow::uint64) * nelems;
+ char* dst = data_start; // Where next string is encoded.
+ tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
+
+ TF_RETURN_IF_ERROR(PyStringArrayMap(
+ array, [&base, &data_start, &dst, &offsets](char* ptr, Py_ssize_t len) {
+ *offsets = (dst - data_start);
+ offsets++;
+ dst = tensorflow::core::EncodeVarint64(dst, len);
+ memcpy(dst, ptr, len);
+ dst += len;
+ }));
+ CHECK_EQ(dst, base + *size);
+ *buffer = base_ptr.release();
+ return Status::OK();
+}
+
+// Determine the pointer and offset of the string at offset 'i' in the string
+// tensor 'src', whose total length is 'num_elements'.
+static Status TF_StringTensor_GetPtrAndLen(const TF_Tensor* src,
+ tensorflow::int64 num_elements,
+ tensorflow::int64 i,
+ const char** ptr,
+ tensorflow::uint64* len) {
+ const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
+ const size_t src_size = TF_TensorByteSize(src);
+ const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
+ const char* limit = input + src_size;
+ tensorflow::uint64 offset =
+ reinterpret_cast<const tensorflow::uint64*>(input)[i];
+ const char* p =
+ tensorflow::core::GetVarint64Ptr(data_start + offset, limit, len);
+ if (offset >= (limit - data_start) || !p || (*len > (limit - p))) {
+ return errors::InvalidArgument("Malformed TF_STRING tensor; element ", i,
+ " out of range");
+ }
+ *ptr = p;
+ return Status::OK();
+}
+
+// Copy the string at offset 'i' in the (linearized) string tensor 'tensor' into
+// 'pyarray' at offset pointed by the 'i_ptr' iterator.
+static Status CopyStringToPyArrayElement(PyArrayObject* pyarray, void* i_ptr,
+ TF_Tensor* tensor,
+ tensorflow::int64 num_elements,
+ tensorflow::int64 i) {
+ const char* ptr;
+ tensorflow::uint64 len;
+ TF_RETURN_IF_ERROR(
+ TF_StringTensor_GetPtrAndLen(tensor, num_elements, i, &ptr, &len));
+ auto py_string = tensorflow::make_safe(PyString_FromStringAndSize(ptr, len));
+ int success =
+ PyArray_SETITEM(pyarray, PyArray_ITER_DATA(i_ptr), py_string.get());
+ if (success != 0) {
+ return errors::Internal("Error setting element ", i);
+ }
+ return Status::OK();
+}
+
+// Converts the given TF_Tensor to a Numpy array.
+// If the returned status is OK, the caller becomes the owner of *out_array.
+Status TF_Tensor_to_PyObject(TF_Tensor* tensor, PyObject** out_array) {
+ // A fetched operation will correspond to a null tensor, and a None
+ // in Python.
+ if (tensor == nullptr) {
+ Py_INCREF(Py_None);
+ *out_array = Py_None;
+ return Status::OK();
+ }
+
+ const int ndims = TF_NumDims(tensor);
+ gtl::InlinedVector<npy_intp, 4> dims(ndims);
+ tensorflow::int64 nelems = 1;
+ for (int i = 0; i < ndims; ++i) {
+ dims[i] = TF_Dim(tensor, i);
+ nelems *= dims[i];
+ }
+
+ // Convert TensorFlow dtype to numpy type descriptor.
+ int type_num;
+ TF_RETURN_IF_ERROR(
+ TF_DataType_to_PyArray_TYPE(TF_TensorType(tensor), &type_num));
+ PyArray_Descr* descr = PyArray_DescrFromType(type_num);
+
+ // Copy the TF_TensorData into a newly-created ndarray and return it.
+ // TODO(mrry): Perhaps investigate zero-copy approaches. This would involve
+ // creating an ndarray-like object that wraps the TF_Tensor buffer, and
+ // maps its destructor to TF_DeleteTensor.
+ Safe_PyObjectPtr safe_out_array =
+ tensorflow::make_safe(PyArray_Empty(ndims, dims.data(), descr, 0));
+ if (!safe_out_array) {
+ return errors::Internal("Could not allocate ndarray");
+ }
+ PyArrayObject* py_array =
+ reinterpret_cast<PyArrayObject*>(safe_out_array.get());
+ if (PyArray_NBYTES(py_array) != TF_TensorByteSize(tensor)) {
+ if (TF_TensorType(tensor) == TF_STRING) {
+ // Copy element by element.
+ auto iter = tensorflow::make_safe(PyArray_IterNew(safe_out_array.get()));
+ for (tensorflow::int64 i = 0; i < nelems; ++i) {
+ auto s =
+ CopyStringToPyArrayElement(py_array, iter.get(), tensor, nelems, i);
+ if (!s.ok()) {
+ return s;
+ }
+ PyArray_ITER_NEXT(iter.get());
+ }
+ } else {
+ return errors::Internal("ndarray was ", PyArray_NBYTES(py_array),
+ " bytes but TF_Tensor was ",
+ TF_TensorByteSize(tensor), " bytes");
+ }
+ } else {
+ memcpy(py_array->data, TF_TensorData(tensor), PyArray_NBYTES(py_array));
+ }
+
+ // PyArray_Return turns rank 0 arrays into numpy scalars
+ *out_array = PyArray_Return(
+ reinterpret_cast<PyArrayObject*>(safe_out_array.release()));
+ return Status::OK();
+}
+
+tensorflow::Status TF_Status_to_Status(TF_Status* tf_status) {
+ TF_Code code = TF_GetCode(tf_status);
+ const string message(TF_Message(tf_status));
+
+ switch (code) {
+ case TF_OK:
+ return Status::OK();
+ case TF_CANCELLED:
+ return errors::Cancelled(message);
+ case TF_UNKNOWN:
+ return errors::Unknown(message);
+ case TF_INVALID_ARGUMENT:
+ return errors::InvalidArgument(message);
+ case TF_DEADLINE_EXCEEDED:
+ return errors::DeadlineExceeded(message);
+ case TF_NOT_FOUND:
+ return errors::NotFound(message);
+ case TF_ALREADY_EXISTS:
+ return errors::AlreadyExists(message);
+ case TF_PERMISSION_DENIED:
+ return errors::PermissionDenied(message);
+ case TF_UNAUTHENTICATED:
+ return errors::Unauthenticated(message);
+ case TF_RESOURCE_EXHAUSTED:
+ return errors::ResourceExhausted(message);
+ case TF_FAILED_PRECONDITION:
+ return errors::FailedPrecondition(message);
+ case TF_ABORTED:
+ return errors::Aborted(message);
+ case TF_OUT_OF_RANGE:
+ return errors::OutOfRange(message);
+ case TF_UNIMPLEMENTED:
+ return errors::Unimplemented(message);
+ case TF_INTERNAL:
+ return errors::Internal(message);
+ case TF_UNAVAILABLE:
+ return errors::Unavailable(message);
+ case TF_DATA_LOSS:
+ return errors::DataLoss(message);
+ default:
+ return errors::Internal("Got error with unknown code: ", code, " ",
+ message);
+ }
+}
+
+static bool numpy_imported = false;
+
+} // namespace
+
+Safe_PyObjectPtr make_safe(PyObject* o) {
+ return Safe_PyObjectPtr(o, Py_DECREF_wrapper);
+}
+
+// Wrapper for TF_Run that converts the arguments to appropriate types.
+// If *out_status is OK, the caller becomes the owner of the PyObjects
+// in *out_values.
+void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs,
+ const NameVector& output_names,
+ const NameVector& target_nodes, Status* out_status,
+ PyObjectVector* out_values) {
+ // 0. Ensure that numpy has been imported.
+ if (!numpy_imported) {
+ import_array();
+ numpy_imported = true;
+ }
+
+ // 1. Convert the feed inputs to the appropriate form for TF_Run.
+ NameVector input_names;
+ Safe_PyObjectVector
+ py_inputs_safe; // Used to decref the input arrays on failure.
+ Safe_TF_TensorVector inputs_safe; // Used to delete tensors on failure.
+ TF_TensorVector inputs_unsafe; // Used to contain the arg to TF_Run.
+
+ for (const auto& name_and_array : inputs) {
+ py_inputs_safe.emplace_back(
+ make_safe(reinterpret_cast<PyObject*>(name_and_array.second)));
+ }
+
+ for (int i = 0; i < inputs.size(); ++i) {
+ input_names.push_back(inputs[i].first);
+ PyArrayObject* array = inputs[i].second;
+
+ // Convert numpy dtype to TensorFlow dtype.
+ TF_DataType dtype;
+ *out_status = PyArray_TYPE_to_TF_DataType(array, &dtype);
+ if (!out_status->ok()) {
+ return;
+ }
+
+ tensorflow::int64 nelems = 1;
+ gtl::InlinedVector<tensorflow::int64, 4> dims;
+ for (int i = 0; i < PyArray_NDIM(array); ++i) {
+ dims.push_back(PyArray_SHAPE(array)[i]);
+ nelems *= dims[i];
+ }
+
+ // Create a TF_Tensor based on the fed data. In the case of non-string data
+ // type, this steals a reference to array, which will be relinquished when
+ // the underlying buffer is deallocated. For string, a new temporary buffer
+ // is allocated into which the strings are encoded.
+ if (dtype != TF_STRING) {
+ // NOTE(mrry): We currently copy the numpy array into a new
+ // buffer to avoid possible issues on deallocation (such as
+ // having to acquire the Python Global Interpreter Lock).
+ // TODO(mrry): Investigate in what cases we can safely acquire
+ size_t size = PyArray_NBYTES(array);
+ // NOTE(mrry): 32 is the upper bound on current alignment
+ // requirements for tensorflow::Tensor. We hard code this here to
+ // avoid taking a dependency on Eigen in the client code.
+ void* data = tensorflow::cpu_allocator()->AllocateRaw(32, size);
+ std::memcpy(data, array->data, size);
+ inputs_safe.emplace_back(make_safe(
+ TF_NewTensor(dtype, dims.data(), dims.size(), data, size,
+ [](void* data, size_t len, void* arg) {
+ tensorflow::cpu_allocator()->DeallocateRaw(data);
+ },
+ nullptr)));
+ // The destruction of the numpy array will now be handled by the
+ // inputs_safe destructor.
+ py_inputs_safe[i].reset();
+ } else {
+ size_t size;
+ void* encoded;
+ Status s = EncodePyStringArray(array, nelems, &size, &encoded);
+ if (!s.ok()) {
+ *out_status = s;
+ return;
+ }
+ inputs_safe.emplace_back(
+ make_safe(TF_NewTensor(dtype, dims.data(), dims.size(), encoded, size,
+ [](void* data, size_t len, void* arg) {
+ delete[] reinterpret_cast<char*>(data);
+ },
+ array)));
+ // The destruction of the numpy array will now be handled by the
+ // inputs_safe destructor.
+ py_inputs_safe[i].reset();
+ }
+ inputs_unsafe.push_back(inputs_safe.back().get());
+ }
+
+ // 2. Allocate a container for the output data.
+ TF_TensorVector outputs(output_names.size());
+
+ Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
+
+ // 3. Actually call TF_Run().
+ Py_BEGIN_ALLOW_THREADS;
+ TF_Run(session, input_names.data(), inputs_unsafe.data(), input_names.size(),
+ const_cast<const char**>(output_names.data()), outputs.data(),
+ output_names.size(), const_cast<const char**>(target_nodes.data()),
+ target_nodes.size(), status.get());
+ Py_END_ALLOW_THREADS;
+
+ // 4. The TensorFlow runtime has taken ownership of the fed tensors,
+ // so we release the safe pointers to them.
+ for (auto& input : inputs_safe) {
+ input.release();
+ }
+
+ if (TF_GetCode(status.get()) != TF_OK) {
+ *out_status = TF_Status_to_Status(status.get());
+ return;
+ }
+
+ // 5. We now own the fetched tensors, so set up a safe container to
+ // delete them when we exit this scope.
+ Safe_TF_TensorVector tf_outputs_safe;
+ for (const auto& output : outputs) {
+ tf_outputs_safe.emplace_back(make_safe(output));
+ }
+
+ // 6. Convert the fetched tensors into numpy ndarrays. Store them in a safe
+ // container so that we do not leak
+ Safe_PyObjectVector py_outputs_safe;
+ for (int i = 0; i < output_names.size(); ++i) {
+ PyObject* py_array;
+ *out_status = TF_Tensor_to_PyObject(outputs[i], &py_array);
+ if (!out_status->ok()) {
+ return;
+ }
+ py_outputs_safe.emplace_back(make_safe(py_array));
+ }
+
+ // 7. If we reach this point, we have successfully built a list of objects
+ // so we can release them from the safe container.
+ for (auto& output : py_outputs_safe) {
+ out_values->push_back(output.release());
+ }
+ *out_status = Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
new file mode 100644
index 0000000000..12a7527ed9
--- /dev/null
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -0,0 +1,56 @@
+#ifndef TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
+#define TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
+
+#include <Python.h>
+
+#include "numpy/arrayobject.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor_c_api.h"
+
+namespace tensorflow {
+
+// Container types for the various arguments and temporary values used
+// in the wrapper.
+
+// A FeedVector is a vector of tensor name and numpy array pairs. The
+// name is a borrowed C string.
+typedef tensorflow::gtl::InlinedVector<std::pair<const char*, PyArrayObject*>,
+ 8> FeedVector;
+
+// A NameVector is a vector of tensor or operation names, as borrowed
+// C strings.
+typedef tensorflow::gtl::InlinedVector<const char*, 8> NameVector;
+
+// A PyObjectVector is a vector of borrowed pointers to PyObjects.
+typedef tensorflow::gtl::InlinedVector<PyObject*, 8> PyObjectVector;
+
+// Safe containers for (an) owned PyObject(s). On destruction, the
+// reference count of the contained object will be decremented.
+inline void Py_DECREF_wrapper(PyObject* o) { Py_DECREF(o); }
+typedef void (*Py_DECREF_wrapper_type)(PyObject*);
+typedef std::unique_ptr<PyObject, Py_DECREF_wrapper_type> Safe_PyObjectPtr;
+typedef std::vector<Safe_PyObjectPtr> Safe_PyObjectVector;
+Safe_PyObjectPtr make_safe(PyObject* o);
+
+// Run the graph associated with the session starting with the
+// supplied inputs[]. Regardless of success of failure, inputs[] are
+// stolen by the implementation (i.e. the implementation will
+// eventually call Py_DECREF on each array input).
+//
+// On success, the tensors corresponding to output_names[0,noutputs-1]
+// are placed in out_values[], and these outputs[] become the property
+// of the caller (the caller must eventually call Py_DECREF on them).
+//
+// On failure, out_status contains a tensorflow::Status with an error
+// message.
+void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs,
+ const NameVector& output_names,
+ const NameVector& target_nodes, Status* out_status,
+ PyObjectVector* out_values);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
diff --git a/tensorflow/python/framework/__init__.py b/tensorflow/python/framework/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/framework/__init__.py
diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py
new file mode 100644
index 0000000000..676e5f779a
--- /dev/null
+++ b/tensorflow/python/framework/device.py
@@ -0,0 +1,220 @@
+"""Class to represent a device."""
+import copy
+
+
+class Device(object):
+ """Represents a Device."""
+
+ def __init__(self, job=None, replica=None, task=None, device_type=None,
+ device_index=None):
+ """Create a new device object.
+
+ Args:
+ job: string. Optional device job name.
+ replica: int. Optional replica index.
+ task: int. Optional task index.
+ device_type: Optional device type string (e.g. "CPU" or "GPU")
+ device_index: int. Optional device index. If left
+ unspecified, device represents 'any' device_index.
+ """
+ self.job = job
+ self.replica = replica
+ self.task = task
+ if device_type == "cpu" or device_type == "gpu":
+ # For backwards compatibility only, we support lowercase variants of
+ # cpu and gpu but turn them into uppercase here.
+ self.device_type = device_type.upper()
+ else:
+ self.device_type = device_type
+ self.device_index = device_index
+
+ def _clear(self):
+ self._job = None
+ self._replica = None
+ self._task = None
+ self.device_type = None
+ self.device_index = None
+
+ @property
+ def job(self):
+ return self._job
+
+ @job.setter
+ def job(self, job):
+ if job is not None:
+ self._job = str(job)
+ else:
+ self._job = None
+
+ @property
+ def replica(self):
+ return self._replica
+
+ @replica.setter
+ def replica(self, replica):
+ if replica is not None:
+ self._replica = int(replica)
+ else:
+ self._replica = None
+
+ @property
+ def task(self):
+ return self._task
+
+ @task.setter
+ def task(self, task):
+ if task is not None:
+ self._task = int(task)
+ else:
+ self._task = None
+
+ def parse_from_string(self, spec):
+ """Parse a Device name into its components.
+
+ Args:
+ spec: a string of the form
+ /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
+ or
+ /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
+ as cpu and gpu are mutually exclusive.
+ All entries are optional.
+
+ Returns:
+ The Device, for convenience.
+
+ Raises:
+ ValueError: if the spec was not valid.
+ """
+ self._clear()
+ splits = [x.split(":") for x in spec.split("/")]
+ for y in splits:
+ ly = len(y)
+ if y:
+ # NOTE(mdevin): we use the property getters here.
+ if ly == 2 and y[0] == "job":
+ self.job = y[1]
+ elif ly == 2 and y[0] == "replica":
+ self.replica = y[1]
+ elif ly == 2 and y[0] == "task":
+ self.task = y[1]
+ elif ((ly == 1 or ly == 2) and
+ ((y[0].upper() == "GPU") or (y[0].upper() == "CPU"))):
+ if self.device_type is not None:
+ raise ValueError("Cannot specify multiple device types: %s" % spec)
+ self.device_type = y[0].upper()
+ if ly == 2 and y[1] != "*":
+ self.device_index = int(y[1])
+ elif ly == 3 and y[0] == "device":
+ if self.device_type is not None:
+ raise ValueError("Cannot specify multiple device types: %s" % spec)
+ self.device_type = y[1]
+ if y[2] != "*":
+ self.device_index = int(y[2])
+ elif ly and y[0] != "": # pylint: disable=g-explicit-bool-comparison
+ raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec))
+
+ return self
+
+ def merge_from(self, dev):
+ """Merge the properties of "dev" into this Device.
+
+ Args:
+ dev: a Device.
+ """
+ if dev.job is not None:
+ self.job = dev.job
+ if dev.replica is not None:
+ self.replica = dev.replica
+ if dev.task is not None:
+ self.task = dev.task
+ if dev.device_type is not None:
+ self.device_type = dev.device_type
+ if dev.device_index is not None:
+ self.device_index = dev.device_index
+
+ def to_string(self):
+ """Return a Device specification string.
+
+ Returns:
+ a string of the form /job:<name>/replica:<id>/task:<id>/device:cpu:<id>
+ or /job:<name>/replica:<id>/task:<id>/device:cpu:<id>.
+ """
+ dev = ""
+ if self.job is not None:
+ dev += "/job:" + self.job
+ if self.replica is not None:
+ dev += "/replica:" + str(self.replica)
+ if self.task is not None:
+ dev += "/task:" + str(self.task)
+ if self.device_type is not None:
+ device_index_string = "*"
+ if self.device_index is not None:
+ device_index_string = str(self.device_index)
+ dev += "/device:%s:%s" % (self.device_type, device_index_string)
+ return dev
+
+
+def from_string(spec):
+ """Construct a Device from a string.
+
+ Args:
+ spec: a string of the form
+ /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
+ or
+ /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
+ as cpu and gpu are mutually exclusive.
+ All entries are optional.
+
+ Returns:
+ A Device.
+ """
+ return Device().parse_from_string(spec)
+
+
+def check_valid(spec):
+ """Check that a device spec is valid.
+
+ Args:
+ spec: a string.
+
+ Raises:
+ An exception if the spec is invalid.
+ """
+ # Construct a device. It will assert a failure if spec is invalid.
+ from_string(spec)
+
+
+def merge_device(spec):
+ """Returns a device function that merges devices specifications.
+
+ This can be used to merge partial specifications of devices. The
+ innermost setting for a device field takes precedence. For example:
+
+ with tf.Device(MergeDevice("/device:GPU:0"))
+ # Nodes created here have device "/device:GPU:0"
+ with tf.Device(MergeDevice("/job:worker")):
+ # Nodes created here have device "/job:worker/device:GPU:0"
+ with tf.Device(MergeDevice("/device:CPU:0")):
+ # Nodes created here have device "/job:worker/device:CPU:0"
+ with tf.Device(MergeDevice("/job:ps")):
+ # Nodes created here have device "/job:ps/device:CPU:0"
+
+ Args:
+ spec: A device or a device spec string (partially) describing the
+ device that should be used for all nodes created in the scope of
+ the returned device function's with block.
+
+ Returns:
+ A device function with the above-described behavior.
+
+ Raises:
+ ValueError: if the spec was not valid.
+ """
+ if not isinstance(spec, Device):
+ spec = from_string(spec or "")
+ def _device_function(node_def):
+ current_device = from_string(node_def.device or "")
+ copy_spec = copy.copy(spec)
+ copy_spec.merge_from(current_device) # current_device takes precedence.
+ return copy_spec
+ return _device_function
diff --git a/tensorflow/python/framework/device_test.py b/tensorflow/python/framework/device_test.py
new file mode 100644
index 0000000000..0a244b0815
--- /dev/null
+++ b/tensorflow/python/framework/device_test.py
@@ -0,0 +1,122 @@
+"""Tests for tensorflow.python.framework.device."""
+import tensorflow.python.platform
+
+from tensorflow.python.framework import device
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class DeviceTest(test_util.TensorFlowTestCase):
+
+ def testEmpty(self):
+ d = device.Device()
+ self.assertEquals("", d.ToString())
+ d.parse_from_string("")
+ self.assertEquals("", d.ToString())
+
+ def testConstructor(self):
+ d = device.Device(job="j", replica=0, task=1,
+ device_type="CPU", device_index=2)
+ self.assertEquals("j", d.job)
+ self.assertEquals(0, d.replica)
+ self.assertEquals(1, d.task)
+ self.assertEquals("CPU", d.device_type)
+ self.assertEquals(2, d.device_index)
+ self.assertEquals("/job:j/replica:0/task:1/device:CPU:2", d.to_string())
+
+ d = device.Device(device_type="GPU", device_index=0)
+ self.assertEquals("/device:GPU:0", d.to_string())
+
+ def testto_string(self):
+ d = device.Device()
+ d.job = "foo"
+ self.assertEquals("/job:foo", d.to_string())
+ d.task = 3
+ self.assertEquals("/job:foo/task:3", d.to_string())
+ d.device_type = "CPU"
+ d.device_index = 0
+ self.assertEquals("/job:foo/task:3/device:CPU:0", d.to_string())
+ d.task = None
+ d.replica = 12
+ self.assertEquals("/job:foo/replica:12/device:CPU:0", d.to_string())
+ d.device_type = "GPU"
+ d.device_index = 2
+ self.assertEquals("/job:foo/replica:12/device:GPU:2", d.to_string())
+ d.device_type = "CPU"
+ d.device_index = 1
+ self.assertEquals("/job:foo/replica:12/device:CPU:1", d.to_string())
+ d.device_type = None
+ d.device_index = None
+ d.cpu = None
+ self.assertEquals("/job:foo/replica:12", d.to_string())
+
+ # Test wildcard
+ d = device.Device(job="foo", replica=12, task=3, device_type="GPU")
+ self.assertEquals("/job:foo/replica:12/task:3/device:GPU:*", d.to_string())
+
+ def testParse(self):
+ d = device.Device()
+ d.parse_from_string("/job:foo/replica:0")
+ self.assertEquals("/job:foo/replica:0", d.to_string())
+ d.parse_from_string("/replica:1/task:0/cpu:0")
+ self.assertEquals("/replica:1/task:0/device:CPU:0", d.to_string())
+ d.parse_from_string("/replica:1/task:0/device:CPU:0")
+ self.assertEquals("/replica:1/task:0/device:CPU:0", d.to_string())
+ d.parse_from_string("/job:muu/gpu:2")
+ self.assertEquals("/job:muu/device:GPU:2", d.to_string())
+ with self.assertRaises(Exception) as e:
+ d.parse_from_string("/job:muu/gpu:2/cpu:0")
+ self.assertTrue("Cannot specify multiple device" in e.exception.message)
+
+ def testFromString(self):
+ d = device.from_string("/job:foo/replica:0")
+ self.assertEquals("/job:foo/replica:0", d.to_string())
+ with self.assertRaises(Exception) as e:
+ d = device.from_string("/job:muu/gpu:2/cpu:0")
+ self.assertTrue("Cannot specify multiple device" in e.exception.message)
+
+ d = device.from_string("/job:foo/replica:0/task:3/cpu:*")
+ self.assertEquals(None, d.device_index)
+ d = device.from_string("/job:foo/replica:0/task:3/gpu:7")
+ self.assertEquals(7, d.device_index)
+ d = device.from_string("/job:foo/replica:0/task:3/device:GPU:7")
+ self.assertEquals(7, d.device_index)
+
+ def testMerge(self):
+ d = device.from_string("/job:foo/replica:0")
+ self.assertEquals("/job:foo/replica:0", d.to_string())
+ d.merge_from(device.from_string("/task:1/gpu:2"))
+ self.assertEquals("/job:foo/replica:0/task:1/device:GPU:2", d.to_string())
+
+ d = device.Device()
+ d.merge_from(device.from_string("/task:1/cpu:0"))
+ self.assertEquals("/task:1/device:CPU:0", d.to_string())
+ d.merge_from(device.from_string("/job:boo/gpu:0"))
+ self.assertEquals("/job:boo/task:1/device:GPU:0", d.to_string())
+ d.merge_from(device.from_string("/job:muu/cpu:2"))
+ self.assertEquals("/job:muu/task:1/device:CPU:2", d.to_string())
+ d.merge_from(device.from_string("/job:muu/device:MyFunnyDevice:2"))
+ self.assertEquals("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())
+
+ def testCheckValid(self):
+ device.CheckValid("/job:foo/replica:0")
+
+ with self.assertRaises(Exception) as e:
+ device.CheckValid("/job:j/replica:foo")
+ self.assertTrue("invalid literal for int" in e.exception.message)
+
+ with self.assertRaises(Exception) as e:
+ device.CheckValid("/job:j/task:bar")
+ self.assertTrue("invalid literal for int" in e.exception.message)
+
+ with self.assertRaises(Exception) as e:
+ device.CheckValid("/bar:muu/baz:2")
+ self.assertTrue("Unknown attribute: 'bar'" in e.exception.message)
+
+ with self.assertRaises(Exception) as e:
+ device.CheckValid("/cpu:0/gpu:2")
+ self.assertTrue("Cannot specify multiple device" in e.exception.message)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/framework/docs.py b/tensorflow/python/framework/docs.py
new file mode 100644
index 0000000000..68dbb3df72
--- /dev/null
+++ b/tensorflow/python/framework/docs.py
@@ -0,0 +1,492 @@
+"""Updates generated docs from Python doc comments.
+
+Both updates the files in the file-system and executes g4 commands to
+make sure any changes are ready to be submitted.
+"""
+
+import inspect
+import os
+import re
+import sys
+
+
+_arg_re = re.compile(" *([*]{0,2}[a-zA-Z][a-zA-Z0-9_]*):")
+_section_re = re.compile("([A-Z][a-zA-Z ]*):$")
+_always_drop_symbol_re = re.compile("_[_a-zA-Z0-9]")
+_anchor_re = re.compile(r"^[\w.]+$")
+_member_mark = "@@"
+
+
+class Document(object):
+ """Base class for an automatically generated document."""
+
+ def write_markdown_to_file(self, f):
+ """Writes a Markdown-formatted version of this document to file `f`.
+
+ Args:
+ f: The output file.
+ """
+ raise NotImplementedError("Document.WriteToFile")
+
+
+class Index(Document):
+ """An automatically generated index for a collection of documents."""
+
+ def __init__(self, module_to_name, members, filename_to_library_map):
+ """Creates a new Index.
+
+ Args:
+ module_to_name: Dictionary mapping modules to short names.
+ members: Dictionary mapping member name to (fullname, member).
+ filename_to_library_map: A list of (filename, Library) pairs. The order
+ corresponds to the order in which the libraries appear in the index.
+ """
+ self._module_to_name = module_to_name
+ self._members = members
+ self._filename_to_library_map = filename_to_library_map
+
+ def write_markdown_to_file(self, f):
+ """Writes this index to file `f`.
+
+ The output is formatted as an unordered list. Each list element
+ contains the title of the library, followed by a list of symbols
+ in that library hyperlinked to the corresponding anchor in that
+ library.
+
+ Args:
+ f: The output file.
+ """
+ print >>f, "<!-- This file is machine generated: DO NOT EDIT! -->"
+ print >>f, ""
+ print >>f, "# TensorFlow Python reference documentation"
+ print >>f, ""
+ for filename, library in self._filename_to_library_map:
+ per_symbol_links = []
+ for name in sorted(library.mentioned):
+ if name in self._members:
+ fullname, member = self._members[name]
+ anchor = _get_anchor(self._module_to_name, fullname)
+ prefix = "class " * inspect.isclass(member)
+ per_symbol_links.append("[%s%s](%s#%s)" %
+ (prefix, name, filename, anchor))
+ if per_symbol_links:
+ print >>f, "* <b>[%s](%s)</b>: %s" % (library.title, filename,
+ ",\n ".join(per_symbol_links))
+ print >>f, ""
+
+ # actually include the files right here
+ print >>f, '<div class="sections-order" style="display: none;">\n<!--'
+ for filename, _ in self._filename_to_library_map:
+ print >>f, "<!-- %s -->" % filename
+ print >>f, "-->\n</div>"
+
+def collect_members(module_to_name):
+ """Collect all symbols from a list of modules.
+
+ Args:
+ module_to_name: Dictionary mapping modules to short names.
+
+ Returns:
+ Dictionary mapping name to (fullname, member) pairs.
+ """
+ members = {}
+ for module, module_name in module_to_name.iteritems():
+ for name, member in inspect.getmembers(module):
+ if ((inspect.isfunction(member) or inspect.isclass(member)) and
+ not _always_drop_symbol_re.match(name)):
+ fullname = '%s.%s' % (module_name, name)
+ if name in members:
+ other_fullname, other_member = members[name]
+ if member is not other_member:
+ raise RuntimeError("Short name collision between %s and %s" %
+ (fullname, other_fullname))
+ if len(fullname) == len(other_fullname):
+ raise RuntimeError("Can't decide whether to use %s or %s for %s: "
+ "both full names have length %d" %
+ (fullname, other_fullname, len(fullname)))
+ if len(fullname) > len(other_fullname):
+ continue # Use the shorter full name
+ members[name] = fullname, member
+ return members
+
+
+def _get_anchor(module_to_name, fullname):
+ """Turn a full member name into an anchor.
+
+ Args:
+ module_to_name: Dictionary mapping modules to short names.
+ fullname: Fully qualified name of symbol.
+
+ Returns:
+ HTML anchor string. The longest module name prefix of fullname is
+ removed to make the anchor.
+
+ Raises:
+ ValueError: If fullname uses characters invalid in an anchor.
+ """
+ if not _anchor_re.match(fullname):
+ raise ValueError("'%s' is not a valid anchor" % fullname)
+ anchor = fullname
+ for module_name in module_to_name.itervalues():
+ if fullname.startswith(module_name + "."):
+ rest = fullname[len(module_name)+1:]
+ # Use this prefix iff it is longer than any found before
+ if len(anchor) > len(rest):
+ anchor = rest
+ return anchor
+
+
+class Library(Document):
+ """An automatically generated document for a set of functions and classes."""
+
+ def __init__(self,
+ title,
+ module,
+ module_to_name,
+ members,
+ documented,
+ exclude_symbols=(),
+ catch_all=False):
+ """Creates a new Library.
+
+ Args:
+ title: A human-readable title for the library.
+ module: Module to pull high level docstring from (for table of contents,
+ list of Ops to document, etc.).
+ module_to_name: Dictionary mapping modules to short names.
+ members: Dictionary mapping member name to (fullname, member).
+ documented: Set of documented names to update.
+ exclude_symbols: A list of specific symbols to exclude.
+ """
+ self._title = title
+ self._module = module
+ self._module_to_name = module_to_name
+ self._members = dict(members) # Copy since we mutate it below
+ self._exclude_symbols = frozenset(exclude_symbols)
+ documented.update(exclude_symbols)
+ self._documented = documented
+ self._mentioned = set()
+
+ @property
+ def title(self):
+ """The human-readable title for this library."""
+ return self._title
+
+ @property
+ def mentioned(self):
+ """Set of names mentioned in this library."""
+ return self._mentioned
+
+ @property
+ def exclude_symbols(self):
+ """Set of excluded symbols."""
+ return self._exclude_symbols
+
+ def _should_include_member(self, name, member):
+ """Returns True if this member should be included in the document."""
+ # Always exclude symbols matching _always_drop_symbol_re.
+ if _always_drop_symbol_re.match(name):
+ return False
+ # Finally, exclude any specifically-excluded symbols.
+ if name in self._exclude_symbols:
+ return False
+ return True
+
+ def get_imported_modules(self, module):
+ """Returns the list of modules imported from `module`."""
+ for name, member in inspect.getmembers(module):
+ if inspect.ismodule(member):
+ yield name, member
+
+ def get_class_members(self, cls_name, cls):
+ """Returns the list of class members to document in `cls`.
+
+ This function filters the class member to ONLY return those
+ defined by the class. It drops the inherited ones.
+
+ Args:
+ cls_name: Qualified name of `cls`.
+ cls: An inspect object of type 'class'.
+
+ Yields:
+ name, member tuples.
+ """
+ for name, member in inspect.getmembers(cls):
+ # Only show methods and properties presently.
+ if not (inspect.ismethod(member) or isinstance(member, property)):
+ continue
+ if ((inspect.ismethod(member) and member.__name__ == "__init__")
+ or self._should_include_member(name, member)):
+ yield name, ("%s.%s" % (cls_name, name), member)
+
+ def _generate_signature_for_function(self, func):
+ """Given a function, returns a string representing its args."""
+ args_list = []
+ argspec = inspect.getargspec(func)
+ first_arg_with_default = (
+ len(argspec.args or []) - len(argspec.defaults or []))
+ for arg in argspec.args[:first_arg_with_default]:
+ if arg == "self":
+ # Python documentation typically skips `self` when printing method
+ # signatures.
+ continue
+ args_list.append(arg)
+ if argspec.defaults:
+ for arg, default in zip(
+ argspec.args[first_arg_with_default:], argspec.defaults):
+ args_list.append("%s=%r" % (arg, default))
+ if argspec.varargs:
+ args_list.append("*" + argspec.varargs)
+ if argspec.keywords:
+ args_list.append("**" + argspec.keywords)
+ return "(" + ", ".join(args_list) + ")"
+
+ def _remove_docstring_indent(self, docstring):
+ """Remove indenting.
+
+ We follow Python's convention and remove the minimum indent of the lines
+ after the first, see:
+ https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
+ preserving relative indentation.
+
+ Args:
+ docstring: A docstring.
+
+ Returns:
+ A list of strings, one per line, with the minimum indent stripped.
+ """
+ docstring = docstring or ""
+ lines = docstring.strip().split("\n")
+
+ min_indent = len(docstring)
+ for l in lines[1:]:
+ l = l.rstrip()
+ if l:
+ i = 0
+ while i < len(l) and l[i] == " ":
+ i += 1
+ if i < min_indent: min_indent = i
+ for i in range(1, len(lines)):
+ l = lines[i].rstrip()
+ if len(l) >= min_indent:
+ l = l[min_indent:]
+ lines[i] = l
+ return lines
+
+ def _print_formatted_docstring(self, docstring, f):
+ """Formats the given `docstring` as Markdown and prints it to `f`."""
+ lines = self._remove_docstring_indent(docstring)
+
+ # Output the lines, identifying "Args" and other section blocks.
+ i = 0
+
+ def _at_start_of_section():
+ """Returns the header if lines[i] is at start of a docstring section."""
+ l = lines[i]
+ match = _section_re.match(l)
+ if match and i + 1 < len(
+ lines) and lines[i + 1].startswith(" "):
+ return match.group(1)
+ else:
+ return None
+
+ while i < len(lines):
+ l = lines[i]
+
+ section_header = _at_start_of_section()
+ if section_header:
+ if i == 0 or lines[i-1]:
+ print >>f, ""
+ # Use at least H4 to keep these out of the TOC.
+ print >>f, "##### " + section_header + ":"
+ print >>f, ""
+ i += 1
+ outputting_list = False
+ while i < len(lines):
+ l = lines[i]
+ # A new section header terminates the section.
+ if _at_start_of_section():
+ break
+ match = _arg_re.match(l)
+ if match:
+ if not outputting_list:
+ # We need to start a list. In Markdown, a blank line needs to
+ # precede a list.
+ print >>f, ""
+ outputting_list = True
+ suffix = l[len(match.group()):].lstrip()
+ print >>f, "* <b>" + match.group(1) + "</b>: " + suffix
+ else:
+ # For lines that don't start with _arg_re, continue the list if it
+ # has enough indentation.
+ outputting_list &= l.startswith(" ")
+ print >>f, l
+ i += 1
+ else:
+ print >>f, l
+ i += 1
+
+ def _print_function(self, f, prefix, fullname, func):
+ """Prints the given function to `f`."""
+ heading = prefix + " " + fullname
+ if not isinstance(func, property):
+ heading += self._generate_signature_for_function(func)
+ heading += " {#%s}" % _get_anchor(self._module_to_name, fullname)
+ print >>f, heading
+ print >>f, ""
+ self._print_formatted_docstring(inspect.getdoc(func), f)
+ print >>f, ""
+
+ def _write_member_markdown_to_file(self, f, name, member):
+ """Print `member` to `f`."""
+ if inspect.isfunction(member):
+ print >>f, "- - -"
+ print >>f, ""
+ self._print_function(f, "###", name, member)
+ print >>f, ""
+ elif inspect.ismethod(member):
+ print >>f, "- - -"
+ print >>f, ""
+ self._print_function(f, "####", name, member)
+ print >>f, ""
+ elif isinstance(member, property):
+ print >>f, "- - -"
+ print >>f, ""
+ self._print_function(f, "####", name, member)
+ elif inspect.isclass(member):
+ print >>f, "- - -"
+ print >>f, ""
+ print >>f, "### class %s {#%s}" % (
+ name, _get_anchor(self._module_to_name, name))
+ print >>f, ""
+ self._write_class_markdown_to_file(f, name, member)
+ print >>f, ""
+ else:
+ raise RuntimeError("Member %s has unknown type %s" % (name, type(member)))
+
+ def _write_docstring_markdown_to_file(self, f, docstring, members, imports):
+ for l in self._remove_docstring_indent(docstring):
+ if l.startswith(_member_mark):
+ name = l[len(_member_mark):].strip(" \t")
+ if name in members:
+ self._documented.add(name)
+ self._mentioned.add(name)
+ self._write_member_markdown_to_file(f, *members[name])
+ del members[name]
+ elif name in imports:
+ self._write_module_markdown_to_file(f, imports[name])
+ else:
+ raise ValueError("%s: unknown member `%s`" % (self._title, name))
+ else:
+ print >>f, l
+
+ def _write_class_markdown_to_file(self, f, name, cls):
+ """Write the class doc to 'f'.
+
+ Args:
+ f: File to write to.
+ prefix: Prefix for names.
+ cls: class object.
+ name: name to use.
+ """
+ # Build the list of class methods to document.
+ methods = dict(self.get_class_members(name, cls))
+ # Used later to check if any methods were called out in the class
+ # docstring.
+ num_methods = len(methods)
+ self._write_docstring_markdown_to_file(f, inspect.getdoc(cls), methods, {})
+
+ # If some methods were not described, describe them now if they are
+ # defined by the class itself (not inherited). If NO methods were
+ # described, describe all methods.
+ #
+ # TODO(mdevin): when all methods have been categorized make it an error
+ # if some methods are not categorized.
+ any_method_called_out = (len(methods) != num_methods)
+ if any_method_called_out:
+ other_methods = {n: m for n, m in methods.iteritems()
+ if n in cls.__dict__}
+ if other_methods:
+ print >>f, "\n#### Other Methods"
+ else:
+ other_methods = methods
+ for name in sorted(other_methods):
+ self._write_member_markdown_to_file(f, *other_methods[name])
+
+ def _write_module_markdown_to_file(self, f, module):
+ imports = dict(self.get_imported_modules(module))
+ self._write_docstring_markdown_to_file(f, inspect.getdoc(module),
+ self._members, imports)
+
+ def write_markdown_to_file(self, f):
+ """Prints this library to file `f`.
+
+ Args:
+ f: File to write to.
+
+ Returns:
+ Dictionary of documented members.
+ """
+ print >>f, "<!-- This file is machine generated: DO NOT EDIT! -->"
+ print >>f, ""
+ # TODO(mdevin): Do not insert these. Let the doc writer put them in
+ # the module docstring explicitly.
+ print >>f, "#", self._title
+ print >>f, "[TOC]"
+ print >>f, ""
+ if self._module is not None:
+ self._write_module_markdown_to_file(f, self._module)
+
+ def write_other_members(self, f, catch_all=False):
+ """Writes the leftover members to `f`.
+
+ Args:
+ f: File to write to.
+ catch_all: If true, document all missing symbols from any module.
+ Otherwise, document missing symbols from just this module.
+ """
+ if catch_all:
+ names = self._members.iteritems()
+ else:
+ names = inspect.getmembers(self._module)
+ leftovers = []
+ for name, _ in names:
+ if name in self._members and name not in self._documented:
+ leftovers.append(name)
+ if leftovers:
+ print "%s: undocumented members: %d" % (self._title, len(leftovers))
+ print >>f, "\n## Other Functions and Classes"
+ for name in sorted(leftovers):
+ print " %s" % name
+ self._documented.add(name)
+ self._mentioned.add(name)
+ self._write_member_markdown_to_file(f, *self._members[name])
+
+ def assert_no_leftovers(self):
+ """Generate an error if there are leftover members."""
+ leftovers = []
+ for name in self._members.iterkeys():
+ if name in self._members and name not in self._documented:
+ leftovers.append(name)
+ if leftovers:
+ raise RuntimeError("%s: undocumented members: %s" %
+ (self._title, ", ".join(leftovers)))
+
+
+def write_libraries(dir, libraries):
+ """Write a list of libraries to disk.
+
+ Args:
+ dir: Output directory.
+ libraries: List of (filename, library) pairs.
+ """
+ files = [open(os.path.join(dir, k), "w") for k, _ in libraries]
+ # Document mentioned symbols for all libraries
+ for f, (_, v) in zip(files, libraries):
+ v.write_markdown_to_file(f)
+ # Document symbols that no library mentioned. We do this after writing
+ # out all libraries so that earlier libraries know what later libraries
+ # documented.
+ for f, (_, v) in zip(files, libraries):
+ v.write_other_members(f)
+ f.close()
diff --git a/tensorflow/python/framework/errors.py b/tensorflow/python/framework/errors.py
new file mode 100644
index 0000000000..fe8f107cec
--- /dev/null
+++ b/tensorflow/python/framework/errors.py
@@ -0,0 +1,410 @@
+"""Exception types for TensorFlow errors."""
+import traceback
+import warnings
+
+from tensorflow.core.lib.core import error_codes_pb2
+
+
+class OpError(Exception):
+ """A generic error that is raised when TensorFlow execution fails.
+
+ Whenever possible, the session will raise a more specific subclass
+ of `OpError` from the `tf.errors` module.
+
+ @@op
+ @@node_def
+ """
+
+ def __init__(self, node_def, op, message, error_code):
+ """Creates a new OpError indicating that a particular op failed.
+
+ Args:
+ node_def: The graph_pb2.NodeDef proto representing the op that failed.
+ op: The ops.Operation that failed, if known; otherwise None.
+ message: The message string describing the failure.
+ error_code: The error_codes_pb2.Code describing the error.
+ """
+ super(OpError, self).__init__()
+ self._message = message
+ self._node_def = node_def
+ self._op = op
+ self._error_code = error_code
+
+ @property
+ def message(self):
+ """The error message that describes the error."""
+ return self._message
+
+ @property
+ def op(self):
+ """The operation that failed, if known.
+
+ *N.B.* If the failed op was synthesized at runtime, e.g. a `Send`
+ or `Recv` op, there will be no corresponding
+ [`Operation`](framework.md#Operation) object. In that case, this
+ will return `None`, and you should instead use the
+ [`node_def`](OpError.node_def) to discover information about the op.
+
+ Returns:
+ The `Operation` that failed, or None.
+ """
+ return self._op
+
+ @property
+ def error_code(self):
+ """The integer error code that describes the error."""
+ return self._error_code
+
+ @property
+ def node_def(self):
+ """The `NodeDef` proto representing the op that failed."""
+ return self._node_def
+
+ def __str__(self):
+ if self._op is not None:
+ output = ["%s\nCaused by op %r, defined at:\n"
+ % (self.message, self._op.name,)]
+ curr_traceback_list = traceback.format_list(self._op.traceback)
+ output.extend(curr_traceback_list)
+ original_op = self._op._original_op
+ while original_op is not None:
+ output.append(
+ "\n...which was originally created as op %r, defined at:\n"
+ % (original_op.name,))
+ prev_traceback_list = curr_traceback_list
+ curr_traceback_list = traceback.format_list(original_op.traceback)
+
+ # Attempt to elide large common subsequences of the subsequent
+ # stack traces.
+ #
+ # TODO(mrry): Consider computing the actual longest common subsequence.
+ is_eliding = False
+ elide_count = 0
+ last_elided_line = None
+ for line, line_in_prev in zip(curr_traceback_list, prev_traceback_list):
+ if line == line_in_prev:
+ if is_eliding:
+ elide_count += 1
+ last_elided_line = line
+ else:
+ output.append(line)
+ is_eliding = True
+ elide_count = 0
+ else:
+ if is_eliding:
+ if elide_count > 0:
+ output.extend(
+ ["[elided %d identical lines from previous traceback]\n"
+ % (elide_count - 1,), last_elided_line])
+ is_eliding = False
+ output.extend(line)
+
+ original_op = original_op._original_op
+ return ''.join(output)
+ else:
+ return self.message
+
+
+OK = error_codes_pb2.OK
+CANCELLED = error_codes_pb2.CANCELLED
+UNKNOWN = error_codes_pb2.UNKNOWN
+INVALID_ARGUMENT = error_codes_pb2.INVALID_ARGUMENT
+DEADLINE_EXCEEDED = error_codes_pb2.DEADLINE_EXCEEDED
+NOT_FOUND = error_codes_pb2.NOT_FOUND
+ALREADY_EXISTS = error_codes_pb2.ALREADY_EXISTS
+PERMISSION_DENIED = error_codes_pb2.PERMISSION_DENIED
+UNAUTHENTICATED = error_codes_pb2.UNAUTHENTICATED
+RESOURCE_EXHAUSTED = error_codes_pb2.RESOURCE_EXHAUSTED
+FAILED_PRECONDITION = error_codes_pb2.FAILED_PRECONDITION
+ABORTED = error_codes_pb2.ABORTED
+OUT_OF_RANGE = error_codes_pb2.OUT_OF_RANGE
+UNIMPLEMENTED = error_codes_pb2.UNIMPLEMENTED
+INTERNAL = error_codes_pb2.INTERNAL
+UNAVAILABLE = error_codes_pb2.UNAVAILABLE
+DATA_LOSS = error_codes_pb2.DATA_LOSS
+
+
+class CancelledError(OpError):
+ """Raised when an operation or step is cancelled.
+
+ For example, a long-running operation (e.g.
+ [`queue.enqueue()`](io_ops.md#QueueBase.enqueue) may be cancelled by
+ running another operation (e.g.
+ [`queue.close(cancel_pending_enqueues=True)`](io_ops.md#QueueBase.close),
+ or by [closing the session](client.md#Session.close). A step that is
+ running such a long-running operation will fail by raising `CancelledError`.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates a `CancelledError`."""
+ super(CancelledError, self).__init__(node_def, op, message, CANCELLED)
+
+
+class UnknownError(OpError):
+ """Unknown error.
+
+ An example of where this error may be returned is if a Status value
+ received from another address space belongs to an error-space that
+ is not known to this address space. Also errors raised by APIs that
+ do not return enough error information may be converted to this
+ error.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message, error_code=UNKNOWN):
+ """Creates an `UnknownError`."""
+ super(UnknownError, self).__init__(node_def, op, message, error_code)
+
+
+class InvalidArgumentError(OpError):
+ """Raised when an operation receives an invalid argument.
+
+ This may occur, for example, if an operation is receives an input
+ tensor that has an invalid value or shape. For example, the
+ [`tf.matmul()`](math_ops.md#matmul) op will raise this error if it
+ receives an input that is not a matrix, and the
+ [`tf.reshape()`](array_ops.md#reshape) op will raise this error if
+ the new shape does not match the number of elements in the input
+ tensor.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates an `InvalidArgumentError`."""
+ super(InvalidArgumentError, self).__init__(node_def, op, message,
+ INVALID_ARGUMENT)
+
+
+class DeadlineExceededError(OpError):
+ """Raised when a deadline expires before an operation could complete.
+
+ This exception is not currently used.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates a `DeadlineExceededError`."""
+ super(DeadlineExceededError, self).__init__(node_def, op, message,
+ DEADLINE_EXCEEDED)
+
+
+class NotFoundError(OpError):
+ """Raised when a requested entity (e.g., a file or directory) was not found.
+
+ For example, running the
+ [`tf.WholeFileReader.read()`](io_ops.md#WholeFileReader) operation
+ could raise `NotFoundError` if it receives the name of a file that
+ does not exist.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates a `NotFoundError`."""
+ super(NotFoundError, self).__init__(node_def, op, message, NOT_FOUND)
+
+
+class AlreadyExistsError(OpError):
+ """Raised when an entity that we attempted to create already exists.
+
+ For example, running an operation that saves a file
+ (e.g. [`tf.train.Saver.save()`](train.md#Saver.save)) could
+ potentially raise this exception if an explicit filename for an
+ existing file was passed.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates an `AlreadyExistsError`."""
+ super(AlreadyExistsError, self).__init__(node_def, op, message,
+ ALREADY_EXISTS)
+
+
+class PermissionDeniedError(OpError):
+ """Raised when the caller does not have permission to run an operation.
+
+ For example, running the
+ [`tf.WholeFileReader.read()`](io_ops.md#WholeFileReader) operation
+ could raise `PermissionDeniedError` if it receives the name of a
+ file for which the user does not have the read file permission.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates a `PermissionDeniedError`."""
+ super(PermissionDeniedError, self).__init__(node_def, op, message,
+ PERMISSION_DENIED)
+
+
+class UnauthenticatedError(OpError):
+ """The request does not have valid authentication credentials.
+
+ This exception is not currently used.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates an `UnauthenticatedError`."""
+ super(UnauthenticatedError, self).__init__(node_def, op, message,
+ UNAUTHENTICATED)
+
+
+class ResourceExhaustedError(OpError):
+ """Some resource has been exhausted.
+
+ For example, this error might be raised if a per-user quota is
+ exhausted, or perhaps the entire file system is out of space.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates a `ResourceExhaustedError`."""
+ super(ResourceExhaustedError, self).__init__(node_def, op, message,
+ RESOURCE_EXHAUSTED)
+
+
+class FailedPreconditionError(OpError):
+ """Operation was rejected because the system is not in a state to execute it.
+
+ This exception is most commonly raised when running an operation
+ that reads a [`tf.Variable`](state_ops.md#Variable) before it has
+ been initialized.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates a `FailedPreconditionError`."""
+ super(FailedPreconditionError, self).__init__(node_def, op, message,
+ FAILED_PRECONDITION)
+
+
+class AbortedError(OpError):
+ """The operation was aborted, typically due to a concurrent action.
+
+ For example, running a [`queue.enqueue()`](io_ops.md#QueueBase.enqueue)
+ operation may raise `AbortedError` if a
+ [`queue.close()`](io_ops.md@QueueBase.close) operation previously ran.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates an `AbortedError`."""
+ super(AbortedError, self).__init__(node_def, op, message, ABORTED)
+
+
+class OutOfRangeError(OpError):
+ """Raised when an operation executed past the valid range.
+
+ This exception is raised in "end-of-file" conditions, such as when a
+ [`queue.dequeue()`](io_ops.md#QueueBase.dequeue) operation is
+ blocked on an empty queue, and a
+ [`queue.close()`](io_ops.md#QueueBase.close) operation executes.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates an `OutOfRangeError`."""
+ super(OutOfRangeError, self).__init__(node_def, op, message,
+ OUT_OF_RANGE)
+
+
+class UnimplementedError(OpError):
+ """Raised when an operation has not been implemented.
+
+ Some operations may raise this error when passed otherwise-valid
+ arguments that it does not currently support. For example, running
+ the [`tf.nn.max_pool()`](nn.md#max_pool) operation would raise this
+ error if pooling was requested on the batch dimension, because this
+ is not yet supported.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates an `UnimplementedError`."""
+ super(UnimplementedError, self).__init__(node_def, op, message,
+ UNIMPLEMENTED)
+
+
+class InternalError(OpError):
+ """Raised when the system experiences an internal error.
+
+ This exception is raised when some invariant expected by the runtime
+ has been broken. Catching this exception is not recommended.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates an `InternalError`."""
+ super(InternalError, self).__init__(node_def, op, message, INTERNAL)
+
+
+class UnavailableError(OpError):
+ """Raised when the runtime is currently unavailable.
+
+ This exception is not currently used.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates an `UnavailableError`."""
+ super(UnavailableError, self).__init__(node_def, op, message,
+ UNAVAILABLE)
+
+
+class DataLossError(OpError):
+ """Raised when unrecoverable data loss or corruption is encountered.
+
+ For example, this may be raised by running a
+ [`tf.WholeFileReader.read()`](io_ops.md#WholeFileReader) operation,
+ if the file is truncated while it is being read.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates a `DataLossError`."""
+ super(DataLossError, self).__init__(node_def, op, message, DATA_LOSS)
+
+
+_CODE_TO_EXCEPTION_CLASS = {
+ CANCELLED: CancelledError,
+ UNKNOWN: UnknownError,
+ INVALID_ARGUMENT: InvalidArgumentError,
+ DEADLINE_EXCEEDED: DeadlineExceededError,
+ NOT_FOUND: NotFoundError,
+ ALREADY_EXISTS: AlreadyExistsError,
+ PERMISSION_DENIED: PermissionDeniedError,
+ UNAUTHENTICATED: UnauthenticatedError,
+ RESOURCE_EXHAUSTED: ResourceExhaustedError,
+ FAILED_PRECONDITION: FailedPreconditionError,
+ ABORTED: AbortedError,
+ OUT_OF_RANGE: OutOfRangeError,
+ UNIMPLEMENTED: UnimplementedError,
+ INTERNAL: InternalError,
+ UNAVAILABLE: UnavailableError,
+ DATA_LOSS: DataLossError,
+}
+
+
+def _make_specific_exception(node_def, op, message, error_code):
+ try:
+ exc_type = _CODE_TO_EXCEPTION_CLASS[error_code]
+ return exc_type(node_def, op, message)
+ except KeyError:
+ warnings.warn("Unknown error code: %d" % error_code)
+ return UnknownError(node_def, op, message, error_code)
diff --git a/tensorflow/python/framework/errors_test.py b/tensorflow/python/framework/errors_test.py
new file mode 100644
index 0000000000..ab59a729f6
--- /dev/null
+++ b/tensorflow/python/framework/errors_test.py
@@ -0,0 +1,63 @@
+"""Tests for tensorflow.python.framework.errors."""
+import tensorflow.python.platform
+
+import warnings
+
+import tensorflow as tf
+
+from tensorflow.core.lib.core import error_codes_pb2
+
+class ErrorsTest(tf.test.TestCase):
+
+ def testUniqueClassForEachErrorCode(self):
+ for error_code, exc_type in [
+ (tf.errors.CANCELLED, tf.errors.CancelledError),
+ (tf.errors.UNKNOWN, tf.errors.UnknownError),
+ (tf.errors.INVALID_ARGUMENT, tf.errors.InvalidArgumentError),
+ (tf.errors.DEADLINE_EXCEEDED, tf.errors.DeadlineExceededError),
+ (tf.errors.NOT_FOUND, tf.errors.NotFoundError),
+ (tf.errors.ALREADY_EXISTS, tf.errors.AlreadyExistsError),
+ (tf.errors.PERMISSION_DENIED, tf.errors.PermissionDeniedError),
+ (tf.errors.UNAUTHENTICATED, tf.errors.UnauthenticatedError),
+ (tf.errors.RESOURCE_EXHAUSTED, tf.errors.ResourceExhaustedError),
+ (tf.errors.FAILED_PRECONDITION, tf.errors.FailedPreconditionError),
+ (tf.errors.ABORTED, tf.errors.AbortedError),
+ (tf.errors.OUT_OF_RANGE, tf.errors.OutOfRangeError),
+ (tf.errors.UNIMPLEMENTED, tf.errors.UnimplementedError),
+ (tf.errors.INTERNAL, tf.errors.InternalError),
+ (tf.errors.UNAVAILABLE, tf.errors.UnavailableError),
+ (tf.errors.DATA_LOSS, tf.errors.DataLossError),
+ ]:
+ # pylint: disable=protected-access
+ self.assertTrue(isinstance(
+ tf.errors._make_specific_exception(None, None, None, error_code),
+ exc_type))
+ # pylint: enable=protected-access
+
+ def testKnownErrorClassForEachErrorCodeInProto(self):
+ for error_code in error_codes_pb2.Code.values():
+ # pylint: disable=line-too-long
+ if error_code in (error_codes_pb2.OK,
+ error_codes_pb2.DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_):
+ continue
+ # pylint: enable=line-too-long
+ with warnings.catch_warnings(record=True) as w:
+ # pylint: disable=protected-access
+ exc = tf.errors._make_specific_exception(None, None, None, error_code)
+ # pylint: enable=protected-access
+ self.assertEqual(0, len(w)) # No warning is raised.
+ self.assertTrue(isinstance(exc, tf.errors.OpError))
+ self.assertTrue(tf.errors.OpError in exc.__class__.__bases__)
+
+ def testUnknownErrorCodeCausesWarning(self):
+ with warnings.catch_warnings(record=True) as w:
+ # pylint: disable=protected-access
+ exc = tf.errors._make_specific_exception(None, None, None, 37)
+ # pylint: enable=protected-access
+ self.assertEqual(1, len(w))
+ self.assertTrue("Unknown error code: 37" in str(w[0].message))
+ self.assertTrue(isinstance(exc, tf.errors.OpError))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py
new file mode 100644
index 0000000000..e317cfda8d
--- /dev/null
+++ b/tensorflow/python/framework/framework_lib.py
@@ -0,0 +1,70 @@
+# pylint: disable=wildcard-import,unused-import,g-bad-import-order,line-too-long
+"""Import names from the framework library.
+
+## Core graph data structures
+
+@@Graph
+@@Operation
+@@Tensor
+
+## Tensor types
+
+@@DType
+@@as_dtype
+
+## Utility functions
+
+@@device
+@@name_scope
+@@control_dependencies
+@@convert_to_tensor
+@@get_default_graph
+@@import_graph_def
+
+## Graph collections
+
+@@add_to_collection
+@@get_collection
+@@GraphKeys
+
+## Defining new operations
+
+@@RegisterGradient
+@@NoGradient
+@@RegisterShape
+@@TensorShape
+@@Dimension
+@@op_scope
+@@get_seed
+"""
+
+# Classes used when building a Graph.
+from tensorflow.python.framework.ops import Graph
+from tensorflow.python.framework.ops import Operation
+from tensorflow.python.framework.ops import Tensor
+from tensorflow.python.framework.ops import SparseTensor
+from tensorflow.python.framework.ops import SparseTensorValue
+from tensorflow.python.framework.ops import IndexedSlices
+
+# Utilities used when building a Graph.
+from tensorflow.python.framework.ops import device
+from tensorflow.python.framework.ops import name_scope
+from tensorflow.python.framework.ops import op_scope
+from tensorflow.python.framework.ops import control_dependencies
+from tensorflow.python.framework.ops import get_default_graph
+from tensorflow.python.framework.ops import GraphKeys
+from tensorflow.python.framework.ops import add_to_collection
+from tensorflow.python.framework.ops import get_collection
+from tensorflow.python.framework.ops import convert_to_tensor
+from tensorflow.python.framework.random_seed import get_seed
+from tensorflow.python.framework.random_seed import set_random_seed
+from tensorflow.python.framework.importer import import_graph_def
+
+# Needed when you defined a new Op in C++.
+from tensorflow.python.framework.ops import RegisterGradient
+from tensorflow.python.framework.ops import NoGradient
+from tensorflow.python.framework.ops import RegisterShape
+from tensorflow.python.framework.tensor_shape import Dimension
+from tensorflow.python.framework.tensor_shape import TensorShape
+
+from tensorflow.python.framework.types import *
diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py
new file mode 100644
index 0000000000..a726d880e7
--- /dev/null
+++ b/tensorflow/python/framework/gen_docs_combined.py
@@ -0,0 +1,114 @@
+"""Updates generated docs from Python doc comments."""
+
+import os.path
+
+import tensorflow.python.platform
+import sys
+import tensorflow as tf
+
+from tensorflow.python.framework import docs
+from tensorflow.python.framework import framework_lib
+from tensorflow.python.client import client_lib
+
+
+tf.flags.DEFINE_string("out_dir", None,
+ "Directory to which docs should be written.")
+tf.flags.DEFINE_boolean("print_hidden_regex", False,
+ "Dump a regular expression matching any hidden symbol")
+FLAGS = tf.flags.FLAGS
+
+
+def get_module_to_name():
+ return {tf: 'tf',
+ tf.errors: 'tf.errors',
+ tf.image: 'tf.image',
+ tf.nn: 'tf.nn',
+ tf.train: 'tf.train',
+ tf.python_io: 'tf.python_io'}
+
+def all_libraries(module_to_name, members, documented):
+ # A list of (filename, docs.Library) pairs representing the individual files
+ # that we want to create.
+ def library(name, title, module=None, **args):
+ if module is None:
+ module = sys.modules["tensorflow.python.ops" +
+ ("" if name == "ops" else "." + name)]
+ return (name + ".md", docs.Library(title=title,
+ module_to_name=module_to_name,
+ members=members,
+ documented=documented,
+ module=module,
+ **args))
+ return [
+ # Splits of module 'tf'.
+ library("framework", "Building Graphs", framework_lib),
+ library("constant_op", "Constants, Sequences, and Random Values"),
+ library("state_ops", "Variables"),
+ library("array_ops", "Tensor Transformations",
+ exclude_symbols=["list_diff"]),
+ library("math_ops", "Math",
+ exclude_symbols=["sparse_matmul", "arg_min", "arg_max",
+ "lin_space", "sparse_segment_mean_grad"]),
+ library("control_flow_ops", "Control Flow"),
+ library("image", "Images", tf.image, exclude_symbols=["ResizeMethod"]),
+ library("sparse_ops", "Sparse Tensors"),
+ library("io_ops", "Inputs and Readers",
+ exclude_symbols=["LookupTableBase", "HashTable",
+ "initialize_all_tables",
+ "string_to_hash_bucket"]),
+ library("python_io", "Data IO (Python functions)", tf.python_io),
+ library("nn", "Neural Network", tf.nn,
+ exclude_symbols=["deconv2d", "conv2d_backprop_input",
+ "conv2d_backprop_filter", "avg_pool_grad",
+ "max_pool_grad", "max_pool_grad_with_argmax",
+ "batch_norm_with_global_normalization_grad",
+ "lrn_grad", "relu6_grad", "softplus_grad",
+ "xw_plus_b", "relu_layer", "lrn",
+ "batch_norm_with_global_normalization",
+ "batch_norm_with_global_normalization_grad",
+ "all_candidate_sampler"]),
+ library('client', "Running Graphs", client_lib,
+ exclude_symbols=["InteractiveSession"]),
+ library("train", "Training", tf.train,
+ exclude_symbols=["Feature", "Features", "BytesList", "FloatList",
+ "Int64List", "Example", "InferenceExample",
+ "RankingExample", "SequenceExample"]),
+ ]
+
+_hidden_symbols = ["Event", "Summary",
+ "HistogramProto", "ConfigProto", "NodeDef", "GraphDef",
+ "GPUOptions", "SessionInterface", "BaseSession"]
+
+def main(unused_argv):
+ if not FLAGS.out_dir:
+ tf.logging.error("out_dir not specified")
+ return -1
+
+ # Document libraries
+ documented = set()
+ module_to_name = get_module_to_name()
+ members = docs.collect_members(module_to_name)
+ libraries = all_libraries(module_to_name, members, documented)
+ docs.write_libraries(FLAGS.out_dir, libraries)
+
+ # Make it easy to search for hidden symbols
+ if FLAGS.print_hidden_regex:
+ hidden = set(_hidden_symbols)
+ for _, lib in libraries:
+ hidden.update(lib.exclude_symbols)
+ print r"hidden symbols regex = r'\b(%s)\b'" % "|".join(sorted(hidden))
+
+ # Verify that all symbols are mentioned in some library doc.
+ catch_all = docs.Library(title="Catch All", module=None,
+ exclude_symbols=_hidden_symbols,
+ module_to_name=module_to_name, members=members,
+ documented=documented)
+ catch_all.assert_no_leftovers()
+
+ # Generate index
+ with open(os.path.join(FLAGS.out_dir, "index.md"), "w") as f:
+ docs.Index(module_to_name, members, libraries).write_markdown_to_file(f)
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/tensorflow/python/framework/gen_docs_test.sh b/tensorflow/python/framework/gen_docs_test.sh
new file mode 100755
index 0000000000..fda214d93c
--- /dev/null
+++ b/tensorflow/python/framework/gen_docs_test.sh
@@ -0,0 +1,4 @@
+#!/bin/bash -eux
+DIR=$TEST_SRCDIR/tensorflow/python
+$DIR/gen_docs_combined --out_dir $TEST_TMPDIR
+echo "PASS"
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
new file mode 100644
index 0000000000..6ad2a1b009
--- /dev/null
+++ b/tensorflow/python/framework/importer.py
@@ -0,0 +1,303 @@
+"""A utility function for importing TensorFlow graphs."""
+import contextlib
+
+import tensorflow.python.platform
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.framework import op_def_registry
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types as types_lib
+
+
+# TODO(josh11b): SWIG the code from node_def_util instead of duplicating
+# the logic here.
+def _GetNodeAttr(node_def, attr_name):
+ if attr_name not in node_def.attr:
+ raise ValueError('Expected one attr with name %r in %s.'
+ % (attr_name, str(node_def)))
+ return node_def.attr[attr_name]
+
+
+def _ArgToTypesNoRef(node_def, arg_def):
+ if arg_def.number_attr:
+ repeats = _GetNodeAttr(node_def, arg_def.number_attr).i
+ if arg_def.type_attr:
+ dtype = _GetNodeAttr(node_def, arg_def.type_attr).type
+ else:
+ assert arg_def.type != types_pb2.DT_INVALID
+ dtype = arg_def.type
+ return [dtype] * repeats
+ elif arg_def.type_attr:
+ return [_GetNodeAttr(node_def, arg_def.type_attr).type]
+ elif arg_def.type_list_attr:
+ return _GetNodeAttr(node_def, arg_def.type_list_attr).list.type
+ else:
+ assert arg_def.type != types_pb2.DT_INVALID
+ return [arg_def.type]
+
+
+def _SingleArgToTypes(node_def, arg_def):
+ types = _ArgToTypesNoRef(node_def, arg_def)
+ if arg_def.is_ref:
+ return [types_lib.as_dtype(dt).as_ref.as_datatype_enum for dt in types]
+ return types
+
+
+def _ArgsToTypes(node_def, arg_list):
+ types = []
+ for arg_def in arg_list:
+ types.extend(_SingleArgToTypes(node_def, arg_def))
+ return types
+
+
+def _InputTypes(node_def, op_dict):
+ op_def = op_dict[node_def.op]
+ return _ArgsToTypes(node_def, op_def.input_arg)
+
+
+def _OutputTypes(node_def, op_dict):
+ op_def = op_dict[node_def.op]
+ return _ArgsToTypes(node_def, op_def.output_arg)
+
+
+def _IsControlInput(input_name):
+ # Expected format: '^operation_name' (control input).
+ return input_name.startswith('^')
+
+
+def _ParseTensorName(tensor_name):
+ """Parses a tensor name into an operation name and output index.
+
+ This function will canonicalize tensor names as follows:
+
+ * "foo:0" -> ("foo", 0)
+ * "foo:7" -> ("foo", 7)
+ * "foo" -> ("foo", 0)
+ * "foo:bar:baz" -> ValueError
+
+ Args:
+ tensor_name: The name of a tensor.
+
+ Returns:
+ A tuple containing the operation name, and the output index.
+
+ Raises:
+ ValueError: If `tensor_name' cannot be interpreted as the name of a tensor.
+ """
+ components = tensor_name.split(':')
+ if len(components) == 2:
+ # Expected format: 'operation_name:output_index'.
+ try:
+ output_index = int(components[1])
+ except ValueError:
+ raise ValueError('Cannot convert %r to a tensor name.' % (tensor_name,))
+ return components[0], output_index
+ elif len(components) == 1:
+ # Expected format: 'operation_name' (implicit 0th output).
+ return components[0], 0
+ else:
+ raise ValueError('Cannot convert %r to a tensor name.' % (tensor_name,))
+
+
+def _CanonicalInputName(input_name):
+ if _IsControlInput(input_name):
+ return input_name
+ input_op_name, output_index = _ParseTensorName(input_name)
+ return '%s:%d' % (input_op_name, output_index)
+
+
+def _InvalidNodeMessage(node, message):
+ return 'graph_def is invalid at node %r: %s.' % (node.name, message)
+
+
+@contextlib.contextmanager
+def _MaybeDevice(device):
+ """Applies the given device only if device is not None or empty."""
+ if device:
+ with ops.device(device):
+ yield
+ else:
+ yield
+
+
+def import_graph_def(graph_def, input_map=None, return_elements=None,
+ name=None, op_dict=None):
+ """Imports the TensorFlow graph in `graph_def` into the Python `Graph`.
+
+ This function provides a way to import a serialized TensorFlow
+ [`GraphDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
+ protocol buffer, and extract individual objects in the `GraphDef` as
+ [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See
+ [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a
+ `GraphDef` proto.
+
+ Args:
+ graph_def: A `GraphDef` proto containing operations to be imported into
+ the default graph.
+ input_map: A dictionary mapping input names (as strings) in `graph_def`
+ to `Tensor` objects. The values of the named input tensors in the
+ imported graph will be re-mapped to the respective `Tensor` values.
+ return_elements: A list of strings containing operation names in
+ `graph_def` that will be returned as `Operation` objects; and/or
+ tensor names in `graph_def` that will be returned as `Tensor` objects.
+ name: (Optional.) A prefix that will be prepended to the names in
+ `graph_def`. Defaults to `"import"`.
+ op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
+ Must contain an `OpDef` proto for each op type named in `graph_def`.
+ If omitted, uses the `OpDef` protos registered in the global registry.
+
+ Returns:
+ A list of `Operation` and/or `Tensor` objects from the imported graph,
+ corresponding to the names in `return_elements'.
+
+ Raises:
+ TypeError: If `graph_def` is not a `GraphDef` proto,
+ `input_map' is not a dictionary mapping strings to `Tensor` objects,
+ or `return_elements` is not a list of strings.
+ ValueError: If `input_map`, or `return_elements` contains names that
+ do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
+ it refers to an unknown tensor).
+ """
+ # Type checks for inputs.
+ if not isinstance(graph_def, graph_pb2.GraphDef):
+ raise TypeError('graph_def must be a GraphDef proto.')
+ if input_map is None:
+ input_map = {}
+ else:
+ if not (isinstance(input_map, dict)
+ and all(isinstance(k, basestring) for k in input_map.keys())):
+ raise TypeError('input_map must be a dictionary mapping strings to '
+ 'Tensor objects.')
+ if (return_elements is not None
+ and not (isinstance(return_elements, (list, tuple))
+ and all(isinstance(x, basestring) for x in return_elements))):
+ raise TypeError('return_elements must be a list of strings.')
+
+ # Use a canonical representation for all tensor names.
+ input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
+ used_input_keys = set()
+
+ name_to_op = {}
+
+ if op_dict is None:
+ op_dict = op_def_registry.get_registered_ops()
+
+ with ops.op_scope(input_map.values(), name, 'import'):
+ g = ops.get_default_graph()
+
+ with ops.name_scope('_inputs'):
+ input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}
+
+ # NOTE(mrry): We do this in two passes, because there may be a cycle in
+ # `graph_def'.
+
+ # 1. Add operations without their inputs.
+ for node in graph_def.node:
+ output_types = _OutputTypes(node, op_dict)
+ with _MaybeDevice(node.device):
+ name_to_op[node.name] = g.create_op(
+ node.op, [], output_types, name=node.name, attrs=node.attr,
+ compute_shapes=False)
+
+ # 2. Add inputs to the operations.
+ for node in graph_def.node:
+ op = name_to_op[node.name]
+ input_types = _InputTypes(node, op_dict)
+
+ # NOTE(mrry): We cannot use zip here because control inputs do not appear
+ # in the list of input_types.
+ for i, input_name in enumerate(
+ [_CanonicalInputName(x) for x in node.input]):
+
+ if _IsControlInput(input_name):
+ # (a) Input is a control input that should be taken from an op
+ # in "graph_def".
+ try:
+ source_op = name_to_op[input_name[1:]]
+ except KeyError:
+ raise ValueError(
+ _InvalidNodeMessage(
+ node,
+ 'Control input %r not found in graph_def.' % (input_name,)))
+ # pylint: disable=protected-access
+ op._add_control_input(source_op)
+ # pylint: enable=protected-access
+
+ else:
+ try:
+ input_type = input_types[i]
+ except IndexError:
+ raise ValueError(_InvalidNodeMessage(
+ node, 'More inputs specified (%r) than the op expects.'
+ % (input_name,)))
+
+ if input_name in input_map:
+ # (b) Input should be replaced by a tensor from the caller.
+ source_tensor = input_map[input_name]
+ used_input_keys.add(input_name)
+
+ else:
+ # (c) Input should be taken from an op in `graph_def'.
+ operation_name, output_index = _ParseTensorName(input_name)
+ try:
+ source_op = name_to_op[operation_name]
+ source_tensor = source_op.values()[output_index]
+ except (KeyError, IndexError):
+ raise ValueError(
+ _InvalidNodeMessage(
+ node,
+ 'Input tensor %r not found in graph_def.'
+ % (input_name,)))
+
+ try:
+ # pylint: disable=protected-access
+ op._add_input(source_tensor, dtype=input_type)
+ # pylint: enable=protected-access
+ except TypeError as te:
+ raise ValueError(
+ _InvalidNodeMessage(node, 'Input tensor %r %s'
+ % (input_name, te.message)))
+
+ # pylint: disable=protected_access
+ if op._input_dtypes != input_types:
+ raise ValueError(
+ _InvalidNodeMessage(
+ node,
+ 'Input types mismatch (expected %r but got %r)'
+ % (", ".join(types_lib.as_dtype(x).name for x in input_types),
+ ", ".join(x.name for x in op._input_dtypes))))
+ # pylint: enable=protected_access
+
+ # Execute shape inference for this op.
+ # NOTE(mrry): If the graph contains a cycle, the full shape information
+ # may not be available for this op's inputs.
+ ops.set_shapes_for_outputs(op)
+
+ # Treat unused input mappings as an error, because they are likely to be
+ # due to a typo.
+ unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys)
+ if unused_input_keys:
+ raise ValueError(
+ 'Attempted to map inputs that were not found in graph_def: [%s]'
+ % ', '.join(unused_input_keys))
+
+ if return_elements is None:
+ return None
+ else:
+ ret = []
+ for name in return_elements:
+ if ':' in name:
+ try:
+ operation_name, output_index = _ParseTensorName(name)
+ ret.append(name_to_op[operation_name].outputs[output_index])
+ except (ValueError, KeyError, IndexError):
+ raise ValueError(
+ 'Requested return_element %r not found in graph_def.' % name)
+ else:
+ try:
+ ret.append(name_to_op[name])
+ except KeyError:
+ raise ValueError(
+ 'Requested return_element %r not found in graph_def.' % name)
+ return ret
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
new file mode 100644
index 0000000000..470092313a
--- /dev/null
+++ b/tensorflow/python/framework/importer_test.py
@@ -0,0 +1,546 @@
+"""Tests for tensorflow.python.framework.importer."""
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+from google.protobuf import text_format
+
+from tensorflow.core.framework import op_def_pb2
+from tensorflow.python.framework import device
+from tensorflow.python.framework import op_def_registry
+
+
+_op_list = op_def_pb2.OpList()
+text_format.Merge("""
+ op {
+ name: 'None'
+ }
+ op {
+ name: 'Oi'
+ output_arg { name: 'a' type: DT_INT32 }
+ }
+ op {
+ name: 'Or'
+ output_arg { name: 'a' type: DT_INT32 is_ref: true }
+ }
+ op {
+ name: 'Of'
+ output_arg { name: 'a' type: DT_FLOAT }
+ }
+ op {
+ name: 'Ii'
+ input_arg { name: 'a' type: DT_INT32 }
+ }
+ op {
+ name: 'If'
+ input_arg { name: 'a' type: DT_FLOAT }
+ }
+ op {
+ name: 'Oii'
+ output_arg { name: 'a' type: DT_INT32 }
+ output_arg { name: 'b' type: DT_INT32 }
+ }
+ op {
+ name: 'Oif'
+ output_arg { name: 'a' type: DT_INT32 }
+ output_arg { name: 'b' type: DT_FLOAT }
+ }
+ op {
+ name: 'Iii'
+ input_arg { name: 'a' type: DT_INT32 }
+ input_arg { name: 'b' type: DT_INT32 }
+ }
+ op {
+ name: 'Iff'
+ input_arg { name: 'a' type: DT_FLOAT }
+ input_arg { name: 'b' type: DT_FLOAT }
+ }
+ op {
+ name: 'Iif'
+ input_arg { name: 'a' type: DT_INT32 }
+ input_arg { name: 'b' type: DT_FLOAT }
+ }
+ op {
+ name: 'Iri'
+ input_arg { name: 'a' type: DT_INT32 is_ref: true }
+ input_arg { name: 'b' type: DT_INT32 }
+ }
+ op {
+ name: 'In'
+ input_arg { name: 'a' number_attr: 'N' type_attr: 'T' }
+ attr { name: 'N' type: 'int' minimum: 1 }
+ attr { name: 'T' type: 'type' }
+ }
+ op {
+ name: 'Otl'
+ output_arg { name: 'a' type_list_attr: 't' }
+ attr { name: 'T' type: 'list(type)' minimum: 1 }
+ }
+ op {
+ name: 'Unary'
+ input_arg { name: 'a' type_attr: 'T' }
+ output_arg { name: 'b' type_attr: 'T' }
+ attr { name: 'T' type: 'type' }
+ }
+""", _op_list)
+op_def_registry.register_op_list(_op_list)
+# NOTE(mrry): Dummy shape registrations for ops used in the tests.
+for op_def in _op_list.op:
+ tf.RegisterShape(op_def.name)(None)
+
+class ImportGraphDefTest(tf.test.TestCase):
+
+ def _MakeGraphDef(self, text):
+ ret = tf.GraphDef()
+ text_format.Merge(text, ret)
+ return ret
+
+ def testBasic(self):
+ with tf.Graph().as_default():
+ a, b, c, d = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oif' }
+ node { name: 'B' op: 'Otl'
+ attr { key: 't'
+ value { list { type: DT_INT32 type: DT_FLOAT } } } }
+ node { name: 'C' op: 'In'
+ attr { key: 'N' value { i: 2 } }
+ attr { key: 'T' value { type: DT_INT32 } }
+ input: 'A:0' input: 'B:0' }
+ node { name: 'D' op: 'In'
+ attr { key: 'N' value { i: 2 } }
+ attr { key: 'T' value { type: DT_FLOAT } }
+ input: 'A:1' input: 'B:1' }
+ """),
+ return_elements=['A', 'B', 'C', 'D'],
+ name='import')
+
+ # Assert that the import process creates distinct tensors.
+ self.assertNotEqual(a.outputs[0].name, a.outputs[1].name)
+ self.assertNotEqual(b.outputs[0].name, b.outputs[1].name)
+ self.assertNotEqual(a.outputs[0].name, b.outputs[0].name)
+ self.assertNotEqual(a.outputs[0].name, b.outputs[1].name)
+ self.assertNotEqual(a.outputs[1].name, b.outputs[0].name)
+ self.assertNotEqual(a.outputs[1].name, b.outputs[1].name)
+
+ # Assert that the ops are connected according to the GraphDef topology.
+ self.assertEqual(c.inputs[0], a.outputs[0])
+ self.assertEqual(c.inputs[1], b.outputs[0])
+ self.assertEqual(d.inputs[0], a.outputs[1])
+ self.assertEqual(d.inputs[1], b.outputs[1])
+
+ # Check the types of the returned ops and tensors.
+ self.assertEqual(a.type, 'Oif')
+ self.assertEqual(b.type, 'Otl')
+ self.assertEqual(c.type, 'In')
+ self.assertEqual(d.type, 'In')
+ self.assertEqual(a.outputs[0].dtype, tf.int32)
+ self.assertEqual(a.outputs[1].dtype, tf.float32)
+ self.assertEqual(b.outputs[0].dtype, tf.int32)
+ self.assertEqual(b.outputs[1].dtype, tf.float32)
+
+ # Check the names of the returned ops.
+ self.assertEqual(a.name, 'import/A')
+ self.assertEqual(b.name, 'import/B')
+ self.assertEqual(c.name, 'import/C')
+ self.assertEqual(d.name, 'import/D')
+
+ def testInputMap(self):
+ with tf.Graph().as_default():
+ feed_a_0 = tf.constant(0, dtype=tf.int32)
+ feed_b_1 = tf.constant(1, dtype=tf.int32)
+
+ a, b, c, d = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oii' }
+ node { name: 'B' op: 'Oii' }
+ node { name: 'C' op: 'In'
+ attr { key: 'N' value { i: 2 } }
+ attr { key: 'T' value { type: DT_INT32 } }
+ input: 'A:0' input: 'B:0' }
+ node { name: 'D' op: 'In'
+ attr { key: 'N' value { i: 2 } }
+ attr { key: 'T' value { type: DT_INT32 } }
+ input: 'A:1' input: 'B:1' }
+ """),
+ input_map={'A:0': feed_a_0, 'B:1': feed_b_1},
+ return_elements=['A', 'B', 'C', 'D'])
+
+ self.assertEqual(c.inputs[0], feed_a_0)
+ self.assertEqual(c.inputs[1], b.outputs[0])
+ self.assertEqual(d.inputs[0], a.outputs[1])
+ self.assertEqual(d.inputs[1], feed_b_1)
+
+ def testImplicitZerothOutput(self):
+ with tf.Graph().as_default():
+ a, b = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oii' }
+ node { name: 'B' op: 'Ii' input: 'A' }
+ """),
+ return_elements=['A', 'B'])
+
+ self.assertEqual(b.inputs[0], a.outputs[0])
+
+ def testInputMapImplicitZerothOutput(self):
+ with tf.Graph().as_default():
+ feed_a_0 = tf.constant(0, dtype=tf.int32)
+ b, = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oii' }
+ node { name: 'B' op: 'Ii' input: 'A:0' }
+ """),
+ input_map={'A': feed_a_0},
+ return_elements=['B'])
+
+ self.assertEqual(b.inputs[0], feed_a_0)
+
+ def testWithControlDependency(self):
+ with tf.Graph().as_default():
+ a, b = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'None' }
+ node { name: 'B' op: 'None' input: '^A' }
+ """),
+ return_elements=['A', 'B'])
+
+ self.assertEqual(b.control_inputs, [a])
+
+ def testWithRefs(self):
+ with tf.Graph().as_default():
+ a, b, c, d = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Or' }
+ node { name: 'B' op: 'Oi' }
+ node { name: 'C' op: 'Iii' input: 'A:0' input: 'B:0' }
+ node { name: 'D' op: 'Iri' input: 'A:0' input: 'B:0' }
+ """),
+ return_elements=['A', 'B', 'C', 'D'])
+
+ self.assertEqual(c.inputs[0], a.outputs[0])
+ self.assertEqual(c.inputs[1], b.outputs[0])
+ self.assertEqual(d.inputs[0], a.outputs[0])
+ self.assertEqual(d.inputs[1], b.outputs[0])
+
+ self.assertEqual(a.outputs[0].dtype, tf.int32_ref)
+ self.assertEqual(c._input_dtypes, [tf.int32, tf.int32])
+ self.assertEqual(c.outputs, [])
+ self.assertEqual(d._input_dtypes,
+ [tf.int32_ref, tf.int32])
+ self.assertEqual(d.outputs, [])
+
+ def testCyclic(self):
+ with tf.Graph().as_default():
+ a, b = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Unary'
+ attr { key: 'T' value { type: DT_INT32 } } input: 'B:0' }
+ node { name: 'B' op: 'Unary'
+ attr { key: 'T' value { type: DT_INT32 } } input: 'A:0' }
+ """),
+ return_elements=['A', 'B'])
+
+ self.assertEqual(a.inputs[0], b.outputs[0])
+ self.assertEqual(b.inputs[0], a.outputs[0])
+
+ def testTypeMismatchInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ node { name: 'B' op: 'If' input: 'A:0' }
+ """))
+ self.assertTrue(
+ 'Cannot convert a tensor of type int32 to an input of type float' in
+ str(e.exception))
+
+ def testInvalidSignatureTooManyInputsInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ node { name: 'B' op: 'None' input: 'A:0' }
+ """))
+ self.assertTrue('More inputs specified (u\'A:0\') than the op expects' in
+ str(e.exception))
+
+ def testInvalidSignatureNotEnoughInputsInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ node { name: 'B' op: 'Iif' input: 'A:0' }
+ """))
+ self.assertTrue('Input types mismatch (expected \'int32, float32\' but '
+ 'got \'int32\')' in str(e.exception))
+
+ def testMissingInputOpInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'B' op: 'If' input: 'A:0' }
+ """))
+ self.assertTrue('Input tensor %r not found' % (u'A:0',) in
+ str(e.exception))
+
+ def testMissingInputOpInGraphDefButAppearsInInputMap(self):
+ with tf.Graph().as_default():
+ feed_a_0 = tf.constant(5.0)
+ b, = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'B' op: 'If' input: 'A:0' }
+ """),
+ input_map={'A:0': feed_a_0},
+ return_elements=['B'])
+ self.assertEqual(b.inputs[0], feed_a_0)
+
+ def testMissingInputTensorInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Of' }
+ node { name: 'B' op: 'If' input: 'A:1' }
+ """))
+ self.assertTrue('Input tensor %r not found' % (u'A:1',) in
+ str(e.exception))
+
+ def testMissingControlInputInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'B' op: 'None' input: '^A' }
+ """))
+ self.assertTrue('Control input %r not found' % (u'^A',) in
+ str(e.exception))
+
+ def testInvalidTensorNameOutputIndexInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'B' op: 'None' input: 'A:B' }
+ """))
+ self.assertEqual(
+ 'Cannot convert %r to a tensor name.' % (u'A:B',), str(e.exception))
+
+ def testInvalidTensorNameInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'B' op: 'None' input: 'A:B:0' }
+ """))
+ self.assertEqual(
+ 'Cannot convert %r to a tensor name.' % (u'A:B:0',), str(e.exception))
+
+ def testMissingReturnOperation(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'None' }
+ """),
+ return_elements=['B'])
+ self.assertTrue('return_element %r not found in graph_def.' % ('B') in
+ str(e.exception))
+
+ def testMissingReturnTensor(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ """),
+ return_elements=['A:1'])
+ self.assertTrue('return_element %r not found in graph_def.' % ('A:1') in
+ str(e.exception))
+
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ """),
+ return_elements=['B:0'])
+ self.assertTrue('return_element %r not found in graph_def.' % ('B:0') in
+ str(e.exception))
+
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ """),
+ return_elements=['A:B:0'])
+ self.assertTrue('return_element %r not found in graph_def.' % ('A:B:0') in
+ str(e.exception))
+
+ def testMissingInputMap(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'None' }
+ """),
+ input_map={'B:0': tf.constant(5.0)})
+ self.assertTrue('not found in graph_def: [B:0]' in str(e.exception))
+
+ def testInputMapTypeMismatch(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ node { name: 'B' op: 'Ii' input: 'A:0' }
+ """),
+ input_map={'A:0': tf.constant(5.0)})
+ self.assertTrue(
+ 'Cannot convert a tensor of type float32 to an input of type int32.'
+ in str(e.exception))
+
+ def testNoReturns(self):
+ with tf.Graph().as_default() as g:
+ ret = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'None' }
+ """))
+ self.assertEqual(ret, None)
+
+ a = g.get_operation_by_name('import/A')
+ self.assertEqual(a.type, 'None')
+
+ def testOverrideNamePrefix(self):
+ with tf.Graph().as_default():
+ a, = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'None' }
+ """),
+ return_elements=['A'], name='imported_graph')
+ self.assertEqual(a.name, 'imported_graph/A')
+
+ def testEmptyGraph(self):
+ with tf.Graph().as_default() as g:
+ init_version = g.version
+ tf.import_graph_def(self._MakeGraphDef(''))
+ self.assertEqual(init_version, g.version)
+
+ def testInvalidInputForGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(TypeError) as e:
+ tf.import_graph_def('')
+ self.assertEqual(
+ 'graph_def must be a GraphDef proto.', str(e.exception))
+
+ def testInvalidInputForInputMap(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(TypeError) as e:
+ tf.import_graph_def(self._MakeGraphDef(''),
+ input_map=[tf.constant(5.0)])
+ self.assertEqual('input_map must be a dictionary mapping strings to '
+ 'Tensor objects.', str(e.exception))
+
+ def testInvalidInputForReturnOperations(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(TypeError) as e:
+ tf.import_graph_def(self._MakeGraphDef(''), return_elements=[7])
+ self.assertEqual(
+ 'return_elements must be a list of strings.', str(e.exception))
+
+ def testWithExtensionAndAttr(self):
+ with tf.Graph().as_default() as g:
+ c = tf.constant(5.0, dtype=tf.float32, name='c')
+ tf.pack([c, c], name='pack')
+ gdef = g.as_graph_def()
+
+ with self.test_session():
+ pack, = tf.import_graph_def(gdef, return_elements=['pack'])
+ self.assertAllEqual(pack.outputs[0].eval(), [5.0, 5.0])
+
+ def testWithDevice(self):
+ with tf.Graph().as_default() as g:
+ # No device.
+ a = tf.constant(3.0, name='a')
+
+ with tf.device('/cpu:0'):
+ b = tf.constant(4.0, name='b')
+ with tf.device('/job:worker'):
+ c = tf.constant(5.0, name='c')
+
+ gdef = g.as_graph_def()
+
+ with tf.Graph().as_default():
+ a2, b2, c2 = tf.import_graph_def(
+ gdef, return_elements=['a', 'b', 'c'])
+ self.assertEqual(a.device, a2.device)
+ self.assertEqual(b.device, b2.device)
+ self.assertEqual(c.device, c2.device)
+
+ with tf.Graph().as_default():
+ with tf.device(device.merge_device('/task:0')):
+ a3, b3, c3 = tf.import_graph_def(
+ gdef, return_elements=['a', 'b', 'c'])
+ self.assertEqual('/task:0', a3.device)
+ self.assertEqual('/task:0/device:CPU:0', b3.device) # canonicalized.
+ self.assertEqual(c.device + '/task:0', c3.device)
+
+ with tf.Graph().as_default():
+ with tf.device(device.merge_device('/job:ps')):
+ a4, b4, c4 = tf.import_graph_def(
+ gdef, return_elements=['a', 'b', 'c'])
+ self.assertEqual('/job:ps', a4.device)
+ self.assertEqual('/job:ps/device:CPU:0', b4.device) # canonicalized.
+ self.assertEqual(c.device, c4.device) # worker overrides ps.
+
+ with tf.Graph().as_default():
+ with tf.device(device.merge_device('/gpu:0')):
+ a5, b5, c5 = tf.import_graph_def(
+ gdef, return_elements=['a', 'b', 'c'])
+ self.assertEqual('/device:GPU:0', a5.device)
+ self.assertEqual('/device:CPU:0', b5.device) # cpu overrides gpu.
+ self.assertEqual(c.device + '/device:GPU:0', c5.device)
+
+ def testGradient(self):
+ with tf.Graph().as_default() as g:
+ inputs = tf.placeholder(tf.float32, shape=[None, 100], name="input")
+ weights = tf.placeholder(tf.float32, shape=[100, 10], name="weights")
+ biases = tf.placeholder(tf.float32, shape=[10], name="biases")
+ activations = tf.nn.relu(tf.matmul(inputs, weights) + biases,
+ name="activations")
+ loss = tf.reduce_mean(activations, name="loss")
+ gdef = g.as_graph_def()
+
+ with tf.Graph().as_default() as g:
+ input_placeholder = tf.placeholder(tf.float32, shape=[32, 100])
+ weights_var = tf.Variable(tf.truncated_normal([100, 10]), name="weights")
+ biases_var = tf.Variable(tf.zeros(10), name="biases")
+ activations, loss = tf.import_graph_def(
+ gdef,
+ input_map={"input:0": input_placeholder,
+ "weights:0": weights_var,
+ "biases:0": biases_var},
+ return_elements=["activations:0", "loss:0"])
+ self.assertEqual([32, 10], activations.get_shape())
+ self.assertEqual([], loss.get_shape())
+ weights_grad, biases_grad = tf.gradients(loss, [weights_var, biases_var])
+ self.assertEqual([100, 10], weights_grad.get_shape())
+ self.assertEqual([10], biases_grad.get_shape())
+
+ def testLargeGraph(self):
+ with self.test_session():
+ # The default message byte limit is 64M. Ours is 2G with a warning at 512.
+ # Adding a 150M entries float32 tensor should blow through the warning,
+ # but not the hard limit.
+ input_shape = [150, 1024, 1024]
+ tensor_input = tf.np.random.rand(*input_shape).astype(tf.np.float32)
+ t = tf.constant(tensor_input, shape=input_shape)
+ g = tf.identity(t)
+ g.eval()
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/python/framework/op_def_registry.py b/tensorflow/python/framework/op_def_registry.py
new file mode 100644
index 0000000000..2ec8c94a10
--- /dev/null
+++ b/tensorflow/python/framework/op_def_registry.py
@@ -0,0 +1,23 @@
+"""Global registry for OpDefs."""
+
+from tensorflow.core.framework import op_def_pb2
+
+
+_registered_ops = {}
+
+
+def register_op_list(op_list):
+ """Register all the ops in an op_def_pb2.OpList."""
+ if not isinstance(op_list, op_def_pb2.OpList):
+ raise TypeError("%s is %s, not an op_def_pb2.OpList" %
+ (op_list, type(op_list)))
+ for op_def in op_list.op:
+ if op_def.name in _registered_ops:
+ assert _registered_ops[op_def.name] == op_def
+ else:
+ _registered_ops[op_def.name] = op_def
+
+
+def get_registered_ops():
+ """Returns a dictionary mapping names to OpDefs."""
+ return _registered_ops
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
new file mode 100644
index 0000000000..0b0442cea1
--- /dev/null
+++ b/tensorflow/python/framework/ops.py
@@ -0,0 +1,2985 @@
+"""Classes and functions used to construct graphs."""
+# pylint: disable=g-bad-name
+import collections
+import contextlib
+import copy
+import linecache
+import re
+import sys
+import threading
+import weakref
+
+import tensorflow.python.platform
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import device as pydev
+from tensorflow.python.framework import registry
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types
+
+
+def _convert_stack(stack):
+ """Converts a stack extracted using _extract_stack() to a traceback stack.
+
+ Args:
+ stack: A list of n 4-tuples, (filename, lineno, name, frame_globals).
+
+ Returns:
+ A list of n 4-tuples (filename, lineno, name, code), where the code tuple
+ element is calculated from the corresponding elements of the input tuple.
+ """
+ ret = []
+ for filename, lineno, name, frame_globals in stack:
+ linecache.checkcache(filename)
+ line = linecache.getline(filename, lineno, frame_globals)
+ if line:
+ line = line.strip()
+ else:
+ line = None
+ ret.append((filename, lineno, name, line))
+ return ret
+
+
+# pylint: disable=line-too-long
+def _extract_stack():
+ """A lightweight re-implementation of traceback.extract_stack.
+
+ NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for
+ each stack frame using linecache, which results in an abundance of stat()
+ calls. This implementation does not retrieve the code, and any consumer
+ should apply _convert_stack to the result to obtain a traceback that can
+ be formatted etc. using traceback methods.
+
+ Returns:
+ A list of 4-tuples (filename, lineno, name, frame_globals) corresponding to
+ the call stack of the current thread.
+ """
+ # pylint: enable=line-too-long
+ try:
+ raise ZeroDivisionError
+ except ZeroDivisionError:
+ f = sys.exc_info()[2].tb_frame.f_back
+ ret = []
+ while f is not None:
+ lineno = f.f_lineno
+ co = f.f_code
+ filename = co.co_filename
+ name = co.co_name
+ frame_globals = f.f_globals
+ ret.append((filename, lineno, name, frame_globals))
+ f = f.f_back
+ ret.reverse()
+ return ret
+
+
+class Tensor(object):
+ """Represents a value produced by an `Operation`.
+
+ A `Tensor` is a symbolic handle to one of the outputs of an
+ `Operation`. It does not hold the values of that operation's output,
+ but instead provides a means of computing those values in a
+ TensorFlow [`Session`](client.md#Session).
+
+ This class has two primary purposes:
+
+ 1. A `Tensor` can be passed as an input to another `Operation`.
+ This builds a dataflow connection between operations, which
+ enables TensorFlow to execute an entire `Graph` that represents a
+ large, multi-step computation.
+
+ 2. After the graph has been launched in a session, the value of the
+ `Tensor` can be computed by passing it to
+ [`Session.run()`](client.md#Session.run).
+ `t.eval()` is a shortcut for calling
+ `tf.get_default_session().run(t)`.
+
+ In the following example, `c`, `d`, and `e` are symbolic `Tensor`
+ objects, whereas `result` is a numpy array that stores a concrete
+ value:
+
+ ```python
+ # Build a dataflow graph.
+ c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
+ d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
+ e = tf.matmul(c, d)
+
+ # Construct a `Session` to execut the graph.
+ sess = tf.Session()
+
+ # Execute the graph and store the value that `e` represents in `result`.
+ result = sess.run(e)
+ ```
+
+ @@dtype
+ @@name
+ @@value_index
+ @@graph
+ @@op
+ @@consumers
+
+ @@eval
+
+ @@get_shape
+ @@set_shape
+
+ """
+
+ # List of Python operators that we allow to override.
+ OVERLOADABLE_OPERATORS = {
+ # Binary.
+ "__add__", "__radd__",
+ "__sub__", "__rsub__",
+ "__mul__", "__rmul__",
+ "__div__", "__rdiv__",
+ "__truediv__", "__rtruediv__",
+ "__mod__", "__rmod__",
+ "__lt__", "__le__",
+ "__gt__", "__ge__",
+ "__and__", "__rand__",
+ "__or__", "__ror__",
+ "__xor__", "__rxor__",
+ "__getitem__",
+ # Unary.
+ "__invert__",
+ "__neg__", "__abs__"}
+
+ def __init__(self, op, value_index, dtype):
+ """Creates a new `Tensor`.
+
+ Args:
+ op: An `Operation`. `Operation` that computes this tensor.
+ value_index: An `int`. Index of the operation's endpoint that produces
+ this tensor.
+ dtype: A `types.DType`. Type of data stored in this tensor.
+
+ Raises:
+ TypeError: If the op is not an `Operation`.
+ """
+ if not isinstance(op, Operation):
+ raise TypeError("op needs to be an Operation: %s" % op)
+ self._op = op
+ self._value_index = value_index
+ self._dtype = types.as_dtype(dtype)
+ self._shape = tensor_shape.unknown_shape()
+ # List of operations that use this Tensor as input. We maintain this list
+ # to easily navigate a computation graph.
+ self._consumers = []
+
+ @property
+ def op(self):
+ """The `Operation` that produces this tensor as an output."""
+ return self._op
+
+ @property
+ def dtype(self):
+ """The `DType` of elements in this tensor."""
+ return self._dtype
+
+ @property
+ def graph(self):
+ """The `Graph` that contains this tensor."""
+ return self._op.graph
+
+ @property
+ def name(self):
+ """The string name of this tensor."""
+ if not self._op.name:
+ raise ValueError("Operation was not named: %s" % self._op)
+ return "%s:%d" % (self._op.name, self._value_index)
+
+ @property
+ def device(self):
+ """The name of the device on which this tensor will be produced, or None."""
+ return self._op.device
+
+ def _shape_as_list(self):
+ if self._shape.ndims is not None:
+ return [dim.value for dim in self._shape.dims]
+ else:
+ return None
+
+ def get_shape(self):
+ """Returns the `TensorShape` that represents the shape of this tensor.
+
+ The shape is computed using shape inference functions that are
+ registered for each `Operation` type using `tf.RegisterShape`.
+ See [`TensorShape`](framework.md#TensorShape) for more details of what a shape
+ represents.
+
+ The inferred shape of a tensor is used to provide shape
+ information without having to launch the graph in a session. This
+ can be used for debugging, and providing early error messages. For
+ example:
+
+ ```python
+ c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+
+ print c.get_shape()
+ ==> TensorShape([Dimension(2), Dimension(3)])
+
+ d = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]])
+
+ print d.get_shape()
+ ==> TensorShape([Dimension(4), Dimension(2)])
+
+ # Raises a ValueError, because `c` and `d` do not have compatible
+ # inner dimensions.
+ e = tf.matmul(c, d)
+
+ f = tf.matmul(c, d, transpose_a=True, transpose_b=True)
+
+ print f.get_shape()
+ ==> TensorShape([Dimension(3), Dimension(4)])
+ ```
+
+ In some cases, the inferred shape may have unknown dimensions. If
+ the caller has additional information about the values of these
+ dimensions, `Tensor.set_shape()` can be used to augment the
+ inferred shape.
+
+ Returns:
+ A `TensorShape` representing the shape of this tensor.
+ """
+ return self._shape
+
+ def set_shape(self, shape):
+ """Updates the shape of this tensor.
+
+ This method can be called multiple times, and will merge the given
+ `shape` with the current shape of this tensor. It can be used to
+ provide additional information about the shape of this tensor that
+ cannot be inferred from the graph alone. For example, this can be used
+ to provide additional information about the shapes of images:
+
+ ```python
+ _, image_data = tf.TFRecordReader(...).read(...)
+ image = tf.image.decode_png(image_data, channels=3)
+
+ # The height and width dimensions of `image` are data dependent, and
+ # cannot be computed without executing the op.
+ print image.get_shape()
+ ==> TensorShape([Dimension(None), Dimension(None), Dimension(3)])
+
+ # We know that each image in this dataset is 28 x 28 pixels.
+ image.set_shape([28, 28, 3])
+ print image.get_shape()
+ ==> TensorShape([Dimension(28), Dimension(28), Dimension(3)])
+ ```
+
+ Args:
+ shape: A `TensorShape` representing the shape of this tensor.
+
+ Raises:
+ ValueError: If `shape` is not compatible with the current shape of
+ this tensor.
+ """
+ self._shape = self._shape.merge_with(shape)
+
+ @property
+ def value_index(self):
+ """The index of this tensor in the outputs of its `Operation`."""
+ return self._value_index
+
+ def consumers(self):
+ """Returns a list of `Operation`s that consume this tensor.
+
+ Returns:
+ A list of `Operation`s.
+ """
+ return self._consumers
+
+ def _add_consumer(self, consumer):
+ """Add a consumer to this tensor.
+
+ Args:
+ consumer: an Operation.
+
+ Raises:
+ TypeError: if the consumer is not an Operation.
+ """
+ if not isinstance(consumer, Operation):
+ raise TypeError("Consumer must be an Operation: %s" % consumer)
+ self._consumers.append(consumer)
+
+ def _as_node_def_input(self):
+ """Return a value to use for the NodeDef "input" attribute.
+
+ The returned string can be used in a NodeDef "input" attribute
+ to indicate that the NodeDef uses this Tensor as input.
+
+ Raises:
+ ValueError: if this Tensor's Operation does not have a name.
+
+ Returns:
+ a string.
+ """
+ if not self._op.name:
+ raise ValueError("Operation was not named: %s" % self._op)
+ if self._value_index == 0:
+ return self._op.name
+ else:
+ return "%s:%d" % (self._op.name, self._value_index)
+
+ def __str__(self):
+ return "Tensor(\"%s\"%s%s%s)" % (
+ self.name,
+ (", shape=%s" % self.get_shape())
+ if self.get_shape().ndims is not None else "",
+ (", dtype=%s" % self._dtype.name) if self._dtype else "",
+ (", device=%s" % self.device) if self.device else "")
+
+ def __hash__(self):
+ # Necessary to support Python's collection membership operators
+ return id(self)
+
+ def __eq__(self, other):
+ # Necessary to support Python's collection membership operators
+ return id(self) == id(other)
+
+ # NOTE(mrry): This enables the Tensor's overloaded "right" binary
+ # operators to run when the left operand is an ndarray, because it
+ # accords the Tensor class higher priority than an ndarray, or a
+ # numpy matrix.
+ # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
+ # mechanism, which allows more control over how Tensors interact
+ # with ndarrays.
+ __array_priority__ = 100
+
+ @staticmethod
+ def _override_operator(operator, func):
+ """Overrides (string) operator on Tensors to call func.
+
+ Args:
+ operator: the string name of the operator to override.
+ func: the function that replaces the overriden operator.
+
+ Raises:
+ ValueError: If operator has already been overwritten,
+ or if operator is not allowed to be overwritten.
+ """
+ if getattr(Tensor, operator, None) is not None:
+ # check to see if this is a default method-wrapper which will be true
+ # for the comparison operators.
+ if not isinstance(getattr(Tensor, operator, None), type(all.__call__)):
+ raise ValueError("operator %s cannot be overwritten again." % operator)
+ if operator not in Tensor.OVERLOADABLE_OPERATORS:
+ raise ValueError("Overriding %s is disallowed" % operator)
+ setattr(Tensor, operator, func)
+
+ def __iter__(self):
+ """Dummy method to prevent iteration. Do not call.
+
+ NOTE(mrry): If we register __getitem__ as an overloaded operator,
+ Python will valiantly attempt to iterate over the Tensor from 0 to
+ infinity. Declaring this method prevents this unintended
+ behavior.
+
+ Raises:
+ TypeError: when invoked.
+ """
+ raise TypeError("'Tensor' object is not iterable")
+
+ def eval(self, feed_dict=None, session=None):
+ """Evaluates this tensor in a `Session`.
+
+ Calling this method will execute all preceding operations that
+ produce the inputs needed for the operation that produces this
+ tensor.
+
+ *N.B.* Before invoking `Tensor.eval()`, its graph must have been
+ launched in a session, and either a default session must be
+ available, or `session` must be specified explicitly.
+
+ Args:
+ feed_dict: A dictionary that maps `Tensor` objects to feed values.
+ See [`Session.run()`](client.md#Session.run) for a description of
+ the valid feed values.
+ session: (Optional.) The `Session` to be used to evaluate this tensor. If
+ none, the default session will be used.
+
+ Returns:
+ A numpy array corresponding to the value of this tensor.
+
+ """
+ return _eval_using_default_session(self, feed_dict, self.graph, session)
+
+
+def _TensorTensorConversionFunction(t, dtype=None, name=None):
+ _ = name
+ if dtype and not dtype.is_compatible_with(t.dtype):
+ raise ValueError(
+ "Tensor conversion requested dtype %s for Tensor with dtype %s: %r"
+ % (dtype.name, t.dtype.name, str(t)))
+ return t
+
+
+_tensor_conversion_func_registry = {
+ 0: [(Tensor, _TensorTensorConversionFunction)]}
+
+
+def convert_to_tensor(value, dtype=None, name=None):
+ """Converts the given `value` to a `Tensor`.
+
+ This function converts Python objects of various types to `Tensor`
+ objects. It accepts `Tensor` objects, numpy arrays, Python lists,
+ and Python scalars. For example:
+
+ ```python
+ import numpy as np
+ array = np.random.rand((32, 100, 100))
+
+ def my_func(arg):
+ arg = tf.convert_to_tensor(arg, dtype=tf.float32)
+ return tf.matmul(arg, arg) + arg
+
+ # The following calls are equivalent.
+ value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]))
+ value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
+ value_3 = my_func(numpy.array([[1.0, 2.0], [3.0, 4.0]], dtype=numpy.float32))
+ ```
+
+ This function can be useful when composing a new operation in Python
+ (such as `my_func` in the example above). All standard Python op
+ constructors apply this function to each of their Tensor-valued
+ inputs, which allows those ops to accept numpy arrays, Python lists,
+ and scalars in addition to `Tensor` objects.
+
+ Args:
+ value: An object whose type has a registered `Tensor` conversion function.
+ dtype: Optional element type for the returned tensor. If missing, the
+ type is inferred from the type of `value`.
+ name: Optional name to use if a new `Tensor` is created.
+
+ Returns:
+ A `Tensor` based on `value`.
+
+ Raises:
+ TypeError: If no conversion function is registered for `value`.
+ RuntimeError: If a registered conversion function returns an invalid value.
+
+ """
+ error_prefix = "" if name is None else "%s: " % name
+ if dtype is not None:
+ dtype = types.as_dtype(dtype)
+ for _, funcs_at_priority in sorted(_tensor_conversion_func_registry.items()):
+ for base_type, conversion_func in funcs_at_priority:
+ if isinstance(value, base_type):
+ ret = conversion_func(value, dtype=dtype, name=name)
+ if not isinstance(ret, Tensor):
+ raise RuntimeError(
+ "%sConversion function %r for type %s returned non-Tensor: %r"
+ % (error_prefix, conversion_func, base_type, ret))
+ if dtype and not dtype.is_compatible_with(ret.dtype):
+ raise RuntimeError(
+ "%sConversion function %r for type %s returned incompatible "
+ "dtype: requested = %s, actual = %s"
+ % (error_prefix, conversion_func, base_type,
+ dtype.name, ret.dtype.name))
+ return ret
+ raise TypeError("%sCannot convert %r with type %s to Tensor: "
+ "no conversion function registered."
+ % (error_prefix, value, type(value)))
+
+
+def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
+ """Converts the given object to a `Tensor` or an `IndexedSlices`.
+
+ If `value` is an `IndexedSlices` it is returned
+ unmodified. Otherwise, it is converted to a `Tensor` using
+ `convert_to_tensor()`.
+
+ Args:
+ value: An `IndexedSlices` or an object that can be consumed by
+ `convert_to_tensor()`.
+ dtype: (Optional.) The required `DType` of the returned `Tensor` or
+ `IndexedSlices`.
+ name: (Optional.) A name to use if a new `Tensor` is created.
+
+ Returns:
+ An `Tensor` or an `IndexedSlices` based on `value`.
+
+ Raises:
+ ValueError: If `dtype` does not match the element type of `value`.
+ """
+ if isinstance(value, IndexedSlices):
+ if dtype and not types.AsDType(dtype).is_compatible_with(value.dtype):
+ raise ValueError(
+ "Tensor conversion requested dtype %s for Tensor with dtype %s: %r"
+ % (types.AsDType(dtype).name, value.dtype.name, str(value)))
+ return value
+ else:
+ return convert_to_tensor(value, dtype, name)
+
+
+def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
+ """Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
+
+ Args:
+ values: A list of `None`, `IndexedSlices`, or objects that can be consumed
+ by `convert_to_tensor()`.
+ dtype: (Optional.) The required `DType` of the returned `Tensor`
+ `IndexedSlices`.
+
+ name: (Optional.) A name prefix to used when a new `Tensor` is
+ created, in which case element `i` will be given the name `name
+ + '_' + i`.
+
+ Returns:
+ A list of `Tensor` and/or `IndexedSlices` objects.
+
+ Raises:
+ TypeError: If no conversion function is registered for an element in
+ `values`.
+ RuntimeError: If a registered conversion function returns an invalid
+ value.
+ """
+ if not isinstance(values, collections.Sequence):
+ raise TypeError("values must be a list.")
+ ret = []
+ for i, value in enumerate(values):
+ if value is None:
+ ret.append(value)
+ else:
+ n = None if name is None else "%s_%d" % (name, i)
+ ret.append(
+ convert_to_tensor_or_indexed_slices(value, dtype=dtype, name=n))
+ return ret
+
+
+def register_tensor_conversion_function(base_type, conversion_func,
+ priority=100):
+ """Registers a function for converting objects of base_type to Tensor.
+
+ The conversion function must have the following signature:
+
+ def conversion_func(value, dtype=None, name=None):
+ # ...
+
+ It must return a Tensor with the given dtype if specified. If the
+ conversion function creates a new Tensor, it should use the given
+ name if specified. All exceptions will be propagated to the caller.
+
+ NOTE: The conversion functions will execute in order of priority,
+ followed by order of registration. To ensure that a conversion
+ function F runs before another conversion function G, ensure that
+ F is registered with a smaller priority than G.
+
+ Args:
+ base_type: The base type or tuple of base types for all objects that
+ `conversion_func` accepts.
+ conversion_func: A function that converts instances of base_type to Tensor.
+ priority: Optional integer that indicates the priority for applying this
+ conversion function. Conversion functions with smaller priority values
+ run earlier than conversion functions with larger priority values.
+ Defaults to 100.
+
+ Raises:
+ TypeError: If the arguments do not have the appropriate type.
+
+ """
+ if not (isinstance(base_type, type) or
+ (isinstance(base_type, tuple)
+ and all(isinstance(x, type) for x in base_type))):
+ raise TypeError("base_type must be a type or a tuple of types.")
+ if not callable(conversion_func):
+ raise TypeError("conversion_func must be callable.")
+
+ try:
+ funcs_at_priority = _tensor_conversion_func_registry[priority]
+ except KeyError:
+ funcs_at_priority = []
+ _tensor_conversion_func_registry[priority] = funcs_at_priority
+ funcs_at_priority.append((base_type, conversion_func))
+
+
+class IndexedSlices(object):
+ """A sparse representation of a set of tensor slices at given indices.
+
+ This class is a simple wrapper for a pair of `Tensor` objects:
+
+ * `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`.
+ * `indices`: A 1-D integer `Tensor` with shape `[D0]`.
+
+ An `IndexedSlices` is typically used to represent a subset of a larger
+ tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`.
+ The values in `indices` are the indices in the first dimension of
+ the slices that have been extracted from the larger tensor.
+
+ The dense tensor `dense` represented by an `IndexedSlices` `slices` has
+
+ ```python
+ dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]
+ ```
+
+ The `IndexedSlices` class is used principally in the definition of
+ gradients for operations that have sparse gradients
+ (e.g. [`tf.gather`](array_ops.md#gather)).
+
+ Contrast this representation with
+ [`SparseTensor`](sparse_ops.md#SparseTensor),
+ which uses multi-dimensional indices and scalar values.
+
+ @@__init__
+
+ @@values
+ @@indices
+ @@dense_shape
+
+ @@name
+ @@dtype
+ @@device
+ @@op
+ """
+
+ def __init__(self, values, indices, dense_shape=None):
+ """Creates an `IndexedSlices`."""
+ self._values = values
+ self._indices = indices
+ self._dense_shape = dense_shape
+
+ @property
+ def values(self):
+ """A `Tensor` containing the values of the slices."""
+ return self._values
+
+ @property
+ def indices(self):
+ """A 1-D `Tensor` containing the indices of the slices."""
+ return self._indices
+
+ @property
+ def dense_shape(self):
+ """A 1-D `Tensor` containing the shape of the corresponding dense tensor."""
+ return self._dense_shape
+
+ @property
+ def name(self):
+ """The name of this `IndexedSlices`."""
+ return self.values.name
+
+ @property
+ def device(self):
+ """The name of the device on which `values` will be produced, or `None`."""
+ return self.values.device
+
+ @property
+ def op(self):
+ """The `Operation` that produces `values` as an output."""
+ return self.values.op
+
+ @property
+ def dtype(self):
+ """The `DType` of elements in this tensor."""
+ return self.values.dtype
+
+ def __str__(self):
+ return "IndexedSlices(indices=%s, values=%s)" % (
+ self._indices, self._values)
+
+
+def assert_same_graph(items, expected_graph=None):
+ """Asserts all items are from the same graph.
+
+ Args:
+ items: List of graph items (e.g., Variable, Tensor, SparseTensor,
+ Operation, or IndexedSlices).
+ expected_graph: Expected graph. If not specified, assert all tensors are
+ from the same graph.
+ Returns:
+ items, for chaining.
+ Raises:
+ ValueError: If any graphs do not match.
+ """
+ for item in items:
+ if not expected_graph:
+ expected_graph = item.graph
+ elif expected_graph != item.graph:
+ raise ValueError("Items must be from the same graph.")
+ return items
+
+
+class SparseTensor(object):
+ """Represents a sparse tensor.
+
+ Tensorflow represents a sparse tensor as three separate dense tensors:
+ `indices`, `values`, and `dense_shape`. In Python, the three tensors are
+ collected into a `SparseTensor` class for ease of use. If you have separate
+ `indices`, `values`, and `dense_shape` tensors, wrap them in a `SparseTensor`
+ object before passing to the Ops below.
+
+ Concretely, the sparse tensor `SparseTensor(values, indices, dense_shape)` is
+
+ * `indices`: A 2-D int64 tensor of shape `[N, ndims]`.
+ * `values`: A 1-D tensor of any type and shape `[N]`.
+ * `dense_shape`: A 1-D int64 tensor of shape `[ndims]`.
+
+ where `N` and `ndims` are the number of values, and number of dimensions in
+ the `SparseTensor` respectively.
+
+ The corresponding dense tensor satisfies
+
+ ```python
+ dense.shape = dense_shape
+ dense[tuple(indices[i])] = values[i]
+ ```
+
+ By convention, `indices` should be sorted in row-major order (or equivalently
+ lexigraphic order on the tuples `indices[i]`). This is not enforced when
+ `SparseTensor` objects are constructed, but most Ops assume correct ordering.
+ If the ordering is wrong, it can be fixed by calling `sparse_reorder` on the
+ misordered `SparseTensor`.
+
+ Example: The sparse tensor
+
+ ```python
+ SparseTensor(values=[1, 2], indices=[[0, 0], [1, 2]], shape=[3, 4])
+ ```
+
+ represents the dense tensor
+
+ ```python
+ [[1, 0, 0, 0]
+ [0, 0, 2, 0]
+ [0, 0, 0, 0]]
+ ```
+
+ @@__init__
+ @@indices
+ @@values
+ @@dtype
+ @@shape
+ @@graph
+ """
+
+ def __init__(self, indices, values, shape):
+ """Creates a `SparseTensor`.
+
+ Args:
+ indices: A 2-D int64 tensor of shape `[N, ndims]`.
+ values: A 1-D tensor of any type and shape `[N]`.
+ dense_shape: A 1-D int64 tensor of shape `[ndims]`.
+
+ Returns:
+ A `SparseTensor`
+ """
+ with op_scope([indices, values, shape], None, "SparseTensor"):
+ indices = convert_to_tensor(indices, name="indices")
+ values = convert_to_tensor(values, name="values")
+ shape = convert_to_tensor(shape, name="shape")
+ self._indices = indices
+ self._values = values
+ self._shape = shape
+
+ indices_shape = indices.get_shape().with_rank(2)
+ values_shape = values.get_shape().with_rank(1)
+ shape_shape = shape.get_shape().with_rank(1)
+
+ # Assert number of rows in indices match the number of elements in values.
+ indices_shape[0].merge_with(values_shape[0])
+ # Assert number of columns in indices matches the number of elements in
+ # shape.
+ indices_shape[1].merge_with(shape_shape[0])
+
+ @property
+ def indices(self):
+ """The indices of non-zero values in the represented dense tensor.
+
+ Returns:
+ A 2-D Tensor of int64 with shape `[N, ndims]`, where `N` is the
+ number of non-zero values in the tensor, and `ndims` is the rank.
+ """
+ return self._indices
+
+ @property
+ def values(self):
+ """The non-zero values in the represented dense tensor.
+
+ Returns:
+ A 1-D Tensor of any data type.
+ """
+ return self._values
+
+ @property
+ def dtype(self):
+ """The `DType` of elements in this tensor."""
+ return self._values.dtype
+
+ @property
+ def shape(self):
+ """A 1-D Tensor of int64 representing the shape of the dense tensor."""
+ return self._shape
+
+ @property
+ def graph(self):
+ """The `Graph` that contains the index, value, and shape tensors."""
+ return self._indices.graph
+
+ def __str__(self):
+ return "SparseTensor(indices=%s, values=%s, shape=%s)" % (
+ self._indices, self._values, self._shape)
+
+
+SparseTensorValue = collections.namedtuple("SparseTensorValue",
+ ["indices", "values", "shape"])
+
+
+def _device_string(dev_spec):
+ if isinstance(dev_spec, pydev.Device):
+ return dev_spec.to_string()
+ else:
+ return dev_spec
+
+
+def _NodeDef(op_type, name, device=None, attrs=None):
+ """Create a NodeDef proto.
+
+ Args:
+ op_type: Value for the "op" attribute of the NodeDef proto.
+ name: Value for the "name" attribute of the NodeDef proto.
+ device: string, device, or function from NodeDef to string.
+ Value for the "device" attribute of the NodeDef proto.
+ attrs: optional list for the "attr" attribute of the NodeDef proto.
+
+ Returns:
+ A graph_pb2.NodeDef protocol buffer.
+ """
+ node_def = graph_pb2.NodeDef()
+ node_def.op = str(op_type)
+ node_def.name = str(name)
+ if attrs is not None:
+ for k, v in attrs.iteritems():
+ node_def.attr[k].CopyFrom(v)
+ if device is not None:
+ if callable(device):
+ node_def.device = device(node_def)
+ else:
+ node_def.device = _device_string(device)
+ return node_def
+
+
+# Copied from core/framework/node_def_util.cc
+# TODO(mrry,josh11b): Consolidate this validation in C++ code.
+_VALID_OP_NAME_REGEX = re.compile("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*")
+
+
+class Operation(object):
+ """Represents a graph node that performs computation on tensors.
+
+ An `Operation` is a node in a TensorFlow `Graph` that takes zero or
+ more `Tensor` objects as input, and produces zero or more `Tensor`
+ objects as output. Objects of type `Operation` are created by
+ calling a Python op constructor (such as [`tf.matmul()`](math_ops.md#matmul))
+ or [`Graph.create_op()`](framework.md#Graph.create_op).
+
+ For example `c = tf.matmul(a, b)` creates an `Operation` of type
+ "MatMul" that takes tensors `a` and `b` as input, and produces `c`
+ as output.
+
+ After the graph has been launched in a session, an `Operation` can
+ be executed by passing it to [`Session.run()`](client.md#Session.run).
+ `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
+
+ @@name
+ @@type
+ @@inputs
+ @@control_inputs
+ @@outputs
+ @@device
+ @@graph
+
+ @@run
+
+ @@get_attr
+ @@traceback
+ """
+
+ def __init__(self, node_def, g, inputs=None, output_types=None,
+ control_inputs=None, input_types=None, original_op=None,
+ op_def=None):
+ """Creates an `Operation`.
+
+ NOTE: This constructor validates the name of the Operation (passed
+ as "node_def.name"). Valid Operation names match the following
+ regular expression:
+
+ [A-Za-z0-9.][A-Za-z0-9_.\\-/]*
+
+ Args:
+ node_def: graph_pb2.NodeDef. NodeDef for the Operation.
+ Used for attributes of graph_pb2.NodeDef, typically "name",
+ "op", and "device". The "input" attribute is irrelevant here
+ as it will be computed when generating the model.
+ g: Graph. The parent graph.
+ inputs: list of Tensor objects. The inputs to this Operation.
+ output_types: list of types_pb2.DataType. List of the types of the
+ Tensors computed by this operation. The length of this list indicates
+ the number of output endpoints of the Operation.
+ control_inputs: list of operations or tensors from which to have a
+ control dependency.
+ input_types: List of types_pb2.DataType representing the
+ types of the Tensors accepted by the Operation. By default
+ uses [x.dtype.base_dtype for x in inputs]. Operations that expect
+ reference-typed inputs must specify these explicitly.
+ original_op: Optional. Used to associate the new Operation with an
+ existing Operation (for example, a replica with the op that was
+ replicated).
+ op_def: Optional. The op_def_pb2.OpDef proto that describes the
+ op type that this Operation represents.
+
+ Raises:
+ TypeError: if control inputs are not Operations or Tensors,
+ or if node_def is not a NodeDef,
+ or if g is not a Graph,
+ or if inputs are not Tensors,
+ or if inputs and input_types are incompatible.
+ ValueError: if the node_def name is not valid.
+ """
+ if not isinstance(node_def, graph_pb2.NodeDef):
+ raise TypeError("node_def needs to be a NodeDef: %s" % node_def)
+ if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0:
+ raise ValueError(
+ "Cannot create an Operation with a NodeDef larger than 2GB.")
+ if not _VALID_OP_NAME_REGEX.match(node_def.name):
+ raise ValueError("'%s' is not a valid node name" % node_def.name)
+ if not isinstance(g, Graph):
+ raise TypeError("g needs to be a Graph: %s" % g)
+ self._node_def = copy.deepcopy(node_def)
+ self._graph = g
+ if inputs is None:
+ inputs = []
+ self._inputs = inputs
+ for a in self._inputs:
+ if not isinstance(a, Tensor):
+ raise TypeError("input needs to be a Tensor: %s" % a)
+ # Mark that we consume the inputs.
+ a._add_consumer(self) # pylint: disable=protected-access
+ if output_types is None:
+ output_types = []
+ self._output_types = output_types
+ self._outputs = [Tensor(self, i, output_types[i])
+ for i in xrange(len(output_types))]
+ if input_types is None:
+ input_types = [i.dtype.base_dtype for i in self._inputs]
+ else:
+ if not all(x.is_compatible_with(i.dtype)
+ for i, x in zip(self._inputs, input_types)):
+ raise TypeError("Inputs are not compatible with input types")
+ self._input_types = input_types
+
+ # Build the list of control inputs.
+ self._control_inputs = []
+ if control_inputs:
+ for c in control_inputs:
+ c_op = None
+ if isinstance(c, Operation):
+ c_op = c
+ elif isinstance(c, (Tensor, IndexedSlices)):
+ c_op = c.op
+ else:
+ raise TypeError("Control input must be an Operation, "
+ "a Tensor, or IndexedSlices: %s" % c)
+ self._control_inputs.append(c_op)
+
+ self._original_op = original_op
+ self._op_def = op_def
+ self._traceback = _extract_stack()
+ # Add this op to the current control flow context:
+ self._control_flow_context = g._get_control_flow_context()
+ if g._get_control_flow_context() is not None:
+ g._get_control_flow_context().AddOp(self)
+ # NOTE(keveman): Control flow context's AddOp could be creating new ops and
+ # setting op.inputs[index] = new_op. Thus the new ops' id could be larger
+ # than this op's id even though this op depend on them. Therefore, delaying
+ # assigning id to this op until all ops this could be dependent on are
+ # created.
+ self._id_value = self._graph._next_id() # pylint: disable=protected-access
+ self._recompute_node_def()
+
+ def values(self):
+ """DEPRECATED: Use outputs."""
+ return tuple(self.outputs)
+
+ def _get_control_flow_context(self):
+ """Returns the current control flow context.
+
+ Returns:
+ A context object.
+ """
+ return self._control_flow_context
+
+ @property
+ def name(self):
+ """The full name of this operation."""
+ return self._node_def.name
+
+ @property
+ def _id(self):
+ """The unique integer id of this operation."""
+ return self._id_value
+
+ @property
+ def device(self):
+ """The name of the device to which this op has been assigned, if any.
+
+ Returns:
+ The string name of the device to which this op has been
+ assigned, or None if it has not been assigned to a device.
+ """
+ dev = self._node_def.device
+ return None if not dev else dev
+
+ def _set_device(self, device):
+ """Set the device of this operation.
+
+ Args:
+ device: string or device.. The device to set.
+ """
+ self._node_def.device = _device_string(device)
+
+ def _add_input(self, tensor, dtype=None):
+ """Add a new input to this operation.
+
+ Args:
+ tensor: the Tensor to add as an input.
+ dtype: types.DType: type of the input; defaults to
+ the tensor's dtype.
+
+ Raises:
+ TypeError: if tensor is not a Tensor,
+ or if input tensor type is not convertible to dtype.
+ ValueError: if the Tensor is from a different graph.
+ """
+ if not isinstance(tensor, Tensor):
+ raise TypeError("tensor must be a Tensor: %s" % tensor)
+ assert_same_graph([self, tensor])
+ if dtype is None:
+ dtype = tensor.dtype
+ else:
+ dtype = types.as_dtype(dtype)
+ if not dtype.is_compatible_with(tensor.dtype):
+ raise TypeError(
+ "Cannot convert a tensor of type %s to an input of type %s"
+ % (tensor.dtype.name, dtype.name))
+ self._inputs.append(tensor)
+ self._input_types.append(dtype)
+ tensor._add_consumer(self) # pylint: disable=protected-access
+ self._recompute_node_def()
+
+ def _update_input(self, index, tensor, dtype=None):
+ """Update the input to this operation at the given index.
+
+ NOTE: This is for TF internal use only. Please don't use it.
+
+ Args:
+ index: the index of the input to update.
+ tensor: the Tensor to be used as the input at the given index.
+ dtype: types.DType: type of the input; defaults to
+ the tensor's dtype.
+
+ Raises:
+ TypeError: if tensor is not a Tensor,
+ or if input tensor type is not convertible to dtype.
+ ValueError: if the Tensor is from a different graph.
+ """
+ if not isinstance(tensor, Tensor):
+ raise TypeError("tensor must be a Tensor: %s" % tensor)
+ assert_same_graph([self, tensor])
+ if dtype is None:
+ dtype = tensor.dtype
+ else:
+ dtype = types.as_dtype(dtype)
+ if not dtype.is_compatible_with(tensor.dtype):
+ raise TypeError(
+ "Cannot convert a tensor of type %s to an input of type %s"
+ % (tensor.dtype.name, dtype.name))
+
+ self._inputs[index].consumers().remove(self)
+ self._inputs[index] = tensor
+ self._input_types[index] = dtype
+ tensor._add_consumer(self) # pylint: disable=protected-access
+ self._recompute_node_def()
+
+ def _add_control_input(self, op):
+ """Add a new control input to this operation.
+
+ Args:
+ op: the Operation to add as control input.
+
+ Raises:
+ TypeError: if op is not an Operation.
+ ValueError: if op is from a different graph.
+ """
+ if not isinstance(op, Operation):
+ raise TypeError("op must be an Operation: %s" % op)
+ assert_same_graph([self, op])
+ self._control_inputs.append(op)
+ self._recompute_node_def()
+
+ # Methods below are used when building the NodeDef and Graph proto.
+ def _recompute_node_def(self):
+ del self._node_def.input[:]
+ self._node_def.input.extend([t._as_node_def_input() for t in self._inputs])
+ if self._control_inputs:
+ self._node_def.input.extend(["^%s" % op.name for op in
+ self._control_inputs])
+
+ def __str__(self):
+ return str(self._node_def)
+
+ @property
+ def outputs(self):
+ """The list of `Tensor` objects representing the outputs of this op."""
+ return self._outputs
+
+# pylint: disable=protected-access
+ class _InputList(object):
+ """Immutable input list wrapper."""
+
+ def __init__(self, op):
+ self._op = op
+
+ def __iter__(self):
+ return iter(self._op._inputs)
+
+ def __len__(self):
+ return len(self._op._inputs)
+
+ def __bool__(self):
+ return bool(self._op._inputs)
+
+ def __getitem__(self, i):
+ return self._op._inputs[i]
+# pylint: enable=protected-access
+
+ @property
+ def inputs(self):
+ """The list of `Tensor` objects representing the data inputs of this op."""
+ return Operation._InputList(self)
+
+ @property
+ def _input_dtypes(self):
+ return self._input_types
+
+ @property
+ def control_inputs(self):
+ """The `Operation` objects on which this op has a control dependency.
+
+ Before this op is executed, TensorFlow will ensure that the
+ operations in `self.control_inputs` have finished executing. This
+ mechanism can be used to run ops sequentially for performance
+ reasons, or to ensure that the side effects of an op are observed
+ in the correct order.
+
+ Returns:
+ A list of `Operation` objects.
+
+ """
+ return self._control_inputs
+
+ @property
+ def type(self):
+ """The type of the op (e.g. `"MatMul"`)."""
+ return self._node_def.op
+
+ @property
+ def graph(self):
+ """The `Graph` that contains this operation."""
+ return self._graph
+
+ @property
+ def node_def(self):
+ """Returns a serialized `NodeDef` representation of this operation.
+
+ Returns:
+ A
+ [`NodeDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
+ protocol buffer.
+ """
+ return self._node_def
+
+ @property
+ def op_def(self):
+ """Returns the `OpDef` proto that represents the type of this op.
+
+ Returns:
+ An
+ [`OpDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/op_def.proto)
+ protocol buffer.
+ """
+ return self._op_def
+
+ @property
+ def traceback(self):
+ """Returns the call stack from when this operation was constructed."""
+ return _convert_stack(self._traceback)
+
+ def get_attr(self, name):
+ """Returns the value of the attr of this op with the given `name`.
+
+ Args:
+ name: The name of the attr to fetch.
+
+ Returns:
+ The value of the attr, as a Python object.
+
+ Raises:
+ ValueError: If this op does not have an attr with the given `name`.
+ """
+ fields = ["s", "i", "f", "b", "type", "shape", "tensor"]
+ if name not in self._node_def.attr:
+ raise ValueError("No attr named '" + name + "' in " +
+ str(self._node_def))
+ x = self._node_def.attr[name]
+ # Treat an empty oneof value as an empty list.
+ if not x.WhichOneof("value"):
+ return []
+ if x.HasField("list"):
+ for f in fields:
+ if getattr(x.list, f):
+ return list(getattr(x.list, f))
+ return []
+ else:
+ for f in fields:
+ if x.HasField(f):
+ return getattr(x, f)
+ assert False, "Unsupported field type in " + str(x)
+
+ def run(self, feed_dict=None, session=None):
+ """Runs this operation in a `Session`.
+
+ Calling this method will execute all preceding operations that
+ produce the inputs needed for this operation.
+
+ *N.B.* Before invoking `Operation.run()`, its graph must have been
+ launched in a session, and either a default session must be
+ available, or `session` must be specified explicitly.
+
+ Args:
+ feed_dict: A dictionary that maps `Tensor` objects to feed values.
+ See [`Session.run()`](client.md#Session.run) for a description of the
+ valid feed values.
+ session: (Optional.) The `Session` to be used to run to this operation. If
+ none, the default session will be used.
+ """
+ _run_using_default_session(self, feed_dict, self.graph, session)
+
+
+_gradient_registry = registry.Registry("gradient")
+
+
+class RegisterGradient(object):
+ """A decorator for registering the gradient function for an op type.
+
+ This decorator is only used when defining a new op type. For an op
+ with `m` inputs and `n` inputs, the gradient function is a function
+ that takes the original `Operation` and `n` `Tensor` objects
+ (representing the gradients with respect to each output of the op),
+ and returns `m` `Tensor` objects (representing the partial gradients
+ with respect to each input of the op).
+
+ For example, assuming that operations of type `"Sub"` take two
+ inputs `x` and `y`, and return a single output `x - y`, the
+ following gradient function would be registered:
+
+ ```python
+ @tf.RegisterGradient("Sub")
+ def _sub_grad(unused_op, grad):
+ return grad, tf.Neg(grad)
+ ```
+
+ The decorator argument `op_type` is the string type of an
+ operation. This corresponds to the `OpDef.name` field for the proto
+ that defines the operation.
+
+ @@__init__
+ """
+
+ def __init__(self, op_type):
+ """Creates a new decorator with `op_type` as the Operation type.
+
+ Args:
+ op_type: The string type of an operation. This corresponds to the
+ `OpDef.name` field for the proto that defines the operation.
+ """
+ if not isinstance(op_type, basestring):
+ raise TypeError("op_type must be a string")
+ self._op_type = op_type
+
+ def __call__(self, f):
+ """Registers the function `f` as gradient function for `op_type`."""
+ _gradient_registry.register(f, self._op_type)
+ return f
+
+
+def NoGradient(op_type):
+ """Specifies that ops of type `op_type` do not have a defined gradient.
+
+ This function is only used when defining a new op type. It may be
+ used for ops such as `tf.size()` that are not differentiable. For
+ example:
+
+ ```python
+ tf.NoGradient("Size")
+ ```
+
+ Args:
+ op_type: The string type of an operation. This corresponds to the
+ `OpDef.name` field for the proto that defines the operation.
+
+ Raises:
+ TypeError: If `op_type` is not a string.
+
+ """
+ if not isinstance(op_type, basestring):
+ raise TypeError("op_type must be a string")
+ _gradient_registry.register(None, op_type)
+
+
+def get_gradient_function(op):
+ """Returns the function that computes gradients for "op"."""
+ if not op.inputs: return None
+ try:
+ op_type = op.get_attr("_gradient_op_type")
+ except ValueError:
+ op_type = op.type
+ return _gradient_registry.lookup(op_type)
+
+
+_shape_registry = registry.Registry("shape functions")
+_default_shape_function_registry = registry.Registry("default shape functions")
+
+class RegisterShape(object):
+ """A decorator for registering the shape function for an op type.
+
+ This decorator is only used when defining a new op type. A shape
+ function is a function from an `Operation` object to a list of
+ `TensorShape` objects, with one `TensorShape` for each output of the
+ operation.
+
+ For example, assuming that operations of type `"Sub"` take two
+ inputs `x` and `y`, and return a single output `x - y`, all with the
+ same shape, the following shape function would be registered:
+
+ ```python
+ @tf.RegisterShape("Sub")
+ def _sub_shape(op):
+ return [op.inputs[0].get_shape().merge_with(op.inputs[1].get_shape())]
+ ```
+
+ The decorator argument `op_type` is the string type of an
+ operation. This corresponds to the `OpDef.name` field for the proto
+ that defines the operation.
+
+ """
+
+ def __init__(self, op_type):
+ """Saves the "op_type" as the Operation type."""
+ if not isinstance(op_type, basestring):
+ raise TypeError("op_type must be a string")
+ self._op_type = op_type
+
+ def __call__(self, f):
+ """Registers "f" as the shape function for "op_type"."""
+ if f is None:
+ # None is a special "weak" value that provides a default shape function,
+ # and can be overridden by a non-None registration.
+ try:
+ _default_shape_function_registry.register(_no_shape_function,
+ self._op_type)
+ except KeyError:
+ # Ignore duplicate registrations of the weak value. This can
+ # occur if the op library input to wrapper generation
+ # inadvertently links in one or more of the standard op
+ # libraries.
+ pass
+ else:
+ _shape_registry.register(f, self._op_type)
+ return f
+
+
+def _no_shape_function(op):
+ return [tensor_shape.unknown_shape() for _ in op.outputs]
+
+
+def set_shapes_for_outputs(op):
+ """Uses the registered shape functions to set the shapes for op's outputs."""
+ try:
+ shape_func = _shape_registry.lookup(op.type)
+ except LookupError:
+ try:
+ shape_func = _default_shape_function_registry.lookup(op.type)
+ except LookupError:
+ raise RuntimeError("No shape function registered for standard op: %s"
+ % op.type)
+ shapes = shape_func(op)
+ if len(op.outputs) != len(shapes):
+ raise RuntimeError(
+ "Shape function for op %s returned %g shapes but expecting %g" %
+ (op, len(op.outputs), len(shapes)))
+ for output, s in zip(op.outputs, shapes):
+ output.set_shape(s)
+
+
+class Graph(object):
+ """A TensorFlow computation, represented as a dataflow graph.
+
+ A `Graph` contains a set of [`Operation`](framework.md#Operation) objects,
+ which represent units of computation; and [`Tensor`](framework.md#Tensor)
+ objects, which represent the units of data that flow between operations.
+
+ A default `Graph` is always registered, and accessible by calling
+ [`tf.get_default_graph()`](framework.md#get_default_graph). To add an
+ operation to the default graph, simply call one of the functions that defines
+ a new `Operation`:
+
+ ```
+ c = tf.constant(4.0)
+ assert c.graph is tf.get_default_graph()
+ ```
+
+ Another typical usage involves the
+ [`Graph.as_default()`](framework.md#Graph.as_default)
+ context manager, which overrides the current default graph for the
+ lifetime of the context:
+
+ ```python
+ g = tf.Graph()
+ with g.as_default():
+ # Define operations and tensors in `g`.
+ c = tf.constant(30.0)
+ assert c.graph is g
+ ```
+
+ Important note: This class *is not* thread-safe for graph construction. All
+ operations should be created from a single thread, or external
+ synchronization must be provided. Unless otherwise specified, all methods
+ are not thread-safe.
+
+ @@__init__
+ @@as_default
+ @@as_graph_def
+ @@finalize
+ @@finalized
+
+ @@control_dependencies
+ @@device
+ @@name_scope
+
+ A `Graph` instance supports an arbitrary number of "collections"
+ that are identified by name. For convenience when building a large
+ graph, collections can store groups of related objects: for
+ example, the `tf.Variable` uses a collection (named
+ [`tf.GraphKeys.VARIABLES`](framework.md#GraphKeys)) for all variables that are
+ created during the construction of a graph. The caller may define
+ additional collections by specifying a new name.
+
+ @@add_to_collection
+ @@get_collection
+
+ @@as_graph_element
+ @@get_operation_by_name
+ @@get_tensor_by_name
+ @@get_operations
+
+ @@get_default_device
+ @@seed
+ @@unique_name
+ @@version
+
+ @@create_op
+ @@gradient_override_map
+ """
+
+ def __init__(self):
+ """Creates a new, empty Graph."""
+ self._nodes_by_id = dict()
+ self._next_node_id = [dict()]
+ self._next_id_counter = 0
+ self._nodes_by_name = dict()
+ # Current name stack: a pair of uniquified names and plain names.
+ self._name_stack = ("", "")
+ # Maps a name used in the graph to the next id to use for that name.
+ self._names_in_use = {}
+ # Default device applied to new ops.
+ self._default_device = None
+ # Functions that will be applied to choose a device if none is specified.
+ self._device_function_stack = []
+ # Default original_op applied to new ops.
+ self._default_original_op = None
+ # Current control flow context. It could be either CondContext or
+ # WhileContext defined in ops/control_flow_ops.py
+ self._control_flow_context = None
+ # A new node will depend of the union of all of the nodes in the stack.
+ self._control_dependencies_stack = []
+ # Arbritrary collections of objects.
+ self._collections = {}
+ # The graph-level random seed
+ self._seed = None
+ # A map from op type to the kernel label that should be used.
+ self._op_to_kernel_label_map = {}
+ # A map from op type to an alternative op type that should be used when
+ # computing gradients.
+ self._gradient_override_map = {}
+ # True if the graph is considered "finalized". In that case no
+ # new operations can be added.
+ self._finalized = False
+
+ def _check_not_finalized(self):
+ """Check if the graph is finalized.
+
+ Raises:
+ RuntimeError: If the graph finalized.
+ """
+ if self._finalized:
+ raise RuntimeError("Graph is finalized and cannot be modified.")
+
+ def _add_op(self, op):
+ """Adds 'op' to the graph.
+
+ Args:
+ op: the Operator or Tensor to add.
+
+ Raises:
+ TypeError: if op is not an Operation or Tensor.
+ ValueError: if the op.name or op._id are already used.
+ """
+ self._check_not_finalized()
+ if not isinstance(op, (Tensor, Operation)):
+ raise TypeError("op must be a Tensor or Operation: %s" % op)
+
+ if op._id in self._nodes_by_id:
+ raise ValueError("cannot add an op with id %d as it already "
+ "exists in the graph" % op._id)
+ if op.name in self._nodes_by_name:
+ raise ValueError("cannot add op with name %s as that name "
+ "is already used" % op.name)
+ self._nodes_by_id[op._id] = op
+ self._nodes_by_name[op.name] = op
+
+ @property
+ def version(self):
+ """Returns a version number that increases as ops are added to the graph."""
+ return self._next_id_counter
+
+ @property
+ def seed(self):
+ return self._seed
+
+ @seed.setter
+ def seed(self, seed):
+ self._seed = seed
+
+ @property
+ def finalized(self):
+ """True if this graph has been finalized."""
+ return self._finalized
+
+ def finalize(self):
+ """Finalizes this graph, making it read-only.
+
+ After calling `g.finalize()`, no new operations can be added to
+ `g`. This method is used to ensure that no operations are added
+ to a graph when it is shared between multiple threads, for example
+ when using a [`QueueRunner`](train.md#QueueRunner).
+ """
+ self._finalized = True
+
+ def _get_control_flow_context(self):
+ """Returns the current control flow context.
+
+ Returns:
+ A context object.
+ """
+ return self._control_flow_context
+
+ def _set_control_flow_context(self, context):
+ """Sets the current control flow context.
+
+ Args:
+ context: a context object.
+ """
+ self._control_flow_context = context
+
+ def as_graph_def(self, from_version=None):
+ """Returns a serialized `GraphDef` representation of this graph.
+
+ This method is thread-safe.
+
+ Args:
+ from_version: Optional. If this is set, returns a `GraphDef`
+ containing only the nodes that were added to this graph since
+ its `version` property had the given value.
+
+ Returns:
+ A
+ [`GraphDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
+ protocol buffer.
+ """
+ graph = graph_pb2.GraphDef()
+ bytesize = 0
+ for op_id in sorted(self._nodes_by_id):
+ op = self._nodes_by_id[op_id]
+ if from_version is None or op_id > from_version:
+ graph.node.extend([op.node_def])
+ bytesize += op.node_def.ByteSize()
+ if bytesize >= (1 << 31) or bytesize < 0:
+ raise ValueError("GraphDef cannot be larger than 2GB.")
+ return graph
+
+ # Helper functions to create operations.
+ def create_op(self, op_type, inputs, dtypes,
+ input_types=None, name=None, attrs=None, op_def=None,
+ compute_shapes=True):
+ """Creates an `Operation` in this graph.
+
+ This is a low-level interface for creating an `Operation`. Most
+ programs will not call this method directly, and instead use the
+ Python op constructors, such as `tf.constant()`, which add ops to
+ the default graph.
+
+ Args:
+ op_type: The `Operation` type to create. This corresponds to the
+ `OpDef.name` field for the proto that defines the operation.
+ inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
+ dtypes: A list of `DType` objects that will be the types of the tensors
+ that the operation produces.
+ input_types: (Optional.) A list of `DType`s that will be the types of
+ the tensors that the operation consumes. By default, uses the base
+ `DType` of each input in `inputs`. Operations that expect
+ reference-typed inputs must specify `input_types` explicitly.
+ name: (Optional.) A string name for the operation. If not specified, a
+ name is generated based on `op_type`.
+ attrs: (Optional.) A list of `AttrValue` protos for the `attr` field of
+ the `NodeDef` proto that will represent the operation.
+ op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
+ the operation will have.
+ compute_shapes: (Optional.) If True, shape inference will be performed
+ to compute the shapes of the outputs.
+
+ Raises:
+ TypeError: if any of the inputs is not a `Tensor`.
+
+ Returns:
+ An `Operation` object.
+
+ """
+ self._check_not_finalized()
+ for idx, a in enumerate(inputs):
+ if not isinstance(a, Tensor):
+ raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
+ if name is None:
+ name = op_type
+ # If a names ends with a '/' it is a "name scope" and we use it as-is,
+ # after removing the trailing '/'.
+ if name and name[-1] == "/":
+ name = name[:-1]
+ else:
+ name = self.unique_name(name)
+
+ node_def = _NodeDef(
+ op_type, name, device=self._default_device or None, attrs=attrs)
+
+ # Apply a kernel label if one has been specified for this op_type.
+ try:
+ kernel_label = self._op_to_kernel_label_map[op_type]
+ node_def.attr["_kernel"].CopyFrom(
+ attr_value_pb2.AttrValue(s=kernel_label))
+ except KeyError:
+ pass
+
+ # Apply the overriding op_type for gradients if one has been
+ # specified for this op_type.
+ try:
+ mapped_op_type = self._gradient_override_map[op_type]
+ node_def.attr["_gradient_op_type"].CopyFrom(
+ attr_value_pb2.AttrValue(s=mapped_op_type))
+ except KeyError:
+ pass
+
+ control_inputs = self._control_dependencies_for_inputs(inputs)
+ ret = Operation(node_def, self, inputs=inputs, output_types=dtypes,
+ control_inputs=control_inputs, input_types=input_types,
+ original_op=self._default_original_op, op_def=op_def)
+ if compute_shapes:
+ set_shapes_for_outputs(ret)
+ self._add_op(ret)
+ self._record_op_seen_by_control_dependencies(ret)
+ # Apply any device functions in reverse order, so that the most recently
+ # pushed function has the first chance to apply a device to the op.
+ # We apply here because the result can depend on the Operation's
+ # signature, which is computed in the Operation constructor.
+ for device_function in reversed(self._device_function_stack):
+ ret._set_device(device_function(ret))
+ return ret
+
+ def as_graph_element(self, obj, allow_tensor=True, allow_operation=True):
+ """Returns the object referred to by `obj`, as an `Operation` or `Tensor`.
+
+ This function validates that `obj` represents an element of this
+ graph, and gives an informative error message if it is not.
+
+ This function is the canonical way to get/validate an object of
+ one of the allowed types from an external argument reference in the
+ Session API.
+
+ This method may be called concurrently from multiple threads.
+
+ Args:
+ obj: A `Tensor`, an `Operation`, or the name of a tensor or operation.
+ Can also be any object with an `_as_graph_element()` method that returns
+ a value of one of these types.
+ allow_tensor: If true, `obj` may refer to a `Tensor`.
+ allow_operation: If true, `obj` may refer to an `Operation`.
+
+ Returns:
+ The `Tensor` or `Operation` in the Graph corresponding to `obj`.
+
+ Raises:
+ TypeError: If `obj` is not a type we support attempting to convert
+ to types.
+ ValueError: If `obj` is of an appropriate type but invalid. For
+ example, an invalid string.
+ KeyError: If `obj` is not an object in the graph.
+ """
+
+ # The vast majority of this function is figuring
+ # out what an API user might be doing wrong, so
+ # that we can give helpful error messages.
+ #
+ # Ideally, it would be nice to split it up, but we
+ # need context to generate nice error messages.
+
+ if allow_tensor and allow_operation:
+ types_str = "Tensor or Operation"
+ elif allow_tensor:
+ types_str = "Tensor"
+ elif allow_operation:
+ types_str = "Operation"
+ else:
+ raise ValueError("allow_tensor and allow_operation can't both be False.")
+
+ conv_fn = getattr(obj, "_as_graph_element", None)
+ if conv_fn and callable(conv_fn):
+ obj = conv_fn()
+
+ # If obj appears to be a name...
+ if isinstance(obj, basestring):
+ name = obj
+
+ if ":" in name and allow_tensor:
+ # Looks like a Tensor name and can be a Tensor.
+ try:
+ op_name, out_n = name.split(":")
+ out_n = int(out_n)
+ except:
+ raise ValueError("The name %s looks a like a Tensor name, but is "
+ "not a valid one. Tensor names must be of the "
+ "form \"<op_name>:<output_index>\"." % repr(name))
+ if op_name in self._nodes_by_name:
+ op = self._nodes_by_name[op_name]
+ else:
+ raise KeyError("The name %s refers to a Tensor which does not "
+ "exist. The operation, %s, does not exist in the "
+ "graph." % (repr(name), repr(op_name)))
+ try:
+ return op.outputs[out_n]
+ except:
+ raise KeyError("The name %s refers to a Tensor which does not "
+ "exist. The operation, %s, exists but only has "
+ "%s outputs."
+ % (repr(name), repr(op_name), len(op.outputs)))
+
+ elif ":" in name and not allow_tensor:
+ # Looks like a Tensor name but can't be a Tensor.
+ raise ValueError("Name %s appears to refer to a Tensor, not a %s."
+ % (repr(name), types_str))
+
+ elif ":" not in name and allow_operation:
+ # Looks like an Operation name and can be an Operation.
+ if name not in self._nodes_by_name:
+ raise KeyError("The name %s refers to an Operation not in the "
+ "graph." % repr(name))
+ return self._nodes_by_name[name]
+
+ elif ":" not in name and not allow_operation:
+ # Looks like an Operation name but can't be an Operation.
+ if name in self._nodes_by_name:
+ # Yep, it's an Operation name
+ err_msg = ("The name %s refers to an Operation, not a %s."
+ % (repr(name), types_str))
+ else:
+ err_msg = ("The name %s looks like an (invalid) Operation name, "
+ "not a %s." % (repr(name), types_str))
+ err_msg += (" Tensor names must be of the form "
+ "\"<op_name>:<output_index>\".")
+ raise ValueError(err_msg)
+
+ elif isinstance(obj, Tensor) and allow_tensor:
+ # Actually obj is just the object it's referring to.
+ return obj
+ elif isinstance(obj, Operation) and allow_operation:
+ # Actually obj is just the object it's referring to.
+ return obj
+ else:
+ # We give up!
+ raise TypeError("Can not convert a %s into a %s."
+ % (type(obj).__name__, types_str))
+
+ def get_operations(self):
+ """Return the list of operations in the graph.
+
+ You can modify the operations in place, but modifications
+ to the list such as inserts/delete have no effect on the
+ list of operations known to the graph.
+
+ This method may be called concurrently from multiple threads.
+
+ Returns:
+ A list of Operations.
+ """
+ return self._nodes_by_id.values()
+
+ def get_operation_by_name(self, name):
+ """Returns the `Operation` with the given `name`.
+
+ This method may be called concurrently from multiple threads.
+
+ Args:
+ name: The name of the `Operation` to return.
+
+ Returns:
+ The `Operation` with the given `name`.
+
+ Raises:
+ TypeError: If `name` is not a string.
+ KeyError: If `name` does not correspond to an operation in this graph.
+ """
+
+ if not isinstance(name, basestring):
+ raise TypeError("Operation names are strings (or similar), not %s."
+ % type(name).__name__)
+ return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
+
+ def get_tensor_by_name(self, name):
+ """Returns the `Tensor` with the given `name`.
+
+ This method may be called concurrently from multiple threads.
+
+ Args:
+ name: The name of the `Tensor` to return.
+
+ Returns:
+ The `Tensor` with the given `name`.
+
+ Raises:
+ TypeError: If `name` is not a string.
+ KeyError: If `name` does not correspond to a tensor in this graph.
+ """
+ # Names should be strings.
+ if not isinstance(name, basestring):
+ raise TypeError("Tensor names are strings (or similar), not %s."
+ % type(name).__name__)
+ return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
+
+ def _next_id(self):
+ """Id for next Operation instance. Also increments the internal id."""
+ self._check_not_finalized()
+ self._next_id_counter += 1
+ return self._next_id_counter
+
+ @property
+ def _last_id(self):
+ return self._next_id_counter
+
+ def as_default(self):
+ """Returns a context manager that makes this `Graph` the default graph.
+
+ This method should be used if you want to create multiple graphs
+ in the same process. For convenience, a global default graph is
+ provided, and all ops will be added to this graph if you do not
+ create a new graph explicitly. Use this method the `with` keyword
+ to specify that ops created within the scope of a block should be
+ added to this graph.
+
+ The default graph is a property of the current thread. If you
+ create a new thread, and wish to use the default graph in that
+ thread, you must explicitly add a `with g.as_default():` in that
+ thread's function.
+
+ The following code examples are equivalent:
+
+ ```python
+ # 1. Using Graph.as_default():
+ g = tf.Graph()
+ with g.as_default():
+ c = tf.constant(5.0)
+ assert c.graph is g
+
+ # 2. Constructing and making default:
+ with tf.Graph().as_default() as g:
+ c = tf.constant(5.0)
+ assert c.graph is g
+ ```
+
+ Returns:
+ A context manager for using this graph as the default graph.
+ """
+ return _default_graph_stack.get_controller(self)
+
+ def add_to_collection(self, name, value):
+ """Stores `value` in the collection with the given `name`.
+
+ Args:
+ name: The key for the collection. For example, the `GraphKeys` class
+ contains many standard names for collections.
+ value: The value to add to the collection.
+ """
+ self._check_not_finalized()
+ if name not in self._collections:
+ self._collections[name] = [value]
+ else:
+ self._collections[name].append(value)
+
+ def get_collection(self, name, scope=None):
+ """Returns a list of values in the collection with the given `name`.
+
+ Args:
+ key: The key for the collection. For example, the `GraphKeys` class
+ contains many standard names for collections.
+ scope: (Optional.) If supplied, the resulting list is filtered to include
+ only items whose name begins with this string.
+
+ Returns:
+ The list of values in the collection with the given `name`, or
+ an empty list if no value has been added to that collection. The
+ list contains the values in the order under which they were
+ collected.
+ """
+ if scope is None:
+ return self._collections.get(name, list())
+ else:
+ c = []
+ for item in self._collections.get(name, list()):
+ if hasattr(item, 'name') and item.name.startswith(scope):
+ c.append(item)
+ return c
+
+ @contextlib.contextmanager
+ def _original_op(self, op):
+ """Python 'with' handler to help annotate ops with their originator.
+
+ An op may have an 'original_op' property that indicates the op on which
+ it was based. For example a replica op is based on the op that was
+ replicated and a gradient op is based on the op that was differentiated.
+
+ All ops created in the scope of this 'with' handler will have
+ the given 'op' as their original op.
+
+ Args:
+ op: The Operation that all ops created in this scope will have as their
+ original op.
+
+ Yields:
+ Nothing.
+ """
+ old_original_op = self._default_original_op
+ try:
+ self._default_original_op = op
+ yield
+ finally:
+ self._default_original_op = old_original_op
+
+ # pylint: disable=g-doc-return-or-yield
+ @contextlib.contextmanager
+ def name_scope(self, name):
+ """Returns a context manager that creates hierarchical names for operations.
+
+ A graph maintains a stack of name scopes. A `with name_scope(...):`
+ statement pushes a new name onto the stack for the lifetime of the context.
+
+ The `name` argument will be interpreted as follows:
+
+ * A string (not ending with '/') will create a new name scope, in which
+ `name` is appended to the prefix of all operations created in the
+ context. If `name` has been used before, it will be made unique by
+ calling `self.unique_name(name)`.
+ * A scope previously captured from a `with g.name_scope(...) as
+ scope:` statement will be treated as an "absolute" name scope, which
+ makes it possible to re-enter existing scopes.
+ * A value of `None` or the empty string will reset the current name scope
+ to the top-level (empty) name scope.
+
+ For example:
+
+ ```python
+ with tf.Graph().as_default() as g:
+ c = tf.constant(5.0, name="c")
+ assert c_1.name == "c"
+ c_1 = tf.constant(6.0, name="c")
+ assert c_1.name == "c_1"
+
+ # Creates a scope called "nested"
+ with g.name_scope("nested") as scope:
+ nested_c = tf.constant(10.0, name="c")
+ assert nested_c.name == "nested/c"
+
+ # Creates a nested scope called "inner".
+ with g.name_scope("inner"):
+ nested_inner_c = tf.constant(20.0, name="c")
+ assert nested_inner_c.name == "nested/inner/c"
+
+ # Create a nested scope called "inner_1".
+ with g.name_scope("inner"):
+ nested_inner_1_c = tf.constant(30.0, name="c")
+ assert nested_inner_1_c.name == "nested/inner_1/c"
+
+ # Treats `scope` as an absolute name scope, and
+ # switches to the "nested/" scope.
+ with g.name_scope(scope):
+ nested_d = tf.constant(40.0, name="d")
+ assert nested_d.name == "nested/d"
+
+ with g.name_scope(""):
+ e = tf.constant(50.0, name="e")
+ assert e.name == "e"
+ ```
+
+ The name of the scope itself can be captured by `with
+ g.name_scope(...) as scope:`, which stores the name of the scope
+ in the variable `scope`. This value can be used to name an
+ operation that represents the overall result of executing the ops
+ in a scope. For example:
+
+ ```python
+ inputs = tf.constant(...)
+ with g.name_scope('my_layer') as scope:
+ weights = tf.Variable(..., name="weights")
+ biases = tf.Variable(..., name="biases")
+ affine = tf.matmul(inputs, weights) + biases
+ output = tf.nn.relu(affine, name=scope)
+ ```
+
+
+ Args:
+ name: A name for the scope.
+
+ Returns:
+ A context manager that installs `name` as a new name scope.
+ """
+ try:
+ old_stack = self._name_stack
+ if not name: # Both for name=None nad name="" we re-set to empty scope.
+ new_stack = (None, None)
+ elif name and name[-1] == "/":
+ new_stack = (name[:-1], name[:-1])
+ else:
+ new_stack = (self.unique_name(name), self._plain_name(name))
+ self._name_stack = new_stack
+ yield "" if new_stack[0] is None else new_stack[0] + "/"
+ finally:
+ self._name_stack = old_stack
+ # pylint: enable=g-doc-return-or-yield
+
+ def unique_name(self, name):
+ """Return a unique Operation name for "name".
+
+ Note: You rarely need to call unique_name() directly. Most of the time you
+ just need to create "with g.name_scope()" blocks to generate structured
+ names.
+
+ `unique_name` is used to generate structured names, separated by "/",
+ to help identify Operations when debugging a Graph. Operation names
+ are displayed in error messages reported by the TensorFlow runtime,
+ and in various visualization tools such as TensorBoard.
+
+ Args:
+ name: The name for an `Operation`.
+
+ Returns:
+ A string to be passed to `create_op()` that will be used
+ to name the operation being created.
+ """
+ if self._name_stack[0]:
+ name = self._name_stack[0] + "/" + name
+ i = self._names_in_use.get(name, 0)
+ # Increment the number for "name".
+ self._names_in_use[name] = i + 1
+ if i > 0:
+ base_name = name
+ # Make sure the composed name is not already used.
+ while name in self._names_in_use:
+ name = "%s_%d" % (base_name, i)
+ i += 1
+ # Mark the composed name as used in case someone wants
+ # to call unique_name("name_1").
+ self._names_in_use[name] = 1
+ return name
+
+ # TODO(mdevin): remove
+ def _plain_name(self, name):
+ """Return the fully scoped 'name'.
+
+ Args:
+ name: a string.
+
+ Returns:
+ 'name' scoped in the current name stack, without any uniquified
+ elements.
+ """
+ if self._name_stack[1]:
+ return self._name_stack[1] + "/" + name
+ else:
+ return name
+
+ def _set_default_device(self, dev):
+ """Set the default device properties.
+
+ Args:
+ dev: string or Device.
+ """
+ self._default_device = _device_string(dev)
+
+ def get_default_device(self):
+ """Returns the default device.
+
+ Returns:
+ A string.
+ """
+ return self._default_device
+
+ def _push_default_device_function(self, device_function):
+ """Pushes the given function onto the stack of device functions.
+
+ See Graph.device for more details.
+
+ Args:
+ device_function: The function to be pushed onto the stack of device
+ functions.
+ """
+ self._device_function_stack.append(device_function)
+
+ def _pop_default_device_function(self, device_function):
+ """Pops the given function from the stack of device functions.
+
+ See Graph.device for more details.
+
+ Args:
+ device_function: The function to be popped from the stack of device
+ functions.
+
+ Raises:
+ ValueError: if the device_function to be popped is not top of the stack,
+ or if the stack is empty.
+ """
+ if not self._device_function_stack:
+ raise ValueError("Tried to pop, but the device function stack is empty")
+ if self._device_function_stack[-1] is not device_function:
+ raise ValueError("Tried to pop device function, but it was not on top "
+ "of the stack")
+
+ self._device_function_stack.pop()
+
+ @contextlib.contextmanager
+ def device(self, device_name_or_function):
+ """Returns a context manager that specifies the default device to use.
+
+ The `device_name_or_function` argument may either be a device name
+ string, a device function, or None:
+
+ * If it is a device name string, all operations constructed in
+ this context will be assigned to the device with that name.
+ * If it is a function, it will be treated as function from
+ Operation objects to device name strings, and invoked each time
+ a new Operation is created. The Operation will be assigned to
+ the device with the returned name.
+ * If it is None, the default device will be cleared.
+
+ For example:
+
+ ```python
+ with g.device('/gpu:0'):
+ # All operations constructed in this context will be placed
+ # on GPU 0.
+ with g.device(None):
+ # All operations constructed in this context will have no
+ # assigned device.
+
+ # Defines a function from `Operation` to device string.
+ def matmul_on_gpu(n):
+ if n.type == "MatMul":
+ return "/gpu:0"
+ else:
+ return "/cpu:0"
+
+ with g.device(matmul_on_gpu):
+ # All operations of type "MatMul" constructed in this context
+ # will be placed on GPU 0; all other operations will be placed
+ # on CPU 0.
+ ```
+
+ Args:
+ device_name_or_function: The device name or function to use in
+ the context.
+
+ Returns:
+ A context manager that specifies the default device to use for newly
+ created ops.
+ """
+ if callable(device_name_or_function):
+ try:
+ self._push_default_device_function(device_name_or_function)
+ yield
+ finally:
+ self._pop_default_device_function(device_name_or_function)
+ else:
+ try:
+ old_dev = self.get_default_device()
+ self._set_default_device(_device_string(device_name_or_function))
+ yield
+ finally:
+ self._set_default_device(old_dev)
+
+ class _ControlDependenciesController(object):
+ """Context manager for `control_dependencies()`."""
+
+ def __init__(self, graph, control_inputs):
+ self._graph = graph
+ self._control_inputs = control_inputs
+ self._seen_nodes = set()
+
+# pylint: disable=protected-access
+ def __enter__(self):
+ self._graph._push_control_dependencies_controller(self)
+
+ def __exit__(self, unused_type, unused_value, unused_traceback):
+ self._graph._pop_control_dependencies_controller(self)
+# pylint: enable=protected-access
+
+ @property
+ def control_inputs(self):
+ return self._control_inputs
+
+ def add_op(self, op):
+ self._seen_nodes.add(op)
+
+ def op_in_group(self, op):
+ return op in self._seen_nodes
+
+ def _push_control_dependencies_controller(self, controller):
+ self._control_dependencies_stack.append(controller)
+
+ def _pop_control_dependencies_controller(self, controller):
+ assert self._control_dependencies_stack[-1] is controller
+ self._control_dependencies_stack.pop()
+
+ def _current_control_dependencies(self):
+ ret = set()
+ for controller in self._control_dependencies_stack:
+ for op in controller.control_inputs:
+ ret.add(op)
+ return ret
+
+ def _control_dependencies_for_inputs(self, input_tensors):
+ """For an op that takes `input_tensors` as inputs, compute control inputs.
+
+ The returned control dependencies should yield an execution that
+ is equivalent to adding all control inputs in
+ self._control_dependencies_stack to a newly created op. However,
+ this function attempts to prune the returned control dependencies
+ by observing that nodes created within the same `with
+ control_dependencies(...):` block may have data dependencies that make
+ the explicit approach redundant.
+
+ Args:
+ input_tensors: The direct data dependencies for an op to be created.
+
+ Returns:
+ A list of control inputs for the op to be created.
+ """
+ ret = []
+ input_ops = set([t.op for t in input_tensors])
+ for controller in self._control_dependencies_stack:
+ # If any of the input_ops already depends on the inputs from controller,
+ # we say that the new op is dominated (by that input), and we therefore
+ # do not need to add control dependences for this controller's inputs.
+ dominated = False
+ for op in input_ops:
+ if controller.op_in_group(op):
+ dominated = True
+ break
+ if not dominated:
+ # Don't add a control input if we already have a data dependency on i.
+ # NOTE(mrry): We do not currently track transitive data dependencies,
+ # so we may add redundant control inputs.
+ ret.extend([c for c in controller.control_inputs if c not in input_ops])
+ return ret
+
+ def _record_op_seen_by_control_dependencies(self, op):
+ """Record that the given op depends on all registered control dependencies.
+
+ Args:
+ op: An Operation.
+ """
+ for controller in self._control_dependencies_stack:
+ controller.add_op(op)
+
+ def control_dependencies(self, control_inputs):
+ """Returns a context manager that specifies control dependencies.
+
+ Use with the `with` keyword to specify that all operations constructed
+ within the context should have control dependencies on
+ `control_inputs`. For example:
+
+ ```python
+ with g.control_dependencies([a, b, c]):
+ # `d` and `e` will only run after `a`, `b`, and `c` have executed.
+ d = ...
+ e = ...
+ ```
+
+ Multiple calls to `control_dependencies()` can be nested, and in
+ that case a new `Operation` will have control dependencies on the union
+ of `control_inputs` from all active contexts.
+
+ ```python
+ with g.control_dependencies([a, b]):
+ # Ops declared here run after `a` and `b`.
+ with g.control_dependencies([c, d]):
+ # Ops declared here run after `a`, `b`, `c`, and `d`.
+ ```
+
+ *N.B.* The control dependencies context applies *only* to ops that
+ are constructed within the context. Merely using an op or tensor
+ in the context does not add a control dependency. The following
+ example illustrates this point:
+
+ ```python
+ # WRONG
+ def my_func(pred, tensor):
+ t = tf.matmul(tensor, tensor)
+ with tf.control_dependencies([pred]):
+ # The matmul op is created outside the context, so no control
+ # dependency will be added.
+ return t
+
+ # RIGHT
+ def my_func(pred, tensor):
+ with tf.control_dependencies([pred]):
+ # The matmul op is created in the context, so a control dependency
+ # will be added.
+ return tf.matmul(tensor, tensor)
+ ```
+
+ Args:
+ control_inputs: A list of `Operation` or `Tensor` objects, which
+ must be executed or computed before running the operations
+ defined in the context.
+
+ Returns:
+ A context manager that specifies control dependencies for all
+ operations constructed within the context.
+
+ Raises:
+ TypeError: If `control_inputs` is not a list of `Operation` or
+ `Tensor` objects.
+ """
+ # First convert the inputs to ops, and deduplicate them.
+ # NOTE(mrry): Other than deduplication, we do not currently track direct
+ # or indirect dependencies between control_inputs, which may result in
+ # redundant control inputs.
+ control_ops = []
+ current = self._current_control_dependencies()
+ for c in control_inputs:
+ if isinstance(c, Tensor):
+ c = c.op
+ elif not isinstance(c, Operation):
+ raise TypeError("Control input must be Operation or Tensor: %s" % c)
+ if c not in current:
+ control_ops.append(c)
+ current.add(c)
+ return self._ControlDependenciesController(self, control_ops)
+
+ # pylint: disable=g-doc-return-or-yield
+ @contextlib.contextmanager
+ def _kernel_label_map(self, op_to_kernel_label_map):
+ """EXPERIMENTAL: A context manager for setting kernel labels.
+
+ This context manager can be used to select particular
+ implementations of kernels within the scope of the context.
+
+ For example:
+
+ with ops.Graph().as_default() as g:
+ f_1 = Foo() # Uses the default registered kernel for the Foo op.
+ with g.kernel_label_map({"Foo": "v_2"}):
+ f_2 = Foo() # Uses the registered kernel with label "v_2"
+ # for the Foo op.
+ with g.kernel_label_map({"Foo": "v_3"}):
+ f_3 = Foo() # Uses the registered kernel with label "v_3"
+ # for the Foo op.
+ with g.kernel_label_map({"Foo": ""}):
+ f_4 = Foo() # Uses the default registered kernel
+ # for the Foo op.
+
+ Args:
+ op_to_kernel_label_map: A dictionary mapping op type strings to
+ kernel label strings.
+
+ Returns:
+ A context manager that sets the kernel label to be used for one or more
+ ops created in that context.
+
+ Raises:
+ TypeError: If op_to_kernel_label_map is not a dictionary mapping
+ strings to strings.
+ """
+ if not isinstance(op_to_kernel_label_map, dict):
+ raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
+ "strings to strings")
+ # The saved_labels dictionary stores any currently-set labels that
+ # will be overridden by this context manager.
+ saved_labels = {}
+ # Install the given label
+ for op_type, label in op_to_kernel_label_map.items():
+ if not (isinstance(op_type, basestring)
+ and isinstance(label, basestring)):
+ raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
+ "strings to strings")
+ try:
+ saved_labels[op_type] = self._op_to_kernel_label_map[op_type]
+ except KeyError:
+ pass
+ self._op_to_kernel_label_map[op_type] = label
+ try:
+ yield # The code within the context runs here.
+ finally:
+ # Remove the labels set for this context, and restore any saved labels.
+ for op_type, label in op_to_kernel_label_map.items():
+ try:
+ self._op_to_kernel_label_map[op_type] = saved_labels[op_type]
+ except KeyError:
+ del self._op_to_kernel_label_map[op_type]
+ # pylint: enable=g-doc-return-or-yield
+
+ # pylint: disable=g-doc-return-or-yield
+ @contextlib.contextmanager
+ def gradient_override_map(self, op_type_map):
+ """EXPERIMENTAL: A context manager for overriding gradient functions.
+
+ This context manager can be used to override the gradient function
+ that will be used for ops within the scope of the context.
+
+ For example:
+
+ ```python
+ @tf.RegisterGradient("CustomSquare")
+ def _custom_square_grad(op, inputs):
+ # ...
+
+ with tf.Graph().as_default() as g:
+ c = tf.constant(5.0)
+ s_1 = tf.square(c) # Uses the default gradient for tf.square.
+ with g.gradient_override_map({"Square": "CustomSquare"}):
+ s_2 = tf.square(s_2) # Uses _custom_square_grad to compute the
+ # gradient of s_2.
+ ```
+
+ Args:
+ op_type_map: A dictionary mapping op type strings to alternative op
+ type strings.
+
+ Returns:
+ A context manager that sets the alternative op type to be used for one
+ or more ops created in that context.
+
+ Raises:
+ TypeError: If `op_type_map` is not a dictionary mapping strings to
+ strings.
+ """
+ if not isinstance(op_type_map, dict):
+ raise TypeError("op_type_map must be a dictionary mapping "
+ "strings to strings")
+ # The saved_mappings dictionary stores any currently-set mappings that
+ # will be overridden by this context manager.
+ saved_mappings = {}
+ # Install the given label
+ for op_type, mapped_op_type in op_type_map.items():
+ if not (isinstance(op_type, basestring)
+ and isinstance(mapped_op_type, basestring)):
+ raise TypeError("op_type_map must be a dictionary mapping "
+ "strings to strings")
+ try:
+ saved_mappings[op_type] = self._gradient_override_map[op_type]
+ except KeyError:
+ pass
+ self._gradient_override_map[op_type] = mapped_op_type
+ try:
+ yield # The code within the context runs here.
+ finally:
+ # Remove the labels set for this context, and restore any saved labels.
+ for op_type, mapped_op_type in op_type_map.items():
+ try:
+ self._gradient_override_map[op_type] = saved_mappings[op_type]
+ except KeyError:
+ del self._gradient_override_map[op_type]
+ # pylint: enable=g-doc-return-or-yield
+
+
+def device(dev):
+ """Wrapper for `Graph.device()` using the default graph.
+
+ See [`Graph.name_scope()`](framework.md#Graph.name_scope) for more details.
+
+ Args:
+ device_name_or_function: The device name or function to use in
+ the context.
+
+ Returns:
+ A context manager that specifies the default device to use for newly
+ created ops.
+ """
+ return get_default_graph().device(dev)
+
+
+def name_scope(name):
+ """Wrapper for `Graph.name_scope()` using the default graph.
+
+ See [`Graph.name_scope()`](framework.md#Graph.name_scope) for more details.
+
+ Args:
+ name: A name for the scope.
+
+ Returns:
+ A context manager that installs `name` as a new name scope in the
+ default graph.
+ """
+ return get_default_graph().name_scope(name)
+
+
+def control_dependencies(control_inputs):
+ """Wrapper for `Graph.control_dependencies()` using the default graph.
+
+ See [`Graph.control_dependencies()`](framework.md#Graph.control_dependencies)
+ for more details.
+
+ Args:
+ control_inputs: A list of `Operation` or `Tensor` objects, which
+ must be executed or computed before running the operations
+ defined in the context.
+
+ Returns:
+ A context manager that specifies control dependencies for all
+ operations constructed within the context.
+ """
+ return get_default_graph().control_dependencies(control_inputs)
+
+
+class _DefaultStack(threading.local):
+ """A thread-local stack of objects for providing implicit defaults."""
+
+ def __init__(self):
+ super(_DefaultStack, self).__init__()
+ self.stack = []
+
+ def get_default(self):
+ return self.stack[-1] if len(self.stack) >= 1 else None
+
+ def reset(self):
+ self.stack = []
+
+ @contextlib.contextmanager
+ def get_controller(self, default):
+ """A context manager for manipulating a default stack."""
+ try:
+ self.stack.append(default)
+ yield default
+ finally:
+ assert self.stack[-1] is default
+ self.stack.pop()
+
+
+_default_session_stack = _DefaultStack()
+
+
+def default_session(session):
+ """Python "with" handler for defining a default session.
+
+ This function provides a means of registering a session for handling
+ Tensor.eval() and Operation.run() calls. It is primarily intended for use
+ by session.Session, but can be used with any object that implements
+ the Session.run() interface.
+
+ Use with the "with" keyword to specify that Tensor.eval() and Operation.run()
+ invocations within the scope of a block should be executed by a particular
+ session.
+
+ The default session applies to the current thread only, so it is always
+ possible to inspect the call stack and determine the scope of a default
+ session. If you create a new thread, and wish to use the default session
+ in that thread, you must explicitly add a "with ops.default_session(sess):"
+ block in that thread's function.
+
+ Example:
+ The following code examples are equivalent:
+
+ # 1. Using the Session object directly:
+ sess = ...
+ c = tf.constant(5.0)
+ sess.run(c)
+
+ # 2. Using default_session():
+ sess = ...
+ with ops.default_session(sess):
+ c = tf.constant(5.0)
+ result = c.eval()
+
+ # 3. Overriding default_session():
+ sess = ...
+ with ops.default_session(sess):
+ c = tf.constant(5.0)
+ with ops.default_session(...):
+ c.eval(session=sess)
+
+ Args:
+ session: The session to be installed as the default session.
+
+ Returns:
+ A context manager for the default session.
+ """
+ return _default_session_stack.get_controller(weakref.ref(session))
+
+
+def get_default_session():
+ """Returns the default session for the current thread.
+
+ The returned `Session` will be the innermost session on which a
+ `Session` or `Session.as_default()` context has been entered.
+
+ *N.B.* The default session is a property of the current thread. If you
+ create a new thread, and wish to use the default session in that
+ thread, you must explicitly add a `with sess.as_default():` in that
+ thread's function.
+
+ Returns:
+ The default `Session` being used in the current thread.
+ """
+ ref = _default_session_stack.get_default()
+ if ref is None:
+ # No default session has been registered.
+ return None
+ else:
+ # De-reference ref.
+ ret = ref()
+ if ret is None:
+ # This should never happen with the current session implementations.
+ raise RuntimeError("Default session has been garbage collected.")
+ return ret
+
+
+def _eval_using_default_session(tensors, feed_dict, graph, session=None):
+ """Uses the default session to evaluate one or more tensors.
+
+ Args:
+ tensors: A single Tensor, or a list of Tensor objects.
+ feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
+ numpy ndarrays, TensorProtos, or strings.
+ graph: The graph in which the tensors are defined.
+ session: (Optional) A different session to use to evaluate "tensors".
+
+ Returns:
+ Either a single numpy ndarray if "tensors" is a single tensor; or a list
+ of numpy ndarrays that each correspond to the respective element in
+ "tensors".
+
+ Raises:
+ ValueError: If no default session is available; the default session
+ does not have "graph" as its graph; or if "session" is specified,
+ and it does not have "graph" as its graph.
+ """
+ if session is None:
+ session = get_default_session()
+ if session is None:
+ raise ValueError("Cannot evaluate tensor using eval(): No default "
+ "session is registered. Use 'with "
+ "DefaultSession(sess)' or pass an explicit session to "
+ "eval(session=sess)")
+ if session.graph is not graph:
+ raise ValueError("Cannot use the default session to evaluate tensor: "
+ "the tensor's graph is different from the session's "
+ "graph. Pass an explicit session to "
+ "eval(session=sess).")
+ else:
+ if session.graph is not graph:
+ raise ValueError("Cannot use the given session to evaluate tensor: "
+ "the tensor's graph is different from the session's "
+ "graph.")
+ return session.run(tensors, feed_dict)
+
+
+def _run_using_default_session(operation, feed_dict, graph, session=None):
+ """Uses the default session to run "operation".
+
+ Args:
+ operation: The Operation to be run.
+ feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
+ numpy ndarrays, TensorProtos, or strings.
+ graph: The graph in which "operation" is defined.
+ session: (Optional) A different session to use to run "operation".
+
+ Raises:
+ ValueError: If no default session is available; the default session
+ does not have "graph" as its graph; or if "session" is specified,
+ and it does not have "graph" as its graph.
+ """
+ if session is None:
+ session = get_default_session()
+ if session is None:
+ raise ValueError("Cannot execute operation using Run(): No default "
+ "session is registered. Use 'with "
+ "default_session(sess)' or pass an explicit session to "
+ "Run(session=sess)")
+ if session.graph is not graph:
+ raise ValueError("Cannot use the default session to execute operation: "
+ "the operation's graph is different from the "
+ "session's graph. Pass an explicit session to "
+ "Run(session=sess).")
+ else:
+ if session.graph is not graph:
+ raise ValueError("Cannot use the given session to execute operation: "
+ "the operation's graph is different from the session's "
+ "graph.")
+ session.run(operation, feed_dict)
+
+
+class _DefaultGraphStack(_DefaultStack):
+ """A thread-local stack of objects for providing an implicit default graph."""
+
+ def __init__(self):
+ super(_DefaultGraphStack, self).__init__()
+ self._global_default_graph = None
+
+ def get_default(self):
+ """Override that returns a global default if the stack is empty."""
+ ret = super(_DefaultGraphStack, self).get_default()
+ if ret is None:
+ ret = self._GetGlobalDefaultGraph()
+ return ret
+
+ def _GetGlobalDefaultGraph(self):
+ if self._global_default_graph is None:
+ # TODO(mrry): Perhaps log that the default graph is being used, or set
+ # provide some other feedback to prevent confusion when a mixture of
+ # the global default graph and an explicit graph are combined in the
+ # same process.
+ self._global_default_graph = Graph()
+ return self._global_default_graph
+
+ def reset(self):
+ super(_DefaultGraphStack, self).reset()
+ self._global_default_graph = None
+
+_default_graph_stack = _DefaultGraphStack()
+
+
+def reset_default_graph():
+ """Clears the default graph stack and resets the global default graph.
+
+ *N.B.* The default graph is a property of the current thread. This
+ function applies only to the current thread.
+ """
+ _default_graph_stack.reset()
+
+
+def get_default_graph():
+ """Returns the default graph for the current thread.
+
+ The returned graph will be the innermost graph on which a
+ `Graph.as_default()` context has been entered, or a global default
+ graph if none has been explicitly created.
+
+ *N.B.* The default graph is a property of the current thread. If you
+ create a new thread, and wish to use the default graph in that
+ thread, you must explicitly add a `with g.as_default():` in that
+ thread's function.
+
+ Returns:
+ The default `Graph` being used in the current thread.
+ """
+ return _default_graph_stack.get_default()
+
+
+def _get_graph_from_inputs(op_input_list, graph=None):
+ """Returns the appropriate graph to use for the given inputs.
+
+ This library method provides a consistent algorithm for choosing the graph
+ in which an Operation should be constructed:
+
+ 1. If the "graph" is specified explicitly, we validate that all of the inputs
+ in "op_input_list" are compatible with that graph.
+ 2. Otherwise, we attempt to select a graph from the first Operation-
+ or Tensor-valued input in "op_input_list", and validate that all other
+ such inputs are in the same graph.
+ 3. If the graph was not specified and it could not be inferred from
+ "op_input_list", we attempt to use the default graph.
+
+ Args:
+ op_input_list: A list of inputs to an operation, which may include Tensor
+ and Operation objects.
+ graph: (Optional) The explicit graph to use.
+
+ Raises:
+ TypeError: If op_input_list is not a list or tuple, or if graph is not a
+ Graph.
+ ValueError: If a graph is explicitly passed and not all inputs are from it,
+ or if the inputs are from multiple graphs, or we could not find a graph
+ and there was no default graph.
+
+ Returns:
+ The appropriate graph to use for the given inputs.
+ """
+ if not isinstance(op_input_list, (list, tuple)):
+ raise TypeError("The op_input_list must be a list or tuple")
+
+ # 1. If the graph is specified explicitly, we validate that all of the inputs
+ # are compatible with that graph.
+ if graph is not None:
+ if not isinstance(graph, Graph):
+ raise TypeError("Input graph needs to be a Graph: %s" % graph)
+ for op_input in op_input_list:
+ if isinstance(op_input, Operation):
+ if op_input.graph is not graph:
+ raise ValueError("Operation %s is not from the passed-in graph"
+ % op_input)
+ elif isinstance(op_input, Tensor):
+ if op_input.graph is not graph:
+ raise ValueError("Tensor %s is not from the passed-in graph"
+ % op_input)
+ return graph
+
+ # 2. Otherwise, we attempt to select a graph from one of the Operation-
+ # or Tensor-valued inputs.
+ original_input = None
+ for op_input in op_input_list:
+ if isinstance(op_input, (Operation, Tensor)):
+ if original_input is None:
+ original_input = op_input
+ else:
+ assert_same_graph([original_input, op_input])
+ if original_input is not None:
+ return original_input.graph
+
+ # 3. If all else fails, we use the default graph, which is always there.
+ return get_default_graph()
+
+
+class GraphKeys(object):
+ """Standard names to use for graph collections.
+
+ The standard library uses various well-known names to collect and
+ retrieve values associated with a graph. For example, the
+ `tf.Optimizer` subclasses default to optimizing the variables
+ collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is
+ specified, but it is also possible to pass an explicit list of
+ variables.
+
+ The following standard keys are defined:
+
+ * `VARIABLES`: the `Variable` objects that comprise a model, and
+ must be saved and restored together. See
+ [`tf.all_variables()`](state_ops.md#all_variables) for more details.
+ * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will
+ be trained by an optimizer. See
+ [`tf.trainable_variables()`](state_ops.md#trainable_variables)
+ for more details.
+ * `SUMMARIES`: the summary `Tensor` objects that have been created
+ in the graph. See [`tf.merge_all_summaries()`](train.md#merge_all_summaries)
+ for more details.
+ * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to
+ produce input for a computation. See
+ [`tf.start_queue_runners()`](train.md#start_queue_runners) for more details.
+ """
+
+ # Key to collect variables.Variable objects that must be saved and restored
+ # by the model.
+ VARIABLES = "variables"
+ # Key to collect variables.Variable objects that will be trained by the
+ # optimizers.
+ TRAINABLE_VARIABLES = "trainable_variables"
+ # Key to collect summaries.
+ SUMMARIES = "summaries"
+ # Key to collect QueueRunners.
+ QUEUE_RUNNERS = "queue_runners"
+ # Key to collect table initializers.
+ TABLE_INITIALIZERS = "table_initializer"
+
+
+def add_to_collection(name, value):
+ """Wrapper for `Graph.add_to_collection()` using the default graph.
+
+ See [`Graph.add_to_collection()`](framework.md#Graph.add_to_collection)
+ for more details.
+
+ Args:
+ name: The key for the collection. For example, the `GraphKeys` class
+ contains many standard names for collections.
+ value: The value to add to the collection.
+ """
+ get_default_graph().add_to_collection(name, value)
+
+
+def get_collection(key, scope=None):
+ """Wrapper for `Graph.get_collection()` using the default graph.
+
+ See [`Graph.get_collection()`](framework.md#Graph.get_collection)
+ for more details.
+
+ Args:
+ key: The key for the collection. For example, the `GraphKeys` class
+ contains many standard names for collections.
+ scope: (Optional.) If supplied, the resulting list is filtered to include
+ only items whose name begins with this string.
+
+ Returns:
+ The list of values in the collection with the given `name`, or
+ an empty list if no value has been added to that collection. The
+ list contains the values in the order under which they were
+ collected.
+ """
+ return get_default_graph().get_collection(key, scope)
+
+
+# pylint: disable=g-doc-return-or-yield
+@contextlib.contextmanager
+def op_scope(values, name, default_name):
+ """Returns a context manager for use when defining a Python op.
+
+ This context manager validates that the given `values` are from the
+ same graph, ensures that that graph is the default graph, and pushes a
+ name scope.
+
+ For example, to define a new Python op called `my_op`:
+
+ ```python
+ def my_op(a, b, c, name=None):
+ with tf.op_scope([a, b, c], name, "MyOp") as scope:
+ a = tf.convert_to_tensor(a, name="a")
+ b = tf.convert_to_tensor(b, name="b")
+ c = tf.convert_to_tensor(c, name="c")
+ # Define some computation that uses `a`, `b`, and `c`.
+ return foo_op(..., name=scope)
+ ```
+
+ Args:
+ values: The list of `Tensor` arguments that are passed to the op function.
+ name: The name argument that is passed to the op function.
+ default_name: The default name to use if the `name` argument is `None`.
+
+ Returns:
+ A context manager for use in defining a Python op.
+ """
+ g = _get_graph_from_inputs(values)
+ n = default_name if name is None else name
+ with g.as_default(), g.name_scope(n) as scope:
+ yield scope
+# pylint: enable=g-doc-return-or-yield
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
new file mode 100644
index 0000000000..a406c5e56e
--- /dev/null
+++ b/tensorflow/python/framework/ops_test.py
@@ -0,0 +1,825 @@
+"""Tests for tensorflow.python.framework.ops."""
+import tensorflow.python.platform
+
+from tensorflow.python.framework import device as pydev
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_kernel_label_op
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.platform import googletest
+
+
+class TensorTest(test_util.TensorFlowTestCase):
+
+ def testShape(self):
+ op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(),
+ [], [types.float32])
+ t = op.outputs[0]
+ self.assertEquals(tensor_shape.unknown_shape(), t.get_shape())
+ t.set_shape([1, 2, 3])
+ self.assertEquals([1, 2, 3], t.get_shape())
+
+
+class NodeDefConstructorTest(test_util.TensorFlowTestCase):
+
+ def testNoArgs(self):
+ nodedef = ops._NodeDef("noop", "bar")
+ self.assertProtoEquals("op: 'noop' name: 'bar'", nodedef)
+
+ def testArgs(self):
+ nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*")
+ self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'",
+ nodedef)
+ nodedef = ops._NodeDef("foo", "bar", device=pydev.Device(job="j"))
+ self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef)
+
+
+# NOTE(mrry): Dummy shape registrations for ops used in the tests.
+ops.RegisterShape("a")(None)
+ops.RegisterShape("b")(None)
+ops.RegisterShape("c")(None)
+ops.RegisterShape("add")(None)
+ops.RegisterShape("an_op")(None)
+ops.RegisterShape("const")(None)
+ops.RegisterShape("copy")(None)
+ops.RegisterShape("foo")(None)
+ops.RegisterShape("identity")(None)
+ops.RegisterShape("mul")(None)
+ops.RegisterShape("nonrefop")(None)
+ops.RegisterShape("noop")(None)
+ops.RegisterShape("refop")(None)
+
+
+def _apply_op(g, *args, **kwargs):
+ op = g.create_op(*args, **kwargs)
+ if len(op.outputs) == 1:
+ return op.outputs[0]
+ else:
+ return op.outputs
+
+
+class OperationTest(test_util.TensorFlowTestCase):
+
+ def testNoInputs(self):
+ op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(),
+ [],
+ [types.float32, types.string])
+ self.assertEquals(2, len(op.values()))
+ self.assertEquals(0, len(op.inputs))
+ self.assertEquals("myop", op.name)
+
+ float_t, label_str_t = op.values()
+ self.assertEquals(types.float32, float_t.dtype)
+ self.assertEquals(op, float_t.op)
+ self.assertEquals(0, float_t._value_index)
+ self.assertEquals(0, len(float_t._consumers))
+ self.assertEquals("myop", float_t._as_node_def_input())
+
+ self.assertEquals(types.string, label_str_t.dtype)
+ self.assertEquals(op, label_str_t.op)
+ self.assertEquals(1, label_str_t._value_index)
+ self.assertEquals(0, len(label_str_t._consumers))
+ self.assertEquals("myop:1", label_str_t._as_node_def_input())
+
+ self.assertProtoEquals("op:'noop' name:'myop'", op.node_def)
+
+ def testNoOutputs(self):
+ g = ops.Graph()
+ op1 = ops.Operation(
+ ops._NodeDef("noop", "myop1"), g, [], [types.float32])
+ float_t, = op1.values()
+ op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g, [float_t], [])
+ self.assertEquals(0, len(op2.values()))
+ self.assertEquals(1, len(op2.inputs))
+ self.assertIs(float_t, op2.inputs[0])
+
+ self.assertEquals(1, len(float_t._consumers))
+ self.assertEquals(op2, float_t._consumers[0])
+
+ self.assertProtoEquals("op:'noop' name:'myop1'", op1.node_def)
+ self.assertProtoEquals("op:'reop' name:'myop2' input:'myop1'",
+ op2.node_def)
+
+ def testInputsAndOutputs(self):
+ g = ops.Graph()
+ op1 = ops.Operation(
+ ops._NodeDef("noop", "myop1"), g, [], [types.float32])
+ self.assertEquals(1, len(op1.values()))
+ float1_t, = op1.values()
+
+ op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g,
+ [], [types.float32, types.string])
+ self.assertEquals(2, len(op2.values()))
+ float2_t, label2_str_t = op2.values()
+
+ # Note that we consume label2_str_t twice here.
+ op3 = ops.Operation(ops._NodeDef("add", "myop3"), g,
+ [float1_t, label2_str_t, label2_str_t],
+ [types.float32, types.int32])
+ self.assertEquals(2, len(op3.values()))
+
+ self.assertEquals(1, len(float1_t._consumers))
+ self.assertEquals(op3, float1_t._consumers[0])
+
+ self.assertEquals(0, len(float2_t._consumers))
+
+ self.assertEquals(2, len(label2_str_t._consumers))
+ self.assertEquals(op3, label2_str_t._consumers[0])
+ self.assertEquals(op3, label2_str_t._consumers[1])
+
+ self.assertProtoEquals("""
+ op:'add' name:'myop3'
+ input:'myop1' input:'myop2:1' input:'myop2:1'
+ """, op3.node_def)
+
+ def testDeviceObject(self):
+ op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [], [])
+ op._set_device("/job:goo/device:GPU:0")
+ self.assertProtoEquals(
+ "op:'noop' name:'myop' device:'/job:goo/device:GPU:0' ",
+ op.node_def)
+ op = ops.Operation(ops._NodeDef("noop", "op2"), ops.Graph(), [], [])
+ op._set_device(pydev.Device(job="muu", device_type="CPU", device_index=0))
+ self.assertProtoEquals(
+ "op:'noop' name:'op2' device:'/job:muu/device:CPU:0'",
+ op.node_def)
+
+ def testReferenceInput(self):
+ g = ops.Graph()
+ op1 = ops.Operation(ops._NodeDef("noop", "op1"), g, [],
+ [types.float32_ref, types.float32])
+ self.assertProtoEquals("op:'noop' name:'op1'",
+ op1.node_def)
+ ref_t, nonref_t = op1.values()
+ # NOTE(mrry): Must specify input_types to preserve ref-typed input.
+ op2 = ops.Operation(
+ ops._NodeDef("refop", "op2"), g, [ref_t, nonref_t], [],
+ input_types=[types.float32_ref, types.float32])
+ self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'",
+ op2.node_def)
+ op3 = ops.Operation(
+ ops._NodeDef("nonrefop", "op3"), g, [ref_t, nonref_t], [])
+ self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'",
+ op3.node_def)
+
+ def testInvalidNames(self):
+ g = ops.Graph()
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", ""), g)
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", "_invalid"), g)
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", "-invalid"), g)
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", "/invalid"), g)
+
+ def testShapeFunctionAbsence(self):
+ def _test():
+ pass
+ g = ops.Graph()
+ with self.assertRaises(RuntimeError):
+ g.create_op("shapeless_op", [], [types.float32])
+
+ def testNoShapeFunction(self):
+ g = ops.Graph()
+ op = ops.Operation(ops._NodeDef("op", "an_op"), g,
+ output_types = [types.float32])
+ self.assertEquals(tensor_shape.unknown_shape(),
+ _apply_op(g, "an_op", [], [types.float32]).get_shape())
+
+class CreateOpTest(test_util.TensorFlowTestCase):
+
+ def testNodeDefArgs(self):
+ g = ops.Graph()
+ op1 = g.create_op("const", [], [types.float32], None, name="myop1")
+ with g.device("/device:GPU"):
+ op2 = g.create_op("add",
+ [],
+ [types.float32, types.string], None,
+ name="myop2")
+ op3 = g.create_op(
+ "foo",
+ [op1.values()[0], op2.values()[1], op2.values()[0]],
+ [types.float32, types.int32], None,
+ name="myop3")
+ self.assertEquals(None, op1.device)
+ self.assertEquals("/device:GPU", op2.device)
+ self.assertEquals(None, op3.device)
+ self.assertProtoEquals("name:'myop1' op:'const'", op1.node_def)
+ self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU'",
+ op2.node_def)
+ self.assertProtoEquals(
+ "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'foo'",
+ op3.node_def)
+
+ def testReferenceInput(self):
+ g = ops.Graph()
+ op1 = g.create_op("noop", [],
+ [types.float32_ref, types.float32], name="op1")
+ self.assertProtoEquals("op:'noop' name:'op1'", op1.node_def)
+ ref_t, nonref_t = op1.values()
+ # NOTE(mrry): Must specify input_types to preserve ref-typed input.
+ op2 = g.create_op("refop", [ref_t, nonref_t], [],
+ input_types=[types.float32_ref, types.float32],
+ name="op2")
+ self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'",
+ op2.node_def)
+ op3 = g.create_op("nonrefop", [ref_t, nonref_t], [], name="op3")
+ self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'",
+ op3.node_def)
+
+ def testFinalized(self):
+ g = ops.Graph()
+ g.finalize()
+ with self.assertRaises(RuntimeError):
+ g.create_op("const", [], [types.float32], None, name="myop1")
+
+
+class ApplyOpTest(test_util.TensorFlowTestCase):
+
+ def testNodeDefArgs(self):
+ g = ops.Graph()
+ t1 = _apply_op(g, "const", [], [types.float32], name="myop1")
+ with g.device("/device:GPU"):
+ t2 = _apply_op(g, "add",
+ [],
+ [types.float32, types.string],
+ name="myop2")
+ t3 = _apply_op(g, "foo", [t1, t2[1], t2[0]],
+ [types.float32, types.int32], name="myop3")
+ self.assertTrue(isinstance(t1, ops.Tensor))
+ self.assertTrue(isinstance(t2, list))
+ self.assertTrue(isinstance(t3, list))
+ self.assertTrue(isinstance(t3[0], ops.Tensor))
+ self.assertEquals("myop1", t1._as_node_def_input())
+ self.assertEquals("myop2", t2[0]._as_node_def_input())
+ self.assertEquals("myop2:1", t2[1]._as_node_def_input())
+ self.assertEquals("myop3", t3[0]._as_node_def_input())
+ # Validate that we got the right ops as well
+ self.assertProtoEquals("name:'myop1' op:'const'", t1.op.node_def)
+ self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU'",
+ t2[0].op.node_def)
+ self.assertProtoEquals(
+ "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'foo'",
+ t3[0].op.node_def)
+
+ def testReferenceInput(self):
+ g = ops.Graph()
+ ref_t, nonref_t = _apply_op(
+ g, "noop", [], [types.float32_ref, types.float32], name="op1")
+ self.assertProtoEquals("op:'noop' name:'op1'", ref_t.op.node_def)
+ # NOTE(mrry): Must specify input_types to preserve ref-typed input.
+ out_2 = _apply_op(g, "refop", [ref_t, nonref_t], [types.int32],
+ input_types=[types.float32_ref, types.float32],
+ name="op2")
+ self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'",
+ out_2.op.node_def)
+ out_3 = _apply_op(g, "nonrefop", [ref_t, nonref_t], [types.int32],
+ name="op3")
+ self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'",
+ out_3.op.node_def)
+
+
+class NameStackTest(test_util.TensorFlowTestCase):
+
+ def testBasics(self):
+ g = ops.Graph()
+ self.assertEquals("foo", g.unique_name("foo"))
+ self.assertEquals("foo_1", g.unique_name("foo"))
+ self.assertEquals("foo_2", g.unique_name("foo"))
+ self.assertEquals("foo_1_1", g.unique_name("foo_1"))
+ self.assertEquals("foo_1_2", g.unique_name("foo_1"))
+ self.assertEquals("foo_1_2_1", g.unique_name("foo_1_2"))
+ with g.name_scope("bar"):
+ self.assertEquals("bar/foo", g.unique_name("foo"))
+ self.assertEquals("bar/foo_1", g.unique_name("foo"))
+ with g.name_scope(None):
+ self.assertEquals("foo_3", g.unique_name("foo"))
+ with g.name_scope("baz"):
+ self.assertEquals("bar/baz/foo", g.unique_name("foo"))
+ self.assertEquals("bar/baz/foo_1", g.unique_name("foo"))
+ with g.name_scope("baz"):
+ self.assertEquals("bar/baz_1/foo", g.unique_name("foo"))
+ self.assertEquals("bar/baz_1/foo_1", g.unique_name("foo"))
+ with g.name_scope("quux"):
+ self.assertEquals("quux/foo", g.unique_name("foo"))
+ with g.name_scope("bar"):
+ with g.name_scope("baz"):
+ self.assertEquals("bar_1/baz/foo", g.unique_name("foo"))
+ self.assertEquals("foo_4", g.unique_name("foo"))
+ self.assertEquals("bar_2", g.unique_name("bar"))
+
+ def testOutOfOrderUniqueName(self):
+ g = ops.Graph()
+ self.assertEquals("foo_2", g.unique_name("foo_2"))
+ self.assertEquals("foo", g.unique_name("foo"))
+ self.assertEquals("foo_1", g.unique_name("foo"))
+ self.assertEquals("foo_3", g.unique_name("foo"))
+
+
+class NameTest(test_util.TensorFlowTestCase):
+
+ def testGenerateName(self):
+ g = ops.Graph()
+ op0 = g.create_op("const", [], [types.float32, types.float32])
+ self.assertEquals("const", op0.name)
+ self.assertEquals("const:0", op0.outputs[0].name)
+ self.assertEquals("const:1", op0.outputs[1].name)
+
+ op1 = g.create_op("const", [], [types.float32])
+ self.assertEquals("const_1", op1.name)
+ self.assertEquals("const_1:0", op1.outputs[0].name)
+
+ op2 = g.create_op("const", [], [types.float32], name="my_op")
+ self.assertEquals("my_op", op2.name)
+ self.assertEquals("my_op:0", op2.outputs[0].name)
+
+ def testname_scope(self):
+ g = ops.Graph()
+
+ with g.name_scope("foo") as foo:
+ self.assertEquals(foo, "foo/")
+ with g.name_scope("foo2") as foo2:
+ self.assertEquals(foo2, "foo/foo2/")
+ with g.name_scope(None) as empty1:
+ self.assertEquals(empty1, "")
+ with g.name_scope("foo3") as foo3:
+ self.assertEquals(foo3, "foo3/")
+ with g.name_scope("") as empty2:
+ self.assertEquals(empty2, "")
+
+ self.assertEquals("const",
+ g.create_op("const", [], [types.float32]).name)
+ with g.name_scope("bar") as scope:
+ self.assertEquals("bar/const",
+ g.create_op("const", [], [types.float32]).name)
+ self.assertEquals("bar/const_1",
+ g.create_op("const", [], [types.float32]).name)
+ # If you use the value from "with .. as", that values is used as-is.
+ self.assertEquals(
+ "bar",
+ g.create_op("const", [], [types.float32], name=scope).name)
+ with g.name_scope("baz") as scope:
+ with g.name_scope("quux"):
+ self.assertEquals("baz/quux/const",
+ g.create_op("const", [], [types.float32]).name)
+ # If you use the value from the enclosing "with .. as", nothing is pushed.
+ with g.name_scope(scope):
+ self.assertEquals("baz/const",
+ g.create_op("const", [], [types.float32]).name)
+ self.assertEquals("baz",
+ g.create_op("const", [], [types.float32],
+ name=scope).name)
+ self.assertEquals("trailing",
+ g.create_op("const", [], [types.float32],
+ name="trailing/").name)
+ with g.name_scope("bar"):
+ self.assertEquals("bar_1/const",
+ g.create_op("const", [], [types.float32]).name)
+ with g.name_scope("bar/"):
+ self.assertEquals("bar/const_2",
+ g.create_op("const", [], [types.float32]).name)
+
+
+class DeviceTest(test_util.TensorFlowTestCase):
+
+ def testNoDevice(self):
+ g = ops.Graph()
+ op = g.create_op("an_op", [], [types.float32])
+ self.assertEqual(None, op.device)
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op" }
+ """, gd)
+
+ def testDevicePartialString(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2"):
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op" device: "/job:worker/replica:2" }
+ """, gd)
+
+ def testDeviceFull(self):
+ g = ops.Graph()
+ with g.device(pydev.Device(job="worker", replica=2, task=0,
+ device_type="CPU",
+ device_index=3)):
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:worker/replica:2/task:0/device:CPU:3" }
+ """, gd)
+
+ def testNesting(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2"):
+ g.create_op("an_op", [], [types.float32])
+ with g.device("/job:worker/replica:3/task:0"):
+ g.create_op("an_op", [], [types.float32])
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:worker/replica:2" }
+ node { name: "an_op_1" op: "an_op"
+ device: "/job:worker/replica:3/task:0" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/replica:2" }
+ """, gd)
+
+ def testNestingString(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2"):
+ g.create_op("an_op", [], [types.float32])
+ with g.device("/job:worker/replica:3/task:0"):
+ g.create_op("an_op", [], [types.float32])
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:worker/replica:2" }
+ node { name: "an_op_1" op: "an_op"
+ device: "/job:worker/replica:3/task:0" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/replica:2" }
+ """, gd)
+
+ def testNestingOverrideGpuCpu(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2/device:CPU:1"):
+ g.create_op("an_op", [], [types.float32])
+ with g.device("/job:worker/replica:2/device:GPU:2"):
+ g.create_op("an_op", [], [types.float32])
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ node { name: "an_op_1" op: "an_op"
+ device: "/job:worker/replica:2/device:GPU:2" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ """, gd)
+
+ def testNestingWithMergeDeviceFunction(self):
+ g = ops.Graph()
+
+ with g.device(pydev.merge_device("/device:GPU:0")):
+ g.create_op("an_op", [], [types.float32])
+ with g.device(pydev.merge_device("/job:worker")):
+ g.create_op("an_op", [], [types.float32])
+ with g.device(pydev.merge_device("/device:CPU:0")):
+ g.create_op("an_op", [], [types.float32])
+ with g.device(pydev.merge_device("/job:ps")):
+ g.create_op("an_op", [], [types.float32])
+ with g.device(pydev.merge_device(None)):
+ g.create_op("an_op", [], [types.float32])
+
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/device:GPU:0" }
+ node { name: "an_op_1" op: "an_op"
+ device: "/job:worker/device:GPU:0" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/device:CPU:0" }
+ node { name: "an_op_3" op: "an_op"
+ device: "/job:ps/device:CPU:0" }
+ node { name: "an_op_4" op: "an_op"
+ device: "/job:ps/device:CPU:0" }
+ """, gd)
+
+ def testNoneClearsDefault(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2/device:CPU:1"):
+ g.create_op("an_op", [], [types.float32])
+ with g.device(None):
+ g.create_op("an_op", [], [types.float32])
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ node { name: "an_op_1" op: "an_op" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ """, gd)
+
+
+class ObjectWithName(object):
+
+ def __init__(self, name):
+ self._name = name
+
+ @property
+ def name(self):
+ return self._name
+
+
+class CollectionTest(test_util.TensorFlowTestCase):
+
+ def testadd_to_collection(self):
+ g = ops.Graph()
+ g.add_to_collection("key", 12)
+ g.add_to_collection("other", "foo")
+ g.add_to_collection("key", 34)
+
+ # Note that only blank1 is returned.
+ g.add_to_collection("blah", 27)
+ blank1 = ObjectWithName("prefix/foo")
+ g.add_to_collection("blah", blank1)
+ blank2 = ObjectWithName("junk/foo")
+ g.add_to_collection("blah", blank2)
+
+ self.assertEquals(["foo"], g.get_collection("other"))
+ self.assertEquals([12, 34], g.get_collection("key"))
+ self.assertEquals([], g.get_collection("nothing"))
+ self.assertEquals([27, blank1, blank2], g.get_collection("blah"))
+ self.assertEquals([blank1], g.get_collection("blah", "prefix"))
+
+ def testDefaulGraph(self):
+ with ops.Graph().as_default():
+ ops.add_to_collection("key", 90)
+ ops.add_to_collection("key", 100)
+ # Collections are ordered.
+ self.assertEquals([90, 100], ops.get_collection("key"))
+
+
+def an_op(g):
+ return _apply_op(g, "an_op", [], [types.float32])
+
+
+ops.NoGradient("an_op")
+
+
+def copy_op(x):
+ return _apply_op(x.graph, "copy", [x], [x.dtype])
+
+
+@ops.RegisterGradient("copy")
+def _CopyGrad(op, x_grad):
+ _ = op
+ return x_grad
+
+
+@ops.RegisterGradient("copy_override")
+def _CopyOverrideGrad(op, x_grad):
+ _ = op
+ return x_grad
+
+
+class RegistrationTest(test_util.TensorFlowTestCase):
+
+ def testRegisterGradients(self):
+ g = ops.Graph()
+ x = an_op(g)
+ y = copy_op(x)
+ fn = ops.get_gradient_function(y.op)
+ self.assertEquals(_CopyGrad, fn)
+
+ def testOverrideGradients(self):
+ g = ops.Graph()
+ x = an_op(g)
+ with g.gradient_override_map({"copy": "copy_override"}):
+ y = copy_op(x)
+ fn = ops.get_gradient_function(y.op)
+ self.assertEquals(_CopyOverrideGrad, fn)
+
+ def testNonExistentOverride(self):
+ g = ops.Graph()
+ x = an_op(g)
+ with g.gradient_override_map({"copy": "unknown_override"}):
+ y = copy_op(x)
+ with self.assertRaisesRegexp(LookupError, "unknown_override"):
+ fn = ops.get_gradient_function(y.op)
+
+
+class ComparisonTest(test_util.TensorFlowTestCase):
+
+ def testMembershipAllowed(self):
+ g = ops.Graph()
+ t1 = _apply_op(g, "const", [], [types.float32], name="myop1")
+ t2 = _apply_op(g, "const", [], [types.float32], name="myop2")
+ self.assertTrue(isinstance(t1, ops.Tensor))
+ self.assertTrue(isinstance(t2, ops.Tensor))
+ self.assertTrue(t1 in [t1])
+ self.assertTrue(t1 not in [t2])
+
+
+class ControlDependenciesTest(test_util.TensorFlowTestCase):
+
+ def testBasic(self):
+ g = ops.Graph()
+ a = _apply_op(g, "const", [], [types.float32])
+ b = _apply_op(g, "const", [], [types.float32])
+ with g.control_dependencies([a]):
+ c = _apply_op(g, "const", [], [types.float32])
+ d = _apply_op(g, "identity", [b], [types.float32])
+ e = _apply_op(g, "identity", [c], [types.float32])
+
+ self.assertEqual(c.op.control_inputs, [a.op])
+ self.assertEqual(d.op.control_inputs, [a.op])
+ # e should be dominated by c.
+ self.assertEqual(e.op.control_inputs, [])
+
+ def testNested(self):
+ g = ops.Graph()
+ a_1 = _apply_op(g, "const", [], [types.float32])
+ a_2 = _apply_op(g, "const", [], [types.float32])
+ a_3 = _apply_op(g, "const", [], [types.float32])
+ a_4 = _apply_op(g, "const", [], [types.float32])
+
+ with g.control_dependencies([a_1, a_2, a_3, a_4]):
+ b_1 = _apply_op(g, "const", [], [types.float32])
+
+ with g.control_dependencies([a_1]):
+ with g.control_dependencies([a_2]):
+ with g.control_dependencies([a_3]):
+ with g.control_dependencies([a_4]):
+ b_2 = _apply_op(g, "const", [], [types.float32])
+
+ self.assertItemsEqual(
+ [a_1.op, a_2.op, a_3.op, a_4.op], b_1.op.control_inputs)
+ self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)
+
+ def testComplex(self):
+ g = ops.Graph()
+
+ # Usage pattern:
+ # * Nodes a_i are constants defined at the outermost scope, and are used
+ # as control inputs for the ith nested scope.
+ # * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
+ # * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
+ # * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
+ # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
+
+ a_1 = _apply_op(g, "const", [], [types.float32])
+ a_2 = _apply_op(g, "const", [], [types.float32])
+ a_3 = _apply_op(g, "const", [], [types.float32])
+ a_4 = _apply_op(g, "const", [], [types.float32])
+
+ with g.control_dependencies([a_1]):
+ b_1 = _apply_op(g, "mul", [a_3, a_4], [types.float32])
+ c_1 = _apply_op(g, "mul", [a_1, b_1], [types.float32])
+ d_1 = _apply_op(g, "mul", [b_1, c_1], [types.float32])
+ e_1 = _apply_op(g, "const", [], [types.float32])
+ with g.control_dependencies([a_2]):
+ b_2 = _apply_op(g, "mul", [a_3, a_4], [types.float32])
+ c_2 = _apply_op(g, "mul", [a_1, b_1], [types.float32])
+ d_2 = _apply_op(g, "mul", [b_2, c_2], [types.float32])
+ e_2 = _apply_op(g, "mul", [e_1, e_1], [types.float32])
+ with g.control_dependencies([a_3]):
+ b_3 = _apply_op(g, "mul", [a_3, a_4], [types.float32])
+ c_3 = _apply_op(g, "mul", [a_1, b_1], [types.float32])
+ d_3 = _apply_op(g, "mul", [b_3, c_3], [types.float32])
+ e_3 = _apply_op(g, "mul", [e_2, e_2], [types.float32])
+ with g.control_dependencies([a_4]):
+ b_4 = _apply_op(g, "mul", [a_3, a_4], [types.float32])
+ c_4 = _apply_op(g, "mul", [a_1, b_1], [types.float32])
+ d_4 = _apply_op(g, "mul", [b_4, c_4], [types.float32])
+ e_4 = _apply_op(g, "mul", [e_3, e_3], [types.float32])
+
+ self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
+ self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs)
+ self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs)
+ self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs)
+
+ self.assertItemsEqual([], c_1.op.control_inputs)
+ self.assertItemsEqual([a_2.op], c_2.op.control_inputs)
+ self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs)
+ self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs)
+
+ self.assertItemsEqual([], d_1.op.control_inputs)
+ self.assertItemsEqual([], d_2.op.control_inputs)
+ self.assertItemsEqual([], d_3.op.control_inputs)
+ self.assertItemsEqual([], d_4.op.control_inputs)
+
+ self.assertItemsEqual([a_1.op], e_1.op.control_inputs)
+ self.assertItemsEqual([a_2.op], e_2.op.control_inputs)
+ self.assertItemsEqual([a_3.op], e_3.op.control_inputs)
+ self.assertItemsEqual([a_4.op], e_4.op.control_inputs)
+
+ def testRepeatedDependency(self):
+ g = ops.Graph()
+ a = g.create_op("foo", [], [types.float32, types.float32])
+ a_0, a_1 = a.outputs
+ with g.control_dependencies([a_0]):
+ b = _apply_op(g, "const", [], [types.float32])
+ with g.control_dependencies([a_1]):
+ c = _apply_op(g, "const", [], [types.float32])
+
+ self.assertEqual(b.op.control_inputs, [a])
+ self.assertEqual(c.op.control_inputs, [a])
+
+ def testNoControlDependencyWithDataDependency(self):
+ g = ops.Graph()
+ a = _apply_op(g, "const", [], [types.float32])
+ with g.control_dependencies([a]):
+ b = _apply_op(g, "identity", [a], [types.float32])
+
+ self.assertEqual(b.op.control_inputs, [])
+
+
+class GraphTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ ops.reset_default_graph()
+
+ def _AssertDefault(self, expected):
+ self.assertIs(expected, ops.get_default_graph())
+
+ def testGraphContextManager(self):
+ g0 = ops.Graph()
+ with g0.as_default() as g1:
+ self.assertIs(g0, g1)
+
+ def testDefaultGraph(self):
+ orig = ops.get_default_graph()
+ self._AssertDefault(orig)
+ g0 = ops.Graph()
+ self._AssertDefault(orig)
+ context_manager_0 = g0.as_default()
+ self._AssertDefault(orig)
+ with context_manager_0 as g0:
+ self._AssertDefault(g0)
+ with ops.Graph().as_default() as g1:
+ self._AssertDefault(g1)
+ self._AssertDefault(g0)
+ self._AssertDefault(orig)
+
+ def testAsGraphElementConversions(self):
+ class ConvertibleObj(object):
+
+ def _as_graph_element(self):
+ return "const:0"
+
+ class NonConvertibleObj(object):
+
+ pass
+
+ g = ops.Graph()
+ a = _apply_op(g, "const", [], [types.float32])
+ self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
+ with self.assertRaises(TypeError):
+ g.as_graph_element(NonConvertibleObj())
+
+ def testAssertSameGraph(self):
+ g0 = ops.Graph()
+ a = g0.create_op("a", [], [types.float32])
+ b = g0.create_op("b", [], [types.float32])
+ ops.assert_same_graph([a, b])
+ ops.assert_same_graph([a, b], g0)
+ g1 = ops.Graph()
+ c = g1.create_op("c", [], [types.float32])
+ self.assertRaises(ValueError, ops.assert_same_graph, [a, b, c])
+ self.assertRaises(ValueError, ops.assert_same_graph, [c], g0)
+ self.assertRaises(ValueError, ops.assert_same_graph, [a], g1)
+
+ sparse = ops.SparseTensor(
+ _apply_op(g0, "const", [], [types.int64]),
+ _apply_op(g0, "const", [], [types.float32]),
+ _apply_op(g0, "const", [], [types.int64]))
+ ops.assert_same_graph([sparse, a, b])
+ ops.assert_same_graph([sparse, a, b], g0)
+ self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c])
+ self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c], g1)
+
+ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
+
+
+class KernelLabelTest(test_util.TensorFlowTestCase):
+
+ def testNoLabel(self):
+ with self.test_session():
+ self.assertAllEqual("My label is: default",
+ test_kernel_label_op.kernel_label().eval())
+
+ def testLabelMap(self):
+ with self.test_session() as sess:
+ default_1 = test_kernel_label_op.kernel_label()
+ # pylint: disable=protected-access
+ with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
+ overload_1_1 = test_kernel_label_op.kernel_label()
+ with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}):
+ overload_2 = test_kernel_label_op.kernel_label()
+ with sess.graph._kernel_label_map({"KernelLabel": ""}):
+ default_2 = test_kernel_label_op.kernel_label()
+ overload_1_2 = test_kernel_label_op.kernel_label()
+ # pylint: enable=protected-access
+ default_3 = test_kernel_label_op.kernel_label()
+
+ self.assertAllEqual("My label is: default", default_1.eval())
+ self.assertAllEqual("My label is: default", default_2.eval())
+ self.assertAllEqual("My label is: default", default_3.eval())
+ self.assertAllEqual("My label is: overload_1", overload_1_1.eval())
+ self.assertAllEqual("My label is: overload_1", overload_1_2.eval())
+ self.assertAllEqual("My label is: overload_2", overload_2.eval())
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc
new file mode 100644
index 0000000000..5c1b4462d5
--- /dev/null
+++ b/tensorflow/python/framework/python_op_gen.cc
@@ -0,0 +1,678 @@
+#include "tensorflow/python/framework/python_op_gen.h"
+
+#include <stdio.h>
+#include <unordered_map>
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace {
+
+const int kRightMargin = 78;
+
+bool IsPythonReserved(const string& s) {
+ static const std::set<string>* const kPythonReserved = new std::set<string>(
+ {// Keywords in Python, from:
+ // import keyword
+ // print keyword.kwlist
+ "and", "as", "assert", "break", "class", "continue", "def", "del",
+ "elif", "else", "except", "exec", "finally", "for", "from", "global",
+ "if", "import", "in", "is", "lambda", "not", "or", "pass", "print",
+ "raise", "return", "try", "while", "with", "yield",
+ // Built-in functions and types in Python, from:
+ // [x for x in dir(__builtins__) if not x[0].islower()]
+ "ArithmeticError", "AssertionError", "AttributeError", "BaseException",
+ "BufferError", "BytesWarning", "DeprecationWarning", "EOFError",
+ "Ellipsis", "EnvironmentError", "Exception", "False",
+ "FloatingPointError", "FutureWarning", "GeneratorExit", "IOError",
+ "ImportError", "ImportWarning", "IndentationError", "IndexError",
+ "KeyError", "KeyboardInterrupt", "LookupError", "MemoryError",
+ "NameError", "None", "NotImplemented", "NotImplementedError", "OSError",
+ "OverflowError", "PendingDeprecationWarning", "ReferenceError",
+ "RuntimeError", "RuntimeWarning", "StandardError", "StopIteration",
+ "SyntaxError", "SyntaxWarning", "SystemError", "SystemExit", "TabError",
+ "True", "TypeError", "UnboundLocalError", "UnicodeDecodeError",
+ "UnicodeEncodeError", "UnicodeError", "UnicodeTranslateError",
+ "UnicodeWarning", "UserWarning", "ValueError", "Warning",
+ "ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__",
+ "__package__",
+ // Imports and symbols used in the generated code:
+ "_op_def_lib", "text_format", "op_def_pb2", "op_def_library", "ops"});
+
+ return kPythonReserved->count(s) > 0;
+}
+
+// Add a _ to the end of s if necessary to avoid a Python keyword or built-in.
+string AvoidPythonReserved(const string& s) {
+ if (IsPythonReserved(s)) return strings::StrCat(s, "_");
+ return s;
+}
+
+// Indent the first line by "initial" spaces and all following lines
+// by "rest" spaces.
+string Indent(int initial, int rest, StringPiece in) {
+ // TODO(josh11b): Also word-wrapping?
+ string copy(in.data(), in.size());
+ str_util::StripTrailingWhitespace(&copy);
+ std::vector<string> v = str_util::Split(copy, '\n');
+
+ string result;
+ bool first = true;
+ for (const string& line : v) {
+ if (first) {
+ result = strings::StrCat(Spaces(initial), line, "\n");
+ first = false;
+ } else {
+ if (line.empty()) {
+ strings::StrAppend(&result, "\n");
+ } else {
+ strings::StrAppend(&result, Spaces(rest), line, "\n");
+ }
+ }
+ }
+ return result;
+}
+
+// Adds append to *dest, with a space if the first line will be <= width,
+// or a newline otherwise.
+void AppendWithinWidth(string* dest, StringPiece append, int width) {
+ auto first_line = append.find('\n');
+ if (first_line == string::npos) first_line = append.size();
+ if (dest->size() + first_line + 1 /* space */ > static_cast<size_t>(width)) {
+ strings::StrAppend(dest, "\n", append);
+ } else {
+ strings::StrAppend(dest, " ", append);
+ }
+}
+
+void RemoveDescriptionsFromOpDef(OpDef* op_def) {
+ for (int i = 0; i < op_def->input_arg_size(); ++i) {
+ op_def->mutable_input_arg(i)->clear_description();
+ }
+ for (int i = 0; i < op_def->output_arg_size(); ++i) {
+ op_def->mutable_output_arg(i)->clear_description();
+ }
+ for (int i = 0; i < op_def->attr_size(); ++i) {
+ op_def->mutable_attr(i)->clear_description();
+ }
+ op_def->clear_summary();
+ op_def->clear_description();
+}
+
+// Like DataTypeString() but uses the Python names for the
+// float types.
+string PythonDataTypeString(DataType dtype) {
+ switch (dtype) {
+ case DT_FLOAT:
+ return "float32";
+ case DT_DOUBLE:
+ return "float64";
+ default:
+ return DataTypeString(dtype);
+ }
+}
+
+string TypeString(DataType dtype, bool ref) {
+ if (ref) {
+ return strings::StrCat("mutable `", PythonDataTypeString(dtype), "`");
+ } else {
+ return strings::StrCat("`", PythonDataTypeString(dtype), "`");
+ }
+}
+
+string TypeListString(const AttrValue& value) {
+ string ret;
+ for (int t : value.list().type()) {
+ if (!ret.empty()) strings::StrAppend(&ret, ", ");
+ DataType dtype = static_cast<DataType>(t);
+ if (IsRefType(dtype)) {
+ strings::StrAppend(&ret, PythonDataTypeString(RemoveRefType(dtype)),
+ " mutable");
+ } else {
+ strings::StrAppend(&ret, "`", PythonDataTypeString(dtype), "`");
+ }
+ }
+ return ret;
+}
+
+string SingleTensorName(DataType dtype, bool is_ref) {
+ const string type_str = TypeString(dtype, is_ref);
+ return strings::StrCat("A `Tensor` of type ", type_str, ".");
+}
+
+const char kUnknownTensorType[] = {"A `Tensor`."};
+
+string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg,
+ const std::unordered_map<string, string>& inferred_attrs,
+ bool is_output) {
+ if (!arg.number_attr().empty()) {
+ // N Tensors with the same type
+ const string* original_arg =
+ gtl::FindOrNull(inferred_attrs, arg.number_attr());
+ string prefix;
+ if (original_arg == nullptr) {
+ prefix = strings::StrCat("A list of `", arg.number_attr(), "`");
+ } else if (*original_arg == arg.name()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
+ if (attr->has_minimum() && attr->minimum() > 0) {
+ prefix = strings::StrCat("A list of at least ", attr->minimum());
+ } else {
+ prefix = "A list of";
+ }
+ } else {
+ prefix = strings::StrCat(
+ "A list with the same number of `Tensor` objects as `",
+ AvoidPythonReserved(*original_arg), "` of");
+ }
+
+ if (arg.type() != DT_INVALID) {
+ return strings::StrCat(prefix, " `Tensor` objects of type ",
+ TypeString(arg.type(), arg.is_ref()), ".");
+ } else {
+ original_arg = gtl::FindOrNull(inferred_attrs, arg.type_attr());
+ if (arg.is_ref()) {
+ strings::StrAppend(&prefix, " mutable");
+ }
+ if (original_arg == nullptr) {
+ return strings::StrCat(prefix, " `Tensor` objects of type ",
+ arg.type_attr(), ".");
+ } else if (*original_arg == arg.name()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
+ if (attr->has_allowed_values()) {
+ return strings::StrCat(prefix,
+ " `Tensor` objects of the same type in: ",
+ TypeListString(attr->allowed_values()), ".");
+ } else {
+ return strings::StrCat(prefix, " `Tensor` objects of the same type.");
+ }
+ } else {
+ return strings::StrCat(prefix, " `Tensor` objects of the same type as ",
+ AvoidPythonReserved(*original_arg), ".");
+ }
+ }
+ } else if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) {
+ const bool is_list = !arg.type_list_attr().empty();
+ const string attr_name = is_list ? arg.type_list_attr() : arg.type_attr();
+ const OpDef::AttrDef* attr = FindAttr(attr_name, op_def);
+ const string mutable_str = arg.is_ref() ? "mutable " : "";
+ const string prefix =
+ is_list ? strings::StrCat("A list of ", mutable_str, "`Tensor` objects")
+ : strings::StrCat("A ", mutable_str, "`Tensor`");
+ const string* original_arg = gtl::FindOrNull(inferred_attrs, attr_name);
+ if (original_arg == nullptr) {
+ return strings::StrCat(prefix, " of type `", attr_name, "`.");
+ } else if (*original_arg == arg.name()) {
+ if (attr->has_allowed_values()) {
+ if (is_list) {
+ return strings::StrCat(prefix, " with types from: ",
+ TypeListString(attr->allowed_values()), ".");
+ } else {
+ return strings::StrCat(
+ prefix, is_output ? ". Has one of the following types: "
+ : ". Must be one of the following types: ",
+ TypeListString(attr->allowed_values()), ".");
+ }
+ } else {
+ return strings::StrCat(prefix, ".");
+ }
+ } else {
+ return strings::StrCat(prefix,
+ is_output ? ". Has the same type as `"
+ : ". Must have the same type as `",
+ AvoidPythonReserved(*original_arg), "`.");
+ }
+ } else {
+ return SingleTensorName(arg.type(), arg.is_ref());
+ }
+}
+
+void PrintReturns(const OpDef& op_def,
+ const std::vector<string>& output_type_string) {
+ DCHECK_EQ(op_def.output_arg_size(), output_type_string.size());
+ const int num_outs = op_def.output_arg_size();
+ printf("\n Returns:\n");
+ if (num_outs == 0) {
+ printf(" The created Operation.\n");
+ } else {
+ if (num_outs == 1) {
+ StringPiece description = op_def.output_arg(0).description();
+ if (ConsumeEquals(&description)) { // Skip the generated type info.
+ printf("%s", Indent(4, 4, description).c_str());
+ } else {
+ // Special case of one output, don't use the name of the output unless
+ // there is no description.
+ string desc = output_type_string.empty() ? kUnknownTensorType
+ : output_type_string[0];
+ if (desc == kUnknownTensorType) {
+ // Special case where we don't understand how the output tensor type
+ // depends on the input tensor types, just use the output arg
+ // description if we can.
+ if (!description.empty()) {
+ desc = op_def.output_arg(0).description();
+ } else if (!op_def.output_arg(0).name().empty()) {
+ desc = strings::StrCat(" The ", op_def.output_arg(0).name(),
+ " `Tensor`.");
+ }
+ } else if (!description.empty()) {
+ AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
+ }
+ printf("%s", Indent(4, 4, desc).c_str());
+ }
+ } else {
+ std::vector<string> out_names(num_outs);
+ for (int i = 0; i < num_outs; ++i) {
+ if (!op_def.output_arg(i).name().empty()) {
+ out_names[i] = op_def.output_arg(i).name();
+ } else {
+ out_names[i] = strings::StrCat("output", i);
+ }
+ }
+ printf(" A tuple of `Tensor` objects (%s).\n",
+ str_util::Join(out_names, ", ").c_str());
+ for (int i = 0; i < num_outs; ++i) {
+ string desc = strings::StrCat(out_names[i], ": ");
+ StringPiece description = op_def.output_arg(i).description();
+ if (ConsumeEquals(&description)) { // Skip the generated type info.
+ strings::StrAppend(&desc, description);
+ } else {
+ const string type = static_cast<size_t>(i) < output_type_string.size()
+ ? output_type_string[i]
+ : kUnknownTensorType;
+ if (!description.empty()) {
+ if (type == kUnknownTensorType) {
+ // Special case where we don't understand how the output tensor
+ // type depends on the input tensor types, so we just use the
+ // output arg description.
+ strings::StrAppend(&desc, description);
+ } else {
+ strings::StrAppend(&desc, type, " ", description);
+ }
+ } else {
+ strings::StrAppend(&desc, type);
+ }
+ }
+ printf("%s", Indent(4, 6, desc).c_str());
+ }
+ }
+ }
+}
+
+string StringToPython(const string& str) {
+ return strings::StrCat("\"", str_util::CEscape(str), "\"");
+}
+
+string DataTypeToPython(DataType dtype) {
+ return strings::StrCat("tf.", PythonDataTypeString(dtype));
+}
+
+string ShapeToPython(const TensorShapeProto& shape) {
+ string python = "[";
+ for (const auto& dim : shape.dim()) {
+ if (python.size() > 1) strings::StrAppend(&python, ", ");
+ if (!dim.name().empty()) {
+ strings::StrAppend(&python, "(", StringToPython(dim.name()), ", ",
+ dim.size(), ")");
+ } else {
+ strings::StrAppend(&python, dim.size());
+ }
+ }
+ strings::StrAppend(&python, "]");
+ return python;
+}
+
+string AttrListToPython(const AttrValue& value) {
+ string ret;
+ if (value.list().s_size() > 0) {
+ for (int i = 0; i < value.list().s_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, StringToPython(value.list().s(i)));
+ }
+ } else if (value.list().i_size() > 0) {
+ for (int i = 0; i < value.list().i_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, value.list().i(i));
+ }
+ } else if (value.list().f_size() > 0) {
+ for (int i = 0; i < value.list().f_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, value.list().f(i));
+ }
+ } else if (value.list().b_size() > 0) {
+ for (int i = 0; i < value.list().b_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, value.list().b(i) ? "True" : "False");
+ }
+ } else if (value.list().type_size() > 0) {
+ for (int i = 0; i < value.list().type_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, DataTypeToPython(value.list().type(i)));
+ }
+ } else if (value.list().shape_size() > 0) {
+ for (int i = 0; i < value.list().shape_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, ShapeToPython(value.list().shape(i)));
+ }
+ }
+ return ret;
+}
+
+string AttrValueToPython(const string& type, const AttrValue& value) {
+ if (type == "string") {
+ return StringToPython(value.s());
+ } else if (type == "int") {
+ return strings::StrCat(value.i());
+ } else if (type == "float") {
+ return strings::StrCat(value.f());
+ } else if (type == "bool") {
+ return value.b() ? "True" : "False";
+ } else if (type == "type") {
+ return DataTypeToPython(value.type());
+ } else if (type == "shape") {
+ return ShapeToPython(value.shape());
+ } else {
+ return strings::StrCat("[", AttrListToPython(value), "]");
+ }
+}
+
+// Requires: ValidateOpDef(op_def).ok()
+void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) {
+ // Map from attr name to the first input arg it is inferred from.
+ std::unordered_map<string, string> inferred_attrs;
+ // This has all the input args followed by those attrs that don't have
+ // defaults.
+ std::vector<string> args_no_default;
+ // The parameters with defaults (these have to be listed after those without).
+ // No input args are included, just attrs and the graph ("g") parameter.
+ std::vector<string> args_with_defaults;
+ for (int i = 0; i < op_def.input_arg_size(); ++i) {
+ const auto& arg(op_def.input_arg(i));
+ args_no_default.push_back(arg.name());
+ if (!arg.type_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs, arg.type_attr(), arg.name());
+ } else if (!arg.type_list_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs, arg.type_list_attr(),
+ arg.name());
+ }
+ if (!arg.number_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs, arg.number_attr(), arg.name());
+ }
+ }
+ for (int i = 0; i < op_def.attr_size(); ++i) {
+ const auto& attr(op_def.attr(i));
+ // Do not add inferred attrs to the Python function signature.
+ if (inferred_attrs.find(attr.name()) == inferred_attrs.end()) {
+ if (attr.has_default_value()) {
+ args_with_defaults.push_back(attr.name());
+ } else {
+ args_no_default.push_back(attr.name());
+ }
+ }
+ }
+
+ // Save the list of attr parameters (attrs that won't be inferred),
+ // those with defaults go at the end.
+ std::vector<string> attrs;
+ // Get the attrs in the order we want by taking the attrs without defaults
+ // from the end of args_no_default, and adding args_no_default (before
+ // "g" gets added to args_no_default, so it only has attrs).
+ attrs.reserve(args_no_default.size() - op_def.input_arg_size() +
+ args_with_defaults.size());
+ attrs.insert(attrs.end(), args_no_default.begin() + op_def.input_arg_size(),
+ args_no_default.end());
+ attrs.insert(attrs.end(), args_with_defaults.begin(),
+ args_with_defaults.end());
+
+ std::vector<string> param_names;
+ param_names.reserve(args_no_default.size() + args_with_defaults.size());
+ string parameters;
+ for (const string& name : args_no_default) {
+ if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
+ const string param = AvoidPythonReserved(name);
+ strings::StrAppend(&parameters, param);
+ param_names.push_back(param);
+ }
+ for (const string& name : args_with_defaults) {
+ if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
+ const string param = AvoidPythonReserved(name);
+ strings::StrAppend(&parameters, param, "=None");
+ param_names.push_back(param);
+ }
+ const bool has_args = args_no_default.size() + args_with_defaults.size() > 0;
+
+ // Print: def Function(parameters):
+ const string lower_op_name = strings::StrCat(is_hidden ? "_" : "", op_name);
+
+ const string def_prefix = strings::StrCat("def ", lower_op_name, "(");
+ const string def_suffix =
+ strings::StrCat(parameters, has_args ? ", " : "", "name=None):");
+
+ printf("%s\n", WordWrap(def_prefix, def_suffix, kRightMargin).c_str());
+
+ // Format the Op's descriptions so that it can be a Python docstring.
+ string comment;
+ if (op_def.summary().empty()) {
+ comment = "TODO: add doc.\n";
+ } else {
+ comment = strings::StrCat(op_def.summary(), "\n");
+ if (!op_def.description().empty()) {
+ strings::StrAppend(&comment, "\n", Indent(2, 2, op_def.description()));
+ }
+ }
+
+ printf(R"( r"""%s
+ Args:
+)",
+ comment.c_str());
+
+ // Inputs
+ for (int i = 0; i < op_def.input_arg_size(); ++i) {
+ const auto& arg(op_def.input_arg(i));
+ StringPiece description = op_def.input_arg(i).description();
+ string desc;
+ if (ConsumeEquals(&description)) { // Skip the generated type info.
+ desc = strings::StrCat(param_names[i], ": ");
+ } else {
+ desc = strings::StrCat(param_names[i], ": ",
+ ArgTypeName(op_def, arg, inferred_attrs, false));
+ }
+ if (!description.empty()) {
+ AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
+ }
+ printf("%s", Indent(4, 6, desc).c_str());
+ }
+
+ // Attrs
+ for (const string& name : attrs) {
+ const auto& attr = *FindAttr(name, op_def);
+ string desc = strings::StrCat(AvoidPythonReserved(name), ": ");
+
+ static const char* const kAttrTypeName[][2] = {
+ {"string", "`string`"},
+ {"list(string)", "list of `strings`"},
+ {"int", "`int`"},
+ {"list(int)", "list of `ints`"},
+ {"float", "`float`"},
+ {"list(float)", "list of `floats`"},
+ {"bool", "`bool`"},
+ {"list(bool)", "list of `bools`"},
+ {"type", "`tf.DType`"},
+ {"list(type)", "list of `tf.DTypes`"},
+ {"shape", "`tf.TensorShape` or list of `ints`"},
+ {"list(shape)",
+ "list of shapes (each a `tf.TensorShape` or list of `ints`)"},
+ };
+ for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) {
+ if (attr.type() == kAttrTypeName[i][0]) {
+ string s;
+ if (attr.has_default_value()) {
+ s = strings::StrCat("optional ", kAttrTypeName[i][1]);
+ } else {
+ s = kAttrTypeName[i][1];
+ }
+ if (s[0] == 'o' || (s[0] == '`' && (s[1] == 'i' || s[1] == 'o'))) {
+ strings::StrAppend(&desc, "An ", s);
+ } else {
+ strings::StrAppend(&desc, "A ", s);
+ }
+ break;
+ }
+ }
+
+ if (attr.has_allowed_values()) {
+ strings::StrAppend(&desc, " from: `",
+ AttrListToPython(attr.allowed_values()), "`");
+ }
+
+ if (attr.has_minimum()) {
+ if (attr.type() == "int") {
+ strings::StrAppend(&desc, " that is `>= ", attr.minimum(), "`");
+ } else if (attr.minimum() > 0) {
+ strings::StrAppend(&desc, " that has length `>= ", attr.minimum(), "`");
+ }
+ }
+
+ strings::StrAppend(&desc, ".");
+
+ if (attr.has_default_value()) {
+ strings::StrAppend(&desc, " Defaults to `",
+ AttrValueToPython(attr.type(), attr.default_value()),
+ "`.");
+ }
+
+ if (!attr.description().empty()) {
+ AppendWithinWidth(&desc, attr.description(),
+ kRightMargin - 4 /* indent */);
+ }
+ printf("%s", Indent(4, 6, desc).c_str());
+ }
+
+ printf(" name: A name for the operation (optional).\n");
+
+ std::vector<string> output_type_string;
+ output_type_string.reserve(op_def.output_arg_size());
+ for (int i = 0; i < op_def.output_arg_size(); ++i) {
+ output_type_string.push_back(
+ ArgTypeName(op_def, op_def.output_arg(i), inferred_attrs, true));
+ }
+ PrintReturns(op_def, output_type_string);
+
+ string return_prefix = strings::StrCat(" return _op_def_lib.apply_op(");
+ string return_args = strings::StrCat("\"", op_def.name(), "\", ");
+ for (size_t i = 0; i < param_names.size(); ++i) {
+ strings::StrAppend(&return_args, param_names[i], "=", param_names[i], ", ");
+ }
+ strings::StrAppend(&return_args, "name=name)");
+
+ printf(R"( """
+%s
+)",
+ // Wrap the arguments, and indent to the (.
+ WordWrap(return_prefix, return_args, kRightMargin).c_str());
+
+ printf("\n\n");
+}
+
+void GenerateLowerCaseOpName(const string& str, string* result) {
+ char joiner = '_';
+ int last_index = str.size() - 1;
+ for (int i = 0; i <= last_index; ++i) {
+ char c = str[i];
+ // Emit a joiner only if a previous-lower-to-now-upper or a
+ // now-upper-to-next-lower transition happens.
+ if (isupper(c) && (i > 0)) {
+ if (islower(str[i - 1]) || ((i < last_index) && islower(str[i + 1]))) {
+ result->push_back(joiner);
+ }
+ }
+ result->push_back(tolower(c));
+ }
+}
+
+} // namespace
+
+void PrintPythonOps(const OpList& ops, const string& hidden_ops,
+ bool require_shapes) {
+ // Header
+ // TODO(josh11b): Mention the library for which wrappers are being generated.
+ printf(R"("""Python wrappers around Brain.
+
+This file is MACHINE GENERATED! Do not edit.
+"""
+
+from google.protobuf import text_format
+
+from tensorflow.core.framework import op_def_pb2
+from tensorflow.python.framework import op_def_registry
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import op_def_library
+
+
+)");
+
+ std::vector<string> hidden_vec = str_util::Split(hidden_ops, ',');
+
+ // We'll make a copy of ops that filters out descriptions.
+ OpList cleaned_ops;
+ auto out = cleaned_ops.mutable_op();
+ out->Reserve(ops.op_size());
+ for (const auto& op_def : ops.op()) {
+ bool is_hidden = false;
+ for (const string& hidden : hidden_vec) {
+ if (op_def.name() == hidden) {
+ is_hidden = true;
+ break;
+ }
+ }
+
+ // PrintPythonOp(op_def, is_hidden, op_def.name());
+ string lower_case_name;
+ GenerateLowerCaseOpName(op_def.name(), &lower_case_name);
+
+ // When users create custom python wrappers, they may link in the
+ // default op registry by accident, and because they can't
+ // enumerate all 'hidden' symbols, this guard is to prevent
+ // instantiating a python reserved word in their wrapper.
+ if (!is_hidden && IsPythonReserved(lower_case_name)) {
+ continue;
+ }
+
+ PrintPythonOp(op_def, is_hidden, lower_case_name);
+
+ if (!require_shapes) {
+ printf("ops.RegisterShape(\"%s\")(None)\n", op_def.name().c_str());
+ }
+
+ auto added = out->Add();
+ *added = op_def;
+ RemoveDescriptionsFromOpDef(added);
+ }
+
+ printf(R"(def _InitOpDefLibrary():
+ op_list = op_def_pb2.OpList()
+ text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list)
+ op_def_registry.register_op_list(op_list)
+ op_def_lib = op_def_library.OpDefLibrary()
+ op_def_lib.add_op_list(op_list)
+ return op_def_lib
+
+
+_InitOpDefLibrary.op_list_ascii = """%s"""
+
+
+_op_def_lib = _InitOpDefLibrary()
+)",
+ cleaned_ops.DebugString().c_str());
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/python/framework/python_op_gen.h b/tensorflow/python/framework/python_op_gen.h
new file mode 100644
index 0000000000..488f7431e0
--- /dev/null
+++ b/tensorflow/python/framework/python_op_gen.h
@@ -0,0 +1,17 @@
+#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_
+#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_
+
+#include <string>
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Result is printed to stdout. hidden_ops should be a comma-separated
+// list of Op names that should get a leading _ in the output.
+void PrintPythonOps(const OpList& ops, const string& hidden_ops,
+ bool require_shapes);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_
diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc
new file mode 100644
index 0000000000..29afe35598
--- /dev/null
+++ b/tensorflow/python/framework/python_op_gen_main.cc
@@ -0,0 +1,30 @@
+#include "tensorflow/python/framework/python_op_gen.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace {
+
+void PrintAllPythonOps(const char* hidden, bool require_shapes) {
+ OpList ops;
+ OpRegistry::Global()->Export(false, &ops);
+ PrintPythonOps(ops, hidden, require_shapes);
+}
+
+} // namespace
+} // namespace tensorflow
+
+int main(int argc, char* argv[]) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+ if (argc == 2) {
+ tensorflow::PrintAllPythonOps("", std::string(argv[1]) == "1");
+ } else if (argc == 3) {
+ tensorflow::PrintAllPythonOps(argv[1], std::string(argv[2]) == "1");
+ } else {
+ return -1;
+ }
+ return 0;
+}
diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py
new file mode 100644
index 0000000000..d0ffee7042
--- /dev/null
+++ b/tensorflow/python/framework/random_seed.py
@@ -0,0 +1,136 @@
+"""For seeding individual ops based on a graph-level seed.
+"""
+
+from tensorflow.python.framework import ops
+
+
+_DEFAULT_GRAPH_SEED = 87654321
+
+
+def get_seed(op_seed):
+ """Returns the local seeds an operation should use given an op-specific seed.
+
+ Given operation-specific seed, `op_seed`, this helper function returns two
+ seeds derived from graph-level and op-level seeds. Many random operations
+ internally use the two seeds to allow user to change the seed globally for a
+ graph, or for only specific operations.
+
+ For details on how the graph-level seed interacts with op seeds, see
+ [`set_random_seed`](constant_op.md#set_random_seed).
+
+ Args:
+ op_seed: integer.
+
+ Returns:
+ A tuple of two integers that should be used for the local seed of this
+ operation.
+ """
+ graph_seed = ops.get_default_graph().seed
+ if graph_seed is not None:
+ if op_seed is not None:
+ return graph_seed, op_seed
+ else:
+ return graph_seed, ops.get_default_graph()._last_id
+ else:
+ if op_seed is not None:
+ return _DEFAULT_GRAPH_SEED, op_seed
+ else:
+ return None, None
+
+
+def set_random_seed(seed):
+ """Sets the graph-level random seed.
+
+ Operations that rely on a random seed actually derive it from two seeds:
+ the graph-level and operation-level seeds. This sets the graph-level seed.
+
+ Its interactions with operation-level seeds is as follows:
+
+ 1. If neither the graph-level nor the operation seed is set:
+ A random seed is used for this op.
+ 2. If the graph-level seed is set, but the operation seed is not:
+ The system deterministically picks an operation seed in conjunction
+ with the graph-level seed so that it gets a unique random sequence.
+ 3. If the graph-level seed is not set, but the operation seed is set:
+ A default graph-level seed and the specified operation seed are used to
+ determine the random sequence.
+ 4. If both the graph-level and the operation seed are set:
+ Both seeds are used in conjunction to determine the random sequence.
+
+ To illustrate the user-visible effects, consider these examples:
+
+ To generate different sequences across sessions, set neither
+ graph-level nor op-level seeds:
+
+ ```python
+ a = tf.random_uniform([1])
+ b = tf.random_normal([1])
+
+ print "Session 1"
+ with tf.Session() as sess1:
+ print sess1.run(a) # generates 'A1'
+ print sess1.run(a) # generates 'A2'
+ print sess1.run(b) # generates 'B1'
+ print sess1.run(b) # generates 'B2'
+
+ print "Session 2"
+ with tf.Session() as sess2:
+ print sess2.run(a) # generates 'A3'
+ print sess2.run(a) # generates 'A4'
+ print sess2.run(b) # generates 'B3'
+ print sess2.run(b) # generates 'B4'
+ ```
+
+ To generate the same repeatable sequence for an op across sessions, set the
+ seed for the op:
+
+ ```python
+ a = tf.random_uniform([1], seed=1)
+ b = tf.random_normal([1])
+
+ # Repeatedly running this block with the same graph will generate the same
+ # sequence of values for 'a', but different sequences of values for 'b'.
+ print "Session 1"
+ with tf.Session() as sess1:
+ print sess1.run(a) # generates 'A1'
+ print sess1.run(a) # generates 'A2'
+ print sess1.run(b) # generates 'B1'
+ print sess1.run(b) # generates 'B2'
+
+ print "Session 2"
+ with tf.Session() as sess2:
+ print sess2.run(a) # generates 'A1'
+ print sess2.run(a) # generates 'A2'
+ print sess2.run(b) # generates 'B3'
+ print sess2.run(b) # generates 'B4'
+ ```
+
+ To make the random sequences generated by all ops be repeatable across
+ sessions, set a graph-level seed:
+
+ ```python
+ tf.set_random_seed(1234)
+ a = tf.random_uniform([1])
+ b = tf.random_normal([1])
+
+ # Repeatedly running this block with the same graph will generate different
+ # sequences of 'a' and 'b'.
+ print "Session 1"
+ with tf.Session() as sess1:
+ print sess1.run(a) # generates 'A1'
+ print sess1.run(a) # generates 'A2'
+ print sess1.run(b) # generates 'B1'
+ print sess1.run(b) # generates 'B2'
+
+ print "Session 2"
+ with tf.Session() as sess2:
+ print sess2.run(a) # generates 'A1'
+ print sess2.run(a) # generates 'A2'
+ print sess2.run(b) # generates 'B1'
+ print sess2.run(b) # generates 'B2'
+ ```
+
+ Args:
+ seed: integer.
+ """
+ ops.get_default_graph().seed = seed
diff --git a/tensorflow/python/framework/registry.py b/tensorflow/python/framework/registry.py
new file mode 100644
index 0000000000..d9556f0a06
--- /dev/null
+++ b/tensorflow/python/framework/registry.py
@@ -0,0 +1,64 @@
+"""Registry mechanism for "registering" classes/functions for general use.
+
+This is typically used with a decorator that calls Register for adding
+a class or function to a registry.
+"""
+
+import traceback
+
+from tensorflow.python.platform import logging
+
+
+# Registry mechanism below is based on mapreduce.python.mrpython.Register.
+_LOCATION_TAG = "location"
+_TYPE_TAG = "type"
+
+
+class Registry(object):
+ """Provides a registry for saving objects."""
+
+ def __init__(self, name):
+ """Creates a new registry."""
+ self._name = name
+ self._registry = dict()
+
+ def register(self, candidate, name=None):
+ """Registers a Python object "candidate" for the given "name".
+
+ Args:
+ candidate: the candidate object to add to the registry.
+ name: an optional string specifying the registry key for the candidate.
+ If None, candidate.__name__ will be used.
+ Raises:
+ KeyError: If same name is used twice.
+ """
+ if not name:
+ name = candidate.__name__
+ if name in self._registry:
+ (filename, line_number, function_name, _) = (
+ self._registry[name][_LOCATION_TAG])
+ raise KeyError("Registering two %s with name '%s' !"
+ "(Previous registration was in %s %s:%d)" %
+ (self._name, name, function_name, filename, line_number))
+
+ logging.vlog(1, "Registering %s (%s) in %s.", name, candidate, self._name)
+ # stack trace is [this_function, Register(), user_function,...]
+ # so the user function is #2.
+ stack = traceback.extract_stack()
+ self._registry[name] = {_TYPE_TAG: candidate, _LOCATION_TAG: stack[2]}
+
+ def lookup(self, name):
+ """Looks up "name".
+
+ Args:
+ name: a string specifying the registry key for the candidate.
+ Returns:
+ Registered object if found
+ Raises:
+ LookupError: if "name" has not been registered.
+ """
+ if name in self._registry:
+ return self._registry[name][_TYPE_TAG]
+ else:
+ raise LookupError(
+ "%s registry has no entry for: %s" % (self._name, name))
diff --git a/tensorflow/python/framework/registry_test.py b/tensorflow/python/framework/registry_test.py
new file mode 100644
index 0000000000..5b4f261ceb
--- /dev/null
+++ b/tensorflow/python/framework/registry_test.py
@@ -0,0 +1,38 @@
+"""Tests for tensorflow.ops.registry."""
+
+from tensorflow.python.framework import registry
+from tensorflow.python.platform import googletest
+
+
+class RegistryTest(googletest.TestCase):
+
+ class Foo(object):
+ pass
+
+ def testRegisterClass(self):
+ myreg = registry.Registry('testfoo')
+ with self.assertRaises(LookupError):
+ myreg.lookup('Foo')
+ myreg.register(RegistryTest.Foo, 'Foo')
+ assert myreg.lookup('Foo') == RegistryTest.Foo
+
+ def testRegisterFunction(self):
+ myreg = registry.Registry('testbar')
+ with self.assertRaises(LookupError):
+ myreg.lookup('Bar')
+ myreg.register(bar, 'Bar')
+ assert myreg.lookup('Bar') == bar
+
+ def testDuplicate(self):
+ myreg = registry.Registry('testbar')
+ myreg.register(bar, 'Bar')
+ with self.assertRaises(KeyError):
+ myreg.register(bar, 'Bar')
+
+
+def bar():
+ pass
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
new file mode 100644
index 0000000000..d4f27696d4
--- /dev/null
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -0,0 +1,743 @@
+"""Helper classes for tensor shape inference."""
+import tensorflow.python.platform
+
+
+class Dimension(object):
+ """Represents the value of one dimension in a TensorShape."""
+
+ def __init__(self, value):
+ """Creates a new Dimension with the given value."""
+ if value is None:
+ self._value = None
+ else:
+ self._value = int(value)
+
+ def __repr__(self):
+ return "Dimension(%s)" % repr(self._value)
+
+ def __eq__(self, other):
+ """Returns true if `other` has the same known value as this Dimension."""
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ return self._value == other.value
+
+ def __ne__(self, other):
+ """Returns true if `other` has a different known value from `self`."""
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ return self._value != other.value
+
+ def __int__(self):
+ return self._value
+
+ @property
+ def value(self):
+ """The value of this dimension, or None if it is unknown."""
+ return self._value
+
+ def is_compatible_with(self, other):
+ """Returns true if `other` is compatible with this Dimension.
+
+ Two known Dimensions are compatible if they have the same value.
+ An unknown Dimension is compatible with all other Dimensions.
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ True if this Dimension and `other` are compatible.
+ """
+ other = as_dimension(other)
+ return (self._value is None
+ or other.value is None
+ or self._value == other.value)
+
+ def assert_is_compatible_with(self, other):
+ """Raises an exception if `other` is not compatible with this Dimension.
+
+ Args:
+ other: Another Dimension.
+
+ Raises:
+ ValueError: If `self` and `other` are not compatible (see
+ is_compatible_with).
+ """
+ if not self.is_compatible_with(other):
+ raise ValueError("Dimensions %s and %s are not compatible"
+ % (self, other))
+
+ def merge_with(self, other):
+ """Returns a Dimension that combines the information in `self` and `other`.
+
+ Dimensions are combined as follows:
+
+ Dimension(n) .merge_with(Dimension(n)) == Dimension(n)
+ Dimension(n) .merge_with(Dimension(None)) == Dimension(n)
+ Dimension(None).merge_with(Dimension(n)) == Dimension(n)
+ Dimension(None).merge_with(Dimension(None)) == Dimension(None)
+ Dimension(n) .merge_with(Dimension(m)) raises ValueError for n != m
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension containing the combined information of `self` and
+ `other`.
+
+ Raises:
+ ValueError: If `self` and `other` are not compatible (see
+ is_compatible_with).
+ """
+ other = as_dimension(other)
+ self.assert_is_compatible_with(other)
+ if self._value is None:
+ return Dimension(other.value)
+ else:
+ return Dimension(self._value)
+
+ def __add__(self, other):
+ """Returns the sum of `self` and `other`.
+
+ Dimensions are summed as follows:
+
+ Dimension(m) + Dimension(n) == Dimension(m + n)
+ Dimension(m) + Dimension(None) == Dimension(None)
+ Dimension(None) + Dimension(n) == Dimension(None)
+ Dimension(None) + Dimension(None) == Dimension(None)
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension whose value is the sum of `self` and `other`.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return Dimension(None)
+ else:
+ return Dimension(self._value + other.value)
+
+ def __sub__(self, other):
+ """Returns the subtraction of `other` from `self`.
+
+ Dimensions are subtracted as follows:
+
+ Dimension(m) - Dimension(n) == Dimension(m - n)
+ Dimension(m) - Dimension(None) == Dimension(None)
+ Dimension(None) - Dimension(n) == Dimension(None)
+ Dimension(None) - Dimension(None) == Dimension(None)
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension whose value is the subtraction of sum of `other` from `self`.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return Dimension(None)
+ else:
+ return Dimension(self._value - other.value)
+
+ def __mul__(self, other):
+ """Returns the product of `self` and `other`.
+
+ Dimensions are summed as follows:
+
+ Dimension(m) * Dimension(n) == Dimension(m * n)
+ Dimension(m) * Dimension(None) == Dimension(None)
+ Dimension(None) * Dimension(n) == Dimension(None)
+ Dimension(None) * Dimension(None) == Dimension(None)
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension whose value is the sum of `self` and `other`.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return Dimension(None)
+ else:
+ return Dimension(self._value * other.value)
+
+ def __div__(self, other):
+ """Returns the quotient of `self` and `other`.
+
+ Dimensions are summed as follows:
+
+ Dimension(m) / Dimension(n) == Dimension(m / n)
+ Dimension(m) / Dimension(None) == Dimension(None)
+ Dimension(None) / Dimension(n) == Dimension(None)
+ Dimension(None) / Dimension(None) == Dimension(None)
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension whose value is the sum of `self` and `other`.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return Dimension(None)
+ else:
+ return Dimension(self._value / other.value)
+
+ def __mod__(self, other):
+ """Returns `self` modulo `other.
+
+ Dimension moduli are computed as follows:
+
+ Dimension(m) % Dimension(n) == Dimension(m % n)
+ Dimension(m) % Dimension(None) == Dimension(None)
+ Dimension(None) % Dimension(n) == Dimension(None)
+ Dimension(None) % Dimension(None) == Dimension(None)
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension whose value is `self` modulo `other`.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return Dimension(None)
+ else:
+ return Dimension(self._value % other.value)
+
+ def __lt__(self, other):
+ """Returns True if `self` is known to be less than `other`.
+
+ Dimensions are compared as follows:
+
+ Dimension(m) < Dimension(n) == m < n
+ Dimension(m) < Dimension(None) == None
+ Dimension(None) < Dimension(n) == None
+ Dimension(None) < Dimension(None) == None
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ The value of `self.value < other.value` if both are known, otherwise
+ None.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ else:
+ return self._value < other.value
+
+ def __le__(self, other):
+ """Returns True if `self` is known to be less than or equal to `other`.
+
+ Dimensions are compared as follows:
+
+ Dimension(m) <= Dimension(n) == m <= n
+ Dimension(m) <= Dimension(None) == None
+ Dimension(None) <= Dimension(n) == None
+ Dimension(None) <= Dimension(None) == None
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ The value of `self.value <= other.value` if both are known, otherwise
+ None.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ else:
+ return self._value <= other.value
+
+ def __gt__(self, other):
+ """Returns True if `self` is known to be greater than `other`.
+
+ Dimensions are compared as follows:
+
+ Dimension(m) > Dimension(n) == m > n
+ Dimension(m) > Dimension(None) == None
+ Dimension(None) > Dimension(n) == None
+ Dimension(None) > Dimension(None) == None
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ The value of `self.value > other.value` if both are known, otherwise
+ None.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ else:
+ return self._value > other.value
+
+ def __ge__(self, other):
+ """Returns True if `self` is known to be greater than or equal to `other`.
+
+ Dimensions are compared as follows:
+
+ Dimension(m) >= Dimension(n) == m >= n
+ Dimension(m) >= Dimension(None) == None
+ Dimension(None) >= Dimension(n) == None
+ Dimension(None) >= Dimension(None) == None
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ The value of `self.value >= other.value` if both are known, otherwise
+ None.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ else:
+ return self._value >= other.value
+
+
+def as_dimension(value):
+ """Converts the given value to a Dimension.
+
+ A Dimenson input will be returned unmodified.
+ An input of `None` will be converted to an unknown Dimension.
+ An integer input will be converted to a Dimension with that value.
+
+ Args:
+ value: The value to be converted.
+
+ Returns:
+ A Dimension corresponding to the given value.
+ """
+ if isinstance(value, Dimension):
+ return value
+ else:
+ return Dimension(value)
+
+
+class TensorShape(object):
+ """Represents the shape of a `Tensor`.
+
+ A `TensorShape` represents a possibly-partial shape specification for a
+ `Tensor`. It may be one of the following:
+
+ * *Fully-known shape:* has a known number of dimensions and a known size
+ for each dimension.
+ * *Partially-known shape:* has a known number of dimensions, and an unknown
+ size for one or more dimension.
+ * *Unknown shape:* has an unknown number of dimensions, and an unknown
+ size in all dimensions.
+
+ If a tensor is produced by an operation of type `"Foo"`, its shape
+ may be inferred if there is a registered shape function for
+ `"Foo"`. See [`tf.RegisterShape()`](framework.md#RegisterShape)
+ for details of shape
+ functions and how to register them. Alternatively, the shape may be set
+ explicitly using [`Tensor.set_shape()`](framework.md#Tensor.set_shape).
+
+ @@merge_with
+ @@concatenate
+
+ @@ndims
+ @@dims
+ @@as_list
+ @@is_compatible_with
+ @@is_fully_defined
+
+ @@with_rank
+ @@with_rank_at_least
+ @@with_rank_at_most
+
+ @@assert_has_rank
+ @@assert_same_rank
+ @@assert_is_compatible_with
+ @@assert_is_fully_defined
+ """
+
+ def __init__(self, dims):
+ """Creates a new TensorShape with the given dimensions.
+
+ Args:
+ dims: A list of Dimensions, or None if the shape is unspecified.
+ DEPRECATED: A single integer is treated as a singleton list.
+ """
+ # TODO(irving): Eliminate the single integer special case.
+ if dims is None:
+ self._dims = None
+ else:
+ try:
+ dims_iter = iter(dims)
+ except TypeError:
+ # Treat as a singleton dimension
+ self._dims = [as_dimension(dims)]
+ else:
+ # Got a list of dimensions
+ self._dims = map(as_dimension, dims_iter)
+
+ def __repr__(self):
+ return "TensorShape(%s)" % str(self._dims)
+
+ @property
+ def dims(self):
+ """Returns a list of Dimensions, or None if the shape is unspecified."""
+ return self._dims
+
+ @property
+ def ndims(self):
+ """Returns the rank of this shape, or None if it is unspecified."""
+ if self._dims is None:
+ return None
+ else:
+ return len(self._dims)
+
+ def __len__(self):
+ """Returns the rank of this shape, or raises ValueError if unspecified."""
+ if self._dims is None:
+ raise ValueError("Cannot take the length of Shape with unknown rank.")
+ return len(self._dims)
+
+ def __nonzero__(self):
+ """Returns True if this shape contains non-zero information."""
+ return self._dims is not None
+
+ def __getitem__(self, key):
+ """Returns the value of a dimension or a shape, depending on the key.
+
+ Args:
+ key: If `key` is an integer, returns the dimension at that index;
+ otherwise if `key` is a slice, returns a TensorShape whose
+ dimensions are those selected by the slice from `self`.
+
+ Returns:
+ A dimension if `key` is an integer, or a `TensorShape` if `key` is a
+ slice.
+
+ Raises:
+ ValueError: If `key` is a slice, and any of its elements are negative, or
+ if `self` is completely unknown and the step is set.
+ """
+ if self._dims is not None:
+ if isinstance(key, slice):
+ return TensorShape(self._dims[key])
+ else:
+ return self._dims[key]
+ else:
+ if isinstance(key, slice):
+ start = key.start if key.start is not None else 0
+ stop = key.stop
+
+ if key.step is not None:
+ # TODO(mrry): Handle these maybe.
+ raise ValueError("Steps are not yet handled")
+ if stop is None:
+ # NOTE(mrry): This implies that TensorShape(None) is compatible with
+ # TensorShape(None)[1:], which is obviously not true. It would be
+ # possible to track the number of dimensions symbolically,
+ # and perhaps we should do that.
+ return unknown_shape()
+ elif start < 0 or stop < 0:
+ # TODO(mrry): Handle this better, as it will be useful for handling
+ # suffixes of otherwise unknown shapes.
+ return unknown_shape()
+ else:
+ return unknown_shape(ndims=stop-start)
+ else:
+ return Dimension(None)
+
+ def num_elements(self):
+ """Returns the total number of elements, or none for incomplete shapes."""
+ if self.is_fully_defined():
+ size = 1
+ for dim in self._dims:
+ size *= dim.value
+ return size
+ else:
+ return None
+
+ def merge_with(self, other):
+ """Returns a `TensorShape` combining the information in `self` and `other`.
+
+ The dimensions in `self` and `other` are merged elementwise,
+ according to the rules defined for `Dimension.merge_with()`.
+
+ Args:
+ other: Another `TensorShape`.
+
+ Returns:
+ A `TensorShape` containing the combined information of `self` and
+ `other`.
+
+ Raises:
+ ValueError: If `self` and `other` are not compatible.
+ """
+ other = as_shape(other)
+ if self._dims is None:
+ return other
+ else:
+ self.assert_same_rank(other)
+ new_dims = []
+ for i, dim in enumerate(self._dims):
+ new_dims.append(dim.merge_with(other[i]))
+ return TensorShape(new_dims)
+
+ def concatenate(self, other):
+ """Returns the concatenation of the dimension in `self` and `other`.
+
+ *N.B.* If either `self` or `other` is completely unknown,
+ concatenation will discard information about the other shape. In
+ future, we might support concatenation that preserves this
+ information for use with slicing.
+
+ Args:
+ other: Another `TensorShape`.
+
+ Returns:
+ A `TensorShape` whose dimensions are the concatenation of the
+ dimensions in `self` and `other`.
+ """
+ # TODO(mrry): Handle the case where we concatenate a known shape with a
+ # completely unknown shape, so that we can use the partial information.
+ other = as_shape(other)
+ if self._dims is None or other.dims is None:
+ return unknown_shape()
+ else:
+ return TensorShape(self._dims + other.dims)
+
+ def assert_same_rank(self, other):
+ """Raises an exception if `self` and `other` do not have compatible ranks.
+
+ Args:
+ other: Another `TensorShape`.
+
+ Raises:
+ ValueError: If `self` and `other` do not represent shapes with the
+ same rank.
+ """
+ other = as_shape(other)
+ if self.ndims is not None and other.ndims is not None:
+ if self.ndims != other.ndims:
+ raise ValueError(
+ "Shapes %s and %s must have the same rank" % (self, other))
+
+ def assert_has_rank(self, rank):
+ """Raises an exception if `self` is not compatible with the given `rank`.
+
+ Args:
+ rank: An integer.
+
+ Raises:
+ ValueError: If `self` does not represent a shape with the given `rank`.
+ """
+ if self.ndims not in (None, rank):
+ raise ValueError("Shape %s must have rank %d" % (self, rank))
+
+ def with_rank(self, rank):
+ """Returns a shape based on `self` with the given rank.
+
+ This method promotes a completely unknown shape to one with a
+ known rank.
+
+ Args:
+ rank: An integer.
+
+ Returns:
+ A shape that is at least as specific as `self` with the given rank.
+
+ Raises:
+ ValueError: If `self` does not represent a shape with the given `rank`.
+ """
+ return self.merge_with(unknown_shape(ndims=rank))
+
+ def with_rank_at_least(self, rank):
+ """Returns a shape based on `self` with at least the given rank.
+
+ Args:
+ rank: An integer.
+
+ Returns:
+ A shape that is at least as specific as `self` with at least the given
+ rank.
+
+ Raises:
+ ValueError: If `self` does not represent a shape with at least the given
+ `rank`.
+ """
+ if self.ndims is not None and self.ndims < rank:
+ raise ValueError("Shape %s must have rank at least %d" % (self, rank))
+ else:
+ return self
+
+ def with_rank_at_most(self, rank):
+ """Returns a shape based on `self` with at most the given rank.
+
+ Args:
+ rank: An integer.
+
+ Returns:
+ A shape that is at least as specific as `self` with at most the given
+ rank.
+
+ Raises:
+ ValueError: If `self` does not represent a shape with at most the given
+ `rank`.
+ """
+ if self.ndims is not None and self.ndims > rank:
+ raise ValueError("Shape %s must have rank at most %d" % (self, rank))
+ else:
+ return self
+
+ def is_compatible_with(self, other):
+ """Returns True iff `self` is compatible with `other`.
+
+ Two possibly-partially-defined shapes are compatible if there
+ exists a fully-defined shape that both shapes can represent. Thus,
+ compatibility allows the shape inference code to reason about
+ partially-defined shapes. For example:
+
+ * TensorShape(None) is compatible with all shapes.
+
+ * TensorShape([None, None]) is compatible with all two-dimensional
+ shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is
+ not compatible with, for example, TensorShape([None]) or
+ TensorShape([None, None, None]).
+
+ * TensorShape([32, None]) is compatible with all two-dimensional shapes
+ with size 32 in the 0th dimension, and also TensorShape([None, None])
+ and TensorShape(None). It is not compatible with, for example,
+ TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]).
+
+ * TensorShape([32, 784]) is compatible with itself, and also
+ TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None,
+ None]) and TensorShape(None). It is not compatible with, for example,
+ TensorShape([32, 1, 784]) or TensorShape([None]).
+
+ The compatibility relation is reflexive and symmetric, but not
+ transitive. For example, TensorShape([32, 784]) is compatible with
+ TensorShape(None), and TensorShape(None) is compatible with
+ TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with
+ TensorShape([4, 4]).
+
+ Args:
+ other: Another TensorShape.
+
+ Returns:
+ True iff `self` is compatible with `other`.
+
+ """
+ other = as_shape(other)
+ if self._dims is not None and other.dims is not None:
+ if self.ndims != other.ndims:
+ return False
+ for x_dim, y_dim in zip(self._dims, other.dims):
+ if not x_dim.is_compatible_with(y_dim):
+ return False
+ return True
+
+ def assert_is_compatible_with(self, other):
+ """Raises exception if `self` and `other` do not represent the same shape.
+
+ This method can be used to assert that there exists a shape that both
+ `self` and `other` represent.
+
+ Args:
+ other: Another TensorShape.
+
+ Raises:
+ ValueError: If `self` and `other` do not represent the same shape.
+ """
+ if not self.is_compatible_with(other):
+ raise ValueError("Shapes %s and %s are incompatible" % (self, other))
+
+ def is_fully_defined(self):
+ """Returns True iff `self` is fully defined in every dimension."""
+ return (self._dims is not None
+ and all(dim.value is not None for dim in self._dims))
+
+ def assert_is_fully_defined(self):
+ """Raises an exception if `self` is not fully defined in every dimension.
+
+ Raises:
+ ValueError: If `self` does not have a known value for every dimension.
+ """
+ if not self.is_fully_defined():
+ raise ValueError("Shape %s is not fully defined" % self)
+
+ def as_dimension_list(self):
+ """DEPRECATED: use as_list()."""
+ self.assert_is_fully_defined()
+ return self.as_list()
+
+ def as_list(self):
+ """Returns a list of integers or None for each dimension."""
+ return [dim.value for dim in self._dims]
+
+ def __eq__(self, other):
+ """Returns True if `self` is equivalent to `other`."""
+ other = as_shape(other)
+ return self._dims == other.dims
+
+ def __ne__(self, other):
+ """Returns True if `self` is known to be different from `other`."""
+ other = as_shape(other)
+ if self.ndims is None or other.ndims is None:
+ raise ValueError("The inequality of unknown TensorShapes is undefined.")
+ if self.ndims != other.ndims:
+ return True
+ return self._dims != other.dims
+
+
+def as_shape(shape):
+ """Converts the given object to a TensorShape."""
+ if isinstance(shape, TensorShape):
+ return shape
+ else:
+ return TensorShape(shape)
+
+
+def unknown_shape(ndims=None):
+ """Returns an unknown TensorShape, optionally with a known rank.
+
+ Args:
+ ndims: (Optional) If specified, the number of dimensions in the shape.
+
+ Returns:
+ An unknown TensorShape.
+ """
+ if ndims is None:
+ return TensorShape(None)
+ else:
+ return TensorShape([Dimension(None) for _ in range(ndims)])
+
+
+def scalar():
+ """Returns a shape representing a scalar."""
+ return TensorShape([])
+
+
+def vector(length):
+ """Returns a shape representing a vector.
+
+ Args:
+ length: The length of the vector, which may be None if unknown.
+
+ Returns:
+ A TensorShape representing a vector of the given length.
+ """
+ return TensorShape([length])
+
+
+def matrix(rows, cols):
+ """Returns a shape representing a matrix.
+
+ Args:
+ rows: The number of rows in the matrix, which may be None if unknown.
+ cols: The number of columns in the matrix, which may be None if unknown.
+
+ Returns:
+ A TensorShape representing a matrix of the given size.
+ """
+ return TensorShape([rows, cols])
diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py
new file mode 100644
index 0000000000..9743a8d199
--- /dev/null
+++ b/tensorflow/python/framework/tensor_shape_test.py
@@ -0,0 +1,232 @@
+"""Functional tests for shape inference helper classes."""
+import tensorflow.python.platform
+
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class DimensionTest(test_util.TensorFlowTestCase):
+
+ def testDimension(self):
+ dim = tensor_shape.Dimension(12)
+ self.assertEqual(12, dim.value)
+ self.assertEqual(12, int(dim))
+ self.assertEqual(dim, tensor_shape.Dimension(12))
+ self.assertEqual(tensor_shape.Dimension(15),
+ dim + tensor_shape.Dimension(3))
+ self.assertEqual(tensor_shape.Dimension(15), dim + 3)
+ self.assertEqual(tensor_shape.Dimension(24),
+ dim * tensor_shape.Dimension(2))
+ self.assertEqual(tensor_shape.Dimension(24), dim * 2)
+ self.assertEqual(tensor_shape.Dimension(6), dim / tensor_shape.Dimension(2))
+ self.assertEqual(tensor_shape.Dimension(6), dim / 2)
+ self.assertEqual(tensor_shape.Dimension(12),
+ dim.merge_with(tensor_shape.Dimension(12)))
+ self.assertEqual(tensor_shape.Dimension(12), dim.merge_with(12))
+ self.assertLess(tensor_shape.Dimension(12), tensor_shape.Dimension(13))
+ self.assertGreater(tensor_shape.Dimension(13), tensor_shape.Dimension(12))
+ self.assertLessEqual(tensor_shape.Dimension(12), tensor_shape.Dimension(12))
+ self.assertLessEqual(tensor_shape.Dimension(12), tensor_shape.Dimension(13))
+ self.assertGreater(tensor_shape.Dimension(13), tensor_shape.Dimension(12))
+ self.assertGreaterEqual(tensor_shape.Dimension(12),
+ tensor_shape.Dimension(12))
+ self.assertGreaterEqual(tensor_shape.Dimension(13),
+ tensor_shape.Dimension(12))
+ with self.assertRaises(ValueError):
+ dim.merge_with(tensor_shape.Dimension(13))
+
+ def testUnknownDimension(self):
+ dim = tensor_shape.Dimension(None)
+ self.assertIs(None, dim.value)
+ self.assertEqual(dim.value, tensor_shape.Dimension(None).value)
+ self.assertEqual(tensor_shape.Dimension(None).value,
+ (dim + tensor_shape.Dimension(None)).value)
+ self.assertEqual(tensor_shape.Dimension(None).value,
+ (dim * tensor_shape.Dimension(None)).value)
+ self.assertEqual(tensor_shape.Dimension(None).value,
+ (dim / tensor_shape.Dimension(None)).value)
+ self.assertEqual(tensor_shape.Dimension(None).value,
+ dim.merge_with(tensor_shape.Dimension(None)).value)
+ self.assertIs(None,
+ tensor_shape.Dimension(None) < tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) <= tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) > tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) >= tensor_shape.Dimension(None))
+
+ def testKnownAndUnknownDimensions(self):
+ known = tensor_shape.Dimension(12)
+ unknown = tensor_shape.Dimension(None)
+ self.assertEqual(
+ tensor_shape.Dimension(None).value, (known + unknown).value)
+ self.assertEqual(
+ tensor_shape.Dimension(None).value, (unknown + known).value)
+ self.assertEqual(
+ tensor_shape.Dimension(None).value, (known * unknown).value)
+ self.assertEqual(
+ tensor_shape.Dimension(None).value, (unknown * known).value)
+ self.assertEqual(
+ tensor_shape.Dimension(None).value, (known / unknown).value)
+ self.assertEqual(
+ tensor_shape.Dimension(None).value, (unknown / known).value)
+ self.assertEqual(
+ tensor_shape.Dimension(12), known.merge_with(unknown))
+ self.assertEqual(
+ tensor_shape.Dimension(12), unknown.merge_with(known))
+ self.assertIs(None,
+ tensor_shape.Dimension(12) < tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(12) <= tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(12) > tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(12) >= tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) < tensor_shape.Dimension(12))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) <= tensor_shape.Dimension(12))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) > tensor_shape.Dimension(12))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) >= tensor_shape.Dimension(12))
+
+ def testAsDimension(self):
+ self.assertEqual(tensor_shape.Dimension(12),
+ tensor_shape.as_dimension(tensor_shape.Dimension(12)))
+ self.assertEqual(tensor_shape.Dimension(12), tensor_shape.as_dimension(12))
+ self.assertEqual(
+ tensor_shape.Dimension(None).value,
+ tensor_shape.as_dimension(tensor_shape.Dimension(None)).value)
+ self.assertEqual(tensor_shape.Dimension(None).value,
+ tensor_shape.as_dimension(None).value)
+
+ def testEquality(self):
+ self.assertTrue(tensor_shape.Dimension(12) == tensor_shape.Dimension(12))
+ self.assertFalse(tensor_shape.Dimension(12) == tensor_shape.Dimension(13))
+ self.assertIs(None,
+ tensor_shape.Dimension(12) == tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) == tensor_shape.Dimension(12))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) == tensor_shape.Dimension(None))
+
+ def testInequality(self):
+ self.assertTrue(tensor_shape.Dimension(12) != tensor_shape.Dimension(13))
+ self.assertFalse(tensor_shape.Dimension(12) != tensor_shape.Dimension(12))
+ self.assertIs(None,
+ tensor_shape.Dimension(12) != tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) != tensor_shape.Dimension(12))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) != tensor_shape.Dimension(None))
+
+
+class ShapeTest(test_util.TensorFlowTestCase):
+
+ def testUnknownShape(self):
+ s = tensor_shape.TensorShape(None)
+ with self.assertRaises(ValueError):
+ s.assert_is_fully_defined()
+ self.assertIs(None, s.ndims)
+ with self.assertRaises(ValueError):
+ len(s)
+ self.assertFalse(s)
+ self.assertIs(None, s.dims)
+
+ def testFullyDefinedShape(self):
+ s = tensor_shape.TensorShape([tensor_shape.Dimension(3),
+ tensor_shape.Dimension(4),
+ tensor_shape.Dimension(7)])
+ s.assert_is_fully_defined()
+ self.assertEqual(3, s.ndims)
+ self.assertEqual(3, len(s))
+ self.assertTrue(s)
+ s.assert_has_rank(3)
+ self.assertEqual([tensor_shape.Dimension(3),
+ tensor_shape.Dimension(4),
+ tensor_shape.Dimension(7)], s.dims)
+ self.assertEqual(tensor_shape.Dimension(3), s[0])
+ self.assertEqual(tensor_shape.Dimension(4), s[1])
+ self.assertEqual(tensor_shape.Dimension(7), s[2])
+ self.assertEqual([3, 4, 7], s.as_list())
+ s.assert_is_compatible_with([3, 4, 7])
+ s.assert_same_rank([6, 3, 7])
+
+ def testPartiallyDefinedShape(self):
+ s = tensor_shape.TensorShape([tensor_shape.Dimension(3),
+ tensor_shape.Dimension(None),
+ tensor_shape.Dimension(7)])
+ with self.assertRaises(ValueError):
+ s.assert_is_fully_defined()
+ self.assertEqual(3, s.ndims)
+ self.assertEqual(3, len(s))
+ self.assertTrue(s)
+ s.assert_has_rank(3)
+ self.assertEqual(tensor_shape.Dimension(3), s[0])
+ self.assertEqual(tensor_shape.Dimension(None).value, s[1].value)
+ self.assertEqual(tensor_shape.Dimension(7), s[2])
+ s.assert_same_rank([6, 3, 7])
+
+ def testMergeFullShapes(self):
+ self.assertEqual([3, 4, 7],
+ tensor_shape.TensorShape([3, 4, 7]).merge_with(
+ tensor_shape.TensorShape([3, 4, 7])).as_list())
+ with self.assertRaises(ValueError):
+ tensor_shape.TensorShape([3, 4, 7]).merge_with(
+ tensor_shape.TensorShape([6, 3, 7]))
+
+ def testMergePartialShapes(self):
+ s1 = tensor_shape.TensorShape([tensor_shape.Dimension(3),
+ tensor_shape.Dimension(None),
+ tensor_shape.Dimension(7)])
+ s2 = tensor_shape.TensorShape([tensor_shape.Dimension(None),
+ tensor_shape.Dimension(4),
+ tensor_shape.Dimension(7)])
+ self.assertEqual([3, 4, 7], s1.merge_with(s2).as_list())
+
+ def testMergeFullAndUnknownShape(self):
+ self.assertEqual([3, 4, 7],
+ tensor_shape.TensorShape([3, 4, 7]).merge_with(
+ tensor_shape.TensorShape(None)).as_list())
+
+ def testSlice(self):
+ known = tensor_shape.TensorShape([0, 1, 2, 3, 4])
+ self.assertEqual(tensor_shape.Dimension(2), known[2])
+ tensor_shape.TensorShape([1, 2, 3]).assert_is_compatible_with(known[1:4])
+
+ unknown = tensor_shape.TensorShape(None)
+ self.assertEqual(tensor_shape.Dimension(None).value, unknown[2].value)
+ tensor_shape.TensorShape(
+ [None, None, None]).assert_is_compatible_with(unknown[1:4])
+
+ def testConcatenate(self):
+ tensor_shape.TensorShape([1, 2, 3, 4]).assert_is_compatible_with(
+ tensor_shape.TensorShape([1, 2]).concatenate(
+ tensor_shape.TensorShape([3, 4])))
+ tensor_shape.TensorShape([1, 2, 3, 4]).assert_is_compatible_with(
+ tensor_shape.TensorShape([1, 2]).concatenate(
+ tensor_shape.TensorShape(None)))
+ tensor_shape.TensorShape([1, 2, 3, 4]).assert_is_compatible_with(
+ tensor_shape.TensorShape(None).concatenate(
+ tensor_shape.TensorShape([3, 4])))
+ tensor_shape.TensorShape([1, 2, 3, 4]).assert_is_compatible_with(
+ tensor_shape.TensorShape(None).concatenate(
+ tensor_shape.TensorShape(None)))
+ tensor_shape.TensorShape([1, 2, 3]).assert_is_compatible_with(
+ tensor_shape.TensorShape([1, 2]).concatenate(
+ tensor_shape.Dimension(3)))
+
+ def testHelpers(self):
+ tensor_shape.TensorShape([]).assert_is_compatible_with(
+ tensor_shape.scalar())
+ tensor_shape.TensorShape([37]).assert_is_compatible_with(
+ tensor_shape.vector(37))
+ tensor_shape.TensorShape(
+ [94, 43]).assert_is_compatible_with(tensor_shape.matrix(94, 43))
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
new file mode 100644
index 0000000000..81ed54c473
--- /dev/null
+++ b/tensorflow/python/framework/tensor_util.py
@@ -0,0 +1,511 @@
+"""Utilities to create TensorProtos."""
+import numbers
+import tensorflow.python.platform
+import numpy as np
+
+from tensorflow.core.framework import tensor_pb2
+from tensorflow.core.framework import tensor_shape_pb2
+
+# TODO(opensource): Add support for pyx_library in the open-source build.
+# For now, we use the slow versions that fast_tensor_util replaces.
+# pylint: disable=g-import-not-at-top
+try:
+ from tensorflow.python.framework import fast_tensor_util
+ _FAST_TENSOR_UTIL_AVAILABLE = True
+except ImportError:
+ _FAST_TENSOR_UTIL_AVAILABLE = False
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+# pylint: enable=g-import-not-at-top
+
+
+if _FAST_TENSOR_UTIL_AVAILABLE:
+ _NP_TO_APPEND_FN = {
+ np.float32: fast_tensor_util.AppendFloat32ArrayToTensorProto,
+ np.float64: fast_tensor_util.AppendFloat64ArrayToTensorProto,
+ np.int32: fast_tensor_util.AppendInt32ArrayToTensorProto,
+ np.int64: fast_tensor_util.AppendInt64ArrayToTensorProto,
+ np.uint8: fast_tensor_util.AppendUInt8ArrayToTensorProto,
+ np.int16: fast_tensor_util.AppendInt16ArrayToTensorProto,
+ np.int8: fast_tensor_util.AppendInt8ArrayToTensorProto,
+ np.complex64: fast_tensor_util.AppendComplex64ArrayToTensorProto,
+ np.complex128: fast_tensor_util.AppendComplex128ArrayToTensorProto,
+ np.object: fast_tensor_util.AppendObjectArrayToTensorProto,
+ np.bool: fast_tensor_util.AppendBoolArrayToTensorProto,
+ types.qint8.as_numpy_dtype:
+ fast_tensor_util.AppendInt8ArrayToTensorProto,
+ types.quint8.as_numpy_dtype:
+ fast_tensor_util.AppendUInt8ArrayToTensorProto,
+ types.qint32.as_numpy_dtype:
+ fast_tensor_util.AppendInt32ArrayToTensorProto,
+ # NOTE(mdevin): Intentionally no way to feed a DT_BFLOAT16.
+ }
+else:
+
+ def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values):
+ tensor_proto.float_val.extend([np.asscalar(x) for x in proto_values])
+
+ def SlowAppendFloat64ArrayToTensorProto(tensor_proto, proto_values):
+ tensor_proto.double_val.extend([np.asscalar(x) for x in proto_values])
+
+ def SlowAppendIntArrayToTensorProto(tensor_proto, proto_values):
+ tensor_proto.int_val.extend([np.asscalar(x) for x in proto_values])
+
+ def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values):
+ tensor_proto.int64_val.extend([np.asscalar(x) for x in proto_values])
+
+ def SlowAppendComplexArrayToTensorProto(tensor_proto, proto_values):
+ tensor_proto.scomplex_val.extend([np.asscalar(v)
+ for x in proto_values
+ for v in [x.real, x.imag]])
+
+ def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values):
+ tensor_proto.string_val.extend([str(x) for x in proto_values])
+
+ def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values):
+ tensor_proto.bool_val.extend([np.asscalar(x) for x in proto_values])
+
+ _NP_TO_APPEND_FN = {
+ np.float32: SlowAppendFloat32ArrayToTensorProto,
+ np.float64: SlowAppendFloat64ArrayToTensorProto,
+ np.int32: SlowAppendIntArrayToTensorProto,
+ np.int64: SlowAppendInt64ArrayToTensorProto,
+ np.uint8: SlowAppendIntArrayToTensorProto,
+ np.int16: SlowAppendIntArrayToTensorProto,
+ np.int8: SlowAppendIntArrayToTensorProto,
+ np.complex64: SlowAppendComplexArrayToTensorProto,
+ np.complex128: SlowAppendComplexArrayToTensorProto,
+ np.object: SlowAppendObjectArrayToTensorProto,
+ np.bool: SlowAppendBoolArrayToTensorProto,
+ types.qint8.as_numpy_dtype: SlowAppendIntArrayToTensorProto,
+ types.quint8.as_numpy_dtype: SlowAppendIntArrayToTensorProto,
+ types.qint32.as_numpy_dtype: SlowAppendIntArrayToTensorProto,
+ # NOTE(mdevin): Intentionally no way to feed a DT_BFLOAT16.
+ }
+
+
+def GetFromNumpyDTypeDict(dtype_dict, dtype):
+ # NOTE: dtype_dict.get(dtype) always returns None.
+ for key, val in dtype_dict.iteritems():
+ if key == dtype:
+ return val
+ return None
+
+
+def GetNumpyAppendFn(dtype):
+ # numpy dtype for strings are variable length. We can not compare
+ # dtype with a single constant (np.string does not exist) to decide
+ # dtype is a "string" type. We need to compare the dtype.type to be
+ # sure it's a string type.
+ if dtype.type == np.string_ or dtype.type == np.unicode_:
+ if _FAST_TENSOR_UTIL_AVAILABLE:
+ return fast_tensor_util.AppendObjectArrayToTensorProto
+ else:
+ return SlowAppendObjectArrayToTensorProto
+ return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype)
+
+
+def MakeTensorShapeProto(shape):
+ """Create a TensorShapeProto.
+
+ Args:
+ shape: List of integers representing the dimensions of the tensor.
+
+ Returns:
+ A TensorShapeProto.
+ """
+ return tensor_shape_pb2.TensorShapeProto(
+ dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=x) for x in shape])
+
+
+def TensorShapeProtoToList(shape):
+ """Convert a TensorShape to a list.
+
+ Args:
+ shape: A TensorShapeProto.
+
+ Returns:
+ List of integers representing the dimensions of the tensor.
+ """
+ return [dim.size for dim in shape.dim]
+
+
+def _GetDenseDimensions(list_of_lists):
+ """Returns the inferred dense dimensions of a list of lists."""
+ if not isinstance(list_of_lists, (list, tuple)):
+ return []
+ elif not list_of_lists:
+ return [0]
+ else:
+ return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0])
+
+
+def _FlattenToStrings(nested_strings):
+ if isinstance(nested_strings, list):
+ for inner in nested_strings:
+ for flattened_string in _FlattenToStrings(inner):
+ yield flattened_string
+ else:
+ yield nested_strings
+
+
+_TENSOR_CONTENT_TYPES = frozenset([
+ types.float32, types.float64, types.int32, types.uint8, types.int16,
+ types.int8, types.int64
+])
+
+
+def _FirstNotNone(l):
+ for x in l:
+ if x is not None:
+ return x
+ return None
+
+
+def _FilterInt(v):
+ if isinstance(v, (list, tuple)):
+ return _FirstNotNone([_FilterInt(x) for x in v])
+ return None if isinstance(v, numbers.Integral) else repr(v)
+
+
+def _FilterFloat(v):
+ if isinstance(v, (list, tuple)):
+ return _FirstNotNone([_FilterFloat(x) for x in v])
+ return None if isinstance(v, numbers.Real) else repr(v)
+
+
+def _FilterComplex(v):
+ if isinstance(v, (list, tuple)):
+ return _FirstNotNone([_FilterComplex(x) for x in v])
+ return None if isinstance(v, numbers.Complex) else repr(v)
+
+
+def _FilterStr(v):
+ if isinstance(v, (list, tuple)):
+ return _FirstNotNone([_FilterStr(x) for x in v])
+ return None if isinstance(v, basestring) else repr(v)
+
+
+def _FilterBool(v):
+ if isinstance(v, (list, tuple)):
+ return _FirstNotNone([_FilterBool(x) for x in v])
+ return None if isinstance(v, bool) else repr(v)
+
+
+def _FilterNotTensor(v):
+ if isinstance(v, (list, tuple)):
+ return _FirstNotNone([_FilterNotTensor(x) for x in v])
+ return repr(v) if isinstance(v, ops.Tensor) else None
+
+
+_TF_TO_IS_OK = {
+ types.float32: _FilterFloat,
+ types.float64: _FilterFloat,
+ types.int32: _FilterInt,
+ types.uint8: _FilterInt,
+ types.int16: _FilterInt,
+ types.int8: _FilterInt,
+ types.string: _FilterStr,
+ types.complex64: _FilterComplex,
+ types.int64: _FilterInt,
+ types.bool: _FilterBool,
+ types.qint32: _FilterInt,
+ types.quint8: _FilterInt,
+ types.qint8: _FilterInt,
+}
+
+
+def _AssertCompatible(values, dtype):
+ fn = _TF_TO_IS_OK.get(dtype, _FilterNotTensor)
+ mismatch = fn(values)
+ if mismatch is not None:
+ if dtype is None:
+ raise TypeError("List of Tensors when single Tensor expected")
+ else:
+ raise TypeError("Expected %s, got %s instead." %
+ (dtype.name, mismatch))
+
+
+def make_tensor_proto(values, dtype=None, shape=None):
+ """Create a TensorProto.
+
+ Args:
+ values: Values to put in the TensorProto.
+ dtype: Optional tensor_pb2 DataType value.
+ shape: List of integers representing the dimensions of tensor.
+
+ Returns:
+ A TensorProto. Depending on the type, it may contain data in the
+ "tensor_content" attribute, which is not directly useful to Python programs.
+ To access the values you should convert the proto back to a numpy ndarray
+ with tensor_util.MakeNdarray(proto).
+
+ Raises:
+ TypeError: if unsupported types are provided.
+ ValueError: if arguments have inappropriate values.
+
+ make_tensor_proto accepts "values" of a python scalar, a python list, a
+ numpy ndarray, or a numpy scalar.
+
+ If "values" is a python scalar or a python list, make_tensor_proto
+ first convert it to numpy ndarray. If dtype is None, the
+ conversion tries its best to infer the right numpy data
+ type. Otherwise, the resulting numpy array has a compatible data
+ type with the given dtype.
+
+ In either case above, the numpy ndarray (either the caller provided
+ or the auto converted) must have the compatible type with dtype.
+
+ make_tensor_proto then converts the numpy array to a tensor proto.
+
+ If "shape" is None, the resulting tensor proto represents the numpy
+ array precisely.
+
+ Otherwise, "shape" specifies the tensor's shape and the numpy array
+ can not have more elements than what "shape" specifies.
+
+ """
+ if dtype:
+ dtype = types.as_dtype(dtype)
+
+ # We first convert value to a numpy array or scalar.
+ if isinstance(values, (np.ndarray, np.generic)):
+ if dtype:
+ nparray = values.astype(dtype.as_numpy_dtype)
+ else:
+ nparray = values
+ else:
+ if values is None:
+ raise ValueError("None values not supported.")
+ # if dtype is provided, forces numpy array to be the type
+ # provided if possible.
+ np_dt = dtype.as_numpy_dtype if dtype else None
+ if np.prod(shape) == 0:
+ nparray = np.empty(shape, dtype=np_dt)
+ else:
+ _AssertCompatible(values, dtype)
+ nparray = np.array(values, dtype=np_dt)
+ if list(nparray.shape) != _GetDenseDimensions(values):
+ raise ValueError("Argument must be a dense tensor: %s" % values)
+ # python/numpy default float type is float64. We prefer float32 instead.
+ if (nparray.dtype == np.float64) and dtype is None:
+ nparray = nparray.astype(np.float32)
+ # python/numpy default int type is int64. We prefer int32 instead.
+ elif (nparray.dtype == np.int64) and dtype is None:
+ nparray = nparray.astype(np.int32)
+
+ # if dtype is provided, it must be compatible with what numpy
+ # conversion says.
+ numpy_dtype = types.as_dtype(nparray.dtype)
+ if numpy_dtype is None:
+ raise TypeError("Unrecognized data type: %s" % nparray.dtype)
+
+ # If dtype was specified and is a quantized type, we convert
+ # numpy_dtype back into the quantized version.
+ if dtype in [types.qint8, types.quint8, types.qint32]:
+ numpy_dtype = dtype
+
+ if dtype is not None and not dtype.base_dtype == numpy_dtype.base_dtype:
+ raise TypeError("Incompatible types: %s vs. %s" % (dtype, nparray.dtype))
+
+ # If shape is not given, get the shape from the numpy array.
+ if shape is None:
+ shape = nparray.shape
+ is_same_size = True
+ shape_size = nparray.size
+ else:
+ shape = [int(dim) for dim in shape]
+ shape_size = np.prod(shape)
+ is_same_size = shape_size == nparray.size
+
+ if nparray.size > shape_size:
+ raise ValueError(
+ "Too many elements provided. Needed at most %d, but received %d" %
+ (shape_size, nparray.size))
+
+ tensor_proto = tensor_pb2.TensorProto(
+ dtype=numpy_dtype.as_datatype_enum,
+ tensor_shape=MakeTensorShapeProto(shape))
+
+ if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1:
+ tensor_proto.tensor_content = nparray.tostring()
+ return tensor_proto
+
+ # If we were not given values as a numpy array, compute the proto_values
+ # from the given values directly, to avoid numpy trimming nulls from the
+ # strings. Since values could be a list of strings, or a multi-dimensional
+ # list of lists that might or might not correspond to the given shape,
+ # we flatten it conservatively.
+ if numpy_dtype == types.string and not isinstance(values, np.ndarray):
+ proto_values = _FlattenToStrings(values)
+ tensor_proto.string_val.extend([str(x) for x in proto_values])
+ return tensor_proto
+
+ # TensorFlow expects C order (a.k.a., eigen row major).
+ proto_values = nparray.ravel()
+
+ append_fn = GetNumpyAppendFn(proto_values.dtype)
+ if append_fn is None:
+ raise TypeError("Element type not supported in TensorProto: %s" %
+ numpy_dtype.name)
+ append_fn(tensor_proto, proto_values)
+
+ return tensor_proto
+
+
+def MakeNdarray(tensor):
+ """Create a numpy ndarray from a tensor.
+
+ Create a numpy ndarray with the same shape and data as the tensor.
+
+ Args:
+ tensor: A TensorProto.
+
+ Returns:
+ A numpy array with the tensor contents.
+
+ Raises:
+ TypeError: if tensor has unsupported type.
+
+ """
+ shape = [d.size for d in tensor.tensor_shape.dim]
+ num_elements = np.prod(shape)
+ tensor_dtype = types.as_dtype(tensor.dtype)
+ dtype = tensor_dtype.as_numpy_dtype
+
+ if tensor.tensor_content:
+ return np.fromstring(tensor.tensor_content, dtype=dtype).reshape(shape)
+ elif tensor_dtype == types.float32:
+ if len(tensor.float_val) == 1:
+ return np.repeat(np.array(tensor.float_val[0], dtype=dtype),
+ num_elements).reshape(shape)
+ else:
+ return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape)
+ elif tensor_dtype == types.float64:
+ if len(tensor.double_val) == 1:
+ return np.repeat(np.array(tensor.double_val[0], dtype=dtype),
+ num_elements).reshape(shape)
+ else:
+ return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape)
+ elif tensor_dtype in [types.int32, types.uint8, types.int16, types.int8,
+ types.qint32, types.quint8, types.qint8,
+ types.bfloat16]:
+ if len(tensor.int_val) == 1:
+ return np.repeat(np.array(tensor.int_val[0], dtype=dtype),
+ num_elements).reshape(shape)
+ else:
+ return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape)
+ elif tensor_dtype == types.int64:
+ if len(tensor.int64_val) == 1:
+ return np.repeat(np.array(tensor.int64_val[0], dtype=dtype),
+ num_elements).reshape(shape)
+ else:
+ return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape)
+ elif tensor_dtype == types.string:
+ if len(tensor.string_val) == 1:
+ return np.repeat(np.array(str(tensor.string_val[0]), dtype=dtype),
+ num_elements).reshape(shape)
+ else:
+ return np.array([str(x) for x in tensor.string_val],
+ dtype=dtype).reshape(shape)
+ elif tensor_dtype == types.complex64:
+ it = iter(tensor.scomplex_val)
+ if len(tensor.scomplex_val) == 2:
+ return np.repeat(np.array(complex(tensor.scomplex_val[0],
+ tensor.scomplex_val[1]), dtype=dtype),
+ num_elements).reshape(shape)
+ else:
+ return np.array([complex(x[0], x[1]) for x in zip(it, it)],
+ dtype=dtype).reshape(shape)
+ elif tensor_dtype == types.bool:
+ if len(tensor.bool_val) == 1:
+ return np.repeat(np.array(tensor.bool_val[0], dtype=dtype),
+ num_elements).reshape(shape)
+ else:
+ return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape)
+ else:
+ raise TypeError("Unsupported tensor type: %s" % tensor.dtype)
+
+
+def ShapeEquals(tensor_proto, shape):
+ """Returns True if "tensor_proto" has the given "shape".
+
+ Args:
+ tensor_proto: A TensorProto.
+ shape: A tensor shape, expressed as a TensorShape, list, or tuple.
+
+ Returns:
+ True if "tensor_proto" has the given "shape", otherwise False.
+
+ Raises:
+ TypeError: If "tensor_proto" is not a TensorProto, or shape is not a
+ TensorShape, list, or tuple.
+ """
+ if not isinstance(tensor_proto, tensor_pb2.TensorProto):
+ raise TypeError("tensor_proto is not a tensor_pb2.TensorProto object")
+ if isinstance(shape, tensor_shape_pb2.TensorShapeProto):
+ shape = [d.size for d in shape.dim]
+ elif not isinstance(shape, (list, tuple)):
+ raise TypeError("shape is not a list or tuple")
+ tensor_shape_list = [d.size for d in tensor_proto.tensor_shape.dim]
+ return all(x == y for x, y in zip(tensor_shape_list, shape))
+
+
+def ConstantValue(tensor):
+ """Returns the constant value of the given tensor, if efficiently calculable.
+
+ This function attempts to partially evaluate the given tensor, and
+ returns its value as a numpy ndarray if this succeeds.
+
+ TODO(mrry): Consider whether this function should use a registration
+ mechanism like gradients and ShapeFunctions, so that it is easily
+ extensible.
+
+ Args:
+ tensor: The Tensor to be evaluated.
+
+ Returns:
+ A numpy ndarray containing the constant value of the given `tensor`,
+ or None if it cannot be calculated.
+
+ Raises:
+ TypeError: if tensor is not an ops.Tensor.
+ """
+ # TODO(mdevin): Support Variables?
+ if not isinstance(tensor, ops.Tensor):
+ raise TypeError("tensor is not a Tensor")
+ if tensor.op.type == "Const":
+ return MakeNdarray(tensor.op.get_attr("value"))
+ elif tensor.op.type == "Shape":
+ input_shape = tensor.op.inputs[0].get_shape()
+ if input_shape.is_fully_defined():
+ return np.array([dim.value for dim in input_shape.dims])
+ else:
+ return None
+ elif tensor.op.type == "Size":
+ input_shape = tensor.op.inputs[0].get_shape()
+ if input_shape.is_fully_defined():
+ return np.array([np.prod([dim.value for dim in input_shape.dims])])
+ else:
+ return None
+ elif tensor.op.type == "Rank":
+ input_shape = tensor.op.inputs[0].get_shape()
+ if input_shape.ndims is not None:
+ return np.array([input_shape.ndims])
+ else:
+ return None
+ elif tensor.op.type == "Range":
+ start = ConstantValue(tensor.op.inputs[0])
+ if start is None:
+ return None
+ limit = ConstantValue(tensor.op.inputs[1])
+ if limit is None:
+ return None
+ delta = ConstantValue(tensor.op.inputs[2])
+ if delta is None:
+ return None
+ return np.array(range(start, limit, delta),
+ dtype=tensor.dtype.as_numpy_dtype)
+ else:
+ return None
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
new file mode 100644
index 0000000000..7c1c0b8d3e
--- /dev/null
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -0,0 +1,379 @@
+"""Functional tests for tensor_util."""
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import state_ops
+from tensorflow.python.platform import googletest
+
+
+class TensorUtilTest(test_util.TensorFlowTestCase):
+
+ def testFloat(self):
+ t = tensor_util.make_tensor_proto(10.0)
+ self.assertProtoEquals("""
+ dtype: DT_FLOAT
+ tensor_shape {}
+ float_val: 10.0
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.float32, a.dtype)
+ self.assertAllClose(np.array(10.0, dtype=np.float32), a)
+
+ def testFloatN(self):
+ t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0])
+ self.assertProtoEquals("""
+ dtype: DT_FLOAT
+ tensor_shape { dim { size: 3 } }
+ tensor_content: "\000\000 A\000\000\240A\000\000\360A"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.float32, a.dtype)
+ self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
+
+ def testFloatTyped(self):
+ t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=types.float32)
+ self.assertProtoEquals("""
+ dtype: DT_FLOAT
+ tensor_shape { dim { size: 3 } }
+ tensor_content: "\000\000 A\000\000\240A\000\000\360A"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.float32, a.dtype)
+ self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
+
+ def testFloatTypeCoerce(self):
+ t = tensor_util.make_tensor_proto([10, 20, 30], dtype=types.float32)
+ self.assertProtoEquals("""
+ dtype: DT_FLOAT
+ tensor_shape { dim { size: 3 } }
+ tensor_content: "\000\000 A\000\000\240A\000\000\360A"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.float32, a.dtype)
+ self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
+
+ def testFloatTypeCoerceNdarray(self):
+ arr = np.asarray([10, 20, 30], dtype="int")
+ t = tensor_util.make_tensor_proto(arr, dtype=types.float32)
+ self.assertProtoEquals("""
+ dtype: DT_FLOAT
+ tensor_shape { dim { size: 3 } }
+ tensor_content: "\000\000 A\000\000\240A\000\000\360A"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.float32, a.dtype)
+ self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
+
+ def testFloatSizes(self):
+ t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[1, 3])
+ self.assertProtoEquals("""
+ dtype: DT_FLOAT
+ tensor_shape { dim { size: 1 } dim { size: 3 } }
+ tensor_content: "\000\000 A\000\000\240A\000\000\360A"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.float32, a.dtype)
+ self.assertAllClose(np.array([[10.0, 20.0, 30.0]], dtype=np.float32), a)
+
+ def testFloatSizes2(self):
+ t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[3, 1])
+ self.assertProtoEquals("""
+ dtype: DT_FLOAT
+ tensor_shape { dim { size: 3 } dim { size: 1 } }
+ tensor_content: "\000\000 A\000\000\240A\000\000\360A"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.float32, a.dtype)
+ self.assertAllClose(np.array([[10.0], [20.0], [30.0]], dtype=np.float32),
+ a)
+
+ def testFloatSizesLessValues(self):
+ t = tensor_util.make_tensor_proto(10.0, shape=[1, 3])
+ self.assertProtoEquals("""
+ dtype: DT_FLOAT
+ tensor_shape { dim { size: 1 } dim { size: 3 } }
+ float_val: 10.0
+ """, t)
+ # No conversion to Ndarray for this one: not enough values.
+
+ def testFloatNpArrayFloat64(self):
+ t = tensor_util.make_tensor_proto(
+ np.array([[10.0, 20.0, 30.0]], dtype=np.float64))
+ self.assertProtoEquals("""
+ dtype: DT_DOUBLE
+ tensor_shape { dim { size: 1 } dim { size: 3 } }
+ tensor_content: "\000\000\000\000\000\000$@\000\000\000\000\000\0004@\000\000\000\000\000\000>@"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.float64, a.dtype)
+ self.assertAllClose(np.array([[10.0, 20.0, 30.0]], dtype=np.float64),
+ tensor_util.MakeNdarray(t))
+
+ def testFloatTypesWithImplicitRepeat(self):
+ for dtype, nptype in [
+ (types.float32, np.float32), (types.float64, np.float64)]:
+ t = tensor_util.make_tensor_proto([10.0], shape=[3, 4], dtype=dtype)
+ a = tensor_util.MakeNdarray(t)
+ self.assertAllClose(np.array([[10.0, 10.0, 10.0, 10.0],
+ [10.0, 10.0, 10.0, 10.0],
+ [10.0, 10.0, 10.0, 10.0]], dtype=nptype), a)
+
+ def testInt(self):
+ t = tensor_util.make_tensor_proto(10)
+ self.assertProtoEquals("""
+ dtype: DT_INT32
+ tensor_shape {}
+ int_val: 10
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.int32, a.dtype)
+ self.assertAllClose(np.array(10, dtype=np.int32), a)
+
+ def testIntNDefaultType(self):
+ t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2])
+ self.assertProtoEquals("""
+ dtype: DT_INT32
+ tensor_shape { dim { size: 2 } dim { size: 2 } }
+ tensor_content: "\\n\000\000\000\024\000\000\000\036\000\000\000(\000\000\000"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.int32, a.dtype)
+ self.assertAllClose(np.array([[10, 20], [30, 40]], dtype=np.int32), a)
+
+ def testIntTypes(self):
+ for dtype, nptype in [
+ (types.int32, np.int32),
+ (types.uint8, np.uint8),
+ (types.int16, np.int16),
+ (types.int8, np.int8)]:
+ # Test with array.
+ t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtype)
+ self.assertEquals(dtype, t.dtype)
+ self.assertProtoEquals("dim { size: 3 }", t.tensor_shape)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(nptype, a.dtype)
+ self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
+ # Test with ndarray.
+ t = tensor_util.make_tensor_proto(np.array([10, 20, 30], dtype=nptype))
+ self.assertEquals(dtype, t.dtype)
+ self.assertProtoEquals("dim { size: 3 }", t.tensor_shape)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(nptype, a.dtype)
+ self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
+
+ def testIntTypesWithImplicitRepeat(self):
+ for dtype, nptype in [
+ (types.int64, np.int64),
+ (types.int32, np.int32),
+ (types.uint8, np.uint8),
+ (types.int16, np.int16),
+ (types.int8, np.int8)]:
+ t = tensor_util.make_tensor_proto([10], shape=[3, 4], dtype=dtype)
+ a = tensor_util.MakeNdarray(t)
+ self.assertAllEqual(np.array([[10, 10, 10, 10],
+ [10, 10, 10, 10],
+ [10, 10, 10, 10]], dtype=nptype), a)
+
+ def testLong(self):
+ t = tensor_util.make_tensor_proto(10, dtype=types.int64)
+ self.assertProtoEquals("""
+ dtype: DT_INT64
+ tensor_shape {}
+ int64_val: 10
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.int64, a.dtype)
+ self.assertAllClose(np.array(10, dtype=np.int64), a)
+
+ def testLongN(self):
+ t = tensor_util.make_tensor_proto([10, 20, 30], shape=[1, 3],
+ dtype=types.int64)
+ self.assertProtoEquals("""
+ dtype: DT_INT64
+ tensor_shape { dim { size: 1 } dim { size: 3 } }
+ tensor_content: "\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.int64, a.dtype)
+ self.assertAllClose(np.array([[10, 20, 30]], dtype=np.int64), a)
+
+ def testLongNpArray(self):
+ t = tensor_util.make_tensor_proto(np.array([10, 20, 30]))
+ self.assertProtoEquals("""
+ dtype: DT_INT64
+ tensor_shape { dim { size: 3 } }
+ tensor_content: "\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.int64, a.dtype)
+ self.assertAllClose(np.array([10, 20, 30], dtype=np.int64), a)
+
+ def testString(self):
+ t = tensor_util.make_tensor_proto("foo")
+ self.assertProtoEquals("""
+ dtype: DT_STRING
+ tensor_shape {}
+ string_val: "foo"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.object, a.dtype)
+ self.assertEquals(["foo"], a)
+
+ def testStringWithImplicitRepeat(self):
+ t = tensor_util.make_tensor_proto("f", shape=[3, 4])
+ a = tensor_util.MakeNdarray(t)
+ self.assertAllEqual(np.array([["f", "f", "f", "f"],
+ ["f", "f", "f", "f"],
+ ["f", "f", "f", "f"]], dtype=np.object), a)
+
+ def testStringN(self):
+ t = tensor_util.make_tensor_proto(["foo", "bar", "baz"], shape=[1, 3])
+ self.assertProtoEquals("""
+ dtype: DT_STRING
+ tensor_shape { dim { size: 1 } dim { size: 3 } }
+ string_val: "foo"
+ string_val: "bar"
+ string_val: "baz"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.object, a.dtype)
+ self.assertAllEqual(np.array([["foo", "bar", "baz"]]), a)
+
+ def testStringNpArray(self):
+ t = tensor_util.make_tensor_proto(np.array([["a", "ab"], ["abc", "abcd"]]))
+ self.assertProtoEquals("""
+ dtype: DT_STRING
+ tensor_shape { dim { size: 2 } dim { size: 2 } }
+ string_val: "a"
+ string_val: "ab"
+ string_val: "abc"
+ string_val: "abcd"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.object, a.dtype)
+ self.assertAllEqual(np.array([["a", "ab"], ["abc", "abcd"]]), a)
+
+ def testComplex(self):
+ t = tensor_util.make_tensor_proto((1+2j), dtype=types.complex64)
+ self.assertProtoEquals("""
+ dtype: DT_COMPLEX64
+ tensor_shape {}
+ scomplex_val: 1
+ scomplex_val: 2
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.complex64, a.dtype)
+ self.assertAllEqual(np.array(1 + 2j), a)
+
+ def testComplexWithImplicitRepeat(self):
+ t = tensor_util.make_tensor_proto((1+1j), shape=[3, 4],
+ dtype=types.complex64)
+ a = tensor_util.MakeNdarray(t)
+ self.assertAllClose(np.array([[(1+1j), (1+1j), (1+1j), (1+1j)],
+ [(1+1j), (1+1j), (1+1j), (1+1j)],
+ [(1+1j), (1+1j), (1+1j), (1+1j)]],
+ dtype=np.complex64), a)
+
+ def testComplexN(self):
+ t = tensor_util.make_tensor_proto([(1+2j), (3+4j), (5+6j)], shape=[1, 3],
+ dtype=types.complex64)
+ self.assertProtoEquals("""
+ dtype: DT_COMPLEX64
+ tensor_shape { dim { size: 1 } dim { size: 3 } }
+ scomplex_val: 1
+ scomplex_val: 2
+ scomplex_val: 3
+ scomplex_val: 4
+ scomplex_val: 5
+ scomplex_val: 6
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.complex64, a.dtype)
+ self.assertAllEqual(np.array([[(1+2j), (3+4j), (5+6j)]]), a)
+
+ def testComplexNpArray(self):
+ t = tensor_util.make_tensor_proto(
+ np.array([[(1+2j), (3+4j)], [(5+6j), (7+8j)]]), dtype=types.complex64)
+ # scomplex_val are real_0, imag_0, real_1, imag_1, ...
+ self.assertProtoEquals("""
+ dtype: DT_COMPLEX64
+ tensor_shape { dim { size: 2 } dim { size: 2 } }
+ scomplex_val: 1
+ scomplex_val: 2
+ scomplex_val: 3
+ scomplex_val: 4
+ scomplex_val: 5
+ scomplex_val: 6
+ scomplex_val: 7
+ scomplex_val: 8
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.complex64, a.dtype)
+ self.assertAllEqual(np.array([[(1+2j), (3+4j)], [(5+6j), (7+8j)]]), a)
+
+ def testUnsupportedDType(self):
+ with self.assertRaises(TypeError):
+ tensor_util.make_tensor_proto(np.array([1]), 0)
+
+ def testShapeTooLarge(self):
+ with self.assertRaises(ValueError):
+ tensor_util.make_tensor_proto(np.array([1, 2]), shape=[1])
+
+ def testLowRankSupported(self):
+ t = tensor_util.make_tensor_proto(np.array(7))
+ self.assertProtoEquals("""
+ dtype: DT_INT64
+ tensor_shape {}
+ int64_val: 7
+ """, t)
+
+ def testShapeEquals(self):
+ t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2])
+ self.assertTrue(tensor_util.ShapeEquals(t, [2, 2]))
+ self.assertTrue(tensor_util.ShapeEquals(t, (2, 2)))
+ self.assertTrue(
+ tensor_util.ShapeEquals(t, tensor_util.MakeTensorShapeProto([2, 2])))
+ self.assertFalse(tensor_util.ShapeEquals(t, [5, 3]))
+ self.assertFalse(tensor_util.ShapeEquals(t, [1, 4]))
+ self.assertFalse(tensor_util.ShapeEquals(t, [4]))
+
+
+class ConstantValueTest(test_util.TensorFlowTestCase):
+
+ def testConstant(self):
+ np_val = np.random.rand(3, 4, 7).astype(np.float32)
+ tf_val = constant_op.constant(np_val)
+ self.assertAllClose(np_val, tensor_util.ConstantValue(tf_val))
+
+ np_val = np.random.rand(3, 0, 7).astype(np.float32)
+ tf_val = constant_op.constant(np_val)
+ self.assertAllClose(np_val, tensor_util.ConstantValue(tf_val))
+
+ def testUnknown(self):
+ tf_val = state_ops.variable_op(shape=[3, 4, 7], dtype=types.float32)
+ self.assertIs(None, tensor_util.ConstantValue(tf_val))
+
+ def testShape(self):
+ np_val = np.array([1, 2, 3])
+ tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3]))
+ self.assertAllEqual(np_val, tensor_util.ConstantValue(tf_val))
+
+ def testSize(self):
+ np_val = np.array([6])
+ tf_val = array_ops.size(constant_op.constant(0.0, shape=[1, 2, 3]))
+ self.assertAllEqual(np_val, tensor_util.ConstantValue(tf_val))
+
+ def testRank(self):
+ np_val = np.array([3])
+ tf_val = array_ops.rank(constant_op.constant(0.0, shape=[1, 2, 3]))
+ self.assertAllEqual(np_val, tensor_util.ConstantValue(tf_val))
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/framework/test_kernel_label_op.cc b/tensorflow/python/framework/test_kernel_label_op.cc
new file mode 100644
index 0000000000..50f8522e1b
--- /dev/null
+++ b/tensorflow/python/framework/test_kernel_label_op.cc
@@ -0,0 +1,47 @@
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+REGISTER_OP("KernelLabel").Output("result: string");
+
+namespace {
+enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL };
+} // namespace
+
+template <KernelLabel KL>
+class KernelLabelOp : public OpKernel {
+ public:
+ using OpKernel::OpKernel;
+
+ void Compute(OpKernelContext* ctx) override {
+ Tensor* output;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output("result", TensorShape({}), &output));
+ switch (KL) {
+ case DEFAULT_LABEL:
+ output->scalar<string>()() = "My label is: default";
+ break;
+ case OVERLOAD_1_LABEL:
+ output->scalar<string>()() = "My label is: overload_1";
+ break;
+ case OVERLOAD_2_LABEL:
+ output->scalar<string>()() = "My label is: overload_2";
+ break;
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("KernelLabel").Device(DEVICE_CPU),
+ KernelLabelOp<DEFAULT_LABEL>);
+REGISTER_KERNEL_BUILDER(Name("KernelLabel")
+ .Device(DEVICE_CPU)
+ .Label("overload_1"),
+ KernelLabelOp<OVERLOAD_1_LABEL>);
+REGISTER_KERNEL_BUILDER(Name("KernelLabel")
+ .Device(DEVICE_CPU)
+ .Label("overload_2"),
+ KernelLabelOp<OVERLOAD_2_LABEL>);
+
+} // end namespace tensorflow
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
new file mode 100644
index 0000000000..597a5ad829
--- /dev/null
+++ b/tensorflow/python/framework/test_util.py
@@ -0,0 +1,437 @@
+# pylint: disable=invalid-name
+"""Test utils for tensorflow."""
+import contextlib
+import math
+import re
+import threading
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from google.protobuf import text_format
+
+from tensorflow.core.framework import config_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.client import graph_util
+from tensorflow.python.client import session
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import logging
+from tensorflow.python.util.protobuf import compare
+
+
+def IsGoogleCudaEnabled():
+ return pywrap_tensorflow.IsGoogleCudaEnabled()
+
+
+class TensorFlowTestCase(googletest.TestCase):
+ """Root class for tests that need to test tensor flow.
+ """
+
+ def __init__(self, methodName="runTest"):
+ super(TensorFlowTestCase, self).__init__(methodName)
+ self._threads = []
+ self._tempdir = None
+ self._cached_session = None
+
+ def setUp(self):
+ self._ClearCachedSession()
+ ops.reset_default_graph()
+
+ def tearDown(self):
+ for thread in self._threads:
+ self.assertFalse(thread.is_alive(), "A checkedThread did not terminate")
+ self._ClearCachedSession()
+
+ def _ClearCachedSession(self):
+ if self._cached_session is not None:
+ self._cached_session.close()
+ self._cached_session = None
+
+ def get_temp_dir(self):
+ if not self._tempdir:
+ self._tempdir = googletest.GetTempDir()
+ return self._tempdir
+
+ def _AssertProtoEquals(self, a, b):
+ """Asserts that a and b are the same proto.
+
+ Uses Proto2Cmp() first, as it returns correct results
+ for floating point attributes, and then use assertProto2Equal()
+ in case of failure as it provides good error messages.
+
+ Args:
+ a: a proto.
+ b: another proto.
+ """
+ if compare.Proto2Cmp(a, b) != 0:
+ compare.assertProto2Equal(self, a, b, normalize_numbers=True)
+
+ def assertProtoEquals(self, expected_message_maybe_ascii, message):
+ """Asserts that message is same as parsed expected_message_ascii.
+
+ Creates another prototype of message, reads the ascii message into it and
+ then compares them using self._AssertProtoEqual().
+
+ Args:
+ expected_message_maybe_ascii: proto message in original or ascii form
+ message: the message to validate
+ """
+
+ if type(expected_message_maybe_ascii) == type(message):
+ expected_message = expected_message_maybe_ascii
+ self._AssertProtoEquals(expected_message, message)
+ elif isinstance(expected_message_maybe_ascii, str):
+ expected_message = type(message)()
+ text_format.Merge(expected_message_maybe_ascii, expected_message)
+ self._AssertProtoEquals(expected_message, message)
+ else:
+ assert False, ("Can't compare protos of type " +
+ type(expected_message_maybe_ascii) + " and " +
+ type(message))
+
+ def assertStartsWith(self, actual, expected_start, msg=None):
+ """Assert that actual.startswith(expected_start) is True.
+
+ Args:
+ actual: str
+ expected_start: str
+ msg: Optional message to report on failure.
+ """
+ if not actual.startswith(expected_start):
+ fail_msg = "%r does not start with %r" % (actual, expected_start)
+ fail_msg += " : %r" % (msg) if msg else ""
+ self.fail(fail_msg)
+
+ # pylint: disable=g-doc-return-or-yield
+ @contextlib.contextmanager
+ def test_session(self,
+ graph=None,
+ config=None,
+ use_gpu=False,
+ force_gpu=False):
+ """Returns a TensorFlow Session for use in executing tests.
+
+ This method should be used for all functional tests.
+
+ Use the `use_gpu` and `force_gpu` options to control where ops are run. If
+ `force_gpu` is True, all ops are pinned to `/gpu:0`. Otherwise, if `use_gpu`
+ is True, TensorFlow tries to run as many ops on the GPU as possible. If both
+ `force_gpu and `use_gpu` are False, all ops are pinned to the CPU.
+
+ Example:
+
+ class MyOperatorTest(test_util.TensorFlowTestCase):
+ def testMyOperator(self):
+ with self.test_session(use_gpu=True):
+ valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
+ result = MyOperator(valid_input).eval()
+ self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
+ invalid_input = [-1.0, 2.0, 7.0]
+ with self.assertRaisesOpError("negative input not supported"):
+ MyOperator(invalid_input).eval()
+
+ Args:
+ graph: Optional graph to use during the returned session.
+ config: An optional config_pb2.ConfigProto to use to configure the
+ session.
+ use_gpu: If True, attempt to run as many ops as possible on GPU.
+ force_gpu: If True, pin all ops to `/gpu:0`.
+
+ Returns:
+ A Session object that should be used as a context manager to surround
+ the graph building and execution code in a test case.
+ """
+ def prepare_config(config):
+ if config is None:
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = not force_gpu
+ config.gpu_options.per_process_gpu_memory_fraction = 0.3
+ elif force_gpu and config.allow_soft_placement:
+ config = config_pb2.ConfigProto().CopyFrom(config)
+ config.allow_soft_placement = False
+ return config
+
+ if graph is None:
+ if self._cached_session is None:
+ self._cached_session = session.Session(graph=None,
+ config=prepare_config(config))
+ sess = self._cached_session
+ with sess.graph.as_default(), sess.as_default():
+ if force_gpu:
+ with sess.graph.device("/gpu:0"):
+ yield sess
+ elif use_gpu:
+ yield sess
+ else:
+ with sess.graph.device(graph_util.pin_to_cpu):
+ yield sess
+ else:
+ with session.Session(graph=graph, config=prepare_config(config)) as sess:
+ if force_gpu:
+ with sess.graph.device("/gpu:0"):
+ yield sess
+ elif use_gpu:
+ yield sess
+ else:
+ with sess.graph.device(graph_util.pin_to_cpu):
+ yield sess
+ # pylint: enable=g-doc-return-or-yield
+
+ class _CheckedThread(object):
+ """A wrapper class for Thread that asserts successful completion.
+
+ This class should be created using the TensorFlowTestCase.checkedThread()
+ method.
+ """
+
+ def __init__(self, testcase, target, args=None, kwargs=None):
+ """Constructs a new instance of _CheckedThread.
+
+ Args:
+ testcase: The TensorFlowTestCase for which this thread is being created.
+ target: A callable object representing the code to be executed in the
+ thread.
+ args: A tuple of positional arguments that will be passed to target.
+ kwargs: A dictionary of keyword arguments that will be passed to target.
+ """
+ self._testcase = testcase
+ self._target = target
+ self._args = () if args is None else args
+ self._kwargs = {} if kwargs is None else kwargs
+ self._thread = threading.Thread(target=self._protected_run)
+ self._exception = None
+
+ def _protected_run(self):
+ """Target for the wrapper thread. Sets self._exception on failure."""
+ try:
+ self._target(*self._args, **self._kwargs)
+# pylint: disable=broad-except
+ except Exception as e:
+ # pylint: enable=broad-except
+ self._exception = e
+
+ def start(self):
+ """Starts the thread's activity.
+
+ This must be called at most once per _CheckedThread object. It arranges
+ for the object's target to be invoked in a separate thread of control.
+ """
+ self._thread.start()
+
+ def join(self):
+ """Blocks until the thread terminates.
+
+ Raises:
+ self._testcase.failureException: If the thread terminates with due to
+ an exception.
+ """
+ self._thread.join()
+ if self._exception is not None:
+ self._testcase.fail(
+ "Error in checkedThread: %s" % str(self._exception))
+
+ def is_alive(self):
+ """Returns whether the thread is alive.
+
+ This method returns True just before the run() method starts
+ until just after the run() method terminates.
+
+ Returns:
+ True if the thread is alive, otherwise False.
+ """
+ return self._thread.is_alive()
+
+ def checkedThread(self, target, args=None, kwargs=None):
+ """Returns a Thread wrapper that asserts 'target' completes successfully.
+
+ This method should be used to create all threads in test cases, as
+ otherwise there is a risk that a thread will silently fail, and/or
+ assertions made in the thread will not be respected.
+
+ Args:
+ target: A callable object to be executed in the thread.
+ args: The argument tuple for the target invocation. Defaults to ().
+ kwargs: A dictionary of keyword arguments for the target invocation.
+ Defaults to {}.
+
+ Returns:
+ A wrapper for threading.Thread that supports start() and join() methods.
+ """
+ ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs)
+ self._threads.append(ret)
+ return ret
+# pylint: enable=invalid-name
+
+ def assertNear(self, f1, f2, err):
+ """Asserts that two floats are near each other.
+
+ Checks that |f1 - f2| < err and asserts a test failure
+ if not.
+
+ Args:
+ f1: a float value.
+ f2: a float value.
+ err: a float value.
+ """
+ self.assertTrue(math.fabs(f1 - f2) < err)
+
+ def assertArrayNear(self, farray1, farray2, err):
+ """Asserts that two float arrays are near each other.
+
+ Checks that for all elements of farray1 and farray2
+ |f1 - f2| < err. Asserts a test failure if not.
+
+ Args:
+ farray1: a list of float values.
+ farray2: a list of float values.
+ err: a float value.
+ """
+ for f1, f2 in zip(farray1, farray2):
+ self.assertNear(f1, f2, err)
+
+ def _NDArrayNear(self, ndarray1, ndarray2, err):
+ return np.linalg.norm(ndarray1 - ndarray2) < err
+
+ def assertNDArrayNear(self, ndarray1, ndarray2, err):
+ """Asserts that two numpy arrays have near values.
+
+ Args:
+ ndarray1: a numpy ndarray.
+ ndarray2: a numpy ndarray.
+ err: a float. The maximum absolute difference allowed.
+ """
+ self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err))
+
+ def _GetNdArray(self, a):
+ if not isinstance(a, np.ndarray):
+ a = np.array(a)
+ return a
+
+ def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6):
+ """Asserts that two numpy arrays have near values.
+
+ Args:
+ a: a numpy ndarray or anything can be converted to one.
+ b: a numpy ndarray or anything can be converted to one.
+ rtol: relative tolerance
+ atol: absolute tolerance
+ """
+ a = self._GetNdArray(a)
+ b = self._GetNdArray(b)
+ self.assertEqual(
+ a.shape, b.shape,
+ "Shape mismatch: expected %s, got %s." % (a.shape, b.shape))
+ if not np.allclose(a, b, rtol=rtol, atol=atol):
+ # Prints more details than np.testing.assert_allclose.
+ #
+ # NOTE: numpy.allclose (and numpy.testing.assert_allclose)
+ # checks whether two arrays are element-wise equal within a
+ # tolerance. The relative difference (rtol * abs(b)) and the
+ # absolute difference atol are added together to compare against
+ # the absolute difference between a and b. Here, we want to
+ # print out which elements violate such conditions.
+ cond = np.abs(a - b) > atol + rtol * np.abs(b)
+ if a.ndim:
+ x = a[np.where(cond)]
+ y = b[np.where(cond)]
+ print "not close where = ", np.where(cond)
+ else:
+ # np.where is broken for scalars
+ x, y = a, b
+ print "not close lhs = ", x
+ print "not close rhs = ", y
+ print "not close dif = ", np.abs(x - y)
+ print "not close tol = ", atol + rtol * np.abs(y)
+ np.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
+
+ def assertAllEqual(self, a, b):
+ """Asserts that two numpy arrays have the same values.
+
+ Args:
+ a: a numpy ndarray or anything can be converted to one.
+ b: a numpy ndarray or anything can be converted to one.
+ """
+ a = self._GetNdArray(a)
+ b = self._GetNdArray(b)
+ self.assertEqual(
+ a.shape, b.shape,
+ "Shape mismatch: expected %s, got %s." % (a.shape, b.shape))
+ same = (a == b)
+
+ if a.dtype == np.float32 or a.dtype == np.float64:
+ same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b)))
+ if not np.all(same):
+ # Prints more details than np.testing.assert_array_equal.
+ diff = np.logical_not(same)
+ if a.ndim:
+ x = a[np.where(diff)]
+ y = b[np.where(diff)]
+ print "not equal where = ", np.where(diff)
+ else:
+ # np.where is broken for scalars
+ x, y = a, b
+ print "not equal lhs = ", x
+ print "not equal rhs = ", y
+ np.testing.assert_array_equal(a, b)
+
+ # pylint: disable=g-doc-return-or-yield
+ @contextlib.contextmanager
+ def assertRaisesWithPredicateMatch(self, exception_type,
+ expected_err_re_or_predicate):
+ """Returns a context manager to enclose code expected to raise an exception.
+
+ Args:
+ exception_type: The expected type of exception that should be raised.
+ expected_err_re_or_predicate: If this is callable, it should be a function
+ of one argument that inspects the passed-in OpError exception and
+ returns True (success) or False (please fail the test). Otherwise, the
+ error message is expected to match this regular expression partially.
+
+ Returns:
+ A context manager to surround code that is expected to raise an
+ errors.OpError exception.
+ """
+ if callable(expected_err_re_or_predicate):
+ predicate = expected_err_re_or_predicate
+ else:
+ def predicate(e):
+ err_str = e.message
+ op = e.op
+ while op is not None:
+ err_str += "\nCaused by: " + op.name
+ op = op._original_op
+ logging.info("Searching within error strings: '%s' within '%s'",
+ expected_err_re_or_predicate, err_str)
+ return re.search(expected_err_re_or_predicate, err_str)
+ try:
+ yield
+ self.fail(exception_type.__name__ + " not raised")
+# pylint: disable=broad-except
+ except Exception as e:
+ # pylint: enable=broad-except
+ if not isinstance(e, exception_type) or not predicate(e):
+ raise AssertionError(e)
+ # pylint: enable=g-doc-return-or-yield
+
+ def assertRaisesOpError(self, expected_err_re_or_predicate):
+ return self.assertRaisesWithPredicateMatch(errors.OpError,
+ expected_err_re_or_predicate)
+
+ def assertShapeEqual(self, np_array, tf_tensor):
+ """Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape.
+
+ Args:
+ np_array: A Numpy ndarray or Numpy scalar.
+ tf_tensor: A Tensor.
+
+ Raises:
+ TypeError: If the arguments have the wrong type.
+ """
+ if not isinstance(np_array, (np.ndarray, np.generic)):
+ raise TypeError("np_array must be a Numpy ndarray or Numpy scalar")
+ if not isinstance(tf_tensor, ops.Tensor):
+ raise TypeError("tf_tensor must be a Tensor")
+ self.assertAllEqual(np_array.shape, tf_tensor.get_shape().as_list())
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
new file mode 100644
index 0000000000..e0618cfea4
--- /dev/null
+++ b/tensorflow/python/framework/test_util_test.py
@@ -0,0 +1,128 @@
+"""Tests for tensorflow.ops.test_util."""
+import threading
+
+import tensorflow.python.platform
+import numpy as np
+
+from google.protobuf import text_format
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.platform import googletest
+from tensorflow.python.ops import logging_ops
+
+class TestUtilTest(test_util.TensorFlowTestCase):
+
+ def testIsGoogleCudaEnabled(self):
+ # The test doesn't assert anything. It ensures the py wrapper
+ # function is generated correctly.
+ if test_util.IsGoogleCudaEnabled():
+ print "GoogleCuda is enabled"
+ else:
+ print "GoogleCuda is disabled"
+
+ def testAssertProtoEqualsStr(self):
+
+ graph_str = "node { name: 'w1' op: 'params' }"
+ graph_def = graph_pb2.GraphDef()
+ text_format.Merge(graph_str, graph_def)
+
+ # test string based comparison
+ self.assertProtoEquals(graph_str, graph_def)
+
+ # test original comparison
+ self.assertProtoEquals(graph_def, graph_def)
+
+ def testNDArrayNear(self):
+ a1 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ a2 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ a3 = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]])
+ self.assertTrue(self._NDArrayNear(a1, a2, 1e-5))
+ self.assertFalse(self._NDArrayNear(a1, a3, 1e-5))
+
+ def testCheckedThreadSucceeds(self):
+ def noop(ev):
+ ev.set()
+
+ event_arg = threading.Event()
+
+ self.assertFalse(event_arg.is_set())
+ t = self.checkedThread(target=noop, args=(event_arg,))
+ t.start()
+ t.join()
+ self.assertTrue(event_arg.is_set())
+
+ def testCheckedThreadFails(self):
+ def err_func():
+ return 1 / 0
+
+ t = self.checkedThread(target=err_func)
+ t.start()
+ with self.assertRaises(self.failureException) as fe:
+ t.join()
+ self.assertTrue("integer division or modulo by zero"
+ in fe.exception.message)
+
+ def testCheckedThreadWithWrongAssertionFails(self):
+ x = 37
+
+ def err_func():
+ self.assertTrue(x < 10)
+
+ t = self.checkedThread(target=err_func)
+ t.start()
+ with self.assertRaises(self.failureException) as fe:
+ t.join()
+ self.assertTrue("False is not true" in fe.exception.message)
+
+ def testMultipleThreadsWithOneFailure(self):
+ def err_func(i):
+ self.assertTrue(i != 7)
+
+ threads = [self.checkedThread(target=err_func, args=(i,))
+ for i in range(10)]
+ for t in threads:
+ t.start()
+ for i, t in enumerate(threads):
+ if i == 7:
+ with self.assertRaises(self.failureException):
+ t.join()
+ else:
+ t.join()
+
+ def _WeMustGoDeeper(self, msg):
+ with self.assertRaisesOpError(msg):
+ node_def = ops._NodeDef("op_type", "name")
+ node_def_orig = ops._NodeDef("op_type_orig", "orig")
+ op_orig = ops.Operation(node_def_orig, ops.get_default_graph())
+ op = ops.Operation(node_def, ops.get_default_graph(), original_op=op_orig)
+ raise errors.UnauthenticatedError(node_def, op, "true_err")
+
+ def testAssertRaisesOpErrorDoesNotPassMessageDueToLeakedStack(self):
+ with self.assertRaises(AssertionError):
+ self._WeMustGoDeeper("this_is_not_the_error_you_are_looking_for")
+
+ self._WeMustGoDeeper("true_err")
+ self._WeMustGoDeeper("name")
+ self._WeMustGoDeeper("orig")
+
+ def testAllCloseScalars(self):
+ self.assertAllClose(7, 7 + 1e-8)
+ with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"):
+ self.assertAllClose(7, 8)
+
+ def testForceGPU(self):
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "Cannot assign a device to node"):
+ with self.test_session(force_gpu=True):
+ # this relies on us not having a GPU implementation for assert, which
+ # seems sensible
+ x = [True]
+ y = [15]
+ logging_ops.Assert(x, y).run()
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/framework/types.py b/tensorflow/python/framework/types.py
new file mode 100644
index 0000000000..6a8c629fe4
--- /dev/null
+++ b/tensorflow/python/framework/types.py
@@ -0,0 +1,418 @@
+"""Library of dtypes (Tensor element types)."""
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.core.framework import types_pb2
+
+
+class DType(object):
+ """Represents the type of the elements in a `Tensor`.
+
+ The following `DType` objects are defined:
+
+ * `tf.float32`: 32-bit single-precision floating-point.
+ * `tf.float64`: 64-bit double-precision floating-point.
+ * `tf.bfloat16`: 16-bit truncated floating-point.
+ * `tf.complex64`: 64-bit single-precision complex.
+
+ * `tf.int8`: 8-bit signed integer.
+ * `tf.uint8`: 8-bit unsigned integer.
+ * `tf.int32`: 32-bit signed integer.
+ * `tf.int64`: 64-bit signed integer.
+
+ * `tf.bool`: Boolean.
+
+ * `tf.string`: String.
+
+ * `tf.qint8`: Quantized 8-bit signed integer.
+ * `tf.quint8`: Quantized 8-bit unsigned integer.
+ * `tf.qint32`: Quantized 32-bit signed integer.
+
+ In addition, variants of these types with the `_ref` suffix are
+ defined for reference-typed tensors.
+
+ The `tf.as_dtype()` function converts numpy types and string type
+ names to a `DType` object.
+
+ @@is_compatible_with
+ @@name
+ @@base_dtype
+ @@is_ref_dtype
+ @@as_ref
+ @@is_integer
+ @@is_quantized
+
+ @@as_numpy_dtype
+ @@as_datatype_enum
+ """
+
+ def __init__(self, type_enum):
+ """Creates a new `DataType`.
+
+ NOTE(mrry): In normal circumstances, you should not need to
+ construct a DataType object directly. Instead, use the
+ types.as_dtype() function.
+
+ Args:
+ type_enum: A `types_pb2.DataType` enum value.
+
+ Raises:
+ TypeError: If `type_enum` is not a value `types_pb2.DataType`.
+
+ """
+ # TODO(mrry): Make the necessary changes (using __new__) to ensure
+ # that calling this returns one of the interned values.
+ type_enum = int(type_enum)
+ if (type_enum not in types_pb2.DataType.values()
+ or type_enum == types_pb2.DT_INVALID):
+ raise TypeError(
+ "type_enum is not a valid types_pb2.DataType: %s" % type_enum)
+ self._type_enum = type_enum
+
+ @property
+ def is_ref_dtype(self):
+ """Returns `True` if this `DType` represents a reference type."""
+ return self._type_enum > 100
+
+ @property
+ def as_ref(self):
+ """Returns a reference `DType` based on this `DType`."""
+ if self.is_ref_dtype:
+ return self
+ else:
+ return _INTERN_TABLE[self._type_enum + 100]
+
+ @property
+ def base_dtype(self):
+ """Returns a non-reference `DType` based on this `DType`."""
+ if self.is_ref_dtype:
+ return _INTERN_TABLE[self._type_enum - 100]
+ else:
+ return self
+
+ @property
+ def as_numpy_dtype(self):
+ """Returns a `numpy.dtype` based on this `DType`."""
+ return _TF_TO_NP[self._type_enum]
+
+ @property
+ def as_datatype_enum(self):
+ """Returns a `types_pb2.DataType` enum value based on this `DType`."""
+ return self._type_enum
+
+ @property
+ def is_integer(self):
+ """Returns whether this is a (non-quantized) integer type."""
+ return (not self.is_quantized and
+ issubclass(self.as_numpy_dtype, np.integer))
+
+ @property
+ def is_quantized(self):
+ """Returns whether this is a quantized data type."""
+ return self.base_dtype in [qint8, quint8, qint32, bfloat16]
+
+ @property
+ def min(self):
+ """Returns the minimum representable value in this data type.
+
+ Raises:
+ TypeError: if this is a non-numeric, unordered, or quantized type.
+
+ """
+ if (self.is_quantized or self.base_dtype == bool or
+ self.base_dtype == string or self.base_dtype == complex64):
+ raise TypeError("Cannot find minimum value of %s." % self)
+
+ # there is no simple way to get the min value of a dtype, we have to check
+ # float and int types separately
+ try:
+ return np.finfo(self.as_numpy_dtype()).min
+ except: # bare except as possible raises by finfo not documented
+ try:
+ return np.iinfo(self.as_numpy_dtype()).min
+ except:
+ raise TypeError("Cannot find minimum value of %s." % self)
+
+ @property
+ def max(self):
+ """Returns the maximum representable value in this data type.
+
+ Raises:
+ TypeError: if this is a non-numeric, unordered, or quantized type.
+
+ """
+ if (self.is_quantized or self.base_dtype == bool or
+ self.base_dtype == string or self.base_dtype == complex64):
+ raise TypeError("Cannot find maximum value of %s." % self)
+
+ # there is no simple way to get the min value of a dtype, we have to check
+ # float and int types separately
+ try:
+ return np.finfo(self.as_numpy_dtype()).max
+ except: # bare except as possible raises by finfo not documented
+ try:
+ return np.iinfo(self.as_numpy_dtype()).max
+ except:
+ raise TypeError("Cannot find maximum value of %s." % self)
+
+ def is_compatible_with(self, other):
+ """Returns True if the `other` DType will be converted to this DType.
+
+ The conversion rules are as follows:
+
+ ```
+ DType(T) .is_compatible_with(DType(T)) == True
+ DType(T) .is_compatible_with(DType(T).as_ref) == True
+ DType(T).as_ref.is_compatible_with(DType(T)) == False
+ DType(T).as_ref.is_compatible_with(DType(T).as_ref) == True
+ ```
+
+ Args:
+ other: A `DType` (or object that may be converted to a `DType`).
+
+ Returns:
+ True if a Tensor of the `other` `DType` will be implicitly converted to
+ this `DType`.
+ """
+ other = as_dtype(other)
+ return self._type_enum in (
+ other.as_datatype_enum, other.base_dtype.as_datatype_enum)
+
+ def __eq__(self, other):
+ """Returns True iff this DType refers to the same type as `other`."""
+ return (other is not None
+ and self._type_enum == as_dtype(other).as_datatype_enum)
+
+ def __ne__(self, other):
+ """Returns True iff self != other."""
+ return not self.__eq__(other)
+
+ @property
+ def name(self):
+ """Returns the string name for this `DType`."""
+ return _TYPE_TO_STRING[self._type_enum]
+
+ def __str__(self):
+ return "<dtype: %r>" % self.name
+
+ def __repr__(self):
+ return "tf." + self.name
+
+
+# Define standard wrappers for the types_pb2.DataType enum.
+float32 = DType(types_pb2.DT_FLOAT)
+float64 = DType(types_pb2.DT_DOUBLE)
+double = float64
+int32 = DType(types_pb2.DT_INT32)
+uint8 = DType(types_pb2.DT_UINT8)
+int16 = DType(types_pb2.DT_INT16)
+int8 = DType(types_pb2.DT_INT8)
+string = DType(types_pb2.DT_STRING)
+complex64 = DType(types_pb2.DT_COMPLEX64)
+int64 = DType(types_pb2.DT_INT64)
+bool = DType(types_pb2.DT_BOOL)
+qint8 = DType(types_pb2.DT_QINT8)
+quint8 = DType(types_pb2.DT_QUINT8)
+qint32 = DType(types_pb2.DT_QINT32)
+bfloat16 = DType(types_pb2.DT_BFLOAT16)
+float32_ref = DType(types_pb2.DT_FLOAT_REF)
+float64_ref = DType(types_pb2.DT_DOUBLE_REF)
+double_ref = float64_ref
+int32_ref = DType(types_pb2.DT_INT32_REF)
+uint8_ref = DType(types_pb2.DT_UINT8_REF)
+int16_ref = DType(types_pb2.DT_INT16_REF)
+int8_ref = DType(types_pb2.DT_INT8_REF)
+string_ref = DType(types_pb2.DT_STRING_REF)
+complex64_ref = DType(types_pb2.DT_COMPLEX64_REF)
+int64_ref = DType(types_pb2.DT_INT64_REF)
+bool_ref = DType(types_pb2.DT_BOOL_REF)
+qint8_ref = DType(types_pb2.DT_QINT8_REF)
+quint8_ref = DType(types_pb2.DT_QUINT8_REF)
+qint32_ref = DType(types_pb2.DT_QINT32_REF)
+bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF)
+
+
+# Maintain an intern table so that we don't have to create a large
+# number of small objects.
+_INTERN_TABLE = {
+ types_pb2.DT_FLOAT: float32,
+ types_pb2.DT_DOUBLE: float64,
+ types_pb2.DT_INT32: int32,
+ types_pb2.DT_UINT8: uint8,
+ types_pb2.DT_INT16: int16,
+ types_pb2.DT_INT8: int8,
+ types_pb2.DT_STRING: string,
+ types_pb2.DT_COMPLEX64: complex64,
+ types_pb2.DT_INT64: int64,
+ types_pb2.DT_BOOL: bool,
+ types_pb2.DT_QINT8: qint8,
+ types_pb2.DT_QUINT8: quint8,
+ types_pb2.DT_QINT32: qint32,
+ types_pb2.DT_BFLOAT16: bfloat16,
+ types_pb2.DT_FLOAT_REF: float32_ref,
+ types_pb2.DT_DOUBLE_REF: float64_ref,
+ types_pb2.DT_INT32_REF: int32_ref,
+ types_pb2.DT_UINT8_REF: uint8_ref,
+ types_pb2.DT_INT16_REF: int16_ref,
+ types_pb2.DT_INT8_REF: int8_ref,
+ types_pb2.DT_STRING_REF: string_ref,
+ types_pb2.DT_COMPLEX64_REF: complex64_ref,
+ types_pb2.DT_INT64_REF: int64_ref,
+ types_pb2.DT_BOOL_REF: bool_ref,
+ types_pb2.DT_QINT8_REF: qint8_ref,
+ types_pb2.DT_QUINT8_REF: quint8_ref,
+ types_pb2.DT_QINT32_REF: qint32_ref,
+ types_pb2.DT_BFLOAT16_REF: bfloat16_ref,
+}
+
+
+# Standard mappings between types_pb2.DataType values and string names.
+_TYPE_TO_STRING = {
+ types_pb2.DT_FLOAT: "float32",
+ types_pb2.DT_DOUBLE: "float64",
+ types_pb2.DT_INT32: "int32",
+ types_pb2.DT_UINT8: "uint8",
+ types_pb2.DT_INT16: "int16",
+ types_pb2.DT_INT8: "int8",
+ types_pb2.DT_STRING: "string",
+ types_pb2.DT_COMPLEX64: "complex64",
+ types_pb2.DT_INT64: "int64",
+ types_pb2.DT_BOOL: "bool",
+ types_pb2.DT_QINT8: "qint8",
+ types_pb2.DT_QUINT8: "quint8",
+ types_pb2.DT_QINT32: "qint32",
+ types_pb2.DT_BFLOAT16: "bfloat16",
+ types_pb2.DT_FLOAT_REF: "float32_ref",
+ types_pb2.DT_DOUBLE_REF: "float64_ref",
+ types_pb2.DT_INT32_REF: "int32_ref",
+ types_pb2.DT_UINT8_REF: "uint8_ref",
+ types_pb2.DT_INT16_REF: "int16_ref",
+ types_pb2.DT_INT8_REF: "int8_ref",
+ types_pb2.DT_STRING_REF: "string_ref",
+ types_pb2.DT_COMPLEX64_REF: "complex64_ref",
+ types_pb2.DT_INT64_REF: "int64_ref",
+ types_pb2.DT_BOOL_REF: "bool_ref",
+ types_pb2.DT_QINT8_REF: "qint8_ref",
+ types_pb2.DT_QUINT8_REF: "quint8_ref",
+ types_pb2.DT_QINT32_REF: "qint32_ref",
+ types_pb2.DT_BFLOAT16_REF: "bfloat16_ref",
+}
+_STRING_TO_TF = {value: _INTERN_TABLE[key]
+ for key, value in _TYPE_TO_STRING.iteritems()}
+# Add non-canonical aliases.
+_STRING_TO_TF["float"] = float32
+_STRING_TO_TF["float_ref"] = float32_ref
+_STRING_TO_TF["double"] = float64
+_STRING_TO_TF["double_ref"] = float64_ref
+
+
+# Numpy representation for quantized dtypes.
+#
+# These are magic strings that are used in the swig wrapper to identify
+# quantized types.
+# TODO(mrry,keveman): Investigate Numpy type registration to replace this
+# hard-coding of names.
+_np_qint8 = np.dtype([("qint8", np.int8, 1)])
+_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
+_np_qint32 = np.dtype([("qint32", np.int32, 1)])
+
+# Standard mappings between types_pb2.DataType values and numpy.dtypes.
+_NP_TO_TF = frozenset([
+ (np.float32, float32),
+ (np.float64, float64),
+ (np.int32, int32),
+ (np.int64, int64),
+ (np.uint8, uint8),
+ (np.int16, int16),
+ (np.int8, int8),
+ (np.complex64, complex64),
+ (np.object, string),
+ (np.bool, bool),
+ (_np_qint8, qint8),
+ (_np_quint8, quint8),
+ (_np_qint32, qint32),
+ # NOTE(mdevin): Intentionally no way to feed a DT_BFLOAT16.
+])
+_TF_TO_NP = {
+ types_pb2.DT_FLOAT: np.float32,
+ types_pb2.DT_DOUBLE: np.float64,
+ types_pb2.DT_INT32: np.int32,
+ types_pb2.DT_UINT8: np.uint8,
+ types_pb2.DT_INT16: np.int16,
+ types_pb2.DT_INT8: np.int8,
+ # NOTE(mdevin): For strings we use np.object as it supports variable length
+ # strings.
+ types_pb2.DT_STRING: np.object,
+ types_pb2.DT_COMPLEX64: np.complex64,
+ types_pb2.DT_INT64: np.int64,
+ types_pb2.DT_BOOL: np.bool,
+ types_pb2.DT_QINT8: _np_qint8,
+ types_pb2.DT_QUINT8: _np_quint8,
+ types_pb2.DT_QINT32: _np_qint32,
+ types_pb2.DT_BFLOAT16: np.uint16,
+
+ # Ref types
+ types_pb2.DT_FLOAT_REF: np.float32,
+ types_pb2.DT_DOUBLE_REF: np.float64,
+ types_pb2.DT_INT32_REF: np.int32,
+ types_pb2.DT_UINT8_REF: np.uint8,
+ types_pb2.DT_INT16_REF: np.int16,
+ types_pb2.DT_INT8_REF: np.int8,
+ types_pb2.DT_STRING_REF: np.object,
+ types_pb2.DT_COMPLEX64_REF: np.complex64,
+ types_pb2.DT_INT64_REF: np.int64,
+ types_pb2.DT_BOOL_REF: np.bool,
+ types_pb2.DT_QINT8_REF: _np_qint8,
+ types_pb2.DT_QUINT8_REF: _np_quint8,
+ types_pb2.DT_QINT32_REF: _np_qint32,
+ types_pb2.DT_BFLOAT16_REF: np.uint16,
+}
+
+
+QUANTIZED_DTYPES = frozenset(
+ [qint8, quint8, qint32, qint8_ref, quint8_ref, qint32_ref])
+
+
+def as_dtype(type_value):
+ """Converts the given `type_value` to a `DType`.
+
+ Args:
+ type_value: A value that can be converted to a `tf.DType`
+ object. This may currently be a `tf.DType` object, a
+ [`DataType` enum](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/types.proto),
+ a string type name, or a `numpy.dtype`.
+
+ Returns:
+ A `DType` corresponding to `type_value`.
+
+ Raises:
+ TypeError: If `type_value` cannot be converted to a `DType`.
+ """
+ if isinstance(type_value, DType):
+ return type_value
+
+ try:
+ return _INTERN_TABLE[type_value]
+ except KeyError:
+ pass
+
+ try:
+ return _STRING_TO_TF[type_value]
+ except KeyError:
+ pass
+
+ if isinstance(type_value, np.dtype):
+ # The numpy dtype for strings is variable length. We can not compare
+ # dtype with a single constant (np.string does not exist) to decide
+ # dtype is a "string" type. We need to compare the dtype.type to be
+ # sure it's a string type.
+ if type_value.type == np.string_ or type_value.type == np.unicode_:
+ return string
+
+ for key, val in _NP_TO_TF:
+ if key == type_value:
+ return val
+
+ raise TypeError(
+ "Cannot convert value %r to a TensorFlow DType." % type_value)
diff --git a/tensorflow/python/framework/types_test.py b/tensorflow/python/framework/types_test.py
new file mode 100644
index 0000000000..acd2994339
--- /dev/null
+++ b/tensorflow/python/framework/types_test.py
@@ -0,0 +1,174 @@
+"""Tests for tensorflow.python.framework.importer."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.platform import googletest
+
+
+class TypesTest(test_util.TensorFlowTestCase):
+
+ def testAllTypesConstructible(self):
+ for datatype_enum in types_pb2.DataType.values():
+ if datatype_enum == types_pb2.DT_INVALID:
+ continue
+ self.assertEqual(
+ datatype_enum, types.DType(datatype_enum).as_datatype_enum)
+
+ def testAllTypesConvertibleToDType(self):
+ for datatype_enum in types_pb2.DataType.values():
+ if datatype_enum == types_pb2.DT_INVALID:
+ continue
+ self.assertEqual(
+ datatype_enum, types.as_dtype(datatype_enum).as_datatype_enum)
+
+ def testAllTypesConvertibleToNumpyDtype(self):
+ for datatype_enum in types_pb2.DataType.values():
+ if datatype_enum == types_pb2.DT_INVALID:
+ continue
+ dtype = types.as_dtype(datatype_enum)
+ numpy_dtype = dtype.as_numpy_dtype
+ _ = np.empty((1, 1, 1, 1), dtype=numpy_dtype)
+ if dtype.base_dtype != types.bfloat16:
+ # NOTE(mdevin): Intentionally no way to feed a DT_BFLOAT16.
+ self.assertEqual(
+ types.as_dtype(datatype_enum).base_dtype, types.as_dtype(numpy_dtype))
+
+ def testInvalid(self):
+ with self.assertRaises(TypeError):
+ types.DType(types_pb2.DT_INVALID)
+ with self.assertRaises(TypeError):
+ types.as_dtype(types_pb2.DT_INVALID)
+
+ def testNumpyConversion(self):
+ self.assertIs(types.float32, types.as_dtype(np.float32))
+ self.assertIs(types.float64, types.as_dtype(np.float64))
+ self.assertIs(types.int32, types.as_dtype(np.int32))
+ self.assertIs(types.int64, types.as_dtype(np.int64))
+ self.assertIs(types.uint8, types.as_dtype(np.uint8))
+ self.assertIs(types.int16, types.as_dtype(np.int16))
+ self.assertIs(types.int8, types.as_dtype(np.int8))
+ self.assertIs(types.complex64, types.as_dtype(np.complex64))
+ self.assertIs(types.string, types.as_dtype(np.object))
+ self.assertIs(types.string, types.as_dtype(np.array(["foo", "bar"]).dtype))
+ self.assertIs(types.bool, types.as_dtype(np.bool))
+ with self.assertRaises(TypeError):
+ types.as_dtype(np.dtype([("f1", np.uint), ("f2", np.int32)]))
+
+ def testStringConversion(self):
+ self.assertIs(types.float32, types.as_dtype("float32"))
+ self.assertIs(types.float64, types.as_dtype("float64"))
+ self.assertIs(types.int32, types.as_dtype("int32"))
+ self.assertIs(types.uint8, types.as_dtype("uint8"))
+ self.assertIs(types.int16, types.as_dtype("int16"))
+ self.assertIs(types.int8, types.as_dtype("int8"))
+ self.assertIs(types.string, types.as_dtype("string"))
+ self.assertIs(types.complex64, types.as_dtype("complex64"))
+ self.assertIs(types.int64, types.as_dtype("int64"))
+ self.assertIs(types.bool, types.as_dtype("bool"))
+ self.assertIs(types.qint8, types.as_dtype("qint8"))
+ self.assertIs(types.quint8, types.as_dtype("quint8"))
+ self.assertIs(types.qint32, types.as_dtype("qint32"))
+ self.assertIs(types.bfloat16, types.as_dtype("bfloat16"))
+ self.assertIs(types.float32_ref, types.as_dtype("float32_ref"))
+ self.assertIs(types.float64_ref, types.as_dtype("float64_ref"))
+ self.assertIs(types.int32_ref, types.as_dtype("int32_ref"))
+ self.assertIs(types.uint8_ref, types.as_dtype("uint8_ref"))
+ self.assertIs(types.int16_ref, types.as_dtype("int16_ref"))
+ self.assertIs(types.int8_ref, types.as_dtype("int8_ref"))
+ self.assertIs(types.string_ref, types.as_dtype("string_ref"))
+ self.assertIs(types.complex64_ref, types.as_dtype("complex64_ref"))
+ self.assertIs(types.int64_ref, types.as_dtype("int64_ref"))
+ self.assertIs(types.bool_ref, types.as_dtype("bool_ref"))
+ self.assertIs(types.qint8_ref, types.as_dtype("qint8_ref"))
+ self.assertIs(types.quint8_ref, types.as_dtype("quint8_ref"))
+ self.assertIs(types.qint32_ref, types.as_dtype("qint32_ref"))
+ self.assertIs(types.bfloat16_ref, types.as_dtype("bfloat16_ref"))
+ with self.assertRaises(TypeError):
+ types.as_dtype("not_a_type")
+
+ def testDTypesHaveUniqueNames(self):
+ dtypes = []
+ names = set()
+ for datatype_enum in types_pb2.DataType.values():
+ if datatype_enum == types_pb2.DT_INVALID:
+ continue
+ dtype = types.as_dtype(datatype_enum)
+ dtypes.append(dtype)
+ names.add(dtype.name)
+ self.assertEqual(len(dtypes), len(names))
+
+ def testIsInteger(self):
+ self.assertEqual(types.as_dtype("int8").is_integer, True)
+ self.assertEqual(types.as_dtype("int16").is_integer, True)
+ self.assertEqual(types.as_dtype("int32").is_integer, True)
+ self.assertEqual(types.as_dtype("int64").is_integer, True)
+ self.assertEqual(types.as_dtype("uint8").is_integer, True)
+ self.assertEqual(types.as_dtype("complex64").is_integer, False)
+ self.assertEqual(types.as_dtype("float").is_integer, False)
+ self.assertEqual(types.as_dtype("double").is_integer, False)
+ self.assertEqual(types.as_dtype("string").is_integer, False)
+ self.assertEqual(types.as_dtype("bool").is_integer, False)
+
+ def testMinMax(self):
+ # make sure min/max evaluates for all data types that have min/max
+ for datatype_enum in types_pb2.DataType.values():
+ if datatype_enum == types_pb2.DT_INVALID:
+ continue
+ dtype = types.as_dtype(datatype_enum)
+ numpy_dtype = dtype.as_numpy_dtype
+
+ # ignore types for which there are no minimum/maximum (or we cannot
+ # compute it, such as for the q* types)
+ if (dtype.is_quantized or
+ dtype.base_dtype == types.bool or
+ dtype.base_dtype == types.string or
+ dtype.base_dtype == types.complex64):
+ continue
+
+ print "%s: %s - %s" % (dtype, dtype.min, dtype.max)
+
+ # check some values that are known
+ if numpy_dtype == np.bool_:
+ self.assertEquals(dtype.min, 0)
+ self.assertEquals(dtype.max, 1)
+ if numpy_dtype == np.int8:
+ self.assertEquals(dtype.min, -128)
+ self.assertEquals(dtype.max, 127)
+ if numpy_dtype == np.int16:
+ self.assertEquals(dtype.min, -32768)
+ self.assertEquals(dtype.max, 32767)
+ if numpy_dtype == np.int32:
+ self.assertEquals(dtype.min, -2147483648)
+ self.assertEquals(dtype.max, 2147483647)
+ if numpy_dtype == np.int64:
+ self.assertEquals(dtype.min, -9223372036854775808)
+ self.assertEquals(dtype.max, 9223372036854775807)
+ if numpy_dtype == np.uint8:
+ self.assertEquals(dtype.min, 0)
+ self.assertEquals(dtype.max, 255)
+ if numpy_dtype == np.uint16:
+ self.assertEquals(dtype.min, 0)
+ self.assertEquals(dtype.max, 4294967295)
+ if numpy_dtype == np.uint32:
+ self.assertEquals(dtype.min, 0)
+ self.assertEquals(dtype.max, 18446744073709551615)
+ if numpy_dtype in (np.float16, np.float32, np.float64):
+ self.assertEquals(dtype.min, np.finfo(numpy_dtype).min)
+ self.assertEquals(dtype.max, np.finfo(numpy_dtype).max)
+
+ def testRepr(self):
+ for enum, name in types._TYPE_TO_STRING.iteritems():
+ dtype = types.DType(enum)
+ self.assertEquals(repr(dtype), 'tf.' + name)
+ dtype2 = eval(repr(dtype))
+ self.assertEquals(type(dtype2), types.DType)
+ self.assertEquals(dtype, dtype2)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/kernel_tests/__init__.py b/tensorflow/python/kernel_tests/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/kernel_tests/__init__.py
diff --git a/tensorflow/python/kernel_tests/argmax_op_test.py b/tensorflow/python/kernel_tests/argmax_op_test.py
new file mode 100644
index 0000000000..2cd6101a87
--- /dev/null
+++ b/tensorflow/python/kernel_tests/argmax_op_test.py
@@ -0,0 +1,61 @@
+"""Tests for tensorflow.ops.argmax_op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+class ArgMaxTest(tf.test.TestCase):
+
+ def _testArg(self, method, x, dimension,
+ expected_values, use_gpu=False, expected_err_re=None):
+ with self.test_session(use_gpu=use_gpu):
+ ans = method(x, dimension=dimension)
+ if expected_err_re is None:
+ tf_ans = ans.eval()
+ self.assertAllEqual(tf_ans, expected_values)
+ self.assertShapeEqual(expected_values, ans)
+ else:
+ with self.assertRaisesOpError(expected_err_re):
+ ans.eval()
+
+ def _testBothArg(self, method, x, dimension,
+ expected_values, expected_err_re=None):
+ self._testArg(method, x, dimension,
+ expected_values, True, expected_err_re)
+ self._testArg(method, x, dimension,
+ expected_values, False, expected_err_re)
+
+ def _testBasic(self, dtype):
+ x = np.asarray(100*np.random.randn(200), dtype=dtype)
+
+ # Check that argmin and argmax match numpy along the primary
+ # dimension
+ self._testBothArg(tf.argmax, x, 0, x.argmax())
+ self._testBothArg(tf.argmin, x, 0, x.argmin())
+
+ def _testDim(self, dtype):
+ x = np.asarray(100*np.random.randn(3, 2, 4, 5, 6), dtype=dtype)
+
+ # Check that argmin and argmax match numpy along all dimensions
+ for dim in range(5):
+ self._testBothArg(tf.argmax, x, dim, x.argmax(dim))
+ self._testBothArg(tf.argmin, x, dim, x.argmin(dim))
+
+ def testFloat(self):
+ self._testBasic(np.float32)
+ self._testDim(np.float32)
+
+ def testDouble(self):
+ self._testBasic(np.float64)
+ self._testDim(np.float64)
+
+ def testInt32(self):
+ self._testBasic(np.int32)
+ self._testDim(np.int32)
+
+ def testInt64(self):
+ self._testBasic(np.int64)
+ self._testDim(np.int64)
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
new file mode 100644
index 0000000000..108cc7599e
--- /dev/null
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -0,0 +1,45 @@
+"""Tests for array_ops."""
+import math
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import googletest
+
+
+class OperatorShapeTest(test_util.TensorFlowTestCase):
+
+ def testExpandScalar(self):
+ scalar = 'hello'
+ scalar_expanded = array_ops.expand_dims(scalar, [0])
+ self.assertEqual(scalar_expanded.get_shape(), (1,))
+
+ def testSqueeze(self):
+ scalar = 'hello'
+ scalar_squeezed = array_ops.squeeze(scalar, ())
+ self.assertEqual(scalar_squeezed.get_shape(), ())
+
+
+class ReverseTest(test_util.TensorFlowTestCase):
+
+ def testReverse0DimAuto(self):
+ x_np = 4
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = array_ops.reverse(x_np, []).eval()
+ self.assertAllEqual(x_tf, x_np)
+
+ def testReverse1DimAuto(self):
+ x_np = [1, 4, 9]
+
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = array_ops.reverse(x_np, [True]).eval()
+ self.assertAllEqual(x_tf, np.asarray(x_np)[::-1])
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/kernel_tests/attention_ops_test.py b/tensorflow/python/kernel_tests/attention_ops_test.py
new file mode 100644
index 0000000000..5541c541b2
--- /dev/null
+++ b/tensorflow/python/kernel_tests/attention_ops_test.py
@@ -0,0 +1,166 @@
+"""Tests for tensorflow.ops.attention_ops."""
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+from tensorflow.python.ops import attention_ops
+
+
+class ExtractGlimpseTest(tf.test.TestCase):
+
+ def _VerifyValues(
+ self, tensor_in_sizes, glimpse_sizes, offsets, expected_rows,
+ expected_cols):
+ """Verifies the output values of the glimpse extraction kernel.
+
+ Args:
+ tensor_in_sizes: Input tensor dimensions in [input_rows, input_cols].
+ glimpse_sizes: Dimensions of the glimpse in [glimpse_rows, glimpse_cols].
+ offsets: Relative location of the center of the glimpse in the input
+ image expressed as [row_offset, col_offset].
+ expected_rows: A list containing the expected row numbers (None for
+ out of bound entries that are expected to be replaced by uniform
+ random entries in [0,1) ).
+ expected_cols: Same as expected_rows, but for column numbers.
+ """
+
+ rows = tensor_in_sizes[0]
+ cols = tensor_in_sizes[1]
+ # Row Tensor with entries by row.
+ # [[ 1 1 1 ... ]
+ # [ 2 2 2 ... ]
+ # [ 3 3 3 ... ]
+ # [ ...
+ # ]
+ t_rows = tf.tile(
+ [[1.0 * r] for r in range(1, rows + 1)], [1, cols],
+ name='tile_rows')
+
+ # Shuffle to switch to a convention of (batch_size, height, width, depth).
+ t_rows_4d = tf.transpose(
+ tf.expand_dims(
+ tf.expand_dims(t_rows, 0), 3), [0, 2, 1, 3])
+
+ # Column Tensor with entries by column.
+ # [[ 1 2 3 4 ... ]
+ # [ 1 2 3 4 ... ]
+ # [ 1 2 3 4 ... ]
+ # [ ... ]
+ # ]
+ t_cols = tf.tile(
+ [[1.0 * r for r in range(1, cols + 1)]],
+ [rows, 1], name='tile_cols')
+
+ # Shuffle to switch to a convention of (batch_size, height, width, depth).
+ t_cols_4d = tf.transpose(
+ tf.expand_dims(
+ tf.expand_dims(t_cols, 0), 3), [0, 2, 1, 3])
+
+ # extract_glimpses from Row and Column Tensor, respectively.
+ # Switch order for glimpse_sizes and offsets to switch from (row, col)
+ # convention to tensorflows (height, width) convention.
+ t1 = tf.constant([glimpse_sizes[1], glimpse_sizes[0]], shape=[2])
+ t2 = tf.constant([offsets[1], offsets[0]], shape=[1, 2])
+ glimpse_rows = (tf.transpose(
+ attention_ops.extract_glimpse(t_rows_4d, t1, t2), [0, 2, 1, 3]))
+ glimpse_cols = (tf.transpose(
+ attention_ops.extract_glimpse(t_cols_4d, t1, t2), [0, 2, 1, 3]))
+
+ # Evaluate the Tensorflow Graph.
+ with self.test_session() as sess:
+ value_rows, value_cols = sess.run([glimpse_rows, glimpse_cols])
+
+ # Check dimensions of returned glimpse.
+ self.assertEqual(value_rows.shape[1], glimpse_sizes[0])
+ self.assertEqual(value_rows.shape[2], glimpse_sizes[1])
+ self.assertEqual(value_cols.shape[1], glimpse_sizes[0])
+ self.assertEqual(value_cols.shape[2], glimpse_sizes[1])
+
+ # Check entries.
+ min_random_val = 0
+ max_random_val = max(rows, cols)
+ for i in range(0, glimpse_sizes[0]):
+ for j in range(0, glimpse_sizes[1]):
+ if expected_rows[i] is None or expected_cols[j] is None:
+ self.assertGreaterEqual(value_rows[0][i][j][0], min_random_val)
+ self.assertLessEqual(value_rows[0][i][j][0], max_random_val)
+ self.assertGreaterEqual(value_cols[0][i][j][0], min_random_val)
+ self.assertLessEqual(value_cols[0][i][j][0], max_random_val)
+ else:
+ self.assertEqual(value_rows[0][i][j][0], expected_rows[i])
+ self.assertEqual(value_cols[0][i][j][0], expected_cols[j])
+
+ def testCenterGlimpse(self):
+ self._VerifyValues(tensor_in_sizes=[41, 61],
+ glimpse_sizes=[3, 5],
+ offsets=[0.0, 0.0],
+ expected_rows=[20, 21, 22],
+ expected_cols=[29, 30, 31, 32, 33])
+
+ def testLargeCenterGlimpse(self):
+ self._VerifyValues(tensor_in_sizes=[41, 61],
+ glimpse_sizes=[41, 61],
+ offsets=[0.0, 0.0],
+ expected_rows=range(1, 42),
+ expected_cols=range(1, 62))
+
+ def testTooLargeCenterGlimpse(self):
+ self._VerifyValues(tensor_in_sizes=[41, 61],
+ glimpse_sizes=[43, 63],
+ offsets=[0.0, 0.0],
+ expected_rows=[None] + range(1, 42) + [None],
+ expected_cols=[None] + range(1, 62) + [None])
+
+ def testGlimpseFullOverlap(self):
+ self._VerifyValues(tensor_in_sizes=[41, 61],
+ glimpse_sizes=[3, 5],
+ offsets=[0.1, 0.3],
+ expected_rows=[22, 23, 24],
+ expected_cols=[38, 39, 40, 41, 42])
+
+ def testGlimpseFullOverlap2(self):
+ self._VerifyValues(tensor_in_sizes=[41, 61],
+ glimpse_sizes=[11, 3],
+ offsets=[-0.7, -0.7],
+ expected_rows=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
+ expected_cols=[8, 9, 10])
+
+ def testGlimpseBeforeLeftMargin(self):
+ self._VerifyValues(tensor_in_sizes=[41, 61],
+ glimpse_sizes=[11, 5],
+ offsets=[-0.7, -0.9],
+ expected_rows=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
+ expected_cols=[1, 2, 3, 4, 5])
+
+ def testGlimpseLowerRightCorner(self):
+ self._VerifyValues(tensor_in_sizes=[41, 61],
+ glimpse_sizes=[7, 5],
+ offsets=[1.0, 1.0],
+ expected_rows=[38, 39, 40, 41, None, None, None],
+ expected_cols=[59, 60, 61, None, None])
+
+ def testGlimpseNoOverlap(self):
+ self._VerifyValues(tensor_in_sizes=[20, 30],
+ glimpse_sizes=[3, 3],
+ offsets=[-2.0, 2.0],
+ expected_rows=[None, None, None],
+ expected_cols=[None, None, None])
+
+ def testGlimpseOnLeftMargin(self):
+ self._VerifyValues(tensor_in_sizes=[41, 61],
+ glimpse_sizes=[11, 7],
+ offsets=[-0.7, -1.0],
+ expected_rows=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
+ expected_cols=[None, None, None, 1, 2, 3, 4])
+
+ def testGlimpseUpperMargin(self):
+ self._VerifyValues(tensor_in_sizes=[41, 61],
+ glimpse_sizes=[7, 5],
+ offsets=[-1, 0.9],
+ expected_rows=[None, None, None, 1, 2, 3, 4],
+ expected_cols=[56, 57, 58, 59, 60])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/batch_matmul_op_test.py b/tensorflow/python/kernel_tests/batch_matmul_op_test.py
new file mode 100644
index 0000000000..8ae37fec3a
--- /dev/null
+++ b/tensorflow/python/kernel_tests/batch_matmul_op_test.py
@@ -0,0 +1,195 @@
+"""Tests for tensorflow.ops.tf.BatchMatMul."""
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+
+class BatchMatmulOpTest(tf.test.TestCase):
+
+ # Uses numpy to compute batch_matmul(x, y, adj_x, adj_y).
+ def _npBatchMatmul(self, x, y, adj_x, adj_y):
+ assert x.ndim >= 3
+ assert y.ndim >= 3
+ # output's shape depends on adj[0] and adj[1]
+ d0 = x.shape[-2] if not adj_x else x.shape[-1]
+ d2 = y.shape[-1] if not adj_y else y.shape[-2]
+ batch_dims = x.shape[:-2]
+ num = np.prod(batch_dims)
+ z = np.empty(list(batch_dims) + [d0, d2], dtype=x.dtype)
+ xr = x.reshape([num, x.shape[-2], x.shape[-1]])
+ yr = y.reshape([num, y.shape[-2], y.shape[-1]])
+ zr = z.reshape([num, z.shape[-2], z.shape[-1]])
+ for i in range(num):
+ a = np.matrix(xr[i, :, :])
+ if adj_x:
+ a = a.transpose().conj()
+ b = np.matrix(yr[i, :, :])
+ if adj_y:
+ b = b.transpose().conj()
+ zr[i, :, :] = a * b
+ return z
+
+ # Test _npBatchMatMul works.
+ def testSimpleNpVersion(self):
+ x = np.array([0., 1., 2., 3.]).reshape([1, 2, 2])
+ y = np.array([1., 2., 3., 4.]).reshape([1, 2, 2])
+ z0 = self._npBatchMatmul(x, y, False, False)
+ z1 = np.array([3., 4., 11., 16.]).reshape([1, 2, 2])
+ self.assertTrue(np.array_equal(z0, z1))
+
+ x = np.array([1., (1j), (-1.), (-1j)]).reshape([1, 2, 2])
+ y = x * np.complex(1, 1) # rotate x 90 degree
+ z0 = self._npBatchMatmul(x, y, False, False)
+ z1 = np.array([2., (2.j), -2., (-2.j)]).reshape([1, 2, 2])
+ self.assertTrue(np.array_equal(z0, z1))
+
+ z0 = self._npBatchMatmul(x, y, False, True)
+ z1 = np.array([(2.-2.j), (-2.+2.j), (-2.+2.j), (2.-2.j)]).reshape([1, 2, 2])
+ self.assertTrue(np.array_equal(z0, z1))
+
+ z0 = self._npBatchMatmul(x, y, True, False)
+ z1 = np.array([(2.+2.j), (-2.+2.j), (2.-2.j), (2.+2.j)]).reshape([1, 2, 2])
+ self.assertTrue(np.array_equal(z0, z1))
+
+ # Compares _tfpBatchMatmul(x, y, alpha, adj) and _npBatchMatMul(x, y, alpha,
+ # adj)
+ def _compare(self, x, y, adj_x, adj_y, use_gpu=False):
+ with self.test_session(use_gpu=use_gpu):
+ z0 = tf.batch_matmul(x, y, adj_x=adj_x, adj_y=adj_y)
+ z0_val = z0.eval()
+ z1 = self._npBatchMatmul(x, y, adj_x, adj_y)
+ self.assertShapeEqual(z1, z0)
+ if z0_val.size != 0:
+ err = (np.abs(z0_val - z1) / np.maximum(1, np.abs(z0_val))).max()
+ tf.logging.info("error = %f", err)
+ self.assertTrue(err < 1e-4)
+
+ # Returns a random float np of "shape".
+ def _randFloat(self, shape):
+ vals = np.random.normal(0, 1, np.prod(shape)).reshape(shape)
+ return np.array(vals, dtype=np.float32)
+
+ def testSimpleFloat(self):
+ for use_gpu in [False, True]:
+ self._compare(self._randFloat([7, 2, 3]), self._randFloat([7, 3, 5]),
+ False, False, use_gpu)
+ self._compare(self._randFloat([7, 2, 3]), self._randFloat([7, 5, 3]),
+ False, True, use_gpu)
+ self._compare(self._randFloat([7, 3, 2]), self._randFloat([7, 3, 5]),
+ True, False, use_gpu)
+ self._compare(self._randFloat([7, 3, 2]), self._randFloat([7, 5, 3]),
+ True, True, use_gpu)
+
+ def testLargeFloat(self):
+ for use_gpu in [False, True]:
+ self._compare(self._randFloat([10, 64, 75]),
+ self._randFloat([10, 75, 30]), False, False, use_gpu)
+ self._compare(self._randFloat([10, 75, 64]),
+ self._randFloat([10, 75, 30]), True, False, use_gpu)
+ self._compare(self._randFloat([10, 64, 75]),
+ self._randFloat([10, 30, 75]), False, True, use_gpu)
+ self._compare(self._randFloat([10, 75, 64]),
+ self._randFloat([10, 30, 75]), True, True, use_gpu)
+
+ def testHighNDims(self):
+ for use_gpu in [False, True]:
+ self._compare(self._randFloat([5, 7, 2, 3]),
+ self._randFloat([5, 7, 3, 5]), False, False, use_gpu)
+ self._compare(self._randFloat([5, 7, 3, 2]),
+ self._randFloat([5, 7, 3, 5]), True, False, use_gpu)
+ self._compare(self._randFloat([5, 7, 2, 3]),
+ self._randFloat([5, 7, 5, 3]), False, True, use_gpu)
+ self._compare(self._randFloat([5, 7, 3, 2]),
+ self._randFloat([5, 7, 5, 3]), True, True, use_gpu)
+
+ # Returns a random complex numpy array of "shape".
+ def _randComplex(self, shape):
+ real = np.random.normal(0, 1, np.prod(shape))
+ imag = np.random.normal(0, 1, np.prod(shape))
+ vals = [np.complex(v[0], v[1]) for v in zip(real, imag)]
+ return np.array(vals, dtype=np.complex64).reshape(shape)
+
+ def testSimpleComplex(self):
+ self._compare(self._randComplex([7, 2, 3]),
+ self._randComplex([7, 3, 5]), False, False)
+ self._compare(self._randComplex([7, 2, 3]),
+ self._randComplex([7, 5, 3]), False, True)
+ self._compare(self._randComplex([7, 3, 2]),
+ self._randComplex([7, 3, 5]), True, False)
+ self._compare(self._randComplex([7, 3, 2]),
+ self._randComplex([7, 5, 3]), True, True)
+
+ def testLargeComplex(self):
+ self._compare(self._randComplex([10, 64, 75]),
+ self._randComplex([10, 75, 30]), False,
+ False)
+ self._compare(self._randComplex([10, 64, 75]),
+ self._randComplex([10, 30, 75]), False, True)
+ self._compare(self._randComplex([10, 75, 64]),
+ self._randComplex([10, 75, 30]), True, False)
+ self._compare(self._randComplex([10, 75, 64]),
+ self._randComplex([10, 30, 75]), True, True)
+
+ def testEmpty(self):
+ self._compare(np.empty([0, 3, 2]).astype(np.float32),
+ np.empty([0, 2, 4]).astype(np.float32), False, False)
+ self._compare(np.empty([3, 2, 0]).astype(np.float32),
+ np.empty([3, 0, 5]).astype(np.float32), False, False)
+ self._compare(np.empty([3, 0, 2]).astype(np.float32),
+ np.empty([3, 2, 5]).astype(np.float32), False, False)
+ self._compare(np.empty([3, 3, 2]).astype(np.float32),
+ np.empty([3, 2, 0]).astype(np.float32), False, False)
+
+
+class BatchMatmulGradientTest(tf.test.TestCase):
+
+ # loss = sum(batch_matmul(x, y)). Verify dl/dx and dl/dy via the
+ # gradient checker.
+ def _checkGrad(self, x, y, adj_x, adj_y):
+ assert 3 == x.ndim
+ assert 3 == y.ndim
+ with self.test_session():
+ inx = tf.convert_to_tensor(x)
+ iny = tf.convert_to_tensor(y)
+ z = tf.batch_matmul(inx, iny, adj_x, adj_y)
+ loss = tf.reduce_sum(z)
+ epsilon = 1e-2
+ ((x_jacob_t, x_jacob_n), (y_jacob_t, y_jacob_n)) = gc.ComputeGradient(
+ [inx, iny], [x.shape, y.shape], loss, [1],
+ x_init_value=[x, y], delta=epsilon)
+
+ tf.logging.info("x_jacob_t = %s", x_jacob_t.reshape(x.shape))
+ tf.logging.info("x_jacob_n = %s", x_jacob_n.reshape(x.shape))
+ self.assertAllClose(x_jacob_t, x_jacob_n, rtol=1e-2, atol=epsilon)
+ tf.logging.info("y_jacob_t = %s", y_jacob_t.reshape(y.shape))
+ tf.logging.info("y_jacob_n = %s", y_jacob_n.reshape(y.shape))
+ self.assertAllClose(y_jacob_t, y_jacob_n, rtol=1e-2, atol=epsilon)
+
+ # Tests a batched matmul of x, and y: x is a 3D tensor of shape [b,
+ # n, k] y is a 3D tensor of shape [b, k, m] the batched matmul
+ # computes z of shape [b, n, m], where z[i, :, :] = x[i, :, :]
+ # matmul y[i, :, :]
+ def _compare(self, b, n, k, m):
+ x = np.random.normal(0, 1, b * n * k).astype(np.float32).reshape([b, n, k])
+ y = np.random.normal(0, 1, b * k * m).astype(np.float32).reshape([b, k, m])
+ self._checkGrad(x, y, False, False)
+ self._checkGrad(x.reshape([b, k, n]), y, True, False)
+ self._checkGrad(x, y.reshape([b, m, k]), False, True)
+ self._checkGrad(x.reshape([b, k, n]), y.reshape([b, m, k]), True, True)
+
+ def testSmall(self):
+ self._compare(1, 2, 3, 5)
+
+ def testMedium(self):
+ self._compare(3, 4, 7, 10)
+
+ # Can't do testLarge using very large inputs because gradient
+ # checker will take way too long time.
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/bcast_ops_test.py b/tensorflow/python/kernel_tests/bcast_ops_test.py
new file mode 100644
index 0000000000..c62a910496
--- /dev/null
+++ b/tensorflow/python/kernel_tests/bcast_ops_test.py
@@ -0,0 +1,76 @@
+"""Tests for tensorflow.kernels.bcast_ops."""
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+from tensorflow.python.ops.gen_array_ops import _broadcast_gradient_args
+
+
+class BcastOpsTest(tf.test.TestCase):
+
+ def _GetGradientArgs(self, xs, ys):
+ with self.test_session() as sess:
+ return sess.run(_broadcast_gradient_args(xs, ys))
+
+ def testBasic(self):
+ r0, r1 = self._GetGradientArgs([2, 3, 5], [1])
+ self.assertAllEqual(r0, [])
+ self.assertAllEqual(r1, [0, 1, 2])
+
+ r0, r1 = self._GetGradientArgs([1], [2, 3, 5])
+ self.assertAllEqual(r0, [0, 1, 2])
+ self.assertAllEqual(r1, [])
+
+ r0, r1 = self._GetGradientArgs([2, 3, 5], [5])
+ self.assertAllEqual(r0, [])
+ self.assertAllEqual(r1, [0, 1])
+
+ r0, r1 = self._GetGradientArgs([5], [2, 3, 5])
+ self.assertAllEqual(r0, [0, 1])
+ self.assertAllEqual(r1, [])
+
+ r0, r1 = self._GetGradientArgs([2, 3, 5], [3, 5])
+ self.assertAllEqual(r0, [])
+ self.assertAllEqual(r1, [0])
+
+ r0, r1 = self._GetGradientArgs([3, 5], [2, 3, 5])
+ self.assertAllEqual(r0, [0])
+ self.assertAllEqual(r1, [])
+
+ r0, r1 = self._GetGradientArgs([2, 3, 5], [3, 1])
+ self.assertAllEqual(r0, [])
+ self.assertAllEqual(r1, [0, 2])
+
+ r0, r1 = self._GetGradientArgs([3, 1], [2, 3, 5])
+ self.assertAllEqual(r0, [0, 2])
+ self.assertAllEqual(r1, [])
+
+ r0, r1 = self._GetGradientArgs([2, 1, 5], [3, 1])
+ self.assertAllEqual(r0, [1])
+ self.assertAllEqual(r1, [0, 2])
+
+ r0, r1 = self._GetGradientArgs([3, 1], [2, 1, 5])
+ self.assertAllEqual(r0, [0, 2])
+ self.assertAllEqual(r1, [1])
+
+ def testZeroDims(self):
+ r0, r1 = self._GetGradientArgs([2, 0, 3, 0, 5], [3, 0, 5])
+ self.assertAllEqual(r0, [])
+ self.assertAllEqual(r1, [0, 1])
+
+ r0, r1 = self._GetGradientArgs([3, 0, 5], [2, 0, 3, 0, 5])
+ self.assertAllEqual(r0, [0, 1])
+ self.assertAllEqual(r1, [])
+
+ r0, r1 = self._GetGradientArgs([2, 0, 3, 0, 5], [3, 1, 5])
+ self.assertAllEqual(r0, [])
+ self.assertAllEqual(r1, [0, 1, 3])
+
+ r0, r1 = self._GetGradientArgs([3, 1, 5], [2, 0, 3, 0, 5])
+ self.assertAllEqual(r0, [0, 1, 3])
+ self.assertAllEqual(r1, [])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/bias_op_test.py b/tensorflow/python/kernel_tests/bias_op_test.py
new file mode 100644
index 0000000000..f3a26e2490
--- /dev/null
+++ b/tensorflow/python/kernel_tests/bias_op_test.py
@@ -0,0 +1,93 @@
+"""Functional tests for BiasAdd."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker
+
+
+class BiasAddTest(tf.test.TestCase):
+
+ def _npBias(self, inputs, bias):
+ assert len(bias.shape) == 1
+ print inputs.shape
+ print bias.shape
+ assert inputs.shape[-1] == bias.shape[0]
+ return inputs + bias.reshape(([1] * (len(inputs.shape) - 1))
+ + [bias.shape[0]])
+
+ def testNpBias(self):
+ self.assertAllClose(np.array([[11, 22, 33], [41, 52, 63]]),
+ self._npBias(np.array([[10, 20, 30], [40, 50, 60]]),
+ np.array([1, 2, 3])))
+
+ def _testBias(self, np_inputs, np_bias, use_gpu=False):
+ np_val = self._npBias(np_inputs, np_bias)
+ with self.test_session(use_gpu=use_gpu):
+ tf_val = tf.nn.bias_add(np_inputs, np_bias).eval()
+ self.assertAllClose(np_val, tf_val)
+
+ def _testAll(self, np_inputs, np_bias):
+ self._testBias(np_inputs, np_bias, use_gpu=False)
+ if np_inputs.dtype == np.float32 or np_inputs.dtype == np.float64:
+ self._testBias(np_inputs, np_bias, use_gpu=True)
+
+ def testInputDims(self):
+ with self.assertRaises(ValueError):
+ tf.nn.bias_add([1, 2], [1])
+
+ def testBiasVec(self):
+ with self.assertRaises(ValueError):
+ tf.nn.bias_add(tf.reshape([1, 2], shape=[1, 2]),
+ tf.reshape([1, 2], shape=[1, 2]))
+
+ def testBiasInputsMatch(self):
+ with self.assertRaises(ValueError):
+ tf.nn.bias_add(tf.reshape([1, 2], shape=[1, 2]),
+ tf.reshape([1], shape=[1]))
+
+ def testIntTypes(self):
+ for t in [np.int8, np.int16, np.int32, np.int64]:
+ self._testAll(np.array([[10, 20, 30], [40, 50, 60]]).astype(t),
+ np.array([1, 2, 3]).astype(t))
+
+ def testFloatTypes(self):
+ for t in [np.float32, np.float64]:
+ self._testAll(np.random.rand(4, 3, 3).astype(t),
+ np.random.rand(3).astype(t))
+
+ def testGradientTensor(self):
+ with self.test_session():
+ t = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2],
+ dtype=tf.float64)
+ b = tf.constant([1.3, 2.4], dtype=tf.float64)
+ bo = tf.nn.bias_add(t, b)
+ err = gradient_checker.ComputeGradientError(t, [3, 2], bo, [3, 2])
+ print "bias add tensor gradient err = ", err
+ self.assertLess(err, 1e-10)
+
+ def testGradientBias(self):
+ with self.test_session():
+ t = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2],
+ dtype=tf.float64)
+ b = tf.constant([1.3, 2.4], dtype=tf.float64)
+ bo = tf.nn.bias_add(t, b)
+ err = gradient_checker.ComputeGradientError(b, [2], bo, [3, 2])
+ print "bias add bias gradient err = ", err
+ self.assertLess(err, 1e-10)
+
+ def testGradientTensor4D(self):
+ with self.test_session():
+ s = [2, 3, 4, 2]
+ x = np.arange(1.0, 49.0).reshape(s).astype(np.float32)
+ t = tf.constant(x, shape=s, dtype=tf.float32)
+ b = tf.constant([1.3, 2.4], dtype=tf.float32)
+ bo = tf.nn.bias_add(t, b)
+ err = gradient_checker.ComputeGradientError(t, s, bo, s, x_init_value=x)
+ print "bias add tensor gradient err = ", err
+ self.assertLess(err, 1e-3)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py b/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py
new file mode 100644
index 0000000000..a36b8587d5
--- /dev/null
+++ b/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py
@@ -0,0 +1,114 @@
+"""Tests for CandidateSamplerOp."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class RangeSamplerOpsTest(tf.test.TestCase):
+
+ BATCH_SIZE = 3
+ NUM_TRUE = 2
+ RANGE = 5
+ NUM_SAMPLED = RANGE
+
+ TRUE_LABELS = [[1, 2], [0, 4], [3, 3]]
+
+ def testTrueCandidates(self):
+ with self.test_session() as sess:
+ indices = tf.constant([0, 0, 1, 1, 2, 2])
+ true_candidates_vec = tf.constant([1, 2, 0, 4, 3, 3])
+ true_candidates_matrix = tf.reshape(
+ true_candidates_vec, [self.BATCH_SIZE, self.NUM_TRUE])
+ indices_val, true_candidates_val = sess.run(
+ [indices, true_candidates_matrix])
+
+ self.assertAllEqual(indices_val, [0, 0, 1, 1, 2, 2])
+ self.assertAllEqual(true_candidates_val, self.TRUE_LABELS)
+
+ def testSampledCandidates(self):
+ with self.test_session():
+ true_classes = tf.constant([[1, 2], [0, 4], [3, 3]],
+ dtype=tf.int64)
+ sampled_candidates, _, _ = tf.nn.all_candidate_sampler(
+ true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
+ result = sampled_candidates.eval()
+
+ expected_ids = [0, 1, 2, 3, 4]
+ self.assertAllEqual(result, expected_ids)
+ self.assertEqual(sampled_candidates.get_shape(), [self.NUM_SAMPLED])
+
+ def testTrueLogExpectedCount(self):
+ with self.test_session():
+ true_classes = tf.constant([[1, 2], [0, 4], [3, 3]],
+ dtype=tf.int64)
+ _, true_expected_count, _ = tf.nn.all_candidate_sampler(
+ true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
+ true_log_expected_count = tf.log(true_expected_count)
+ result = true_log_expected_count.eval()
+
+ self.assertAllEqual(result, [[0.0] * self.NUM_TRUE] * self.BATCH_SIZE)
+ self.assertEqual(true_expected_count.get_shape(), [self.BATCH_SIZE,
+ self.NUM_TRUE])
+ self.assertEqual(true_log_expected_count.get_shape(), [self.BATCH_SIZE,
+ self.NUM_TRUE])
+
+ def testSampledLogExpectedCount(self):
+ with self.test_session():
+ true_classes = tf.constant([[1, 2], [0, 4], [3, 3]],
+ dtype=tf.int64)
+ _, _, sampled_expected_count = tf.nn.all_candidate_sampler(
+ true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
+ sampled_log_expected_count = tf.log(sampled_expected_count)
+ result = sampled_log_expected_count.eval()
+
+ self.assertAllEqual(result, [0.0] * self.NUM_SAMPLED)
+ self.assertEqual(sampled_expected_count.get_shape(), [self.NUM_SAMPLED])
+ self.assertEqual(sampled_log_expected_count.get_shape(), [self.NUM_SAMPLED])
+
+ def testAccidentalHits(self):
+ with self.test_session() as sess:
+ true_classes = tf.constant([[1, 2], [0, 4], [3, 3]],
+ dtype=tf.int64)
+ sampled_candidates, _, _ = tf.nn.all_candidate_sampler(
+ true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
+ accidental_hits = tf.nn.compute_accidental_hits(
+ true_classes, sampled_candidates, self.NUM_TRUE)
+ indices, ids, weights = sess.run(accidental_hits)
+
+ self.assertEqual(1, accidental_hits[0].get_shape().ndims)
+ self.assertEqual(1, accidental_hits[1].get_shape().ndims)
+ self.assertEqual(1, accidental_hits[2].get_shape().ndims)
+ for index, id_, weight in zip(indices, ids, weights):
+ self.assertTrue(id_ in self.TRUE_LABELS[index])
+ self.assertLess(weight, -1.0e37)
+
+ def testSeed(self):
+
+ def draw(seed):
+ with self.test_session():
+ true_classes = tf.constant([[1, 2], [0, 4], [3, 3]],
+ dtype=tf.int64)
+ sampled, _, _ = tf.nn.log_uniform_candidate_sampler(
+ true_classes,
+ self.NUM_TRUE,
+ self.NUM_SAMPLED,
+ True,
+ 5,
+ seed=seed)
+ return sampled.eval()
+ # Non-zero seed. Repeatable.
+ for seed in [1, 12, 123, 1234]:
+ self.assertAllEqual(draw(seed), draw(seed))
+ # Seed=0 means random seeds.
+ num_same = 0
+ for _ in range(10):
+ if np.allclose(draw(None), draw(None)):
+ num_same += 1
+ # Accounts for the fact that the same random seed may be picked
+ # twice very rarely.
+ self.assertLessEqual(num_same, 2)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/cast_op_test.py b/tensorflow/python/kernel_tests/cast_op_test.py
new file mode 100644
index 0000000000..21e8f71198
--- /dev/null
+++ b/tensorflow/python/kernel_tests/cast_op_test.py
@@ -0,0 +1,165 @@
+"""Tests for tensorflow.ops.tf.cast."""
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+
+class CastOpTest(tf.test.TestCase):
+
+ def _toDataType(self, dtype):
+ """Returns TensorFlow data type for numpy type."""
+ if dtype == np.float32:
+ return tf.float32
+ elif dtype == np.float64:
+ return tf.float64
+ elif dtype == np.int32:
+ return tf.int32
+ elif dtype == np.int64:
+ return tf.int64
+ elif dtype == np.bool:
+ return tf.bool
+ else:
+ return None
+
+ def _cast(self, x, dtype, use_gpu=False):
+ with self.test_session(use_gpu=use_gpu):
+ val = tf.constant(x, self._toDataType(np.array([x]).dtype))
+ return tf.cast(val, self._toDataType(dtype), name="cast").eval()
+
+ def _test(self, x, dtype, use_gpu=False):
+ """Tests cast(x) to dtype behaves the same as numpy.astype."""
+ np_ans = x.astype(dtype)
+ tf_ans = self._cast(x, dtype, use_gpu)
+ self.assertAllEqual(np_ans, tf_ans)
+
+ def _testTypes(self, x, use_gpu=False):
+ """Tests cast(x) to different tf."""
+ if use_gpu:
+ type_list = [np.float32, np.float64, np.int64]
+ else:
+ type_list = [np.float32, np.float64, np.int32, np.int64]
+ for from_type in type_list:
+ for to_type in type_list:
+ self._test(x.astype(from_type), to_type, use_gpu)
+
+ self._test(x.astype(np.bool), np.float32, use_gpu)
+ self._test(x.astype(np.uint8), np.float32, use_gpu)
+ if not use_gpu:
+ self._test(x.astype(np.bool), np.int32, use_gpu)
+ self._test(x.astype(np.int32), np.int32, use_gpu)
+
+ def _testAll(self, x):
+ self._testTypes(x, use_gpu=False)
+ if x.dtype == np.float32 or x.dtype == np.float64:
+ self._testTypes(x, use_gpu=True)
+
+ def testBasic(self):
+ self._testAll(np.arange(-10, 10).reshape(2, 10))
+ self._testAll(np.linspace(-10, 10, 17))
+
+ def testSmallValues(self):
+ f4 = np.finfo(np.float32)
+ f8 = np.finfo(np.float64)
+ self._testAll(np.array([0, -1, 1, -f4.resolution, f4.resolution,
+ f8.resolution, -f8.resolution]))
+
+ def testBfloat16(self):
+ a = np.random.uniform(-100, 100, 100).astype(np.float32)
+ with self.test_session(use_gpu=False):
+ b = tf.cast(tf.cast(a, tf.bfloat16), tf.float32)
+ self.assertAllClose(a, b.eval(), rtol=1/128.)
+ with self.test_session(use_gpu=True):
+ b = tf.cast(tf.cast(a, tf.bfloat16), tf.float32)
+ self.assertAllClose(a, b.eval(), rtol=1/128.)
+
+ def testRandom(self):
+ self._testAll(np.random.normal(0, 10, 210).reshape([2, 3, 5, 7]))
+ self._testAll(np.random.normal(0, 1e6, 210).reshape([2, 3, 5, 7]))
+
+ # Special values like int32max, int64min, inf, -inf, nan casted to
+ # integer values in somewhat unexpected ways. And they behave
+ # differently on CPU and GPU.
+ def _compare(self, x, dst_dtype, expected, use_gpu=False):
+ np.testing.assert_equal(self._cast(x, dst_dtype, use_gpu=use_gpu),
+ dst_dtype(expected))
+
+ def testIntToFloatBoundary(self):
+ i4 = np.iinfo(np.int32)
+ i8 = np.iinfo(np.int64)
+
+ self._compare(i4.min, np.float32, i4.min, False)
+ self._compare(i4.max, np.float32, i4.max, False)
+ self._compare(i8.min, np.float32, i8.min, False)
+ self._compare(i8.max, np.float32, i8.max, False)
+ self._compare(i4.min, np.float64, i4.min, False)
+ self._compare(i4.max, np.float64, i4.max, False)
+ self._compare(i8.min, np.float64, i8.min, False)
+ self._compare(i8.max, np.float64, i8.max, False)
+ # NOTE: GPU does not support int32/int64 for casting.
+
+ def testInfNan(self):
+ i4 = np.iinfo(np.int32)
+ i8 = np.iinfo(np.int64)
+
+ self._compare(np.inf, np.float32, np.inf, False)
+ self._compare(np.inf, np.float64, np.inf, False)
+ self._compare(np.inf, np.int32, i4.min, False)
+ self._compare(np.inf, np.int64, i8.min, False)
+ self._compare(-np.inf, np.float32, -np.inf, False)
+ self._compare(-np.inf, np.float64, -np.inf, False)
+ self._compare(-np.inf, np.int32, i4.min, False)
+ self._compare(-np.inf, np.int64, i8.min, False)
+ self.assertAllEqual(np.isnan(self._cast(np.nan, np.float32, False)), True)
+ self.assertAllEqual(np.isnan(self._cast(np.nan, np.float64, False)), True)
+ self._compare(np.nan, np.int32, i4.min, False)
+ self._compare(np.nan, np.int64, i8.min, False)
+
+ self._compare(np.inf, np.float32, np.inf, True)
+ self._compare(np.inf, np.float64, np.inf, True)
+ self._compare(-np.inf, np.float32, -np.inf, True)
+ self._compare(-np.inf, np.float64, -np.inf, True)
+ self.assertAllEqual(np.isnan(self._cast(np.nan, np.float32, True)), True)
+ self.assertAllEqual(np.isnan(self._cast(np.nan, np.float64, True)), True)
+
+ def _OpError(self, x, dtype, err):
+ with self.test_session():
+ with self.assertRaisesOpError(err):
+ tf.cast(x, dtype).eval()
+
+ def testNotImplemented(self):
+ self._OpError(np.arange(0, 10), tf.string,
+ "Cast.*int64.*string.*")
+
+ def testGradients(self):
+ t = [tf.float32, tf.float64]
+ for src_t in t:
+ for dst_t in t:
+ with self.test_session():
+ x = tf.constant(1.0, src_t)
+ z = tf.identity(x)
+ y = tf.cast(z, dst_t)
+ err = gc.ComputeGradientError(x, [1], y, [1])
+ self.assertLess(err, 1e-3)
+
+
+class SparseTensorCastTest(tf.test.TestCase):
+
+ def testCast(self):
+ indices = tf.constant([[0L], [1L], [2L]])
+ values = tf.constant(np.array([1, 2, 3], np.int64))
+ shape = tf.constant([3L])
+ st = tf.SparseTensor(indices, values, shape)
+ st_cast = tf.cast(st, tf.float32)
+ with self.test_session():
+ self.assertAllEqual(st_cast.indices.eval(), [[0L], [1L], [2L]])
+ self.assertAllEqual(st_cast.values.eval(),
+ np.array([1, 2, 3], np.float32))
+ self.assertAllEqual(st_cast.shape.eval(), [3L])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py
new file mode 100644
index 0000000000..17e8d116be
--- /dev/null
+++ b/tensorflow/python/kernel_tests/cholesky_op_test.py
@@ -0,0 +1,74 @@
+"""Tests for tensorflow.ops.tf.Cholesky."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class CholeskyOpTest(tf.test.TestCase):
+
+ def _verifyCholesky(self, x):
+ with self.test_session() as sess:
+ # Verify that LL^T == x.
+ if x.ndim == 2:
+ chol = tf.cholesky(x)
+ verification = tf.matmul(chol,
+ chol,
+ transpose_a=False,
+ transpose_b=True)
+ else:
+ chol = tf.batch_cholesky(x)
+ verification = tf.batch_matmul(chol, chol, adj_x=False, adj_y=True)
+ chol_np, verification_np = sess.run([chol, verification])
+ self.assertAllClose(x, verification_np)
+ self.assertShapeEqual(x, chol)
+ # Check that the cholesky is lower triangular, and has positive diagonal
+ # elements.
+ if chol_np.shape[-1] > 0:
+ chol_reshaped = np.reshape(chol_np, (-1, chol_np.shape[-2],
+ chol_np.shape[-1]))
+ for chol_matrix in chol_reshaped:
+ self.assertAllClose(chol_matrix, np.tril(chol_matrix))
+ self.assertTrue((np.diag(chol_matrix) > 0.0).all())
+
+ def testBasic(self):
+ self._verifyCholesky(np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]))
+
+ def testBatch(self):
+ simple_array = np.array([[[1., 0.], [0., 5.]]]) # shape (1, 2, 2)
+ self._verifyCholesky(simple_array)
+ self._verifyCholesky(np.vstack((simple_array, simple_array)))
+ odd_sized_array = np.array([[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]])
+ self._verifyCholesky(np.vstack((odd_sized_array, odd_sized_array)))
+
+ # Generate random positive-definite matrices.
+ matrices = np.random.rand(10, 5, 5)
+ for i in xrange(10):
+ matrices[i] = np.dot(matrices[i].T, matrices[i])
+ self._verifyCholesky(matrices)
+
+ def testNonSquareMatrix(self):
+ with self.assertRaises(ValueError):
+ tf.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]]))
+
+ def testWrongDimensions(self):
+ tensor3 = tf.constant([1., 2.])
+ with self.assertRaises(ValueError):
+ tf.cholesky(tensor3)
+
+ def testNotInvertible(self):
+ # The input should be invertible.
+ with self.test_session():
+ with self.assertRaisesOpError("LLT decomposition was not successful. The "
+ "input might not be valid."):
+ # All rows of the matrix below add to zero
+ self._verifyCholesky(np.array([[1., -1., 0.], [-1., 1., -1.], [0., -1.,
+ 1.]]))
+
+ def testEmpty(self):
+ self._verifyCholesky(np.empty([0, 2, 2]))
+ self._verifyCholesky(np.empty([2, 0, 0]))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py
new file mode 100644
index 0000000000..46bba7514d
--- /dev/null
+++ b/tensorflow/python/kernel_tests/clip_ops_test.py
@@ -0,0 +1,222 @@
+"""Tests for tensorflow.ops.clip_ops."""
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class ClipTest(tf.test.TestCase):
+
+ # ClipByValue test
+ def testClipByValue(self):
+ with self.test_session():
+ x = tf.constant([-5.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3])
+ np_ans = [[-4.4, 2.0, 3.0],
+ [4.0, 4.4, 4.4]]
+ clip_value = 4.4
+ ans = tf.clip_by_value(x, -clip_value, clip_value)
+ tf_ans = ans.eval()
+
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testClipByValueNonFinite(self):
+ with self.test_session():
+ x = tf.constant([float('NaN'), float('Inf'), -float('Inf')])
+ np_ans = [float('NaN'), 4.0, -4.0]
+ clip_value = 4.0
+ ans = tf.clip_by_value(x, -clip_value, clip_value)
+ tf_ans = ans.eval()
+
+ self.assertAllClose(np_ans, tf_ans)
+
+ # ClipByNorm tests
+ def testClipByNormClipped(self):
+ # Norm clipping when clip_norm < 5
+ with self.test_session():
+ x = tf.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
+ # Norm of x = sqrt(3^2 + 4^2) = 5
+ np_ans = [[-2.4, 0.0, 0.0],
+ [3.2, 0.0, 0.0]]
+ clip_norm = 4.0
+ ans = tf.clip_by_norm(x, clip_norm)
+ tf_ans = ans.eval()
+
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testClipByNormNotClipped(self):
+ # No norm clipping when clip_norm >= 5
+ with self.test_session():
+ x = tf.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
+ # Norm of x = sqrt(3^2 + 4^2) = 5
+ np_ans = [[-3.0, 0.0, 0.0],
+ [4.0, 0.0, 0.0]]
+ clip_norm = 6.0
+ ans = tf.clip_by_norm(x, clip_norm)
+ tf_ans = ans.eval()
+
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testClipByNormZero(self):
+ # No norm clipping when norm = 0
+ with self.test_session():
+ x = tf.constant([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], shape=[2, 3])
+ # Norm = 0, no changes
+ np_ans = [[0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0]]
+ clip_norm = 6.0
+ ans = tf.clip_by_norm(x, clip_norm)
+ tf_ans = ans.eval()
+
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testClipByGlobalNormClipped(self):
+ # Norm clipping when clip_norm < 5
+ with self.test_session():
+ x0 = tf.constant([-2.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
+ x1 = tf.constant([1.0, -2.0])
+ # Global norm of x0 and x1 = sqrt(1 + 4^2 + 2^2 + 2^2) = 5
+ clip_norm = 4.0
+
+ # Answers are the original tensors scaled by 4.0/5.0
+ np_ans_0 = [[-1.6, 0.0, 0.0],
+ [3.2, 0.0, 0.0]]
+ np_ans_1 = [0.8, -1.6]
+
+ ans, norm = tf.clip_by_global_norm((x0, x1), clip_norm)
+ tf_ans_1 = ans[0].eval()
+ tf_ans_2 = ans[1].eval()
+ tf_norm = norm.eval()
+
+ self.assertAllClose(tf_norm, 5.0)
+ self.assertAllClose(np_ans_0, tf_ans_1)
+ self.assertAllClose(np_ans_1, tf_ans_2)
+
+ def testClipByGlobalNormSupportsNone(self):
+ # Norm clipping when clip_norm < 5
+ with self.test_session():
+ x0 = tf.constant([-2.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
+ x1 = tf.constant([1.0, -2.0])
+ # Global norm of x0 and x1 = sqrt(1 + 4^2 + 2^2 + 2^2) = 5
+ clip_norm = 4.0
+
+ # Answers are the original tensors scaled by 4.0/5.0
+ np_ans_0 = [[-1.6, 0.0, 0.0],
+ [3.2, 0.0, 0.0]]
+ np_ans_1 = [0.8, -1.6]
+
+ ans, norm = tf.clip_by_global_norm((x0, None, x1, None), clip_norm)
+ self.assertTrue(ans[1] is None)
+ self.assertTrue(ans[3] is None)
+ tf_ans_1 = ans[0].eval()
+ tf_ans_2 = ans[2].eval()
+ tf_norm = norm.eval()
+
+ self.assertAllClose(tf_norm, 5.0)
+ self.assertAllClose(np_ans_0, tf_ans_1)
+ self.assertAllClose(np_ans_1, tf_ans_2)
+
+ # ClipByGlobalNorm tests
+ def testClipByGlobalNormWithIndexedSlicesClipped(self):
+ # Norm clipping when clip_norm < 5
+ with self.test_session():
+ x0 = tf.constant([-2.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
+ x1 = tf.IndexedSlices(tf.constant([1.0, -2.0]),
+ tf.constant([3, 4]))
+ # Global norm of x0 and x1 = sqrt(1 + 4^2 + 2^2 + 2^2) = 5
+ clip_norm = 4.0
+
+ # Answers are the original tensors scaled by 4.0/5.0
+ np_ans_0 = [[-1.6, 0.0, 0.0],
+ [3.2, 0.0, 0.0]]
+ np_ans_1 = [0.8, -1.6]
+
+ ans, norm = tf.clip_by_global_norm([x0, x1], clip_norm)
+ tf_ans_1 = ans[0].eval()
+ tf_ans_2 = ans[1].values.eval()
+ tf_norm = norm.eval()
+
+ self.assertAllClose(tf_norm, 5.0)
+ self.assertAllClose(np_ans_0, tf_ans_1)
+ self.assertAllClose(np_ans_1, tf_ans_2)
+
+ def testClipByGlobalNormNotClipped(self):
+ # No norm clipping when clip_norm >= 5
+ with self.test_session():
+ x0 = tf.constant([-2.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
+ x1 = tf.constant([1.0, -2.0])
+ # Global norm of x0 and x1 = sqrt(1 + 4^2 + 2^2 + 2^2) = 5
+ np_ans_0 = [[-2.0, 0.0, 0.0],
+ [4.0, 0.0, 0.0]]
+ np_ans_1 = [1.0, -2.0]
+ clip_norm = 6.0
+
+ ans, norm = tf.clip_by_global_norm([x0, x1], clip_norm)
+ tf_ans_1 = ans[0].eval()
+ tf_ans_2 = ans[1].eval()
+ tf_norm = norm.eval()
+
+ self.assertAllClose(tf_norm, 5.0)
+ self.assertAllClose(np_ans_0, tf_ans_1)
+ self.assertAllClose(np_ans_1, tf_ans_2)
+
+ def testClipByGlobalNormZero(self):
+ # No norm clipping when norm = 0
+ with self.test_session():
+ x0 = tf.constant([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], shape=[2, 3])
+ x1 = tf.constant([0.0, 0.0])
+ # Norm = 0, no changes
+ np_ans_0 = [[0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0]]
+ np_ans_1 = [0.0, 0.0]
+ clip_norm = 6.0
+
+ ans, norm = tf.clip_by_global_norm([x0, x1], clip_norm)
+ tf_ans_1 = ans[0].eval()
+ tf_ans_2 = ans[1].eval()
+ tf_norm = norm.eval()
+
+ self.assertAllClose(tf_norm, 0.0)
+ self.assertAllClose(np_ans_0, tf_ans_1)
+ self.assertAllClose(np_ans_1, tf_ans_2)
+
+ def testClipByAverageNormClipped(self):
+ # Norm clipping when average clip_norm < 0.83333333
+ with self.test_session():
+ x = tf.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
+ # Average norm of x = sqrt(3^2 + 4^2) / 6 = 0.83333333
+ np_ans = [[-2.88, 0.0, 0.0],
+ [3.84, 0.0, 0.0]]
+ clip_norm = 0.8
+ ans = tf.clip_by_average_norm(x, clip_norm)
+ tf_ans = ans.eval()
+
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testClipByAverageNormNotClipped(self):
+ # No norm clipping when average clip_norm >= 0.83333333
+ with self.test_session():
+ x = tf.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3])
+ # Average norm of x = sqrt(3^2 + 4^2) / 6 = 0.83333333
+ np_ans = [[-3.0, 0.0, 0.0],
+ [4.0, 0.0, 0.0]]
+ clip_norm = 0.9
+ ans = tf.clip_by_average_norm(x, clip_norm)
+ tf_ans = ans.eval()
+
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testClipByAverageNormZero(self):
+ # No norm clipping when average clip_norm = 0
+ with self.test_session():
+ x = tf.constant([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], shape=[2, 3])
+ # Average norm = 0, no changes
+ np_ans = [[0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0]]
+ clip_norm = 0.9
+ ans = tf.clip_by_average_norm(x, clip_norm)
+ tf_ans = ans.eval()
+
+ self.assertAllClose(np_ans, tf_ans)
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py
new file mode 100644
index 0000000000..3f6c43f0a6
--- /dev/null
+++ b/tensorflow/python/kernel_tests/concat_op_test.py
@@ -0,0 +1,276 @@
+"""Functional tests for Concat Op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class ConcatOpTest(tf.test.TestCase):
+
+ def testHStack(self):
+ with self.test_session():
+ p1 = tf.placeholder(tf.float32, shape=[4, 4])
+ p2 = tf.placeholder(tf.float32, shape=[4, 4])
+ c = tf.concat(0, [p1, p2])
+ params = {
+ p1: np.random.rand(4, 4).astype("f"),
+ p2: np.random.rand(4, 4).astype("f")
+ }
+ result = c.eval(feed_dict=params)
+
+ self.assertEqual(result.shape, c.get_shape())
+ self.assertAllEqual(result[:4, :], params[p1])
+ self.assertAllEqual(result[4:, :], params[p2])
+
+ def testVStack(self):
+ with self.test_session():
+ p1 = tf.placeholder(tf.float32, shape=[4, 4])
+ p2 = tf.placeholder(tf.float32, shape=[4, 4])
+ c = tf.concat(1, [p1, p2])
+ params = {
+ p1: np.random.rand(4, 4).astype("f"),
+ p2: np.random.rand(4, 4).astype("f")
+ }
+ result = c.eval(feed_dict=params)
+
+ self.assertEqual(result.shape, c.get_shape())
+ self.assertAllEqual(result[:, :4], params[p1])
+ self.assertAllEqual(result[:, 4:], params[p2])
+
+ def testInt32GPU(self):
+ with self.test_session(use_gpu=True):
+ p1 = np.random.rand(2, 3).astype("i")
+ p2 = np.random.rand(2, 3).astype("i")
+ x1 = tf.constant(p1)
+ x2 = tf.constant(p2)
+ c = tf.concat(0, [x1, x2])
+ result = c.eval()
+ self.assertAllEqual(result[:2, :], p1)
+ self.assertAllEqual(result[2:, :], p2)
+
+ def testRefType(self):
+ with self.test_session():
+ p1 = tf.placeholder(tf.float32_ref, shape=[4, 4])
+ p2 = tf.placeholder(tf.float32_ref, shape=[4, 4])
+ c = tf.concat(0, [p1, p2])
+ params = {
+ p1: np.random.rand(4, 4).astype("f"),
+ p2: np.random.rand(4, 4).astype("f")
+ }
+ result = c.eval(feed_dict=params)
+
+ self.assertEqual(result.shape, c.get_shape())
+ self.assertAllEqual(result[:4, :], params[p1])
+ self.assertAllEqual(result[4:, :], params[p2])
+
+ def _testRandom(self, dtype, use_gpu=False):
+ # Random dims of rank 5
+ shape = np.random.randint(1, 5, size=5)
+ # Random number of tensors, but always > 1.
+ num_tensors = np.random.randint(2, 10)
+ # Random dim to concat on
+ concat_dim = np.random.randint(5)
+ params = {}
+ with self.test_session(use_gpu=use_gpu):
+ p = []
+ for i in np.arange(num_tensors):
+ input_shape = shape
+ input_shape[concat_dim] = np.random.randint(1, 5)
+ placeholder = tf.placeholder(dtype, shape=input_shape)
+ p.append(placeholder)
+
+ t = dtype.as_numpy_dtype
+ params[placeholder] = np.random.rand(*input_shape).astype(t)
+
+ c = tf.concat(concat_dim, p)
+ result = c.eval(feed_dict=params)
+
+ self.assertEqual(result.shape, c.get_shape())
+ cur_offset = 0
+
+ for i in np.arange(num_tensors):
+ # The index into the result is the ':' along all dimensions
+ # except the concat_dim. slice(0, size) is used for ':', and
+ # a list of slices is used to index into result.
+ ind = [slice(0, params[p[i]].shape[j]) for j in np.arange(5)]
+ ind[concat_dim] = slice(cur_offset,
+ cur_offset + params[p[i]].shape[concat_dim])
+ cur_offset += params[p[i]].shape[concat_dim]
+ self.assertAllEqual(result[ind], params[p[i]])
+
+ def testRandom(self):
+ self._testRandom(tf.float32)
+ self._testRandom(tf.int16)
+ self._testRandom(tf.int32, use_gpu=True)
+ # Note that the following does not work since bfloat16 is not supported in
+ # numpy.
+ # self._testRandom(tf.bfloat16)
+
+ def _testGradientsSimple(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ inp = []
+ inp_tensors = []
+ for x in [1, 2, 6]:
+ shape = [10, x, 2]
+ t = np.random.rand(*shape).astype("f")
+ inp.append(t)
+ inp_tensors.append(
+ tf.constant([float(y) for y in t.flatten()],
+ shape=shape, dtype=tf.float32))
+ c = tf.concat(1, inp_tensors)
+ output_shape = [10, 9, 2]
+ grad_inp = np.random.rand(*output_shape).astype("f")
+ grad_tensor = tf.constant([float(x) for x in grad_inp.flatten()],
+ shape=output_shape)
+ grad = tf.gradients([c], inp_tensors, [grad_tensor])
+ concated_grad = tf.concat(1, grad)
+ result = concated_grad.eval()
+
+ self.assertAllEqual(result, grad_inp)
+
+ def testGradientsSimpleAll(self):
+ self._testGradientsSimple(use_gpu=False)
+ self._testGradientsSimple(use_gpu=True)
+
+ def _testGradientsFirstDim(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ inp = []
+ inp_tensors = []
+ for x in [1, 2, 6]:
+ shape = [x, 10, 2]
+ t = np.random.rand(*shape).astype("f")
+ inp.append(t)
+ inp_tensors.append(
+ tf.constant([float(y) for y in t.flatten()],
+ shape=shape, dtype=tf.float32))
+ c = tf.concat(0, inp_tensors)
+ output_shape = [9, 10, 2]
+ grad_inp = np.random.rand(*output_shape).astype("f")
+ grad_tensor = tf.constant([float(x) for x in grad_inp.flatten()],
+ shape=output_shape)
+ grad = tf.gradients([c], inp_tensors, [grad_tensor])
+ concated_grad = tf.concat(0, grad)
+ result = concated_grad.eval()
+
+ self.assertAllEqual(result, grad_inp)
+
+ def testGradientsFirstDimAll(self):
+ self._testGradientsFirstDim(use_gpu=False)
+ self._testGradientsFirstDim(use_gpu=True)
+
+ def _testGradientsLastDim(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ inp = []
+ inp_tensors = []
+ for x in [1, 2, 6]:
+ shape = [10, 2, x]
+ t = np.random.rand(*shape).astype("f")
+ inp.append(t)
+ inp_tensors.append(
+ tf.constant([float(y) for y in t.flatten()],
+ shape=shape, dtype=tf.float32))
+ c = tf.concat(2, inp_tensors)
+ output_shape = [10, 2, 9]
+ grad_inp = np.random.rand(*output_shape).astype("f")
+ grad_tensor = tf.constant([float(x) for x in grad_inp.flatten()],
+ shape=output_shape)
+ grad = tf.gradients([c], inp_tensors, [grad_tensor])
+ concated_grad = tf.concat(2, grad)
+ result = concated_grad.eval()
+
+ self.assertAllEqual(result, grad_inp)
+
+ def testGradientsLastDimAll(self):
+ self._testGradientsLastDim(use_gpu=False)
+ self._testGradientsLastDim(use_gpu=True)
+
+ def _RunAndVerifyGradientsRandom(self, use_gpu):
+ # Random dims of rank 5
+ input_shape = np.random.randint(1, 5, size=5)
+ # Random number of tensors
+ num_tensors = np.random.randint(1, 10)
+ # Random dim to concat on
+ concat_dim = np.random.randint(5)
+ concat_dim_sizes = np.random.randint(1, 5, size=num_tensors)
+ with self.test_session(use_gpu=use_gpu):
+ inp = []
+ inp_tensors = []
+ for x in concat_dim_sizes:
+ shape = input_shape
+ shape[concat_dim] = x
+ t = np.random.rand(*shape).astype("f")
+ inp.append(t)
+ inp_tensors.append(
+ tf.constant([float(y) for y in t.flatten()],
+ shape=shape, dtype=tf.float32))
+ c = tf.concat(concat_dim, inp_tensors)
+ output_shape = input_shape
+ output_shape[concat_dim] = concat_dim_sizes.sum()
+ grad_inp = np.random.rand(*output_shape).astype("f")
+ grad_tensor = tf.constant([float(x) for x in grad_inp.flatten()],
+ shape=output_shape)
+ grad = tf.gradients([c], inp_tensors, [grad_tensor])
+ concated_grad = tf.concat(concat_dim, grad)
+ result = concated_grad.eval()
+
+ self.assertAllEqual(result, grad_inp)
+
+ def testGradientsRandom(self):
+ for _ in range(5):
+ self._RunAndVerifyGradientsRandom(use_gpu=False)
+ self._RunAndVerifyGradientsRandom(use_gpu=True)
+
+ def testShapeError(self):
+ # Rank doesn't match.
+ with self.assertRaises(ValueError):
+ tf.concat(1, [tf.constant(10.0, shape=[4, 4, 4, 4]),
+ tf.constant(20.0, shape=[4, 4, 4])])
+
+ # Dimensions don't match in a non-concat dim.
+ with self.assertRaises(ValueError):
+ tf.concat(1, [tf.constant(10.0, shape=[1, 2, 1]),
+ tf.constant(20.0, shape=[3, 2, 1])])
+
+ # concat_dim out of range.
+ with self.assertRaises(ValueError):
+ tf.concat(3, [tf.constant(10.0, shape=[4, 4, 4]),
+ tf.constant(20.0, shape=[4, 4, 4])])
+
+ def testShapeWithUnknownConcatDim(self):
+ p1 = tf.placeholder(tf.float32)
+ c1 = tf.constant(10.0, shape=[4, 4, 4, 4])
+ p2 = tf.placeholder(tf.float32)
+ c2 = tf.constant(20.0, shape=[4, 4, 4, 4])
+ dim = tf.placeholder(tf.int32)
+ concat = tf.concat(dim, [p1, c1, p2, c2])
+ self.assertEqual(4, concat.get_shape().ndims)
+
+ # Rank doesn't match.
+ c3 = tf.constant(30.0, shape=[4, 4, 4])
+ with self.assertRaises(ValueError):
+ tf.concat(dim, [p1, c1, p2, c3])
+
+ def testZeroSize(self):
+ # Verify that concat doesn't crash and burn for zero size inputs
+ np.random.seed(7)
+ for use_gpu in False, True:
+ with self.test_session(use_gpu=use_gpu) as sess:
+ for shape0 in (), (2,):
+ axis = len(shape0)
+ for shape1 in (), (3,):
+ for n0 in 0, 1, 2:
+ for n1 in 0, 1, 2:
+ x0 = np.random.randn(*(shape0 + (n0,) + shape1))
+ x1 = np.random.randn(*(shape0 + (n1,) + shape1))
+ correct = np.concatenate([x0, x1], axis=axis)
+ xs = map(tf.constant, [x0, x1])
+ c = tf.concat(axis, xs)
+ self.assertAllEqual(c.eval(), correct)
+ # Check gradients
+ dc = np.random.randn(*c.get_shape().as_list())
+ dxs = sess.run(tf.gradients(c, xs, dc))
+ self.assertAllEqual(dc, np.concatenate(dxs, axis=axis))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
new file mode 100644
index 0000000000..92f9b5fe4a
--- /dev/null
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -0,0 +1,524 @@
+"""Tests for ConstantOp."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.ops import gen_array_ops
+
+
+class ConstantTest(tf.test.TestCase):
+
+ def _testCpu(self, x):
+ np_ans = np.array(x)
+ with self.test_session(use_gpu=False):
+ tf_ans = tf.convert_to_tensor(x).eval()
+ if np_ans.dtype in [np.float32, np.float64, np.complex64]:
+ self.assertAllClose(np_ans, tf_ans)
+ else:
+ self.assertAllEqual(np_ans, tf_ans)
+
+ def _testGpu(self, x):
+ np_ans = np.array(x)
+ with self.test_session(use_gpu=True):
+ tf_ans = tf.convert_to_tensor(x).eval()
+ if np_ans.dtype in [np.float32, np.float64, np.complex64]:
+ self.assertAllClose(np_ans, tf_ans)
+ else:
+ self.assertAllEqual(np_ans, tf_ans)
+
+ def _testAll(self, x):
+ self._testCpu(x)
+ self._testGpu(x)
+
+ def testFloat(self):
+ self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32))
+ self._testAll(
+ np.random.normal(size=30).reshape([2, 3, 5]).astype(np.float32))
+ self._testAll(np.empty((2, 0, 5)).astype(np.float32))
+
+ def testDouble(self):
+ self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float64))
+ self._testAll(
+ np.random.normal(size=30).reshape([2, 3, 5]).astype(np.float64))
+ self._testAll(np.empty((2, 0, 5)).astype(np.float64))
+
+ def testInt32(self):
+ self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.int32))
+ self._testAll(
+ (100 * np.random.normal(size=30)).reshape([2, 3, 5]).astype(np.int32))
+ self._testAll(np.empty((2, 0, 5)).astype(np.int32))
+
+ def testInt64(self):
+ self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.int64))
+ self._testAll(
+ (100 * np.random.normal(size=30)).reshape([2, 3, 5]).astype(np.int64))
+ self._testAll(np.empty((2, 0, 5)).astype(np.int64))
+
+ def testSComplex(self):
+ self._testAll(
+ np.complex(1, 2) * np.arange(-15, 15).reshape([2, 3, 5]).astype(
+ np.complex64))
+ self._testAll(np.complex(
+ 1, 2) * np.random.normal(size=30).reshape([2, 3, 5]).astype(
+ np.complex64))
+ self._testAll(np.empty((2, 0, 5)).astype(np.complex64))
+
+ def testString(self):
+ self._testCpu(np.array([str(x) for x in np.arange(-15, 15)]).reshape(
+ [2, 3, 5]))
+ self._testCpu(np.empty((2, 0, 5)).astype(np.str_))
+
+ def testStringWithNulls(self):
+ with self.test_session():
+ val = tf.convert_to_tensor("\0\0\0\0").eval()
+ self.assertEqual(len(val), 4)
+ self.assertEqual(val, "\0\0\0\0")
+
+ with self.test_session():
+ val = tf.convert_to_tensor("xx\0xx").eval()
+ self.assertEqual(len(val), 5)
+ self.assertAllEqual(val, "xx\0xx")
+ nested = [["\0\0\0\0", "xx\0xx"], ["\0_\0_\0_\0", "\0"]]
+
+ with self.test_session():
+ val = tf.convert_to_tensor(nested).eval()
+ # NOTE(mrry): Do not use assertAllEqual, because it converts nested to a
+ # numpy array, which loses the null terminators.
+ self.assertEqual(val.tolist(), nested)
+
+ def testExplicitShapeNumPy(self):
+ with tf.Graph().as_default():
+ c = tf.constant(
+ np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32),
+ shape=[2, 3, 5])
+ self.assertEqual(c.get_shape(), [2, 3, 5])
+
+ def testImplicitShapeNumPy(self):
+ with tf.Graph().as_default():
+ c = tf.constant(
+ np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32))
+ self.assertEqual(c.get_shape(), [2, 3, 5])
+
+ def testExplicitShapeList(self):
+ with tf.Graph().as_default():
+ c = tf.constant([1, 2, 3, 4, 5, 6, 7], shape=[7])
+ self.assertEqual(c.get_shape(), [7])
+
+ def testImplicitShapeList(self):
+ with tf.Graph().as_default():
+ c = tf.constant([1, 2, 3, 4, 5, 6, 7])
+ self.assertEqual(c.get_shape(), [7])
+
+ def testExplicitShapeNumber(self):
+ with tf.Graph().as_default():
+ c = tf.constant(1, shape=[1])
+ self.assertEqual(c.get_shape(), [1])
+
+ def testImplicitShapeNumber(self):
+ with tf.Graph().as_default():
+ c = tf.constant(1)
+ self.assertEqual(c.get_shape(), [])
+
+ def testShapeInconsistent(self):
+ with tf.Graph().as_default():
+ c = tf.constant([1, 2, 3, 4, 5, 6, 7], shape=[10])
+ self.assertEqual(c.get_shape(), [10])
+
+ # pylint: disable=g-long-lambda
+ def testShapeWrong(self):
+ with tf.Graph().as_default():
+ with self.assertRaisesWithPredicateMatch(
+ ValueError,
+ lambda e: ("Too many elements provided. Needed at most 5, "
+ "but received 7" == str(e))):
+ tf.constant([1, 2, 3, 4, 5, 6, 7], shape=[5])
+ # pylint: enable=g-long-lambda
+
+ def testTooLargeConstant(self):
+ with tf.Graph().as_default():
+ large_array = np.zeros((512, 1024, 1024), dtype=np.float32)
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Cannot create an Operation with a NodeDef larger than 2GB."):
+ c = tf.constant(large_array)
+
+ def testTooLargeGraph(self):
+ with tf.Graph().as_default() as g:
+ large_array = np.zeros((256, 1024, 1024), dtype=np.float32)
+ c = tf.constant(large_array)
+ d = tf.constant(large_array)
+ with self.assertRaisesRegexp(
+ ValueError, "GraphDef cannot be larger than 2GB."):
+ g.as_graph_def()
+
+ def testSparseValuesRaiseErrors(self):
+ with self.assertRaisesRegexp(ValueError,
+ "setting an array element with a sequence"):
+ c = tf.constant([[1, 2], [3]], dtype=tf.int32)
+
+ with self.assertRaisesRegexp(ValueError, "must be a dense"):
+ c = tf.constant([[1, 2], [3]])
+
+ with self.assertRaisesRegexp(ValueError, "must be a dense"):
+ c = tf.constant([[1, 2], [3], [4, 5]])
+
+
+class AsTensorTest(tf.test.TestCase):
+
+ def testAsTensorForTensorInput(self):
+ with tf.Graph().as_default():
+ t = tf.constant(10.0)
+ x = tf.convert_to_tensor(t)
+ self.assertIs(t, x)
+
+ def testAsTensorForNonTensorInput(self):
+ with tf.Graph().as_default():
+ x = tf.convert_to_tensor(10.0)
+ self.assertTrue(isinstance(x, tf.Tensor))
+
+ def testAsTensorForShapeInput(self):
+ with self.test_session():
+ x = tf.convert_to_tensor(tf.TensorShape([]))
+ self.assertEqual(tf.int32, x.dtype)
+ self.assertAllEqual([], x.eval())
+
+ x = tf.convert_to_tensor(tf.TensorShape([1, 2, 3]))
+ self.assertEqual(tf.int32, x.dtype)
+ self.assertAllEqual([1, 2, 3], x.eval())
+
+ x = tf.convert_to_tensor(tf.TensorShape([1, 2, 3]), dtype=tf.int64)
+ self.assertEqual(tf.int64, x.dtype)
+ self.assertAllEqual([1, 2, 3], x.eval())
+
+ x = tf.reshape(tf.zeros([6]), tf.TensorShape([2, 3]))
+ self.assertAllEqual([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], x.eval())
+
+ with self.assertRaisesRegexp(ValueError, "partially known"):
+ tf.convert_to_tensor(tf.TensorShape(None))
+
+ with self.assertRaisesRegexp(ValueError, "partially known"):
+ tf.convert_to_tensor(tf.TensorShape([1, None, 64]))
+
+ with self.assertRaises(TypeError):
+ tf.convert_to_tensor(tf.TensorShape([1, 2, 3]), dtype=tf.float32)
+
+ def testAsTensorForDimensionInput(self):
+ with self.test_session():
+ x = tf.convert_to_tensor(tf.TensorShape([1, 2, 3])[1])
+ self.assertEqual(tf.int32, x.dtype)
+ self.assertAllEqual(2, x.eval())
+
+ x = tf.convert_to_tensor(tf.TensorShape([1, 2, 3])[1], dtype=tf.int64)
+ self.assertEqual(tf.int64, x.dtype)
+ self.assertAllEqual(2, x.eval())
+
+ with self.assertRaisesRegexp(ValueError, "unknown Dimension"):
+ tf.convert_to_tensor(tf.TensorShape(None)[1])
+
+ with self.assertRaisesRegexp(ValueError, "unknown Dimension"):
+ tf.convert_to_tensor(tf.TensorShape([1, None, 64])[1])
+
+ with self.assertRaises(TypeError):
+ tf.convert_to_tensor(tf.TensorShape([1, 2, 3])[1], dtype=tf.float32)
+
+
+class IdentityOpTest(tf.test.TestCase):
+
+ def testIdTensor(self):
+ with tf.Graph().as_default():
+ x = tf.constant(2.0, shape=[6], name="input")
+ id_op = tf.identity(x, name="id")
+ self.assertTrue(isinstance(id_op.op.inputs[0], tf.Tensor))
+ self.assertProtoEquals(
+ "name: 'id' op: 'Identity' input: 'input' "
+ "attr { key: 'T' value { type: DT_FLOAT } }", id_op.op.node_def)
+
+
+class ZerosTest(tf.test.TestCase):
+
+ def _Zeros(self, shape):
+ with self.test_session():
+ ret = tf.zeros(shape)
+ self.assertEqual(shape, ret.get_shape())
+ return ret.eval()
+
+ def testConst(self):
+ self.assertTrue(np.array_equal(self._Zeros([2, 3]), np.array([[0] * 3] *
+ 2)))
+
+ def testDynamicSizes(self):
+ np_ans = np.array([[0] * 3] * 2)
+ with self.test_session():
+ # Creates a tensor of 2 x 3.
+ d = tf.fill([2, 3], 12., name="fill")
+ # Constructs a tensor of zeros of the same dimensions as "d".
+ z = tf.zeros(tf.shape(d))
+ out = z.eval()
+ self.assertAllEqual(np_ans, out)
+ self.assertShapeEqual(np_ans, d)
+ self.assertShapeEqual(np_ans, z)
+
+ def testDtype(self):
+ with self.test_session():
+ d = tf.fill([2, 3], 12., name="fill")
+ self.assertEqual(d.get_shape(), [2, 3])
+ # Test default type for both constant size and dynamic size
+ z = tf.zeros([2, 3])
+ self.assertEquals(z.dtype, tf.float32)
+ self.assertEqual([2, 3], z.get_shape())
+ z = tf.zeros(tf.shape(d))
+ self.assertEquals(z.dtype, tf.float32)
+ self.assertEqual([2, 3], z.get_shape())
+ # Test explicit type control
+ for dtype in [tf.float32, tf.float64, tf.int32,
+ tf.uint8, tf.int16, tf.int8,
+ tf.complex64, tf.int64]:
+ z = tf.zeros([2, 3], dtype=dtype)
+ self.assertEquals(z.dtype, dtype)
+ self.assertEquals([2, 3], z.get_shape())
+ z = tf.zeros(tf.shape(d), dtype=dtype)
+ self.assertEquals(z.dtype, dtype)
+ self.assertEquals([2, 3], z.get_shape())
+
+
+class ZerosLikeTest(tf.test.TestCase):
+
+ def testZerosLike(self):
+ for dtype in [tf.float32, tf.float64, tf.int32,
+ tf.uint8, tf.int16, tf.int8,
+ tf.complex64, tf.int64]:
+ numpy_dtype = dtype.as_numpy_dtype
+ with self.test_session():
+ # Creates a tensor of non-zero values with shape 2 x 3.
+ d = tf.constant(np.ones((2, 3), dtype=numpy_dtype), dtype=dtype)
+ # Constructs a tensor of zeros of the same dimensions and type as "d".
+ z_var = tf.zeros_like(d)
+ # Test that the type is correct
+ self.assertEquals(z_var.dtype, dtype)
+ z_value = z_var.eval()
+
+ # Test that the value is correct
+ self.assertTrue(np.array_equal(z_value, np.array([[0] * 3] * 2)))
+ self.assertEqual([2, 3], z_var.get_shape())
+
+ def testGenZerosLike(self):
+ for dtype in [tf.float32, tf.float64, tf.int32,
+ tf.uint8, tf.int16, tf.int8,
+ tf.complex64, tf.int64]:
+ numpy_dtype = dtype.as_numpy_dtype
+ with self.test_session():
+ # Creates a tensor of non-zero values with shape 2 x 3.
+ d = tf.constant(np.ones((2, 3), dtype=numpy_dtype), dtype=dtype)
+ # Constructs a tensor of zeros of the same dimensions and type as "d".
+ z_var = gen_array_ops._zeros_like(d)
+ # Test that the type is correct
+ self.assertEquals(z_var.dtype, dtype)
+ z_value = z_var.eval()
+
+ # Test that the value is correct
+ self.assertTrue(np.array_equal(z_value, np.array([[0] * 3] * 2)))
+ self.assertEqual([2, 3], z_var.get_shape())
+
+
+class OnesTest(tf.test.TestCase):
+
+ def _Ones(self, shape):
+ with self.test_session():
+ ret = tf.ones(shape)
+ self.assertEqual(shape, ret.get_shape())
+ return ret.eval()
+
+ def testConst(self):
+ self.assertTrue(np.array_equal(self._Ones([2, 3]), np.array([[1] * 3] * 2)))
+
+ def testDynamicSizes(self):
+ np_ans = np.array([[1] * 3] * 2)
+ with self.test_session():
+ # Creates a tensor of 2 x 3.
+ d = tf.fill([2, 3], 12., name="fill")
+ # Constructs a tensor of ones of the same dimensions as "d".
+ z = tf.ones(tf.shape(d))
+ out = z.eval()
+ self.assertAllEqual(np_ans, out)
+ self.assertShapeEqual(np_ans, d)
+ self.assertShapeEqual(np_ans, z)
+
+ def testDtype(self):
+ with self.test_session():
+ d = tf.fill([2, 3], 12., name="fill")
+ self.assertEqual(d.get_shape(), [2, 3])
+ # Test default type for both constant size and dynamic size
+ z = tf.ones([2, 3])
+ self.assertEquals(z.dtype, tf.float32)
+ self.assertEqual([2, 3], z.get_shape())
+ z = tf.ones(tf.shape(d))
+ self.assertEquals(z.dtype, tf.float32)
+ self.assertEqual([2, 3], z.get_shape())
+ # Test explicit type control
+ for dtype in [tf.float32, tf.float64, tf.int32,
+ tf.uint8, tf.int16, tf.int8,
+ tf.complex64, tf.int64]:
+ z = tf.ones([2, 3], dtype=dtype)
+ self.assertEquals(z.dtype, dtype)
+ self.assertEqual([2, 3], z.get_shape())
+ z = tf.ones(tf.shape(d), dtype=dtype)
+ self.assertEquals(z.dtype, dtype)
+ self.assertEqual([2, 3], z.get_shape())
+
+
+class OnesLikeTest(tf.test.TestCase):
+
+ def testOnesLike(self):
+ for dtype in [tf.float32, tf.float64, tf.int32,
+ tf.uint8, tf.int16, tf.int8,
+ tf.complex64, tf.int64]:
+ numpy_dtype = dtype.as_numpy_dtype
+ with self.test_session():
+ # Creates a tensor of non-zero values with shape 2 x 3.
+ d = tf.constant(np.ones((2, 3), dtype=numpy_dtype), dtype=dtype)
+ # Constructs a tensor of zeros of the same dimensions and type as "d".
+ z_var = tf.ones_like(d)
+ # Test that the type is correct
+ self.assertEquals(z_var.dtype, dtype)
+ z_value = z_var.eval()
+
+ # Test that the value is correct
+ self.assertTrue(np.array_equal(z_value, np.array([[1] * 3] * 2)))
+ self.assertEqual([2, 3], z_var.get_shape())
+
+ def testGenOnesLike(self):
+ for dtype in [tf.float32, tf.float64, tf.int32,
+ tf.uint8, tf.int16, tf.int8,
+ tf.complex64, tf.int64]:
+ numpy_dtype = dtype.as_numpy_dtype
+ with self.test_session():
+ # Creates a tensor of non-zero values with shape 2 x 3.
+ d = tf.constant(np.ones((2, 3), dtype=numpy_dtype), dtype=dtype)
+ # Constructs a tensor of zeros of the same dimensions and type as "d".
+ z_var = tf.ones_like(d)
+ # Test that the type is correct
+ self.assertEquals(z_var.dtype, dtype)
+ z_value = z_var.eval()
+
+ # Test that the value is correct
+ self.assertTrue(np.array_equal(z_value, np.array([[1] * 3] * 2)))
+ self.assertEqual([2, 3], z_var.get_shape())
+
+
+class FillTest(tf.test.TestCase):
+
+ def _compare(self, dims, val, np_ans, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ tf_ans = tf.fill(dims, val, name="fill")
+ out = tf_ans.eval()
+ self.assertAllClose(np_ans, out)
+ # Fill does not set the shape.
+ # self.assertShapeEqual(np_ans, tf_ans)
+
+ def _compareAll(self, dims, val, np_ans):
+ self._compare(dims, val, np_ans, False)
+ self._compare(dims, val, np_ans, True)
+
+ def testFillFloat(self):
+ np_ans = np.array([[3.1415] * 3] * 2).astype(np.float32)
+ self._compareAll([2, 3], np_ans[0][0], np_ans)
+
+ def testFillDouble(self):
+ np_ans = np.array([[3.1415] * 3] * 2).astype(np.float64)
+ self._compareAll([2, 3], np_ans[0][0], np_ans)
+
+ def testFillInt32(self):
+ np_ans = np.array([[42] * 3] * 2).astype(np.int32)
+ self._compareAll([2, 3], np_ans[0][0], np_ans)
+
+ def testFillInt64(self):
+ np_ans = np.array([[-42] * 3] * 2).astype(np.int64)
+ self._compareAll([2, 3], np_ans[0][0], np_ans)
+
+ def testFillComplex(self):
+ np_ans = np.array([[0.15] * 3] * 2).astype(np.complex64)
+ self._compare([2, 3], np_ans[0][0], np_ans, use_gpu=False)
+
+ def testFillString(self):
+ np_ans = np.array([["yolo"] * 3] * 2)
+ with self.test_session(use_gpu=False):
+ tf_ans = tf.fill([2, 3], np_ans[0][0], name="fill").eval()
+ self.assertAllEqual(np_ans, tf_ans)
+
+ def testShapeFunctionEdgeCases(self):
+ # Non-vector dimensions.
+ with self.assertRaises(ValueError):
+ tf.fill([[0, 1], [2, 3]], 1.0)
+
+ # Non-scalar value.
+ with self.assertRaises(ValueError):
+ tf.fill([3, 2], [1.0, 2.0])
+
+ # Partial dimension information.
+ f = tf.fill(
+ tf.placeholder(tf.int32, shape=(4,)), 3.0)
+ self.assertEqual([None, None, None, None], f.get_shape().as_list())
+
+
+class PlaceholderTest(tf.test.TestCase):
+
+ def testDtype(self):
+ with self.test_session():
+ p = tf.placeholder(tf.float32, name="p")
+ p_identity = tf.identity(p)
+ feed_array = np.random.rand(10, 10)
+ self.assertAllClose(p_identity.eval(feed_dict={p: feed_array}),
+ feed_array)
+
+ with self.assertRaisesOpError(
+ "must feed a value for placeholder tensor 'p' with dtype float"):
+ p_identity.eval()
+
+ def testShape(self):
+ with self.test_session():
+ p = tf.placeholder(tf.float32, shape=(10, 10), name="p")
+ p_identity = tf.identity(p)
+ feed_array = np.random.rand(10, 10)
+ self.assertAllClose(p_identity.eval(feed_dict={p: feed_array}),
+ feed_array)
+
+ with self.assertRaisesOpError(
+ "must feed a value for placeholder tensor 'p' with dtype float and "
+ "shape dim { size: 10 } dim { size: 10 }"):
+ p_identity.eval()
+
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, lambda e: "Cannot feed value of shape" in e.message):
+ p_identity.eval(feed_dict={p: feed_array[:5, :5]})
+
+ def testPartialShape(self):
+ with self.test_session():
+ p = tf.placeholder(tf.float32, shape=[None, 3], name="p")
+ p_identity = tf.identity(p)
+ feed_array = np.random.rand(10, 3)
+ self.assertAllClose(p_identity.eval(feed_dict={p: feed_array}),
+ feed_array)
+
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, lambda e: "Cannot feed value of shape" in e.message):
+ p_identity.eval(feed_dict={p: feed_array[:5, :2]})
+
+ def testControlDependency(self):
+ with self.test_session():
+ p = tf.placeholder(tf.int32, shape=[], name="p")
+ with tf.control_dependencies([p]):
+ c = tf.constant(5, tf.int32)
+ d = tf.mul(p, c)
+ self.assertEqual(10, d.eval(feed_dict={p: 2}))
+
+ def testFillNegative(self):
+ with self.test_session():
+ for shape in (-1,), (2, -1), (-1, 2):
+ with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
+ " must be nonnegative"):
+ tf.fill(shape, 7).eval()
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
new file mode 100644
index 0000000000..adf3552739
--- /dev/null
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -0,0 +1,1260 @@
+# pylint: disable=g-long-lambda
+"""Tests for tensorflow.ops.control_flow_ops."""
+import math
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.pywrap_tensorflow import StatusNotOK
+
+def check_op_order(graph):
+ """Sanity check on the ordering of op id."""
+
+ for op in graph.get_operations():
+ for v in op.inputs:
+ assert v.op._id < op._id or op.type == "Merge", (
+ "The id of %s must be less than the id of %s" % (v.op.name, op.name))
+ return True
+
+
+def check_consumers(graph):
+ """Sanity check on the consumer list of the tensors."""
+
+ consumer_count = {}
+ for op in graph.get_operations():
+ for v in op.inputs:
+ cnt = consumer_count.get(v, 0)
+ consumer_count[v] = cnt + 1
+ for k, v in consumer_count.iteritems():
+ if len(k.consumers()) != v:
+ return False
+ return True
+
+
+def isum(s):
+ i = tf.constant(0, name="i")
+ c = lambda i, s: tf.less(i, 10)
+ b = lambda i, s: [tf.add(i, 1), tf.add(i, s)]
+ _, r_s = control_flow_ops.While(c, b, [i, s])
+ return r_s
+
+
+class ControlFlowTest(tf.test.TestCase):
+
+ def testRefIdentity(self):
+ with self.test_session():
+ v = tf.Variable(7)
+
+ v = control_flow_ops._Identity(v)
+ op = tf.assign(v, 9)
+ v2 = control_flow_ops.with_dependencies([op], v)
+
+ self.assertTrue(check_op_order(v.graph))
+ self.assertTrue(isinstance(v2, tf.Tensor))
+ tf.initialize_all_variables().run()
+ self.assertEqual(9, v2.eval())
+
+ def testRefEnter(self):
+ with self.test_session():
+ v = tf.Variable(7)
+
+ enter_v = control_flow_ops._Enter(v, "foo_1")
+ nine = tf.constant(9)
+ enter_nine = control_flow_ops.enter(nine, "foo_1")
+ op = tf.assign(enter_v, enter_nine)
+ v2 = control_flow_ops.with_dependencies([op], enter_v)
+ v3 = control_flow_ops.exit(v2)
+ tf.initialize_all_variables().run()
+ self.assertEqual(9, v3.eval())
+
+ def testRefSwitch(self):
+ with self.test_session():
+ v = tf.Variable(7)
+
+ p = tf.constant(True)
+ v1 = control_flow_ops._SwitchRefOrTensor(v, p)
+ v2 = tf.assign(v1[1], 9)
+ tf.initialize_all_variables().run()
+ self.assertEqual(9, v2.eval())
+
+ def testEnterExit_1(self):
+ with self.test_session():
+ data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+ enter_op = control_flow_ops.enter(data, "foo_1", False)
+ exit_op = control_flow_ops.exit(enter_op)
+
+ result = exit_op.eval()
+ self.assertAllEqual(np.array([1, 2, 3, 4, 5, 6]), result)
+
+ def testEnterMulExit_1(self):
+ with self.test_session():
+ data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+ enter_data = control_flow_ops.enter(data, "foo_1", False)
+ five = tf.constant(5)
+ enter_five = control_flow_ops.enter(five, "foo_1", False)
+ mul_op = tf.mul(enter_data, enter_five)
+ exit_op = control_flow_ops.exit(mul_op)
+
+ result = exit_op.eval()
+ self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
+
+ def testEnterNextExit_1(self):
+ with self.test_session():
+ data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+ enter_op = control_flow_ops.enter(data, "foo_1", False)
+ next_op = control_flow_ops.next_iteration(enter_op)
+ exit_op = control_flow_ops.exit(next_op)
+
+ result = exit_op.eval()
+ self.assertAllEqual(np.array([1, 2, 3, 4, 5, 6]), result)
+
+ def testSwitchMergeIndexedSlices(self):
+ with self.test_session():
+ values = tf.constant([1, 2, 3, 4, 5, 6])
+ indices = tf.constant([0, 2, 4, 6, 8, 10])
+ data = tf.IndexedSlices(values, indices)
+ pred = tf.convert_to_tensor(True)
+ switch_op = control_flow_ops.switch(data, pred)
+ merge_op = control_flow_ops.merge(switch_op)[0]
+
+ val = merge_op.values.eval()
+ ind = merge_op.indices.eval()
+ self.assertAllEqual(np.arange(1, 7), val)
+ self.assertAllEqual(np.arange(0, 12, 2), ind)
+
+ def _testSwitchMerge_1(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+ ports = tf.convert_to_tensor(True, name="ports")
+ switch_op = control_flow_ops.switch(data, ports)
+ merge_op = control_flow_ops.merge(switch_op)[0]
+
+ result = merge_op.eval()
+ self.assertAllEqual(np.arange(1, 7), result)
+
+ def testSwitchMerge_1(self):
+ self._testSwitchMerge_1(use_gpu=False)
+ self._testSwitchMerge_1(use_gpu=True)
+
+ def testSwitchDeadBranch(self):
+ with self.test_session():
+ data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+ ports = tf.convert_to_tensor(True, name="ports")
+ switch_op = control_flow_ops.switch(data, ports)
+ dead_branch = tf.identity(switch_op[0])
+
+ with self.assertRaisesWithPredicateMatch(
+ StatusNotOK, lambda e: 'The tensor returned for' in str(e)):
+ dead_branch.eval()
+
+ def testSwitchMergeIdentity_1(self):
+ with self.test_session():
+ data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+ ports = tf.convert_to_tensor(True, name="ports")
+ switch_op = control_flow_ops.switch(data, ports)
+ merge_op = control_flow_ops.merge(switch_op)[0]
+ id_op = tf.identity(merge_op)
+
+ result = id_op.eval()
+ self.assertAllEqual(np.arange(1, 7), result)
+
+ def testSwitchMergeLess_0(self):
+ with self.test_session():
+ data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+ zero = tf.constant(0)
+ one = tf.constant(1)
+ less_op = tf.less(zero, one)
+ switch_op = control_flow_ops.switch(data, less_op)
+ merge_op = control_flow_ops.merge(switch_op)[0]
+
+ result = merge_op.eval()
+ self.assertAllEqual(np.arange(1, 7), result)
+
+ def testSwitchMergeLess_1(self):
+ with self.test_session():
+ data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+ zero = tf.convert_to_tensor(0)
+ one = tf.convert_to_tensor(1)
+ less_op = tf.less(zero, one)
+ switch_op = control_flow_ops.switch(data, less_op)
+ merge_op = control_flow_ops.merge(switch_op)[0]
+
+ result = merge_op.eval()
+ self.assertAllEqual(np.arange(1, 7), result)
+
+ def testSwitchMergeAddIdentity_0(self):
+ with self.test_session():
+ data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+ ports = tf.convert_to_tensor(False, name="ports")
+ switch_op = control_flow_ops.switch(data, ports)
+ one = tf.constant(1)
+ add_op = tf.add(switch_op[0], one)
+ id_op = tf.identity(switch_op[1])
+ merge_op = control_flow_ops.merge([add_op, id_op])[0]
+
+ result = merge_op.eval()
+ self.assertAllEqual(np.array([x + 1 for x in [1, 2, 3, 4, 5, 6]]), result)
+
+ def testSwitchMergeAddIdentity_1(self):
+ with self.test_session():
+ data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+ ports = tf.convert_to_tensor(True, name="ports")
+ switch_op = control_flow_ops.switch(data, ports)
+ one = tf.constant(1)
+ add_op = tf.add(switch_op[0], one)
+ id_op = tf.identity(switch_op[1])
+ merge_op = control_flow_ops.merge([add_op, id_op])[0]
+
+ result = merge_op.eval()
+ self.assertAllEqual(np.arange(1, 7), result)
+
+ def testSwitchMergeAddMul_0(self):
+ with self.test_session():
+ data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+ ports = tf.convert_to_tensor(False, name="ports")
+ switch_op = control_flow_ops.switch(data, ports)
+ one = tf.constant(1)
+ add_op = tf.add(switch_op[0], one)
+ five = tf.constant(5)
+ mul_op = tf.mul(switch_op[1], five)
+ merge_op = control_flow_ops.merge([add_op, mul_op])[0]
+
+ result = merge_op.eval()
+ self.assertAllEqual(np.array([x + 1 for x in [1, 2, 3, 4, 5, 6]]), result)
+
+ def testSwitchMergeAddMul_1(self):
+ with self.test_session():
+ data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+ ports = tf.convert_to_tensor(True, name="ports")
+ switch_op = control_flow_ops.switch(data, ports)
+ one = tf.constant(1)
+ add_op = tf.add(switch_op[0], one)
+ five = tf.constant(5)
+ mul_op = tf.mul(switch_op[1], five)
+ merge_op = control_flow_ops.merge([add_op, mul_op])[0]
+
+ result = merge_op.eval()
+ self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
+
+ def testLoop_false(self):
+ with self.test_session():
+ false = tf.convert_to_tensor(False)
+ n = tf.constant(10)
+
+ enter_false = control_flow_ops.enter(false, "foo_1", False)
+ enter_n = control_flow_ops.enter(n, "foo_1", False)
+
+ merge_n = control_flow_ops.merge([enter_n], name="merge_n")[0]
+ switch_n = control_flow_ops.switch(merge_n, enter_false)
+ exit_n = control_flow_ops.exit(switch_n[0])
+
+ result = exit_n.eval()
+ self.assertAllEqual(10, result)
+
+ def testLoop_false_1(self):
+ with self.test_session():
+ false = tf.convert_to_tensor(False)
+ n = tf.constant(10)
+
+ enter_false = control_flow_ops.enter(false, "foo_1", False)
+ enter_n = control_flow_ops.enter(n, "foo_1", False)
+
+ merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0]
+ switch_n = control_flow_ops.switch(merge_n, enter_false)
+ exit_n = control_flow_ops.exit(switch_n[0])
+ next_n = control_flow_ops.next_iteration(switch_n[0])
+ merge_n.op._update_input(1, next_n)
+
+ result = exit_n.eval()
+ self.assertAllEqual(10, result)
+
+ def testLoop_1(self):
+ with self.test_session():
+ zero = tf.convert_to_tensor(0)
+ one = tf.convert_to_tensor(1)
+ n = tf.constant(10)
+
+ enter_zero = control_flow_ops.enter(zero, "foo_1", False)
+ enter_one = control_flow_ops.enter(one, "foo_1", False)
+ enter_n = control_flow_ops.enter(n, "foo_1", False)
+ merge_zero = control_flow_ops.merge([enter_zero, enter_zero],
+ name="merge_zero")[0]
+ merge_one = control_flow_ops.merge([enter_one, enter_one],
+ name="merge_one")[0]
+ merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0]
+ less_op = tf.less(merge_n, merge_n)
+ cond_op = control_flow_ops.loop_cond(less_op)
+ switch_zero = control_flow_ops.switch(merge_zero, cond_op)
+ switch_one = control_flow_ops.switch(merge_one, cond_op)
+ switch_n = control_flow_ops.switch(merge_n, cond_op)
+ next_zero = control_flow_ops.next_iteration(switch_zero[1])
+ next_one = control_flow_ops.next_iteration(switch_one[1])
+ next_n = control_flow_ops.next_iteration(switch_n[1])
+ merge_zero.op._update_input(1, next_zero)
+ merge_one.op._update_input(1, next_one)
+ merge_n.op._update_input(1, next_n)
+ exit_n = control_flow_ops.exit(switch_n[0])
+
+ result = exit_n.eval()
+ self.assertAllEqual(10, result)
+
+ def testCondIndexedSlices(self):
+ with self.test_session():
+ values = tf.constant(10)
+ indices = tf.constant(0)
+ x = tf.IndexedSlices(values, indices)
+ pred = tf.less(1, 2)
+ fn1 = lambda: tf.IndexedSlices(tf.add(x.values, 1), indices)
+ fn2 = lambda: tf.IndexedSlices(tf.sub(x.values, 1), indices)
+ r = control_flow_ops.cond(pred, fn1, fn2)
+
+ val = r.values.eval()
+ ind = r.indices.eval()
+ self.assertTrue(check_op_order(x.values.graph))
+ self.assertAllEqual(11, val)
+ self.assertAllEqual(0, ind)
+
+ def testCondIndexedSlicesDifferentTypes(self):
+ with self.test_session():
+ values = tf.constant(10)
+ i_32 = tf.convert_to_tensor(0, name="one", dtype=tf.int32)
+ i_64 = tf.convert_to_tensor(0, name="one", dtype=tf.int64)
+ x = tf.IndexedSlices(values, i_32)
+ pred = tf.less(1, 2)
+ fn1 = lambda: tf.IndexedSlices(tf.add(x.values, 1), i_32)
+ fn2 = lambda: tf.IndexedSlices(tf.sub(x.values, 1), i_64)
+ r = control_flow_ops.cond(pred, fn1, fn2)
+
+ val = r.values.eval()
+ ind = r.indices.eval()
+ self.assertTrue(check_op_order(x.values.graph))
+ self.assertAllEqual(11, val)
+ self.assertAllEqual(0, ind)
+ self.assertTrue(ind.dtype == np.int64)
+
+ def _testCond_1(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ x = tf.constant(10)
+ pred = tf.less(1, 2)
+ fn1 = lambda: tf.add(x, 1)
+ fn2 = lambda: tf.sub(x, 1)
+ r = control_flow_ops.cond(pred, fn1, fn2)
+
+ result = r.eval()
+ self.assertTrue(check_op_order(x.graph))
+ self.assertAllEqual(11, result)
+
+ def testCond_1(self):
+ self._testCond_1(use_gpu=False)
+ self._testCond_1(use_gpu=True)
+
+ def testCond_2(self):
+ with self.test_session():
+ x = tf.constant(10)
+ r = control_flow_ops.cond(tf.less(1, 0), lambda: tf.add(x, 1),
+ lambda: tf.sub(x, 1))
+ result = r.eval()
+ self.assertTrue(check_op_order(x.graph))
+ self.assertAllEqual(9, result)
+
+ def testCond_3(self):
+ with self.test_session():
+ x = tf.constant(10)
+ pred = tf.less(1, 2)
+ fn1 = lambda: tf.add(x, 1)
+ fn2 = lambda: tf.sub(x, 1)
+ fn3 = lambda: tf.add(control_flow_ops.cond(pred, fn1, fn2), 1)
+ r = control_flow_ops.cond(pred, fn3, fn2)
+
+ result = r.eval()
+ self.assertTrue(check_op_order(x.graph))
+ self.assertAllEqual(12, result)
+
+ def testCond_4(self):
+ with self.test_session():
+ v1 = tf.Variable(7)
+ v2 = tf.Variable(7)
+ v3 = tf.Variable(7)
+
+ age = tf.constant(3)
+ max_age = tf.constant(2)
+ pred = tf.greater(age, max_age)
+ fn1 = lambda: [tf.assign(v1, 1).op, tf.assign(v2, 2).op]
+ fn2 = lambda: [tf.assign(v3, 3).op, tf.constant(10).op]
+ r = control_flow_ops.cond(pred, fn1, fn2)
+
+ tf.initialize_all_variables().run()
+ self.assertEqual(len(r), 2)
+ result = r[1].eval()
+ self.assertTrue(check_op_order(age.graph))
+ self.assertAllEqual(True, result)
+ self.assertAllEqual(7, v1.eval())
+ self.assertAllEqual(2, v2.eval())
+ self.assertAllEqual(7, v3.eval())
+
+ def testCond_5(self):
+ with self.test_session():
+ alive = tf.constant(True, name="alive")
+ count = tf.constant(0, name="count")
+
+ def body(i):
+ return control_flow_ops.cond(
+ alive, lambda: [tf.less(i, 3), tf.add(count, 1)],
+ lambda: [alive, count])
+
+ for i in range(10):
+ alive, count = body(i)
+ self.assertAllEqual(4, count.eval())
+
+ def testCond_6(self):
+ with self.test_session():
+ v1 = tf.Variable([7])
+
+ age = tf.constant(3)
+ pred = tf.greater(age, 4)
+ fn1 = lambda: age
+ fn2 = lambda: v1
+ r = control_flow_ops.cond(pred, fn1, fn2)
+
+ tf.initialize_all_variables().run()
+ result = r.eval()
+ self.assertAllEqual(np.array([7]), result)
+
+ def testCondGrad_1(self):
+ with self.test_session():
+ x = tf.constant(10.0, name="x")
+ pred = tf.less(1, 2)
+ fn1 = lambda: tf.identity(x)
+ fn2 = lambda: tf.identity(x)
+ r = control_flow_ops.cond(pred, fn1, fn2)
+
+ grad = tf.gradients(r, [x])[0]
+ result = grad.eval()
+ self.assertAllEqual(1.0, result)
+
+ def testCondGrad_2(self):
+ with self.test_session():
+ c = tf.placeholder(tf.int32, shape=[])
+ x = tf.constant(10.0)
+ pred = tf.less(c, 2)
+ fn1 = lambda: tf.mul(x, 42.0)
+ fn2 = lambda: tf.mul(x, 3.0)
+ r = control_flow_ops.cond(pred, fn1, fn2)
+
+ grad = tf.gradients(r, [x])[0]
+ self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1}))
+ self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
+
+ def testCondGrad_Gather(self):
+ with self.test_session() as sess:
+ v1 = tf.Variable([1.0, 42.0])
+ c = tf.placeholder(tf.int32, shape=[])
+ pred = tf.less(c, 2)
+ fn1 = lambda: tf.identity(v1)
+ fn2 = lambda: tf.gather(v1, [1, 1])
+ r = control_flow_ops.cond(pred, fn1, fn2)
+ grad = tf.gradients(r, [v1])[0]
+ tf.initialize_all_variables().run()
+ # Should just be [1, 1], but possibly a sparse representation
+ gv, gi = sess.run([grad.values, grad.indices], feed_dict={c: 1})
+ dense_gv = [sum([y for (x, y) in zip(gi, gv) if x == i]) for i in range(2)
+ ]
+ self.assertAllEqual(dense_gv, [1.0, 1.0])
+ # Should be [0, 2], as the else forwards v1[1] twice
+ gv, gi = sess.run([grad.values, grad.indices], feed_dict={c: 3})
+ dense_gv = [sum([y for (x, y) in zip(gi, gv) if x == i]) for i in range(2)
+ ]
+ self.assertAllEqual(dense_gv, [0.0, 2.0])
+
+ def testWhileGrad_1(self):
+ with self.test_session():
+ v = tf.constant(2.0, name="v")
+ c = lambda v: tf.less(v, 100.0)
+ b = tf.square
+ r = control_flow_ops.While(c, b, [v], parallel_iterations=1)
+
+ r = tf.gradients(r, v)
+ result = r[0].eval()
+ self.assertEqual(1024.0, result)
+
+ def testWhileGrad_2(self):
+ with self.test_session():
+ a = tf.constant(3.0, name="a")
+ v = tf.constant(2.0, name="v")
+ c = lambda v: tf.less(v, 100.0)
+ b = lambda v: tf.mul(v, a)
+ r = control_flow_ops.While(c, b, [v], parallel_iterations=1)
+
+ r = tf.gradients(r, a)
+ result = r[0].eval()
+ self.assertEqual(216.0, result)
+
+ def testWhileGrad_3(self):
+ with self.test_session():
+ a = tf.constant(3.0, name="a")
+ v = tf.constant(2.0, name="v")
+ c = lambda v: tf.less(v, 100.0)
+ b = lambda v: tf.mul(v, a)
+ r = control_flow_ops.While(c, b, [v], parallel_iterations=1)
+
+ r = tf.gradients(r, v)
+ result = r[0].eval()
+ self.assertEqual(81.0, result)
+
+ def testWhileGrad_4(self):
+ with self.test_session():
+ a = tf.Variable(3.0)
+ v = tf.constant(2.0, name="v")
+ c = lambda v: tf.less(v, 100.0)
+ b = lambda v: tf.mul(v, a)
+ r = control_flow_ops.While(c, b, [v], parallel_iterations=1)
+
+ r = tf.gradients(r, a)
+ tf.initialize_all_variables().run()
+ result = r[0].eval()
+ self.assertEqual(216.0, result)
+
+ def testWhileGrad_5(self):
+ with self.test_session():
+ x = tf.constant(3.0, name="x")
+ y = tf.constant(2.0, name="y")
+ c = lambda x, y: tf.less(x, 100.0)
+
+ def b(x, y):
+ y1 = tf.add(x, y)
+ x1 = tf.mul(x, y1)
+ return x1, y1
+
+ r = control_flow_ops.While(c, b, [x, y], parallel_iterations=1)
+
+ # Must use the complete r.
+ r = tf.gradients(r, x)
+ result = r[0].eval()
+ self.assertEqual(304.0, result)
+
+ def testWhileGrad_6(self):
+ with self.test_session():
+ i = tf.constant(0, name="i")
+ x = tf.constant(2.0, name="x")
+ c = lambda i, x: tf.less(i, 10)
+
+ def b(i, x):
+ x = tf.mul(x, 2.0)
+ i = tf.add(i, 1)
+ return i, x
+
+ r = control_flow_ops.While(c, b, [i, x], parallel_iterations=1)
+
+ # Must use the complete r.
+ r = tf.gradients(r, x)
+ r = r[0].eval()
+ self.assertEqual(1024.0, r)
+
+ def testWhileGrad_7(self):
+ with self.test_session():
+ v = tf.constant(2.0, name="v")
+ c = lambda v: tf.less(v, 100.0)
+ b = tf.square
+ r = control_flow_ops.While(c, b, [v], parallel_iterations=1,
+ back_prop=False)
+ r = tf.add(r, v)
+ r = tf.gradients(r, v)
+ result = r[0].eval()
+ self.assertEqual(1.0, result)
+
+ # Microbenchmark: 10,000 iterations took 0.21s.
+ def testWhile_1(self):
+ with self.test_session():
+ n = tf.constant(0)
+ c = lambda x: tf.less(x, 10000)
+ b = lambda x: tf.add(x, 1)
+ r = control_flow_ops.While(c, b, [n], parallel_iterations=20)
+
+ result = r.eval()
+ self.assertTrue(check_op_order(n.graph))
+ self.assertEqual(10000, result)
+
+ def testWhile_2(self):
+ with self.test_session():
+ s = tf.constant(0)
+ r = isum(s)
+
+ result = r.eval()
+ self.assertTrue(check_op_order(s.graph))
+ self.assertAllEqual(45, result)
+
+ # Have more than 10 parallel iterations and hence exercise k-bound
+ # most of the time.
+ def testWhile_3(self):
+ with self.test_session():
+
+ def compute(i, m, c, o):
+ m, c = [tf.add(m, 1), tf.add(c, 1)]
+ o = tf.add(o, m)
+ o = tf.add(o, c)
+ i = tf.add(i, 1)
+ return [i, m, c, o]
+
+ i = tf.convert_to_tensor(0)
+ m = tf.convert_to_tensor(0)
+ c = tf.convert_to_tensor(0)
+ o = tf.convert_to_tensor(0)
+ d = tf.convert_to_tensor(100)
+ r = control_flow_ops.While(
+ lambda i, m, c, o: tf.less(i, d), compute, [i, m, c, o])
+ result = r[3].eval()
+ self.assertTrue(check_op_order(i.graph))
+ self.assertAllEqual(10100, result)
+
+ def testWhile_4(self):
+ with self.test_session():
+
+ def compute(i, m, c, o):
+ m, c = [tf.gather(x, i), tf.gather(x, i)]
+ o = tf.add(o, m)
+ o = tf.add(o, c)
+ i = tf.add(i, 1)
+ return [i, m, c, o]
+
+ i = tf.convert_to_tensor(0)
+ m = tf.convert_to_tensor(0)
+ c = tf.convert_to_tensor(0)
+ o = tf.convert_to_tensor(0)
+ x = tf.convert_to_tensor([1, 2, 3, 4, 5, 6])
+ s = tf.size(x)
+ r = control_flow_ops.While(
+ lambda i, m, c, o: tf.less(i, s), compute, [i, m, c, o])
+ result = r[3].eval()
+ self.assertTrue(check_op_order(i.graph))
+ self.assertAllEqual(42, result)
+
+ def testWhile_5(self):
+ with self.test_session():
+
+ def compute(i, c, o):
+ c = tf.slice(x, tf.expand_dims(i, 0), [1])
+ o = tf.concat(0, [o, c])
+ i = tf.add(i, 1)
+ return [i, c, o]
+
+ i = tf.convert_to_tensor(0)
+ c = tf.convert_to_tensor(0)
+ o = tf.convert_to_tensor([0])
+ x = tf.convert_to_tensor([1, 2, 3, 4, 5, 6])
+ s = tf.size(x)
+ r = control_flow_ops.While(
+ lambda i, c, o: tf.less(i, s), compute, [i, c, o])
+ result = r[2].eval()
+ self.assertTrue(check_op_order(i.graph))
+ self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result)
+
+ def _testWhile_Gpu_1(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ n = tf.constant(1.0)
+ c = lambda x: tf.less(x, 10.0)
+ b = lambda x: tf.add(x, 1.0)
+ r = control_flow_ops.While(c, b, [n])
+
+ result = r.eval()
+ self.assertEqual(10.0, result)
+
+ def testWhile_Gpu_1(self):
+ self._testWhile_Gpu_1(use_gpu=False)
+ self._testWhile_Gpu_1(use_gpu=True)
+
+ def _testWhile_Gpu_2(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ n = tf.constant(1.0)
+ c = lambda x: tf.less(x, 10.0)
+ def b(x):
+ with tf.device("/cpu:0"):
+ return tf.add(x, 1.0)
+ r = control_flow_ops.While(c, b, [n])
+
+ result = r.eval()
+ self.assertEqual(10.0, result)
+
+ def testWhile_Gpu_2(self):
+ self._testWhile_Gpu_1(use_gpu=False)
+ self._testWhile_Gpu_1(use_gpu=True)
+
+ def testWhileWithControl_1(self):
+ with self.test_session():
+ n = tf.constant(0)
+ r = tf.constant(0)
+ condition = lambda n_, r_: tf.less(n_, 10)
+
+ def body(n_, r_):
+ n_ = tf.add(n_, 1)
+ with r_.graph.control_dependencies([r_]):
+ r_ = tf.constant(12)
+ return [n_, r_]
+
+ res = control_flow_ops.While(condition,
+ body,
+ [n, r],
+ parallel_iterations=1)
+ result = res[1].eval()
+ self.assertTrue(check_op_order(n.graph))
+ self.assertAllEqual(12, result)
+
+ def testWhileWithControl_2(self):
+ with self.test_session():
+ r = tf.constant(0)
+ condition = lambda r_: tf.less(r_, 10)
+
+ def body(r_):
+ with r_.graph.control_dependencies([r_]):
+ r_ = tf.constant(12)
+ return [r_]
+
+ res = control_flow_ops.While(condition, body, [r], parallel_iterations=1)
+ result = res.eval()
+ self.assertTrue(check_op_order(r.graph))
+ self.assertAllEqual(12, result)
+
+ def testCondWhile_1(self):
+ with self.test_session():
+ n = tf.convert_to_tensor(0, name="n")
+ c = lambda x: tf.less(x, 10)
+ b = lambda x: tf.add(x, 1)
+ r = control_flow_ops.cond(tf.less(0, 1),
+ lambda: control_flow_ops.While(c, b, [n]),
+ lambda: n)
+
+ result = r.eval()
+ self.assertTrue(check_op_order(n.graph))
+ self.assertAllEqual(10, result)
+
+ def testCondWhile_2(self):
+ with self.test_session():
+ n = tf.convert_to_tensor(0)
+ c = lambda x: tf.less(x, 10)
+ b = lambda x: tf.add(x, 1)
+ r = control_flow_ops.cond(tf.less(1, 0), lambda: tf.add(n, 1),
+ lambda: control_flow_ops.While(c, b, [n]))
+
+ result = r.eval()
+ self.assertTrue(check_op_order(n.graph))
+ self.assertAllEqual(10, result)
+
+ def testWhileCond_1(self):
+ with self.test_session():
+ i = tf.convert_to_tensor(0, name="i")
+ n = tf.convert_to_tensor(10, name="n")
+ one = tf.convert_to_tensor(1, name="one")
+ c = lambda x: tf.less(x, n)
+ b = lambda x: control_flow_ops.cond(tf.constant(True),
+ lambda: tf.add(x, one),
+ lambda: tf.sub(x, one))
+ r = control_flow_ops.While(c, b, [i])
+
+ result = r.eval()
+ self.assertTrue(check_op_order(n.graph))
+ self.assertAllEqual(10, result)
+
+ def testWhileCond_2(self):
+ with self.test_session():
+ n = tf.convert_to_tensor(0, name="n")
+ c = lambda x: tf.less(x, 10)
+ b = lambda x: control_flow_ops.cond(tf.constant(True),
+ lambda: tf.add(x, 1),
+ lambda: n)
+ r = control_flow_ops.While(c, b, [n])
+
+ result = r.eval()
+ self.assertTrue(check_op_order(n.graph))
+ self.assertAllEqual(10, result)
+
+ def testWhileCond_3(self):
+ with self.test_session():
+ n = tf.convert_to_tensor(0)
+ c = lambda x: tf.less(x, 10)
+ b = lambda x: control_flow_ops.cond(tf.less(0, 1),
+ lambda: tf.add(x, 1),
+ lambda: tf.sub(x, 1))
+ r = control_flow_ops.While(c, b, [n])
+
+ result = r.eval()
+ self.assertTrue(check_op_order(n.graph))
+ self.assertAllEqual(10, result)
+
+ # NOTE: It is ok to have parallel_iterations > 1
+ def testWhileUpdateVariable_1(self):
+ with self.test_session():
+ select = tf.Variable([3.0, 4.0, 5.0])
+ n = tf.constant(0)
+
+ def loop_iterator(j):
+ return tf.less(j, 3)
+
+ def loop_body(j):
+ ns = tf.scatter_update(select, j, 10.0)
+ nj = tf.add(j, 1)
+ op = control_flow_ops.group(ns)
+ nj = control_flow_ops.with_dependencies([op], nj)
+ return [nj]
+
+ r = control_flow_ops.While(loop_iterator,
+ loop_body,
+ [n],
+ parallel_iterations=1)
+ self.assertTrue(check_op_order(n.graph))
+ tf.initialize_all_variables().run()
+ self.assertEqual(3, r.eval())
+ result = select.eval()
+ self.assertAllEqual(np.array([10.0, 10.0, 10.0]), result)
+
+ def testWhileUpdateVariable_2(self):
+ with self.test_session():
+ select1 = tf.Variable([3.0, 4.0, 5.0])
+ select2 = tf.Variable([3.0, 4.0, 5.0])
+ n = tf.constant(0)
+
+ def loop_iterator(j):
+ return tf.less(j, 3)
+
+ def loop_body(j):
+ ns1 = tf.scatter_update(select1, j, 10.0)
+ ns2 = tf.scatter_update(select2, j, 10.0)
+ nj = tf.add(j, 1)
+ op = control_flow_ops.group(ns1, ns2)
+ nj = control_flow_ops.with_dependencies([op], nj)
+ return [nj]
+
+ r = control_flow_ops.While(loop_iterator,
+ loop_body,
+ [n],
+ parallel_iterations=1)
+ self.assertTrue(check_op_order(n.graph))
+ tf.initialize_all_variables().run()
+ self.assertEqual(3, r.eval())
+ result1 = select1.eval()
+ self.assertAllEqual(np.array([10.0, 10.0, 10.0]), result1)
+ result2 = select2.eval()
+ self.assertAllEqual(np.array([10.0, 10.0, 10.0]), result2)
+
+ def testWhileUpdateVariable_3(self):
+ with self.test_session():
+ select = tf.Variable([3.0, 4.0, 5.0])
+ n = tf.constant(0)
+
+ def loop_iterator(j, _):
+ return tf.less(j, 3)
+
+ def loop_body(j, _):
+ ns = tf.scatter_update(select, j, 10.0)
+ nj = tf.add(j, 1)
+ return [nj, ns]
+
+ r = control_flow_ops.While(loop_iterator,
+ loop_body,
+ [n, tf.identity(select)],
+ parallel_iterations=1)
+ tf.initialize_all_variables().run()
+ result = r[1].eval()
+ self.assertTrue(check_op_order(n.graph))
+ self.assertAllEqual(np.array([10.0, 10.0, 10.0]), result)
+
+ # b/24814703
+ def testWhileUpdateVariable_4(self):
+ with self.test_session():
+ var_a = tf.Variable(0, name="a")
+ var_b = tf.Variable(0, name="b")
+ tf.initialize_all_variables().run()
+
+ c = tf.constant(0, name="c")
+ asn1 = tf.assign_add(var_a, 1, name="a_add")
+ # Loop condition
+ def pred(i):
+ return tf.less(i, 10)
+ # Loop body
+ def loop_body(i):
+ asn2 = tf.assign_add(var_b, asn1, name="b_add")
+ with tf.control_dependencies([asn2]):
+ ni = tf.add(i, 1, name="i_add")
+ return ni
+
+ lpa = control_flow_ops.While(pred, loop_body, [c],
+ parallel_iterations=1)
+
+ self.assertEqual(0, var_b.eval())
+ lpa.eval() # Run the loop
+ self.assertEqual(10, var_b.eval())
+
+ # b/24736492
+ def testWhileUpdateVariable_5(self):
+ with self.test_session():
+ # Create some variables.
+ var_a = tf.Variable(0, name="a")
+ var_b = tf.Variable(0, name="b")
+ tf.initialize_all_variables().run()
+
+ # Change condition to check var_b
+ def pred(i):
+ return tf.less(var_b, 10)
+
+ # Change body to increment var_b
+ def loop_body(i):
+ asn1 = tf.assign_add(var_a, tf.constant(1), name="a_add")
+ asn2 = tf.assign_add(var_b, tf.constant(1), name="b_add")
+ with tf.control_dependencies([asn1, asn2]):
+ inc_b = tf.identity(var_b)
+ return inc_b
+
+ lpa = control_flow_ops.While(pred, loop_body, [var_b], 1, name="loop")
+
+ self.assertEqual(0, var_b.eval())
+ lpa.eval() # Run the loop
+ self.assertEqual(10, var_a.eval())
+ self.assertEqual(10, var_b.eval())
+
+ def testWhileQueue_1(self):
+ with self.test_session():
+ q = tf.FIFOQueue(-1, tf.int32)
+ i = tf.constant(0)
+
+ def c(i):
+ return tf.less(i, 10)
+
+ def b(i):
+ ni = tf.add(i, 1)
+ ni = control_flow_ops.with_dependencies([q.enqueue((i,))], ni)
+ return ni
+
+ r = control_flow_ops.While(c, b, [i], parallel_iterations=1)
+ self.assertEqual([10], r.eval())
+ for i in xrange(10):
+ self.assertEqual([i], q.dequeue().eval())
+
+ def testFold_1(self):
+ with self.test_session():
+ elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+ r = control_flow_ops.fold(
+ lambda a, x: tf.mul(tf.add(a, x), 2), elems, [1])
+ result = r.eval()
+ self.assertTrue(check_op_order(elems.graph))
+ self.assertAllEqual(np.array([208]), result)
+
+ def testFold_2(self):
+ with self.test_session():
+ elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
+ ten = tf.convert_to_tensor(10)
+
+ def compute(a, x):
+ r = tf.mul(x, ten)
+ return tf.add(a, r)
+
+ r = control_flow_ops.fold(compute, elems, [1])
+ result = r.eval()
+ self.assertTrue(check_op_order(elems.graph))
+ self.assertAllEqual([201], result)
+
+ def testOneValueCond(self):
+ with self.test_session():
+ c = tf.placeholder(tf.int32, shape=[])
+ one = tf.convert_to_tensor(1, name="one")
+ two = tf.convert_to_tensor(2, name="two")
+ p = tf.greater_equal(c, 1)
+ i = control_flow_ops.cond(p, lambda: one, lambda: two)
+ self.assertTrue(isinstance(i, tf.Tensor))
+
+ # True case: c = 2 is >= 1
+ self.assertEqual([1], i.eval(feed_dict={c: 2}))
+
+ # False case: c = 0 is not >= 1
+ self.assertEqual([2], i.eval(feed_dict={c: 0}))
+
+ def testExampleCond(self):
+ with self.test_session():
+ x = tf.convert_to_tensor([-2.0, 2.0], name="x")
+ d = tf.placeholder(tf.int32, shape=[])
+
+ def l2():
+ return tf.sqrt(tf.reduce_sum(tf.square(x)))
+
+ def l1():
+ return tf.reduce_sum(tf.abs(x))
+
+ i = control_flow_ops.cond(tf.equal(d, 2), l2, l1)
+ self.assertEqual(4.0, i.eval(feed_dict={d: 1}))
+ self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
+
+ def testOneOpCond(self):
+ with self.test_session():
+ v = tf.Variable(0)
+ c = tf.convert_to_tensor(0)
+ one = tf.convert_to_tensor(1)
+ two = tf.convert_to_tensor(2)
+ p = tf.greater_equal(c, 1)
+
+ def a():
+ return tf.assign(v, one)
+
+ def b():
+ return tf.assign(v, two)
+
+ i = control_flow_ops.cond(p, a, b)
+ self.assertTrue(isinstance(i, tf.Tensor))
+ tf.initialize_all_variables().run()
+
+ self.assertEqual(0, v.eval())
+
+ # True case: c = 2 is >= 1, v is set to 1.
+ self.assertEqual(1, i.eval(feed_dict={c.name: 2}))
+ self.assertEqual(1, v.eval())
+
+ # False case: c = 0 is not >= 1, v is set to 2.
+ self.assertEqual(2, i.eval(feed_dict={c.name: 0}))
+ self.assertEqual(2, v.eval())
+
+ def testWithOpsDependencies(self):
+ with self.test_session() as sess:
+ v = tf.Variable(0.0)
+ c = tf.constant(10)
+
+ # Fetching v directly will result in an uninitialized error
+ with self.assertRaisesOpError("Attempting to use uninitialized value"):
+ sess.run([c, v])
+
+ # Use a control dependency to ensure init_variable is run
+ # while asking for c
+ real_v = control_flow_ops.with_dependencies(name="real_tensor",
+ output_tensor=v,
+ dependencies=[v.initializer])
+ c_val, real_v_val = sess.run([c, real_v])
+
+ # Ensure the result of 'real_c' is the same as 'c'
+ self.assertAllEqual(10, c_val)
+
+ # Ensure that 'v' is initialized
+ self.assertAllClose(0.0, real_v_val)
+
+ def testWithTensorDependencies(self):
+ with self.test_session():
+ v = tf.Variable(0.0)
+ c1 = tf.constant(10)
+ c2 = tf.constant(20)
+
+ # c1_with_init_v depends on the init op for v
+ c1_with_init_v = control_flow_ops.with_dependencies(
+ name="c1_with_init_v",
+ output_tensor=c1,
+ dependencies=[v.initializer])
+ # c2_with_c1 depends on the value of c1_with_init_v
+ c2_with_c1_dep = control_flow_ops.with_dependencies(
+ name="c2_with_c1_dep",
+ output_tensor=c2,
+ dependencies=[c1_with_init_v])
+
+ # Fetching v directly will result in an uninitialized error
+ with self.assertRaisesOpError("Attempting to use uninitialized value"):
+ v.eval()
+
+ # Get the value of 'c2_with_c1_dep', which should cause 'v'
+ # to be initialized.
+ self.assertAllEqual(20, c2_with_c1_dep.eval())
+
+ # Ensure that 'v' is initialized
+ self.assertAllClose(0.0, v.eval())
+
+ def testWithIndexedSlicesDependencies(self):
+ with self.test_session():
+ v = tf.Variable(
+ np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32))
+ v_at_1 = tf.IndexedSlices(v, tf.constant([1]))
+ gather_v_at_1 = tf.gather(v_at_1.values, v_at_1.indices)
+ v_at_1_after_init = control_flow_ops.with_dependencies([v.initializer],
+ v_at_1)
+ gather_v_at_1_after_init = tf.gather(
+ v_at_1_after_init.values, v_at_1_after_init.indices)
+
+ # Fetching gather_v_at_1 will result in an uninitialized error
+ with self.assertRaisesOpError("Attempting to use uninitialized value"):
+ gather_v_at_1.eval()
+
+ # Getting gather_v_at_1_after_init will work, and initialize v.
+ self.assertAllEqual([[10.0, 11.0]], gather_v_at_1_after_init.eval())
+
+ # Double check that 'v' is initialized
+ self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]], v.eval())
+
+ def testDependenciesDevice(self):
+ with tf.Graph().as_default():
+ # device set on tensor => same device on dep.
+ with tf.device("/job:ps"):
+ vd = tf.Variable([0.0])
+ with_vd_dep = control_flow_ops.with_dependencies([vd.initializer], vd)
+ self.assertTrue("/job:ps" in with_vd_dep.device)
+
+ # No device set on tensor => no device on dep.
+ vnod = tf.Variable([0.0])
+ with_vnod_dep = control_flow_ops.with_dependencies([vnod.initializer],
+ vnod)
+ self.assertEquals(None, with_vnod_dep.device)
+
+ # device set on tensor, default device on graph => default device on dep.
+ vdef = tf.Variable([0.0])
+ with tf.device("/job:worker/gpu:1"):
+ with_vdef_dep = control_flow_ops.with_dependencies([vdef.initializer],
+ vdef)
+ self.assertEquals("/job:worker/gpu:1", with_vdef_dep.device)
+
+ def testGroup(self):
+ with self.test_session() as sess:
+ v1 = tf.Variable([0.0])
+ v2 = tf.Variable([1.0])
+
+ # Group init1 and init2 and run.
+ init = control_flow_ops.group(v1.initializer, v2.initializer)
+ # Fetching v1 directly will result in an uninitialized error
+ with self.assertRaisesOpError("Attempting to use uninitialized value"):
+ v1.eval()
+
+ # Runs "init" before fetching v1 and v2.
+ init.run()
+ v1_val, v2_val = sess.run([v1, v2])
+
+ # Ensure that v1 and v2 are initialized
+ self.assertAllClose([0.0], v1_val)
+ self.assertAllClose([1.0], v2_val)
+
+ def testMergeShapes(self):
+ # All inputs unknown.
+ p1 = tf.placeholder(tf.float32)
+ p2 = tf.placeholder(tf.float32)
+ p3 = tf.placeholder(tf.float32)
+ m, index = control_flow_ops.merge([p1, p2, p3])
+ self.assertIs(None, m.get_shape().ndims)
+ self.assertEqual([], index.get_shape())
+
+ # All inputs known but different.
+ p1 = tf.placeholder(tf.float32, shape=[1, 2])
+ p2 = tf.placeholder(tf.float32, shape=[2, 1])
+ m, index = control_flow_ops.merge([p1, p2])
+ self.assertIs(None, m.get_shape().ndims)
+ self.assertEqual([], index.get_shape())
+
+ # All inputs known but same.
+ p1 = tf.placeholder(tf.float32, shape=[1, 2])
+ p2 = tf.placeholder(tf.float32, shape=[1, 2])
+ m, index = control_flow_ops.merge([p1, p2])
+ self.assertEqual([1, 2], m.get_shape())
+ self.assertEqual([], index.get_shape())
+
+ # Possibly the same but not guaranteed.
+ p1 = tf.placeholder(tf.float32, shape=[1, 2])
+ p2 = tf.placeholder(tf.float32)
+ p2.set_shape([None, 2])
+ m, index = control_flow_ops.merge([p1, p2])
+ self.assertIs(None, m.get_shape().ndims)
+ self.assertEqual([], index.get_shape())
+
+ def testRefSelect(self):
+ index = tf.placeholder(tf.int32)
+
+ # All inputs unknown.
+ p1 = tf.placeholder(tf.float32_ref)
+ p2 = tf.placeholder(tf.float32_ref)
+ p3 = tf.placeholder(tf.float32_ref)
+ s = control_flow_ops.ref_select(index, [p1, p2, p3])
+ self.assertIs(None, s.get_shape().ndims)
+
+ # All inputs known but different.
+ p1 = tf.placeholder(tf.float32_ref, shape=[1, 2])
+ p2 = tf.placeholder(tf.float32_ref, shape=[2, 1])
+ s = control_flow_ops.ref_select(index, [p1, p2])
+ self.assertIs(None, s.get_shape().ndims)
+
+ # All inputs known but same.
+ p1 = tf.placeholder(tf.float32_ref, shape=[1, 2])
+ p2 = tf.placeholder(tf.float32_ref, shape=[1, 2])
+ s = control_flow_ops.ref_select(index, [p1, p2])
+ self.assertEqual([1, 2], s.get_shape())
+
+ # Possibly the same but not guaranteed.
+ p1 = tf.placeholder(tf.float32_ref, shape=[1, 2])
+ p2 = tf.placeholder(tf.float32_ref)
+ p2.set_shape([None, 2])
+ s = control_flow_ops.ref_select(index, [p1, p2])
+ self.assertEqual(None, s.get_shape())
+
+
+class TupleTest(tf.test.TestCase):
+
+ def testTensors(self):
+ for v1_first in [True, False]:
+ with self.test_session():
+ v1 = tf.Variable([1.0])
+ add1 = tf.add(
+ control_flow_ops.with_dependencies([v1.initializer], v1),
+ 2.0)
+ v2 = tf.Variable([10.0])
+ add2 = tf.add(control_flow_ops.with_dependencies([v2.initializer],
+ v2),
+ 20.0)
+ t1, _, t2 = control_flow_ops.tuple([add1, None, add2])
+
+ # v1 is not initialized.
+ with self.assertRaisesOpError("Attempting to use uninitialized value"):
+ v1.eval()
+
+ # v2 is not initialized.
+ with self.assertRaisesOpError("Attempting to use uninitialized value"):
+ v2.eval()
+
+ if v1_first:
+ # Getting t1 initializes v2.
+ self.assertAllClose([3.0], t1.eval())
+ self.assertAllClose([10.0], v2.eval())
+ else:
+ # Getting t2 initializes v1.
+ self.assertAllClose([30.0], t2.eval())
+ self.assertAllClose([1.0], v1.eval())
+
+ def testIndexedSlices(self):
+ for v1_first in [True, False]:
+ with self.test_session():
+ v1 = tf.Variable(
+ np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(
+ np.float32))
+ v1_at_1 = tf.IndexedSlices(
+ control_flow_ops.with_dependencies([v1.initializer], v1),
+ tf.constant([1]))
+
+ v2 = tf.Variable(
+ np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype(
+ np.float32))
+ v2_at_1 = tf.IndexedSlices(
+ control_flow_ops.with_dependencies([v2.initializer], v2),
+ tf.constant([1]))
+
+ st1, st2 = control_flow_ops.tuple([v1_at_1, v2_at_1])
+ g1 = tf.gather(st1.values, st1.indices)
+ g2 = tf.gather(st2.values, st2.indices)
+
+ # v1 is not initialized.
+ with self.assertRaisesOpError("Attempting to use uninitialized value"):
+ v1.eval()
+
+ # v2 is not initialized.
+ with self.assertRaisesOpError("Attempting to use uninitialized value"):
+ v2.eval()
+
+ if v1_first:
+ # Getting g1 initializes v2.
+ self.assertAllClose([[10.0, 11.0]], g1.eval())
+ self.assertAllClose([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]],
+ v2.eval())
+ else:
+ # Getting g2 initializes v1.
+ self.assertAllClose([[10.1, 11.1]], g2.eval())
+ self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]],
+ v1.eval())
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
new file mode 100644
index 0000000000..7f5d419c98
--- /dev/null
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -0,0 +1,1009 @@
+"""Functional tests for convolutional operations."""
+import math
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+
+def GetInceptionShapes():
+ """Iterator for the convolution shapes used in the Inception 2015 model.
+
+ Yields:
+ Tuple (input_size, filter_size, out_size, stride, padding), the convolution
+ parameters of Inception layers.
+ """
+ input_sizes = [[4, 5, 5, 1248], [4, 8, 8, 384], [4, 8, 8, 384],
+ [4, 8, 8, 2048], [4, 8, 8, 448], [4, 8, 8, 2048],
+ [4, 8, 8, 2048], [4, 8, 8, 2048], [4, 8, 8, 1760],
+ [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 8, 8, 1760],
+ [4, 17, 17, 192], [4, 17, 17, 192], [4, 17, 17, 1248],
+ [4, 17, 17, 128], [4, 17, 17, 1248], [4, 17, 17, 224],
+ [4, 17, 17, 192], [4, 17, 17, 192], [4, 17, 17, 1216],
+ [4, 17, 17, 1216], [4, 17, 17, 224], [4, 17, 17, 192],
+ [4, 17, 17, 192], [4, 17, 17, 1152], [4, 17, 17, 1152],
+ [4, 17, 17, 192], [4, 17, 17, 160], [4, 17, 17, 1152],
+ [4, 17, 17, 1024], [4, 17, 17, 128], [4, 17, 17, 1024],
+ [4, 17, 17, 128], [4, 17, 17, 1024], [4, 17, 17, 128],
+ [4, 17, 17, 768], [4, 17, 17, 128], [4, 17, 17, 128],
+ [4, 17, 17, 768], [4, 17, 17, 768], [4, 35, 35, 96],
+ [4, 35, 35, 288], [4, 35, 35, 64], [4, 35, 35, 288],
+ [4, 35, 35, 256], [4, 35, 35, 48], [4, 35, 35, 256],
+ [4, 35, 35, 96], [4, 35, 35, 192], [4, 35, 35, 192],
+ [4, 35, 35, 192], [4, 73, 73, 64], [4, 73, 73, 64],
+ [4, 147, 147, 24]]
+ filter_sizes = [[1, 1, 1248, 128], [1, 3, 384, 384], [3, 1, 384, 384],
+ [1, 1, 2048, 192], [3, 3, 448, 384], [1, 1, 2048, 320],
+ [1, 1, 2048, 448], [1, 1, 2048, 384], [1, 1, 1760, 384],
+ [1, 1, 1760, 192], [1, 1, 1760, 448], [1, 1, 1760, 320],
+ [3, 3, 192, 192], [3, 3, 192, 192], [1, 1, 1248, 192],
+ [3, 3, 128, 320], [1, 1, 1248, 128], [1, 3, 224, 224],
+ [3, 1, 192, 256], [1, 3, 192, 256], [1, 1, 1216, 192],
+ [1, 1, 1216, 96], [3, 1, 224, 224], [3, 3, 192, 224],
+ [1, 3, 192, 192], [1, 1, 1152, 192], [1, 1, 1152, 128],
+ [3, 1, 192, 192], [3, 3, 160, 192], [1, 1, 1152, 160],
+ [1, 1, 1024, 128], [1, 3, 128, 192], [1, 1, 1024, 160],
+ [3, 1, 128, 192], [1, 1, 1024, 256], [3, 1, 128, 128],
+ [1, 1, 768, 192], [1, 3, 128, 128], [3, 3, 128, 128],
+ [1, 1, 768, 128], [1, 1, 768, 320], [3, 3, 96, 96],
+ [3, 3, 288, 384], [3, 3, 64, 96], [1, 1, 288, 64],
+ [1, 1, 256, 64], [5, 5, 48, 64], [1, 1, 256, 48],
+ [3, 3, 96, 96], [1, 1, 192, 32], [1, 1, 192, 64],
+ [1, 1, 192, 48], [3, 3, 64, 192], [1, 1, 64, 64],
+ [1, 1, 24, 64]]
+ out_sizes = [[4, 5, 5, 128], [4, 8, 8, 384], [4, 8, 8, 384],
+ [4, 8, 8, 192], [4, 8, 8, 384], [4, 8, 8, 320],
+ [4, 8, 8, 448], [4, 8, 8, 384], [4, 8, 8, 384],
+ [4, 8, 8, 192], [4, 8, 8, 448], [4, 8, 8, 320],
+ [4, 8, 8, 192], [4, 17, 17, 192], [4, 17, 17, 192],
+ [4, 8, 8, 320], [4, 17, 17, 128], [4, 17, 17, 224],
+ [4, 17, 17, 256], [4, 17, 17, 256], [4, 17, 17, 192],
+ [4, 17, 17, 96], [4, 17, 17, 224], [4, 17, 17, 224],
+ [4, 17, 17, 192], [4, 17, 17, 192], [4, 17, 17, 128],
+ [4, 17, 17, 192], [4, 17, 17, 192], [4, 17, 17, 160],
+ [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 160],
+ [4, 17, 17, 192], [4, 17, 17, 256], [4, 17, 17, 128],
+ [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 128],
+ [4, 17, 17, 128], [4, 17, 17, 320], [4, 17, 17, 96],
+ [4, 17, 17, 384], [4, 35, 35, 96], [4, 35, 35, 64],
+ [4, 35, 35, 64], [4, 35, 35, 64], [4, 35, 35, 48],
+ [4, 35, 35, 96], [4, 35, 35, 32], [4, 35, 35, 64],
+ [4, 35, 35, 48], [4, 71, 71, 192], [4, 73, 73, 64],
+ [4, 147, 147, 64]]
+ strides = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
+ # pylint: disable=invalid-name
+ VALID = "VALID"
+ SAME = "SAME"
+ # pylint: enable=invalid-name
+ paddings = [SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
+ SAME, SAME, SAME, SAME, VALID, SAME, SAME, VALID,
+ SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
+ SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
+ SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
+ SAME, VALID, VALID, SAME, SAME, SAME, SAME, SAME,
+ SAME, SAME, SAME, SAME, VALID, VALID, VALID]
+ for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes,
+ strides, paddings):
+ yield i, f, o, s, p
+
+
+class Conv2DTest(tf.test.TestCase):
+
+ def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, stride,
+ padding, use_gpu):
+ """Verifies the output values of the convolution function.
+
+ Args:
+ tensor_in_sizes: Input tensor dimensions in
+ [batch, input_rows, input_cols, input_depth].
+ filter_in_sizes: Filter tensor dimensions in
+ [kernel_rows, kernel_cols, input_depth, output_depth].
+ stride: Stride.
+ padding: Padding type.
+ use_gpu: True if the operations should be run on GPU
+ Returns:
+ Symbolic tensor value that can be used to execute the computation
+ """
+ total_size_1 = 1
+ total_size_2 = 1
+ for s in tensor_in_sizes:
+ total_size_1 *= s
+ for s in filter_in_sizes:
+ total_size_2 *= s
+ # Initializes the input tensor with array containing incrementing
+ # numbers from 1.
+ x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
+ x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
+ with self.test_session(use_gpu=use_gpu) as sess:
+ t1 = tf.constant(x1, shape=tensor_in_sizes)
+ t2 = tf.constant(x2, shape=filter_in_sizes)
+ conv = tf.nn.conv2d(t1, t2,
+ strides=[1, stride, stride, 1],
+ padding=padding)
+ return conv
+
+ def _CompareFwdValues(self, tensor_in_sizes, filter_in_sizes,
+ stride, padding):
+ """Verifies that CPU and GPU produce the same values.
+
+ Args:
+ tensor_in_sizes: Input tensor dimensions in
+ [batch, input_rows, input_cols, input_depth].
+ filter_in_sizes: Filter tensor dimensions in
+ [kernel_rows, kernel_cols, input_depth, output_depth].
+ stride: Stride.
+ padding: Padding type.
+ """
+ x1 = np.random.rand(*tensor_in_sizes).astype(np.float32)
+ x2 = np.random.rand(*filter_in_sizes).astype(np.float32)
+ def _SetupVal(use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ t1 = tf.constant(x1, shape=tensor_in_sizes)
+ t2 = tf.constant(x2, shape=filter_in_sizes)
+ conv = tf.nn.conv2d(t1, t2, strides=[1, stride, stride, 1],
+ padding=padding)
+ return conv
+ gpu_tensor = _SetupVal(use_gpu=True)
+ cpu_tensor = _SetupVal(use_gpu=False)
+ with self.test_session() as sess:
+ (gpu_value, cpu_value) = sess.run([gpu_tensor, cpu_tensor])
+ self.assertAllClose(cpu_value, gpu_value, rtol=1e-5, atol=1e-5)
+
+ def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, stride,
+ padding, expected):
+ tensor_cpu = self._SetupValuesForDevice(tensor_in_sizes, filter_in_sizes,
+ stride, padding, use_gpu=False)
+ tensor_gpu = self._SetupValuesForDevice(tensor_in_sizes, filter_in_sizes,
+ stride, padding, use_gpu=True)
+ with self.test_session() as sess:
+ tensors = [tensor_cpu, tensor_gpu]
+ (value_cpu, value_gpu) = sess.run(tensors)
+ values = [value_cpu, value_gpu]
+ for i in range(len(tensors)):
+ conv = tensors[i]
+ value = values[i]
+ print "expected = ", expected
+ print "actual = ", value
+ self.assertArrayNear(expected, np.ravel(value), 1e-5)
+ self.assertShapeEqual(value, conv)
+
+ def testConv2D1x1Filter(self):
+ expected_output = [30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0,
+ 138.0, 171.0, 204.0, 174.0, 216.0, 258.0, 210.0, 261.0,
+ 312.0]
+ self._VerifyValues(tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[1, 1, 3, 3],
+ stride=1, padding="VALID",
+ expected=expected_output)
+
+ def testConv2D2x2Filter(self):
+ # The outputs are computed using third_party/py/IPython/notebook.
+ expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0]
+ self._VerifyValues(tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[2, 2, 3, 3],
+ stride=1, padding="VALID",
+ expected=expected_output)
+
+ def testConv2D1x2Filter(self):
+ # The outputs are computed using third_party/py/IPython/notebook.
+ expected_output = [231.0, 252.0, 273.0, 384.0, 423.0, 462.0, 690.0,
+ 765.0, 840.0, 843.0, 936.0, 1029.0]
+ self._VerifyValues(tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[1, 2, 3, 3],
+ stride=1, padding="VALID",
+ expected=expected_output)
+
+ def testConv2D2x2FilterStride2(self):
+ expected_output = [2271.0, 2367.0, 2463.0]
+ self._VerifyValues(tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[2, 2, 3, 3],
+ stride=2, padding="VALID",
+ expected=expected_output)
+
+ def testConv2D2x2FilterStride2Same(self):
+ expected_output = [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0]
+ self._VerifyValues(tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[2, 2, 3, 3],
+ stride=2, padding="SAME",
+ expected=expected_output)
+
+ # Testing for backprops
+ def _RunAndVerifyBackpropInput(self, input_sizes, filter_sizes, output_sizes,
+ stride, padding, expected, use_gpu):
+ total_output_size = 1
+ total_filter_size = 1
+ for s in output_sizes:
+ total_output_size *= s
+ for s in filter_sizes:
+ total_filter_size *= s
+ # Initializes the input tensor with array containing incrementing
+ # numbers from 1.
+ x1 = [f * 1.0 for f in range(1, total_filter_size + 1)]
+ x2 = [f * 1.0 for f in range(1, total_output_size + 1)]
+ with self.test_session(use_gpu=use_gpu) as sess:
+ t0 = tf.constant(input_sizes, shape=[len(input_sizes)])
+ t1 = tf.constant(x1, shape=filter_sizes)
+ t2 = tf.constant(x2, shape=output_sizes)
+ conv = tf.nn.conv2d_backprop_input(t0, t1, t2,
+ strides=[1, stride, stride, 1],
+ padding=padding)
+ # "values" consists of two tensors for two backprops
+ value = sess.run(conv)
+ self.assertShapeEqual(value, conv)
+ print "expected = ", expected
+ print "actual = ", value
+ self.assertArrayNear(expected, value.flatten(), 1e-5)
+
+ def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes,
+ stride, padding):
+ x1 = np.random.rand(*filter_sizes).astype(np.float32)
+ x2 = np.random.rand(*output_sizes).astype(np.float32)
+ def _GetVal(use_gpu):
+ with self.test_session(use_gpu=use_gpu) as sess:
+ t0 = tf.constant(input_sizes, shape=[len(input_sizes)])
+ t1 = tf.constant(x1, shape=filter_sizes)
+ t2 = tf.constant(x2, shape=output_sizes)
+ conv = tf.nn.conv2d_backprop_input(t0, t1, t2,
+ strides=[1, stride, stride, 1],
+ padding=padding)
+ ret = conv.eval()
+ self.assertShapeEqual(ret, conv)
+ return ret
+ gpu_value = _GetVal(use_gpu=True)
+ cpu_value = _GetVal(use_gpu=False)
+ self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
+
+ def testConv2D2x2Depth1ValidBackpropInput(self):
+ expected_output = [1.0, 4.0, 4.0, 3.0, 10.0, 8.0]
+ self._RunAndVerifyBackpropInput(input_sizes=[1, 2, 3, 1],
+ filter_sizes=[2, 2, 1, 1],
+ output_sizes=[1, 1, 2, 1],
+ stride=1, padding="VALID",
+ expected=expected_output, use_gpu=False)
+ self._RunAndVerifyBackpropInput(input_sizes=[1, 2, 3, 1],
+ filter_sizes=[2, 2, 1, 1],
+ output_sizes=[1, 1, 2, 1],
+ stride=1, padding="VALID",
+ expected=expected_output, use_gpu=True)
+
+ def testConv2D2x2Depth3ValidBackpropInput(self):
+ expected_output = [14.0, 32.0, 50.0,
+ 100.0, 163.0, 226.0,
+ 167.0, 212.0, 257.0,
+ 122.0, 140.0, 158.0,
+ 478.0, 541.0, 604.0,
+ 437.0, 482.0, 527.0]
+ self._RunAndVerifyBackpropInput(input_sizes=[1, 2, 3, 3],
+ filter_sizes=[2, 2, 3, 3],
+ output_sizes=[1, 1, 2, 3],
+ stride=1, padding="VALID",
+ expected=expected_output, use_gpu=False)
+ self._RunAndVerifyBackpropInput(input_sizes=[1, 2, 3, 3],
+ filter_sizes=[2, 2, 3, 3],
+ output_sizes=[1, 1, 2, 3],
+ stride=1, padding="VALID",
+ expected=expected_output, use_gpu=True)
+
+ # Testing for backprops
+ def _RunAndVerifyBackpropFilter(self, input_sizes, filter_sizes, output_sizes,
+ stride, padding, expected, use_gpu):
+ total_input_size = 1
+ total_output_size = 1
+ for s in input_sizes:
+ total_input_size *= s
+ for s in output_sizes:
+ total_output_size *= s
+ # Initializes the input tensor with array containing incrementing
+ # numbers from 1.
+ x0 = [f * 1.0 for f in range(1, total_input_size + 1)]
+ x2 = [f * 1.0 for f in range(1, total_output_size + 1)]
+ with self.test_session(use_gpu=use_gpu) as sess:
+ t0 = tf.constant(x0, shape=input_sizes)
+ t1 = tf.constant(filter_sizes, shape=[len(filter_sizes)])
+ t2 = tf.constant(x2, shape=output_sizes)
+ conv = tf.nn.conv2d_backprop_filter(t0, t1, t2,
+ strides=[1, stride, stride, 1],
+ padding=padding)
+ value = sess.run(conv)
+ self.assertShapeEqual(value, conv)
+ print "expected = ", expected
+ print "actual = ", value
+ self.assertArrayNear(expected, value.flatten(), 1e-5)
+
+ def _CompareBackFilter(self, input_sizes, filter_sizes, output_sizes,
+ stride, padding):
+ x0 = np.random.rand(*input_sizes).astype(np.float32)
+ x2 = np.random.rand(*output_sizes).astype(np.float32)
+ def _GetVal(use_gpu):
+ with self.test_session(use_gpu=use_gpu) as sess:
+ t0 = tf.constant(x0, shape=input_sizes)
+ t1 = tf.constant(filter_sizes, shape=[len(filter_sizes)])
+ t2 = tf.constant(x2, shape=output_sizes)
+ conv = tf.nn.conv2d_backprop_filter(t0, t1, t2,
+ strides=[1, stride, stride, 1],
+ padding=padding)
+ ret = conv.eval()
+ self.assertShapeEqual(ret, conv)
+ return ret
+ gpu_value = _GetVal(use_gpu=True)
+ cpu_value = _GetVal(use_gpu=False)
+ self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
+
+ def testConv2D2x2Depth1ValidBackpropFilter(self):
+ expected = [5.0, 8.0, 14.0, 17.0]
+ self._RunAndVerifyBackpropFilter(input_sizes=[1, 2, 3, 1],
+ filter_sizes=[2, 2, 1, 1],
+ output_sizes=[1, 1, 2, 1],
+ stride=1, padding="VALID",
+ expected=expected, use_gpu=False)
+ self._RunAndVerifyBackpropFilter(input_sizes=[1, 2, 3, 1],
+ filter_sizes=[2, 2, 1, 1],
+ output_sizes=[1, 1, 2, 1],
+ stride=1, padding="VALID",
+ expected=expected, use_gpu=True)
+
+ def testConv2D2x2Depth3ValidBackpropFilter(self):
+ expected = [17.0, 22.0, 27.0, 22.0, 29.0, 36.0, 27.0, 36.0, 45.0,
+ 32.0, 43.0, 54.0, 37.0, 50.0, 63.0, 42.0, 57.0, 72.0,
+ 62.0, 85.0, 108.0, 67.0, 92.0, 117.0, 72.0, 99.0, 126.0,
+ 77.0, 106.0, 135.0, 82.0, 113.0, 144.0, 87.0, 120.0, 153.0]
+ self._RunAndVerifyBackpropFilter(input_sizes=[1, 2, 3, 3],
+ filter_sizes=[2, 2, 3, 3],
+ output_sizes=[1, 1, 2, 3],
+ stride=1, padding="VALID",
+ expected=expected, use_gpu=False)
+ self._RunAndVerifyBackpropFilter(input_sizes=[1, 2, 3, 3],
+ filter_sizes=[2, 2, 3, 3],
+ output_sizes=[1, 1, 2, 3],
+ stride=1, padding="VALID",
+ expected=expected, use_gpu=True)
+
+ # Gradient checkers
+ def ConstructAndTestGradient(self, batch, input_rows, input_cols, filter_rows,
+ filter_cols, in_depth, out_depth, stride,
+ padding, test_input, use_gpu):
+ input_shape = [batch, input_rows, input_cols, in_depth]
+ filter_shape = [filter_rows, filter_cols, in_depth, out_depth]
+ # TODO(yangke): re-factor the computation of output shape.
+ if padding == "VALID":
+ output_rows = int(math.ceil((input_rows - filter_rows + 1.0) / stride))
+ output_cols = int(math.ceil((input_cols - filter_cols + 1.0) / stride))
+ else:
+ output_rows = int(math.ceil(float(input_rows) / stride))
+ output_cols = int(math.ceil(float(input_cols) / stride))
+ output_shape = [batch, output_rows, output_cols, out_depth]
+ input_size = 1
+ for x in input_shape:
+ input_size *= x
+ filter_size = 1
+ for x in filter_shape:
+ filter_size *= x
+ input_data = [x * 1.0 / input_size for x in range(0, input_size)]
+ filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
+ with self.test_session(use_gpu=use_gpu):
+ # Conv2DGrad functions are not compiled for double due to
+ # a problem in the way Eigen's Conv2DGrad works for double.
+ # So we disable the DOUBLE path. We should re-enable this
+ # when double support returns for CPU and/or GPU.
+ # data_type = tf.float64
+ # tolerance = 1e-8
+
+ data_type = tf.float32
+ tolerance = 0.002
+
+ input_tensor = tf.constant(input_data, shape=input_shape,
+ dtype=data_type, name="input")
+ filter_tensor = tf.constant(filter_data, shape=filter_shape,
+ dtype=data_type, name="filter")
+ conv = tf.nn.conv2d(input_tensor, filter_tensor,
+ [1, stride, stride, 1], padding,
+ name="conv")
+ self.assertEqual(output_shape, conv.get_shape())
+ if test_input:
+ err = gc.ComputeGradientError(input_tensor, input_shape,
+ conv, output_shape)
+ else:
+ err = gc.ComputeGradientError(filter_tensor, filter_shape,
+ conv, output_shape)
+ print "conv_2d gradient error = ", err
+ self.assertLess(err, tolerance)
+
+ def testInputGradientValidPaddingStrideOne(self):
+ self.ConstructAndTestGradient(
+ batch=2,
+ input_rows=5,
+ input_cols=4,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=2,
+ out_depth=3,
+ stride=1,
+ padding="VALID",
+ test_input=True,
+ use_gpu=False)
+ self.ConstructAndTestGradient(
+ batch=2,
+ input_rows=5,
+ input_cols=4,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=2,
+ out_depth=3,
+ stride=1,
+ padding="VALID",
+ test_input=True,
+ use_gpu=True)
+
+ def testFilterGradientValidPaddingStrideOne(self):
+ self.ConstructAndTestGradient(
+ batch=4,
+ input_rows=6,
+ input_cols=5,
+ filter_rows=2,
+ filter_cols=2,
+ in_depth=2,
+ out_depth=3,
+ stride=1,
+ padding="VALID",
+ test_input=False,
+ use_gpu=False)
+ self.ConstructAndTestGradient(
+ batch=4,
+ input_rows=6,
+ input_cols=5,
+ filter_rows=2,
+ filter_cols=2,
+ in_depth=2,
+ out_depth=3,
+ stride=1,
+ padding="VALID",
+ test_input=False,
+ use_gpu=True)
+
+ def testInputGradientValidPaddingStrideTwo(self):
+ self.ConstructAndTestGradient(
+ batch=2,
+ input_rows=4,
+ input_cols=5,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=2,
+ out_depth=3,
+ stride=2,
+ padding="VALID",
+ test_input=True,
+ use_gpu=False)
+ self.ConstructAndTestGradient(
+ batch=2,
+ input_rows=4,
+ input_cols=5,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=2,
+ out_depth=3,
+ stride=2,
+ padding="VALID",
+ test_input=True,
+ use_gpu=True)
+
+ def testFilterGradientValidPaddingStrideTwo(self):
+ self.ConstructAndTestGradient(
+ batch=4,
+ input_rows=6,
+ input_cols=5,
+ filter_rows=2,
+ filter_cols=2,
+ in_depth=2,
+ out_depth=3,
+ stride=2,
+ padding="VALID",
+ test_input=False,
+ use_gpu=False)
+ self.ConstructAndTestGradient(
+ batch=4,
+ input_rows=6,
+ input_cols=5,
+ filter_rows=2,
+ filter_cols=2,
+ in_depth=2,
+ out_depth=3,
+ stride=2,
+ padding="VALID",
+ test_input=False,
+ use_gpu=True)
+
+ def testInputGradientValidPaddingStrideThree(self):
+ self.ConstructAndTestGradient(
+ batch=2,
+ input_rows=7,
+ input_cols=6,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=4,
+ out_depth=5,
+ stride=3,
+ padding="VALID",
+ test_input=True,
+ use_gpu=False)
+ self.ConstructAndTestGradient(
+ batch=2,
+ input_rows=7,
+ input_cols=6,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=4,
+ out_depth=5,
+ stride=3,
+ padding="VALID",
+ test_input=True,
+ use_gpu=True)
+
+ def testFilterGradientValidPaddingStrideThree(self):
+ self.ConstructAndTestGradient(
+ batch=2,
+ input_rows=8,
+ input_cols=7,
+ filter_rows=4,
+ filter_cols=4,
+ in_depth=2,
+ out_depth=3,
+ stride=3,
+ padding="VALID",
+ test_input=False,
+ use_gpu=False)
+ self.ConstructAndTestGradient(
+ batch=2,
+ input_rows=8,
+ input_cols=7,
+ filter_rows=4,
+ filter_cols=4,
+ in_depth=2,
+ out_depth=3,
+ stride=3,
+ padding="VALID",
+ test_input=False,
+ use_gpu=True)
+
+ def testInputGradientSamePaddingStrideOne(self):
+ self.ConstructAndTestGradient(
+ batch=2,
+ input_rows=7,
+ input_cols=6,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=2,
+ out_depth=3,
+ stride=1,
+ padding="SAME",
+ test_input=True,
+ use_gpu=False)
+ self.ConstructAndTestGradient(
+ batch=2,
+ input_rows=7,
+ input_cols=6,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=2,
+ out_depth=3,
+ stride=1,
+ padding="SAME",
+ test_input=True,
+ use_gpu=True)
+
+ def testFilterGradientSamePaddingStrideOne(self):
+ self.ConstructAndTestGradient(
+ batch=4,
+ input_rows=6,
+ input_cols=5,
+ filter_rows=2,
+ filter_cols=2,
+ in_depth=2,
+ out_depth=3,
+ stride=1,
+ padding="SAME",
+ test_input=False,
+ use_gpu=False)
+ self.ConstructAndTestGradient(
+ batch=4,
+ input_rows=6,
+ input_cols=5,
+ filter_rows=2,
+ filter_cols=2,
+ in_depth=2,
+ out_depth=3,
+ stride=1,
+ padding="SAME",
+ test_input=False,
+ use_gpu=True)
+
+ def testInputGradientSamePaddingStrideTwo(self):
+ self.ConstructAndTestGradient(
+ batch=2,
+ input_rows=5,
+ input_cols=4,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=3,
+ out_depth=3,
+ stride=2,
+ padding="SAME",
+ test_input=True,
+ use_gpu=False)
+ self.ConstructAndTestGradient(
+ batch=2,
+ input_rows=5,
+ input_cols=4,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=3,
+ out_depth=3,
+ stride=2,
+ padding="SAME",
+ test_input=True,
+ use_gpu=True)
+
+ def testFilterGradientSamePaddingStrideTwo(self):
+ self.ConstructAndTestGradient(
+ batch=4,
+ input_rows=6,
+ input_cols=5,
+ filter_rows=2,
+ filter_cols=2,
+ in_depth=2,
+ out_depth=3,
+ stride=2,
+ padding="SAME",
+ test_input=False,
+ use_gpu=False)
+ self.ConstructAndTestGradient(
+ batch=4,
+ input_rows=6,
+ input_cols=5,
+ filter_rows=2,
+ filter_cols=2,
+ in_depth=2,
+ out_depth=3,
+ stride=2,
+ padding="SAME",
+ test_input=False,
+ use_gpu=True)
+
+ def testInputGradientSamePaddingStrideThree(self):
+ self.ConstructAndTestGradient(
+ batch=2,
+ input_rows=7,
+ input_cols=6,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=4,
+ out_depth=5,
+ stride=3,
+ padding="SAME",
+ test_input=True,
+ use_gpu=False)
+ self.ConstructAndTestGradient(
+ batch=2,
+ input_rows=7,
+ input_cols=6,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=4,
+ out_depth=5,
+ stride=3,
+ padding="SAME",
+ test_input=True,
+ use_gpu=True)
+
+ def testFilterGradientSamePaddingStrideThree(self):
+ self.ConstructAndTestGradient(
+ batch=2,
+ input_rows=8,
+ input_cols=7,
+ filter_rows=4,
+ filter_cols=4,
+ in_depth=2,
+ out_depth=3,
+ stride=3,
+ padding="SAME",
+ test_input=False,
+ use_gpu=False)
+ self.ConstructAndTestGradient(
+ batch=2,
+ input_rows=8,
+ input_cols=7,
+ filter_rows=4,
+ filter_cols=4,
+ in_depth=2,
+ out_depth=3,
+ stride=3,
+ padding="SAME",
+ test_input=False,
+ use_gpu=True)
+
+ def testShapeFunctionEdgeCases(self):
+ # All shapes unknown.
+ c1 = tf.nn.conv2d(tf.placeholder(tf.float32),
+ tf.placeholder(tf.float32),
+ strides=[1, 1, 1, 1], padding="SAME")
+ self.assertEqual([None, None, None, None], c1.get_shape().as_list())
+
+ # Incorrect input shape.
+ with self.assertRaises(ValueError):
+ tf.nn.conv2d(tf.placeholder(tf.float32, shape=[1, 3]),
+ tf.placeholder(tf.float32),
+ strides=[1, 1, 1, 1], padding="SAME")
+
+ # Incorrect filter shape.
+ with self.assertRaises(ValueError):
+ tf.nn.conv2d(tf.placeholder(tf.float32),
+ tf.placeholder(tf.float32, shape=[1, 3]),
+ strides=[1, 1, 1, 1], padding="SAME")
+
+ # Depth mismatch.
+ with self.assertRaises(ValueError):
+ tf.nn.conv2d(tf.placeholder(tf.float32,
+ shape=[32, 20, 20, 3]),
+ tf.placeholder(tf.float32,
+ shape=[4, 4, 2, 2]),
+ strides=[1, 1, 1, 1], padding="SAME")
+
+ # Illegal strides.
+ with self.assertRaisesRegexp(ValueError, "strides in the batch and depth"):
+ tf.nn.conv2d(tf.placeholder(tf.float32),
+ tf.placeholder(tf.float32),
+ strides=[2, 1, 1, 1], padding="SAME")
+ with self.assertRaisesRegexp(ValueError, "strides in the batch and depth"):
+ tf.nn.conv2d(tf.placeholder(tf.float32),
+ tf.placeholder(tf.float32),
+ strides=[1, 1, 1, 2], padding="SAME")
+
+ # Filter larger than input.
+ with self.assertRaisesRegexp(ValueError,
+ "filter must not be larger than the input"):
+ tf.nn.conv2d(tf.placeholder(tf.float32,
+ shape=[32, 20, 20, 3]),
+ tf.placeholder(tf.float32,
+ shape=[20, 21, 3, 2]),
+ strides=[1, 1, 1, 1], padding="SAME")
+ with self.assertRaisesRegexp(ValueError,
+ "filter must not be larger than the input"):
+ tf.nn.conv2d(tf.placeholder(tf.float32,
+ shape=[32, 20, 20, 3]),
+ tf.placeholder(tf.float32,
+ shape=[21, 20, 3, 2]),
+ strides=[1, 1, 1, 1], padding="SAME")
+
+ # Stride larger than filter.
+ with self.assertRaisesRegexp(ValueError,
+ "stride must be less than or equal to filter"):
+ tf.nn.conv2d(tf.placeholder(tf.float32,
+ shape=[32, 20, 20, 3]),
+ tf.placeholder(tf.float32,
+ shape=[4, 5, 3, 2]),
+ strides=[1, 5, 5, 1], padding="SAME")
+ with self.assertRaisesRegexp(ValueError,
+ "stride must be less than or equal to filter"):
+ tf.nn.conv2d(tf.placeholder(tf.float32,
+ shape=[32, 20, 20, 3]),
+ tf.placeholder(tf.float32,
+ shape=[5, 4, 3, 2]),
+ strides=[1, 5, 5, 1], padding="SAME")
+
+ # Invalid rectangular stride.
+ with self.assertRaisesRegexp(ValueError,
+ "equal length strides in the row and column"):
+ tf.nn.conv2d(tf.placeholder(tf.float32),
+ tf.placeholder(tf.float32),
+ strides=[1, 3, 7, 1], padding="SAME")
+
+
+# This is only a very simple test. More comprehensive tests live in
+# //learning/dist_belief/experimental/brain_compatibility/conv_nn_test.py
+# where we compare the numeric results of the depthwise conv op with the
+# depthwise weighted sum transformer in dist_belief.
+class DepthwiseConv2DTest(tf.test.TestCase):
+
+ def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, stride,
+ padding, expected):
+ """Verifies the output values of the convolution function.
+
+ Args:
+ tensor_in_sizes: Input tensor dimensions in
+ [batch, input_rows, input_cols, input_depth].
+ filter_in_sizes: Filter tensor dimensions in
+ [filter_rows, filter_cols, input_depth, depth_multiplier].
+ stride: Stride.
+ padding: Padding type.
+ expected: An array containing the expected operation outputs.
+ """
+ total_size_1 = 1
+ total_size_2 = 1
+ for s in tensor_in_sizes:
+ total_size_1 *= s
+ for s in filter_in_sizes:
+ total_size_2 *= s
+ # Initializes the input tensor with array containing incrementing
+ # numbers from 1.
+ x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
+ x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
+ with self.test_session() as sess:
+ t1 = tf.constant(x1, shape=tensor_in_sizes)
+ t1.set_shape(tensor_in_sizes)
+ t2 = tf.constant(x2, shape=filter_in_sizes)
+ conv = tf.nn.depthwise_conv2d(t1, t2, strides=[1, stride, stride, 1],
+ padding=padding)
+ value = sess.run(conv)
+ print "value = ", value
+ self.assertArrayNear(expected, np.ravel(value), 1e-5)
+ self.assertShapeEqual(value, conv)
+
+ def testConv2D2x2Filter(self):
+ # The inputs look like this (it's a 3 x 2 matrix, each of depth 2):
+ #
+ # [ (1.0, 2.0), (3.0, 4.0), ( 5.0, 6.0) ]
+ # [ (7.0, 8.0), (9.0, 10.0), (11.0, 12.0) ]
+ # We can view this as two inputs
+ #
+ # input depth 0:
+ #
+ # [ 1.0, 3.0, 5.0 ]
+ # [ 7.0, 9.0, 11.0 ]
+ #
+ # input depth 1:
+ #
+ # [ 2.0, 4.0, 6.0 ]
+ # [ 8.0, 10.0, 12.0 ]
+ #
+ # The filter looks like this (it has two 2 x 2 patches, each generating 2
+ # depths):
+ #
+ # filter #0:
+ #
+ # [ (1.0, 3.0), ( 5.0, 7.0)]
+ # [ (9.0, 11.0), (13.0, 15.0)]
+ #
+ # filter #1:
+ #
+ # [ ( 2.0, 4.0), ( 6.0, 8.0)]
+ # [ (10.0, 12.0), (14.0, 16.0)]
+ #
+ # So the outputs are:
+ #
+ # (position 0, 0: in_depth 0, output_depth 0 -- using filter #0)
+ # 1.0 * 1.0 + 7.0 * 9.0 + 3.0 * 5.0 + 9.0 * 13.0 = 196
+ # (position 0, 0: in_depth 0, output_depth 1 -- using filter #1)
+ # 1.0 * 2.0 + 7.0 * 10.0 + 3.0 * 6.0 + 9.0 * 14.0 = 216
+ # (position 0, 0: in_depth 1, output_depth 2 -- using filter #0)
+ # 2.0 * 3.0 + 8.0 * 11.0 + 4.0 * 7.0 + 10.0 * 15.0 = 272
+ # (position 0, 0: in_depth 1, output_depth 3 -- using filter #1)
+ # 2.0 * 4.0 + 8.0 * 12.0 + 4.0 * 8.0 + 10.0 * 16.0 = 296
+ #
+ # (position 1, 0: in_depth 0, output_depth 0 -- using filter #0)
+ # 3.0 * 1.0 + 9.0 * 9.0 + 5.0 * 5.0 + 11.0 * 13.0 = 252
+ # (position 1, 0: in_depth 0, output_depth 1 -- using filter #1)
+ # 3.0 * 2.0 + 9.0 * 10.0 + 5.0 * 6.0 + 11.0 * 14.0 = 280
+ # (position 1, 0: in_depth 1, output_depth 2 -- using filter #0)
+ # 4.0 * 3.0 + 10.0 * 11.0 + 6.0 * 7.0 + 12.0 * 15.0 = 344
+ # (position 1, 0: in_depth 1, output_depth 3 -- using filter #1)
+ # 4.0 * 4.0 + 10.0 * 12.0 + 6.0 * 8.0 + 12.0 * 16.0 = 376
+ expected_output = [196, 216, 272, 296, 252, 280, 344, 376]
+ self._VerifyValues(tensor_in_sizes=[1, 2, 3, 2],
+ filter_in_sizes=[2, 2, 2, 2],
+ stride=1, padding="VALID",
+ expected=expected_output)
+
+
+class SeparableConv2DTest(tf.test.TestCase):
+
+ def _InitValues(self, sizes):
+ """Initializes values for input tensors.
+
+ Args:
+ sizes: Tensor dimensions.
+
+ Returns:
+ Tensor initialized to values.
+ """
+ total_size = 1
+ for s in sizes:
+ total_size *= s
+ x = [f * 0.5 for f in range(1, total_size + 1)]
+ return tf.constant(x, shape=sizes)
+
+ def _VerifyValues(self, tensor_in_sizes, depthwise_filter_in_sizes,
+ pointwise_filter_in_sizes, stride, padding, expected):
+ """Verifies the output values of the separable convolution function.
+
+ Args:
+ tensor_in_sizes: Input tensor dimensions.
+ depthwise_filter_in_sizes: Depthwise filter tensor dimensions.
+ pointwise_filter_in_sizes: Pointwise filter tensor dimensions.
+ stride: Stride.
+ padding: Padding type.
+ expected: An array containing the expected operation outputs.
+ """
+ with self.test_session() as sess:
+ t1 = self._InitValues(tensor_in_sizes)
+ f1 = self._InitValues(depthwise_filter_in_sizes)
+ f1.set_shape(depthwise_filter_in_sizes)
+ f2 = self._InitValues(pointwise_filter_in_sizes)
+ conv = tf.nn.separable_conv2d(t1, f1, f2, strides=[1, stride, stride, 1],
+ padding=padding)
+ value = sess.run(conv)
+ print "value = ", value
+ self.assertArrayNear(expected, np.ravel(value), 1e-5)
+ self.assertShapeEqual(value, conv)
+
+ def testSeparableConv2D(self):
+ # The output is the result of two convolutions:
+ # First with tensor_in[1, 4, 4, 3] * filter1[2, 2, 3, 3].
+ # Second with intermediate_out[4, 4, 3, 3] * filter2[1, 1, 3, 6].
+ # Complexity is O(3*3*2*2 + 3*6*1*1] as opposed to O(3*6*2*2).
+ expected_output = [
+ 6644.5, 6971.5, 7298.5, 7625.5, 7952.5, 8279.5, 8606.5, 8154.5, 8556.5,
+ 8958.5, 9360.5, 9762.5, 10164.5, 10566.5, 9664.5, 10141.5, 10618.5,
+ 11095.5, 11572.5, 12049.5, 12526.5, 4145.5, 4346.5, 4547.5, 4748.5,
+ 4949.5, 5150.5, 5351.5, 12684.5, 13311.5, 13938.5, 14565.5, 15192.5,
+ 15819.5, 16446.5, 14194.5, 14896.5, 15598.5, 16300.5, 17002.5, 17704.5,
+ 18406.5, 15704.5, 16481.5, 17258.5, 18035.5, 18812.5, 19589.5, 20366.5,
+ 6499.5, 6814.5, 7129.5, 7444.5, 7759.5, 8074.5, 8389.5, 18724.5,
+ 19651.5, 20578.5, 21505.5, 22432.5, 23359.5, 24286.5, 20234.5, 21236.5,
+ 22238.5, 23240.5, 24242.5, 25244.5, 26246.5, 21744.5, 22821.5, 23898.5,
+ 24975.5, 26052.5, 27129.5, 28206.5, 8853.5, 9282.5, 9711.5, 10140.5,
+ 10569.5, 10998.5, 11427.5, 5746.75, 6010.75, 6274.75, 6538.75, 6802.75,
+ 7066.75, 7330.75, 6168.75, 6452.25, 6735.75, 7019.25, 7302.75, 7586.25,
+ 7869.75, 6590.75, 6893.75, 7196.75, 7499.75, 7802.75, 8105.75, 8408.75,
+ 2036.25, 2119.5, 2202.75, 2286.0, 2369.25, 2452.5, 2535.75]
+
+ self._VerifyValues(tensor_in_sizes=[1, 4, 4, 2],
+ depthwise_filter_in_sizes=[2, 2, 2, 3],
+ pointwise_filter_in_sizes=[1, 1, 6, 7],
+ stride=1, padding="SAME",
+ expected=expected_output)
+
+
+def GetInceptionFwdTest(input_size, filter_size, stride, padding):
+ def Test(self):
+ tf.logging.info("Testing InceptionFwd %s", (input_size, filter_size,
+ stride, padding))
+ self._CompareFwdValues(input_size, filter_size, stride, padding)
+ return Test
+
+
+def GetInceptionBackInputTest(input_size, filter_size, output_size,
+ stride, padding):
+ def Test(self):
+ tf.logging.info("Testing InceptionBackInput %s",
+ (input_size, filter_size, output_size, stride, padding))
+ self._CompareBackpropInput(input_size, filter_size, output_size,
+ stride, padding)
+ return Test
+
+
+def GetInceptionBackFilterTest(input_size, filter_size, output_size,
+ stride, padding):
+ def Test(self):
+ tf.logging.info("Testing InceptionBackFilter %s",
+ (input_size, filter_size, output_size, stride, padding))
+ self._CompareBackFilter(input_size, filter_size, output_size,
+ stride, padding)
+ return Test
+
+
+if __name__ == "__main__":
+ for index, (input_size_, filter_size_, output_size_, stride_,
+ padding_) in enumerate(GetInceptionShapes()):
+ setattr(Conv2DTest, "testInceptionFwd_" + str(index),
+ GetInceptionFwdTest(input_size_, filter_size_, stride_, padding_))
+ setattr(Conv2DTest, "testInceptionBackInput_" + str(index),
+ GetInceptionBackInputTest(input_size_, filter_size_, output_size_,
+ stride_, padding_))
+ setattr(Conv2DTest, "testInceptionBackFilter_" + str(index),
+ GetInceptionBackFilterTest(input_size_, filter_size_, output_size_,
+ stride_, padding_))
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
new file mode 100644
index 0000000000..22491f231a
--- /dev/null
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -0,0 +1,1187 @@
+"""Functional tests for coefficient-wise operations.
+"""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+_ADD = lambda x, y: x + y
+_SUB = lambda x, y: x - y
+_MUL = lambda x, y: x * y
+_DIV = lambda x, y: x / y
+_MOD = lambda x, y: x % y
+_NEG = lambda x: -x
+_ABS = abs
+
+_LT = lambda x, y: x < y
+_LE = lambda x, y: x <= y
+_GT = lambda x, y: x > y
+_GE = lambda x, y: x >= y
+
+_AND = lambda x, y: x & y
+_OR = lambda x, y: x | y
+_XOR = lambda x, y: x ^ y
+_INV = lambda x: ~x
+
+
+class UnaryOpTest(tf.test.TestCase):
+
+ def _compareCpu(self, x, np_func, tf_func):
+ np_ans = np_func(x)
+ with self.test_session(use_gpu=False):
+ inx = tf.convert_to_tensor(x)
+ y = tf_func(inx)
+ tf_cpu = y.eval()
+ self.assertShapeEqual(np_ans, y)
+ self.assertAllClose(np_ans, tf_cpu)
+ if x.dtype == np.float32:
+ s = list(np.shape(x))
+ jacob_t, jacob_n = gc.ComputeGradient(inx, s, y, s, x_init_value=x)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+ elif x.dtype == np.float64:
+ s = list(np.shape(x))
+ jacob_t, jacob_n = gc.ComputeGradient(inx, s, y, s, x_init_value=x)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
+
+ def _compareGpu(self, x, np_func, tf_func):
+ np_ans = np_func(x)
+ with self.test_session(use_gpu=True):
+ result = tf_func(tf.convert_to_tensor(x))
+ tf_gpu = result.eval()
+ self.assertShapeEqual(np_ans, result)
+ self.assertAllClose(np_ans, tf_gpu)
+ # TODO(zhifengc/ke): make gradient checker work on GPU.
+
+ def _compareBoth(self, x, np_func, tf_func):
+ self._compareCpu(x, np_func, tf_func)
+ self._compareGpu(x, np_func, tf_func)
+
+ def _inv(self, x):
+ return 1.0 / x
+
+ def _rsqrt(self, x):
+ return self._inv(np.sqrt(x))
+
+ def _sigmoid(self, x):
+ return 1.0 / (1.0 + np.exp(-x))
+
+ def testFloatBasic(self):
+ x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
+ y = (x + .5).astype(np.float32) # no zero
+ z = (x + 15.5).astype(np.float32) # all positive
+ self._compareBoth(x, np.abs, tf.abs)
+ self._compareBoth(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, tf.neg)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareBoth(y, self._inv, tf.inv)
+ self._compareBoth(x, np.square, tf.square)
+ self._compareBoth(z, np.sqrt, tf.sqrt)
+ self._compareBoth(z, self._rsqrt, tf.rsqrt)
+ self._compareBoth(x, np.exp, tf.exp)
+ self._compareBoth(z, np.log, tf.log)
+ self._compareBoth(x, np.tanh, tf.tanh)
+ self._compareBoth(x, self._sigmoid, tf.sigmoid)
+ self._compareBoth(y, np.sign, tf.sign)
+ self._compareBoth(x, np.sin, tf.sin)
+ self._compareBoth(x, np.cos, tf.cos)
+
+ def testFloatTanhEdge(self):
+ x = np.arange(40, 40 + 6).reshape(6).astype(np.float32)
+ self._compareBoth(x, np.tanh, tf.tanh)
+ x = np.arange(-40, -40 + 6).reshape(6).astype(np.float32)
+ self._compareBoth(x, np.tanh, tf.tanh)
+
+ def testFloatEmpty(self):
+ x = np.empty((2, 0, 5), dtype=np.float32)
+ self._compareBoth(x, np.abs, tf.abs)
+ self._compareBoth(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, tf.neg)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareBoth(x, self._inv, tf.inv)
+ self._compareBoth(x, np.square, tf.square)
+ self._compareBoth(x, np.sqrt, tf.sqrt)
+ self._compareBoth(x, self._rsqrt, tf.rsqrt)
+ self._compareBoth(x, np.exp, tf.exp)
+ self._compareBoth(x, np.log, tf.log)
+ self._compareBoth(x, np.tanh, tf.tanh)
+ self._compareBoth(x, self._sigmoid, tf.sigmoid)
+ self._compareBoth(x, np.sign, tf.sign)
+ self._compareBoth(x, np.sin, tf.sin)
+ self._compareBoth(x, np.cos, tf.cos)
+
+ def testDoubleBasic(self):
+ x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
+ y = (x + .5).astype(np.float64) # no zero
+ z = (x + 15.5).astype(np.float64) # all positive
+ self._compareBoth(x, np.abs, tf.abs)
+ self._compareBoth(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, tf.neg)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareBoth(y, self._inv, tf.inv)
+ self._compareBoth(x, np.square, tf.square)
+ self._compareBoth(z, np.sqrt, tf.sqrt)
+ self._compareBoth(z, self._rsqrt, tf.rsqrt)
+ self._compareBoth(x, np.exp, tf.exp)
+ self._compareBoth(z, np.log, tf.log)
+ self._compareBoth(x, np.tanh, tf.tanh)
+ self._compareBoth(x, self._sigmoid, tf.sigmoid)
+ self._compareBoth(y, np.sign, tf.sign)
+ self._compareBoth(x, np.sin, tf.sin)
+ self._compareBoth(x, np.cos, tf.cos)
+
+ def testInt32Basic(self):
+ x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
+ self._compareCpu(x, np.abs, tf.abs)
+ self._compareCpu(x, np.abs, _ABS)
+ self._compareCpu(x, np.negative, tf.neg)
+ self._compareCpu(x, np.negative, _NEG)
+ self._compareCpu(x, np.square, tf.square)
+ self._compareCpu(x, np.sign, tf.sign)
+
+ def testInt64Basic(self):
+ x = np.arange(
+ -6 << 40, 6 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
+ self._compareCpu(x, np.abs, tf.abs)
+ self._compareCpu(x, np.abs, _ABS)
+ self._compareCpu(x, np.negative, tf.neg)
+ self._compareCpu(x, np.negative, _NEG)
+ self._compareCpu(x, np.square, tf.square)
+ self._compareCpu(x, np.sign, tf.sign)
+
+ def testComplex64Basic(self):
+ x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype(
+ np.complex64)
+ y = x + 0.5 # no zeros
+ self._compareCpu(x, np.abs, tf.abs)
+ self._compareCpu(x, np.abs, _ABS)
+ self._compareCpu(x, np.negative, tf.neg)
+ self._compareCpu(x, np.negative, _NEG)
+ self._compareCpu(y, self._inv, tf.inv)
+ self._compareCpu(x, np.square, tf.square)
+ self._compareCpu(x, np.sqrt, tf.sqrt)
+ self._compareCpu(y, self._rsqrt, tf.rsqrt)
+ self._compareCpu(x, np.exp, tf.exp)
+ self._compareCpu(y, np.log, tf.log)
+ self._compareCpu(x, np.tanh, tf.tanh)
+ self._compareCpu(x, self._sigmoid, tf.sigmoid)
+ self._compareCpu(x, np.sin, tf.sin)
+ self._compareCpu(x, np.cos, tf.cos)
+
+
+class BinaryOpTest(tf.test.TestCase):
+
+ def _compareCpu(self, x, y, np_func, tf_func):
+ np_ans = np_func(x, y)
+ with self.test_session(use_gpu=False):
+ inx = tf.convert_to_tensor(x)
+ iny = tf.convert_to_tensor(y)
+ out = tf_func(inx, iny)
+ tf_cpu = out.eval()
+ # Test that the op takes precedence over numpy operators.
+ np_left = tf_func(x, iny).eval()
+ np_right = tf_func(inx, y).eval()
+
+ self.assertAllClose(np_ans, tf_cpu)
+ self.assertAllClose(np_ans, np_left)
+ self.assertAllClose(np_ans, np_right)
+ self.assertShapeEqual(np_ans, out)
+
+ def _compareGradientX(self, x, y, np_func, tf_func):
+ z = np_func(x, y)
+ zs = list(z.shape)
+ with self.test_session():
+ inx = tf.convert_to_tensor(x)
+ iny = tf.convert_to_tensor(y)
+ out = tf_func(inx, iny)
+ xs = list(x.shape)
+ jacob_t, jacob_n = gc.ComputeGradient(inx, xs, out, zs, x_init_value=x)
+ if x.dtype == np.float32:
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+ elif x.dtype == np.float64:
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
+
+ def _compareGradientY(self, x, y, np_func, tf_func):
+ z = np_func(x, y)
+ zs = list(z.shape)
+ with self.test_session():
+ inx = tf.convert_to_tensor(x)
+ iny = tf.convert_to_tensor(y)
+ out = tf_func(inx, iny)
+ ys = list(np.shape(y))
+ jacob_t, jacob_n = gc.ComputeGradient(iny, ys, out, zs, x_init_value=y)
+ if x.dtype == np.float32:
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+ elif x.dtype == np.float64:
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
+
+ def _compareGpu(self, x, y, np_func, tf_func):
+ np_ans = np_func(x, y)
+ with self.test_session(use_gpu=True):
+ inx = tf.convert_to_tensor(x)
+ iny = tf.convert_to_tensor(y)
+ out = tf_func(inx, iny)
+ tf_gpu = out.eval()
+ self.assertAllClose(np_ans, tf_gpu)
+ self.assertShapeEqual(np_ans, out)
+ # TODO(zhifengc/ke): make gradient checker work on GPU.
+
+ def _compareBoth(self, x, y, np_func, tf_func):
+ self._compareCpu(x, y, np_func, tf_func)
+ if x.dtype == np.float32 or x.dtype == np.float64:
+ self._compareGradientX(x, y, np_func, tf_func)
+ self._compareGradientY(x, y, np_func, tf_func)
+ self._compareGpu(x, y, np_func, tf_func)
+
+ def testFloatBasic(self):
+ x = np.linspace(-10, 10, 6).reshape(1, 3, 2).astype(np.float32)
+ y = np.linspace(20, -20, 6).reshape(1, 3, 2).astype(np.float32)
+ self._compareBoth(x, y, np.add, tf.add)
+ self._compareBoth(x, y, np.subtract, tf.sub)
+ self._compareBoth(x, y, np.multiply, tf.mul)
+ self._compareBoth(x, y + 0.1, np.divide, tf.div)
+ self._compareBoth(x, y, np.add, _ADD)
+ self._compareBoth(x, y, np.subtract, _SUB)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ self._compareBoth(x, y + 0.1, np.divide, _DIV)
+
+ def testFloatDifferentShapes(self):
+ x = np.array([1, 2, 3, 4]).reshape(2, 2).astype(np.float32)
+ y = np.array([1, 2]).reshape(2, 1).astype(np.float32)
+ with self.test_session() as sess:
+ inx = tf.convert_to_tensor(x)
+ iny = tf.convert_to_tensor(y)
+ s = tf.reduce_sum(inx * iny)
+ gx, gy = sess.run(tf.gradients(s, [inx, iny]))
+ # gx is simply the broadcasted y
+ self.assertAllEqual(gx, np.array([1, 1, 2, 2])
+ .reshape(2, 2).astype(np.float32))
+ # gy is x's column summed up
+ self.assertAllEqual(gy, np.array([3, 7]).
+ reshape(2, 1).astype(np.float32))
+
+ def testDoubleBasic(self):
+ x = np.linspace(-10, 10, 6).reshape(1, 3, 2).astype(np.float64)
+ y = np.linspace(20, -20, 6).reshape(1, 3, 2).astype(np.float64)
+ self._compareBoth(x, y, np.add, tf.add)
+ self._compareBoth(x, y, np.subtract, tf.sub)
+ self._compareBoth(x, y, np.multiply, tf.mul)
+ self._compareBoth(x, y + 0.1, np.divide, tf.div)
+ self._compareBoth(x, y, np.add, _ADD)
+ self._compareBoth(x, y, np.subtract, _SUB)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ self._compareBoth(x, y + 0.1, np.divide, _DIV)
+
+ def testInt8Basic(self):
+ x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int8)
+ y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int8)
+ self._compareBoth(x, y, np.multiply, tf.mul)
+ self._compareBoth(x, y, np.multiply, _MUL)
+
+ def testInt16Basic(self):
+ x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int16)
+ y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int16)
+ self._compareBoth(x, y, np.multiply, tf.mul)
+ self._compareBoth(x, y, np.multiply, _MUL)
+
+ def testInt32Basic(self):
+ x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int32)
+ y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int32)
+ self._compareBoth(x, y, np.add, tf.add)
+ self._compareBoth(x, y, np.subtract, tf.sub)
+ self._compareBoth(x, y, np.multiply, tf.mul)
+ # NOTE: int32 division is ill-defined.
+ self._compareBoth(x, y, np.divide, tf.div)
+ self._compareBoth(x, y, np.mod, tf.mod)
+ self._compareBoth(x, y, np.add, _ADD)
+ self._compareBoth(x, y, np.subtract, _SUB)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ # NOTE: int32 division is ill-defined.
+ self._compareBoth(x, y, np.divide, _DIV)
+ self._compareBoth(x, y, np.mod, _MOD)
+
+ def testInt64Basic(self):
+ x = np.arange(1 << 40, 13 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
+ y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int64)
+ self._compareBoth(x, y, np.subtract, tf.sub)
+ self._compareBoth(x, y, np.multiply, tf.mul)
+ # NOTE: int64 division is ill-defined.
+ self._compareBoth(x, y, np.divide, tf.div)
+ self._compareBoth(x, y, np.mod, tf.mod)
+ self._compareBoth(x, y, np.subtract, _SUB)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ # NOTE: int64 division is ill-defined.
+ self._compareBoth(x, y, np.divide, _DIV)
+ self._compareBoth(x, y, np.mod, _MOD)
+
+ def testComplex64Basic(self):
+ x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype(
+ np.complex64)
+ y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype(
+ np.complex64)
+ self._compareCpu(x, y, np.add, tf.add)
+ self._compareCpu(x, y, np.subtract, tf.sub)
+ self._compareCpu(x, y, np.multiply, tf.mul)
+ self._compareCpu(x, y + 0.1, np.divide, tf.div)
+ self._compareCpu(x, y, np.add, _ADD)
+ self._compareCpu(x, y, np.subtract, _SUB)
+ self._compareCpu(x, y, np.multiply, _MUL)
+ self._compareCpu(x, y + 0.1, np.divide, _DIV)
+
+ def _compareBCast(self, xs, ys, dtype, np_func, tf_func):
+ x = (1 + np.linspace(0, 5, np.prod(xs))).astype(dtype).reshape(xs)
+ y = (1 + np.linspace(0, 5, np.prod(ys))).astype(dtype).reshape(ys)
+ self._compareCpu(x, y, np_func, tf_func)
+ if x.dtype == np.float32 or x.dtype == np.float64:
+ self._compareGradientX(x, y, np_func, tf_func)
+ self._compareGradientY(x, y, np_func, tf_func)
+ self._compareGpu(x, y, np_func, tf_func)
+
+ # TODO(josh11b,vrv): Refactor this to use parameterized tests.
+ def _testBCastByFunc(self, funcs, xs, ys):
+ dtypes = [
+ np.float32,
+ np.float64,
+ np.int32,
+ np.int64,
+ np.complex64
+ ]
+ for dtype in dtypes:
+ for (np_func, tf_func) in funcs:
+ self._compareBCast(xs, ys, dtype, np_func, tf_func)
+ self._compareBCast(ys, xs, dtype, np_func, tf_func)
+
+ def _testBCastA(self, xs, ys):
+ funcs = [
+ (np.add, tf.add),
+ (np.add, _ADD),
+ ]
+ self._testBCastByFunc(funcs, xs, ys)
+
+ def _testBCastB(self, xs, ys):
+ funcs = [
+ (np.subtract, tf.sub),
+ (np.subtract, _SUB),
+ (np.power, tf.pow),
+ ]
+ self._testBCastByFunc(funcs, xs, ys)
+
+ def _testBCastC(self, xs, ys):
+ funcs = [
+ (np.multiply, tf.mul),
+ (np.multiply, _MUL),
+ ]
+ self._testBCastByFunc(funcs, xs, ys)
+
+ def _testBCastD(self, xs, ys):
+ funcs = [
+ (np.divide, tf.div),
+ (np.divide, _DIV)
+ ]
+ self._testBCastByFunc(funcs, xs, ys)
+
+ def testBCast_0A(self):
+ self._testBCastA([1, 3, 2], [1])
+
+ def testBCast_0B(self):
+ self._testBCastB([1, 3, 2], [1])
+
+ def testBCast_0C(self):
+ self._testBCastC([1, 3, 2], [1])
+
+ def testBCast_0D(self):
+ self._testBCastD([1, 3, 2], [1])
+
+ def testBCast_1A(self):
+ self._testBCastA([1, 3, 2], [2])
+
+ def testBCast_1B(self):
+ self._testBCastB([1, 3, 2], [2])
+
+ def testBCast_1C(self):
+ self._testBCastC([1, 3, 2], [2])
+
+ def testBCast_1D(self):
+ self._testBCastD([1, 3, 2], [2])
+
+ def testBCast_2A(self):
+ self._testBCastA([1, 3, 2], [3, 2])
+
+ def testBCast_2B(self):
+ self._testBCastB([1, 3, 2], [3, 2])
+
+ def testBCast_2C(self):
+ self._testBCastC([1, 3, 2], [3, 2])
+
+ def testBCast_2D(self):
+ self._testBCastD([1, 3, 2], [3, 2])
+
+ def testBCast_3A(self):
+ self._testBCastA([1, 3, 2], [3, 1])
+
+ def testBCast_3B(self):
+ self._testBCastB([1, 3, 2], [3, 1])
+
+ def testBCast_3C(self):
+ self._testBCastC([1, 3, 2], [3, 1])
+
+ def testBCast_3D(self):
+ self._testBCastD([1, 3, 2], [3, 1])
+
+ def testBCast_4A(self):
+ self._testBCastA([1, 3, 2], [1, 3, 2])
+
+ def testBCast_4B(self):
+ self._testBCastB([1, 3, 2], [1, 3, 2])
+
+ def testBCast_4C(self):
+ self._testBCastC([1, 3, 2], [1, 3, 2])
+
+ def testBCast_4D(self):
+ self._testBCastD([1, 3, 2], [1, 3, 2])
+
+ def testBCast_5A(self):
+ self._testBCastA([1, 3, 2], [2, 3, 1])
+
+ def testBCast_5B(self):
+ self._testBCastB([1, 3, 2], [2, 3, 1])
+
+ def testBCast_5C(self):
+ self._testBCastC([1, 3, 2], [2, 3, 1])
+
+ def testBCast_5D(self):
+ self._testBCastD([1, 3, 2], [2, 3, 1])
+
+ def testBCast_6A(self):
+ self._testBCastA([1, 3, 2], [2, 1, 1])
+
+ def testBCast_6B(self):
+ self._testBCastB([1, 3, 2], [2, 1, 1])
+
+ def testBCast_6C(self):
+ self._testBCastC([1, 3, 2], [2, 1, 1])
+
+ def testBCast_6D(self):
+ self._testBCastD([1, 3, 2], [2, 1, 1])
+
+ def testBCast_7A(self):
+ self._testBCastA([1, 3, 2], [1, 3, 1])
+
+ def testBCast_7B(self):
+ self._testBCastB([1, 3, 2], [1, 3, 1])
+
+ def testBCast_7C(self):
+ self._testBCastC([1, 3, 2], [1, 3, 1])
+
+ def testBCast_7D(self):
+ self._testBCastD([1, 3, 2], [1, 3, 1])
+
+ def testBCast_8A(self):
+ self._testBCastA([2, 1, 5], [2, 3, 1])
+
+ def testBCast_8B(self):
+ self._testBCastB([2, 1, 5], [2, 3, 1])
+
+ def testBCast_8C(self):
+ self._testBCastC([2, 1, 5], [2, 3, 1])
+
+ def testBCast_8D(self):
+ self._testBCastD([2, 1, 5], [2, 3, 1])
+
+ def testBCast_9A(self):
+ self._testBCastA([2, 0, 5], [2, 0, 1])
+
+ def testBCast_9B(self):
+ self._testBCastB([2, 0, 5], [2, 0, 1])
+
+ def testBCast_9C(self):
+ self._testBCastC([2, 0, 5], [2, 0, 1])
+
+ def testBCast_9D(self):
+ self._testBCastD([2, 0, 5], [2, 0, 1])
+
+ def testBCast_10A(self):
+ self._testBCastA([2, 3, 0], [2, 3, 1])
+
+ def testBCast_10B(self):
+ self._testBCastB([2, 3, 0], [2, 3, 1])
+
+ def testBCast_10C(self):
+ self._testBCastC([2, 3, 0], [2, 3, 1])
+
+ def testBCast_10D(self):
+ self._testBCastD([2, 3, 0], [2, 3, 1])
+
+ def testBCast_11A(self):
+ self._testBCastA([1, 3, 2], [1, 3, 2])
+
+ def testBCast_11B(self):
+ self._testBCastB([1, 3, 2], [1, 3, 2])
+
+ def testBCast_11C(self):
+ self._testBCastC([1, 3, 2], [1, 3, 2])
+
+ def testBCast_11D(self):
+ self._testBCastD([1, 3, 2], [1, 3, 2])
+
+ def testBCast_12A(self):
+ self._testBCastA([1, 1, 1, 1, 3, 2], [1, 3, 2])
+
+ def testBCast_12B(self):
+ self._testBCastB([1, 1, 1, 1, 3, 2], [1, 3, 2])
+
+ def testBCast_12C(self):
+ self._testBCastC([1, 1, 1, 1, 3, 2], [1, 3, 2])
+
+ def testBCast_12D(self):
+ self._testBCastD([1, 1, 1, 1, 3, 2], [1, 3, 2])
+
+ def testBCast_13A(self):
+ self._testBCastA([1, 3, 2, 1, 1], [1])
+
+ def testBCast_13B(self):
+ self._testBCastB([1, 3, 2, 1, 1], [1])
+
+ def testBCast_13C(self):
+ self._testBCastC([1, 3, 2, 1, 1], [1])
+
+ def testBCast_13D(self):
+ self._testBCastD([1, 3, 2, 1, 1], [1])
+
+ def testBCast_14A(self):
+ self._testBCastA([2, 3, 1, 1, 5], [1])
+
+ def testBCast_14B(self):
+ self._testBCastB([2, 3, 1, 1, 5], [1])
+
+ def testBCast_14C(self):
+ self._testBCastC([2, 3, 1, 1, 5], [1])
+
+ def testBCast_14D(self):
+ self._testBCastD([2, 3, 1, 1, 5], [1])
+
+ def testBCast_15A(self):
+ self._testBCastA([10, 3, 1, 2], [3, 1, 2])
+
+ def testBCast_15B(self):
+ self._testBCastB([10, 3, 1, 2], [3, 1, 2])
+
+ def testBCast_15C(self):
+ self._testBCastC([10, 3, 1, 2], [3, 1, 2])
+
+ def testBCast_15D(self):
+ self._testBCastD([10, 3, 1, 2], [3, 1, 2])
+
+ def testMismatchedDimensions(self):
+ for func in [tf.add, tf.sub, tf.mul, tf.div,
+ _ADD, _SUB, _MUL, _DIV]:
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, lambda e: "Incompatible shapes" in e.message):
+ func(tf.convert_to_tensor([10.0, 20.0, 30.0]),
+ tf.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
+
+
+class ComparisonOpTest(tf.test.TestCase):
+
+ def _compare(self, func, x, y, dtype):
+ with self.test_session(use_gpu=False):
+ out = func(tf.convert_to_tensor(np.array([x]).astype(dtype)),
+ tf.convert_to_tensor(np.array([y]).astype(dtype)))
+ ret = out.eval()
+ return ret[0]
+
+ def testScalarCompareScalar(self):
+ dtypes = [np.float32, np.float64, np.int32, np.int64]
+ data = [-1, 0, 1]
+ for t in dtypes:
+ for x in data:
+ for y in data:
+ self.assertEqual(self._compare(tf.less, x, y, t),
+ x < y)
+ self.assertEqual(self._compare(tf.less_equal, x, y, t),
+ x <= y)
+ self.assertEqual(self._compare(tf.greater, x, y, t),
+ x > y)
+ self.assertEqual(self._compare(tf.greater_equal, x, y, t),
+ x >= y)
+ self.assertEqual(self._compare(tf.equal, x, y, t),
+ x == y)
+ self.assertEqual(self._compare(tf.not_equal, x, y, t),
+ x != y)
+
+ def _compareCpu(self, x, y, np_func, tf_func):
+ np_ans = np_func(x, y)
+ with self.test_session(use_gpu=False):
+ out = tf_func(tf.convert_to_tensor(x), tf.convert_to_tensor(y))
+ tf_cpu = out.eval()
+ self.assertAllEqual(np_ans, tf_cpu)
+
+ def _compareGpu(self, x, y, np_func, tf_func):
+ np_ans = np_func(x, y)
+ with self.test_session(use_gpu=True):
+ out = tf_func(tf.convert_to_tensor(x), tf.convert_to_tensor(y))
+ tf_gpu = out.eval()
+ self.assertAllEqual(np_ans, tf_gpu)
+
+ def _compareBoth(self, x, y, np_func, tf_func):
+ self._compareCpu(x, y, np_func, tf_func)
+ if x.dtype == np.float32 or x.dtype == np.float64:
+ self._compareGpu(x, y, np_func, tf_func)
+
+ def testTensorCompareTensor(self):
+ x = np.linspace(-15, 15, 6).reshape(1, 3, 2)
+ y = np.linspace(20, -10, 6).reshape(1, 3, 2)
+ for t in [np.float32, np.float64, np.int32, np.int64]:
+ xt = x.astype(t)
+ yt = y.astype(t)
+ self._compareBoth(xt, yt, np.less, tf.less)
+ self._compareBoth(xt, yt, np.less_equal, tf.less_equal)
+ self._compareBoth(xt, yt, np.greater, tf.greater)
+ self._compareBoth(xt, yt, np.greater_equal, tf.greater_equal)
+ self._compareBoth(xt, yt, np.equal, tf.equal)
+ self._compareBoth(xt, yt, np.not_equal, tf.not_equal)
+ # TODO(zhifengc): complex64 doesn't work on GPU yet.
+ self._compareCpu(x.astype(np.complex64), y.astype(np.complex64),
+ np.equal, tf.equal)
+ self._compareCpu(x.astype(np.complex64), y.astype(np.complex64),
+ np.not_equal, tf.not_equal)
+
+ def _compareBCast(self, xs, ys, dtype, np_func, tf_func):
+ x = np.linspace(-15, 15, np.prod(xs)).astype(dtype).reshape(xs)
+ y = np.linspace(20, -10, np.prod(ys)).astype(dtype).reshape(ys)
+ self._compareCpu(x, y, np_func, tf_func)
+ self._compareCpu(y, x, np_func, tf_func)
+ if x.dtype == np.float32 or x.dtype == np.float64:
+ self._compareGpu(x, y, np_func, tf_func)
+ self._compareGpu(y, x, np_func, tf_func)
+
+ def _testBCastByFunc(self, np_func, tf_func):
+ shapes = [
+ ([1, 3, 2], [1]),
+ ([1, 3, 2], [2]),
+ ([1, 3, 2], [3, 2]),
+ ([1, 3, 2], [3, 1]),
+ ([1, 3, 2], [1, 3, 2]),
+ ([1, 3, 2], [2, 3, 1]),
+ ([1, 3, 2], [2, 1, 1]),
+ ([1, 3, 2], [1, 3, 1]),
+ ([2, 1, 5], [2, 3, 1]),
+ ([2, 0, 5], [2, 0, 1]),
+ ([2, 3, 0], [2, 3, 1]),
+ ]
+ dtypes = [
+ np.float32,
+ np.float64,
+ np.int32,
+ np.int64,
+ ]
+ for (xs, ys) in shapes:
+ for dtype in dtypes:
+ self._compareBCast(xs, ys, dtype, np_func, tf_func)
+
+ def testBCastLess(self):
+ self._testBCastByFunc(np.less, tf.less)
+
+ def testBCastLessEqual(self):
+ self._testBCastByFunc(np.less_equal, tf.less_equal)
+
+ def testBCastGreater(self):
+ self._testBCastByFunc(np.greater, tf.greater)
+
+ def testBCastGreaterEqual(self):
+ self._testBCastByFunc(np.greater_equal, tf.greater_equal)
+
+ def testBCastEqual(self):
+ self._testBCastByFunc(np.equal, tf.equal)
+
+ def testBCastNotEqual(self):
+ self._testBCastByFunc(np.not_equal, tf.not_equal)
+
+ def testShapeMismatch(self):
+ dtypes = [np.float32, np.float64, np.int32, np.int64]
+ funcs = [tf.less, tf.less_equal, tf.greater,
+ tf.greater_equal, tf.equal, tf.not_equal]
+ x = np.arange(0, 10).reshape([2, 5])
+ y = np.arange(0, 10).reshape([5, 2])
+ for t in dtypes:
+ for f in funcs:
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, lambda e: "Incompatible shapes" in e.message):
+ f(x.astype(t), y.astype(t))
+
+
+class LogicalOpTest(tf.test.TestCase):
+
+ def _compareBinary(self, x, y, np_func, tf_func, use_gpu=False):
+ np_ans = np_func(x, y)
+ with self.test_session(use_gpu=use_gpu):
+ inx = tf.convert_to_tensor(x)
+ iny = tf.convert_to_tensor(y)
+ out = tf_func(inx, iny)
+ tf_val = out.eval()
+ self.assertEqual(out.dtype, tf.bool)
+ self.assertAllEqual(np_ans, tf_val)
+ self.assertShapeEqual(np_ans, out)
+
+ def _not(self, x, use_gpu=False):
+ np_ans = np.logical_not(x)
+ with self.test_session(use_gpu=use_gpu):
+ out = tf.logical_not(tf.convert_to_tensor(x))
+ tf_val = out.eval()
+ self.assertEqual(out.dtype, tf.bool)
+ self.assertAllEqual(np_ans, tf_val)
+ self.assertShapeEqual(np_ans, out)
+
+ def testScalar(self):
+ data = [np.array([True]), np.array([False])]
+ for use_gpu in [True, False]:
+ for x in data:
+ self._not(x, use_gpu)
+ for x in data:
+ for y in data:
+ self._compareBinary(
+ x, y, np.logical_and, tf.logical_and, use_gpu)
+ self._compareBinary(
+ x, y, np.logical_or, tf.logical_or, use_gpu)
+ self._compareBinary(
+ x, y, np.logical_xor, tf.logical_xor, use_gpu)
+
+ def testTensor(self):
+ x = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
+ y = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
+ for use_gpu in [True, False]:
+ self._not(x, use_gpu)
+ self._compareBinary(x, y, np.logical_and, tf.logical_and, use_gpu)
+ self._compareBinary(x, y, np.logical_or, tf.logical_or, use_gpu)
+ self._compareBinary(x, y, np.logical_xor, tf.logical_xor, use_gpu)
+
+ def testBCast(self):
+ shapes = [
+ ([1, 3, 2], [1]),
+ ([1, 3, 2], [2]),
+ ([1, 3, 2], [3, 2]),
+ ([1, 3, 2], [3, 1]),
+ ([1, 3, 2], [1, 3, 2]),
+ ([1, 3, 2], [2, 3, 1]),
+ ([1, 3, 2], [2, 1, 1]),
+ ([1, 3, 2], [1, 3, 1]),
+ ([2, 1, 5], [2, 3, 1]),
+ ([2, 0, 5], [2, 0, 1]),
+ ([2, 3, 0], [2, 3, 1]),
+ ]
+ for (xs, ys) in shapes:
+ x = np.random.randint(0, 2, np.prod(xs)).astype(np.bool).reshape(xs)
+ y = np.random.randint(0, 2, np.prod(ys)).astype(np.bool).reshape(ys)
+ for use_gpu in [True, False]:
+ self._compareBinary(x, y, np.logical_and, tf.logical_and, use_gpu)
+ self._compareBinary(x, y, np.logical_or, tf.logical_or, use_gpu)
+ self._compareBinary(x, y, np.logical_xor, tf.logical_xor, use_gpu)
+
+ def testShapeMismatch(self):
+ x = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
+ y = np.random.randint(0, 2, 6).astype(np.bool).reshape(3, 2, 1)
+ for f in [tf.logical_and, tf.logical_or, tf.logical_xor]:
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, lambda e: "Incompatible shapes" in e.message):
+ f(x, y)
+
+
+class SelectOpTest(tf.test.TestCase):
+
+ def _compare(self, c, x, y, use_gpu):
+ np_ans = np.where(c, x, y)
+ with self.test_session(use_gpu=use_gpu):
+ out = tf.select(c, x, y)
+ tf_ans = out.eval()
+ self.assertAllEqual(np_ans, tf_ans)
+ self.assertShapeEqual(np_ans, out)
+
+ def _compareGradientX(self, c, x, y):
+ with self.test_session():
+ inx = tf.convert_to_tensor(x)
+ iny = tf.convert_to_tensor(y)
+ out = tf.select(c, inx, iny)
+ s = list(np.shape(c))
+ jacob_t, jacob_n = gc.ComputeGradient(inx, s, out, s, x_init_value=x)
+ if x.dtype == np.float32:
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+ elif x.dtype == np.float64:
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
+
+ def _compareGradientY(self, c, x, y):
+ with self.test_session():
+ inx = tf.convert_to_tensor(x)
+ iny = tf.convert_to_tensor(y)
+ out = tf.select(c, inx, iny)
+ s = list(np.shape(c))
+ jacob_t, jacob_n = gc.ComputeGradient(iny, s, out, s, x_init_value=y)
+ if x.dtype == np.float32:
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+ elif x.dtype == np.float64:
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
+
+ def testBasic(self):
+ c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
+ x = np.random.rand(1, 3, 2) * 100
+ y = np.random.rand(1, 3, 2) * 100
+ for t in [np.float32, np.float64, np.int32, np.int64, np.complex64]:
+ xt = x.astype(t)
+ yt = y.astype(t)
+ self._compare(c, xt, yt, use_gpu=False)
+ if t in [np.float32, np.float64]:
+ self._compare(c, xt, yt, use_gpu=True)
+
+ def testGradients(self):
+ c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
+ x = np.random.rand(1, 3, 2) * 100
+ y = np.random.rand(1, 3, 2) * 100
+ for t in [np.float32, np.float64]:
+ xt = x.astype(t)
+ yt = y.astype(t)
+ self._compareGradientX(c, xt, yt)
+ self._compareGradientY(c, xt, yt)
+
+ def testShapeMismatch(self):
+ c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
+ x = np.random.rand(1, 3, 2) * 100
+ y = np.random.rand(2, 5, 3) * 100
+ for t in [np.float32, np.float64, np.int32, np.int64, np.complex64]:
+ xt = x.astype(t)
+ yt = y.astype(t)
+ with self.assertRaises(ValueError):
+ tf.select(c, xt, yt)
+
+
+class MinMaxOpTest(tf.test.TestCase):
+
+ def _compare(self, x, y, use_gpu):
+ np_min, np_max = np.minimum(x, y), np.maximum(x, y)
+ with self.test_session(use_gpu=use_gpu) as sess:
+ inx = tf.convert_to_tensor(x)
+ iny = tf.convert_to_tensor(y)
+ omin, omax = tf.minimum(inx, iny), tf.maximum(inx, iny)
+ tf_min, tf_max = sess.run([omin, omax])
+ self.assertAllEqual(np_min, tf_min)
+ self.assertAllEqual(np_max, tf_max)
+
+ def testBasic(self):
+ x = np.random.rand(1, 3, 2) * 100.
+ y = np.random.rand(1, 3, 2) * 100.
+ for t in [np.float32, np.float64, np.int32, np.int64]:
+ self._compare(x.astype(t), y.astype(t), use_gpu=False)
+ self._compare(x.astype(t), y.astype(t), use_gpu=True)
+
+ def testDifferentShapes(self):
+ x = np.random.rand(1, 3, 2) * 100.
+ y = np.random.rand(2) * 100. # should broadcast
+ for t in [np.float32, np.float64, np.int32, np.int64]:
+ self._compare(x.astype(t), y.astype(t), use_gpu=False)
+ self._compare(x.astype(t), y.astype(t), use_gpu=True)
+
+ def testScalar(self):
+ x = np.random.rand(1, 3, 2) * 100.
+ y = np.asscalar(np.random.rand(1) * 100.) # should broadcast
+ # dropped np.float64, int64 because TF automatically converts to 32 bit
+ for t in [np.float32, np.int32]:
+ self._compare(x.astype(t), t(y), use_gpu=False)
+ self._compare(x.astype(t), t(y), use_gpu=True)
+
+ def _compareGradientX(self, func, x, y):
+ with self.test_session():
+ inx = tf.convert_to_tensor(x)
+ iny = tf.convert_to_tensor(y)
+ out = func(inx, iny)
+ s = list(np.shape(x))
+ jacob_t, jacob_n = gc.ComputeGradient(inx, s, out, s, x_init_value=x)
+ if x.dtype == np.float32:
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+ elif x.dtype == np.float64:
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
+
+ def _compareGradientY(self, func, x, y):
+ with self.test_session():
+ inx = tf.convert_to_tensor(x)
+ iny = tf.convert_to_tensor(y)
+ out = func(inx, iny)
+ s = list(np.shape(x))
+ jacob_t, jacob_n = gc.ComputeGradient(iny, s, out, s, x_init_value=y)
+ if x.dtype == np.float32:
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+ elif x.dtype == np.float64:
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
+
+ def testGradients(self):
+ x = np.random.rand(1, 3, 2) * 100.
+ # ensure x != y
+ y = x + (np.random.randint(2, size=x.shape) - .5) * 2 # -1 or +1
+ self._compareGradientX(tf.maximum, x, y)
+ self._compareGradientY(tf.maximum, x, y)
+ self._compareGradientX(tf.minimum, x, y)
+ self._compareGradientY(tf.minimum, x, y)
+
+
+class MathOpsOverloadTest(tf.test.TestCase):
+
+ def _computeTensorAndLiteral(self, x, y, dtype, func):
+ with self.test_session(use_gpu=False):
+ inx = tf.convert_to_tensor(x, dtype=dtype)
+ z = func(inx, y) # Should use __add__, __sub__, etc.
+ return z.eval()
+
+ def _computeLiteralAndTensor(self, x, y, dtype, func):
+ with self.test_session(use_gpu=False):
+ iny = tf.convert_to_tensor(y, dtype=dtype)
+ z = func(x, iny) # Should use __radd__, __rsub__, etc.
+ return z.eval()
+
+ def _compareBinary(self, x, y, dtype, np_func, tf_func):
+ np_ans = np_func(x, y)
+ self.assertAllClose(np_ans, self._computeTensorAndLiteral(
+ x, y, dtype, tf_func))
+ self.assertAllClose(np_ans, self._computeLiteralAndTensor(
+ x, y, dtype, tf_func))
+
+ def _compareUnary(self, x, dtype, np_func, tf_func):
+ np_ans = np_func(x)
+ with self.test_session(use_gpu=False):
+ self.assertAllClose(np_ans, tf_func(tf.convert_to_tensor(x, dtype=dtype)).eval())
+
+ def testOverload(self):
+ dtypes = [
+ tf.float32,
+ tf.float64,
+ tf.int32,
+ tf.int64,
+ tf.complex64,
+ ]
+ funcs = [
+ (np.add, _ADD),
+ (np.subtract, _SUB),
+ (np.multiply, _MUL),
+ (np.divide, _DIV)
+ ]
+ for dtype in dtypes:
+ for np_func, tf_func in funcs:
+ self._compareBinary(10, 5, dtype, np_func, tf_func)
+ # Mod only works for int32 and int64.
+ for dtype in [tf.int32, tf.int64]:
+ self._compareBinary(10, 3, dtype, np.mod, _MOD)
+
+ def testOverloadComparisons(self):
+ dtypes = [
+ tf.float32,
+ tf.float64,
+ tf.int32,
+ tf.int64,
+ ]
+ funcs = [
+ (np.less, _LT),
+ (np.less_equal, _LE),
+ (np.greater, _GT),
+ (np.greater_equal, _GE),
+ ]
+ for dtype in dtypes:
+ for np_func, tf_func in funcs:
+ self._compareBinary(10, 5, dtype, np_func, tf_func)
+ logical_funcs = [
+ (np.logical_and, _AND),
+ (np.logical_or, _OR),
+ (np.logical_xor, _XOR),
+ ]
+ for np_func, tf_func in logical_funcs:
+ self._compareBinary(True, False, tf.bool, np_func, tf_func)
+ self._compareBinary(True, True, tf.bool, np_func, tf_func)
+ self._compareBinary(False, False, tf.bool, np_func, tf_func)
+ self._compareBinary(False, True, tf.bool, np_func, tf_func)
+ self._compareBinary([True, True, False, False],
+ [True, False, True, False],
+ tf.bool, np_func, tf_func)
+ self._compareUnary(True, tf.bool, np.logical_not, _INV)
+ self._compareUnary(False, tf.bool, np.logical_not, _INV)
+ self._compareUnary([True, False], tf.bool, np.logical_not, _INV)
+
+
+class IsFiniteInfNanTest(tf.test.TestCase):
+
+ def _compare(self, x, use_gpu):
+ np_finite, np_inf, np_nan = np.isfinite(x), np.isinf(x), np.isnan(x)
+ with self.test_session(use_gpu=use_gpu) as sess:
+ inx = tf.convert_to_tensor(x)
+ ofinite, oinf, onan = tf.is_finite(inx), tf.is_inf(
+ inx), tf.is_nan(inx)
+ tf_finite, tf_inf, tf_nan = sess.run([ofinite, oinf, onan])
+ self.assertAllEqual(np_inf, tf_inf)
+ self.assertAllEqual(np_nan, tf_nan)
+ self.assertAllEqual(np_finite, tf_finite)
+ self.assertShapeEqual(np_inf, oinf)
+ self.assertShapeEqual(np_nan, onan)
+ self.assertShapeEqual(np_finite, ofinite)
+
+ def _testDtype(self, dtype):
+ fi = np.finfo(dtype)
+ data = np.array([0, -1, 1, fi.resolution, -fi.resolution, fi.min, fi.max,
+ -np.inf, np.inf, np.nan]).astype(dtype)
+ self._compare(data, use_gpu=False)
+ self._compare(data, use_gpu=True)
+
+ def testFloat(self):
+ self._testDtype(np.float32)
+
+ def testDouble(self):
+ self._testDtype(np.float64)
+
+
+class RoundingTest(tf.test.TestCase):
+
+ def _compare(self, x, use_gpu):
+ np_floor, np_ceil = np.floor(x), np.ceil(x)
+ with self.test_session(use_gpu=use_gpu) as sess:
+ inx = tf.convert_to_tensor(x)
+ ofloor, oceil = tf.floor(inx), tf.ceil(inx)
+ tf_floor, tf_ceil = sess.run([ofloor, oceil])
+ self.assertAllEqual(np_floor, tf_floor)
+ self.assertAllEqual(np_ceil, tf_ceil)
+ self.assertShapeEqual(np_floor, ofloor)
+ self.assertShapeEqual(np_ceil, oceil)
+
+ def _testDtype(self, dtype):
+ data = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(dtype)
+ self._compare(data, use_gpu=True)
+ self._compare(data, use_gpu=True)
+
+ def testTypes(self):
+ for dtype in [np.float32, np.float64]:
+ self._testDtype(dtype)
+
+
+class ComplexMakeRealImagTest(tf.test.TestCase):
+
+ def _compareMake(self, real, imag, use_gpu):
+ np_ans = real + (1j) * imag
+ with self.test_session(use_gpu=use_gpu):
+ real = tf.convert_to_tensor(real)
+ imag = tf.convert_to_tensor(imag)
+ tf_ans = tf.complex(real, imag)
+ out = tf_ans.eval()
+ self.assertAllEqual(np_ans, out)
+ self.assertShapeEqual(np_ans, tf_ans)
+
+ def testMake(self):
+ real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float32)
+ imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float32)
+ for use_gpu in [False, True]:
+ self._compareMake(real, imag, use_gpu)
+ self._compareMake(real, 12.0, use_gpu)
+ self._compareMake(23.0, imag, use_gpu)
+
+ def _compareRealImag(self, cplx, use_gpu):
+ np_real, np_imag = np.real(cplx), np.imag(cplx)
+ with self.test_session(use_gpu=use_gpu) as sess:
+ inx = tf.convert_to_tensor(cplx)
+ tf_real = tf.real(inx)
+ tf_imag = tf.imag(inx)
+ tf_real_val, tf_imag_val = sess.run([tf_real, tf_imag])
+ self.assertAllEqual(np_real, tf_real_val)
+ self.assertAllEqual(np_imag, tf_imag_val)
+ self.assertShapeEqual(np_real, tf_real)
+ self.assertShapeEqual(np_imag, tf_imag)
+
+ def testRealImag(self):
+ real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float32)
+ imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float32)
+ cplx = real + (1j) * imag
+ self._compareRealImag(cplx, use_gpu=False)
+ self._compareRealImag(cplx, use_gpu=True)
+
+ def _compareConj(self, cplx, use_gpu):
+ np_ans = np.conj(cplx)
+ with self.test_session(use_gpu=use_gpu):
+ inx = tf.convert_to_tensor(cplx)
+ tf_conj = tf.conj(inx)
+ tf_ans = tf_conj.eval()
+ self.assertAllEqual(np_ans, tf_ans)
+ self.assertShapeEqual(np_ans, tf_conj)
+
+ def testConj(self):
+ real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float32)
+ imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float32)
+ cplx = real + (1j) * imag
+ self._compareConj(cplx, use_gpu=False)
+ self._compareConj(cplx, use_gpu=True)
+
+ def _compareGradient(self, x):
+ # x[:, 0] is real, x[:, 1] is imag. We combine real and imag into
+ # complex numbers. Then, we extract real and imag parts and
+ # computes the squared sum. This is obviously the same as sum(real
+ # * real) + sum(imag * imag). We just want to make sure the
+ # gradient function is checked.
+ with self.test_session():
+ inx = tf.convert_to_tensor(x)
+ real, imag = tf.split(1, 2, inx)
+ real, imag = tf.reshape(real, [-1]), tf.reshape(imag, [-1])
+ cplx = tf.complex(real, imag)
+ cplx = tf.conj(cplx)
+ loss = tf.reduce_sum(
+ tf.square(tf.real(cplx))) + tf.reduce_sum(
+ tf.square(tf.imag(cplx)))
+ epsilon = 1e-3
+ jacob_t, jacob_n = gc.ComputeGradient(inx, list(x.shape), loss, [1],
+ x_init_value=x, delta=epsilon)
+ self.assertAllClose(jacob_t, jacob_n, rtol=epsilon, atol=epsilon)
+
+ def testGradient(self):
+ data = np.arange(1, 2, 0.10).reshape([5, 2]).astype(np.float32)
+ self._compareGradient(data)
+
+ def _compareMulGradient(self, data):
+ # data is a float matrix of shape [n, 4]. data[:, 0], data[:, 1],
+ # data[:, 2], data[:, 3] are real parts of x, imaginary parts of
+ # x, real parts of y and imaginary parts of y.
+ with self.test_session():
+ inp = tf.convert_to_tensor(data)
+ xr, xi, yr, yi = tf.split(1, 4, inp)
+
+ def vec(x): # Reshape to a vector
+ return tf.reshape(x, [-1])
+ xr, xi, yr, yi = vec(xr), vec(xi), vec(yr), vec(yi)
+
+ def cplx(r, i): # Combine to a complex vector
+ return tf.complex(r, i)
+ x, y = cplx(xr, xi), cplx(yr, yi)
+ # z is x times y in complex plane.
+ z = x * y
+ # Defines the loss function as the sum of all coefficients of z.
+ loss = tf.reduce_sum(tf.real(z) + tf.imag(z))
+ epsilon = 0.005
+ jacob_t, jacob_n = gc.ComputeGradient(inp, list(data.shape), loss, [1],
+ x_init_value=data, delta=epsilon)
+ self.assertAllClose(jacob_t, jacob_n, rtol=epsilon, atol=epsilon)
+
+ def testMulGradient(self):
+ data = np.arange(1, 2, 0.125).reshape([2, 4]).astype(np.float32)
+ self._compareMulGradient(data)
+
+
+class AccumulateTest(tf.test.TestCase):
+
+ def testSimple(self):
+ with self.test_session():
+ random_arrays = [np.random.rand(16, 16, 16, 16).astype(np.float32)
+ for _ in range(20)]
+ random_tensors = [tf.convert_to_tensor(x, dtype=tf.float32)
+ for x in random_arrays]
+ tf_val = tf.accumulate_n(random_tensors)
+ np_val = random_arrays[0]
+ for random_array in random_arrays[1:]:
+ np_val += random_array
+ self.assertAllClose(np_val, tf_val.eval())
+
+ def testZeroArgs(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf_val = tf.accumulate_n([])
+ tf_val.eval()
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/decode_csv_op_test.py b/tensorflow/python/kernel_tests/decode_csv_op_test.py
new file mode 100644
index 0000000000..ae0917f8c4
--- /dev/null
+++ b/tensorflow/python/kernel_tests/decode_csv_op_test.py
@@ -0,0 +1,148 @@
+"""Tests for DecodeCSV op from parsing_ops."""
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class DecodeCSVOpTest(tf.test.TestCase):
+
+ def _test(self, args, expected_out=None, expected_err_re=None):
+ with self.test_session() as sess:
+ decode = tf.decode_csv(**args)
+
+ if expected_err_re is None:
+ out = sess.run(decode)
+
+ for i, field in enumerate(out):
+ if field.dtype == np.float32:
+ self.assertAllClose(field, expected_out[i])
+ else:
+ self.assertAllEqual(field, expected_out[i])
+
+ else:
+ with self.assertRaisesOpError(expected_err_re):
+ sess.run(decode)
+
+ def testSimple(self):
+ args = {"records": ["1", "2", '"3"'], "record_defaults": [[1]],}
+
+ expected_out = [[1, 2, 3]]
+
+ self._test(args, expected_out)
+
+ def testScalar(self):
+ args = {"records": '1,""', "record_defaults": [[3], [4]]}
+
+ expected_out = [1, 4]
+
+ self._test(args, expected_out)
+
+ def test2D(self):
+ args = {"records": [["1", "2"], ['""', "4"]], "record_defaults": [[5]]}
+ expected_out = [[[1, 2], [5, 4]]]
+
+ self._test(args, expected_out)
+
+ def testInt64(self):
+ args = {
+ "records": ["1", "2", '"2147483648"'],
+ "record_defaults": [np.array([],
+ dtype=np.int64)],
+ }
+
+ expected_out = [[1, 2, 2147483648]]
+
+ self._test(args, expected_out)
+
+ def testComplexString(self):
+ args = {
+ "records": ['"1.0"', '"ab , c"', '"a\nbc"', '"ab""c"', " abc "],
+ "record_defaults": [["1"]]
+ }
+
+ expected_out = [["1.0", "ab , c", "a\nbc", 'ab"c', " abc "]]
+
+ self._test(args, expected_out)
+
+ def testMultiRecords(self):
+ args = {
+ "records": ["1.0,4,aa", "0.2,5,bb", "3,6,cc"],
+ "record_defaults": [[1.0], [1], ["aa"]]
+ }
+
+ expected_out = [[1.0, 0.2, 3], [4, 5, 6], ["aa", "bb", "cc"]]
+
+ self._test(args, expected_out)
+
+ def testWithDefaults(self):
+ args = {
+ "records": [",1,", "0.2,3,bcd", "3.0,,"],
+ "record_defaults": [[1.0], [0], ["a"]]
+ }
+
+ expected_out = [[1.0, 0.2, 3.0], [1, 3, 0], ["a", "bcd", "a"]]
+
+ self._test(args, expected_out)
+
+ def testWithTabDelim(self):
+ args = {
+ "records": ["1\t1", "0.2\t3", "3.0\t"],
+ "record_defaults": [[1.0], [0]],
+ "field_delim": "\t"
+ }
+
+ expected_out = [[1.0, 0.2, 3.0], [1, 3, 0]]
+
+ self._test(args, expected_out)
+
+ def testWithoutDefaultsError(self):
+ args = {
+ "records": [",1", "0.2,3", "3.0,"],
+ "record_defaults": [[1.0], np.array([],
+ dtype=np.int32)]
+ }
+
+ self._test(args,
+ expected_err_re="Field 1 is required but missing in record 2!")
+
+ def testWrongFieldIntError(self):
+ args = {
+ "records": [",1", "0.2,234a", "3.0,2"],
+ "record_defaults": [[1.0], np.array([],
+ dtype=np.int32)]
+ }
+
+ self._test(args,
+ expected_err_re="Field 1 in record 1 is not a valid int32: 234a")
+
+ def testOutOfRangeError(self):
+ args = {
+ "records": ["1", "9999999999999999999999999", "3"],
+ "record_defaults": [[1]]
+ }
+
+ self._test(args,
+ expected_err_re="Field 0 in record 1 is not a valid int32: ")
+
+ def testWrongFieldFloatError(self):
+ args = {
+ "records": [",1", "0.2,2", "3.0adf,3"],
+ "record_defaults": [[1.0], np.array([],
+ dtype=np.int32)]
+ }
+
+ self._test(args,
+ expected_err_re="Field 0 in record 2 is not a valid float: ")
+
+ def testWrongFieldStringError(self):
+ args = {"records": ['"1,a,"', "0.22", 'a"bc'], "record_defaults": [["a"]]}
+
+ self._test(
+ args,
+ expected_err_re="Unquoted fields cannot have quotes/CRLFs inside")
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/decode_raw_op_test.py b/tensorflow/python/kernel_tests/decode_raw_op_test.py
new file mode 100644
index 0000000000..abd50a7527
--- /dev/null
+++ b/tensorflow/python/kernel_tests/decode_raw_op_test.py
@@ -0,0 +1,44 @@
+"""Tests for DecodeRaw op from parsing_ops."""
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class DecodeRawOpTest(tf.test.TestCase):
+
+ def testToUint8(self):
+ with self.test_session():
+ in_bytes = tf.placeholder(tf.string, shape=[2])
+ decode = tf.decode_raw(in_bytes, out_type=tf.uint8)
+ self.assertEqual([2, None], decode.get_shape().as_list())
+
+ result = decode.eval(feed_dict={in_bytes: ["A", "a"]})
+ self.assertAllEqual([[ord("A")], [ord("a")]], result)
+
+ result = decode.eval(feed_dict={in_bytes: ["wer", "XYZ"]})
+ self.assertAllEqual([[ord("w"), ord("e"), ord("r")],
+ [ord("X"), ord("Y"), ord("Z")]], result)
+
+ with self.assertRaisesOpError(
+ "DecodeRaw requires input strings to all be the same size, but "
+ "element 1 has size 5 != 6"):
+ decode.eval(feed_dict={in_bytes: ["short", "longer"]})
+
+ def testToInt16(self):
+ with self.test_session():
+ in_bytes = tf.placeholder(tf.string, shape=[None])
+ decode = tf.decode_raw(in_bytes, out_type=tf.int16)
+ self.assertEqual([None, None], decode.get_shape().as_list())
+
+ result = decode.eval(feed_dict={in_bytes: ["AaBC"]})
+ self.assertAllEqual([[ord("A") + ord("a") * 256,
+ ord("B") + ord("C") * 256]], result)
+
+ with self.assertRaisesOpError(
+ "Input to DecodeRaw has length 3 that is not a multiple of 2, the "
+ "size of int16"):
+ decode.eval(feed_dict={in_bytes: ["123", "456"]})
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py b/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
new file mode 100644
index 0000000000..ad0724931e
--- /dev/null
+++ b/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
@@ -0,0 +1,60 @@
+"""Tests for state updating ops that may have benign race conditions."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class AssignOpTest(tf.test.TestCase):
+
+ # NOTE(mrry): We exclude thess tests from the TSAN TAP target, because they
+ # contain benign and deliberate data races when multiple threads update
+ # the same parameters without a lock.
+ def testParallelUpdateWithoutLocking(self):
+ with self.test_session() as sess:
+ ones_t = tf.fill([1024, 1024], 1.0)
+ p = tf.Variable(tf.zeros([1024, 1024]))
+ adds = [tf.assign_add(p, ones_t, use_locking=False)
+ for _ in range(20)]
+ tf.initialize_all_variables().run()
+
+ def run_add(add_op):
+ sess.run(add_op)
+ threads = [self.checkedThread(target=run_add, args=(add_op,))
+ for add_op in adds]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ vals = p.eval()
+ ones = np.ones((1024, 1024)).astype(np.float32)
+ self.assertTrue((vals >= ones).all())
+ self.assertTrue((vals <= ones * 20).all())
+
+ def testParallelAssignWithoutLocking(self):
+ with self.test_session() as sess:
+ ones_t = tf.fill([1024, 1024], float(1))
+ p = tf.Variable(tf.zeros([1024, 1024]))
+ assigns = [tf.assign(p, tf.mul(ones_t, float(i)), False)
+ for i in range(1, 21)]
+ tf.initialize_all_variables().run()
+
+ def run_assign(assign_op):
+ sess.run(assign_op)
+ threads = [self.checkedThread(target=run_assign, args=(assign_op,))
+ for assign_op in assigns]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ vals = p.eval()
+
+ # Assert every element is taken from one of the assignments.
+ self.assertTrue((vals > 0).all())
+ self.assertTrue((vals <= 20).all())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/dense_update_ops_test.py b/tensorflow/python/kernel_tests/dense_update_ops_test.py
new file mode 100644
index 0000000000..2e1ea468c3
--- /dev/null
+++ b/tensorflow/python/kernel_tests/dense_update_ops_test.py
@@ -0,0 +1,151 @@
+"""Tests for tensorflow.ops.tf.Assign*."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class AssignOpTest(tf.test.TestCase):
+
+ def _initAssignFetch(self, x, y, use_gpu=False):
+ """Initialize a param to init and update it with y."""
+ super(AssignOpTest, self).setUp()
+ with self.test_session(use_gpu=use_gpu):
+ p = tf.Variable(x)
+ assign = tf.assign(p, y)
+ p.initializer.run()
+ new_value = assign.eval()
+ return p.eval(), new_value
+
+ def _initAssignAddFetch(self, x, y, use_gpu=False):
+ """Initialize a param to init, and compute param += y."""
+ with self.test_session(use_gpu=use_gpu):
+ p = tf.Variable(x)
+ add = tf.assign_add(p, y)
+ p.initializer.run()
+ new_value = add.eval()
+ return p.eval(), new_value
+
+ def _initAssignSubFetch(self, x, y, use_gpu=False):
+ """Initialize a param to init, and compute param -= y."""
+ with self.test_session(use_gpu=use_gpu):
+ p = tf.Variable(x)
+ sub = tf.assign_sub(p, y)
+ p.initializer.run()
+ new_value = sub.eval()
+ return p.eval(), new_value
+
+ def _testTypes(self, vals):
+ for dtype in [np.float32, np.float64, np.int32, np.int64]:
+ x = np.zeros(vals.shape).astype(dtype)
+ y = vals.astype(dtype)
+ var_value, op_value = self._initAssignFetch(x, y, use_gpu=False)
+ self.assertAllEqual(y, var_value)
+ self.assertAllEqual(y, op_value)
+ var_value, op_value = self._initAssignAddFetch(x, y, use_gpu=False)
+ self.assertAllEqual(x + y, var_value)
+ self.assertAllEqual(x + y, op_value)
+ var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=False)
+ self.assertAllEqual(x - y, var_value)
+ self.assertAllEqual(x - y, op_value)
+ if tf.test.IsBuiltWithCuda() and dtype in [np.float32, np.float64]:
+ var_value, op_value = self._initAssignFetch(x, y, use_gpu=True)
+ self.assertAllEqual(y, var_value)
+ self.assertAllEqual(y, op_value)
+ var_value, op_value = self._initAssignAddFetch(x, y, use_gpu=True)
+ self.assertAllEqual(x + y, var_value)
+ self.assertAllEqual(x + y, op_value)
+ var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=False)
+ self.assertAllEqual(x - y, var_value)
+ self.assertAllEqual(x - y, op_value)
+
+ def testBasic(self):
+ self._testTypes(np.arange(0, 20).reshape([4, 5]))
+
+ def testAssignNonStrictShapeChecking(self):
+ with self.test_session():
+ data = tf.fill([1024, 1024], 0)
+ p = tf.Variable([1])
+ a = tf.assign(p, data, validate_shape=False)
+ a.op.run()
+ self.assertAllEqual(p.eval(), data.eval())
+
+ # Assign to yet another shape
+ data2 = tf.fill([10, 10], 1)
+ a2 = tf.assign(p, data2, validate_shape=False)
+ a2.op.run()
+ self.assertAllEqual(p.eval(), data2.eval())
+
+ def testInitRequiredAssignAdd(self):
+ with self.test_session():
+ p = tf.Variable(tf.fill([1024, 1024], 1),
+ tf.int32)
+ a = tf.assign_add(p, tf.fill([1024, 1024], 0))
+ with self.assertRaisesOpError("use uninitialized"):
+ a.op.run()
+
+ def testInitRequiredAssignSub(self):
+ with self.test_session():
+ p = tf.Variable(tf.fill([1024, 1024], 1),
+ tf.int32)
+ a = tf.assign_sub(p, tf.fill([1024, 1024], 0))
+ with self.assertRaisesOpError("use uninitialized"):
+ a.op.run()
+
+ # NOTE(mrry): See also
+ # dense_update_ops_no_tsan_test.AssignOpTest, which contains a benign
+ # data race and must run without TSAN.
+ def testParallelUpdateWithLocking(self):
+ with self.test_session() as sess:
+ zeros_t = tf.fill([1024, 1024], 0.0)
+ ones_t = tf.fill([1024, 1024], 1.0)
+ p = tf.Variable(zeros_t)
+ adds = [tf.assign_add(p, ones_t, use_locking=True)
+ for _ in range(20)]
+ p.initializer.run()
+
+ def run_add(add_op):
+ sess.run(add_op)
+ threads = [
+ self.checkedThread(target=run_add, args=(add_op,)) for add_op in adds]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ vals = p.eval()
+ ones = np.ones((1024, 1024)).astype(np.float32)
+ self.assertAllEqual(vals, ones * 20)
+
+ # NOTE(mrry): See also
+ # dense_update_ops_no_tsan_test.[...].testParallelAssignWithoutLocking,
+ # which contains a benign data race and must run without TSAN.
+ def testParallelAssignWithLocking(self):
+ with self.test_session() as sess:
+ zeros_t = tf.fill([1024, 1024], 0.0)
+ ones_t = tf.fill([1024, 1024], 1.0)
+ p = tf.Variable(zeros_t)
+ assigns = [tf.assign(p, tf.mul(ones_t, float(i)),
+ use_locking=True)
+ for i in range(1, 21)]
+ p.initializer.run()
+
+ def run_assign(assign_op):
+ sess.run(assign_op)
+ threads = [self.checkedThread(target=run_assign, args=(assign_op,))
+ for assign_op in assigns]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ vals = p.eval()
+
+ # Assert every element is the same, and taken from one of the assignments.
+ self.assertTrue(vals[0, 0] > 0)
+ self.assertTrue(vals[0, 0] <= 20)
+ self.assertAllEqual(vals, np.ones([1024, 1024]) * vals[0, 0])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/determinant_op_test.py b/tensorflow/python/kernel_tests/determinant_op_test.py
new file mode 100644
index 0000000000..d4e2b88339
--- /dev/null
+++ b/tensorflow/python/kernel_tests/determinant_op_test.py
@@ -0,0 +1,72 @@
+"""Tests for tensorflow.ops.tf.MatrixDeterminant."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class DeterminantOpTest(tf.test.TestCase):
+
+ def _compareDeterminant(self, matrix_x):
+ with self.test_session():
+ if matrix_x.ndim == 2:
+ tf_ans = tf.matrix_determinant(matrix_x)
+ else:
+ tf_ans = tf.batch_matrix_determinant(matrix_x)
+ out = tf_ans.eval()
+ shape = matrix_x.shape
+ if shape[-1] == 0 and shape[-2] == 0:
+ np_ans = np.ones(shape[:-2]).astype(matrix_x.dtype)
+ else:
+ np_ans = np.array(np.linalg.det(matrix_x)).astype(matrix_x.dtype)
+ self.assertAllClose(np_ans, out)
+ self.assertShapeEqual(np_ans, tf_ans)
+
+ def testBasic(self):
+ # 2x2 matrices
+ self._compareDeterminant(np.array([[2., 3.], [3., 4.]]).astype(np.float32))
+ self._compareDeterminant(np.array([[0., 0.], [0., 0.]]).astype(np.float32))
+ # 5x5 matrices (Eigen forces LU decomposition)
+ self._compareDeterminant(np.array(
+ [[2., 3., 4., 5., 6.], [3., 4., 9., 2., 0.], [2., 5., 8., 3., 8.],
+ [1., 6., 7., 4., 7.], [2., 3., 4., 5., 6.]]).astype(np.float32))
+ # A multidimensional batch of 2x2 matrices
+ self._compareDeterminant(np.random.rand(3, 4, 5, 2, 2).astype(np.float32))
+
+ def testBasicDouble(self):
+ # 2x2 matrices
+ self._compareDeterminant(np.array([[2., 3.], [3., 4.]]).astype(np.float64))
+ self._compareDeterminant(np.array([[0., 0.], [0., 0.]]).astype(np.float64))
+ # 5x5 matrices (Eigen forces LU decomposition)
+ self._compareDeterminant(np.array(
+ [[2., 3., 4., 5., 6.], [3., 4., 9., 2., 0.], [2., 5., 8., 3., 8.],
+ [1., 6., 7., 4., 7.], [2., 3., 4., 5., 6.]]).astype(np.float64))
+ # A multidimensional batch of 2x2 matrices
+ self._compareDeterminant(np.random.rand(3, 4, 5, 2, 2).astype(np.float64))
+
+ def testOverflow(self):
+ max_double = np.finfo("d").max
+ huge_matrix = np.array([[max_double, 0.0], [0.0, max_double]])
+ with self.assertRaisesOpError("not finite"):
+ self._compareDeterminant(huge_matrix)
+
+ def testNonSquareMatrix(self):
+ # When the determinant of a non-square matrix is attempted we should return
+ # an error
+ with self.assertRaises(ValueError):
+ tf.matrix_determinant(
+ np.array([[1., 2., 3.], [3., 5., 4.]]).astype(np.float32))
+
+ def testWrongDimensions(self):
+ # The input to the determinant should be a 2-dimensional tensor.
+ tensor1 = tf.constant([1., 2.])
+ with self.assertRaises(ValueError):
+ tf.matrix_determinant(tensor1)
+
+ def testEmpty(self):
+ self._compareDeterminant(np.empty([0, 2, 2]))
+ self._compareDeterminant(np.empty([2, 0, 0]))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/diag_op_test.py b/tensorflow/python/kernel_tests/diag_op_test.py
new file mode 100644
index 0000000000..7b53ee26fa
--- /dev/null
+++ b/tensorflow/python/kernel_tests/diag_op_test.py
@@ -0,0 +1,80 @@
+import tensorflow.python.platform
+
+import numpy
+import tensorflow as tf
+
+
+class GenerateIdentityTensorTest(tf.test.TestCase):
+
+ def _testDiagOp(self, diag, dtype, expected_ans, use_gpu=False,
+ expected_err_re=None):
+ with self.test_session(use_gpu=use_gpu):
+ tf_ans = tf.diag(tf.convert_to_tensor(diag.astype(dtype)))
+ out = tf_ans.eval()
+ self.assertAllClose(out, expected_ans)
+ self.assertShapeEqual(expected_ans, tf_ans)
+
+ def testEmptyTensor(self):
+ x = numpy.array([])
+ expected_ans = numpy.empty([0, 0])
+ self._testDiagOp(x, numpy.int32, expected_ans)
+
+ def testRankOneIntTensor(self):
+ x = numpy.array([1, 2, 3])
+ expected_ans = numpy.array(
+ [[1, 0, 0],
+ [0, 2, 0],
+ [0, 0, 3]])
+ self._testDiagOp(x, numpy.int32, expected_ans)
+ self._testDiagOp(x, numpy.int64, expected_ans)
+
+ def testRankOneFloatTensor(self):
+ x = numpy.array([1.1, 2.2, 3.3])
+ expected_ans = numpy.array(
+ [[1.1, 0, 0],
+ [0, 2.2, 0],
+ [0, 0, 3.3]])
+ self._testDiagOp(x, numpy.float32, expected_ans)
+ self._testDiagOp(x, numpy.float64, expected_ans)
+
+ def testRankTwoIntTensor(self):
+ x = numpy.array([[1, 2, 3], [4, 5, 6]])
+ expected_ans = numpy.array(
+ [[[[1, 0, 0], [0, 0, 0]],
+ [[0, 2, 0], [0, 0, 0]],
+ [[0, 0, 3], [0, 0, 0]]],
+ [[[0, 0, 0], [4, 0, 0]],
+ [[0, 0, 0], [0, 5, 0]],
+ [[0, 0, 0], [0, 0, 6]]]])
+ self._testDiagOp(x, numpy.int32, expected_ans)
+ self._testDiagOp(x, numpy.int64, expected_ans)
+
+ def testRankTwoFloatTensor(self):
+ x = numpy.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]])
+ expected_ans = numpy.array(
+ [[[[1.1, 0, 0], [0, 0, 0]],
+ [[0, 2.2, 0], [0, 0, 0]],
+ [[0, 0, 3.3], [0, 0, 0]]],
+ [[[0, 0, 0], [4.4, 0, 0]],
+ [[0, 0, 0], [0, 5.5, 0]],
+ [[0, 0, 0], [0, 0, 6.6]]]])
+ self._testDiagOp(x, numpy.float32, expected_ans)
+ self._testDiagOp(x, numpy.float64, expected_ans)
+
+ def testRankThreeFloatTensor(self):
+ x = numpy.array([[[1.1, 2.2], [3.3, 4.4]],
+ [[5.5, 6.6], [7.7, 8.8]]])
+ expected_ans = numpy.array(
+ [[[[[[1.1, 0], [0, 0]], [[0, 0], [0, 0]]],
+ [[[0, 2.2], [0, 0]], [[0, 0], [0, 0]]]],
+ [[[[0, 0], [3.3, 0]], [[0, 0], [0, 0]]],
+ [[[0, 0], [0, 4.4]], [[0, 0], [0, 0]]]]],
+ [[[[[0, 0], [0, 0]], [[5.5, 0], [0, 0]]],
+ [[[0, 0], [0, 0]], [[0, 6.6], [0, 0]]]],
+ [[[[0, 0], [0, 0]], [[0, 0], [7.7, 0]]],
+ [[[0, 0], [0, 0]], [[0, 0], [0, 8.8]]]]]])
+ self._testDiagOp(x, numpy.float32, expected_ans)
+ self._testDiagOp(x, numpy.float64, expected_ans)
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
new file mode 100644
index 0000000000..a7a276893d
--- /dev/null
+++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
@@ -0,0 +1,99 @@
+"""Tests for the DynamicPartition op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class DynamicPartitionTest(tf.test.TestCase):
+
+ def testSimpleOneDimensional(self):
+ with self.test_session() as sess:
+ data = tf.constant([0, 13, 2, 39, 4, 17])
+ indices = tf.constant([0, 0, 2, 3, 2, 1])
+ partitions = tf.dynamic_partition(data, indices, num_partitions=4)
+ partition_vals = sess.run(partitions)
+
+ self.assertAllEqual([0, 13], partition_vals[0])
+ self.assertAllEqual([17], partition_vals[1])
+ self.assertAllEqual([2, 4], partition_vals[2])
+ self.assertAllEqual([39], partition_vals[3])
+ # Vector data input to DynamicPartition results in
+ # `num_partitions` vectors of unknown length.
+ self.assertEqual([None], partitions[0].get_shape().as_list())
+ self.assertEqual([None], partitions[1].get_shape().as_list())
+ self.assertEqual([None], partitions[2].get_shape().as_list())
+ self.assertEqual([None], partitions[3].get_shape().as_list())
+
+ def testSimpleTwoDimensional(self):
+ with self.test_session() as sess:
+ data = tf.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8],
+ [9, 10, 11], [12, 13, 14], [15, 16, 17]])
+ indices = tf.constant([0, 0, 2, 3, 2, 1])
+ partitions = tf.dynamic_partition(data, indices, num_partitions=4)
+ partition_vals = sess.run(partitions)
+
+ self.assertAllEqual([[0, 1, 2], [3, 4, 5]], partition_vals[0])
+ self.assertAllEqual([[15, 16, 17]], partition_vals[1])
+ self.assertAllEqual([[6, 7, 8], [12, 13, 14]], partition_vals[2])
+ self.assertAllEqual([[9, 10, 11]], partition_vals[3])
+ # Vector data input to DynamicPartition results in
+ # `num_partitions` matrices with an unknown number of rows, and 3 columns.
+ self.assertEqual([None, 3], partitions[0].get_shape().as_list())
+ self.assertEqual([None, 3], partitions[1].get_shape().as_list())
+ self.assertEqual([None, 3], partitions[2].get_shape().as_list())
+ self.assertEqual([None, 3], partitions[3].get_shape().as_list())
+
+ def testHigherRank(self):
+ np.random.seed(7)
+ with self.test_session() as sess:
+ for n in 2, 3:
+ for shape in (4,), (4, 5), (4, 5, 2):
+ partitions = np.random.randint(n, size=np.prod(shape)).reshape(shape)
+ for extra_shape in (), (6,), (6, 7):
+ data = np.random.randn(*(shape + extra_shape))
+ outputs = tf.dynamic_partition(data, partitions, num_partitions=n)
+ self.assertEqual(n, len(outputs))
+ for i, output in enumerate(sess.run(outputs)):
+ self.assertAllEqual(output, data[partitions == i])
+
+ def testErrorIndexOutOfRange(self):
+ with self.test_session() as sess:
+ data = tf.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8],
+ [9, 10, 11], [12, 13, 14]])
+ indices = tf.constant([0, 2, 99, 2, 2])
+ partitions = tf.dynamic_partition(data, indices, num_partitions=4)
+ with self.assertRaisesOpError(r"partitions\[2\] = 99 is not in \[0, 4\)"):
+ sess.run(partitions)
+
+ def testScalarIndexOutOfRange(self):
+ with self.test_session() as sess:
+ bad = 17
+ data = np.zeros(5)
+ partitions = tf.dynamic_partition(data, bad, num_partitions=7)
+ with self.assertRaisesOpError(r"partitions = 17 is not in \[0, 7\)"):
+ sess.run(partitions)
+
+ def testHigherRankIndexOutOfRange(self):
+ with self.test_session() as sess:
+ shape = (2, 3)
+ indices = tf.placeholder(shape=shape, dtype=np.int32)
+ data = np.zeros(shape + (5,))
+ partitions = tf.dynamic_partition(data, indices, num_partitions=7)
+ for i in xrange(2):
+ for j in xrange(3):
+ bad = np.zeros(shape, dtype=np.int32)
+ bad[i, j] = 17
+ with self.assertRaisesOpError(
+ r"partitions\[%d,%d\] = 17 is not in \[0, 7\)" % (i, j)):
+ sess.run(partitions, feed_dict={indices: bad})
+
+ def testErrorWrongDimsIndices(self):
+ data = tf.constant([[0], [1], [2]])
+ indices = tf.constant([[0], [0]])
+ with self.assertRaises(ValueError):
+ tf.dynamic_partition(data, indices, num_partitions=4)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
new file mode 100644
index 0000000000..9ac49390b9
--- /dev/null
+++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
@@ -0,0 +1,107 @@
+"""Tests for tensorflow.ops.data_flow_ops.dynamic_stitch."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class DynamicStitchTest(tf.test.TestCase):
+
+ def testScalar(self):
+ with self.test_session():
+ indices = [tf.constant(0), tf.constant(1)]
+ data = [tf.constant(40), tf.constant(60)]
+ for step in -1, 1:
+ stitched_t = tf.dynamic_stitch(indices[::step], data)
+ stitched_val = stitched_t.eval()
+ self.assertAllEqual([40, 60][::step], stitched_val)
+ # Dimension 0 is determined by the max index in indices, so we
+ # can only infer that the output is a vector of some unknown
+ # length.
+ self.assertEqual([None], stitched_t.get_shape().as_list())
+
+ def testSimpleOneDimensional(self):
+ with self.test_session():
+ indices = [tf.constant([0, 4, 7]),
+ tf.constant([1, 6, 2, 3, 5])]
+ data = [tf.constant([0, 40, 70]),
+ tf.constant([10, 60, 20, 30, 50])]
+ stitched_t = tf.dynamic_stitch(indices, data)
+ stitched_val = stitched_t.eval()
+ self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val)
+ # Dimension 0 is determined by the max index in indices, so we
+ # can only infer that the output is a vector of some unknown
+ # length.
+ self.assertEqual([None], stitched_t.get_shape().as_list())
+
+ def testSimpleTwoDimensional(self):
+ with self.test_session():
+ indices = [tf.constant([0, 4, 7]),
+ tf.constant([1, 6]),
+ tf.constant([2, 3, 5])]
+ data = [tf.constant([[0, 1], [40, 41], [70, 71]]),
+ tf.constant([[10, 11], [60, 61]]),
+ tf.constant([[20, 21], [30, 31], [50, 51]])]
+ stitched_t = tf.dynamic_stitch(indices, data)
+ stitched_val = stitched_t.eval()
+ self.assertAllEqual(
+ [[0, 1], [10, 11], [20, 21], [30, 31],
+ [40, 41], [50, 51], [60, 61], [70, 71]], stitched_val)
+ # Dimension 0 is determined by the max index in indices, so we
+ # can only infer that the output is a matrix with 2 columns and
+ # some unknown number of rows.
+ self.assertEqual([None, 2], stitched_t.get_shape().as_list())
+
+ def testHigherRank(self):
+ with self.test_session() as sess:
+ indices = [tf.constant(6), tf.constant([4, 1]),
+ tf.constant([[5, 2], [0, 3]])]
+ data = [tf.constant([61, 62]), tf.constant([[41, 42], [11, 12]]),
+ tf.constant([[[51, 52], [21, 22]], [[1, 2], [31, 32]]])]
+ stitched_t = tf.dynamic_stitch(indices, data)
+ stitched_val = stitched_t.eval()
+ correct = 10 * np.arange(7)[:, None] + [1, 2]
+ self.assertAllEqual(correct, stitched_val)
+ self.assertEqual([None, 2], stitched_t.get_shape().as_list())
+ # Test gradients
+ stitched_grad = 7 * stitched_val
+ grads = tf.gradients(stitched_t, indices + data, stitched_grad)
+ self.assertEqual(grads[:3], [None] * 3) # Indices have no gradients
+ for datum, grad in zip(data, sess.run(grads[3:])):
+ self.assertAllEqual(7 * datum.eval(), grad)
+
+ def testErrorIndicesMultiDimensional(self):
+ indices = [tf.constant([0, 4, 7]),
+ tf.constant([[1, 6, 2, 3, 5]])]
+ data = [tf.constant([[0, 40, 70]]),
+ tf.constant([10, 60, 20, 30, 50])]
+ with self.assertRaises(ValueError):
+ tf.dynamic_stitch(indices, data)
+
+ def testErrorDataNumDimsMismatch(self):
+ indices = [tf.constant([0, 4, 7]),
+ tf.constant([1, 6, 2, 3, 5])]
+ data = [tf.constant([0, 40, 70]),
+ tf.constant([[10, 60, 20, 30, 50]])]
+ with self.assertRaises(ValueError):
+ tf.dynamic_stitch(indices, data)
+
+ def testErrorDataDimSizeMismatch(self):
+ indices = [tf.constant([0, 4, 5]),
+ tf.constant([1, 6, 2, 3])]
+ data = [tf.constant([[0], [40], [70]]),
+ tf.constant([[10, 11], [60, 61], [20, 21], [30, 31]])]
+ with self.assertRaises(ValueError):
+ tf.dynamic_stitch(indices, data)
+
+ def testErrorDataAndIndicesSizeMismatch(self):
+ indices = [tf.constant([0, 4, 7]),
+ tf.constant([1, 6, 2, 3, 5])]
+ data = [tf.constant([0, 40, 70]),
+ tf.constant([10, 60, 20, 30])]
+ with self.assertRaises(ValueError):
+ tf.dynamic_stitch(indices, data)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/edit_distance_op_test.py b/tensorflow/python/kernel_tests/edit_distance_op_test.py
new file mode 100644
index 0000000000..5919adcfaf
--- /dev/null
+++ b/tensorflow/python/kernel_tests/edit_distance_op_test.py
@@ -0,0 +1,153 @@
+"""Tests for tensorflow.kernels.edit_distance_op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+def ConstantOf(x):
+ x = np.asarray(x)
+ # Convert to int64 if it's not a string
+ if x.dtype.char != "S": x = np.asarray(x, dtype=np.int64)
+ return tf.constant(x)
+
+
+class EditDistanceTest(tf.test.TestCase):
+
+ def _testEditDistance(self, hypothesis, truth, normalize,
+ expected_output, expected_err_re=None):
+ # hypothesis and truth are (index, value, shape) tuples
+ hypothesis_st = tf.SparseTensor(*[ConstantOf(x) for x in hypothesis])
+ truth_st = tf.SparseTensor(*[ConstantOf(x) for x in truth])
+ edit_distance = tf.edit_distance(
+ hypothesis=hypothesis_st, truth=truth_st, normalize=normalize)
+
+ with self.test_session():
+ if expected_err_re is None:
+ # Shape inference figures out the shape from the shape variables
+ expected_shape = [
+ max(h, t) for h, t in zip(hypothesis[2], truth[2])[:-1]]
+ self.assertEqual(edit_distance.get_shape(), expected_shape)
+ output = edit_distance.eval()
+ self.assertAllClose(output, expected_output)
+ else:
+ with self.assertRaisesOpError(expected_err_re):
+ edit_distance.eval()
+
+ def testEditDistanceNormalized(self):
+ hypothesis_indices = [[0, 0], [0, 1],
+ [1, 0], [1, 1]]
+ hypothesis_values = [0, 1,
+ 1, -1]
+ hypothesis_shape = [2, 2]
+ truth_indices = [[0, 0],
+ [1, 0], [1, 1]]
+ truth_values = [0,
+ 1, 1]
+ truth_shape = [2, 2]
+ expected_output = [1.0, 0.5]
+
+ self._testEditDistance(
+ hypothesis=(hypothesis_indices, hypothesis_values, hypothesis_shape),
+ truth=(truth_indices, truth_values, truth_shape),
+ normalize=True,
+ expected_output=expected_output)
+
+ def testEditDistanceUnnormalized(self):
+ hypothesis_indices = [[0, 0],
+ [1, 0], [1, 1]]
+ hypothesis_values = [10,
+ 10, 11]
+ hypothesis_shape = [2, 2]
+ truth_indices = [[0, 0], [0, 1],
+ [1, 0], [1, 1]]
+ truth_values = [1, 2,
+ 1, -1]
+ truth_shape = [2, 3]
+ expected_output = [2.0, 2.0]
+
+ self._testEditDistance(
+ hypothesis=(hypothesis_indices, hypothesis_values, hypothesis_shape),
+ truth=(truth_indices, truth_values, truth_shape),
+ normalize=False,
+ expected_output=expected_output)
+
+ def testEditDistanceProperDistance(self):
+ # In this case, the values are individual characters stored in the
+ # SparseTensor (type DT_STRING)
+ hypothesis_indices = ([[0, i] for i, _ in enumerate("algorithm")] +
+ [[1, i] for i, _ in enumerate("altruistic")])
+ hypothesis_values = [x for x in "algorithm"] + [x for x in "altruistic"]
+ hypothesis_shape = [2, 11]
+ truth_indices = ([[0, i] for i, _ in enumerate("altruistic")] +
+ [[1, i] for i, _ in enumerate("algorithm")])
+ truth_values = [x for x in "altruistic"] + [x for x in "algorithm"]
+ truth_shape = [2, 11]
+ expected_unnormalized = [6.0, 6.0]
+ expected_normalized = [6.0/len("altruistic"),
+ 6.0/len("algorithm")]
+
+ self._testEditDistance(
+ hypothesis=(hypothesis_indices, hypothesis_values, hypothesis_shape),
+ truth=(truth_indices, truth_values, truth_shape),
+ normalize=False,
+ expected_output=expected_unnormalized)
+
+ self._testEditDistance(
+ hypothesis=(hypothesis_indices, hypothesis_values, hypothesis_shape),
+ truth=(truth_indices, truth_values, truth_shape),
+ normalize=True,
+ expected_output=expected_normalized)
+
+ def testEditDistance3D(self):
+ hypothesis_indices = [[0, 0, 0],
+ [1, 0, 0]]
+ hypothesis_values = [0, 1]
+ hypothesis_shape = [2, 1, 1]
+ truth_indices = [[0, 1, 0],
+ [1, 0, 0],
+ [1, 1, 0]]
+ truth_values = [0, 1, 1]
+ truth_shape = [2, 2, 1]
+ expected_output = [[np.inf, 1.0], # (0,0): no truth, (0,1): no hypothesis
+ [0.0, 1.0]] # (1,0): match, (1,1): no hypothesis
+
+ self._testEditDistance(
+ hypothesis=(hypothesis_indices, hypothesis_values, hypothesis_shape),
+ truth=(truth_indices, truth_values, truth_shape),
+ normalize=True,
+ expected_output=expected_output)
+
+ def testEditDistanceMissingHypothesis(self):
+ hypothesis_indices = np.empty((0, 2), dtype=np.int64)
+ hypothesis_values = []
+ hypothesis_shape = [1, 0]
+ truth_indices = [[0, 0]]
+ truth_values = [0]
+ truth_shape = [1, 1]
+ expected_output = [1.0]
+
+ self._testEditDistance(
+ hypothesis=(hypothesis_indices, hypothesis_values, hypothesis_shape),
+ truth=(truth_indices, truth_values, truth_shape),
+ normalize=True,
+ expected_output=expected_output)
+
+ def testEditDistanceMissingTruth(self):
+ hypothesis_indices = [[0, 0]]
+ hypothesis_values = [0]
+ hypothesis_shape = [1, 1]
+ truth_indices = np.empty((0, 2), dtype=np.int64)
+ truth_values = []
+ truth_shape = [1, 0]
+ expected_output = [np.inf] # Normalized, divide by zero
+
+ self._testEditDistance(
+ hypothesis=(hypothesis_indices, hypothesis_values, hypothesis_shape),
+ truth=(truth_indices, truth_values, truth_shape),
+ normalize=True,
+ expected_output=expected_output)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
new file mode 100644
index 0000000000..99aa2453dc
--- /dev/null
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -0,0 +1,422 @@
+"""Functional tests for ops used with embeddings."""
+import itertools
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+
+def _AsLong(array):
+ """Casts arrays elements to long type. Used to convert from numpy tf."""
+ return [long(x) for x in array]
+
+
+class ScatterAddSubTest(tf.test.TestCase):
+
+ def _TestCase(self, shape, indices, scatter_op=tf.scatter_add):
+ """Run a random test case with the given shape and indices.
+
+ Args:
+ shape: Shape of the parameters array.
+ indices: One-dimensional array of ints, the indices of the last dimension
+ of the parameters to update.
+ scatter_op: ScatterAdd or ScatterSub.
+ """
+ super(ScatterAddSubTest, self).setUp()
+ with self.test_session(use_gpu=False):
+ # Create a random parameter array of given shape
+ p_init = np.random.rand(*shape).astype("f")
+ # Create the shape of the update array. All dimensions except the last
+ # match the parameter array, the last dimension equals the # of indices.
+ vals_shape = [len(indices)] + shape[1:]
+ vals_init = np.random.rand(*vals_shape).astype("f")
+ v_i = [float(x) for x in vals_init.ravel()]
+ p = tf.Variable(p_init)
+ vals = tf.constant(v_i, shape=vals_shape, name="vals")
+ ind = tf.constant(indices, dtype=tf.int32)
+ p2 = scatter_op(p, ind, vals, name="updated_p")
+ # p = init
+ tf.initialize_all_variables().run()
+ # p += vals
+ result = p2.eval()
+ # Compute the expected 'p' using numpy operations.
+ for i, ind in enumerate(indices):
+ if scatter_op == tf.scatter_add:
+ p_init.reshape(shape[0], -1)[ind, :] += (
+ vals_init.reshape(vals_shape[0], -1)[i, :])
+ else:
+ p_init.reshape(shape[0], -1)[ind, :] -= (
+ vals_init.reshape(vals_shape[0], -1)[i, :])
+ self.assertTrue(all((p_init == result).ravel()))
+
+ def testNoRepetitions(self):
+ self._TestCase([2, 2], [1])
+ self._TestCase([4, 4, 4], [2, 0])
+ self._TestCase([43, 20, 10, 10], [42, 5, 6, 1, 3, 5, 7, 9])
+
+ def testWithRepetitions(self):
+ self._TestCase([2, 2], [1, 1])
+ self._TestCase([5, 3, 9, 5], [2, 0, 4, 1, 3, 1, 4, 0, 4, 3])
+ self._TestCase([32, 4, 4], [31] * 8)
+
+ def testRandom(self):
+ # Random shapes of rank 4, random indices
+ for _ in range(5):
+ shape = np.random.randint(1, 20, size=4)
+ indices = np.random.randint(shape[0], size=2 * shape[0])
+ self._TestCase(_AsLong(list(shape)), list(indices))
+
+ def testSubRandom(self):
+ # Random shapes of rank 4, random indices
+ for _ in range(5):
+ shape = np.random.randint(1, 20, size=4)
+ indices = np.random.randint(shape[0], size=2 * shape[0])
+ self._TestCase(_AsLong(list(shape)), list(indices),
+ tf.scatter_sub)
+
+ def testWrongShape(self):
+ # Indices and values mismatch.
+ var = tf.Variable(tf.zeros(shape=[1024, 64, 64], dtype=tf.float32))
+ indices = tf.placeholder(tf.int32, shape=[32])
+ values = tf.placeholder(tf.float32, shape=[33, 64, 64])
+ with self.assertRaises(ValueError):
+ tf.scatter_add(var, indices, values)
+
+ # Var and values mismatch.
+ values = tf.placeholder(tf.float32, shape=[32, 64, 63])
+ with self.assertRaises(ValueError):
+ tf.scatter_add(var, indices, values)
+
+
+def _PName(param_id):
+ return "p" + str(param_id)
+
+
+def _EmbeddingParams(num_shards, vocab_size,
+ dtype=tf.float32,
+ shape=None):
+ p = []
+ params = {}
+ feed_dict = {}
+ if not shape: shape = [10]
+ assert not vocab_size % num_shards
+ shape = [vocab_size / num_shards] + shape
+ for i in range(num_shards):
+ param_name = _PName(i)
+ constant_t = tf.constant(1.0, shape=shape, dtype=dtype,
+ name=param_name)
+ p.append(constant_t)
+ np_type = "f" if dtype == tf.float32 else "d"
+ val = (np.random.rand(*shape).astype(np_type)) + 1
+ params[param_name + ":0"] = val
+ feed_dict[constant_t.name] = val
+ return p, params, feed_dict
+
+
+def _EmbeddingResult(params, id_vals, num_shards, weight_vals=None):
+ if weight_vals is None:
+ weight_vals = np.copy(id_vals)
+ weight_vals.fill(1)
+ values = []
+ weights = []
+ for ids, wts in zip(id_vals, weight_vals):
+ val_aggr = None
+ wt_aggr = None
+ if isinstance(ids, int):
+ ids = [ids]
+ wts = [wts]
+ for i, wt_val in zip(ids, wts):
+ val = np.copy(params[_PName(i % num_shards) + ":0"]
+ [i / num_shards, :]) * wt_val
+ if val_aggr is None:
+ assert wt_aggr is None
+ val_aggr = val
+ wt_aggr = wt_val
+ else:
+ assert wt_aggr is not None
+ val_aggr += val
+ wt_aggr += wt_val
+ values.append(val_aggr)
+ weights.append(wt_aggr)
+ values = np.array(values).astype(np.float32)
+ weights = np.array(weights).astype(np.float32)
+ return values, weights
+
+
+class EmbeddingLookupTest(tf.test.TestCase):
+
+ # This test looks up [0, 0] in a parameter matrix sharded 2 ways. Since
+ # both the ids are in the first shard, one of the resulting lookup
+ # vector is going to be empty. The subsequent DivOp fails because of that.
+ # TODO(keveman): Disabling the test until the underlying problem is fixed.
+ def testSimpleSharded(self):
+ with self.test_session():
+ num_shards = 2
+ vocab_size = 4
+ p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
+
+ id_vals = np.array([0, 0])
+ ids = tf.constant(list(id_vals), dtype=tf.int32)
+ print "Construct ids", ids.get_shape()
+ embedding = tf.nn.embedding_lookup(p, ids)
+
+ tf_result = embedding.eval(feed_dict=feed_dict)
+ np_result, _ = _EmbeddingResult(params, id_vals, num_shards)
+ self.assertAllEqual(np_result, tf_result)
+ self.assertShapeEqual(np_result, embedding)
+
+ def testSharded(self):
+ with self.test_session():
+ num_shards = 5
+ vocab_size = 25
+ # Embedding dimensions is 10. The 10 x vocab_size embedding
+ # parameters are spread in num_shards matrices, so each
+ # matrix is 10 x (vocab_size / num_shards)
+ p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
+
+ num_vals = 30
+ # Fetch num_vals embeddings for random word ids. Since
+ # num_vals > vocab_size, this ought to have repetitions, so
+ # will test that aspect.
+ id_vals = np.random.randint(vocab_size, size=num_vals)
+ ids = tf.constant(list(id_vals), dtype=tf.int32)
+
+ embedding = tf.nn.embedding_lookup(p, ids)
+ tf_result = embedding.eval(feed_dict=feed_dict)
+ np_result, _ = _EmbeddingResult(params, id_vals, num_shards)
+ self.assertAllEqual(np_result, tf_result)
+ self.assertShapeEqual(np_result, embedding)
+
+ def testGradientsEmbeddingLookup(self):
+ vocab_size = 9
+ num_ids = 5
+ id_vals = list(np.random.randint(vocab_size, size=num_ids))
+ tf.logging.vlog(1, id_vals)
+ for num_shards in [1, 3]:
+ with self.test_session():
+ ids = tf.constant(id_vals, dtype=tf.int32)
+ x, params, _ = _EmbeddingParams(
+ num_shards, vocab_size, shape=[2])
+ y = tf.nn.embedding_lookup(x, ids)
+ y_shape = [num_ids] + list(params[_PName(0) + ":0"].shape[1:])
+ x_name = [_PName(i) for i in range(num_shards)]
+ x_init_value = [params[x_n + ":0"] for x_n in x_name]
+ x_shape = [i.shape for i in x_init_value]
+ err = gc.ComputeGradientError(x, x_shape, y, y_shape,
+ x_init_value=x_init_value)
+ self.assertLess(err, 1e-4)
+
+ def testGradientsEmbeddingLookupWithComputedParams(self):
+ vocab_size = 9
+ num_ids = 5
+ id_vals = list(np.random.randint(vocab_size, size=num_ids))
+ tf.logging.vlog(1, id_vals)
+ for num_shards in [1, 3]:
+ with self.test_session():
+ ids = tf.constant(id_vals, dtype=tf.int32)
+ x, params, _ = _EmbeddingParams(
+ num_shards, vocab_size, shape=[2])
+ # This will force a conversion from IndexedSlices to Tensor.
+ x_squared = [tf.square(elem) for elem in x]
+ y = tf.nn.embedding_lookup(x_squared, ids)
+ y_shape = [num_ids] + list(params[_PName(0) + ":0"].shape[1:])
+ x_name = [_PName(i) for i in range(num_shards)]
+ x_init_value = [params[x_n + ":0"] for x_n in x_name]
+ x_shape = [i.shape for i in x_init_value]
+ err = gc.ComputeGradientError(x, x_shape, y, y_shape,
+ x_init_value=x_init_value)
+ self.assertLess(err, 1e-3)
+
+ def testConstructionNonSharded(self):
+ with tf.Graph().as_default():
+ p = tf.Variable(tf.zeros(shape=[100, 100], dtype=tf.float32))
+ ids = tf.constant([0, 1, 1, 7], dtype=tf.int32)
+ tf.nn.embedding_lookup([p], ids)
+
+ def testConstructionSharded(self):
+ with tf.Graph().as_default():
+ p = []
+ for _ in range(2):
+ p += [tf.Variable(tf.zeros(shape=[100, 100], dtype=tf.float32))]
+ ids = tf.constant([0, 1, 1, 17], dtype=tf.int32)
+ tf.nn.embedding_lookup(p, ids)
+
+ def testHigherRank(self):
+ np.random.seed(8)
+ with self.test_session():
+ for params_shape in (12,), (6, 3):
+ params = np.random.randn(*params_shape)
+ for ids_shape in (3, 2), (4, 3):
+ ids = np.random.randint(params.shape[0],
+ size=np.prod(ids_shape)).reshape(ids_shape)
+ # Compare nonsharded to gather
+ simple = tf.nn.embedding_lookup(params, ids).eval()
+ self.assertAllEqual(simple, tf.gather(params, ids).eval())
+ # Run a few random sharded versions
+ for procs in 1, 2, 3:
+ stride = procs * tf.range(0, params.shape[0] / procs)
+ split_params = [tf.gather(params, stride + p)
+ for p in xrange(procs)]
+ sharded = tf.nn.embedding_lookup(split_params, ids).eval()
+ self.assertAllEqual(simple, sharded)
+
+
+class EmbeddingLookupSparseTest(tf.test.TestCase):
+
+ def _RandomIdsAndWeights(self, batch_size, vocab_size):
+ max_val_per_entry = 6
+ vals_per_batch_entry = np.random.randint(
+ 1, max_val_per_entry, size=batch_size)
+ num_vals = np.sum(vals_per_batch_entry)
+
+ ids = np.random.randint(vocab_size, size=num_vals)
+ weights = 1 + np.random.rand(num_vals)
+
+ indices = []
+ for batch_entry, num_val in enumerate(vals_per_batch_entry):
+ for val_index in range(num_val):
+ indices.append([batch_entry, val_index])
+
+ shape = [batch_size, max_val_per_entry]
+
+ sp_ids = tf.SparseTensor(
+ tf.constant(indices, tf.int64),
+ tf.constant(ids, tf.int32),
+ tf.constant(shape, tf.int64))
+ sp_weights = tf.SparseTensor(
+ tf.constant(indices, tf.int64),
+ tf.constant(weights, tf.float32),
+ tf.constant(shape, tf.int64))
+
+ return sp_ids, sp_weights, ids, weights, vals_per_batch_entry
+
+ def _GroupByBatchEntry(self, vals, vals_per_batch_entry):
+ grouped_vals = []
+ index = 0
+ for num_val in vals_per_batch_entry:
+ grouped_vals.append(list(vals[index: (index + num_val)]))
+ index += num_val
+ return grouped_vals
+
+ def testEmbeddingLookupSparse(self):
+ vocab_size = 25
+ batch_size = 10
+ param_shape = [2, 5]
+
+ sp_ids, sp_weights, ids, weights, vals_per_batch_entry = (
+ self._RandomIdsAndWeights(batch_size, vocab_size))
+
+ grouped_ids = self._GroupByBatchEntry(ids, vals_per_batch_entry)
+ grouped_weights = self._GroupByBatchEntry(weights, vals_per_batch_entry)
+ grouped_ignored_weights = self._GroupByBatchEntry(
+ np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry)
+
+ for num_shards, combiner, dtype, ignore_weights in itertools.product(
+ [1, 5],
+ ["sum", "mean"],
+ [tf.float32, tf.float64],
+ [True, False]):
+
+ with self.test_session():
+ p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size,
+ shape=param_shape,
+ dtype=dtype)
+ embedding_sum = tf.nn.embedding_lookup_sparse(
+ p, sp_ids, None if ignore_weights else sp_weights,
+ combiner=combiner)
+ tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict)
+
+ np_embedding_sum, np_weight_sum = _EmbeddingResult(
+ params, grouped_ids, num_shards,
+ weight_vals=grouped_ignored_weights
+ if ignore_weights else grouped_weights)
+ if combiner == "mean":
+ np_embedding_sum /= np.reshape(np_weight_sum, (batch_size, 1, 1))
+ self.assertAllClose(np_embedding_sum, tf_embedding_sum)
+
+ def testGradientsEmbeddingLookupSparse(self):
+ vocab_size = 12
+ batch_size = 4
+ param_shape = [2, 3]
+ sp_ids, sp_weights, _, _, _ = (
+ self._RandomIdsAndWeights(batch_size, vocab_size))
+
+ for num_shards, combiner, dtype, ignore_weights in itertools.product(
+ [1, 3],
+ ["sum", "mean"],
+ [tf.float32, tf.float64],
+ [True, False]):
+ with self.test_session():
+ x, params, _ = _EmbeddingParams(num_shards, vocab_size,
+ shape=param_shape,
+ dtype=dtype)
+
+ y = tf.nn.embedding_lookup_sparse(
+ x, sp_ids, None if ignore_weights else sp_weights,
+ combiner=combiner)
+ x_name = [_PName(i) for i in range(num_shards)]
+ x_init_value = [params[x_n + ":0"] for x_n in x_name]
+ x_shape = [i.shape for i in x_init_value]
+ y_shape = [batch_size] + list(params[_PName(0) + ":0"].shape[1:])
+ err = gc.ComputeGradientError(x, x_shape, y, y_shape,
+ x_init_value=x_init_value)
+ self.assertLess(err, 1e-5 if dtype == tf.float64 else 2e-3)
+
+
+class DynamicStitchOpTest(tf.test.TestCase):
+
+ def testCint32Cpu(self):
+ with self.test_session(use_gpu=False):
+ indices = [tf.convert_to_tensor([0, 1, 2]), tf.convert_to_tensor([2, 3])]
+ values = [tf.convert_to_tensor([12, 23, 34]), tf.convert_to_tensor([1, 2])]
+ self.assertAllEqual(
+ tf.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2])
+
+ def testCint32Gpu(self):
+ with self.test_session(use_gpu=True):
+ indices = [tf.convert_to_tensor([0, 1, 2]), tf.convert_to_tensor([2, 3])]
+ values = [tf.convert_to_tensor([12, 23, 34]), tf.convert_to_tensor([1, 2])]
+ self.assertAllEqual(
+ tf.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2])
+
+ def testInt32Cpu(self):
+ with self.test_session(use_gpu=False):
+ indices = [tf.convert_to_tensor([0, 1, 2]), tf.convert_to_tensor([2, 3])]
+ values = [tf.convert_to_tensor([12, 23, 34]), tf.convert_to_tensor([1, 2])]
+ self.assertAllEqual(
+ tf.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2])
+
+ def testInt32Gpu(self):
+ with self.test_session(use_gpu=True):
+ indices = [tf.convert_to_tensor([0, 1, 2]), tf.convert_to_tensor([2, 3])]
+ values = [tf.convert_to_tensor([12, 23, 34]), tf.convert_to_tensor([1, 2])]
+ self.assertAllEqual(
+ tf.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2])
+
+ def testSumGradArgs(self):
+ with self.test_session(use_gpu=False):
+ indices = [tf.convert_to_tensor([0, 1, 2, 3]),
+ tf.convert_to_tensor([2, 3])]
+ values = [tf.convert_to_tensor([2, 3, 5, 7]), tf.convert_to_tensor([1, 1])]
+ self.assertAllEqual(
+ tf.dynamic_stitch(indices, values).eval(), [2, 3, 1, 1])
+
+ # We expect that the values are merged in order.
+ def testStitchOrder(self):
+ with self.test_session():
+ indices = []
+ np_values = []
+ values = []
+ for _ in range(10):
+ indices.extend([tf.convert_to_tensor(np.arange(100).astype(np.int32))])
+ np_values.extend([np.random.uniform(size=100)])
+ values.extend([tf.convert_to_tensor(np_values[-1])])
+ stitched = tf.dynamic_stitch(indices, values).eval()
+ self.assertAllEqual(np_values[-1], stitched)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
new file mode 100644
index 0000000000..57448db433
--- /dev/null
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -0,0 +1,1043 @@
+"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue."""
+import random
+import re
+import time
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class FIFOQueueTest(tf.test.TestCase):
+
+ def testConstructor(self):
+ with tf.Graph().as_default():
+ q = tf.FIFOQueue(10, tf.float32, name="Q")
+ self.assertTrue(isinstance(q.queue_ref, tf.Tensor))
+ self.assertEquals(tf.string_ref, q.queue_ref.dtype)
+ self.assertProtoEquals("""
+ name:'Q' op:'FIFOQueue'
+ attr { key: 'component_types' value { list { type: DT_FLOAT } } }
+ attr { key: 'shapes' value { list {} } }
+ attr { key: 'capacity' value { i: 10 } }
+ attr { key: 'container' value { s: '' } }
+ attr { key: 'shared_name' value { s: '' } }
+ """, q.queue_ref.op.node_def)
+
+ def testMultiQueueConstructor(self):
+ with tf.Graph().as_default():
+ q = tf.FIFOQueue(5, (tf.int32, tf.float32),
+ shared_name="foo", name="Q")
+ self.assertTrue(isinstance(q.queue_ref, tf.Tensor))
+ self.assertEquals(tf.string_ref, q.queue_ref.dtype)
+ self.assertProtoEquals("""
+ name:'Q' op:'FIFOQueue'
+ attr { key: 'component_types' value { list {
+ type: DT_INT32 type : DT_FLOAT
+ } } }
+ attr { key: 'shapes' value { list {} } }
+ attr { key: 'capacity' value { i: 5 } }
+ attr { key: 'container' value { s: '' } }
+ attr { key: 'shared_name' value { s: 'foo' } }
+ """, q.queue_ref.op.node_def)
+
+ def testConstructorWithShapes(self):
+ with tf.Graph().as_default():
+ q = tf.FIFOQueue(5, (tf.int32, tf.float32),
+ shapes=(tf.TensorShape([1, 1, 2, 3]),
+ tf.TensorShape([5, 8])), name="Q")
+ self.assertTrue(isinstance(q.queue_ref, tf.Tensor))
+ self.assertEquals(tf.string_ref, q.queue_ref.dtype)
+ self.assertProtoEquals("""
+ name:'Q' op:'FIFOQueue'
+ attr { key: 'component_types' value { list {
+ type: DT_INT32 type : DT_FLOAT
+ } } }
+ attr { key: 'shapes' value { list {
+ shape { dim { size: 1 }
+ dim { size: 1 }
+ dim { size: 2 }
+ dim { size: 3 } }
+ shape { dim { size: 5 }
+ dim { size: 8 } }
+ } } }
+ attr { key: 'capacity' value { i: 5 } }
+ attr { key: 'container' value { s: '' } }
+ attr { key: 'shared_name' value { s: '' } }
+ """, q.queue_ref.op.node_def)
+
+ def testEnqueue(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, tf.float32)
+ enqueue_op = q.enqueue((10.0,))
+ enqueue_op.run()
+
+ def testEnqueueWithShape(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, tf.float32, shapes=(3, 2))
+ enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
+ enqueue_correct_op.run()
+ with self.assertRaises(ValueError):
+ q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],))
+ self.assertEqual(1, q.size().eval())
+
+ def testEnqueueManyWithShape(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, [tf.int32, tf.int32],
+ shapes=[(), (2,)])
+ q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run()
+ self.assertEqual(4, q.size().eval())
+
+ def testParallelEnqueue(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, tf.float32)
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ # Run one producer thread for each element in elems.
+ def enqueue(enqueue_op):
+ sess.run(enqueue_op)
+ threads = [self.checkedThread(target=enqueue, args=(e,))
+ for e in enqueue_ops]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ # Dequeue every element using a single thread.
+ results = []
+ for _ in xrange(len(elems)):
+ results.append(dequeued_t.eval())
+ self.assertItemsEqual(elems, results)
+
+ def testParallelDequeue(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, tf.float32)
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ # Enqueue every element using a single thread.
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ # Run one consumer thread for each element in elems.
+ results = []
+
+ def dequeue():
+ results.append(sess.run(dequeued_t))
+ threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ self.assertItemsEqual(elems, results)
+
+ def testDequeue(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, tf.float32)
+ elems = [10.0, 20.0, 30.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ for i in xrange(len(elems)):
+ vals = dequeued_t.eval()
+ self.assertEqual([elems[i]], vals)
+
+ def testEnqueueAndBlockingDequeue(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(3, tf.float32)
+ elems = [10.0, 20.0, 30.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ def enqueue():
+ # The enqueue_ops should run after the dequeue op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ for enqueue_op in enqueue_ops:
+ sess.run(enqueue_op)
+
+ results = []
+
+ def dequeue():
+ for _ in xrange(len(elems)):
+ results.append(sess.run(dequeued_t))
+
+ enqueue_thread = self.checkedThread(target=enqueue)
+ dequeue_thread = self.checkedThread(target=dequeue)
+ enqueue_thread.start()
+ dequeue_thread.start()
+ enqueue_thread.join()
+ dequeue_thread.join()
+
+ for elem, result in zip(elems, results):
+ self.assertEqual([elem], result)
+
+ def testMultiEnqueueAndDequeue(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, (tf.int32, tf.float32))
+ elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
+ enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
+ dequeued_t = q.dequeue()
+
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ for i in xrange(len(elems)):
+ x_val, y_val = sess.run(dequeued_t)
+ x, y = elems[i]
+ self.assertEqual([x], x_val)
+ self.assertEqual([y], y_val)
+
+ def testQueueSizeEmpty(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, tf.float32)
+ self.assertEqual([0], q.size().eval())
+
+ def testQueueSizeAfterEnqueueAndDequeue(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, tf.float32)
+ enqueue_op = q.enqueue((10.0,))
+ dequeued_t = q.dequeue()
+ size = q.size()
+ self.assertEqual([], size.get_shape())
+
+ enqueue_op.run()
+ self.assertEqual(1, size.eval())
+ dequeued_t.op.run()
+ self.assertEqual(0, size.eval())
+
+ def testEnqueueMany(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, tf.float32)
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ dequeued_t = q.dequeue()
+ enqueue_op.run()
+ enqueue_op.run()
+
+ for i in range(8):
+ vals = dequeued_t.eval()
+ self.assertEqual([elems[i % 4]], vals)
+
+ def testEmptyEnqueueMany(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, tf.float32)
+ empty_t = tf.constant([], dtype=tf.float32,
+ shape=[0, 2, 3])
+ enqueue_op = q.enqueue_many((empty_t,))
+ size_t = q.size()
+
+ self.assertEqual([0], size_t.eval())
+ enqueue_op.run()
+ self.assertEqual([0], size_t.eval())
+
+ def testEmptyDequeueMany(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, tf.float32, shapes=())
+ enqueue_op = q.enqueue((10.0,))
+ dequeued_t = q.dequeue_many(0)
+
+ self.assertEqual([], dequeued_t.eval().tolist())
+ enqueue_op.run()
+ self.assertEqual([], dequeued_t.eval().tolist())
+
+ def testEmptyDequeueManyWithNoShape(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, tf.float32)
+ # Expect the operation to fail due to the shape not being constrained.
+ with self.assertRaisesOpError("specified shapes"):
+ q.dequeue_many(0).eval()
+
+ def testMultiEnqueueMany(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, (tf.float32, tf.int32))
+ float_elems = [10.0, 20.0, 30.0, 40.0]
+ int_elems = [[1, 2], [3, 4], [5, 6], [7, 8]]
+ enqueue_op = q.enqueue_many((float_elems, int_elems))
+ dequeued_t = q.dequeue()
+
+ enqueue_op.run()
+ enqueue_op.run()
+
+ for i in range(8):
+ float_val, int_val = sess.run(dequeued_t)
+ self.assertEqual(float_elems[i % 4], float_val)
+ self.assertAllEqual(int_elems[i % 4], int_val)
+
+ def testDequeueMany(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, tf.float32, ())
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ enqueue_op = q.enqueue_many((elems,))
+ dequeued_t = q.dequeue_many(4)
+
+ enqueue_op.run()
+
+ self.assertAllEqual(elems[0:4], dequeued_t.eval())
+ self.assertAllEqual(elems[4:8], dequeued_t.eval())
+
+ def testMultiDequeueMany(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, (tf.float32, tf.int32),
+ shapes=((), (2,)))
+ float_elems = [
+ 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ int_elems = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
+ [11, 12], [13, 14], [15, 16], [17, 18], [19, 20]]
+ enqueue_op = q.enqueue_many((float_elems, int_elems))
+ dequeued_t = q.dequeue_many(4)
+ dequeued_single_t = q.dequeue()
+
+ enqueue_op.run()
+
+ float_val, int_val = sess.run(dequeued_t)
+ self.assertAllEqual(float_elems[0:4], float_val)
+ self.assertAllEqual(int_elems[0:4], int_val)
+ self.assertEqual(float_val.shape, dequeued_t[0].get_shape())
+ self.assertEqual(int_val.shape, dequeued_t[1].get_shape())
+
+ float_val, int_val = sess.run(dequeued_t)
+ self.assertAllEqual(float_elems[4:8], float_val)
+ self.assertAllEqual(int_elems[4:8], int_val)
+
+ float_val, int_val = sess.run(dequeued_single_t)
+ self.assertAllEqual(float_elems[8], float_val)
+ self.assertAllEqual(int_elems[8], int_val)
+ self.assertEqual(float_val.shape, dequeued_single_t[0].get_shape())
+ self.assertEqual(int_val.shape, dequeued_single_t[1].get_shape())
+
+ def testHighDimension(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, tf.int32, (4, 4, 4, 4))
+ elems = np.array([[[[[x] * 4] * 4] * 4] * 4 for x in range(10)], np.int32)
+ enqueue_op = q.enqueue_many((elems,))
+ dequeued_t = q.dequeue_many(10)
+
+ enqueue_op.run()
+ self.assertAllEqual(dequeued_t.eval(), elems)
+
+ def testParallelEnqueueMany(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(1000, tf.float32, shapes=())
+ elems = [10.0 * x for x in range(100)]
+ enqueue_op = q.enqueue_many((elems,))
+ dequeued_t = q.dequeue_many(1000)
+
+ # Enqueue 100 items in parallel on 10 threads.
+ def enqueue():
+ sess.run(enqueue_op)
+ threads = [self.checkedThread(target=enqueue) for _ in range(10)]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ self.assertItemsEqual(dequeued_t.eval(), elems * 10)
+
+ def testParallelDequeueMany(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(1000, tf.float32, shapes=())
+ elems = [10.0 * x for x in range(1000)]
+ enqueue_op = q.enqueue_many((elems,))
+ dequeued_t = q.dequeue_many(100)
+
+ enqueue_op.run()
+
+ # Dequeue 100 items in parallel on 10 threads.
+ dequeued_elems = []
+
+ def dequeue():
+ dequeued_elems.extend(sess.run(dequeued_t))
+ threads = [self.checkedThread(target=dequeue) for _ in range(10)]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ self.assertItemsEqual(elems, dequeued_elems)
+
+ def testParallelEnqueueAndDequeue(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(50, tf.float32, shapes=())
+ initial_elements = [10.0] * 49
+ q.enqueue_many((initial_elements,)).run()
+
+ enqueue_op = q.enqueue((20.0,))
+ dequeued_t = q.dequeue()
+
+ def enqueue():
+ for _ in xrange(100):
+ sess.run(enqueue_op)
+ def dequeue():
+ for _ in xrange(100):
+ self.assertTrue(sess.run(dequeued_t) in (10.0, 20.0))
+
+ enqueue_threads = [self.checkedThread(target=enqueue) for _ in range(10)]
+ dequeue_threads = [self.checkedThread(target=dequeue) for _ in range(10)]
+ for enqueue_thread in enqueue_threads:
+ enqueue_thread.start()
+ for dequeue_thread in dequeue_threads:
+ dequeue_thread.start()
+ for enqueue_thread in enqueue_threads:
+ enqueue_thread.join()
+ for dequeue_thread in dequeue_threads:
+ dequeue_thread.join()
+
+ # Dequeue the initial count of elements to clean up.
+ cleanup_elems = q.dequeue_many(49).eval()
+ for elem in cleanup_elems:
+ self.assertTrue(elem in (10.0, 20.0))
+
+ def testMixtureOfEnqueueAndEnqueueMany(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, tf.int32, shapes=())
+ enqueue_placeholder = tf.placeholder(tf.int32, shape=())
+ enqueue_op = q.enqueue((enqueue_placeholder,))
+ enqueuemany_placeholder = tf.placeholder(
+ tf.int32, shape=(None,))
+ enqueuemany_op = q.enqueue_many((enqueuemany_placeholder,))
+
+ dequeued_t = q.dequeue()
+ close_op = q.close()
+
+ def dequeue():
+ for i in xrange(250):
+ self.assertEqual(i, sess.run(dequeued_t))
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+
+ elements_enqueued = 0
+ while elements_enqueued < 250:
+ # With equal probability, run Enqueue or enqueue_many.
+ if random.random() > 0.5:
+ enqueue_op.run({enqueue_placeholder: elements_enqueued})
+ elements_enqueued += 1
+ else:
+ count = random.randint(0, min(20, 250 - elements_enqueued))
+ range_to_enqueue = range(elements_enqueued, elements_enqueued + count)
+ enqueuemany_op.run({enqueuemany_placeholder: range_to_enqueue})
+ elements_enqueued += count
+
+ close_op.run()
+ dequeue_thread.join()
+ self.assertEqual(0, q.size().eval())
+
+ def testMixtureOfDequeueAndDequeueMany(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, tf.int32, shapes=())
+ enqueue_op = q.enqueue_many((range(250),))
+ dequeued_t = q.dequeue()
+ count_placeholder = tf.placeholder(tf.int32, shape=())
+ dequeuemany_t = q.dequeue_many(count_placeholder)
+
+ def enqueue():
+ sess.run(enqueue_op)
+ enqueue_thread = self.checkedThread(target=enqueue)
+ enqueue_thread.start()
+
+ elements_dequeued = 0
+ while elements_dequeued < 250:
+ # With equal probability, run Dequeue or dequeue_many.
+ if random.random() > 0.5:
+ self.assertEqual(elements_dequeued, dequeued_t.eval())
+ elements_dequeued += 1
+ else:
+ count = random.randint(0, min(20, 250 - elements_dequeued))
+ expected_range = range(elements_dequeued, elements_dequeued + count)
+ self.assertAllEqual(
+ expected_range, dequeuemany_t.eval({count_placeholder: count}))
+ elements_dequeued += count
+
+ q.close().run()
+ enqueue_thread.join()
+ self.assertEqual(0, q.size().eval())
+
+ def testBlockingDequeueMany(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, tf.float32, ())
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ dequeued_t = q.dequeue_many(4)
+
+ dequeued_elems = []
+
+ def enqueue():
+ # The enqueue_op should run after the dequeue op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ sess.run(enqueue_op)
+
+ def dequeue():
+ dequeued_elems.extend(sess.run(dequeued_t).tolist())
+
+ enqueue_thread = self.checkedThread(target=enqueue)
+ dequeue_thread = self.checkedThread(target=dequeue)
+ enqueue_thread.start()
+ dequeue_thread.start()
+ enqueue_thread.join()
+ dequeue_thread.join()
+
+ self.assertAllEqual(elems, dequeued_elems)
+
+ def testDequeueManyWithTensorParameter(self):
+ with self.test_session():
+ # Define a first queue that contains integer counts.
+ dequeue_counts = [random.randint(1, 10) for _ in range(100)]
+ count_q = tf.FIFOQueue(100, tf.int32, ())
+ enqueue_counts_op = count_q.enqueue_many((dequeue_counts,))
+ total_count = sum(dequeue_counts)
+
+ # Define a second queue that contains total_count elements.
+ elems = [random.randint(0, 100) for _ in range(total_count)]
+ q = tf.FIFOQueue(total_count, tf.int32, ())
+ enqueue_elems_op = q.enqueue_many((elems,))
+
+ # Define a subgraph that first dequeues a count, then DequeuesMany
+ # that number of elements.
+ dequeued_t = q.dequeue_many(count_q.dequeue())
+
+ enqueue_counts_op.run()
+ enqueue_elems_op.run()
+
+ dequeued_elems = []
+ for _ in dequeue_counts:
+ dequeued_elems.extend(dequeued_t.eval())
+ self.assertEqual(elems, dequeued_elems)
+
+ def testDequeueFromClosedQueue(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, tf.float32)
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ close_op = q.close()
+ dequeued_t = q.dequeue()
+
+ enqueue_op.run()
+ close_op.run()
+ for elem in elems:
+ self.assertEqual([elem], dequeued_t.eval())
+
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError,
+ "is closed and has insufficient"):
+ dequeued_t.eval()
+
+ def testBlockingDequeueFromClosedQueue(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, tf.float32)
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ close_op = q.close()
+ dequeued_t = q.dequeue()
+
+ enqueue_op.run()
+
+ def dequeue():
+ for elem in elems:
+ self.assertEqual([elem], sess.run(dequeued_t))
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError,
+ "is closed and has insufficient"):
+ sess.run(dequeued_t)
+
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ close_op.run()
+ dequeue_thread.join()
+
+ def testBlockingDequeueFromClosedEmptyQueue(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, tf.float32)
+ close_op = q.close()
+ dequeued_t = q.dequeue()
+
+ def dequeue():
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError,
+ "is closed and has insufficient"):
+ sess.run(dequeued_t)
+
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ close_op.run()
+ dequeue_thread.join()
+
+ def testBlockingDequeueManyFromClosedQueue(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, tf.float32, ())
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ close_op = q.close()
+ dequeued_t = q.dequeue_many(4)
+
+ enqueue_op.run()
+
+ def dequeue():
+ self.assertAllEqual(elems, sess.run(dequeued_t))
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError,
+ "is closed and has insufficient"):
+ sess.run(dequeued_t)
+
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ close_op.run()
+ dequeue_thread.join()
+
+ def testEnqueueManyLargerThanCapacityWithConcurrentDequeueMany(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(4, tf.float32, ())
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ close_op = q.close()
+ dequeued_t = q.dequeue_many(3)
+ cleanup_dequeue_t = q.dequeue()
+
+ def enqueue():
+ sess.run(enqueue_op)
+
+ def dequeue():
+ self.assertAllEqual(elems[0:3], sess.run(dequeued_t))
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(dequeued_t)
+ self.assertEqual(elems[3], sess.run(cleanup_dequeue_t))
+
+ def close():
+ sess.run(close_op)
+
+ enqueue_thread = self.checkedThread(target=enqueue)
+ enqueue_thread.start()
+
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+
+ close_thread = self.checkedThread(target=close)
+ close_thread.start()
+
+ enqueue_thread.join()
+ dequeue_thread.join()
+ close_thread.join()
+
+ def testClosedBlockingDequeueManyRestoresPartialBatch(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(4, (tf.float32, tf.float32), ((), ()))
+ elems_a = [1.0, 2.0, 3.0]
+ elems_b = [10.0, 20.0, 30.0]
+ enqueue_op = q.enqueue_many((elems_a, elems_b))
+ dequeued_a_t, dequeued_b_t = q.dequeue_many(4)
+ cleanup_dequeue_a_t, cleanup_dequeue_b_t = q.dequeue()
+ close_op = q.close()
+
+ enqueue_op.run()
+
+ def dequeue():
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run([dequeued_a_t, dequeued_b_t])
+
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+
+ close_op.run()
+ dequeue_thread.join()
+ # Test that the elements in the partially-dequeued batch are
+ # restored in the correct order.
+ for elem_a, elem_b in zip(elems_a, elems_b):
+ val_a, val_b = sess.run([cleanup_dequeue_a_t, cleanup_dequeue_b_t])
+ self.assertEqual(elem_a, val_a)
+ self.assertEqual(elem_b, val_b)
+ self.assertEqual(0, q.size().eval())
+
+ def testBlockingDequeueManyFromClosedEmptyQueue(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(10, tf.float32, ())
+ close_op = q.close()
+ dequeued_t = q.dequeue_many(4)
+
+ def dequeue():
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError,
+ "is closed and has insufficient"):
+ sess.run(dequeued_t)
+
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ close_op.run()
+ dequeue_thread.join()
+
+ def testEnqueueToClosedQueue(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, tf.float32)
+ enqueue_op = q.enqueue((10.0,))
+ close_op = q.close()
+
+ enqueue_op.run()
+ close_op.run()
+
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.AbortedError, "is closed"):
+ enqueue_op.run()
+
+ def testEnqueueManyToClosedQueue(self):
+ with self.test_session():
+ q = tf.FIFOQueue(10, tf.float32)
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ close_op = q.close()
+
+ enqueue_op.run()
+ close_op.run()
+
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.AbortedError, "is closed"):
+ enqueue_op.run()
+
+ def testBlockingEnqueueToFullQueue(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(4, tf.float32)
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ blocking_enqueue_op = q.enqueue((50.0,))
+ dequeued_t = q.dequeue()
+
+ enqueue_op.run()
+
+ def blocking_enqueue():
+ sess.run(blocking_enqueue_op)
+ thread = self.checkedThread(target=blocking_enqueue)
+ thread.start()
+ # The dequeue ops should run after the blocking_enqueue_op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ for elem in elems:
+ self.assertEqual([elem], dequeued_t.eval())
+ self.assertEqual([50.0], dequeued_t.eval())
+ thread.join()
+
+ def testBlockingEnqueueManyToFullQueue(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(4, tf.float32)
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ blocking_enqueue_op = q.enqueue_many(([50.0, 60.0],))
+ dequeued_t = q.dequeue()
+
+ enqueue_op.run()
+
+ def blocking_enqueue():
+ sess.run(blocking_enqueue_op)
+ thread = self.checkedThread(target=blocking_enqueue)
+ thread.start()
+ # The dequeue ops should run after the blocking_enqueue_op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ for elem in elems:
+ self.assertEqual([elem], dequeued_t.eval())
+ time.sleep(0.01)
+ self.assertEqual([50.0], dequeued_t.eval())
+ self.assertEqual([60.0], dequeued_t.eval())
+
+ def testBlockingEnqueueBeforeClose(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(4, tf.float32)
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ blocking_enqueue_op = q.enqueue((50.0,))
+ close_op = q.close()
+ dequeued_t = q.dequeue()
+
+ enqueue_op.run()
+
+ def blocking_enqueue():
+ # Expect the operation to succeed once the dequeue op runs.
+ sess.run(blocking_enqueue_op)
+ enqueue_thread = self.checkedThread(target=blocking_enqueue)
+ enqueue_thread.start()
+
+ # The close_op should run after the blocking_enqueue_op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+
+ def close():
+ sess.run(close_op)
+ close_thread = self.checkedThread(target=close)
+ close_thread.start()
+
+ # The dequeue will unblock both threads.
+ self.assertEqual(10.0, dequeued_t.eval())
+ enqueue_thread.join()
+ close_thread.join()
+
+ for elem in [20.0, 30.0, 40.0, 50.0]:
+ self.assertEqual(elem, dequeued_t.eval())
+ self.assertEqual(0, q.size().eval())
+
+ def testBlockingEnqueueManyBeforeClose(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(4, tf.float32)
+ elems = [10.0, 20.0, 30.0]
+ enqueue_op = q.enqueue_many((elems,))
+ blocking_enqueue_op = q.enqueue_many(([50.0, 60.0],))
+ close_op = q.close()
+ dequeued_t = q.dequeue()
+ enqueue_op.run()
+
+ def blocking_enqueue():
+ sess.run(blocking_enqueue_op)
+ enqueue_thread = self.checkedThread(target=blocking_enqueue)
+ enqueue_thread.start()
+
+ # The close_op should run after the blocking_enqueue_op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+
+ def close():
+ sess.run(close_op)
+ close_thread = self.checkedThread(target=close)
+ close_thread.start()
+
+ # The dequeue will unblock both threads.
+ self.assertEqual(10.0, dequeued_t.eval())
+ enqueue_thread.join()
+ close_thread.join()
+ for elem in [20.0, 30.0, 50.0, 60.0]:
+ self.assertEqual(elem, dequeued_t.eval())
+
+ def testDoesNotLoseValue(self):
+ with self.test_session():
+ q = tf.FIFOQueue(1, tf.float32)
+ enqueue_op = q.enqueue((10.0,))
+ size_t = q.size()
+
+ enqueue_op.run()
+ for _ in range(500):
+ self.assertEqual(size_t.eval(), [1])
+
+ def testSharedQueueSameSession(self):
+ with self.test_session():
+ q1 = tf.FIFOQueue(
+ 1, tf.float32, shared_name="shared_queue")
+ q1.enqueue((10.0,)).run()
+
+ q2 = tf.FIFOQueue(
+ 1, tf.float32, shared_name="shared_queue")
+
+ q1_size_t = q1.size()
+ q2_size_t = q2.size()
+
+ self.assertEqual(q1_size_t.eval(), [1])
+ self.assertEqual(q2_size_t.eval(), [1])
+
+ self.assertEqual(q2.dequeue().eval(), [10.0])
+
+ self.assertEqual(q1_size_t.eval(), [0])
+ self.assertEqual(q2_size_t.eval(), [0])
+
+ q2.enqueue((20.0,)).run()
+
+ self.assertEqual(q1_size_t.eval(), [1])
+ self.assertEqual(q2_size_t.eval(), [1])
+
+ self.assertEqual(q1.dequeue().eval(), [20.0])
+
+ self.assertEqual(q1_size_t.eval(), [0])
+ self.assertEqual(q2_size_t.eval(), [0])
+
+ def testIncompatibleSharedQueueErrors(self):
+ with self.test_session():
+ q_a_1 = tf.FIFOQueue(10, tf.float32, shared_name="q_a")
+ q_a_2 = tf.FIFOQueue(15, tf.float32, shared_name="q_a")
+ q_a_1.queue_ref.eval()
+ with self.assertRaisesOpError("capacity"):
+ q_a_2.queue_ref.eval()
+
+ q_b_1 = tf.FIFOQueue(10, tf.float32, shared_name="q_b")
+ q_b_2 = tf.FIFOQueue(10, tf.int32, shared_name="q_b")
+ q_b_1.queue_ref.eval()
+ with self.assertRaisesOpError("component types"):
+ q_b_2.queue_ref.eval()
+
+ q_c_1 = tf.FIFOQueue(10, tf.float32, shared_name="q_c")
+ q_c_2 = tf.FIFOQueue(
+ 10, tf.float32, shapes=[(1, 1, 2, 3)], shared_name="q_c")
+ q_c_1.queue_ref.eval()
+ with self.assertRaisesOpError("component shapes"):
+ q_c_2.queue_ref.eval()
+
+ q_d_1 = tf.FIFOQueue(
+ 10, tf.float32, shapes=[(1, 1, 2, 3)], shared_name="q_d")
+ q_d_2 = tf.FIFOQueue(10, tf.float32, shared_name="q_d")
+ q_d_1.queue_ref.eval()
+ with self.assertRaisesOpError("component shapes"):
+ q_d_2.queue_ref.eval()
+
+ q_e_1 = tf.FIFOQueue(
+ 10, tf.float32, shapes=[(1, 1, 2, 3)], shared_name="q_e")
+ q_e_2 = tf.FIFOQueue(
+ 10, tf.float32, shapes=[(1, 1, 2, 4)], shared_name="q_e")
+ q_e_1.queue_ref.eval()
+ with self.assertRaisesOpError("component shapes"):
+ q_e_2.queue_ref.eval()
+
+ q_f_1 = tf.FIFOQueue(10, tf.float32, shared_name="q_f")
+ q_f_2 = tf.FIFOQueue(
+ 10, (tf.float32, tf.int32), shared_name="q_f")
+ q_f_1.queue_ref.eval()
+ with self.assertRaisesOpError("component types"):
+ q_f_2.queue_ref.eval()
+
+ def testSelectQueue(self):
+ with self.test_session():
+ num_queues = 10
+ qlist = list()
+ for _ in xrange(num_queues):
+ qlist.append(tf.FIFOQueue(10, tf.float32))
+ # Enqueue/Dequeue into a dynamically selected queue
+ for _ in xrange(20):
+ index = np.random.randint(num_queues)
+ q = tf.FIFOQueue.from_list(index, qlist)
+ q.enqueue((10.,)).run()
+ self.assertEqual(q.dequeue().eval(), 10.0)
+
+ def testSelectQueueOutOfRange(self):
+ with self.test_session():
+ q1 = tf.FIFOQueue(10, tf.float32)
+ q2 = tf.FIFOQueue(15, tf.float32)
+ enq_q = tf.FIFOQueue.from_list(3, [q1, q2])
+ with self.assertRaisesOpError("Index must be in the range"):
+ enq_q.dequeue().eval()
+
+ def _blockingDequeue(self, sess, dequeue_op):
+ with self.assertRaisesOpError("Dequeue operation was cancelled"):
+ sess.run(dequeue_op)
+
+ def _blockingDequeueMany(self, sess, dequeue_many_op):
+ with self.assertRaisesOpError("Dequeue operation was cancelled"):
+ sess.run(dequeue_many_op)
+
+ def _blockingEnqueue(self, sess, enqueue_op):
+ with self.assertRaisesOpError("Enqueue operation was cancelled"):
+ sess.run(enqueue_op)
+
+ def _blockingEnqueueMany(self, sess, enqueue_many_op):
+ with self.assertRaisesOpError("Enqueue operation was cancelled"):
+ sess.run(enqueue_many_op)
+
+ def testResetOfBlockingOperation(self):
+ with self.test_session() as sess:
+ q_empty = tf.FIFOQueue(5, tf.float32, ())
+ dequeue_op = q_empty.dequeue()
+ dequeue_many_op = q_empty.dequeue_many(1)
+
+ q_full = tf.FIFOQueue(5, tf.float32)
+ sess.run(q_full.enqueue_many(([1.0, 2.0, 3.0, 4.0, 5.0],)))
+ enqueue_op = q_full.enqueue((6.0,))
+ enqueue_many_op = q_full.enqueue_many(([6.0],))
+
+ threads = [
+ self.checkedThread(self._blockingDequeue, args=(sess, dequeue_op)),
+ self.checkedThread(self._blockingDequeueMany, args=(sess,
+ dequeue_many_op)),
+ self.checkedThread(self._blockingEnqueue, args=(sess, enqueue_op)),
+ self.checkedThread(self._blockingEnqueueMany, args=(sess,
+ enqueue_many_op))]
+ for t in threads:
+ t.start()
+ time.sleep(0.1)
+ sess.close() # Will cancel the blocked operations.
+ for t in threads:
+ t.join()
+
+ def testBigEnqueueMany(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(5, tf.int32, ((),))
+ elem = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ enq = q.enqueue_many((elem,))
+ deq = q.dequeue()
+ size_op = q.size()
+
+ enq_done = []
+ def blocking_enqueue():
+ enq_done.append(False)
+ # This will fill the queue and then block until enough dequeues happen.
+ sess.run(enq)
+ enq_done.append(True)
+ thread = self.checkedThread(target=blocking_enqueue)
+ thread.start()
+
+ # The enqueue should start and then block.
+ results = []
+ results.append(deq.eval()) # Will only complete after the enqueue starts.
+ self.assertEqual(len(enq_done), 1)
+ self.assertEqual(sess.run(size_op), 5)
+
+ for _ in range(3):
+ results.append(deq.eval())
+
+ time.sleep(0.1)
+ self.assertEqual(len(enq_done), 1)
+ self.assertEqual(sess.run(size_op), 5)
+
+ # This dequeue will unblock the thread.
+ results.append(deq.eval())
+ time.sleep(0.1)
+ self.assertEqual(len(enq_done), 2)
+ thread.join()
+
+ for i in range(5):
+ self.assertEqual(size_op.eval(), 5 - i)
+ results.append(deq.eval())
+ self.assertEqual(size_op.eval(), 5 - i - 1)
+
+ self.assertAllEqual(elem, results)
+
+ def testBigDequeueMany(self):
+ with self.test_session() as sess:
+ q = tf.FIFOQueue(2, tf.int32, ((),))
+ elem = range(4)
+ enq_list = [q.enqueue((e,)) for e in elem]
+ deq = q.dequeue_many(4)
+
+ results = []
+ def blocking_dequeue():
+ # Will only complete after 4 enqueues complete.
+ results.extend(sess.run(deq))
+ thread = self.checkedThread(target=blocking_dequeue)
+ thread.start()
+ # The dequeue should start and then block.
+ for enq in enq_list:
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ self.assertEqual(len(results), 0)
+ sess.run(enq)
+
+ # Enough enqueued to unblock the dequeue
+ thread.join()
+ self.assertAllEqual(elem, results)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py
new file mode 100644
index 0000000000..39e97531d2
--- /dev/null
+++ b/tensorflow/python/kernel_tests/gather_op_test.py
@@ -0,0 +1,71 @@
+"""Tests for tensorflow.ops.tf.gather."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class GatherTest(tf.test.TestCase):
+
+ def testScalar1D(self):
+ with self.test_session():
+ params = tf.constant([0, 1, 2, 3, 7, 5])
+ indices = tf.constant(4)
+ gather_t = tf.gather(params, indices)
+ gather_val = gather_t.eval()
+ self.assertAllEqual(7, gather_val)
+ self.assertEqual([], gather_t.get_shape())
+
+ def testScalar2D(self):
+ with self.test_session():
+ params = tf.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8],
+ [9, 10, 11], [12, 13, 14]])
+ indices = tf.constant(2)
+ gather_t = tf.gather(params, indices)
+ gather_val = gather_t.eval()
+ self.assertAllEqual([6, 7, 8], gather_val)
+ self.assertEqual([3], gather_t.get_shape())
+
+ def testSimpleTwoD32(self):
+ with self.test_session():
+ params = tf.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8],
+ [9, 10, 11], [12, 13, 14]])
+ indices = tf.constant([0, 4, 0, 2])
+ gather_t = tf.gather(params, indices)
+ gather_val = gather_t.eval()
+ self.assertAllEqual([[0, 1, 2], [12, 13, 14], [0, 1, 2], [6, 7, 8]],
+ gather_val)
+ self.assertEqual([4, 3], gather_t.get_shape())
+
+ def testHigherRank(self):
+ np.random.seed(1)
+ shape = (4, 3, 2)
+ params = np.random.randn(*shape)
+ indices = np.random.randint(shape[0], size=15).reshape(3, 5)
+ with self.test_session():
+ tf_params = tf.constant(params)
+ tf_indices = tf.constant(indices)
+ gather = tf.gather(tf_params, tf_indices)
+ self.assertAllEqual(params[indices], gather.eval())
+ self.assertEqual(indices.shape + params.shape[1:], gather.get_shape())
+ # Test gradients
+ gather_grad = np.random.randn(*gather.get_shape().as_list())
+ params_grad, indices_grad = tf.gradients(gather, [tf_params, tf_indices],
+ gather_grad)
+ self.assertEqual(indices_grad, None)
+ self.assertEqual(type(params_grad), tf.IndexedSlices)
+ params_grad = tf.convert_to_tensor(params_grad)
+ correct_params_grad = np.zeros(shape)
+ for i, g in zip(indices.ravel(), gather_grad.reshape((15,) + shape[1:])):
+ correct_params_grad[i] += g
+ self.assertAllEqual(correct_params_grad, params_grad.eval())
+
+ def testUnknownIndices(self):
+ params = tf.constant([[0, 1, 2]])
+ indices = tf.placeholder(tf.int32)
+ gather_t = tf.gather(params, indices)
+ self.assertEqual(None, gather_t.get_shape())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/gradient_checker.py b/tensorflow/python/kernel_tests/gradient_checker.py
new file mode 100644
index 0000000000..fe74768986
--- /dev/null
+++ b/tensorflow/python/kernel_tests/gradient_checker.py
@@ -0,0 +1,251 @@
+"""Gradient checker for any ops, graphs.
+
+The gradient checker verifies numerically that an op/graph properly
+computes the gradients
+"""
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import gradients
+from tensorflow.python.platform import logging
+
+
+def _Product(t):
+ if isinstance(t, int):
+ return t
+ else:
+ y = 1
+ for x in t:
+ y *= x
+ return y
+
+
+def _ComputeTheoricalJacobian(x, x_shape, x_data, dy, dy_shape, dx):
+ """Computes the theoretical Jacobian for dy/dx.
+
+ Computes the theoretical Jacobian using the ops generated by
+ ComputeGradient().
+
+ Args:
+ x: the tensor "x".
+ x_shape: the dimensions of x as a tuple or an array of ints.
+ x_data: a numpy parray as the input data for x
+ dy: the tensor "dy".
+ dy_shape: the dimensions of dy as a tuple or an array of ints.
+ dx: Tensor or IndexedSlices representing dx
+
+ Returns:
+ A 2-d numpy array representing the Jacobian for dy/dx. It has "x_size" rows
+ and "dy_size" columns where "x_size" is the number of elements in x and
+ "dy_size" is the number of elements in dy.
+ """
+ # To compute the jacobian, we treat x and y are one-dimensional vectors
+ x_size = _Product(x_shape)
+ x_val_size = _Product(x_shape[1:]) # This is used for sparse gradients
+ dy_size = _Product(dy_shape)
+
+ jacobian = np.zeros((x_size, dy_size), dtype=x_data.dtype)
+ # For each of the entry of dy, we set this to be 1 and
+ # everything else to be 0 and compute the backprop -- this will give us one
+ # one column of the Jacobian matrix.
+ for col in range(0, dy_size):
+ dy_data = np.zeros(dy_shape, dtype=x_data.dtype)
+ dy_data.flat[col] = 1
+ sess = ops.get_default_session()
+ if isinstance(dx, ops.IndexedSlices):
+ backprop_indices, backprop_values = sess.run(
+ [dx.indices, dx.values], feed_dict={x: x_data, dy: dy_data})
+ for i, v in zip(backprop_indices, backprop_values):
+ r_begin = i * x_val_size
+ r_end = r_begin + x_val_size
+ jacobian[r_begin:r_end, col] += v.flat
+ else:
+ assert isinstance(dx, ops.Tensor), "dx = " + str(dx)
+ backprop = sess.run(dx, feed_dict={x: x_data, dy: dy_data})
+ jacobian[:, col] = backprop.reshape(x_size)
+
+ logging.vlog(1, "Theoretical Jacobian =\n%s", jacobian)
+ return jacobian
+
+
+def _ComputeNumericJacobian(x, x_shape, x_data, y, y_shape, delta):
+ """Computes the numeric Jacobian for dy/dx.
+
+ Computes the numeric Japcobian by slightly perturbing the inputs and
+ measuring the differences on the output.
+
+ Args:
+ x: the tensor "x".
+ x_shape: the dimensions of x as a tuple or an array of ints.
+ x_data: a numpy array as the input data for x
+ y: the tensor "y".
+ y_shape: the dimensions of y as a tuple or an array of ints.
+ delta: the amount of perturbation we give to the input
+
+ Returns:
+ A 2-d numpy array representing the Jacobian for dy/dx. It has "x_size" rows
+ and "y_size" columns where "x_size" is the number of elements in x and
+ "y_size" is the number of elements in y.
+ """
+
+ # To compute the jacobian, we treat x and y are one-dimensional vectors
+ x_size = _Product(x_shape)
+ y_size = _Product(y_shape)
+
+ jacobian = np.zeros((x_size, y_size), dtype=x_data.dtype)
+ # For each of the entry of x, we slightly perturbs this by adding and
+ # subtracting a delta and then compute difference between the outputs. This
+ # will give us one row of the Jacobian matrix.
+ for row in range(0, x_size):
+ x_pos = x_data.copy()
+ x_pos.flat[row] += delta
+ y_pos = y.eval(feed_dict={x: x_pos})
+ x_neg = x_data.copy()
+ x_neg.flat[row] -= delta
+ y_neg = y.eval(feed_dict={x: x_neg})
+ diff = (y_pos - y_neg) / (2 * delta)
+ jacobian[row, :] = diff.reshape(y_size)
+
+ logging.vlog(1, "Numeric Jacobian =\n%s", jacobian)
+ return jacobian
+
+
+def _ComputeDxAndDy(x, y, y_shape):
+ """Returns a node to compute gradient of x wrt y."""
+ # We make up a dy so that we can compute the gradients. We don't really use
+ # the value of dy -- we will always feed it. We need to add an identity node
+ # so that we can always feed it properly. Otherwise, for the Add operation,
+ # dx is the same as dy and we cannot fetch the tensor that we are feeding.
+ with x.graph.as_default():
+ dy_orig = constant_op.constant(1.0, shape=y_shape, dtype=y.dtype)
+ dy = array_ops.identity(dy_orig)
+ # We compute the gradients for x wrt. y
+ grads = gradients.gradients(y, x, dy)
+ assert len(grads) == 1
+ return grads[0], dy_orig
+
+
+def _ComputeGradient(x, x_shape, dx, y, y_shape, dy,
+ x_init_value=None, delta=1e-3):
+ """Computes the theoretical and numerical jacobian."""
+ t = types.as_dtype(x.dtype)
+ allowed_types = [types.float32, types.float64]
+ assert t.base_dtype in allowed_types, "Don't support type %s for x" % t.name
+ t2 = types.as_dtype(y.dtype)
+ assert t2.base_dtype in allowed_types, "Don't support type %s for y" % t2.name
+
+ if x_init_value is not None:
+ i_shape = list(x_init_value.shape)
+ assert(list(x_shape) == i_shape), "x_shape = %s, init_data shape = %s" % (
+ x_shape, i_shape)
+ x_data = x_init_value
+ else:
+ if t == types.float32:
+ dtype = np.float32
+ else:
+ dtype = np.float64
+ x_data = np.asfarray(np.random.random_sample(x_shape), dtype=dtype)
+
+ jacob_t = _ComputeTheoricalJacobian(x, x_shape, x_data, dy, y_shape, dx)
+ jacob_n = _ComputeNumericJacobian(x, x_shape, x_data, y, y_shape, delta)
+ return jacob_t, jacob_n
+
+
+def _ComputeGradientList(
+ x, x_shape, y, y_shape, x_init_value=None, delta=1e-3, init_targets=None):
+ """Compute gradients for a list of x values."""
+ assert isinstance(x, list)
+ dx, dy = zip(*[_ComputeDxAndDy(xi, y, y_shape) for xi in x])
+
+ if init_targets is not None:
+ assert isinstance(init_targets, (list, tuple))
+ for init in init_targets:
+ init.run()
+ if x_init_value is None:
+ x_init_value = [None] * len(x)
+ ret = [_ComputeGradient(xi, x_shapei, dxi, y, y_shape, dyi,
+ x_init_valuei, delta)
+ for xi, x_shapei, dxi, dyi, x_init_valuei in
+ zip(x, x_shape, dx, dy, x_init_value)]
+ return ret
+
+
+def ComputeGradient(
+ x, x_shape, y, y_shape, x_init_value=None, delta=1e-3, init_targets=None):
+ """Computes and returns the theoretical and numerical Jacobian.
+
+ Args:
+ x: a tensor or list of tensors
+ x_shape: the dimensions of x as a tuple or an array of ints. If x is a list,
+ then this is the list of shapes.
+ y: a tensor
+ y_shape: the dimensions of y as a tuple or an array of ints.
+ x_init_value: (optional) a numpy array of the same shape as "x"
+ representing the initial value of x. If x is a list, this should be a list
+ of numpy arrays. If this is none, the function will pick a random tensor
+ as the initial value.
+ delta: (optional) the amount of perturbation.
+ init_targets: list of targets to run to initialize model params.
+ TODO(mrry): remove this argument.
+
+ Returns:
+ Two 2-d numpy arrays representing the theoretical and numerical
+ Jacobian for dy/dx. Each has "x_size" rows and "y_size" columns
+ where "x_size" is the number of elements in x and "y_size" is the
+ number of elements in y. If x is a list, returns a list of two numpy arrays.
+ """
+ if isinstance(x, list):
+ return _ComputeGradientList(x, x_shape, y, y_shape, x_init_value,
+ delta, init_targets)
+ else:
+ if init_targets is not None:
+ assert isinstance(init_targets, (list, tuple))
+ for init in init_targets:
+ init.run()
+ dx, dy = _ComputeDxAndDy(x, y, y_shape)
+ ret = _ComputeGradient(x, x_shape, dx, y, y_shape, dy, x_init_value, delta)
+ return ret
+
+
+def ComputeGradientError(
+ x, x_shape, y, y_shape, x_init_value=None, delta=1e-3, init_targets=None):
+ """Computes the gradient error.
+
+ Computes the maximum error for dy/dx between the computed Jacobian and the
+ numerically estimated Jacobian.
+
+ This function will modify the tensors passed in as it adds more operations
+ and hence changing the consumers of the operations of the input tensors.
+
+ This function adds operations to the current session. To compute the error
+ using a particular device, such as a GPU, use the standard methods for
+ setting a device (e.g. using with sess.graph.device() or setting a device
+ function in the session constructor).
+
+ Args:
+ x: a tensor or list of tensors
+ x_shape: the dimensions of x as a tuple or an array of ints. If x is a list,
+ then this is the list of shapes.
+ y: a tensor
+ y_shape: the dimensions of y as a tuple or an array of ints.
+ x_init_value: (optional) a numpy array of the same shape as "x"
+ representing the initial value of x. If x is a list, this should be a list
+ of numpy arrays. If this is none, the function will pick a random tensor
+ as the initial value.
+ delta: (optional) the amount of perturbation.
+ init_targets: list of targets to run to initialize model params.
+ TODO(mrry): Remove this argument.
+
+ Returns:
+ The maximum error in between the two Jacobians.
+ """
+ grad = ComputeGradient(x, x_shape, y, y_shape, x_init_value,
+ delta, init_targets)
+ if isinstance(grad, tuple):
+ grad = [grad]
+ return max(np.fabs(j_t - j_n).max() for j_t, j_n in grad)
diff --git a/tensorflow/python/kernel_tests/gradient_checker_test.py b/tensorflow/python/kernel_tests/gradient_checker_test.py
new file mode 100644
index 0000000000..a844b7c637
--- /dev/null
+++ b/tensorflow/python/kernel_tests/gradient_checker_test.py
@@ -0,0 +1,178 @@
+"""Tests for tensorflow.kernels.gradient_checker."""
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests.gradient_checker import ComputeGradientError
+
+
+class GradientCheckerTest(tf.test.TestCase):
+
+ def testAddSimple(self):
+ with self.test_session(use_gpu=False):
+ # a test case for Add operation
+ size = (2, 3)
+ x1 = tf.constant(2.0, shape=size, name="x1")
+ x2 = tf.constant(3.0, shape=size, name="x2")
+ y = tf.add(x1, x2, name="y")
+
+ # checking gradients for x1
+ error = ComputeGradientError(x1, size, y, size)
+ tf.logging.info("x1 error = %f", error)
+ assert error < 1e-4
+
+ def testAddSimpleGPU(self):
+ with self.test_session(use_gpu=True):
+ # a test case for Add operation
+ size = (2, 3)
+ x1 = tf.constant(2.0, shape=size, name="x1")
+ x2 = tf.constant(3.0, shape=size, name="x2")
+ y = tf.add(x1, x2, name="y")
+
+ # checking gradients for x1
+ error = ComputeGradientError(x1, size, y, size)
+ tf.logging.info("x1 error = %f", error)
+ assert error < 1e-4
+
+ def testAddCustomized(self):
+ with self.test_session():
+ # a test case for Add operation
+ size = (2, 3)
+ x1 = tf.constant(2.0, shape=size, dtype=tf.float64,
+ name="x1")
+ x2 = tf.constant(3.0, shape=size, dtype=tf.float64,
+ name="x2")
+ y = tf.add(x1, x2, name="y")
+
+ # checkint gradients for x2 using a special init_value and delta
+ x_init_value = np.asarray(np.arange(6, dtype=np.float64).reshape(2, 3))
+ error = ComputeGradientError(x2, size, y, size, x_init_value=x_init_value,
+ delta=1e-2)
+ tf.logging.info("x2 error = %f", error)
+ assert error < 1e-10
+
+ def testGather(self):
+ with self.test_session():
+ p_shape = (4, 2)
+ p_size = 8
+ index_values = [1, 3]
+ y_shape = [2, 2]
+ params = tf.constant(np.arange(p_size).astype(np.float),
+ shape=p_shape, name="p")
+ indices = tf.constant(index_values, name="i")
+ y = tf.gather(params, indices, name="y")
+
+ error = ComputeGradientError(params, p_shape, y, y_shape)
+ tf.logging.info("gather error = %f", error)
+ assert error < 1e-4
+
+ def testNestedGather(self):
+ with self.test_session():
+ p_shape = (8, 2)
+ p_size = 16
+ index_values = [1, 3, 5, 6]
+ index_values2 = [0, 2]
+ y2_shape = [2, 2]
+
+ params = tf.constant(np.arange(p_size).astype(np.float),
+ shape=p_shape, name="p")
+ indices = tf.constant(index_values, name="i")
+ y = tf.gather(params, indices, name="y")
+ indices2 = tf.constant(index_values2, name="i2")
+ y2 = tf.gather(y, indices2, name="y2")
+
+ error = ComputeGradientError(params, p_shape, y2, y2_shape)
+ tf.logging.info("nested gather error = %f", error)
+ assert error < 1e-4
+
+
+# Gradient checker for MNIST.
+def BuildAndTestMiniMNIST(param_index, tag):
+ # Hyperparameters
+ batch = 3
+ inputs = 16
+ features = 32
+ classes = 10
+
+ # Define the parameters
+ inp_data = np.random.random_sample(inputs * batch)
+ hidden_weight_data = np.random.randn(inputs * features) / np.sqrt(inputs)
+ hidden_bias_data = np.random.random_sample(features)
+ sm_weight_data = np.random.randn(features * classes) / np.sqrt(features)
+ sm_bias_data = np.random.random_sample(classes)
+
+ # special care for labels since they need to be normalized per batch
+ label_data = np.random.random(batch * classes).reshape((batch, classes))
+ s = label_data.sum(axis=1)
+ label_data /= s[:, None]
+
+ with tf.Session():
+ # We treat the inputs as "parameters" here
+ inp = tf.constant(inp_data.tolist(), shape=[batch, inputs],
+ dtype=tf.float64, name="inp")
+ hidden_weight = tf.constant(hidden_weight_data.tolist(),
+ shape=[inputs, features],
+ dtype=tf.float64,
+ name="hidden_weight")
+ hidden_bias = tf.constant(hidden_bias_data.tolist(),
+ shape=[features],
+ dtype=tf.float64,
+ name="hidden_bias")
+ softmax_weight = tf.constant(sm_weight_data.tolist(),
+ shape=[features, classes],
+ dtype=tf.float64,
+ name="softmax_weight")
+ softmax_bias = tf.constant(sm_bias_data.tolist(), shape=[classes],
+ dtype=tf.float64,
+ name="softmax_bias")
+
+ # List all the parameter so that we can test them one at a time
+ all_params = [inp, hidden_weight, hidden_bias, softmax_weight, softmax_bias]
+ param_sizes = [[batch, inputs], # inp
+ [inputs, features], # hidden_weight,
+ [features], # hidden_bias
+ [features, classes], # softmax_weight,
+ [classes]] # softmax_bias
+
+ # Now, Building MNIST
+ features = tf.nn.relu(tf.nn.xw_plus_b(inp, hidden_weight, hidden_bias),
+ name="features")
+ logits = tf.nn.xw_plus_b(features, softmax_weight, softmax_bias,
+ name="logits")
+ labels = tf.constant(label_data.tolist(),
+ shape=[batch, classes],
+ dtype=tf.float64,
+ name="labels")
+ cost = tf.nn.softmax_cross_entropy_with_logits(logits, labels, name="cost")
+
+ # Test the gradients.
+ err = ComputeGradientError(all_params[param_index],
+ param_sizes[param_index],
+ cost, [batch], delta=1e-5)
+
+ tf.logging.info("Mini MNIST: %s gradient error = %g", tag, err)
+ return err
+
+
+class MiniMNISTTest(tf.test.TestCase):
+
+ def testInputGradient(self):
+ self.assertLess(BuildAndTestMiniMNIST(0, "input"), 1e-8)
+
+ def testHiddenWeightGradient(self):
+ self.assertLess(BuildAndTestMiniMNIST(1, "hidden_weight"), 1e-8)
+
+ def testHiddenBiasGradient(self):
+ self.assertLess(BuildAndTestMiniMNIST(2, "hidden_bias"), 1e-8)
+
+ def testSoftmaxWeightGradient(self):
+ self.assertLess(BuildAndTestMiniMNIST(3, "softmax_weight"), 1e-8)
+
+ def testSoftmaxBiasGradient(self):
+ self.assertLess(BuildAndTestMiniMNIST(4, "softmax_bias"), 1e-8)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/identity_op_py_test.py b/tensorflow/python/kernel_tests/identity_op_py_test.py
new file mode 100644
index 0000000000..2209cf08ad
--- /dev/null
+++ b/tensorflow/python/kernel_tests/identity_op_py_test.py
@@ -0,0 +1,47 @@
+"""Tests for IdentityOp."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.ops import gen_array_ops
+
+
+class IdentityOpTest(tf.test.TestCase):
+
+ def testInt32_6(self):
+ with self.test_session():
+ value = tf.identity([1, 2, 3, 4, 5, 6]).eval()
+ self.assertAllEqual(np.array([1, 2, 3, 4, 5, 6]), value)
+
+ def testInt32_2_3(self):
+ with self.test_session():
+ inp = tf.constant([10, 20, 30, 40, 50, 60], shape=[2, 3])
+ value = tf.identity(inp).eval()
+ self.assertAllEqual(np.array([[10, 20, 30], [40, 50, 60]]), value)
+
+ def testString(self):
+ with self.test_session():
+ value = tf.identity(["A", "b", "C", "d", "E", "f"]).eval()
+ self.assertAllEqual(["A", "b", "C", "d", "E", "f"], value)
+
+ def testIdentityShape(self):
+ with self.test_session():
+ shape = [2, 3]
+ array_2x3 = [[1, 2, 3], [6, 5, 4]]
+ tensor = tf.constant(array_2x3)
+ self.assertEquals(shape, tensor.get_shape())
+ self.assertEquals(shape, tf.identity(tensor).get_shape())
+ self.assertEquals(shape, tf.identity(array_2x3).get_shape())
+ self.assertEquals(shape, tf.identity(np.array(array_2x3)).get_shape())
+
+ def testRefIdentityShape(self):
+ with self.test_session():
+ shape = [2, 3]
+ tensor = tf.Variable(tf.constant([[1, 2, 3], [6, 5, 4]], dtype=tf.int32))
+ self.assertEquals(shape, tensor.get_shape())
+ self.assertEquals(shape, gen_array_ops._ref_identity(tensor).get_shape())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/in_topk_op_test.py b/tensorflow/python/kernel_tests/in_topk_op_test.py
new file mode 100644
index 0000000000..d2a51788c4
--- /dev/null
+++ b/tensorflow/python/kernel_tests/in_topk_op_test.py
@@ -0,0 +1,36 @@
+"""Tests for PrecisionOp."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class InTopKTest(tf.test.TestCase):
+
+ def _validateInTopK(self, predictions, target, k, expected):
+ np_ans = np.array(expected)
+ with self.test_session():
+ precision = tf.nn.in_top_k(predictions, target, k)
+ out = precision.eval()
+ self.assertAllClose(np_ans, out)
+ self.assertShapeEqual(np_ans, precision)
+
+ def testInTop1(self):
+ predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ target = [3, 1]
+ self._validateInTopK(predictions, target, 1, [True, False])
+
+ def testInTop2(self):
+ predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ target = [0, 2]
+ self._validateInTopK(predictions, target, 2, [False, True])
+
+ def testInTop2Tie(self):
+ # Class 2 and 3 tie for 2nd, so both are considered in top 2.
+ predictions = [[0.1, 0.3, 0.2, 0.2], [0.1, 0.3, 0.2, 0.2]]
+ target = [2, 3]
+ self._validateInTopK(predictions, target, 2, [True, True])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
new file mode 100644
index 0000000000..4ce6081b7b
--- /dev/null
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -0,0 +1,252 @@
+"""Tests for tensorflow.ops.ops."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import init_ops
+
+
+# Returns true iff the two initalizers produce the same tensor to
+# within a tiny tolerance.
+def identicaltest(tc, init1, init2, use_gpu):
+ """Tests if two initializations are identical to within tiny tolerances.
+
+ Args:
+ tc: An instance of TensorFlowTestCase.
+ init1: An Initializer that generates a tensor of a given shape
+ init2: An Initializer that generates a tensor of a given shape
+ use_gpu: Use gpu if true.
+ Returns:
+ True or False as determined by test.
+ """
+ num = 100
+ with tc.test_session(use_gpu=use_gpu, graph=tf.Graph()):
+ t1 = init1([num]).eval()
+ with tc.test_session(use_gpu=use_gpu, graph=tf.Graph()):
+ t2 = init2([num]).eval()
+ return np.allclose(t1, t2, rtol=1e-15, atol=1e-15)
+
+
+def duplicated_initializer(tc, init, use_gpu, graph_seed):
+ """Tests duplicated random initializer within the same graph.
+
+ This test generates two random kernels from the same initializer to the same
+ graph, and checks if the results are close enough. Even given the same global,
+ seed, two different instances of random kernels should generate different
+ results.
+
+ Args:
+ tc: An instance of TensorFlowTestCase.
+ init: An Initializer that generates a tensor of a given shape
+ use_gpu: Use gpu if true.
+ graph_seed: A graph-level seed to use.
+ Returns:
+ True or False as determined by test.
+ """
+ num = 100
+ with tc.test_session(use_gpu=use_gpu, graph=tf.Graph()):
+ random_seed.set_random_seed(graph_seed)
+ t1 = init([num]).eval()
+ t2 = init([num]).eval()
+ return np.allclose(t1, t2, rtol=1e-15, atol=1e-15)
+
+
+def _init_sampler(tc, init, num, use_gpu):
+ """Returns a func to generate a random tensor of shape [num].
+
+ Args:
+ tc: An instance of TensorFlowTestCase.
+ init: An Initializer that generates a tensor of a given shape
+ num: Size of 1D tensor to create.
+ use_gpu: Use gpu if true.
+ Returns:
+ Function to generate a random tensor.
+ """
+ def func():
+ with tc.test_session(use_gpu=use_gpu):
+ return init([num]).eval()
+ return func
+
+
+class RandomNormalInitializationTest(tf.test.TestCase):
+
+ def testInitializerIdentical(self):
+ for use_gpu in [False, True]:
+ init1 = tf.random_normal_initializer(0.0, 1.0, seed=1)
+ init2 = tf.random_normal_initializer(0.0, 1.0, seed=1)
+ self.assertTrue(identicaltest(self, init1, init2, use_gpu))
+
+ def testInitializerDifferent(self):
+ for use_gpu in [False, True]:
+ init1 = tf.random_normal_initializer(0.0, 1.0, seed=1)
+ init2 = tf.random_normal_initializer(0.0, 1.0, seed=2)
+ self.assertFalse(identicaltest(self, init1, init2, use_gpu=use_gpu))
+
+ def testDuplicatedInitializer(self):
+ for use_gpu in [False, True]:
+ init = tf.random_normal_initializer(0.0, 1.0)
+ self.assertFalse(duplicated_initializer(self, init, use_gpu, 1))
+
+
+class TruncatedNormalInitializationTest(tf.test.TestCase):
+
+ def testInitializerIdentical(self):
+ for use_gpu in [False, True]:
+ init1 = tf.truncated_normal_initializer(0.0, 1.0, seed=1)
+ init2 = tf.truncated_normal_initializer(0.0, 1.0, seed=1)
+ self.assertTrue(identicaltest(self, init1, init2, use_gpu))
+
+ def testInitializerDifferent(self):
+ for use_gpu in [False, True]:
+ init1 = tf.truncated_normal_initializer(0.0, 1.0, seed=1)
+ init2 = tf.truncated_normal_initializer(0.0, 1.0, seed=2)
+ self.assertFalse(identicaltest(self, init1, init2, use_gpu=use_gpu))
+
+ def testDuplicatedInitializer(self):
+ for use_gpu in [False, True]:
+ init = tf.truncated_normal_initializer(0.0, 1.0)
+ self.assertFalse(duplicated_initializer(self, init, use_gpu, 1))
+
+
+class RandomUniformInitializationTest(tf.test.TestCase):
+
+ def testInitializerIdentical(self):
+ for use_gpu in [False, True]:
+ init1 = tf.random_uniform_initializer(0.0, 1.0, seed=1)
+ init2 = tf.random_uniform_initializer(0.0, 1.0, seed=1)
+ self.assertTrue(identicaltest(self, init1, init2, use_gpu))
+
+ def testInitializerDifferent(self):
+ for use_gpu in [False, True]:
+ init1 = tf.random_uniform_initializer(0.0, 1.0, seed=1)
+ init2 = tf.random_uniform_initializer(0.0, 1.0, seed=2)
+ self.assertFalse(identicaltest(self, init1, init2, use_gpu))
+
+ def testDuplicatedInitializer(self):
+ for use_gpu in [False, True]:
+ init = tf.random_uniform_initializer(0.0, 1.0)
+ self.assertFalse(duplicated_initializer(self, init, use_gpu, 1))
+
+
+class UniformUnitScalingInitializationTest(tf.test.TestCase):
+
+ def testInitializerIdentical(self):
+ for use_gpu in [False, True]:
+ init1 = tf.uniform_unit_scaling_initializer(seed=1)
+ init2 = tf.uniform_unit_scaling_initializer(seed=1)
+ self.assertTrue(identicaltest(self, init1, init2, use_gpu))
+ init3 = tf.uniform_unit_scaling_initializer(1.5, seed=1)
+ init4 = tf.uniform_unit_scaling_initializer(1.5, seed=1)
+ self.assertTrue(identicaltest(self, init3, init4, use_gpu))
+
+ def testInitializerDifferent(self):
+ for use_gpu in [False, True]:
+ init1 = tf.uniform_unit_scaling_initializer(seed=1)
+ init2 = tf.uniform_unit_scaling_initializer(seed=2)
+ init3 = tf.uniform_unit_scaling_initializer(1.5, seed=1)
+ self.assertFalse(identicaltest(self, init1, init2, use_gpu))
+ self.assertFalse(identicaltest(self, init1, init3, use_gpu))
+ self.assertFalse(identicaltest(self, init2, init3, use_gpu))
+
+ def testDuplicatedInitializer(self):
+ for use_gpu in [False, True]:
+ init = tf.uniform_unit_scaling_initializer()
+ self.assertFalse(duplicated_initializer(self, init, use_gpu, 1))
+
+
+class RandomWalkShapeTest(tf.test.TestCase):
+
+ def testRandomWalk(self):
+ # Fully known shape.
+ rnd1 = init_ops._random_walk([1, 2], tf.nn.relu)
+ self.assertEqual([1, 2], rnd1.get_shape())
+
+
+# TODO(vrv): move to sequence_ops_test?
+class RangeTest(tf.test.TestCase):
+
+ def _Range(self, start, limit, delta):
+ with self.test_session():
+ tf_ans = tf.range(start, limit, delta, name="range")
+ self.assertEqual([len(range(start, limit, delta))], tf_ans.get_shape())
+ return tf_ans.eval()
+
+ def testBasic(self):
+ self.assertTrue(np.array_equal(
+ self._Range(0, 5, 1), np.array([0, 1, 2, 3, 4])))
+ self.assertTrue(np.array_equal(
+ self._Range(0, 5, 2), np.array([0, 2, 4])))
+ self.assertTrue(np.array_equal(
+ self._Range(0, 6, 2), np.array([0, 2, 4])))
+ self.assertTrue(np.array_equal(
+ self._Range(13, 32, 7), np.array([13, 20, 27])))
+ self.assertTrue(np.array_equal(
+ self._Range(100, 500, 100), np.array([100, 200, 300, 400])))
+ self.assertEqual(tf.range(0, 5, 1).dtype, tf.int32)
+
+ def testEmpty(self):
+ for start in 0, 5:
+ self.assertTrue(np.array_equal(self._Range(start, start, 1), []))
+
+
+# TODO(vrv): move to sequence_ops_test?
+class LinSpaceTest(tf.test.TestCase):
+
+ def _LinSpace(self, start, stop, num):
+ with self.test_session():
+ tf_ans = tf.linspace(start, stop, num, name="linspace")
+ self.assertEqual([num], tf_ans.get_shape())
+ return tf_ans.eval()
+
+ def testPositive(self):
+ self.assertArrayNear(self._LinSpace(1., 5., 1), np.array([1.]), 1e-5)
+ self.assertArrayNear(self._LinSpace(1., 5., 2), np.array([1., 5.]), 1e-5)
+ self.assertArrayNear(self._LinSpace(1., 5., 3),
+ np.array([1., 3., 5.]), 1e-5)
+ self.assertArrayNear(self._LinSpace(1., 5., 4),
+ np.array([1., 7. / 3., 11. / 3., 5.]), 1e-5)
+
+ def testNegative(self):
+ self.assertArrayNear(self._LinSpace(-1., -5., 1), np.array([-1.]), 1e-5)
+ self.assertArrayNear(self._LinSpace(-1., -5., 2),
+ np.array([-1., -5.]), 1e-5)
+ self.assertArrayNear(self._LinSpace(-1., -5., 3),
+ np.array([-1., -3., -5.]), 1e-5)
+ self.assertArrayNear(self._LinSpace(-1., -5., 4),
+ np.array([-1., -7. / 3., -11. / 3., -5.]), 1e-5)
+
+ def testNegativeToPositive(self):
+ self.assertArrayNear(self._LinSpace(-1., 5., 1), np.array([-1.]), 1e-5)
+ self.assertArrayNear(self._LinSpace(-1., 5., 2), np.array([-1., 5.]), 1e-5)
+ self.assertArrayNear(self._LinSpace(-1., 5., 3),
+ np.array([-1., 2., 5.]), 1e-5)
+ self.assertArrayNear(self._LinSpace(-1., 5., 4),
+ np.array([-1., 1., 3., 5.]), 1e-5)
+
+ def testPoint(self):
+ self.assertArrayNear(self._LinSpace(5., 5., 1), np.array([5.]), 1e-5)
+ self.assertArrayNear(self._LinSpace(5., 5., 2), np.array([5.] * 2), 1e-5)
+ self.assertArrayNear(self._LinSpace(5., 5., 3), np.array([5.] * 3), 1e-5)
+ self.assertArrayNear(self._LinSpace(5., 5., 4), np.array([5.] * 4), 1e-5)
+
+
+class DeviceTest(tf.test.TestCase):
+
+ def testNoDevice(self):
+ with tf.Graph().as_default():
+ var = tf.Variable([[1.0, 1.0]])
+ self.assertEqual(None, var.device)
+ self.assertEqual(None, var.initializer.device)
+
+ def testDevice(self):
+ with tf.Graph().as_default():
+ with tf.device("/job:ps"):
+ var = tf.Variable([[1.0, 1.0]])
+ self.assertEqual("/job:ps", var.device)
+ self.assertEqual("/job:ps", var.initializer.device)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/io_ops_test.py b/tensorflow/python/kernel_tests/io_ops_test.py
new file mode 100644
index 0000000000..2eb8bdd26f
--- /dev/null
+++ b/tensorflow/python/kernel_tests/io_ops_test.py
@@ -0,0 +1,53 @@
+"""Tests for tensorflow.python.ops.io_ops."""
+# -*- coding: utf-8 -*-
+
+import tempfile
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class IoOpsTest(tf.test.TestCase):
+
+ def testReadFile(self):
+ cases = ['', 'Some contents', 'Неки садржаји на српском']
+ for contents in cases:
+ temp = tempfile.NamedTemporaryFile(prefix='ReadFileTest')
+ open(temp.name, 'wb').write(contents)
+ with self.test_session():
+ read = tf.read_file(temp.name)
+ self.assertEqual([], read.get_shape())
+ self.assertEqual(read.eval(), contents)
+
+ def _subset(self, files, indices):
+ return set([files[i].name for i in range(len(files)) if i in indices])
+
+ def testMatchingFiles(self):
+ cases = ['ABcDEF.GH', 'ABzDEF.GH', 'ABasdfjklDEF.GH', 'AB3DEF.GH',
+ 'AB4DEF.GH', 'ABDEF.GH', 'XYZ']
+ files = [tempfile.NamedTemporaryFile(prefix=c) for c in cases]
+
+ with self.test_session():
+ # Test exact match without wildcards.
+ for f in files:
+ self.assertEqual(tf.matching_files(f.name).eval(), f.name)
+
+ # We will look for files matching "ABxDEF.GH*" where "x" is some wildcard.
+ pos = files[0].name.find(cases[0])
+ pattern = files[0].name[:pos] + 'AB%sDEF.GH*'
+
+ self.assertEqual(set(tf.matching_files(pattern % 'z').eval()),
+ self._subset(files, [1]))
+ self.assertEqual(set(tf.matching_files(pattern % '?').eval()),
+ self._subset(files, [0, 1, 3, 4]))
+ self.assertEqual(set(tf.matching_files(pattern % '*').eval()),
+ self._subset(files, [0, 1, 2, 3, 4, 5]))
+ self.assertEqual(set(tf.matching_files(pattern % '[cxz]').eval()),
+ self._subset(files, [0, 1]))
+ self.assertEqual(set(tf.matching_files(pattern % '[0-9]').eval()),
+ self._subset(files, [3, 4]))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py
new file mode 100644
index 0000000000..50e5328c3e
--- /dev/null
+++ b/tensorflow/python/kernel_tests/linalg_grad_test.py
@@ -0,0 +1,49 @@
+"""Tests for tensorflow.ops.linalg_grad."""
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+
+class MatrixInverseGradientTest(tf.test.TestCase):
+ pass # Filled in below
+
+def _GetMatrixInverseGradientTest(dtype, shape):
+ def Test(self):
+ with self.test_session():
+ np.random.seed(1)
+ m = np.random.uniform(low=1.0, high=100.0, size=np.prod(shape)).reshape(
+ shape).astype(dtype)
+ a = tf.constant(m)
+ epsilon = np.finfo(dtype).eps
+ # Optimal stepsize for central difference is O(epsilon^{1/3}).
+ delta = epsilon ** (1.0 / 3.0)
+ tol = 1e-3
+
+ if len(shape) == 2:
+ ainv = tf.matrix_inverse(a)
+ else:
+ ainv = tf.batch_matrix_inverse(a)
+
+ theoretical, numerical = gc.ComputeGradient(a, shape, ainv, shape,
+ delta=delta)
+ self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
+ return Test
+
+
+if __name__ == "__main__":
+ # TODO(rmlarsen,irving): Reenable float32 once tolerances are fixed
+ # The test used to loop over (np.float, np.double), both of which are float64.
+ for dtype in np.float64,:
+ for size in 2, 3, 5, 10:
+ # We skip the rank 4, size 10 case: it is slow and conceptually covered
+ # by the other cases.
+ for extra in [(), (2,), (3,)] + [(3, 2)] * (size < 10):
+ shape = extra + (size, size)
+ name = '%s_%s' % (dtype.__name__, '_'.join(map(str, shape)))
+ setattr(MatrixInverseGradientTest, 'testMatrixInverseGradient_' + name,
+ _GetMatrixInverseGradientTest(dtype, shape))
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py
new file mode 100644
index 0000000000..b4607be1fb
--- /dev/null
+++ b/tensorflow/python/kernel_tests/listdiff_op_test.py
@@ -0,0 +1,117 @@
+"""Tests for tensorflow.kernels.listdiff_op."""
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class ListDiffTest(tf.test.TestCase):
+
+ def _testListDiff(self, x, y, out, idx, dtype=np.int32):
+ x = np.array(x, dtype=dtype)
+ y = np.array(y, dtype=dtype)
+ out = np.array(out, dtype=dtype)
+ idx = np.array(idx, dtype=dtype)
+
+ with self.test_session() as sess:
+ x_tensor = tf.convert_to_tensor(x)
+ y_tensor = tf.convert_to_tensor(y)
+ out_tensor, idx_tensor = tf.listdiff(x_tensor, y_tensor)
+ tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
+
+ self.assertAllEqual(tf_out, out)
+ self.assertAllEqual(tf_idx, idx)
+ self.assertEqual(1, out_tensor.get_shape().ndims)
+ self.assertEqual(1, idx_tensor.get_shape().ndims)
+
+ def testBasic1(self):
+ x = [1, 2, 3, 4]
+ y = [1, 2]
+ out = [3, 4]
+ idx = [2, 3]
+ for t in [np.int32, np.int64, np.float, np.double]:
+ self._testListDiff(x, y, out, idx, dtype=t)
+
+ def testBasic2(self):
+ x = [1, 2, 3, 4]
+ y = [2]
+ out = [1, 3, 4]
+ idx = [0, 2, 3]
+ for t in [np.int32, np.int64, np.float, np.double]:
+ self._testListDiff(x, y, out, idx, dtype=t)
+
+ def testBasic3(self):
+ x = [1, 4, 3, 2]
+ y = [4, 2]
+ out = [1, 3]
+ idx = [0, 2]
+ for t in [np.int32, np.int64, np.float, np.double]:
+ self._testListDiff(x, y, out, idx, dtype=t)
+
+ def testDuplicates(self):
+ x = [1, 2, 4, 3, 2, 3, 3, 1]
+ y = [4, 2]
+ out = [1, 3, 3, 3, 1]
+ idx = [0, 3, 5, 6, 7]
+ for t in [np.int32, np.int64, np.float, np.double]:
+ self._testListDiff(x, y, out, idx, dtype=t)
+
+ def testRandom(self):
+ num_random_tests = 10
+ int_low = -7
+ int_high = 8
+ max_size = 50
+ for _ in xrange(num_random_tests):
+ x_size = np.random.randint(max_size + 1)
+ x = np.random.randint(int_low, int_high, size=x_size)
+ y_size = np.random.randint(max_size + 1)
+ y = np.random.randint(int_low, int_high, size=y_size)
+ out_idx = [(entry, pos) for pos, entry in enumerate(x) if entry not in y]
+ if out_idx:
+ out_idx = map(list, zip(*out_idx))
+ out = out_idx[0]
+ idx = out_idx[1]
+ else:
+ out = []
+ idx = []
+ for t in [np.int32, np.int64, np.float, np.double]:
+ self._testListDiff(x, y, out, idx, dtype=t)
+
+ def testInt32FullyOverlapping(self):
+ x = [1, 2, 3, 4]
+ y = [1, 2, 3, 4]
+ out = []
+ idx = []
+ self._testListDiff(x, y, out, idx)
+
+ def testInt32NonOverlapping(self):
+ x = [1, 2, 3, 4]
+ y = [5, 6]
+ out = x
+ idx = range(len(x))
+ self._testListDiff(x, y, out, idx)
+
+ def testInt32EmptyX(self):
+ x = []
+ y = [1, 2]
+ out = []
+ idx = []
+ self._testListDiff(x, y, out, idx)
+
+ def testInt32EmptyY(self):
+ x = [1, 2, 3, 4]
+ y = []
+ out = x
+ idx = range(len(x))
+ self._testListDiff(x, y, out, idx)
+
+ def testInt32EmptyXY(self):
+ x = []
+ y = []
+ out = []
+ idx = []
+ self._testListDiff(x, y, out, idx)
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
new file mode 100644
index 0000000000..18ca441b23
--- /dev/null
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -0,0 +1,50 @@
+"""Tests for tensorflow.kernels.logging_ops."""
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class LoggingOpsTest(tf.test.TestCase):
+
+ def testAssertDivideByZero(self):
+ with self.test_session() as sess:
+ epsilon = tf.convert_to_tensor(1e-20)
+ x = tf.convert_to_tensor(0.0)
+ y = tf.convert_to_tensor(1.0)
+ z = tf.convert_to_tensor(2.0)
+ # assert(epsilon < y)
+ # z / y
+ with sess.graph.control_dependencies(
+ [tf.Assert(tf.less(epsilon, y), ["Divide-by-zero"])]):
+ out = tf.div(z, y)
+ self.assertAllEqual(2.0, out.eval())
+ # assert(epsilon < x)
+ # z / x
+ #
+ # This tests printing out multiple tensors
+ with sess.graph.control_dependencies(
+ [tf.Assert(tf.less(epsilon, x),
+ ["Divide-by-zero", "less than x"])]):
+ out = tf.div(z, x)
+ with self.assertRaisesOpError("less than x"):
+ out.eval()
+
+
+class PrintGradientTest(tf.test.TestCase):
+
+ def testPrintGradient(self):
+ with self.test_session():
+ inp = tf.constant(2.0, shape=[100, 32], name="in")
+ w = tf.constant(4.0, shape=[10, 100], name="w")
+ wx = tf.matmul(w, inp, name="wx")
+ wx_print = tf.Print(wx, [w, w, w])
+ wx_grad = tf.gradients(wx, w)[0]
+ wx_print_grad = tf.gradients(wx_print, w)[0]
+ wxg = wx_grad.eval()
+ wxpg = wx_print_grad.eval()
+ self.assertAllEqual(wxg, wxpg)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/lookup_table_op_test.py b/tensorflow/python/kernel_tests/lookup_table_op_test.py
new file mode 100644
index 0000000000..cd170876e6
--- /dev/null
+++ b/tensorflow/python/kernel_tests/lookup_table_op_test.py
@@ -0,0 +1,195 @@
+"""Tests for lookup table ops from tf."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class HashTableOpTest(tf.test.TestCase):
+
+ def testHashTable(self):
+ with self.test_session():
+ shared_name = ''
+ default_val = -1
+ table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
+
+ # Initialize with keys and values tensors.
+ keys = tf.constant(['brain', 'salad', 'surgery'])
+ values = tf.constant([0, 1, 2], tf.int64)
+ init = table.initialize_from(keys, values)
+ init.run()
+ self.assertAllEqual(3, table.size().eval())
+
+ input_string = tf.constant(['brain', 'salad', 'tank'])
+ output = table.lookup(input_string)
+
+ result = output.eval()
+ self.assertAllEqual([0, 1, -1], result)
+
+ def testHashTableInitWithPythonArrays(self):
+ with self.test_session():
+ shared_name = ''
+ default_val = -1
+ table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
+ # Empty table.
+ self.assertAllEqual(0, table.size().eval())
+
+ # Initialize with keys and values tensors.
+ keys = ['brain', 'salad', 'surgery']
+ values = [0, 1, 2]
+ init = table.initialize_from(keys, values)
+ init.run()
+ self.assertAllEqual(3, table.size().eval())
+
+ input_string = tf.constant(['brain', 'salad', 'tank'])
+ output = table.lookup(input_string)
+
+ result = output.eval()
+ self.assertAllEqual([0, 1, -1], result)
+
+ def testHashTableInitWithNumPyArrays(self):
+ with self.test_session():
+ shared_name = ''
+ default_val = -1
+ table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
+
+ # Initialize with keys and values tensors.
+ keys = np.array(['brain', 'salad', 'surgery'], dtype=np.str)
+ values = np.array([0, 1, 2], dtype=np.int64)
+ init = table.initialize_from(keys, values)
+ init.run()
+ self.assertAllEqual(3, table.size().eval())
+
+ input_string = tf.constant(['brain', 'salad', 'tank'])
+ output = table.lookup(input_string)
+
+ result = output.eval()
+ self.assertAllEqual([0, 1, -1], result)
+
+ def testMultipleHashTables(self):
+ with self.test_session() as sess:
+ shared_name = ''
+ default_val = -1
+ table1 = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
+ table2 = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
+ table3 = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
+
+ keys = tf.constant(['brain', 'salad', 'surgery'])
+ values = tf.constant([0, 1, 2], tf.int64)
+ table1.initialize_from(keys, values)
+ table2.initialize_from(keys, values)
+ table3.initialize_from(keys, values)
+
+ tf.initialize_all_tables().run()
+ self.assertAllEqual(3, table1.size().eval())
+ self.assertAllEqual(3, table2.size().eval())
+ self.assertAllEqual(3, table3.size().eval())
+
+ input_string = tf.constant(['brain', 'salad', 'tank'])
+ output1 = table1.lookup(input_string)
+ output2 = table2.lookup(input_string)
+ output3 = table3.lookup(input_string)
+
+ out1, out2, out3 = sess.run([output1, output2, output3])
+ self.assertAllEqual([0, 1, -1], out1)
+ self.assertAllEqual([0, 1, -1], out2)
+ self.assertAllEqual([0, 1, -1], out3)
+
+ def testHashTableWithTensorDefault(self):
+ with self.test_session():
+ shared_name = ''
+ default_val = tf.constant(-1, tf.int64)
+ table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
+
+ # Initialize with keys and values tensors.
+ keys = tf.constant(['brain', 'salad', 'surgery'])
+ values = tf.constant([0, 1, 2], tf.int64)
+ init = table.initialize_from(keys, values)
+ init.run()
+
+ input_string = tf.constant(['brain', 'salad', 'tank'])
+ output = table.lookup(input_string)
+
+ result = output.eval()
+ self.assertAllEqual([0, 1, -1], result)
+
+ def testSignatureMismatch(self):
+ with self.test_session():
+ shared_name = ''
+ default_val = -1
+ table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
+
+ # Initialize with keys and values tensors.
+ keys = tf.constant(['brain', 'salad', 'surgery'])
+ values = tf.constant([0, 1, 2], tf.int64)
+ init = table.initialize_from(keys, values)
+ init.run()
+
+ input_string = tf.constant([1, 2, 3], tf.int64)
+ with self.assertRaises(TypeError):
+ table.lookup(input_string)
+
+ with self.assertRaises(TypeError):
+ tf.HashTable(tf.string, tf.int64, 'UNK', shared_name)
+
+ def testDTypes(self):
+ with self.test_session():
+ shared_name = ''
+ default_val = -1
+ with self.assertRaises(TypeError):
+ tf.HashTable([tf.string], tf.string, default_val, shared_name)
+
+ def testNotInitialized(self):
+ with self.test_session():
+ shared_name = ''
+ default_val = -1
+ table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
+
+ input_string = tf.constant(['brain', 'salad', 'surgery'])
+ output = table.lookup(input_string)
+
+ with self.assertRaisesOpError('Table not initialized'):
+ output.eval()
+
+ def testInitializeTwice(self):
+ with self.test_session():
+ shared_name = ''
+ default_val = -1
+ table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
+
+ # Initialize with keys and values tensors.
+ keys = tf.constant(['brain', 'salad', 'surgery'])
+ values = tf.constant([0, 1, 2], tf.int64)
+ init = table.initialize_from(keys, values)
+ init.run()
+
+ with self.assertRaisesOpError('Table already initialized'):
+ init.run()
+
+ def testInitializationWithInvalidDimensions(self):
+ with self.test_session():
+ shared_name = ''
+ default_val = -1
+ table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
+
+ # Initialize with keys and values tensors.
+ keys = tf.constant(['brain', 'salad', 'surgery'])
+ values = tf.constant([0, 1, 2, 3, 4], tf.int64)
+ with self.assertRaises(ValueError):
+ table.initialize_from(keys, values)
+
+ def testInitializationWithInvalidDataTypes(self):
+ with self.test_session():
+ shared_name = ''
+ default_val = -1
+ table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
+
+ # Initialize with keys and values tensors.
+ keys = [0, 1, 2]
+ values = ['brain', 'salad', 'surgery']
+ with self.assertRaises(TypeError):
+ table.initialize_from(keys, values)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/lrn_op_test.py b/tensorflow/python/kernel_tests/lrn_op_test.py
new file mode 100644
index 0000000000..7a3bb67938
--- /dev/null
+++ b/tensorflow/python/kernel_tests/lrn_op_test.py
@@ -0,0 +1,101 @@
+"""Tests for local response normalization."""
+import copy
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests.gradient_checker import ComputeGradientError
+
+
+
+class LRNOpTest(tf.test.TestCase):
+
+ def _LRN(self, input_image, lrn_depth_radius=5, bias=1.0,
+ alpha=1.0, beta=0.5):
+ """Compute expected result."""
+ output = copy.deepcopy(input_image)
+ batch_size = input_image.shape[0]
+ rows = input_image.shape[1]
+ cols = input_image.shape[2]
+ depth = input_image.shape[3]
+ for b in range(batch_size):
+ for r in range(rows):
+ for c in range(cols):
+ for d in range(depth):
+ begin = max(0, d - lrn_depth_radius)
+ end = min(depth, d + lrn_depth_radius + 1)
+ patch = input_image[b, r, c, begin:end]
+ output[b, r, c, d] /= (
+ np.power(bias + alpha * np.sum(patch * patch), beta))
+ return output
+
+ def _RunAndVerify(self):
+ with self.test_session():
+ # random shape
+ shape = np.random.randint(1, 16, size=4)
+ # Make depth at least 2 to make it meaningful
+ shape[3] += 1
+ p = tf.placeholder(tf.float32, shape=shape)
+ # random depth_radius, bias, alpha, beta
+ lrn_depth_radius = np.random.randint(1, shape[3])
+ bias = 1.0 + np.random.rand()
+ alpha = 2.0 * np.random.rand()
+ beta = 2.0 * np.random.rand()
+ lrn_t = tf.nn.local_response_normalization(
+ p, name="lrn", depth_radius=lrn_depth_radius, bias=bias,
+ alpha=alpha, beta=beta)
+ params = {p: np.random.rand(*shape).astype("f")}
+ result = lrn_t.eval(feed_dict=params)
+ expected = self._LRN(
+ params[p], lrn_depth_radius=lrn_depth_radius, bias=bias, alpha=alpha,
+ beta=beta)
+ self.assertTrue(np.amax(np.abs(result - expected)) < 1e-4)
+ self.assertShapeEqual(expected, lrn_t)
+
+ def testCompute(self):
+ for _ in range(2):
+ self._RunAndVerify()
+
+ def testGradientsZeroInput(self):
+ with self.test_session():
+ shape = [4, 4, 4, 4]
+ p = tf.placeholder(tf.float32, shape=shape)
+ inp_array = np.zeros(shape).astype("f")
+ lrn_op = tf.nn.local_response_normalization(p, 2, 1.0, 0.0,
+ 1.0, name="lrn")
+ grad = tf.gradients([lrn_op], [p])[0]
+ params = {p: inp_array}
+ r = grad.eval(feed_dict=params)
+ expected = np.ones(shape).astype("f")
+ self.assertAllClose(r, expected)
+ self.assertShapeEqual(expected, grad)
+
+ def _RunAndVerifyGradients(self):
+ with self.test_session():
+ # random shape
+ shape = np.random.randint(1, 5, size=4)
+ # Make depth at least 2 to make it meaningful
+ shape[3] += 1
+ # random depth_radius, bias, alpha, beta
+ lrn_depth_radius = np.random.randint(1, shape[3])
+ bias = 1.0 + np.random.rand()
+ alpha = 1.0 * np.random.rand()
+ beta = 1.0 * np.random.rand()
+ inp_array = np.random.rand(*shape).astype("f")
+ inp = tf.constant(list(inp_array.ravel(order="C")), shape=shape)
+ lrn_op = tf.nn.local_response_normalization(
+ inp, name="lrn", depth_radius=lrn_depth_radius, bias=bias,
+ alpha=alpha, beta=beta)
+ err = ComputeGradientError(inp, shape, lrn_op, shape)
+ print "LRN Gradient error ", err
+ self.assertLess(err, 1e-4)
+
+ def testGradients(self):
+ for _ in range(2):
+ self._RunAndVerifyGradients()
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py
new file mode 100644
index 0000000000..5aeb736b9b
--- /dev/null
+++ b/tensorflow/python/kernel_tests/matmul_op_test.py
@@ -0,0 +1,206 @@
+"""Tests for tensorflow.ops.math_ops.matmul."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+
+class MatMulTest(tf.test.TestCase):
+
+ def _testCpuMatmul(self, x, y, transpose_x=False, transpose_y=False):
+ x_mat = np.matrix(x).T if transpose_x else np.matrix(x)
+ y_mat = np.matrix(y).T if transpose_y else np.matrix(y)
+ np_ans = x_mat * y_mat
+ with self.test_session(use_gpu=False):
+ tf_ans = tf.matmul(x, y, transpose_x, transpose_y).eval()
+ self.assertAllClose(np_ans, tf_ans)
+ self.assertAllEqual(np_ans.shape, tf_ans.shape)
+
+ def _testGpuMatmul(self, x, y, transpose_x=False, transpose_y=False):
+ x_mat = np.matrix(x).T if transpose_x else np.matrix(x)
+ y_mat = np.matrix(y).T if transpose_y else np.matrix(y)
+ np_ans = x_mat * y_mat
+ with self.test_session(use_gpu=True):
+ tf_ans = tf.matmul(x, y, transpose_x, transpose_y).eval()
+ self.assertAllClose(np_ans, tf_ans)
+ self.assertAllEqual(np_ans.shape, tf_ans.shape)
+
+ def _randMatrix(self, rows, cols, dtype):
+ if dtype is np.complex64:
+ real = self._randMatrix(rows, cols, np.float32)
+ imag = self._randMatrix(rows, cols, np.float32)
+ return real + np.complex(0, 1) * imag
+ else:
+ return np.random.uniform(low=1.0, high=100.0, size=rows * cols).reshape(
+ [rows, cols]).astype(dtype)
+
+ # Basic test:
+ # [ [1],
+ # [2],
+ # [3], * [1, 2]
+ # [4] ]
+ def testFloatBasic(self):
+ x = np.arange(1., 5.).reshape([4, 1]).astype(np.float32)
+ y = np.arange(1., 3.).reshape([1, 2]).astype(np.float32)
+ self._testCpuMatmul(x, y)
+ self._testGpuMatmul(x, y)
+
+ def testDoubleBasic(self):
+ x = np.arange(1., 5.).reshape([4, 1]).astype(np.float64)
+ y = np.arange(1., 3.).reshape([1, 2]).astype(np.float64)
+ self._testCpuMatmul(x, y)
+
+ def testInt32Basic(self):
+ x = np.arange(1., 5.).reshape([4, 1]).astype(np.int32)
+ y = np.arange(1., 3.).reshape([1, 2]).astype(np.int32)
+ self._testCpuMatmul(x, y)
+
+ def testSComplexBasic(self):
+ x = np.arange(1., 5.).reshape([4, 1]).astype(np.complex64)
+ y = np.arange(1., 3.).reshape([1, 2]).astype(np.complex64)
+ self._testCpuMatmul(x, y)
+
+ # Tests testing random sized matrices.
+ def testFloatRandom(self):
+ for _ in range(10):
+ n, k, m = np.random.randint(1, 100, size=3)
+ x = self._randMatrix(n, k, np.float32)
+ y = self._randMatrix(k, m, np.float32)
+ self._testCpuMatmul(x, y)
+ self._testGpuMatmul(x, y)
+
+ def testDoubleRandom(self):
+ for _ in range(10):
+ n, k, m = np.random.randint(1, 100, size=3)
+ x = self._randMatrix(n, k, np.float64)
+ y = self._randMatrix(k, m, np.float64)
+ self._testCpuMatmul(x, y)
+
+ def testInt32Random(self):
+ for _ in range(10):
+ n, k, m = np.random.randint(1, 100, size=3)
+ x = self._randMatrix(n, k, np.int32)
+ y = self._randMatrix(k, m, np.int32)
+ self._testCpuMatmul(x, y)
+
+ def testSComplexRandom(self):
+ for _ in range(10):
+ n, k, m = np.random.randint(1, 100, size=3)
+ x = self._randMatrix(n, k, np.complex64)
+ y = self._randMatrix(k, m, np.complex64)
+ self._testCpuMatmul(x, y)
+
+ # Test the cases that transpose the matrices before multiplying.
+ # NOTE(keveman): The cases where only one of the inputs is
+ # transposed are covered by tf.matmul's gradient function.
+ def testFloatRandomTransposeBoth(self):
+ for _ in range(10):
+ n, k, m = np.random.randint(1, 100, size=3)
+ x = self._randMatrix(k, n, np.float32)
+ y = self._randMatrix(m, k, np.float32)
+ self._testCpuMatmul(x, y, True, True)
+ self._testGpuMatmul(x, y, True, True)
+
+ def testDoubleRandomTranposeBoth(self):
+ for _ in range(10):
+ n, k, m = np.random.randint(1, 100, size=3)
+ x = self._randMatrix(k, n, np.float64)
+ y = self._randMatrix(m, k, np.float64)
+ self._testCpuMatmul(x, y, True, True)
+
+ def testMatMul_OutEmpty_A(self):
+ n, k, m = 0, 8, 3
+ x = self._randMatrix(n, k, np.float32)
+ y = self._randMatrix(k, m, np.float32)
+ self._testCpuMatmul(x, y)
+ self._testGpuMatmul(x, y)
+
+ def testMatMul_OutEmpty_B(self):
+ n, k, m = 3, 8, 0
+ x = self._randMatrix(n, k, np.float32)
+ y = self._randMatrix(k, m, np.float32)
+ self._testCpuMatmul(x, y)
+ self._testGpuMatmul(x, y)
+
+ def testMatMul_Inputs_Empty(self):
+ n, k, m = 3, 0, 4
+ x = self._randMatrix(n, k, np.float32)
+ y = self._randMatrix(k, m, np.float32)
+ self._testCpuMatmul(x, y)
+ self._testGpuMatmul(x, y)
+
+
+# TODO(zhifengc): Figures out how to test matmul gradients on GPU.
+class MatMulGradientTest(tf.test.TestCase):
+
+ def testGradientInput0(self):
+ with self.test_session(use_gpu=False):
+ x = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2],
+ dtype=tf.float64, name="x")
+ y = tf.constant([1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7],
+ shape=[2, 4], dtype=tf.float64, name="y")
+ m = tf.matmul(x, y, name="matmul")
+ err = gc.ComputeGradientError(x, [3, 2], m, [3, 4])
+ print "matmul input0 gradient err = ", err
+ self.assertLess(err, 1e-10)
+
+ def testGradientInput1(self):
+ with self.test_session(use_gpu=False):
+ x = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2],
+ dtype=tf.float64, name="x")
+ y = tf.constant([1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7],
+ shape=[2, 4], dtype=tf.float64, name="y")
+ m = tf.matmul(x, y, name="matmul")
+ err = gc.ComputeGradientError(y, [2, 4], m, [3, 4])
+ print "matmul input1 gradient err = ", err
+ self.assertLess(err, 1e-10)
+
+ def _VerifyInput0(self, transpose_a, transpose_b):
+ shape_x = [3, 2]
+ shape_y = [2, 4]
+ if transpose_a:
+ shape_x = list(reversed(shape_x))
+ if transpose_b:
+ shape_y = list(reversed(shape_y))
+ with self.test_session(use_gpu=False):
+ x = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=shape_x,
+ dtype=tf.float64, name="x")
+ y = tf.constant([1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7],
+ shape=shape_y, dtype=tf.float64, name="y")
+ m = tf.matmul(x, y, transpose_a, transpose_b, name="matmul")
+ err = gc.ComputeGradientError(x, shape_x, m, [3, 4])
+ print "matmul input0 gradient err = ", err
+ self.assertLess(err, 1e-10)
+
+ def testGradientInput0WithTranspose(self):
+ self._VerifyInput0(transpose_a=True, transpose_b=False)
+ self._VerifyInput0(transpose_a=False, transpose_b=True)
+ self._VerifyInput0(transpose_a=True, transpose_b=True)
+
+ def _VerifyInput1(self, transpose_a, transpose_b):
+ shape_x = [3, 2]
+ shape_y = [2, 4]
+ if transpose_a:
+ shape_x = list(reversed(shape_x))
+ if transpose_b:
+ shape_y = list(reversed(shape_y))
+ with self.test_session(use_gpu=False):
+ x = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=shape_x,
+ dtype=tf.float64, name="x")
+ y = tf.constant([1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7],
+ shape=shape_y, dtype=tf.float64, name="y")
+ m = tf.matmul(x, y, transpose_a, transpose_b, name="matmul")
+ err = gc.ComputeGradientError(y, shape_y, m, [3, 4])
+ print "matmul input1 gradient err = ", err
+ self.assertLess(err, 1e-10)
+
+ def testGradientInput1WithTranspose(self):
+ self._VerifyInput1(transpose_a=True, transpose_b=False)
+ self._VerifyInput1(transpose_a=False, transpose_b=True)
+ self._VerifyInput1(transpose_a=True, transpose_b=True)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
new file mode 100644
index 0000000000..541a937185
--- /dev/null
+++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
@@ -0,0 +1,79 @@
+"""Tests for tensorflow.ops.math_ops.matrix_inverse."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class InverseOpTest(tf.test.TestCase):
+
+ def _verifyInverse(self, x):
+ for np_type in [np.float32, np.float64]:
+ y = x.astype(np_type)
+ with self.test_session():
+ # Verify that x^{-1} * x == Identity matrix.
+ if x.ndim == 2:
+ inv = tf.matrix_inverse(y)
+ tf_ans = tf.matmul(inv, y)
+ np_ans = np.identity(y.shape[-1])
+ else:
+ inv = tf.batch_matrix_inverse(y)
+ tf_ans = tf.batch_matmul(inv, y)
+ tiling = list(y.shape)
+ tiling[-2:] = [1, 1]
+ np_ans = np.tile(np.identity(y.shape[-1]), tiling)
+ out = tf_ans.eval()
+ self.assertAllClose(np_ans, out)
+ self.assertShapeEqual(y, tf_ans)
+
+ def testBasic(self):
+ # 2x2 matrices
+ matrix1 = np.array([[1., 2.], [3., 4.]])
+ matrix2 = np.array([[1., 3.], [3., 5.]])
+ self._verifyInverse(matrix1)
+ self._verifyInverse(matrix2)
+ # A multidimensional batch of 2x2 matrices
+ matrix_batch = np.concatenate([np.expand_dims(matrix1, 0),
+ np.expand_dims(matrix2, 0)])
+ matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1])
+ self._verifyInverse(matrix_batch)
+
+ def testNonSquareMatrix(self):
+ # When the inverse of a non-square matrix is attempted we should return
+ # an error
+ with self.assertRaises(ValueError):
+ tf.matrix_inverse(np.array([[1., 2., 3.], [3., 4., 5.]]))
+
+ def testWrongDimensions(self):
+ # The input to the inverse should be at least a 2-dimensional tensor.
+ tensor3 = tf.constant([1., 2.])
+ with self.assertRaises(ValueError):
+ tf.matrix_inverse(tensor3)
+
+ def testNotInvertible(self):
+ # The input should be invertible.
+ with self.test_session():
+ with self.assertRaisesOpError("Input is not invertible."):
+ # All rows of the matrix below add to zero
+ tensor3 = tf.constant([[1., 0., -1.], [-1., 1., 0.], [0., -1., 1.]])
+ tf.matrix_inverse(tensor3).eval()
+
+ with self.test_session():
+ with self.assertRaisesOpError("Input is not invertible."):
+ # Determinant of the matrix below is zero
+ tensor3 = tf.constant([[1., 1.], [1., 1.]])
+ tf.matrix_inverse(tensor3).eval()
+
+ with self.test_session():
+ with self.assertRaisesOpError("Input is not invertible."):
+ # Determinant of the matrix below is zero
+ tensor3 = tf.constant([[np.inf, 1.], [1., 1.]])
+ tf.matrix_inverse(tensor3).eval()
+
+ def testEmpty(self):
+ self._verifyInverse(np.empty([0, 2, 2]))
+ self._verifyInverse(np.empty([2, 0, 0]))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/numerics_test.py b/tensorflow/python/kernel_tests/numerics_test.py
new file mode 100644
index 0000000000..8cb2fe2f8b
--- /dev/null
+++ b/tensorflow/python/kernel_tests/numerics_test.py
@@ -0,0 +1,91 @@
+"""Tests for tensorflow.ops.numerics."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.ops import control_flow_ops
+
+
+class VerifyTensorAllFiniteTest(tf.test.TestCase):
+
+ def testVerifyTensorAllFiniteSucceeds(self):
+ x_shape = [5, 4]
+ x = np.random.random_sample(x_shape).astype(np.float32)
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ t = tf.constant(x, shape=x_shape, dtype=tf.float32)
+ t_verified = tf.verify_tensor_all_finite(t, "Input is not a number.")
+ self.assertAllClose(x, t_verified.eval())
+
+ def testVerifyTensorAllFiniteFails(self):
+ x_shape = [5, 4]
+ x = np.random.random_sample(x_shape).astype(np.float32)
+ my_msg = "Input is not a number."
+
+ # Test NaN.
+ x[0] = np.nan
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ with self.assertRaisesOpError(my_msg):
+ t = tf.constant(x, shape=x_shape, dtype=tf.float32)
+ t_verified = tf.verify_tensor_all_finite(t, my_msg)
+ t_verified.eval()
+
+ # Test Inf.
+ x[0] = np.inf
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ with self.assertRaisesOpError(my_msg):
+ t = tf.constant(x, shape=x_shape, dtype=tf.float32)
+ t_verified = tf.verify_tensor_all_finite(t, my_msg)
+ t_verified.eval()
+
+
+class NumericsTest(tf.test.TestCase):
+
+ def testInf(self):
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()):
+ t1 = tf.constant(1.0)
+ t2 = tf.constant(0.0)
+ a = tf.div(t1, t2)
+ check = tf.add_check_numerics_ops()
+ a = control_flow_ops.with_dependencies([check], a)
+ with self.assertRaisesOpError("Inf"):
+ a.eval()
+
+ def testNaN(self):
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()):
+ t1 = tf.constant(0.0)
+ t2 = tf.constant(0.0)
+ a = tf.div(t1, t2)
+ check = tf.add_check_numerics_ops()
+ a = control_flow_ops.with_dependencies([check], a)
+ with self.assertRaisesOpError("NaN"):
+ a.eval()
+
+ def testBoth(self):
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()):
+ t1 = tf.constant([1.0, 0.0])
+ t2 = tf.constant([0.0, 0.0])
+ a = tf.div(t1, t2)
+ check = tf.add_check_numerics_ops()
+ a = control_flow_ops.with_dependencies([check], a)
+ with self.assertRaisesOpError("Inf and NaN"):
+ a.eval()
+
+ def testPassThrough(self):
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()):
+ t1 = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3])
+ checked = tf.check_numerics(t1, message="pass through test")
+ value = checked.eval()
+ self.assertAllEqual(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), value)
+ self.assertEqual([2, 3], checked.get_shape())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/pack_op_test.py b/tensorflow/python/kernel_tests/pack_op_test.py
new file mode 100644
index 0000000000..5f3b1823c0
--- /dev/null
+++ b/tensorflow/python/kernel_tests/pack_op_test.py
@@ -0,0 +1,47 @@
+"""Functional tests for Pack Op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker
+
+
+class PackOpTest(tf.test.TestCase):
+
+ def testSimple(self):
+ np.random.seed(7)
+ for use_gpu in False, True:
+ with self.test_session(use_gpu=use_gpu):
+ for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
+ data = np.random.randn(*shape)
+ # Convert [data[0], data[1], ...] separately to tensorflow
+ xs = map(tf.constant, data)
+ # Pack back into a single tensorflow tensor
+ c = tf.pack(xs)
+ self.assertAllEqual(c.eval(), data)
+
+ def testGradients(self):
+ np.random.seed(7)
+ for use_gpu in False, True:
+ for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
+ data = np.random.randn(*shape)
+ shapes = [shape[1:]] * shape[0]
+ with self.test_session(use_gpu=use_gpu):
+ xs = map(tf.constant, data)
+ c = tf.pack(xs)
+ err = gradient_checker.ComputeGradientError(xs, shapes, c, shape)
+ self.assertLess(err, 1e-6)
+
+ def testZeroSize(self):
+ # Verify that pack doesn't crash for zero size inputs
+ for use_gpu in False, True:
+ with self.test_session(use_gpu=use_gpu):
+ for shape in (0,), (3,0), (0, 3):
+ x = np.zeros((2,) + shape)
+ p = tf.pack(list(x)).eval()
+ self.assertAllEqual(p, x)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/pad_op_test.py b/tensorflow/python/kernel_tests/pad_op_test.py
new file mode 100644
index 0000000000..113aeb1ccf
--- /dev/null
+++ b/tensorflow/python/kernel_tests/pad_op_test.py
@@ -0,0 +1,140 @@
+"""Tests for tensorflow.ops.nn_ops.Pad."""
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+
+class PadOpTest(tf.test.TestCase):
+
+ def _npPad(self, inp, paddings):
+ return np.pad(inp, paddings, mode="constant")
+
+ def testNpPad(self):
+ self.assertAllClose(
+ np.array([[0, 0, 0, 0, 0, 0],
+ [0, 3, 3, 0, 0, 0],
+ [0, 4, 4, 0, 0, 0],
+ [0, 5, 5, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0]]),
+ self._npPad(np.array([[3, 3], [4, 4], [5, 5]]), [[1, 2], [1, 3]]))
+
+ def _testPad(self, np_inputs, paddings, use_gpu=False):
+ np_val = self._npPad(np_inputs, paddings)
+ with self.test_session(use_gpu=use_gpu):
+ tf_val = tf.pad(np_inputs, paddings)
+ out = tf_val.eval()
+ self.assertAllClose(np_val, out)
+ self.assertShapeEqual(np_val, tf_val)
+
+ def _testGradient(self, x, a):
+ with self.test_session():
+ inx = tf.convert_to_tensor(x)
+ xs = list(x.shape)
+ ina = tf.convert_to_tensor(a)
+ y = tf.pad(inx, ina)
+ # Expected y's shape to be:
+ ys = list(np.array(x.shape) + np.sum(np.array(a), axis=1))
+ jacob_t, jacob_n = gc.ComputeGradient(inx, xs, y, ys, x_init_value=x)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
+
+ def _testAll(self, np_inputs, paddings):
+ self._testPad(np_inputs, paddings, use_gpu=False)
+ self._testPad(np_inputs, paddings, use_gpu=True)
+ if np_inputs.dtype == np.float32:
+ self._testGradient(np_inputs, paddings)
+
+ def testInputDims(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.pad(
+ tf.reshape([1, 2], shape=[1, 2, 1, 1, 1, 1]),
+ tf.reshape([1, 2], shape=[1, 2]))
+
+ def testPaddingsDim(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.pad(
+ tf.reshape([1, 2], shape=[1, 2]),
+ tf.reshape([1, 2], shape=[2]))
+
+ def testPaddingsDim2(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.pad(
+ tf.reshape([1, 2], shape=[1, 2]),
+ tf.reshape([1, 2], shape=[2, 1]))
+
+ def testPaddingsDim3(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.pad(
+ tf.reshape([1, 2], shape=[1, 2]),
+ tf.reshape([1, 2], shape=[1, 2]))
+
+ def testPaddingsDim4(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.pad(
+ tf.reshape([1, 2], shape=[1, 2]),
+ tf.reshape([1, 2, 3, 4, 5, 6], shape=[3, 2]))
+
+ def testPaddingsNonNegative(self):
+ with self.test_session():
+ with self.assertRaisesRegexp(ValueError, "must be non-negative"):
+ tf.pad(
+ tf.constant([1], shape=[1]),
+ tf.constant([-1, 0], shape=[1, 2]))
+
+ def testPaddingsNonNegative2(self):
+ with self.test_session():
+ with self.assertRaisesRegexp(ValueError, "must be non-negative"):
+ tf.pad(
+ tf.constant([1], shape=[1]),
+ tf.constant([-1, 0], shape=[1, 2]))
+
+ def testIntTypes(self):
+ # TODO(mdevin): Figure out why the padding tests do not work on GPU
+ # for int types and rank > 2.
+ for t in [np.int32, np.int64]:
+ self._testPad((np.random.rand(4, 3, 3) * 100).astype(t),
+ [[1, 0], [2, 3], [0, 2]])
+
+ def testFloatTypes(self):
+ for t in [np.float32, np.float64]:
+ self._testAll(np.random.rand(2, 5).astype(t),
+ [[1, 0], [2, 0]])
+
+ def testShapeFunctionEdgeCases(self):
+ # Unknown paddings shape.
+ inp = tf.constant(0.0, shape=[4, 4, 4, 4])
+ padded = tf.pad(inp, tf.placeholder(tf.int32))
+ self.assertEqual([None, None, None, None], padded.get_shape().as_list())
+
+ # Unknown input shape.
+ inp = tf.placeholder(tf.float32)
+ padded = tf.pad(inp, [[2, 2], [2, 2]])
+ self.assertEqual([None, None], padded.get_shape().as_list())
+
+ # Unknown input and paddings shape.
+ inp = tf.placeholder(tf.float32)
+ padded = tf.pad(inp, tf.placeholder(tf.int32))
+ self.assertAllEqual(None, padded.get_shape().ndims)
+
+ def testScalars(self):
+ paddings = np.zeros((0, 2), dtype=np.int32)
+ inp = np.asarray(7)
+ for use_gpu in False, True:
+ with self.test_session(use_gpu=use_gpu):
+ tf_val = tf.pad(inp, paddings)
+ out = tf_val.eval()
+ self.assertAllClose(inp, out)
+ self.assertShapeEqual(inp, tf_val)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py
new file mode 100644
index 0000000000..fba7c705fb
--- /dev/null
+++ b/tensorflow/python/kernel_tests/parsing_ops_test.py
@@ -0,0 +1,414 @@
+"""Tests for tensorflow.ops.parsing_ops."""
+
+import itertools
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+# Helpers for creating Example objects
+example = tf.train.Example
+feature = tf.train.Feature
+features = lambda d: tf.train.Features(feature=d)
+bytes_feature = lambda v: feature(bytes_list=tf.train.BytesList(value=v))
+int64_feature = lambda v: feature(int64_list=tf.train.Int64List(value=v))
+float_feature = lambda v: feature(float_list=tf.train.FloatList(value=v))
+
+
+def flatten(list_of_lists):
+ """Flatten one level of nesting."""
+ return itertools.chain.from_iterable(list_of_lists)
+
+
+def flatten_values_tensors_or_sparse(tensors_list):
+ """Flatten each SparseTensor object into 3 Tensors for session.run()."""
+ return list(flatten([[v.indices, v.values, v.shape]
+ if isinstance(v, tf.SparseTensor) else [v]
+ for v in tensors_list]))
+
+
+def _compare_output_to_expected(
+ tester, dict_tensors, expected_tensors, flat_output):
+ tester.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))
+
+ i = 0 # Index into the flattened output of session.run()
+ for k, v in dict_tensors.iteritems():
+ expected_v = expected_tensors[k]
+ tf.logging.info("Comparing key: %s", k)
+ if isinstance(v, tf.SparseTensor):
+ # Three outputs for SparseTensor : indices, values, shape.
+ tester.assertEqual([k, 3], [k, len(expected_v)])
+ tester.assertAllEqual(flat_output[i], expected_v[0])
+ tester.assertAllEqual(flat_output[i + 1], expected_v[1])
+ tester.assertAllEqual(flat_output[i + 2], expected_v[2])
+ i += 3
+ else:
+ # One output for standard Tensor.
+ tester.assertAllEqual(flat_output[i], expected_v)
+ i += 1
+
+
+class ParseExampleTest(tf.test.TestCase):
+
+ def _test(self, kwargs, expected_values=None, expected_err_re=None):
+ with self.test_session() as sess:
+ # Pull out some keys to check shape inference
+ serialized = kwargs["serialized"]
+ dense_keys = kwargs["dense_keys"] if "dense_keys" in kwargs else []
+ sparse_keys = kwargs["sparse_keys"] if "sparse_keys" in kwargs else []
+ dense_shapes = kwargs["dense_shapes"] if "dense_shapes" in kwargs else []
+
+ # Returns dict w/ Tensors and SparseTensors
+ out = tf.parse_example(**kwargs)
+
+ # Check shapes; if serialized is a Tensor we need its size to
+ # properly check.
+ batch_size = (
+ serialized.eval().size if isinstance(serialized, tf.Tensor)
+ else np.asarray(serialized).size)
+ self.assertEqual(len(dense_keys), len(dense_shapes))
+ for (k, s) in zip(dense_keys, dense_shapes):
+ self.assertEqual(tuple(out[k].get_shape().as_list()), (batch_size,) + s)
+ for k in sparse_keys:
+ self.assertEqual(tuple(out[k].indices.get_shape().as_list()), (None, 2))
+ self.assertEqual(tuple(out[k].values.get_shape().as_list()), (None,))
+ self.assertEqual(tuple(out[k].shape.get_shape().as_list()), (2,))
+
+ # Check values
+ result = flatten_values_tensors_or_sparse(out.values()) # flatten values
+ if expected_err_re is None:
+ tf_result = sess.run(result)
+ _compare_output_to_expected(self, out, expected_values, tf_result)
+ else:
+ with self.assertRaisesOpError(expected_err_re):
+ sess.run(result)
+
+ def testEmptySerializedWithAllDefaults(self):
+ dense_keys = ["a", "b", "c"]
+ dense_shapes = [(1, 3), (3, 3), (2,)]
+ dense_types = [tf.int64, tf.string, tf.float32]
+ dense_defaults = {
+ "a": [0, 42, 0],
+ "b": np.random.rand(3, 3).astype(np.str),
+ "c": np.random.rand(2).astype(np.float32),
+ }
+
+ expected_st_a = ( # indices, values, shape
+ np.empty((0, 2), dtype=np.int64), # indices
+ np.empty((0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array([2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+
+ expected_output = {
+ "st_a": expected_st_a,
+ "a": np.array(2 * [[dense_defaults["a"]]]),
+ "b": np.array(2 * [dense_defaults["b"]]),
+ "c": np.array(2 * [dense_defaults["c"]]),
+ }
+
+ self._test(
+ {
+ "names": np.empty((0,), dtype=np.str),
+ # empty serialized input Examples
+ "serialized": tf.convert_to_tensor(["", ""]),
+ "dense_defaults": dense_defaults,
+ "sparse_keys": ["st_a"],
+ "sparse_types": [tf.int64],
+ "dense_keys": dense_keys,
+ "dense_types": dense_types,
+ "dense_shapes": dense_shapes
+ }, expected_output)
+
+ def testEmptySerializedWithoutDefaultsShouldFail(self):
+ dense_shapes = [(1, 3), (3, 3), (2,)]
+ dense_defaults = {
+ "a": [0, 42, 0],
+ "b": np.random.rand(3, 3).astype(np.str),
+ # Feature "c" is missing, since there's gaps it will cause failure.
+ }
+ self._test(
+ {
+ "serialized": ["", ""], # empty serialized input Examples
+ "names": ["in1", "in2"],
+ "dense_defaults": dense_defaults,
+ "sparse_keys": ["st_a"],
+ "sparse_types": [tf.int64],
+ "dense_keys": ["a", "b", "c"],
+ "dense_types": [tf.int64, tf.string, tf.float32],
+ "dense_shapes": dense_shapes
+ },
+ expected_err_re="Name: in1, Feature: c is required")
+
+ def testDenseNotMatchingShapeShouldFail(self):
+ dense_shapes = [(1, 3)]
+ dense_defaults = {
+ # no default!
+ }
+
+ original = [
+ example(features=features({
+ "a": float_feature([1, 1, 3]),
+ })),
+ example(features=features({
+ "a": float_feature([-1, -1]),
+ }))
+ ]
+
+ names = ["passing", "failing"]
+ serialized = [m.SerializeToString() for m in original]
+
+ self._test(
+ {
+ "serialized": tf.convert_to_tensor(serialized),
+ "names": names,
+ "dense_defaults": dense_defaults,
+ "dense_keys": ["a"],
+ "dense_types": [tf.float32],
+ "dense_shapes": dense_shapes,
+ },
+ expected_err_re="Name: failing, Key: a. Number of float values")
+
+ def testSerializedContainingSparse(self):
+ original = [
+ example(features=features({
+ "st_c": float_feature([3, 4])
+ })),
+ example(features=features({
+ "st_c": float_feature([]), # empty float list
+ })),
+ example(features=features({
+ "st_d": feature(), # feature with nothing in it
+ })),
+ example(features=features({
+ "st_c": float_feature([1, 2, -1]),
+ "st_d": bytes_feature(["hi"])
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_st_c = ( # indices, values, shape
+ np.array([[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64),
+ np.array([3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32),
+ np.array([4, 3], dtype=np.int64)) # batch == 2, max_elems = 3
+
+ expected_st_d = ( # indices, values, shape
+ np.array([[3, 0]], dtype=np.int64),
+ np.array(["hi"], dtype=np.str),
+ np.array([4, 1], dtype=np.int64)) # batch == 2, max_elems = 1
+
+ expected_output = {
+ "st_c": expected_st_c,
+ "st_d": expected_st_d,
+ }
+
+ self._test(
+ {
+ "serialized": tf.convert_to_tensor(serialized),
+ "sparse_keys": ["st_c", "st_d"],
+ "sparse_types": [tf.float32, tf.string],
+ }, expected_output)
+
+ def testSerializedContainingDense(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1, 1]),
+ "b": bytes_feature(["b0_str"]),
+ })),
+ example(features=features({
+ "a": float_feature([-1, -1]),
+ "b": bytes_feature(["b1"]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ dense_shapes = [(1, 2, 1), (1, 1, 1, 1)]
+
+ expected_output = {
+ "a": np.array([[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
+ "b": np.array(["b0_str", "b1"], dtype=np.str).reshape(2, 1, 1, 1, 1),
+ }
+
+ # No defaults, values required
+ self._test(
+ {
+ "serialized": tf.convert_to_tensor(serialized),
+ "dense_keys": ["a", "b"],
+ "dense_types": [tf.float32, tf.string],
+ "dense_shapes": dense_shapes,
+ }, expected_output)
+
+ def testSerializedContainingDenseScalar(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1]),
+ })),
+ example(features=features({}))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "a": np.array([[1], [-1]], dtype=np.float32) # 2x1 (column vector)
+ }
+
+ self._test(
+ {
+ "serialized": tf.convert_to_tensor(serialized),
+ "dense_defaults": {"a": -1},
+ "dense_shapes": [(1,)],
+ "dense_keys": ["a"],
+ "dense_types": [tf.float32],
+ }, expected_output)
+
+ def testSerializedContainingDenseWithDefaults(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1, 1]),
+ })),
+ example(features=features({
+ "b": bytes_feature(["b1"]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ dense_shapes = [(1, 2, 1), (1, 1, 1, 1)]
+ dense_types = [tf.float32, tf.string]
+ dense_defaults = {
+ "a": [3.0, -3.0],
+ "b": "tmp_str",
+ }
+
+ expected_output = {
+ "a": np.array([[1, 1], [3, -3]], dtype=np.float32).reshape(2, 1, 2, 1),
+ "b": np.array(["tmp_str", "b1"], dtype=np.str).reshape(2, 1, 1, 1, 1),
+ }
+
+ self._test(
+ {
+ "serialized": tf.convert_to_tensor(serialized),
+ "dense_defaults": dense_defaults,
+ "dense_keys": ["a", "b"],
+ "dense_types": dense_types,
+ "dense_shapes": dense_shapes,
+ }, expected_output)
+
+ def testSerializedContainingSparseAndDenseWithNoDefault(self):
+ dense_defaults = {
+ "a": [1, 2, 3],
+ "b": np.random.rand(3, 3).astype(np.str),
+ # Feature "c" must be provided
+ }
+ dense_shapes = [(1, 3), (3, 3), (2,)]
+
+ expected_st_a = ( # indices, values, shape
+ np.empty((0, 2), dtype=np.int64), # indices
+ np.empty((0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array([2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+
+ original = [
+ example(features=features({
+ "c": float_feature([3, 4])
+ })),
+ example(features=features({
+ "c": float_feature([1, 2])
+ }))
+ ]
+
+ names = ["in1", "in2"]
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "st_a": expected_st_a,
+ "a": np.array(2 * [[dense_defaults["a"]]]),
+ "b": np.array(2 * [dense_defaults["b"]]),
+ "c": np.array([[3, 4], [1, 2]], dtype=np.float32),
+ }
+
+ self._test(
+ {
+ "names": names,
+ "serialized": tf.convert_to_tensor(serialized),
+ "dense_defaults": dense_defaults,
+ "sparse_keys": ["st_a"],
+ "sparse_types": [tf.int64],
+ "dense_keys": ["a", "b", "c"],
+ "dense_types": [tf.int64, tf.string, tf.float32],
+ "dense_shapes": dense_shapes
+ }, expected_output)
+
+
+class ParseSingleExampleTest(tf.test.TestCase):
+
+ def _test(self, kwargs, expected_values=None, expected_err_re=None):
+ with self.test_session() as sess:
+ # Pull out some keys to check shape inference
+ dense_keys = kwargs["dense_keys"] if "dense_keys" in kwargs else []
+ sparse_keys = kwargs["sparse_keys"] if "sparse_keys" in kwargs else []
+ dense_shapes = kwargs["dense_shapes"] if "dense_shapes" in kwargs else []
+
+ # Returns dict w/ Tensors and SparseTensors
+ out = tf.parse_single_example(**kwargs)
+
+ # Check shapes
+ self.assertEqual(len(dense_keys), len(dense_shapes))
+ for (k, s) in zip(dense_keys, dense_shapes):
+ self.assertEqual(tuple(out[k].get_shape()), s)
+ for k in sparse_keys:
+ self.assertEqual(tuple(out[k].indices.get_shape().as_list()), (None, 1))
+ self.assertEqual(tuple(out[k].values.get_shape().as_list()), (None,))
+ self.assertEqual(tuple(out[k].shape.get_shape().as_list()), (1,))
+
+ # Check values
+ result = flatten_values_tensors_or_sparse(out.values()) # flatten values
+ if expected_err_re is None:
+ tf_result = sess.run(result)
+ _compare_output_to_expected(self, out, expected_values, tf_result)
+ else:
+ with self.assertRaisesOpError(expected_err_re):
+ sess.run(result)
+
+ def testSingleExampleWithSparseAndDense(self):
+ dense_types = [tf.int64, tf.string, tf.float32]
+ dense_shapes = [(1, 3), (3, 3), (2,)]
+ dense_defaults = {
+ "a": [1, 2, 3],
+ "b": np.random.rand(3, 3).astype(np.str),
+ # Feature "c" must be provided
+ }
+
+ original = example(features=features(
+ {"c": float_feature([3, 4]),
+ "st_a": float_feature([3.0, 4.0])}))
+
+ serialized = original.SerializeToString()
+
+ expected_st_a = (
+ np.array([[0], [1]], dtype=np.int64), # indices
+ np.array([3.0, 4.0], dtype=np.float32), # values
+ np.array([2], dtype=np.int64)) # shape: max_values = 2
+
+ expected_output = {
+ "st_a": expected_st_a,
+ "a": [dense_defaults["a"]],
+ "b": dense_defaults["b"],
+ "c": np.array([3, 4], dtype=np.float32),
+ }
+
+ self._test(
+ {
+ "names": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "dense_defaults": dense_defaults,
+ "dense_types": dense_types,
+ "sparse_keys": ["st_a"],
+ "sparse_types": [tf.float32],
+ "dense_keys": ["a", "b", "c"],
+ "dense_shapes": dense_shapes
+ }, expected_output)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
new file mode 100644
index 0000000000..b9a65726ee
--- /dev/null
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -0,0 +1,819 @@
+"""Functional tests for pooling operations."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+from tensorflow.python.ops import gen_nn_ops
+
+
+def GetInceptionMaxPoolShapes():
+ """Iterator for some of the max pool ops in the Inception 2015 model.
+
+ Yields:
+ Tuple (name, input_size, filter_size, out_size, strides, padding)
+ """
+ names = ["maxpool2", "maxpool3", "maxpool4", "maxpool5"]
+ input_sizes = [[32, 71, 71, 192],
+ [32, 35, 35, 288], [32, 17, 17, 1248], [32, 8, 8, 2048]]
+ filter_sizes = [[1, 3, 3, 1], [1, 3, 3, 1],
+ [1, 3, 3, 1], [1, 3, 3, 1]]
+ output_sizes = [[32, 35, 35, 192], [32, 17, 17, 288],
+ [32, 8, 8, 1248], [32, 8, 8, 2048]]
+ strides = [[1, 2, 2, 1], [1, 2, 2, 1], [1, 2, 2, 1],
+ [1, 1, 1, 1]]
+ paddings = ["VALID", "VALID", "VALID", "SAME"]
+ for n, i, f, o, s, p in zip(names, input_sizes, filter_sizes, output_sizes,
+ strides, paddings):
+ yield n, i, f, o, s, p
+
+
+class PoolingTest(tf.test.TestCase):
+
+ def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding,
+ expected, use_gpu):
+ """Verifies the output values of the pooling function.
+
+ Args:
+ pool_func: Function to be called, co.MaxPool, co.AvgPool,
+ or the Lua version.
+ input_sizes: Input tensor dimensions.
+ ksize: The kernel size dimensions
+ strides: The stride dimensions
+ padding: Padding type.
+ expected: An array containing the expected operation outputs.
+ use_gpu: Whether we are running on GPU.
+ """
+ total_size = 1
+ for s in input_sizes:
+ total_size *= s
+ # Initializes the input tensor with array containing incrementing
+ # numbers from 1.
+ x = [f * 1.0 for f in range(1, total_size + 1)]
+ with self.test_session(use_gpu=use_gpu) as sess:
+ t = tf.constant(x, shape=input_sizes)
+ t = pool_func(t, ksize=ksize, strides=strides, padding=padding)
+ actual = t.eval()
+ self.assertAllClose(expected, actual.flatten())
+ self.assertShapeEqual(actual, t)
+
+ def _testAvgPoolValidPadding(self, use_gpu):
+ expected_output = [7.0, 8.0, 9.0]
+ self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 3, 3, 3],
+ ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
+ padding="VALID",
+ expected=expected_output, use_gpu=use_gpu)
+
+ def _testAvgPoolSamePadding(self, use_gpu):
+ expected_output = [8.5, 9.5, 10.5, 14.5, 15.5, 16.5]
+ self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 2, 4, 3],
+ ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=expected_output, use_gpu=use_gpu)
+
+ def _testAvgPoolSamePaddingNonSquareWindow(self, use_gpu):
+ # input is:
+ # [1.0, 2.0
+ # 3.0 4.0]
+ #
+ # Window of [x, x] should do:
+ # [avg(1.0, 2.0), avg(2.0, padded0),
+ # avg(3.0, 4.0), avg(4.0, padded0)]
+ self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 2, 2, 1],
+ ksize=[1, 1, 2, 1], strides=[1, 1, 1, 1],
+ padding="SAME",
+ expected=[1.5, 2.0, 3.5, 4.0], use_gpu=use_gpu)
+
+ # Window of [x,
+ # x] should do:
+ # [avg(1.0, 3.0), avg(2.0, 4.0)
+ # avg(3.0, padded0), avg(4.0, padded0)]
+ self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 2, 2, 1],
+ ksize=[1, 2, 1, 1], strides=[1, 1, 1, 1],
+ padding="SAME",
+ expected=[2.0, 3.0, 3.0, 4.0], use_gpu=use_gpu)
+
+ def _testAvgPoolSamePaddingNonSquareWindowMultiBatch(self, use_gpu):
+ self._VerifyValues(tf.nn.avg_pool, input_sizes=[2, 2, 2, 2],
+ ksize=[1, 1, 2, 1], strides=[1, 1, 1, 1],
+ padding="SAME",
+ expected=[2.0, 3.0, 3.0, 4.0,
+ 6.0, 7.0, 7.0, 8.0,
+ 10.0, 11.0, 11.0, 12.0,
+ 14.0, 15.0, 15.0, 16.0],
+ use_gpu=use_gpu)
+ self._VerifyValues(tf.nn.avg_pool, input_sizes=[2, 2, 2, 2],
+ ksize=[1, 2, 1, 1], strides=[1, 1, 1, 1],
+ padding="SAME",
+ expected=[3.0, 4.0, 5.0, 6.0,
+ 5.0, 6.0, 7.0, 8.0,
+ 11.0, 12.0, 13.0, 14.0,
+ 13.0, 14.0, 15.0, 16.0],
+ use_gpu=use_gpu)
+
+ def _testAvgPoolValidPaddingUnevenStride(self, use_gpu):
+ self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 3, 3, 3],
+ ksize=[1, 2, 2, 1], strides=[1, 1, 2, 1],
+ padding="VALID",
+ expected=[7.0, 8.0, 9.0, 16.0, 17.0, 18.0],
+ use_gpu=use_gpu)
+ self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 3, 3, 3],
+ ksize=[1, 2, 2, 1], strides=[1, 2, 1, 1],
+ padding="VALID",
+ expected=[7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
+ use_gpu=use_gpu)
+
+ def _testAvgPoolSamePadding4(self, use_gpu):
+ expected_output = [11.0, 12.0, 13.0, 14.0, 19.0, 20.0, 21.0, 22.0, 43.0,
+ 44.0, 45.0, 46.0, 51.0, 52.0, 53.0, 54.0]
+ self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 4, 4, 4],
+ ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=expected_output, use_gpu=use_gpu)
+
+ def _testAvgPoolSamePaddingPacket4(self, use_gpu):
+ expected_output = [21.0, 22.0, 23.0, 24.0, 27.0, 28.0, 29.0, 30.0,
+ 45.0, 46.0, 47.0, 48.0, 51.0, 52.0, 53.0, 54.0]
+ self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 4, 4, 4],
+ ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=expected_output, use_gpu=use_gpu)
+
+ def _testAvgPoolSamePaddingPacket8(self, use_gpu):
+ expected_output = [73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 89.0,
+ 90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, 105.0, 106.0,
+ 107.0, 108.0, 109.0, 110.0, 111.0, 112.0, 117.0, 118.0,
+ 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 201.0, 202.0,
+ 203.0, 204.0, 205.0, 206.0, 207.0, 208.0, 217.0, 218.0,
+ 219.0, 220.0, 221.0, 222.0, 223.0, 224.0, 233.0, 234.0,
+ 235.0, 236.0, 237.0, 238.0, 239.0, 240.0, 245.0, 246.0,
+ 247.0, 248.0, 249.0, 250.0, 251.0, 252.0, 329.0, 330.0,
+ 331.0, 332.0, 333.0, 334.0, 335.0, 336.0, 345.0, 346.0,
+ 347.0, 348.0, 349.0, 350.0, 351.0, 352.0, 361.0, 362.0,
+ 363.0, 364.0, 365.0, 366.0, 367.0, 368.0, 373.0, 374.0,
+ 375.0, 376.0, 377.0, 378.0, 379.0, 380.0, 425.0, 426.0,
+ 427.0, 428.0, 429.0, 430.0, 431.0, 432.0, 441.0, 442.0,
+ 443.0, 444.0, 445.0, 446.0, 447.0, 448.0, 457.0, 458.0,
+ 459.0, 460.0, 461.0, 462.0, 463.0, 464.0, 469.0, 470.0,
+ 471.0, 472.0, 473.0, 474.0, 475.0, 476.0]
+ self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 8, 8, 8],
+ ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=expected_output, use_gpu=use_gpu)
+
+ def testAvgPooling(self):
+ for use_gpu in True, False:
+ self._testAvgPoolValidPadding(use_gpu)
+ self._testAvgPoolSamePadding(use_gpu)
+ self._testAvgPoolSamePaddingNonSquareWindow(use_gpu)
+ self._testAvgPoolSamePaddingNonSquareWindowMultiBatch(use_gpu)
+ self._testAvgPoolValidPaddingUnevenStride(use_gpu)
+ self._testAvgPoolSamePadding4(use_gpu)
+ self._testAvgPoolSamePaddingPacket4(use_gpu)
+ self._testAvgPoolSamePaddingPacket8(use_gpu)
+
+ def _testMaxPoolValidPadding(self, use_gpu):
+ expected_output = [13.0, 14.0, 15.0]
+ self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 3, 3, 3],
+ ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
+ padding="VALID",
+ expected=expected_output, use_gpu=use_gpu)
+
+ def _testMaxPoolSamePadding(self, use_gpu):
+ expected_output = [13.0, 14.0, 15.0, 16.0, 17.0, 18.0]
+ self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 2, 3, 3],
+ ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=expected_output, use_gpu=use_gpu)
+
+ def _testMaxPoolSamePaddingNonSquareWindow(self, use_gpu):
+ # input is:
+ # [1.0, 2.0
+ # 3.0 4.0]
+ #
+ # Window of [x, x] should do:
+ #
+ # [max(1.0, 2.0), max(2.0, padded0),
+ # max(3.0, 4.0), max(4.0, padded0)]
+ self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 2, 2, 1],
+ ksize=[1, 1, 2, 1], strides=[1, 1, 1, 1],
+ padding="SAME",
+ expected=[2.0, 2.0, 4.0, 4.0], use_gpu=use_gpu)
+
+ def _testMaxPoolValidPaddingUnevenStride(self, use_gpu):
+ self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 4, 4, 1],
+ ksize=[1, 2, 2, 1], strides=[1, 1, 2, 1],
+ padding="VALID",
+ expected=[6.0, 8.0, 10.0, 12.0, 14.0, 16.0],
+ use_gpu=use_gpu)
+ self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 4, 4, 1],
+ ksize=[1, 2, 2, 1], strides=[1, 2, 1, 1],
+ padding="VALID",
+ expected=[6.0, 7.0, 8.0, 14.0, 15.0, 16.0],
+ use_gpu=use_gpu)
+
+ def _testMaxPoolSamePaddingPacket4(self, use_gpu):
+ expected_output = [21.0, 22.0, 23.0, 24.0, 29.0, 30.0, 31.0, 32.0, 53.0,
+ 54.0, 55.0, 56.0, 61.0, 62.0, 63.0, 64.0]
+ self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 4, 4, 4],
+ ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=expected_output, use_gpu=use_gpu)
+
+ def _testMaxPoolSamePaddingPacket8(self, use_gpu):
+ expected_output = [145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0,
+ 161.0, 162.0, 163.0, 164.0, 165.0, 166.0, 167.0, 168.0,
+ 177.0, 178.0, 179.0, 180.0, 181.0, 182.0, 183.0, 184.0,
+ 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, 191.0, 192.0,
+ 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, 279.0, 280.0,
+ 289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0,
+ 305.0, 306.0, 307.0, 308.0, 309.0, 310.0, 311.0, 312.0,
+ 313.0, 314.0, 315.0, 316.0, 317.0, 318.0, 319.0, 320.0,
+ 401.0, 402.0, 403.0, 404.0, 405.0, 406.0, 407.0, 408.0,
+ 417.0, 418.0, 419.0, 420.0, 421.0, 422.0, 423.0, 424.0,
+ 433.0, 434.0, 435.0, 436.0, 437.0, 438.0, 439.0, 440.0,
+ 441.0, 442.0, 443.0, 444.0, 445.0, 446.0, 447.0, 448.0,
+ 465.0, 466.0, 467.0, 468.0, 469.0, 470.0, 471.0, 472.0,
+ 481.0, 482.0, 483.0, 484.0, 485.0, 486.0, 487.0, 488.0,
+ 497.0, 498.0, 499.0, 500.0, 501.0, 502.0, 503.0, 504.0,
+ 505.0, 506.0, 507.0, 508.0, 509.0, 510.0, 511.0, 512.0]
+ self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 8, 8, 8],
+ ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=expected_output, use_gpu=use_gpu)
+
+ def testMaxPooling(self):
+ for use_gpu in True, False:
+ self._testMaxPoolValidPadding(use_gpu)
+ self._testMaxPoolSamePadding(use_gpu)
+ self._testMaxPoolSamePaddingNonSquareWindow(use_gpu)
+ self._testMaxPoolValidPaddingUnevenStride(use_gpu)
+ self._testMaxPoolSamePaddingPacket4(use_gpu)
+ self._testMaxPoolSamePaddingPacket8(use_gpu)
+
+ # Tests for DepthwiseMaxPooling on CPU only.
+ def testDepthwiseMaxPool1x1DepthWindow1(self):
+ # input is:
+ # [1.0, ..., 10.0] along depth,
+ #
+ # We maxpool by depth in patches of 2.
+ self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 1, 1, 10],
+ ksize=[1, 1, 1, 2], strides=[1, 1, 1, 2],
+ padding="SAME",
+ expected=[2.0, 4.0, 6.0, 8.0, 10.0], use_gpu=False)
+
+ def testDepthwiseMaxPool2x2DepthWindow3(self):
+ # input is:
+ #
+ # a 2x2x6 cube, and we depthwise max across 3 to produce a 2x2x2
+ # output. Each node has contiguous values, so the depthwise max
+ # should be multiples of 3.0.
+ self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 2, 2, 6],
+ ksize=[1, 1, 1, 3], strides=[1, 1, 1, 3],
+ padding="SAME",
+ expected=[3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0],
+ use_gpu=False)
+
+ def _testDepthwiseMaxPoolInvalidConfig(self, in_size, ksize, strides,
+ error_msg, use_gpu=False):
+ t = tf.constant(1.0, shape=in_size)
+ with self.assertRaisesRegexp(ValueError, error_msg):
+ t = tf.nn.max_pool(t, ksize=ksize, strides=strides, padding="SAME")
+
+ def testDepthwiseMaxPoolInvalidConfigs(self):
+ self._testDepthwiseMaxPoolInvalidConfig(
+ [1, 2, 2, 4], [1, 2, 2, 2],
+ [1, 1, 1, 2], "exactly one of pooling across depth")
+ self._testDepthwiseMaxPoolInvalidConfig(
+ [1, 2, 2, 4], [1, 1, 1, 2],
+ [1, 1, 1, 1], "depth window to equal the depth stride")
+ self._testDepthwiseMaxPoolInvalidConfig(
+ [1, 2, 2, 4], [1, 1, 1, 3],
+ [1, 1, 1, 3], "evenly divide")
+ if tf.test.IsBuiltWithCuda():
+ with self.test_session(use_gpu=True):
+ t = tf.constant(1.0, shape=[1, 2, 2, 4])
+ with self.assertRaisesOpError("for CPU devices"):
+ tf.nn.max_pool(t, ksize=[1, 1, 1, 2], strides=[1, 1, 1, 2],
+ padding="SAME").eval()
+
+ # The following are tests that verify that the CPU and GPU implementations
+ # produce the same resuts.
+ def _CompareMaxPoolingFwd(self, input_shape, ksize, strides, padding):
+ tensor_input = np.random.rand(*input_shape).astype(np.float32)
+ with self.test_session(use_gpu=True):
+ t = tf.constant(tensor_input, shape=input_shape)
+ out_op, _ = tf.nn.max_pool_with_argmax(t, ksize, strides, padding)
+ gpu_val = out_op.eval()
+ with self.test_session(use_gpu=False):
+ t = tf.constant(tensor_input, shape=input_shape)
+ out_op = tf.nn.max_pool(t, ksize, strides, padding)
+ cpu_val = out_op.eval()
+ self.assertAllClose(cpu_val, gpu_val, rtol=1e-5, atol=1e-5)
+
+ def _CompareMaxPoolingBk(self, input_shape, output_shape, ksize, strides,
+ padding):
+ # Generate numbers in a narrow range, so that there are many duplicates
+ # in the input.
+ tensor_input = np.random.random_integers(0, 3,
+ input_shape).astype(np.float32)
+ tensor_output = np.random.rand(*output_shape).astype(np.float32)
+ with self.test_session(use_gpu=True):
+ t = tf.constant(tensor_input, shape=input_shape)
+ _, argmax_op = tf.nn.max_pool_with_argmax(t, ksize, strides, padding)
+ argmax = argmax_op.eval()
+ grad_in = tf.constant(tensor_output, shape=output_shape)
+ out_op = gen_nn_ops._max_pool_grad_with_argmax(t, grad_in, argmax,
+ ksize, strides, padding)
+ gpu_val = out_op.eval()
+ self.assertShapeEqual(gpu_val, out_op)
+ with self.test_session(use_gpu=False):
+ t = tf.constant(tensor_input, shape=input_shape)
+ out_op = tf.nn.max_pool(t, ksize, strides, padding)
+ orig_out = out_op.eval()
+ grad_in = tf.constant(tensor_output, shape=output_shape)
+ out_op = gen_nn_ops._max_pool_grad(t, orig_out, grad_in, ksize,
+ strides, padding)
+ cpu_val = out_op.eval()
+ self.assertShapeEqual(cpu_val, out_op)
+ self.assertAllClose(cpu_val, gpu_val, rtol=1e-5, atol=1e-5)
+
+ def testMaxPoolingWithArgmax(self):
+ # MaxPoolWithArgMax is implemented only on GPU.
+ if not tf.test.IsBuiltWithCuda():
+ return
+ tensor_input = [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]
+ with self.test_session(use_gpu=True) as sess:
+ t = tf.constant(tensor_input, shape=[1, 3, 3, 1])
+ out_op, argmax_op = tf.nn.max_pool_with_argmax(t,
+ ksize=[1, 2, 2, 1],
+ strides=[1, 1, 1, 1],
+ Targmax=tf.int64,
+ padding="VALID")
+ out, argmax = sess.run([out_op, argmax_op])
+ self.assertShapeEqual(out, out_op)
+ self.assertShapeEqual(argmax, argmax_op)
+ self.assertAllClose(out.ravel(), [1.0, 1.0, 1.0, 1.0])
+ self.assertAllEqual(argmax.ravel(), [0, 1, 3, 5])
+
+ def testMaxPoolingGradWithArgmax(self):
+ # MaxPoolWithArgMax is implemented only on GPU.
+ if not tf.test.IsBuiltWithCuda():
+ return
+ orig_input = [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]
+ tensor_input = [11.0, 12.0, 13.0, 14.0]
+ tensor_argmax = list(np.array([0, 1, 3, 5], dtype=np.int64))
+ with self.test_session(use_gpu=True) as sess:
+ orig_in = tf.constant(orig_input, shape=[1, 3, 3, 1])
+ t = tf.constant(tensor_input, shape=[1, 2, 2, 1])
+ argmax = tf.constant(tensor_argmax, shape=[1, 2, 2, 1],
+ dtype=tf.int64)
+ out_op = gen_nn_ops._max_pool_grad_with_argmax(orig_in, t, argmax,
+ ksize=[1, 2, 2, 1],
+ strides=[1, 1, 1, 1],
+ padding="VALID")
+ out = out_op.eval().flatten()
+ self.assertAllClose(out, [11.0, 12.0, 0.0, 13.0, 0.0,
+ 14.0, 0.0, 0.0, 0.0])
+
+ def _ConstructAndTestGradient(self, pool_func, input_sizes, output_sizes,
+ window_rows, window_cols, row_stride,
+ col_stride, padding, use_gpu,
+ x_init_value=None):
+ """Verifies the gradients of the avg pooling function.
+
+ Args:
+ pool_func: Function to be called, co.MaxPool, co.AvgPool,
+ or the Lua version.
+ input_sizes: Input tensor dimensions.
+ output_sizes: Output tensor dimensions.
+ window_rows: kernel size in row dim
+ window_cols: kernel size in col dim
+ row_stride: Row Stride.
+ col_stride: Col Stride.
+ padding: Padding type.
+ use_gpu: whether we are running on GPU
+ x_init_value: Values to be passed to the gradient checker.
+ """
+ total_size = 1
+ for s in input_sizes:
+ total_size *= s
+ # Initializes the input tensor with array containing incrementing
+ # numbers from 1.
+ x = [f * 1.0 for f in range(1, total_size + 1)]
+ with self.test_session(use_gpu=use_gpu):
+ input_tensor = tf.constant(x, shape=input_sizes, name="input")
+ if pool_func == tf.nn.avg_pool:
+ func_name = "avg_pool"
+ err_margin = 1e-4
+ else:
+ if x_init_value is None:
+ x_init_value = np.asfarray(
+ np.arange(1, total_size + 1),
+ dtype=np.float32).reshape(input_sizes)
+ func_name = "max_pool"
+ err_margin = 1e-3
+ t = pool_func(input_tensor, ksize=[1, window_rows, window_rows, 1],
+ strides=[1, row_stride, col_stride, 1],
+ padding=padding, name=func_name)
+ err = gc.ComputeGradientError(
+ input_tensor, input_sizes, t, output_sizes,
+ x_init_value=x_init_value, delta=1e-2)
+ print "%s gradient error = " % func_name, err
+ self.assertLess(err, err_margin)
+
+ def _testMaxPoolGradValidPadding1_1(self, use_gpu):
+ self._ConstructAndTestGradient(
+ tf.nn.max_pool, input_sizes=[1, 3, 3, 1],
+ output_sizes=[1, 3, 3, 1], window_rows=1, window_cols=1, row_stride=1,
+ col_stride=1, padding="VALID", use_gpu=use_gpu)
+
+ def _testMaxPoolGradValidPadding2_1_6(self, use_gpu):
+ self._ConstructAndTestGradient(
+ tf.nn.max_pool, input_sizes=[2, 6, 6, 3],
+ output_sizes=[2, 5, 5, 3], window_rows=2, window_cols=2, row_stride=1,
+ col_stride=1, padding="VALID", use_gpu=use_gpu)
+
+ def _testMaxPoolGradValidPadding2_1_7(self, use_gpu):
+ self._ConstructAndTestGradient(
+ tf.nn.max_pool, input_sizes=[2, 7, 7, 3],
+ output_sizes=[2, 6, 6, 3], window_rows=2, window_cols=2, row_stride=1,
+ col_stride=1, padding="VALID", use_gpu=use_gpu)
+
+ def _testMaxPoolGradValidPadding2_2(self, use_gpu):
+ self._ConstructAndTestGradient(
+ tf.nn.max_pool, input_sizes=[2, 2, 2, 3],
+ output_sizes=[2, 1, 1, 3], window_rows=2, window_cols=2, row_stride=2,
+ col_stride=2, padding="VALID", use_gpu=use_gpu)
+
+ def _testMaxPoolGradSamePadding1_1(self, use_gpu):
+ self._ConstructAndTestGradient(
+ tf.nn.max_pool, input_sizes=[2, 2, 4, 3],
+ output_sizes=[2, 2, 4, 3], window_rows=1, window_cols=1, row_stride=1,
+ col_stride=1, padding="SAME", use_gpu=use_gpu)
+
+ def _testMaxPoolGradSamePadding2_1(self, use_gpu):
+ self._ConstructAndTestGradient(
+ tf.nn.max_pool, input_sizes=[2, 2, 4, 3],
+ output_sizes=[2, 2, 4, 3], window_rows=2, window_cols=2, row_stride=1,
+ col_stride=1, padding="SAME", use_gpu=use_gpu)
+
+ def _testMaxPoolGradSamePadding2_2(self, use_gpu):
+ self._ConstructAndTestGradient(
+ tf.nn.max_pool, input_sizes=[2, 2, 4, 3],
+ output_sizes=[2, 1, 2, 3], window_rows=2, window_cols=2, row_stride=2,
+ col_stride=2, padding="SAME", use_gpu=use_gpu)
+
+ def _testMaxPoolGradSamePadding3_1(self, use_gpu):
+ self._ConstructAndTestGradient(
+ tf.nn.max_pool, input_sizes=[1, 7, 7, 1],
+ output_sizes=[1, 7, 7, 1], window_rows=3, window_cols=3, row_stride=1,
+ col_stride=1, padding="SAME", use_gpu=use_gpu)
+
+ def testMaxPoolGrad(self):
+ for use_gpu in True, False:
+ self._testMaxPoolGradValidPadding1_1(use_gpu=use_gpu)
+ self._testMaxPoolGradValidPadding2_1_6(use_gpu=use_gpu)
+ self._testMaxPoolGradValidPadding2_1_7(use_gpu=use_gpu)
+ self._testMaxPoolGradValidPadding2_2(use_gpu=use_gpu)
+ self._testMaxPoolGradSamePadding1_1(use_gpu=use_gpu)
+ self._testMaxPoolGradSamePadding2_1(use_gpu=use_gpu)
+ self._testMaxPoolGradSamePadding2_2(use_gpu=use_gpu)
+ self._testMaxPoolGradSamePadding3_1(use_gpu=use_gpu)
+
+ def _MaxPoolGrad(self, orig_input, orig_output, grad, window_rows,
+ window_cols, row_stride, col_stride, padding):
+ """Max Pooling Gradient.
+
+ Args:
+ orig_input: A float Tensor. The original input tensor.
+ orig_output: A float Tensor. The original output tensor.
+ grad: A float Tensor.
+ The 4D (batch x rows x cols x depth) output backprop.
+ window_rows: integer. Kernel size along rows dimension.
+ window_cols: integer. Kernel size along cols dimension.
+ row_stride: integer. Stride along rows dimension
+ col_stride: integer. Stride along cols dimension
+ padding: PoolingOpDef.Padding. Padding type.
+
+ Returns:
+ A Tensor.
+ """
+ return gen_nn_ops._max_pool_grad(
+ orig_input, orig_output, grad,
+ [1, window_rows, window_cols, 1], [1, row_stride, col_stride, 1],
+ padding)
+
+ def _testMaxPoolGradDirect(self, input_data, output_backprop,
+ expected_input_backprop, input_sizes, output_sizes,
+ window_rows, window_cols, row_stride, col_stride,
+ padding, use_gpu):
+ with self.test_session(use_gpu=use_gpu) as sess:
+ input_tensor = tf.constant(input_data, shape=input_sizes)
+ output_tensor = tf.nn.max_pool(
+ input_tensor, [1, window_rows, window_cols, 1],
+ [1, row_stride, col_stride, 1], padding)
+ output_backprop_tensor = tf.constant(output_backprop,
+ shape=output_sizes)
+
+ input_backprop_tensor = self._MaxPoolGrad(
+ input_tensor, output_tensor, output_backprop_tensor,
+ window_rows, window_cols, row_stride, col_stride, padding)
+
+ actual_input_backprop = input_backprop_tensor.eval()
+ self.assertShapeEqual(actual_input_backprop, input_backprop_tensor)
+ actual_input_backprop = actual_input_backprop.flatten()
+ actual_input_backprop = self._GetNdArray(actual_input_backprop)
+
+ actual_output = output_tensor.eval().flatten()
+ actual_output = self._GetNdArray(actual_output)
+
+ self.assertAllClose(expected_input_backprop, actual_input_backprop,
+ rtol=1e-6, atol=1e-6)
+
+ def _testMaxPoolGradDirect1_1(self):
+ input_data = [
+ 1.0, 1.0, 1.0, 1.0,
+ 1.0, 1.0, 1.0, 1.0,
+ 1.0, 1.0, 1.0, 1.0,
+ 1.0, 1.0, 1.0, 1.0]
+ output_backprop = [
+ 11.0, 12.0, 13.0,
+ 15.0, 16.0, 17.0,
+ 19.0, 20.0, 21.0]
+ expected_input_backprop = [
+ 11.0, 12.0, 13.0, 0.0,
+ 15.0, 16.0, 17.0, 0.0,
+ 19.0, 20.0, 21.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0]
+
+ for use_gpu in True, False:
+ self._testMaxPoolGradDirect(
+ input_data, output_backprop, expected_input_backprop,
+ input_sizes=[1, 4, 4, 1], output_sizes=[1, 3, 3, 1],
+ window_rows=2, window_cols=2, row_stride=1, col_stride=1,
+ padding="VALID", use_gpu=use_gpu)
+
+ def _testMaxPoolGradDirect1_2(self):
+ input_data = [
+ 1.0, 0.0, 1.0, 0.0,
+ 0.0, 1.0, 0.0, 1.0,
+ 1.0, 0.0, 1.0, 0.0,
+ 0.0, 1.0, 0.0, 1.0]
+ output_backprop = [
+ 11.0, 12.0, 13.0,
+ 15.0, 16.0, 17.0,
+ 19.0, 20.0, 21.0]
+ expected_input_backprop = [
+ 11.0, 0.0, 25.0, 0.0,
+ 0.0, 31.0, 0.0, 17.0,
+ 19.0, 0.0, 41.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0]
+
+ for use_gpu in True, False:
+ self._testMaxPoolGradDirect(
+ input_data, output_backprop, expected_input_backprop,
+ input_sizes=[1, 4, 4, 1], output_sizes=[1, 3, 3, 1],
+ window_rows=2, window_cols=2, row_stride=1, col_stride=1,
+ padding="VALID", use_gpu=use_gpu)
+
+ def _testMaxPoolGradDirect1_3(self):
+ input_data = [
+ 1.0, 0.0, 1.0, 0.0,
+ 0.0, 1.0, 0.0, 1.0,
+ 1.0, 0.0, 1.0, 0.0,
+ 0.0, 1.0, 0.0, 1.0,]
+ output_backprop = [
+ 11.0, 12.0, 13.0, 14.0,
+ 15.0, 16.0, 17.0, 18.0,
+ 19.0, 20.0, 21.0, 22.0,
+ 23.0, 24.0, 25.0, 26.0]
+ expected_input_backprop = [
+ 54, 0.0, 62, 0.0,
+ 0.0, 60, 0.0, 22.0,
+ 47, 0.0, 51, 0.0,
+ 0.0, 0.0, 0.0, 0.0,]
+
+ for use_gpu in True, False:
+ self._testMaxPoolGradDirect(
+ input_data, output_backprop, expected_input_backprop,
+ input_sizes=[1, 4, 4, 1], output_sizes=[1, 4, 4, 1],
+ window_rows=3, window_cols=3, row_stride=1, col_stride=1,
+ padding="SAME", use_gpu=use_gpu)
+
+ def _testMaxPoolGradDirectWithNans2_1(self):
+ input_data = [float("nan")] * 16
+ output_backprop = [
+ 11.0, 12.0, 13.0,
+ 15.0, 16.0, 17.0,
+ 19.0, 20.0, 21.0]
+ # Test the CPU implementation, which propagates diffs in case of NaN
+ expected_input_backprop_tf_cpu = [
+ 11.0, 12.0, 13.0, 0.0,
+ 15.0, 16.0, 17.0, 0.0,
+ 19.0, 20.0, 21.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0]
+ self._testMaxPoolGradDirect(
+ input_data, output_backprop, expected_input_backprop_tf_cpu,
+ input_sizes=[1, 4, 4, 1], output_sizes=[1, 3, 3, 1],
+ window_rows=2, window_cols=2, row_stride=1, col_stride=1,
+ padding="VALID", use_gpu=False)
+
+ if not tf.test.IsBuiltWithCuda():
+ return
+
+ # Test the GPU implementation that uses cudnn for now.
+ # It does not propagate the diff in cases of NaNs
+ expected_input_backprop_cudnn = [
+ 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0]
+ self._testMaxPoolGradDirect(
+ input_data, output_backprop, expected_input_backprop_cudnn,
+ input_sizes=[1, 4, 4, 1], output_sizes=[1, 3, 3, 1],
+ window_rows=2, window_cols=2, row_stride=1, col_stride=1,
+ padding="VALID", use_gpu=True)
+
+ def _testMaxPoolGradDirectWithNans2_2(self):
+ input_data = [float("nan")] * 16
+ output_backprop = [
+ float("nan"), 12.0, 13.0,
+ 15.0, float("nan"), 17.0,
+ 19.0, 20.0, float("nan")]
+ # Test the CPU implementation, which propagates diffs in case of NaN
+ expected_input_backprop_tf_cpu = [
+ float("nan"), 12.0, 13.0, 0.0,
+ 15.0, float("nan"), 17.0, 0.0,
+ 19.0, 20.0, float("nan"), 0.0,
+ 0.0, 0.0, 0.0, 0.0]
+ self._testMaxPoolGradDirect(
+ input_data, output_backprop, expected_input_backprop_tf_cpu,
+ input_sizes=[1, 4, 4, 1], output_sizes=[1, 3, 3, 1],
+ window_rows=2, window_cols=2, row_stride=1, col_stride=1,
+ padding="VALID", use_gpu=False)
+
+ if not tf.test.IsBuiltWithCuda():
+ return
+
+ # Test the GPU implementation that uses cudnn for now.
+ # It does not propagate the diff in cases of NaNs
+ expected_input_backprop_cudnn = [
+ 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0]
+ self._testMaxPoolGradDirect(
+ input_data, output_backprop, expected_input_backprop_cudnn,
+ input_sizes=[1, 4, 4, 1], output_sizes=[1, 3, 3, 1],
+ window_rows=2, window_cols=2, row_stride=1, col_stride=1,
+ padding="VALID", use_gpu=True)
+
+ def testMaxPoolGradDirect(self):
+ self._testMaxPoolGradDirect1_1()
+ self._testMaxPoolGradDirect1_2()
+ self._testMaxPoolGradDirect1_3()
+ self._testMaxPoolGradDirectWithNans2_1()
+ self._testMaxPoolGradDirectWithNans2_2()
+
+ def testAvgPoolGrad(self):
+ for use_gpu in False, True:
+ self._testAvgPoolGradValidPadding1_1(use_gpu)
+ self._testAvgPoolGradValidPadding2_1(use_gpu)
+ self._testAvgPoolGradValidPadding2_2(use_gpu)
+ self._testAvgPoolGradSamePadding1_1(use_gpu)
+ self._testAvgPoolGradSamePadding2_1(use_gpu)
+ self._testAvgPoolGradSamePadding2_2(use_gpu)
+ self._testAvgPoolGradSamePadding3_1(use_gpu)
+
+ def _testAvgPoolGradValidPadding1_1(self, use_gpu):
+ self._ConstructAndTestGradient(
+ tf.nn.avg_pool, input_sizes=[2, 3, 3, 3],
+ output_sizes=[2, 3, 3, 3], window_rows=1, window_cols=1, row_stride=1,
+ col_stride=1, padding="VALID", use_gpu=use_gpu)
+
+ def _testAvgPoolGradValidPadding2_1(self, use_gpu):
+ self._ConstructAndTestGradient(
+ tf.nn.avg_pool, input_sizes=[2, 3, 3, 3],
+ output_sizes=[2, 2, 2, 3], window_rows=2, window_cols=2, row_stride=1,
+ col_stride=1, padding="VALID", use_gpu=use_gpu)
+
+ def _testAvgPoolGradValidPadding2_2(self, use_gpu):
+ self._ConstructAndTestGradient(
+ tf.nn.avg_pool, input_sizes=[2, 2, 2, 3],
+ output_sizes=[2, 1, 1, 3], window_rows=2, window_cols=2, row_stride=2,
+ col_stride=2, padding="VALID", use_gpu=use_gpu)
+
+ def _testAvgPoolGradSamePadding1_1(self, use_gpu):
+ self._ConstructAndTestGradient(
+ tf.nn.avg_pool, input_sizes=[2, 2, 4, 3],
+ output_sizes=[2, 2, 4, 3], window_rows=1, window_cols=1, row_stride=1,
+ col_stride=1, padding="SAME", use_gpu=use_gpu)
+
+ def _testAvgPoolGradSamePadding2_1(self, use_gpu):
+ self._ConstructAndTestGradient(
+ tf.nn.avg_pool, input_sizes=[2, 2, 4, 3],
+ output_sizes=[2, 2, 4, 3], window_rows=2, window_cols=2, row_stride=1,
+ col_stride=1, padding="SAME", use_gpu=use_gpu)
+
+ def _testAvgPoolGradSamePadding2_2(self, use_gpu):
+ self._ConstructAndTestGradient(
+ tf.nn.avg_pool, input_sizes=[2, 2, 4, 3],
+ output_sizes=[2, 1, 2, 3], window_rows=2, window_cols=2, row_stride=2,
+ col_stride=2, padding="SAME", use_gpu=use_gpu)
+
+ def _testAvgPoolGradSamePadding3_1(self, use_gpu):
+ self._ConstructAndTestGradient(
+ tf.nn.avg_pool, input_sizes=[1, 7, 7, 1],
+ output_sizes=[1, 7, 7, 1], window_rows=3, window_cols=3, row_stride=1,
+ col_stride=1, padding="SAME", use_gpu=use_gpu)
+
+ def testShapeFunctionEdgeCases(self):
+ # All shapes unknown.
+ for pool_func in [tf.nn.max_pool, tf.nn.avg_pool]:
+ p = tf.nn.max_pool(tf.placeholder(tf.float32),
+ ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1],
+ padding="SAME")
+ self.assertEqual([None, None, None, None], p.get_shape().as_list())
+ p, am = tf.nn.max_pool_with_argmax(
+ tf.placeholder(tf.float32),
+ ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1],
+ padding="SAME")
+ self.assertEqual([None, None, None, None], p.get_shape().as_list())
+ self.assertEqual([None, None, None, None], am.get_shape().as_list())
+
+ # Incorrect input shape.
+ for pool_func in [tf.nn.max_pool, tf.nn.avg_pool,
+ tf.nn.max_pool_with_argmax]:
+ with self.assertRaises(ValueError):
+ pool_func(tf.placeholder(tf.float32, shape=[1, 3]),
+ ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1], padding="SAME")
+
+ # Illegal strides.
+ for pool_func in [tf.nn.max_pool, tf.nn.avg_pool,
+ tf.nn.max_pool_with_argmax]:
+ with self.assertRaisesRegexp(ValueError, "strides in the batch"):
+ pool_func(tf.placeholder(tf.float32),
+ ksize=[1, 1, 1, 1], strides=[2, 1, 1, 1], padding="SAME")
+ with self.assertRaisesRegexp(ValueError, "strides in the batch and depth"):
+ tf.nn.avg_pool(tf.placeholder(tf.float32),
+ ksize=[1, 1, 1, 1], strides=[1, 1, 1, 2], padding="SAME")
+
+ # Filter larger than input.
+ for pool_func in [tf.nn.max_pool, tf.nn.avg_pool,
+ tf.nn.max_pool_with_argmax]:
+ with self.assertRaisesRegexp(ValueError,
+ "filter must not be larger than the input"):
+ pool_func(tf.placeholder(tf.float32,
+ shape=[32, 20, 20, 3]),
+ ksize=[1, 20, 21, 1], strides=[1, 1, 1, 1], padding="SAME")
+ with self.assertRaisesRegexp(ValueError,
+ "filter must not be larger than the input"):
+ pool_func(tf.placeholder(tf.float32,
+ shape=[32, 20, 20, 3]),
+ ksize=[1, 21, 20, 1], strides=[1, 1, 1, 1], padding="SAME")
+
+ # Stride larger than filter.
+ for pool_func in [tf.nn.max_pool, tf.nn.avg_pool,
+ tf.nn.max_pool_with_argmax]:
+ with self.assertRaisesRegexp(
+ ValueError, "stride must be less than or equal to filter"):
+ pool_func(tf.placeholder(tf.float32,
+ shape=[32, 20, 20, 3]),
+ ksize=[1, 5, 3, 1], strides=[1, 5, 5, 1], padding="SAME")
+ with self.assertRaisesRegexp(
+ ValueError, "stride must be less than or equal to filter"):
+ pool_func(tf.placeholder(tf.float32,
+ shape=[32, 20, 20, 3]),
+ ksize=[1, 3, 5, 1], strides=[1, 5, 5, 1], padding="SAME")
+
+
+def GetMaxPoolFwdTest(input_size, filter_size, strides, padding):
+ def Test(self):
+ # MaxPoolWithArgMax is implemented only on GPU.
+ if not tf.test.IsBuiltWithCuda():
+ return
+ self._CompareMaxPoolingFwd(input_size, filter_size, strides, padding)
+ return Test
+
+
+def GetMaxPoolGradTest(input_size, filter_size, output_size, strides, padding):
+ def Test(self):
+ # MaxPoolWithArgMax is implemented only on GPU.
+ if not tf.test.IsBuiltWithCuda():
+ return
+ self._CompareMaxPoolingBk(input_size, output_size,
+ filter_size, strides, padding)
+ return Test
+
+
+if __name__ == "__main__":
+ for (name_, input_size_, filter_size_, output_size_, stride_,
+ padding_) in GetInceptionMaxPoolShapes():
+ setattr(PoolingTest, "testMaxPoolFwd_" + name_,
+ GetMaxPoolFwdTest(input_size_, filter_size_, stride_, padding_))
+ setattr(PoolingTest, "testMaxPoolGrad_" + name_,
+ GetMaxPoolGradTest(input_size_, filter_size_, output_size_,
+ stride_, padding_))
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/random_ops_test.py b/tensorflow/python/kernel_tests/random_ops_test.py
new file mode 100644
index 0000000000..311f0e3e5e
--- /dev/null
+++ b/tensorflow/python/kernel_tests/random_ops_test.py
@@ -0,0 +1,242 @@
+"""Tests for tensorflow.ops.random_ops."""
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class RandomNormalTest(tf.test.TestCase):
+
+ def _Sampler(self, num, mu, sigma, dtype, use_gpu, seed=None):
+ def func():
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
+ rng = tf.random_normal(
+ [num], mean=mu, stddev=sigma, dtype=dtype, seed=seed)
+ ret = np.empty([10, num])
+ for i in xrange(10):
+ ret[i, :] = sess.run(rng)
+ return ret
+ return func
+
+ # Asserts that different trials (1000 samples per trial) is unlikely
+ # to see the same sequence of values. Will catch buggy
+ # implementations which uses the same random number seed.
+ def testDistinct(self):
+ for use_gpu in [False, True]:
+ for dt in tf.float32, tf.float64:
+ sampler = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu)
+ x = sampler()
+ y = sampler()
+ # Number of different samples.
+ count = (x == y).sum()
+ if count >= 10:
+ print "x = ", x
+ print "y = ", y
+ print "count = ", count
+ self.assertTrue(count < 10)
+
+ # Checks that the CPU and GPU implementation returns the same results,
+ # given the same random seed
+ def testCPUGPUMatch(self):
+ for dt in tf.float32, tf.float64:
+ results = {}
+ for use_gpu in [False, True]:
+ sampler = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=12345)
+ results[use_gpu] = sampler()
+ self.assertAllClose(results[False], results[True], rtol=1e-6, atol=1e-6)
+
+ def testSeed(self):
+ for use_gpu in [False, True]:
+ for dt in tf.float32, tf.float64:
+ sx = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=345)
+ sy = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=345)
+ self.assertAllEqual(sx(), sy())
+
+ def testNoCSE(self):
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ shape = [2, 3, 4]
+ rnd1 = tf.random_normal(shape, 0.0, 1.0, tf.float32)
+ rnd2 = tf.random_normal(shape, 0.0, 1.0, tf.float32)
+ diff = rnd2 - rnd1
+ self.assertTrue(np.linalg.norm(diff.eval()) > 0.1)
+
+
+class TruncatedNormalTest(tf.test.TestCase):
+
+ def _Sampler(self, num, mu, sigma, dtype, use_gpu, seed=None):
+ def func():
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
+ rng = tf.truncated_normal(
+ [num], mean=mu, stddev=sigma, dtype=dtype, seed=seed)
+ ret = np.empty([10, num])
+ for i in xrange(10):
+ ret[i, :] = sess.run(rng)
+ return ret
+ return func
+
+ # Asserts that different trials (1000 samples per trial) is unlikely
+ # to see the same sequence of values. Will catch buggy
+ # implementations which uses the same random number seed.
+ def testDistinct(self):
+ # NOTE: RandomParameters on GPU is not supported.
+ for use_gpu in [False]:
+ for dt in tf.float32, tf.float64:
+ sampler = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu)
+ x = sampler()
+ y = sampler()
+ # Number of different samples.
+ count = (x == y).sum()
+ if count >= 10:
+ print "x = ", x
+ print "y = ", y
+ print "count = ", count
+ self.assertTrue(count < 10)
+
+ # Checks that the CPU and GPU implementation returns the same results,
+ # given the same random seed
+ def testCPUGPUMatch(self):
+ for dt in tf.float32, tf.float64:
+ results = {}
+ for use_gpu in [False, True]:
+ # We need a particular larger number of samples to test multiple rounds
+ # on GPU
+ sampler = self._Sampler(1000000, 0.0, 1.0, dt, use_gpu=use_gpu,
+ seed=12345)
+ results[use_gpu] = sampler()
+ self.assertAllClose(results[False], results[True], rtol=1e-6, atol=1e-6)
+
+ def testSeed(self):
+ for use_gpu in [False, True]:
+ for dt in tf.float32, tf.float64:
+ sx = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=345)
+ sy = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=345)
+ self.assertAllEqual(sx(), sy())
+
+ # The effective standard deviation of truncated normal is 85% of the
+ # requested one.
+ def testStdDev(self):
+ for use_gpu in [False, True]:
+ for dt in tf.float32, tf.float64:
+ stddev = 3.0
+ sampler = self._Sampler(100000, 0.0, stddev, dt, use_gpu=use_gpu)
+ x = sampler()
+ print "std(x)", np.std(x), abs(np.std(x) / stddev - 0.85)
+ self.assertTrue(abs(np.std(x) / stddev - 0.85) < 0.04)
+
+ def testNoCSE(self):
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ shape = [2, 3, 4]
+ rnd1 = tf.truncated_normal(shape, 0.0, 1.0, tf.float32)
+ rnd2 = tf.truncated_normal(shape, 0.0, 1.0, tf.float32)
+ diff = rnd2 - rnd1
+ self.assertTrue(np.linalg.norm(diff.eval()) > 0.1)
+
+
+class RandomUniformTest(tf.test.TestCase):
+
+ def _Sampler(self, num, minv, maxv, dtype, use_gpu, seed=None):
+ def func():
+ with self.test_session(use_gpu=use_gpu, graph=tf.Graph()) as sess:
+ rng = tf.random_uniform(
+ [num], minval=minv, maxval=maxv, dtype=dtype, seed=seed)
+ ret = np.empty([10, num])
+ for i in xrange(10):
+ ret[i, :] = sess.run(rng)
+ return ret
+ return func
+
+ def testRange(self):
+ for use_gpu in [False, True]:
+ for dt in tf.float32, tf.float64:
+ sampler = self._Sampler(1000, -2., 8., dt, use_gpu=use_gpu)
+ x = sampler()
+ self.assertTrue(-2 <= np.min(x))
+ self.assertTrue(np.max(x) <= 8)
+
+ # Asserts that different trials (1000 samples per trial) is unlikely
+ # to see the same sequence of values. Will catch buggy
+ # implementations which uses the same random number seed.
+ def testDistinct(self):
+ for use_gpu in [False, True]:
+ for dt in tf.float32, tf.float64:
+ sampler = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu)
+ x = sampler()
+ y = sampler()
+ count = (x == y).sum()
+ if count >= 10:
+ print "x = ", x
+ print "y = ", y
+ print "count = ", count
+ self.assertTrue(count < 10)
+
+ # Checks that the CPU and GPU implementation returns the same results,
+ # given the same random seed
+ def testCPUGPUMatch(self):
+ for dt in tf.float32, tf.float64:
+ results = {}
+ for use_gpu in [False, True]:
+ sampler = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=12345)
+ results[use_gpu] = sampler()
+ self.assertAllClose(results[False], results[True], rtol=1e-6, atol=1e-6)
+
+ def testSeed(self):
+ for use_gpu in [False, True]:
+ for dt in tf.float32, tf.float64:
+ sx = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=345)
+ sy = self._Sampler(1000, 0.0, 1.0, dt, use_gpu=use_gpu, seed=345)
+ self.assertAllEqual(sx(), sy())
+
+ def testNoCSE(self):
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ shape = [2, 3, 4]
+ rnd1 = tf.random_uniform(shape, 0.0, 1.0,
+ dtype=tf.float32)
+ rnd2 = tf.random_uniform(shape, 0.0, 1.0,
+ dtype=tf.float32)
+ diff = (rnd2 - rnd1).eval()
+ self.assertTrue(np.linalg.norm(diff) > 0.1)
+
+
+class RandomShapeTest(tf.test.TestCase):
+
+ def testRandomParameters(self):
+ # Fully known shape.
+ rnd1 = tf.truncated_normal([1, 2, 3])
+ self.assertEqual([1, 2, 3], rnd1.get_shape())
+ # Partially known shape.
+ rnd2 = tf.truncated_normal(tf.placeholder(tf.int32, shape=(3,)))
+ self.assertEqual([None, None, None], rnd2.get_shape().as_list())
+ # Unknown shape.
+ rnd3 = tf.truncated_normal(tf.placeholder(tf.int32))
+ self.assertIs(None, rnd3.get_shape().ndims)
+
+ def testRandomNormal(self):
+ # Fully known shape.
+ rnd1 = tf.random_normal([1, 2, 3])
+ self.assertEqual([1, 2, 3], rnd1.get_shape())
+ # Partially known shape.
+ rnd2 = tf.random_normal(tf.placeholder(tf.int32, shape=(3,)))
+ self.assertEqual([None, None, None], rnd2.get_shape().as_list())
+ # Unknown shape.
+ rnd3 = tf.random_normal(tf.placeholder(tf.int32))
+ self.assertIs(None, rnd3.get_shape().ndims)
+
+ def testRandomUniform(self):
+ # Fully known shape.
+ rnd1 = tf.random_uniform([1, 2, 3])
+ self.assertEqual([1, 2, 3], rnd1.get_shape())
+ # Partially known shape.
+ rnd2 = tf.random_uniform(
+ tf.placeholder(tf.int32, shape=(3,)))
+ self.assertEqual([None, None, None], rnd2.get_shape().as_list())
+ # Unknown shape.
+ rnd3 = tf.random_uniform(tf.placeholder(tf.int32))
+ self.assertIs(None, rnd3.get_shape().ndims)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/random_shuffle_queue_test.py b/tensorflow/python/kernel_tests/random_shuffle_queue_test.py
new file mode 100644
index 0000000000..343ffdcb76
--- /dev/null
+++ b/tensorflow/python/kernel_tests/random_shuffle_queue_test.py
@@ -0,0 +1,1054 @@
+"""Tests for tensorflow.ops.data_flow_ops.Queue."""
+import random
+import re
+import time
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class RandomShuffleQueueTest(tf.test.TestCase):
+
+ def setUp(self):
+ # Useful for debugging when a test times out.
+ super(RandomShuffleQueueTest, self).setUp()
+ tf.logging.error("Starting: %s", self._testMethodName)
+
+ def tearDown(self):
+ super(RandomShuffleQueueTest, self).tearDown()
+ tf.logging.error("Finished: %s", self._testMethodName)
+
+ def testEnqueue(self):
+ with self.test_session():
+ q = tf.RandomShuffleQueue(10, 5, tf.float32)
+ enqueue_op = q.enqueue((10.0,))
+ self.assertAllEqual(0, q.size().eval())
+ enqueue_op.run()
+ self.assertAllEqual(1, q.size().eval())
+
+ def testEnqueueWithShape(self):
+ with self.test_session():
+ q = tf.RandomShuffleQueue(
+ 10, 5, tf.float32, shapes=tf.TensorShape([3, 2]))
+ enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
+ enqueue_correct_op.run()
+ self.assertAllEqual(1, q.size().eval())
+ with self.assertRaises(ValueError):
+ q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],))
+
+ def testEnqueueManyWithShape(self):
+ with self.test_session():
+ q = tf.RandomShuffleQueue(
+ 10, 5, [tf.int32, tf.int32],
+ shapes=[(), (2,)])
+ q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run()
+ self.assertAllEqual(4, q.size().eval())
+
+ q2 = tf.RandomShuffleQueue(10, 5, tf.int32, shapes=tf.TensorShape([3]))
+ q2.enqueue(([1, 2, 3],))
+ q2.enqueue_many(([[1, 2, 3]],))
+
+ def testScalarShapes(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(
+ 10, 0, [tf.int32, tf.int32],
+ shapes=[(), (1,)])
+ q.enqueue_many([[1, 2, 3, 4], [[5], [6], [7], [8]]]).run()
+ q.enqueue([9, [10]]).run()
+ dequeue_t = q.dequeue()
+ results = []
+ for _ in range(2):
+ a, b = sess.run(dequeue_t)
+ results.append((a, b))
+ a, b = sess.run(q.dequeue_many(3))
+ for i in range(3):
+ results.append((a[i], b[i]))
+ self.assertItemsEqual([(1, [5]), (2, [6]), (3, [7]), (4, [8]), (9, [10])],
+ results)
+
+ def testParallelEnqueue(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(10, 0, tf.float32)
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ # Run one producer thread for each element in elems.
+ def enqueue(enqueue_op):
+ sess.run(enqueue_op)
+ threads = [self.checkedThread(target=enqueue, args=(e,))
+ for e in enqueue_ops]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ # Dequeue every element using a single thread.
+ results = []
+ for _ in xrange(len(elems)):
+ results.append(dequeued_t.eval())
+ self.assertItemsEqual(elems, results)
+
+ def testParallelDequeue(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(10, 0, tf.float32)
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ # Enqueue every element using a single thread.
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ # Run one consumer thread for each element in elems.
+ results = []
+
+ def dequeue():
+ results.append(sess.run(dequeued_t))
+ threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ self.assertItemsEqual(elems, results)
+
+ def testDequeue(self):
+ with self.test_session():
+ q = tf.RandomShuffleQueue(10, 0, tf.float32)
+ elems = [10.0, 20.0, 30.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ vals = [dequeued_t.eval() for _ in xrange(len(elems))]
+ self.assertItemsEqual(elems, vals)
+
+ def testEnqueueAndBlockingDequeue(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(3, 0, tf.float32)
+ elems = [10.0, 20.0, 30.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ def enqueue():
+ # The enqueue_ops should run after the dequeue op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ for enqueue_op in enqueue_ops:
+ sess.run(enqueue_op)
+
+ results = []
+
+ def dequeue():
+ for _ in xrange(len(elems)):
+ results.append(sess.run(dequeued_t))
+
+ enqueue_thread = self.checkedThread(target=enqueue)
+ dequeue_thread = self.checkedThread(target=dequeue)
+ enqueue_thread.start()
+ dequeue_thread.start()
+ enqueue_thread.join()
+ dequeue_thread.join()
+
+ self.assertItemsEqual(elems, results)
+
+ def testMultiEnqueueAndDequeue(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(
+ 10, 0, (tf.int32, tf.float32))
+ elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
+ enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
+ dequeued_t = q.dequeue()
+
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ results = []
+ for _ in xrange(len(elems)):
+ x, y = sess.run(dequeued_t)
+ results.append((x, y))
+ self.assertItemsEqual(elems, results)
+
+ def testQueueSizeEmpty(self):
+ with self.test_session():
+ q = tf.RandomShuffleQueue(10, 5, tf.float32)
+ self.assertEqual(0, q.size().eval())
+
+ def testQueueSizeAfterEnqueueAndDequeue(self):
+ with self.test_session():
+ q = tf.RandomShuffleQueue(10, 0, tf.float32)
+ enqueue_op = q.enqueue((10.0,))
+ dequeued_t = q.dequeue()
+ size = q.size()
+ self.assertEqual([], size.get_shape())
+
+ enqueue_op.run()
+ self.assertEqual([1], size.eval())
+ dequeued_t.op.run()
+ self.assertEqual([0], size.eval())
+
+ def testEnqueueMany(self):
+ with self.test_session():
+ q = tf.RandomShuffleQueue(10, 0, tf.float32)
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ dequeued_t = q.dequeue()
+ enqueue_op.run()
+ enqueue_op.run()
+
+ results = []
+ for _ in range(8):
+ results.append(dequeued_t.eval())
+ self.assertItemsEqual(elems + elems, results)
+
+ def testEmptyEnqueueMany(self):
+ with self.test_session():
+ q = tf.RandomShuffleQueue(10, 5, tf.float32)
+ empty_t = tf.constant([], dtype=tf.float32,
+ shape=[0, 2, 3])
+ enqueue_op = q.enqueue_many((empty_t,))
+ size_t = q.size()
+
+ self.assertEqual(0, size_t.eval())
+ enqueue_op.run()
+ self.assertEqual(0, size_t.eval())
+
+ def testEmptyDequeueMany(self):
+ with self.test_session():
+ q = tf.RandomShuffleQueue(10, 0, tf.float32, shapes=())
+ enqueue_op = q.enqueue((10.0,))
+ dequeued_t = q.dequeue_many(0)
+
+ self.assertEqual([], dequeued_t.eval().tolist())
+ enqueue_op.run()
+ self.assertEqual([], dequeued_t.eval().tolist())
+
+ def testEmptyDequeueManyWithNoShape(self):
+ with self.test_session():
+ q = tf.RandomShuffleQueue(10, 0, tf.float32)
+ enqueue_op = q.enqueue(
+ (tf.constant([10.0, 20.0], shape=(1, 2)),))
+ dequeued_t = q.dequeue_many(0)
+
+ # Expect the operation to fail due to the shape not being constrained.
+ with self.assertRaisesOpError(
+ "requires the components to have specified shapes"):
+ dequeued_t.eval()
+
+ enqueue_op.run()
+
+ # Unlike tf.Queue, RandomShuffleQueue does not make any
+ # attempt to support DequeueMany with unspecified shapes, even if
+ # a shape could be inferred from the elements enqueued.
+ with self.assertRaisesOpError(
+ "requires the components to have specified shapes"):
+ dequeued_t.eval()
+
+ def testMultiEnqueueMany(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(
+ 10, 0, (tf.float32, tf.int32))
+ float_elems = [10.0, 20.0, 30.0, 40.0]
+ int_elems = [[1, 2], [3, 4], [5, 6], [7, 8]]
+ enqueue_op = q.enqueue_many((float_elems, int_elems))
+ dequeued_t = q.dequeue()
+
+ enqueue_op.run()
+ enqueue_op.run()
+
+ results = []
+ for _ in range(8):
+ float_val, int_val = sess.run(dequeued_t)
+ results.append((float_val, [int_val[0], int_val[1]]))
+ expected = zip(float_elems, int_elems) + zip(float_elems, int_elems)
+ self.assertItemsEqual(expected, results)
+
+ def testDequeueMany(self):
+ with self.test_session():
+ q = tf.RandomShuffleQueue(10, 0, tf.float32, ((),))
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ enqueue_op = q.enqueue_many((elems,))
+ dequeued_t = q.dequeue_many(5)
+
+ enqueue_op.run()
+
+ results = dequeued_t.eval().tolist()
+ results.extend(dequeued_t.eval())
+ self.assertItemsEqual(elems, results)
+
+ def testMultiDequeueMany(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(
+ 10, 0, (tf.float32, tf.int32),
+ shapes=((), (2,)))
+ float_elems = [
+ 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ int_elems = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
+ [11, 12], [13, 14], [15, 16], [17, 18], [19, 20]]
+ enqueue_op = q.enqueue_many((float_elems, int_elems))
+ dequeued_t = q.dequeue_many(4)
+ dequeued_single_t = q.dequeue()
+
+ enqueue_op.run()
+
+ results = []
+ float_val, int_val = sess.run(dequeued_t)
+ self.assertEqual(float_val.shape, dequeued_t[0].get_shape())
+ self.assertEqual(int_val.shape, dequeued_t[1].get_shape())
+ results.extend(zip(float_val, int_val.tolist()))
+
+ float_val, int_val = sess.run(dequeued_t)
+ results.extend(zip(float_val, int_val.tolist()))
+
+ float_val, int_val = sess.run(dequeued_single_t)
+ self.assertEqual(float_val.shape, dequeued_single_t[0].get_shape())
+ self.assertEqual(int_val.shape, dequeued_single_t[1].get_shape())
+ results.append((float_val, int_val.tolist()))
+
+ float_val, int_val = sess.run(dequeued_single_t)
+ results.append((float_val, int_val.tolist()))
+
+ self.assertItemsEqual(zip(float_elems, int_elems), results)
+
+ def testHighDimension(self):
+ with self.test_session():
+ q = tf.RandomShuffleQueue(
+ 10, 0, tf.int32, ((4, 4, 4, 4)))
+ elems = np.array([[[[[x] * 4] * 4] * 4] * 4 for x in range(10)], np.int32)
+ enqueue_op = q.enqueue_many((elems,))
+ dequeued_t = q.dequeue_many(10)
+
+ enqueue_op.run()
+ self.assertItemsEqual(dequeued_t.eval().tolist(), elems.tolist())
+
+ def testParallelEnqueueMany(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(1000, 0, tf.float32, shapes=())
+ elems = [10.0 * x for x in range(100)]
+ enqueue_op = q.enqueue_many((elems,))
+ dequeued_t = q.dequeue_many(1000)
+
+ # Enqueue 100 items in parallel on 10 threads.
+ def enqueue():
+ sess.run(enqueue_op)
+ threads = [self.checkedThread(target=enqueue) for _ in range(10)]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ self.assertItemsEqual(dequeued_t.eval(), elems * 10)
+
+ def testParallelDequeueMany(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(1000, 0, tf.float32, shapes=())
+ elems = [10.0 * x for x in range(1000)]
+ enqueue_op = q.enqueue_many((elems,))
+ dequeued_t = q.dequeue_many(100)
+
+ enqueue_op.run()
+
+ # Dequeue 100 items in parallel on 10 threads.
+ dequeued_elems = []
+
+ def dequeue():
+ dequeued_elems.extend(sess.run(dequeued_t))
+ threads = [self.checkedThread(target=dequeue) for _ in range(10)]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ self.assertItemsEqual(elems, dequeued_elems)
+
+ def testBlockingDequeueMany(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(10, 0, tf.float32, ((),))
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ dequeued_t = q.dequeue_many(4)
+
+ dequeued_elems = []
+
+ def enqueue():
+ # The enqueue_op should run after the dequeue op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ sess.run(enqueue_op)
+
+ def dequeue():
+ dequeued_elems.extend(sess.run(dequeued_t).tolist())
+
+ enqueue_thread = self.checkedThread(target=enqueue)
+ dequeue_thread = self.checkedThread(target=dequeue)
+ enqueue_thread.start()
+ dequeue_thread.start()
+ enqueue_thread.join()
+ dequeue_thread.join()
+
+ self.assertItemsEqual(elems, dequeued_elems)
+
+ def testDequeueManyWithTensorParameter(self):
+ with self.test_session():
+ # Define a first queue that contains integer counts.
+ dequeue_counts = [random.randint(1, 10) for _ in range(100)]
+ count_q = tf.RandomShuffleQueue(100, 0, tf.int32)
+ enqueue_counts_op = count_q.enqueue_many((dequeue_counts,))
+ total_count = sum(dequeue_counts)
+
+ # Define a second queue that contains total_count elements.
+ elems = [random.randint(0, 100) for _ in range(total_count)]
+ q = tf.RandomShuffleQueue(
+ total_count, 0, tf.int32, ((),))
+ enqueue_elems_op = q.enqueue_many((elems,))
+
+ # Define a subgraph that first dequeues a count, then DequeuesMany
+ # that number of elements.
+ dequeued_t = q.dequeue_many(count_q.dequeue())
+
+ enqueue_counts_op.run()
+ enqueue_elems_op.run()
+
+ dequeued_elems = []
+ for _ in dequeue_counts:
+ dequeued_elems.extend(dequeued_t.eval())
+ self.assertItemsEqual(elems, dequeued_elems)
+
+ def testDequeueFromClosedQueue(self):
+ with self.test_session():
+ q = tf.RandomShuffleQueue(10, 2, tf.float32)
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ close_op = q.close()
+ dequeued_t = q.dequeue()
+
+ enqueue_op.run()
+ close_op.run()
+ results = [dequeued_t.eval() for _ in elems]
+ expected = [[elem] for elem in elems]
+ self.assertItemsEqual(expected, results)
+
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError,
+ "is closed and has insufficient"):
+ dequeued_t.eval()
+
+ def testBlockingDequeueFromClosedQueue(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(10, 2, tf.float32)
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ close_op = q.close()
+ dequeued_t = q.dequeue()
+
+ enqueue_op.run()
+
+ results = []
+ def dequeue():
+ for _ in elems:
+ results.append(sess.run(dequeued_t))
+ self.assertItemsEqual(elems, results)
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError,
+ "is closed and has insufficient"):
+ sess.run(dequeued_t)
+
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ # The dequeue thread blocked when it hit the min_size requirement.
+ self.assertEqual(len(results), 2)
+ close_op.run()
+ dequeue_thread.join()
+ # Once the queue is closed, the min_size requirement is lifted.
+ self.assertEqual(len(results), 4)
+
+ def testBlockingDequeueFromClosedEmptyQueue(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(10, 0, tf.float32)
+ close_op = q.close()
+ dequeued_t = q.dequeue()
+
+ finished = [] # Needs to be a mutable type
+ def dequeue():
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError,
+ "is closed and has insufficient"):
+ sess.run(dequeued_t)
+ finished.append(True)
+
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ self.assertEqual(len(finished), 0)
+ close_op.run()
+ dequeue_thread.join()
+ self.assertEqual(len(finished), 1)
+
+ def testBlockingDequeueManyFromClosedQueue(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(10, 0, tf.float32, ((),))
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ close_op = q.close()
+ dequeued_t = q.dequeue_many(4)
+
+ enqueue_op.run()
+
+ progress = [] # Must be mutable
+ def dequeue():
+ self.assertItemsEqual(elems, sess.run(dequeued_t))
+ progress.append(1)
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError,
+ "is closed and has insufficient"):
+ sess.run(dequeued_t)
+ progress.append(2)
+
+ self.assertEqual(len(progress), 0)
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ for _ in range(100):
+ time.sleep(0.01)
+ if len(progress) == 1: break
+ self.assertEqual(len(progress), 1)
+ time.sleep(0.01)
+ close_op.run()
+ dequeue_thread.join()
+ self.assertEqual(len(progress), 2)
+
+ def testBlockingDequeueManyFromClosedQueueWithElementsRemaining(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(10, 0, tf.float32, ((),))
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ close_op = q.close()
+ dequeued_t = q.dequeue_many(3)
+ cleanup_dequeue_t = q.dequeue_many(q.size())
+
+ enqueue_op.run()
+
+ results = []
+ def dequeue():
+ results.extend(sess.run(dequeued_t))
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError,
+ "is closed and has insufficient"):
+ sess.run(dequeued_t)
+ # However, the last result was dequeued before the queue was closed,
+ # so nothing more is added to results.
+ results.extend(sess.run(cleanup_dequeue_t))
+
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ self.assertEqual(len(results), 3)
+ close_op.run()
+ dequeue_thread.join()
+ self.assertEqual(len(results), 3)
+
+ def testBlockingDequeueManyFromClosedEmptyQueue(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(10, 5, tf.float32, ((),))
+ close_op = q.close()
+ dequeued_t = q.dequeue_many(4)
+
+ def dequeue():
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError,
+ "is closed and has insufficient"):
+ sess.run(dequeued_t)
+
+ dequeue_thread = self.checkedThread(target=dequeue)
+ dequeue_thread.start()
+ # The close_op should run after the dequeue_thread has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ close_op.run()
+ dequeue_thread.join()
+
+ def testEnqueueToClosedQueue(self):
+ with self.test_session():
+ q = tf.RandomShuffleQueue(10, 4, tf.float32)
+ enqueue_op = q.enqueue((10.0,))
+ close_op = q.close()
+
+ enqueue_op.run()
+ close_op.run()
+
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.AbortedError, "is closed"):
+ enqueue_op.run()
+
+ def testEnqueueManyToClosedQueue(self):
+ with self.test_session():
+ q = tf.RandomShuffleQueue(10, 5, tf.float32, ((),))
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ close_op = q.close()
+
+ enqueue_op.run()
+ close_op.run()
+
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.AbortedError, "is closed"):
+ enqueue_op.run()
+
+ def testBlockingEnqueueToFullQueue(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(4, 0, tf.float32, ((),))
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ blocking_enqueue_op = q.enqueue((50.0,))
+ dequeued_t = q.dequeue()
+
+ enqueue_op.run()
+
+ def blocking_enqueue():
+ sess.run(blocking_enqueue_op)
+ thread = self.checkedThread(target=blocking_enqueue)
+ thread.start()
+ # The dequeue ops should run after the blocking_enqueue_op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ results = []
+ for _ in elems:
+ results.append(dequeued_t.eval())
+ results.append(dequeued_t.eval())
+ self.assertItemsEqual(elems + [50.0], results)
+ # There wasn't room for 50.0 in the queue when the first element was
+ # dequeued.
+ self.assertNotEqual(50.0, results[0])
+ thread.join()
+
+ def testBlockingEnqueueManyToFullQueue(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(4, 0, tf.float32, ((),))
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ blocking_enqueue_op = q.enqueue_many(([50.0, 60.0],))
+ dequeued_t = q.dequeue()
+
+ enqueue_op.run()
+
+ def blocking_enqueue():
+ sess.run(blocking_enqueue_op)
+ thread = self.checkedThread(target=blocking_enqueue)
+ thread.start()
+ # The dequeue ops should run after the blocking_enqueue_op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+
+ results = []
+ for _ in elems:
+ time.sleep(0.01)
+ results.append(dequeued_t.eval())
+ results.append(dequeued_t.eval())
+ results.append(dequeued_t.eval())
+ self.assertItemsEqual(elems + [50.0, 60.0], results)
+ # There wasn't room for 50.0 or 60.0 in the queue when the first
+ # element was dequeued.
+ self.assertNotEqual(50.0, results[0])
+ self.assertNotEqual(60.0, results[0])
+ # Similarly for 60.0 and the second element.
+ self.assertNotEqual(60.0, results[1])
+
+ def testBlockingEnqueueToClosedQueue(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(4, 0, tf.float32, ((),))
+ elems = [10.0, 20.0, 30.0, 40.0]
+ enqueue_op = q.enqueue_many((elems,))
+ blocking_enqueue_op = q.enqueue((50.0,))
+ dequeued_t = q.dequeue()
+ close_op = q.close()
+
+ enqueue_op.run()
+
+ def blocking_enqueue():
+ # Expect the operation to succeed since it will complete
+ # before the queue is closed.
+ sess.run(blocking_enqueue_op)
+
+ # Expect the operation to fail due to the queue being closed.
+ with self.assertRaisesRegexp(tf.errors.AbortedError, "closed"):
+ sess.run(blocking_enqueue_op)
+ thread1 = self.checkedThread(target=blocking_enqueue)
+ thread1.start()
+
+ # The close_op should run after the first blocking_enqueue_op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+
+ def blocking_close():
+ sess.run(close_op)
+ thread2 = self.checkedThread(target=blocking_close)
+ thread2.start()
+
+ # Wait for the close op to block before unblocking the enqueue.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+
+ results = []
+ # Dequeue to unblock the first blocking_enqueue_op, after which the
+ # close will complete.
+ results.append(dequeued_t.eval())
+ self.assertTrue(results[0] in elems)
+ thread2.join()
+ thread1.join()
+
+ def testBlockingEnqueueManyToClosedQueue(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(4, 0, tf.float32, ((),))
+ elems = [10.0, 20.0, 30.0]
+ enqueue_op = q.enqueue_many((elems,))
+ blocking_enqueue_op = q.enqueue_many(([50.0, 60.0],))
+ close_op = q.close()
+ size_t = q.size()
+
+ enqueue_op.run()
+ self.assertEqual(size_t.eval(), 3)
+
+ def blocking_enqueue():
+ # This will block until the dequeue after the close.
+ sess.run(blocking_enqueue_op)
+ # At this point the close operation will become unblocked, so the
+ # next enqueue will fail.
+ with self.assertRaisesRegexp(tf.errors.AbortedError, "closed"):
+ sess.run(blocking_enqueue_op)
+ thread1 = self.checkedThread(target=blocking_enqueue)
+ thread1.start()
+ # The close_op should run after the blocking_enqueue_op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ # First blocking_enqueue_op of blocking_enqueue has enqueued 1 of 2
+ # elements, and is blocked waiting for one more element to be dequeue.
+ self.assertEqual(size_t.eval(), 4)
+
+ def blocking_close():
+ sess.run(close_op)
+ thread2 = self.checkedThread(target=blocking_close)
+ thread2.start()
+
+ # The close_op should run before the second blocking_enqueue_op
+ # has started.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+
+ # Unblock the first blocking_enqueue_op in blocking_enqueue.
+ q.dequeue().eval()
+
+ thread2.join()
+ thread1.join()
+
+ def testSharedQueueSameSession(self):
+ with self.test_session():
+ q1 = tf.RandomShuffleQueue(
+ 1, 0, tf.float32, ((),), shared_name="shared_queue")
+ q1.enqueue((10.0,)).run()
+
+ q2 = tf.RandomShuffleQueue(
+ 1, 0, tf.float32, ((),), shared_name="shared_queue")
+
+ q1_size_t = q1.size()
+ q2_size_t = q2.size()
+
+ self.assertEqual(q1_size_t.eval(), 1)
+ self.assertEqual(q2_size_t.eval(), 1)
+
+ self.assertEqual(q2.dequeue().eval(), 10.0)
+
+ self.assertEqual(q1_size_t.eval(), 0)
+ self.assertEqual(q2_size_t.eval(), 0)
+
+ q2.enqueue((20.0,)).run()
+
+ self.assertEqual(q1_size_t.eval(), 1)
+ self.assertEqual(q2_size_t.eval(), 1)
+
+ self.assertEqual(q1.dequeue().eval(), 20.0)
+
+ self.assertEqual(q1_size_t.eval(), 0)
+ self.assertEqual(q2_size_t.eval(), 0)
+
+ def testIncompatibleSharedQueueErrors(self):
+ with self.test_session():
+ q_a_1 = tf.RandomShuffleQueue(
+ 10, 5, tf.float32, shared_name="q_a")
+ q_a_2 = tf.RandomShuffleQueue(
+ 15, 5, tf.float32, shared_name="q_a")
+ q_a_1.queue_ref.eval()
+ with self.assertRaisesOpError("capacity"):
+ q_a_2.queue_ref.eval()
+
+ q_b_1 = tf.RandomShuffleQueue(
+ 10, 0, tf.float32, shared_name="q_b")
+ q_b_2 = tf.RandomShuffleQueue(
+ 10, 5, tf.float32, shared_name="q_b")
+ q_b_1.queue_ref.eval()
+ with self.assertRaisesOpError("min_after_dequeue"):
+ q_b_2.queue_ref.eval()
+
+ q_c_1 = tf.RandomShuffleQueue(
+ 10, 5, tf.float32, shared_name="q_c")
+ q_c_2 = tf.RandomShuffleQueue(
+ 10, 5, tf.int32, shared_name="q_c")
+ q_c_1.queue_ref.eval()
+ with self.assertRaisesOpError("component types"):
+ q_c_2.queue_ref.eval()
+
+ q_d_1 = tf.RandomShuffleQueue(
+ 10, 5, tf.float32, shared_name="q_d")
+ q_d_2 = tf.RandomShuffleQueue(
+ 10, 5, tf.float32, shapes=[(1, 1, 2, 3)], shared_name="q_d")
+ q_d_1.queue_ref.eval()
+ with self.assertRaisesOpError("component shapes"):
+ q_d_2.queue_ref.eval()
+
+ q_e_1 = tf.RandomShuffleQueue(
+ 10, 5, tf.float32, shapes=[(1, 1, 2, 3)], shared_name="q_e")
+ q_e_2 = tf.RandomShuffleQueue(
+ 10, 5, tf.float32, shared_name="q_e")
+ q_e_1.queue_ref.eval()
+ with self.assertRaisesOpError("component shapes"):
+ q_e_2.queue_ref.eval()
+
+ q_f_1 = tf.RandomShuffleQueue(
+ 10, 5, tf.float32, shapes=[(1, 1, 2, 3)], shared_name="q_f")
+ q_f_2 = tf.RandomShuffleQueue(
+ 10, 5, tf.float32, shapes=[(1, 1, 2, 4)], shared_name="q_f")
+ q_f_1.queue_ref.eval()
+ with self.assertRaisesOpError("component shapes"):
+ q_f_2.queue_ref.eval()
+
+ q_g_1 = tf.RandomShuffleQueue(
+ 10, 5, tf.float32, shared_name="q_g")
+ q_g_2 = tf.RandomShuffleQueue(
+ 10, 5, (tf.float32, tf.int32), shared_name="q_g")
+ q_g_1.queue_ref.eval()
+ with self.assertRaisesOpError("component types"):
+ q_g_2.queue_ref.eval()
+
+ q_h_1 = tf.RandomShuffleQueue(
+ 10, 5, tf.float32, seed=12, shared_name="q_h")
+ q_h_2 = tf.RandomShuffleQueue(
+ 10, 5, tf.float32, seed=21, shared_name="q_h")
+ q_h_1.queue_ref.eval()
+ with self.assertRaisesOpError("random seeds"):
+ q_h_2.queue_ref.eval()
+
+ def testSelectQueue(self):
+ with self.test_session():
+ num_queues = 10
+ qlist = list()
+ for _ in xrange(num_queues):
+ qlist.append(
+ tf.RandomShuffleQueue(10, 0, tf.float32))
+ # Enqueue/Dequeue into a dynamically selected queue
+ for _ in xrange(20):
+ index = np.random.randint(num_queues)
+ q = tf.RandomShuffleQueue.from_list(index, qlist)
+ q.enqueue((10.,)).run()
+ self.assertEqual(q.dequeue().eval(), 10.0)
+
+ def testSelectQueueOutOfRange(self):
+ with self.test_session():
+ q1 = tf.RandomShuffleQueue(10, 0, tf.float32)
+ q2 = tf.RandomShuffleQueue(15, 0, tf.float32)
+ enq_q = tf.RandomShuffleQueue.from_list(3, [q1, q2])
+ with self.assertRaisesOpError("Index must be in the range"):
+ enq_q.dequeue().eval()
+
+ def _blockingDequeue(self, sess, dequeue_op):
+ with self.assertRaisesOpError("Dequeue operation was cancelled"):
+ sess.run(dequeue_op)
+
+ def _blockingDequeueMany(self, sess, dequeue_many_op):
+ with self.assertRaisesOpError("Dequeue operation was cancelled"):
+ sess.run(dequeue_many_op)
+
+ def _blockingEnqueue(self, sess, enqueue_op):
+ with self.assertRaisesOpError("Enqueue operation was cancelled"):
+ sess.run(enqueue_op)
+
+ def _blockingEnqueueMany(self, sess, enqueue_many_op):
+ with self.assertRaisesOpError("Enqueue operation was cancelled"):
+ sess.run(enqueue_many_op)
+
+ def testResetOfBlockingOperation(self):
+ with self.test_session() as sess:
+ q_empty = tf.RandomShuffleQueue(
+ 5, 0, tf.float32, ((),))
+ dequeue_op = q_empty.dequeue()
+ dequeue_many_op = q_empty.dequeue_many(1)
+
+ q_full = tf.RandomShuffleQueue(5, 0, tf.float32, ((),))
+ sess.run(q_full.enqueue_many(([1.0, 2.0, 3.0, 4.0, 5.0],)))
+ enqueue_op = q_full.enqueue((6.0,))
+ enqueue_many_op = q_full.enqueue_many(([6.0],))
+
+ threads = [
+ self.checkedThread(self._blockingDequeue, args=(sess, dequeue_op)),
+ self.checkedThread(self._blockingDequeueMany, args=(sess,
+ dequeue_many_op)),
+ self.checkedThread(self._blockingEnqueue, args=(sess, enqueue_op)),
+ self.checkedThread(self._blockingEnqueueMany, args=(sess,
+ enqueue_many_op))]
+ for t in threads:
+ t.start()
+ time.sleep(0.1)
+ sess.close() # Will cancel the blocked operations.
+ for t in threads:
+ t.join()
+
+ def testDequeueManyInDifferentOrders(self):
+ with self.test_session():
+ # Specify seeds to make the test deterministic
+ # (https://en.wikipedia.org/wiki/Taxicab_number).
+ q1 = tf.RandomShuffleQueue(10, 5, tf.int32,
+ ((),), seed=1729)
+ q2 = tf.RandomShuffleQueue(10, 5, tf.int32,
+ ((),), seed=87539319)
+ enq1 = q1.enqueue_many(([1, 2, 3, 4, 5],))
+ enq2 = q2.enqueue_many(([1, 2, 3, 4, 5],))
+ deq1 = q1.dequeue_many(5)
+ deq2 = q2.dequeue_many(5)
+
+ enq1.run()
+ enq1.run()
+ enq2.run()
+ enq2.run()
+
+ results = [[], [], [], []]
+
+ results[0].extend(deq1.eval())
+ results[1].extend(deq2.eval())
+
+ q1.close().run()
+ q2.close().run()
+
+ results[2].extend(deq1.eval())
+ results[3].extend(deq2.eval())
+
+ # No two should match
+ for i in range(1, 4):
+ for j in range(i):
+ self.assertNotEqual(results[i], results[j])
+
+ def testDequeueInDifferentOrders(self):
+ with self.test_session():
+ # Specify seeds to make the test deterministic
+ # (https://en.wikipedia.org/wiki/Taxicab_number).
+ q1 = tf.RandomShuffleQueue(10, 5, tf.int32,
+ ((),), seed=1729)
+ q2 = tf.RandomShuffleQueue(10, 5, tf.int32,
+ ((),), seed=87539319)
+ enq1 = q1.enqueue_many(([1, 2, 3, 4, 5],))
+ enq2 = q2.enqueue_many(([1, 2, 3, 4, 5],))
+ deq1 = q1.dequeue()
+ deq2 = q2.dequeue()
+
+ enq1.run()
+ enq1.run()
+ enq2.run()
+ enq2.run()
+
+ results = [[], [], [], []]
+
+ for _ in range(5):
+ results[0].append(deq1.eval())
+ results[1].append(deq2.eval())
+
+ q1.close().run()
+ q2.close().run()
+
+ for _ in range(5):
+ results[2].append(deq1.eval())
+ results[3].append(deq2.eval())
+
+ # No two should match
+ for i in range(1, 4):
+ for j in range(i):
+ self.assertNotEqual(results[i], results[j])
+
+ def testBigEnqueueMany(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(
+ 5, 0, tf.int32, ((),))
+ elem = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ enq = q.enqueue_many((elem,))
+ deq = q.dequeue()
+ size_op = q.size()
+
+ enq_done = []
+ def blocking_enqueue():
+ enq_done.append(False)
+ # This will fill the queue and then block until enough dequeues happen.
+ sess.run(enq)
+ enq_done.append(True)
+ thread = self.checkedThread(target=blocking_enqueue)
+ thread.start()
+
+ # The enqueue should start and then block.
+ results = []
+ results.append(deq.eval()) # Will only complete after the enqueue starts.
+ self.assertEqual(len(enq_done), 1)
+ self.assertEqual(sess.run(size_op), 5)
+
+ for _ in range(3):
+ results.append(deq.eval())
+
+ time.sleep(0.1)
+ self.assertEqual(len(enq_done), 1)
+ self.assertEqual(sess.run(size_op), 5)
+
+ # This dequeue will unblock the thread.
+ results.append(deq.eval())
+ time.sleep(0.1)
+ self.assertEqual(len(enq_done), 2)
+ thread.join()
+
+ for i in range(5):
+ self.assertEqual(size_op.eval(), 5 - i)
+ results.append(deq.eval())
+ self.assertEqual(size_op.eval(), 5 - i - 1)
+
+ self.assertItemsEqual(elem, results)
+
+ def testBigDequeueMany(self):
+ with self.test_session() as sess:
+ q = tf.RandomShuffleQueue(2, 0, tf.int32, ((),))
+ elem = range(4)
+ enq_list = [q.enqueue((e,)) for e in elem]
+ deq = q.dequeue_many(4)
+
+ results = []
+ def blocking_dequeue():
+ # Will only complete after 4 enqueues complete.
+ results.extend(sess.run(deq))
+ thread = self.checkedThread(target=blocking_dequeue)
+ thread.start()
+ # The dequeue should start and then block.
+ for enq in enq_list:
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ self.assertEqual(len(results), 0)
+ sess.run(enq)
+
+ # Enough enqueued to unblock the dequeue
+ thread.join()
+ self.assertItemsEqual(elem, results)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
new file mode 100644
index 0000000000..484e3eca43
--- /dev/null
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -0,0 +1,362 @@
+"""Tests for Reader ops from io_ops."""
+
+import os
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class IdentityReaderTest(tf.test.TestCase):
+
+ def _ExpectRead(self, sess, key, value, expected):
+ k, v = sess.run([key, value])
+ self.assertAllEqual(expected, k)
+ self.assertAllEqual(expected, v)
+
+ def testOneEpoch(self):
+ with self.test_session() as sess:
+ reader = tf.IdentityReader("test_reader")
+ work_completed = reader.num_work_units_completed()
+ produced = reader.num_records_produced()
+ queue = tf.FIFOQueue(99, [tf.string], shapes=())
+ queued_length = queue.size()
+ key, value = reader.read(queue)
+
+ self.assertAllEqual(0, work_completed.eval())
+ self.assertAllEqual(0, produced.eval())
+ self.assertAllEqual(0, queued_length.eval())
+
+ queue.enqueue_many([["A", "B", "C"]]).run()
+ queue.close().run()
+ self.assertAllEqual(3, queued_length.eval())
+
+ self._ExpectRead(sess, key, value, "A")
+ self.assertAllEqual(1, produced.eval())
+
+ self._ExpectRead(sess, key, value, "B")
+
+ self._ExpectRead(sess, key, value, "C")
+ self.assertAllEqual(3, produced.eval())
+ self.assertAllEqual(0, queued_length.eval())
+
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ sess.run([key, value])
+
+ self.assertAllEqual(3, work_completed.eval())
+ self.assertAllEqual(3, produced.eval())
+ self.assertAllEqual(0, queued_length.eval())
+
+ def testMultipleEpochs(self):
+ with self.test_session() as sess:
+ reader = tf.IdentityReader("test_reader")
+ queue = tf.FIFOQueue(99, [tf.string], shapes=())
+ enqueue = queue.enqueue_many([["DD", "EE"]])
+ key, value = reader.read(queue)
+
+ enqueue.run()
+ self._ExpectRead(sess, key, value, "DD")
+ self._ExpectRead(sess, key, value, "EE")
+ enqueue.run()
+ self._ExpectRead(sess, key, value, "DD")
+ self._ExpectRead(sess, key, value, "EE")
+ enqueue.run()
+ self._ExpectRead(sess, key, value, "DD")
+ self._ExpectRead(sess, key, value, "EE")
+ queue.close().run()
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ sess.run([key, value])
+
+ def testSerializeRestore(self):
+ with self.test_session() as sess:
+ reader = tf.IdentityReader("test_reader")
+ produced = reader.num_records_produced()
+ queue = tf.FIFOQueue(99, [tf.string], shapes=())
+ queue.enqueue_many([["X", "Y", "Z"]]).run()
+ key, value = reader.read(queue)
+
+ self._ExpectRead(sess, key, value, "X")
+ self.assertAllEqual(1, produced.eval())
+ state = reader.serialize_state().eval()
+
+ self._ExpectRead(sess, key, value, "Y")
+ self._ExpectRead(sess, key, value, "Z")
+ self.assertAllEqual(3, produced.eval())
+
+ queue.enqueue_many([["Y", "Z"]]).run()
+ queue.close().run()
+ reader.restore_state(state).run()
+ self.assertAllEqual(1, produced.eval())
+ self._ExpectRead(sess, key, value, "Y")
+ self._ExpectRead(sess, key, value, "Z")
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ sess.run([key, value])
+ self.assertAllEqual(3, produced.eval())
+
+ self.assertEqual(str, type(state))
+
+ with self.assertRaises(ValueError):
+ reader.restore_state([])
+
+ with self.assertRaises(ValueError):
+ reader.restore_state([state, state])
+
+ with self.assertRaisesOpError(
+ "Could not parse state for IdentityReader 'test_reader'"):
+ reader.restore_state(state[1:]).run()
+
+ with self.assertRaisesOpError(
+ "Could not parse state for IdentityReader 'test_reader'"):
+ reader.restore_state(state[:-1]).run()
+
+ with self.assertRaisesOpError(
+ "Could not parse state for IdentityReader 'test_reader'"):
+ reader.restore_state(state + "ExtraJunk").run()
+
+ with self.assertRaisesOpError(
+ "Could not parse state for IdentityReader 'test_reader'"):
+ reader.restore_state("PREFIX" + state).run()
+
+ with self.assertRaisesOpError(
+ "Could not parse state for IdentityReader 'test_reader'"):
+ reader.restore_state("BOGUS" + state[5:]).run()
+
+ def testReset(self):
+ with self.test_session() as sess:
+ reader = tf.IdentityReader("test_reader")
+ work_completed = reader.num_work_units_completed()
+ produced = reader.num_records_produced()
+ queue = tf.FIFOQueue(99, [tf.string], shapes=())
+ queued_length = queue.size()
+ key, value = reader.read(queue)
+
+ queue.enqueue_many([["X", "Y", "Z"]]).run()
+ self._ExpectRead(sess, key, value, "X")
+ self.assertLess(0, queued_length.eval())
+ self.assertAllEqual(1, produced.eval())
+
+ self._ExpectRead(sess, key, value, "Y")
+ self.assertLess(0, work_completed.eval())
+ self.assertAllEqual(2, produced.eval())
+
+ reader.reset().run()
+ self.assertAllEqual(0, work_completed.eval())
+ self.assertAllEqual(0, produced.eval())
+ self.assertAllEqual(1, queued_length.eval())
+ self._ExpectRead(sess, key, value, "Z")
+
+ queue.enqueue_many([["K", "L"]]).run()
+ self._ExpectRead(sess, key, value, "K")
+
+
+class WholeFileReaderTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(WholeFileReaderTest, self).setUp()
+ self._filenames = [os.path.join(self.get_temp_dir(), "whole_file.%d.txt" % i)
+ for i in range(3)]
+ self._content = ["One\na\nb\n", "Two\nC\nD", "Three x, y, z"]
+ for fn, c in zip(self._filenames, self._content):
+ open(fn, "w").write(c)
+
+ def tearDown(self):
+ super(WholeFileReaderTest, self).tearDown()
+ for fn in self._filenames:
+ os.remove(fn)
+
+ def _ExpectRead(self, sess, key, value, index):
+ k, v = sess.run([key, value])
+ self.assertAllEqual(self._filenames[index], k)
+ self.assertAllEqual(self._content[index], v)
+
+ def testOneEpoch(self):
+ with self.test_session() as sess:
+ reader = tf.WholeFileReader("test_reader")
+ queue = tf.FIFOQueue(99, [tf.string], shapes=())
+ queue.enqueue_many([self._filenames]).run()
+ queue.close().run()
+ key, value = reader.read(queue)
+
+ self._ExpectRead(sess, key, value, 0)
+ self._ExpectRead(sess, key, value, 1)
+ self._ExpectRead(sess, key, value, 2)
+
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ sess.run([key, value])
+
+ def testInfiniteEpochs(self):
+ with self.test_session() as sess:
+ reader = tf.WholeFileReader("test_reader")
+ queue = tf.FIFOQueue(99, [tf.string], shapes=())
+ enqueue = queue.enqueue_many([self._filenames])
+ key, value = reader.read(queue)
+
+ enqueue.run()
+ self._ExpectRead(sess, key, value, 0)
+ self._ExpectRead(sess, key, value, 1)
+ enqueue.run()
+ self._ExpectRead(sess, key, value, 2)
+ self._ExpectRead(sess, key, value, 0)
+ self._ExpectRead(sess, key, value, 1)
+ enqueue.run()
+ self._ExpectRead(sess, key, value, 2)
+ self._ExpectRead(sess, key, value, 0)
+
+
+class TextLineReaderTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(TextLineReaderTest, self).setUp()
+ self._num_files = 2
+ self._num_lines = 5
+
+ def _LineText(self, f, l):
+ return "%d: %d" % (f, l)
+
+ def _CreateFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
+ filenames.append(fn)
+ f = open(fn, "w")
+ for j in range(self._num_lines):
+ f.write(self._LineText(i, j))
+ # Always include a newline after the record unless it is
+ # at the end of the file, in which case we include it sometimes.
+ if j + 1 != self._num_lines or i == 0:
+ f.write("\n")
+ return filenames
+
+ def testOneEpoch(self):
+ files = self._CreateFiles()
+ with self.test_session() as sess:
+ reader = tf.TextLineReader(name="test_reader")
+ queue = tf.FIFOQueue(99, [tf.string], shapes=())
+ key, value = reader.read(queue)
+
+ queue.enqueue_many([files]).run()
+ queue.close().run()
+ for i in range(self._num_files):
+ for j in range(self._num_lines):
+ k, v = sess.run([key, value])
+ self.assertAllEqual("%s:%d" % (files[i], j + 1), k)
+ self.assertAllEqual(self._LineText(i, j), v)
+
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ k, v = sess.run([key, value])
+
+ def testSkipHeaderLines(self):
+ files = self._CreateFiles()
+ with self.test_session() as sess:
+ reader = tf.TextLineReader(skip_header_lines=1, name="test_reader")
+ queue = tf.FIFOQueue(99, [tf.string], shapes=())
+ key, value = reader.read(queue)
+
+ queue.enqueue_many([files]).run()
+ queue.close().run()
+ for i in range(self._num_files):
+ for j in range(self._num_lines - 1):
+ k, v = sess.run([key, value])
+ self.assertAllEqual("%s:%d" % (files[i], j + 2), k)
+ self.assertAllEqual(self._LineText(i, j + 1), v)
+
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ k, v = sess.run([key, value])
+
+
+class FixedLengthRecordReaderTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(FixedLengthRecordReaderTest, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+ self._header_bytes = 5
+ self._record_bytes = 3
+ self._footer_bytes = 2
+
+ def _Record(self, f, r):
+ return str(f * 2 + r) * self._record_bytes
+
+ def _CreateFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
+ filenames.append(fn)
+ f = open(fn, "w")
+ f.write("H" * self._header_bytes)
+ for j in range(self._num_records):
+ f.write(self._Record(i, j))
+ f.write("F" * self._footer_bytes)
+ return filenames
+
+ def testOneEpoch(self):
+ files = self._CreateFiles()
+ with self.test_session() as sess:
+ reader = tf.FixedLengthRecordReader(
+ header_bytes=self._header_bytes,
+ record_bytes=self._record_bytes,
+ footer_bytes=self._footer_bytes,
+ name="test_reader")
+ queue = tf.FIFOQueue(99, [tf.string], shapes=())
+ key, value = reader.read(queue)
+
+ queue.enqueue_many([files]).run()
+ queue.close().run()
+ for i in range(self._num_files):
+ for j in range(self._num_records):
+ k, v = sess.run([key, value])
+ self.assertAllEqual("%s:%d" % (files[i], j), k)
+ self.assertAllEqual(self._Record(i, j), v)
+
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ k, v = sess.run([key, value])
+
+
+class TFRecordReaderTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(TFRecordReaderTest, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+
+ def _Record(self, f, r):
+ return "Record %d of file %d" % (r, f)
+
+ def _CreateFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
+ filenames.append(fn)
+ writer = tf.python_io.TFRecordWriter(fn)
+ for j in range(self._num_records):
+ writer.write(self._Record(i, j))
+ return filenames
+
+ def testOneEpoch(self):
+ files = self._CreateFiles()
+ with self.test_session() as sess:
+ reader = tf.TFRecordReader(name="test_reader")
+ queue = tf.FIFOQueue(99, [tf.string], shapes=())
+ key, value = reader.read(queue)
+
+ queue.enqueue_many([files]).run()
+ queue.close().run()
+ for i in range(self._num_files):
+ for j in range(self._num_records):
+ k, v = sess.run([key, value])
+ self.assertTrue(k.startswith("%s:" % files[i]))
+ self.assertAllEqual(self._Record(i, j), v)
+
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ k, v = sess.run([key, value])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
new file mode 100644
index 0000000000..e5cab62c09
--- /dev/null
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -0,0 +1,533 @@
+"""Functional tests for reduction ops."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.kernel_tests import gradient_checker
+
+
+class SumReductionTest(tf.test.TestCase):
+
+ def _compare(self, x, reduction_axes, keep_dims, use_gpu=False):
+ np_ans = x
+ if reduction_axes is None:
+ np_ans = np.sum(np_ans, keepdims=keep_dims)
+ else:
+ reduction_axes = np.array(reduction_axes).astype(np.int32)
+ for ra in reduction_axes.ravel()[::-1]:
+ np_ans = np.sum(np_ans, axis=ra, keepdims=keep_dims)
+ with self.test_session(use_gpu=use_gpu):
+ tf_ans = tf.reduce_sum(x, reduction_axes, keep_dims)
+ out = tf_ans.eval()
+ self.assertAllClose(np_ans, out)
+ self.assertShapeEqual(np_ans, tf_ans)
+
+ def _compareAll(self, x, reduction_axes):
+ if reduction_axes is not None and np.shape(reduction_axes) == (1,):
+ # Test scalar reduction_axes argument
+ self._compareAll(x, reduction_axes[0])
+ self._compare(x, reduction_axes, False, use_gpu=True)
+ self._compare(x, reduction_axes, False, use_gpu=False)
+ self._compare(x, reduction_axes, True, use_gpu=True)
+ self._compare(x, reduction_axes, True, use_gpu=False)
+
+ def testFloatReduce1D(self):
+ # Create a 1D array of floats
+ np_arr = np.arange(1, 6).reshape([5]).astype(np.float32)
+ self._compareAll(np_arr, [0])
+
+ def testFloatReduce2D(self):
+ # Create a 2D array of floats and reduce across all possible
+ # dimensions
+ np_arr = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
+ self._compareAll(np_arr, None)
+ self._compareAll(np_arr, [])
+ self._compareAll(np_arr, [0])
+ self._compareAll(np_arr, [1])
+ self._compareAll(np_arr, [0, 1])
+
+ def testFloatReduce3D(self):
+ # Create a 3D array of floats and reduce across all possible
+ # dimensions
+ np_arr = np.arange(0, 30).reshape([2, 3, 5]).astype(np.float32)
+ self._compareAll(np_arr, None)
+ self._compareAll(np_arr, [])
+ self._compareAll(np_arr, [0])
+ self._compareAll(np_arr, [1])
+ self._compareAll(np_arr, [2])
+ self._compareAll(np_arr, [0, 1])
+ self._compareAll(np_arr, [1, 2])
+ self._compareAll(np_arr, [0, 2])
+ self._compareAll(np_arr, [0, 1, 2])
+
+ def testFloatReduce4D(self):
+ # Create a 4D array of floats and reduce across some
+ # dimensions
+ np_arr = np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.float32)
+ self._compareAll(np_arr, None)
+ self._compareAll(np_arr, [])
+ self._compareAll(np_arr, [0])
+ self._compareAll(np_arr, [1])
+ self._compareAll(np_arr, [2])
+ self._compareAll(np_arr, [0, 1])
+ self._compareAll(np_arr, [1, 2])
+ # Need specialization for reduce(4D, [0, 2])
+ # self._compareAll(np_arr, [0, 2])
+ self._compareAll(np_arr, [0, 1, 2])
+ self._compareAll(np_arr, [1, 2, 3])
+ self._compareAll(np_arr, [0, 1, 2, 3])
+
+ def testFloatReduce5D(self):
+ # Create a 5D array of floats and reduce across some dimensions
+ np_arr = np.arange(0, 840).reshape([2, 3, 5, 7, 4]).astype(np.float32)
+ self._compareAll(np_arr, None)
+ self._compareAll(np_arr, [])
+ self._compareAll(np_arr, [0])
+ self._compareAll(np_arr, [1])
+ self._compareAll(np_arr, [2])
+ self._compareAll(np_arr, [0, 1])
+ self._compareAll(np_arr, [1, 2])
+ # Need specialization for reduce(4D, [0, 2])
+ # self._compareAll(np_arr, [0, 2])
+ self._compareAll(np_arr, [0, 1, 2])
+ self._compareAll(np_arr, [1, 2, 3])
+ self._compareAll(np_arr, [0, 1, 2, 3])
+ self._compareAll(np_arr, [1, 2, 3, 4])
+ self._compareAll(np_arr, [0, 1, 2, 3, 4])
+
+ # Simple tests for various tf.
+ def testDoubleReduce1D(self):
+ np_arr = np.arange(1, 6).reshape([5]).astype(np.float64)
+ self._compare(np_arr, [], False)
+ self._compare(np_arr, [0], False)
+
+ def testInt32Reduce1D(self):
+ np_arr = np.arange(1, 6).reshape([5]).astype(np.int32)
+ self._compare(np_arr, [], False)
+ self._compare(np_arr, [0], False)
+
+ def testInvalidIndex(self):
+ np_arr = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
+ input_tensor = tf.convert_to_tensor(np_arr)
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, lambda e: "Invalid reduction dimension" in e.message):
+ tf.reduce_sum(input_tensor, [-1])
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, lambda e: "Invalid reduction dimension" in e.message):
+ tf.reduce_sum(input_tensor, [2])
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, lambda e: "Invalid reduction dimension" in e.message):
+ tf.reduce_sum(input_tensor, [0, 2])
+
+ # Int64??
+
+ def _compareGradient(self, shape, sum_shape, reduction_axes):
+ if reduction_axes is not None and np.shape(reduction_axes) == (1,):
+ # Test scalar reduction_axes argument
+ self._compareGradient(shape, sum_shape, reduction_axes[0])
+ x = np.arange(1.0, 49.0).reshape(shape).astype(np.float64)
+ with self.test_session():
+ t = tf.convert_to_tensor(x)
+ su = tf.reduce_sum(t, reduction_axes)
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ t,
+ shape,
+ su,
+ sum_shape,
+ x_init_value=x,
+ delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
+
+ def testGradient(self):
+ self._compareGradient([2, 3, 4, 2], [2, 2], [1, 2])
+
+ def testGradient2(self):
+ self._compareGradient([2, 3, 4, 2], [2, 4, 2], [1])
+
+ def testGradient3(self):
+ self._compareGradient([2, 3, 4, 2], [2, 3, 2], [2])
+
+ def testGradient4(self):
+ self._compareGradient([2, 3, 4, 2], [], None)
+
+
+class MeanReductionTest(tf.test.TestCase):
+
+ def _compare(self, x, reduction_axes, keep_dims):
+ np_sum = x
+ count = 1
+ for ra in reduction_axes[::-1]:
+ np_sum = np.sum(np_sum, axis=ra, keepdims=keep_dims)
+ count *= x.shape[ra]
+ np_ans = np_sum / count
+ with self.test_session():
+ reduction_axes = np.array(reduction_axes).astype(np.int32)
+ tf_ans = tf.reduce_mean(x, reduction_axes, keep_dims)
+ out = tf_ans.eval()
+ self.assertAllClose(np_ans, out)
+ self.assertShapeEqual(np_ans, tf_ans)
+
+ def _compareAll(self, x, reduction_axes):
+ self._compare(x, reduction_axes, False)
+ self._compare(x, reduction_axes, True)
+
+ def testFloatReduce3D(self):
+ # Create a 3D array of floats and reduce across all possible
+ # dimensions
+ np_arr = np.arange(0, 30).reshape([2, 3, 5]).astype(np.float32)
+ self._compareAll(np_arr, [])
+ self._compareAll(np_arr, [0])
+ self._compareAll(np_arr, [1])
+ self._compareAll(np_arr, [2])
+ self._compareAll(np_arr, [0, 1])
+ self._compareAll(np_arr, [1, 2])
+ self._compareAll(np_arr, [0, 2])
+ self._compareAll(np_arr, [0, 1, 2])
+
+ def testGradient(self):
+ s = [2, 3, 4, 2]
+ x = np.arange(1.0, 49.0).reshape(s).astype(np.float32)
+ with self.test_session():
+ t = tf.convert_to_tensor(x)
+ su = tf.reduce_mean(t, [1, 2])
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ t, s, su, [2, 2], x_init_value=x, delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+
+ su = tf.reduce_mean(t, [0, 1, 2, 3])
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ t, s, su, [1], x_init_value=x, delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+
+ su = tf.reduce_mean(t, [])
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ t, s, su, [2, 3, 4, 2], x_init_value=x, delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+
+
+class ProdReductionTest(tf.test.TestCase):
+
+ def _compare(self, x, reduction_axes, keep_dims):
+ np_ans = x
+ if reduction_axes is None:
+ np_ans = np.prod(np_ans, keepdims=keep_dims)
+ else:
+ for ra in reduction_axes[::-1]:
+ np_ans = np.prod(np_ans, axis=ra, keepdims=keep_dims)
+ with self.test_session():
+ if reduction_axes is not None:
+ reduction_axes = np.array(reduction_axes).astype(np.int32)
+ tf_ans = tf.reduce_prod(x, reduction_axes, keep_dims)
+ out = tf_ans.eval()
+ self.assertAllClose(np_ans, out)
+ self.assertShapeEqual(np_ans, tf_ans)
+
+ def _compareAll(self, x, reduction_axes):
+ self._compare(x, reduction_axes, False)
+ self._compare(x, reduction_axes, True)
+
+ def testFloatReduce3D(self):
+ # Create a 3D array of floats and reduce across all possible
+ # dimensions
+ np_arr = np.arange(0, 30).reshape([2, 3, 5]).astype(np.float32)
+ self._compareAll(np_arr, None)
+ self._compareAll(np_arr, [])
+ self._compareAll(np_arr, [0])
+ self._compareAll(np_arr, [1])
+ self._compareAll(np_arr, [2])
+ self._compareAll(np_arr, [0, 1])
+ self._compareAll(np_arr, [1, 2])
+ self._compareAll(np_arr, [0, 2])
+ self._compareAll(np_arr, [0, 1, 2])
+
+ def testGradient(self):
+ s = [2, 3, 4, 2]
+ # NOTE(kearnes): divide by 20 so product is a reasonable size
+ x = np.arange(1.0, 49.0).reshape(s).astype(np.float32) / 20.
+ with self.test_session():
+ t = tf.convert_to_tensor(x)
+
+ su = tf.reduce_prod(t, [])
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ t, s, su, [2, 3, 4, 2], x_init_value=x, delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+
+ su = tf.reduce_prod(t, [1, 2])
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ t, s, su, [2, 2], x_init_value=x, delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+
+ su = tf.reduce_prod(t, [0, 1, 2, 3])
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ t, s, su, [1], x_init_value=x, delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+
+ # NOTE(kearnes): the current gradient calculation gives NaNs for 0 inputs
+ x = np.arange(0.0, 48.0).reshape(s).astype(np.float32) / 20.
+ with self.test_session():
+ t = tf.convert_to_tensor(x)
+ su = tf.reduce_prod(t, [])
+ jacob_t, _ = gradient_checker.ComputeGradient(
+ t, s, su, [2, 3, 4, 2], x_init_value=x, delta=1)
+ with self.assertRaisesOpError("Tensor had NaN values"):
+ tf.check_numerics(jacob_t, message="_ProdGrad NaN test").op.run()
+
+
+class MinReductionTest(tf.test.TestCase):
+
+ def _compare(self, x, reduction_axes, keep_dims, use_gpu=False):
+ np_ans = x
+ if reduction_axes is None:
+ np_ans = np.amin(np_ans, keepdims=keep_dims)
+ else:
+ for ra in reduction_axes[::-1]:
+ np_ans = np.amin(np_ans, axis=ra, keepdims=keep_dims)
+ with self.test_session(use_gpu=use_gpu):
+ if reduction_axes is not None:
+ reduction_axes = np.array(reduction_axes).astype(np.int32)
+ tf_ans = tf.reduce_min(x, reduction_axes, keep_dims)
+ out = tf_ans.eval()
+ self.assertAllClose(np_ans, out)
+ self.assertShapeEqual(np_ans, tf_ans)
+
+ def _compareAll(self, x, reduction_axes):
+ self._compare(x, reduction_axes, False, use_gpu=True)
+ self._compare(x, reduction_axes, False, use_gpu=False)
+ self._compare(x, reduction_axes, True, use_gpu=True)
+ self._compare(x, reduction_axes, True, use_gpu=False)
+
+ def testFloatReduce3D(self):
+ # Create a 3D array of floats and reduce across all possible
+ # dimensions
+ np_arr = np.arange(0, 30).reshape([2, 3, 5]).astype(np.float32)
+ self._compareAll(np_arr, [])
+ self._compareAll(np_arr, [0])
+ self._compareAll(np_arr, [1])
+ self._compareAll(np_arr, [2])
+ self._compareAll(np_arr, [0, 1])
+ self._compareAll(np_arr, [1, 2])
+ self._compareAll(np_arr, [0, 2])
+ self._compareAll(np_arr, [0, 1, 2])
+
+ def testGradient(self):
+ s = [2, 3, 4, 2]
+ x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
+ with self.test_session():
+ t = tf.convert_to_tensor(x)
+ su = tf.reduce_min(t, [1, 2])
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ t, s, su, [2, 2], x_init_value=x, delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
+
+ def testGradient2(self):
+ s = [2, 3, 4, 2]
+ x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
+ with self.test_session():
+ t = tf.convert_to_tensor(x)
+ su = tf.reduce_min(t, [1])
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ t, s, su, [2, 4, 2], x_init_value=x, delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
+
+ def testGradient3(self):
+ s = [2, 3, 4, 2]
+ x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
+ with self.test_session():
+ t = tf.convert_to_tensor(x)
+ su = tf.reduce_min(t, [2])
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ t, s, su, [2, 3, 2], x_init_value=x, delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
+
+ def testGradient4(self):
+ s = [2, 3, 4, 2]
+ x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
+ with self.test_session():
+ t = tf.convert_to_tensor(x)
+ su = tf.reduce_min(t)
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ t, s, su, [1], x_init_value=x, delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
+
+
+class MaxReductionTest(tf.test.TestCase):
+
+ def _compare(self, x, reduction_axes, keep_dims, use_gpu=False):
+ np_ans = x
+ if reduction_axes is None:
+ np_ans = np.amax(np_ans, keepdims=keep_dims)
+ else:
+ for ra in reduction_axes[::-1]:
+ np_ans = np.amax(np_ans, axis=ra, keepdims=keep_dims)
+ with self.test_session(use_gpu=use_gpu):
+ if reduction_axes is not None:
+ reduction_axes = np.array(reduction_axes).astype(np.int32)
+ tf_ans = tf.reduce_max(x, reduction_axes, keep_dims)
+ out = tf_ans.eval()
+ self.assertAllClose(np_ans, out)
+ self.assertShapeEqual(np_ans, tf_ans)
+
+ def _compareAll(self, x, reduction_axes):
+ self._compare(x, reduction_axes, False, use_gpu=True)
+ self._compare(x, reduction_axes, False, use_gpu=False)
+ self._compare(x, reduction_axes, True, use_gpu=True)
+ self._compare(x, reduction_axes, True, use_gpu=False)
+
+ def testFloatReduce3D(self):
+ # Create a 3D array of floats and reduce across all possible
+ # dimensions
+ np_arr = np.arange(0, 30).reshape([2, 3, 5]).astype(np.float32)
+ self._compareAll(np_arr, None)
+ self._compareAll(np_arr, [])
+ self._compareAll(np_arr, [0])
+ self._compareAll(np_arr, [1])
+ self._compareAll(np_arr, [2])
+ self._compareAll(np_arr, [0, 1])
+ self._compareAll(np_arr, [1, 2])
+ self._compareAll(np_arr, [0, 2])
+ self._compareAll(np_arr, [0, 1, 2])
+
+ def testGradient(self):
+ s = [2, 3, 4, 2]
+ x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
+ with self.test_session():
+ t = tf.convert_to_tensor(x)
+ su = tf.reduce_max(t, [1, 2])
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ t, s, su, [2, 2], x_init_value=x, delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
+
+ def testGradient2(self):
+ s = [2, 3, 4, 2]
+ x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
+ with self.test_session():
+ t = tf.convert_to_tensor(x)
+ su = tf.reduce_max(t, [1])
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ t, s, su, [2, 4, 2], x_init_value=x, delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
+
+ def testGradient3(self):
+ s = [2, 3, 4, 2]
+ x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
+ with self.test_session():
+ t = tf.convert_to_tensor(x)
+ su = tf.reduce_max(t, [2])
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ t, s, su, [2, 3, 2], x_init_value=x, delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
+
+ def testGradient4(self):
+ s = [2, 3, 4, 2]
+ x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
+ with self.test_session():
+ t = tf.convert_to_tensor(x)
+ su = tf.reduce_max(t)
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ t, s, su, [1], x_init_value=x, delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
+
+
+class AllReductionTest(tf.test.TestCase):
+
+ def _compare(self, x, reduction_axes, keep_dims, use_gpu=False):
+ np_ans = x
+ if reduction_axes is None:
+ np_ans = np.all(np_ans, keepdims=keep_dims)
+ else:
+ for ra in reduction_axes[::-1]:
+ np_ans = np.all(np_ans, axis=ra, keepdims=keep_dims)
+ with self.test_session(use_gpu=use_gpu):
+ if reduction_axes is not None:
+ reduction_axes = np.array(reduction_axes).astype(np.int32)
+ tf_ans = tf.reduce_all(x, reduction_axes, keep_dims)
+ out = tf_ans.eval()
+ self.assertAllEqual(np_ans, out)
+ self.assertShapeEqual(np_ans, tf_ans)
+
+ def _compareAll(self, x, reduction_axes):
+ self._compare(x, reduction_axes, False, use_gpu=True)
+ self._compare(x, reduction_axes, False, use_gpu=False)
+ self._compare(x, reduction_axes, True, use_gpu=True)
+ self._compare(x, reduction_axes, True, use_gpu=False)
+
+ def testAll3D(self):
+ # Create a 3D array of bools and reduce across all possible
+ # dimensions
+ np_arr = (np.random.uniform(0, 1, 30) > 0.1).reshape([2, 3, 5])
+ self._compareAll(np_arr, None)
+ self._compareAll(np_arr, [])
+ self._compareAll(np_arr, [0])
+ self._compareAll(np_arr, [1])
+ self._compareAll(np_arr, [2])
+ self._compareAll(np_arr, [0, 1])
+ self._compareAll(np_arr, [1, 2])
+ self._compareAll(np_arr, [0, 2])
+ self._compareAll(np_arr, [0, 1, 2])
+
+
+class AnyReductionTest(tf.test.TestCase):
+
+ def _compare(self, x, reduction_axes, keep_dims, use_gpu=False):
+ np_ans = x
+ if reduction_axes is None:
+ np_ans = np.any(np_ans, keepdims=keep_dims)
+ else:
+ for ra in reduction_axes[::-1]:
+ np_ans = np.any(np_ans, axis=ra, keepdims=keep_dims)
+ with self.test_session(use_gpu=use_gpu):
+ if reduction_axes is not None:
+ reduction_axes = np.array(reduction_axes).astype(np.int32)
+ tf_ans = tf.reduce_any(x, reduction_axes, keep_dims)
+ out = tf_ans.eval()
+ self.assertAllEqual(np_ans, out)
+ self.assertShapeEqual(np_ans, tf_ans)
+
+ def _compareAll(self, x, reduction_axes):
+ self._compare(x, reduction_axes, False, use_gpu=True)
+ self._compare(x, reduction_axes, False, use_gpu=False)
+ self._compare(x, reduction_axes, True, use_gpu=True)
+ self._compare(x, reduction_axes, True, use_gpu=False)
+
+ def testAll3D(self):
+ # Create a 3D array of bools and reduce across all possible
+ # dimensions
+ np_arr = (np.random.uniform(0, 1, 30) > 0.9).reshape([2, 3, 5])
+ self._compareAll(np_arr, None)
+ self._compareAll(np_arr, [])
+ self._compareAll(np_arr, [0])
+ self._compareAll(np_arr, [1])
+ self._compareAll(np_arr, [2])
+ self._compareAll(np_arr, [0, 1])
+ self._compareAll(np_arr, [1, 2])
+ self._compareAll(np_arr, [0, 2])
+ self._compareAll(np_arr, [0, 1, 2])
+
+ def testPartialShapes(self):
+ # Input shape is unknown.
+ c_unknown = tf.placeholder(tf.float32)
+ s_unknown = tf.reduce_sum(c_unknown, [1, 2])
+ self.assertEqual(tensor_shape.unknown_shape(), s_unknown.get_shape())
+
+ # Input shape only has known rank.
+ c_known_rank = tf.placeholder(tf.float32)
+ c_known_rank.set_shape(tensor_shape.unknown_shape(ndims=3))
+ s_known_rank = tf.reduce_sum(c_known_rank, [1, 2], keep_dims=True)
+ self.assertEqual(3, s_known_rank.get_shape().ndims)
+
+ # Reduction indices are unknown.
+ unknown_indices = tf.placeholder(tf.int32)
+ c_unknown_indices = tf.constant([[10.0], [20.0]])
+ s_unknown_indices = tf.reduce_sum(c_unknown_indices, unknown_indices,
+ keep_dims=False)
+ self.assertEqual(tensor_shape.unknown_shape(),
+ s_unknown_indices.get_shape())
+ s_unknown_indices_keep = tf.reduce_sum(c_unknown_indices, unknown_indices,
+ keep_dims=True)
+ self.assertEqual(2, s_unknown_indices_keep.get_shape().ndims)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
new file mode 100644
index 0000000000..a4b353f253
--- /dev/null
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -0,0 +1,181 @@
+"""Tests for Relu and ReluGrad."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+
+class ReluTest(tf.test.TestCase):
+
+ def _npRelu(self, np_features):
+ return np.maximum(np_features, np.zeros(np_features.shape))
+
+ def testNpRelu(self):
+ self.assertAllClose(
+ np.array([[0.0, 0.7, 0.0, 0.3, 0.0],
+ [0.1, 0.0, 0.5, 0.0, 0.9]]),
+ self._npRelu(np.array([[-0.9, 0.7, -0.5, 0.3, -0.1],
+ [0.1, -0.3, 0.5, -0.7, 0.9]])))
+
+ def _testRelu(self, np_features, use_gpu=False):
+ np_relu = self._npRelu(np_features)
+ with self.test_session(use_gpu=use_gpu):
+ relu = tf.nn.relu(np_features)
+ tf_relu = relu.eval()
+ self.assertAllClose(np_relu, tf_relu)
+ self.assertShapeEqual(np_relu, relu)
+
+ def testNumbers(self):
+ for t in [np.int32, np.int64, np.float, np.double]:
+ self._testRelu(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
+ use_gpu=False)
+ if t in [np.float, np.double]:
+ self._testRelu(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
+ use_gpu=True)
+
+ # The gradient test for ReLU is a bit tricky as the derivative is not well
+ # defined at around zero and we want to avoid that in terms of input values.
+ def testGradientFloat(self):
+ with self.test_session():
+ x = tf.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5], name="x")
+ y = tf.nn.relu(x, name="relu")
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float32, order="F")
+ err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init)
+ print "relu (float) gradient err = ", err
+ self.assertLess(err, 1e-4)
+
+ def testGradientNaN(self):
+ with self.test_session():
+ # Note the NaN is injected as an input to the gradient calculation.
+ x = tf.constant(
+ [-0.9, -0.7, -0.5, -0.3, np.nan, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5], name="x")
+ y = tf.nn.relu(x, name="relu")
+ grad_ys = tf.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5], name="ys")
+ g_op = tf.gradients(
+ [y], [x], grad_ys=[grad_ys], name="gradients")[0]
+ try:
+ g_op.op.run()
+ assert False, "ReluGrad should have failed due to CheckNumerics."
+ except Exception as e: # pylint: disable=broad-except
+ assert "ReluGrad input is not finite." in str(e)
+
+ def testGradientDouble(self):
+ with self.test_session():
+ x = tf.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5], dtype=tf.float64, name="x")
+ y = tf.nn.relu(x, name="relu")
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float64, order="F")
+ err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init)
+ print "relu (double) gradient err = ", err
+ self.assertLess(err, 1e-10)
+
+ def testGradGradFloat(self):
+ with self.test_session():
+ x = tf.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5], name="x")
+ y = tf.nn.relu(x, name="relu")
+ z = tf.gradients(y, x)
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float32, order="F")
+ err = gc.ComputeGradientError(x, [2, 5], z[0], [2, 5],
+ x_init_value=x_init)
+ print "relu (float) gradient of gradient err = ", err
+ self.assertLess(err, 1e-4)
+
+ def testGradGradDouble(self):
+ with self.test_session():
+ x = tf.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5], dtype=tf.float64, name="x")
+ y = tf.nn.relu(x, name="relu")
+ z = tf.gradients(y, x)
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float64, order="F")
+ err = gc.ComputeGradientError(x, [2, 5], z[0], [2, 5],
+ x_init_value=x_init)
+ print "relu (double) gradient of gradient err = ", err
+ self.assertLess(err, 1e-10)
+
+
+class Relu6Test(tf.test.TestCase):
+
+ def _npRelu6(self, np_features):
+ sixes = np.copy(np_features)
+ sixes.fill(6.0)
+ return np.minimum(np.maximum(np_features, np.zeros(np_features.shape)),
+ sixes)
+
+ def testNpRelu6(self):
+ self.assertAllClose(
+ np.array([[0.0, 0.7, 0.0, 0.3, 6.0],
+ [0.1, 0.0, 6.0, 0.0, 0.9]]),
+ self._npRelu6(np.array([[-0.9, 0.7, -0.5, 0.3, 6.0],
+ [0.1, -0.3, 6.5, -0.7, 0.9]])))
+
+ def _testRelu6(self, np_features, use_gpu=False):
+ np_relu6 = self._npRelu6(np_features)
+ with self.test_session(use_gpu=use_gpu):
+ relu6 = tf.nn.relu6(np_features)
+ tf_relu6 = relu6.eval()
+ self.assertAllClose(np_relu6, tf_relu6)
+ self.assertShapeEqual(np_relu6, relu6)
+
+ def testNumbers(self):
+ for t in [np.int32, np.int64, np.float, np.double]:
+ self._testRelu6(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
+ use_gpu=False)
+ if t in [np.float, np.double]:
+ self._testRelu6(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
+ use_gpu=True)
+
+ # The gradient test for ReLU6 is a bit tricky as the derivative is
+ # not well defined at around zero and six and we want to avoid that
+ # in terms of input values.
+ def testGradientFloat(self):
+ with self.test_session():
+ x = tf.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9],
+ shape=[2, 5], name="x")
+ y = tf.nn.relu6(x, name="relu6")
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [6.1, 6.3, 6.5, 6.7, 6.9]],
+ dtype=np.float32, order="F")
+ err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init)
+ print "relu6 (float) gradient err = ", err
+ self.assertLess(err, 1e-4)
+
+ def testGradientDouble(self):
+ with self.test_session():
+ x = tf.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9],
+ shape=[2, 5], dtype=tf.float64, name="x")
+ y = tf.nn.relu6(x, name="relu6")
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [6.1, 6.3, 6.5, 6.7, 6.9]],
+ dtype=np.float64, order="F")
+ err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init)
+ print "relu6 (double) gradient err = ", err
+ self.assertLess(err, 1e-10)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py
new file mode 100644
index 0000000000..65b0e6d4bf
--- /dev/null
+++ b/tensorflow/python/kernel_tests/reshape_op_test.py
@@ -0,0 +1,106 @@
+"""Tests for tensorflow.ops.reshape_op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+
+class ReshapeTest(tf.test.TestCase):
+
+ def _testReshape(self, x, y, use_gpu=False):
+ with self.test_session(use_gpu=use_gpu):
+ np_ans = x.reshape(y)
+ tf_ans = tf.reshape(x, y)
+ out = tf_ans.eval()
+ self.assertEqual(tf_ans.get_shape(), out.shape)
+ self.assertShapeEqual(np_ans, tf_ans)
+
+ def _testBothReshape(self, x, y):
+ self._testReshape(x, y, False)
+ self._testReshape(x, y, True)
+
+ def testFloatBasic(self):
+ x = np.arange(1., 7.).reshape([1, 6]).astype(np.float32)
+ self._testBothReshape(x, [2, 3])
+
+ def testDoubleBasic(self):
+ x = np.arange(1., 7.).reshape([1, 6]).astype(np.float64)
+ self._testBothReshape(x, [2, 3])
+
+ def testInt32Basic(self):
+ x = np.arange(1., 7.).reshape([1, 6]).astype(np.int32)
+ self._testBothReshape(x, [2, 3])
+
+ def testSComplexBasic(self):
+ x = np.arange(1., 7.).reshape([1, 6]).astype(np.complex64)
+ self._testBothReshape(x, [2, 3])
+
+ def testFloatReshapeThreeDimensions(self):
+ x = np.arange(1., 28.).reshape([1, 27]).astype(np.float32)
+ self._testBothReshape(x, [3, 3, 3])
+
+ def testFloatUnspecifiedDimOnly(self):
+ x = np.arange(1., 7.).reshape([6]).astype(np.float32)
+ self._testBothReshape(x, [-1])
+
+ def testFloatUnspecifiedDimBegin(self):
+ x = np.arange(1., 7.).reshape([6]).astype(np.float32)
+ self._testBothReshape(x, [-1, 2])
+
+ def testFloatUnspecifiedDimEnd(self):
+ x = np.arange(1., 7.).reshape([6]).astype(np.float32)
+ self._testBothReshape(x, [3, -1])
+
+ # TODO(vrv): Add tests for failure conditions once python test_util
+ # reports errors.
+
+ def testFloatReshapeGradThreeDimensions(self):
+ x = np.arange(1., 25.).reshape([1, 24]).astype(np.float32)
+ s = list(np.shape(x))
+ with self.test_session():
+ input_tensor = tf.constant(x, shape=[2, 3, 4])
+ reshape_out = tf.reshape(input_tensor, [1, 8, 3])
+ err = gc.ComputeGradientError(input_tensor, s,
+ reshape_out, s, x_init_value=x)
+ print "Reshape gradient error = " % err
+ self.assertLess(err, 1e-3)
+
+ def testFloatEmpty(self):
+ x = np.empty((0, 0, 0, 0), dtype=np.float32)
+ self._testBothReshape(x, [1, 2, 3, 0])
+ self._testBothReshape(x, [1, 0, 0, 4])
+ self._testBothReshape(x, [0, 0, 0, 0])
+ self._testBothReshape(x, [1, 2, 0])
+ self._testBothReshape(x, [0, 0, 0])
+ self._testBothReshape(x, [1, -1, 5])
+
+ def testErrors(self):
+ x = tf.constant(0.0, shape=[1, 0, 3])
+ with self.assertRaisesRegexp(
+ ValueError, "cannot infer the missing input size"):
+ tf.reshape(x, [0, -1, 5])
+
+ y = tf.constant(0.0, shape=[23, 29, 31])
+ with self.assertRaisesRegexp(ValueError, "isn't divisible by 17"):
+ tf.reshape(y, [17, -1])
+
+ def testPartialShapes(self):
+ x = tf.placeholder(tf.float32)
+
+ # Unknown input shape, partial new shape.
+ y = tf.reshape(x, [1, 1, -1, 1])
+ self.assertEqual([1, 1, None, 1], y.get_shape().as_list())
+
+ # Unknown input shape, unknown new shape.
+ y = tf.reshape(x, tf.placeholder(tf.int32))
+ self.assertEqual(None, y.get_shape().ndims)
+
+ # Unknown input shape, known rank for new shape.
+ y = tf.reshape(x, tf.placeholder(tf.int32, shape=(3,)))
+ self.assertEqual([None, None, None], y.get_shape().as_list())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
new file mode 100644
index 0000000000..7cfbcd7946
--- /dev/null
+++ b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
@@ -0,0 +1,109 @@
+"""Tests for tensorflow.ops.reverse_sequence_op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+
+class ReverseSequenceTest(tf.test.TestCase):
+
+ def _testReverseSequence(self, x, seq_dim, seq_lengths,
+ truth, use_gpu=False, expected_err_re=None):
+ with self.test_session(use_gpu=use_gpu):
+ ans = tf.reverse_sequence(x,
+ seq_dim=seq_dim,
+ seq_lengths=seq_lengths)
+ if expected_err_re is None:
+ tf_ans = ans.eval()
+ self.assertAllClose(tf_ans, truth, atol=1e-10)
+ self.assertShapeEqual(truth, ans)
+ else:
+ with self.assertRaisesOpError(expected_err_re):
+ ans.eval()
+
+ def _testBothReverseSequence(self, x, seq_dim, seq_lengths,
+ truth, expected_err_re=None):
+ self._testReverseSequence(x, seq_dim, seq_lengths,
+ truth, True, expected_err_re)
+ self._testReverseSequence(x, seq_dim, seq_lengths,
+ truth, False, expected_err_re)
+
+ def _testBasic(self, dtype):
+ x = np.asarray([
+ [[1, 2, 3, 4], [5, 6, 7, 8]],
+ [[9, 10, 11, 12], [13, 14, 15, 16]],
+ [[17, 18, 19, 20], [21, 22, 23, 24]]], dtype=dtype)
+ x = x.reshape(3, 2, 4, 1, 1)
+
+ # reverse dim 2 up to (0:3, none, 0:4) along dim=0
+ seq_dim = 2
+ seq_lengths = np.asarray([3, 0, 4], dtype=np.int64)
+
+ truth = np.asarray(
+ [[[3, 2, 1, 4], [7, 6, 5, 8]], # reverse 0:3
+ [[9, 10, 11, 12], [13, 14, 15, 16]], # reverse none
+ [[20, 19, 18, 17], [24, 23, 22, 21]]], # reverse 0:4 (all)
+ dtype=dtype)
+ truth = truth.reshape(3, 2, 4, 1, 1)
+ self._testBothReverseSequence(x, seq_dim, seq_lengths, truth)
+
+ def testFloatBasic(self):
+ self._testBasic(np.float32)
+
+ def testDoubleBasic(self):
+ self._testBasic(np.float64)
+
+ def testInt32Basic(self):
+ self._testBasic(np.int32)
+
+ def testInt64Basic(self):
+ self._testBasic(np.int64)
+
+ def testSComplexBasic(self):
+ self._testBasic(np.complex64)
+
+ def testFloatReverseSequenceGrad(self):
+ x = np.asarray([
+ [[1, 2, 3, 4], [5, 6, 7, 8]],
+ [[9, 10, 11, 12], [13, 14, 15, 16]],
+ [[17, 18, 19, 20], [21, 22, 23, 24]]], dtype=np.float)
+ x = x.reshape(3, 2, 4, 1, 1)
+
+ # reverse dim 2 up to (0:3, none, 0:4) along dim=0
+ seq_dim = 2
+ seq_lengths = np.asarray([3, 0, 4], dtype=np.int64)
+
+ with self.test_session():
+ input_t = tf.constant(x, shape=x.shape)
+ seq_lengths_t = tf.constant(seq_lengths, shape=seq_lengths.shape)
+ reverse_sequence_out = tf.reverse_sequence(input_t,
+ seq_dim=seq_dim,
+ seq_lengths=seq_lengths_t)
+ err = gc.ComputeGradientError(input_t,
+ x.shape,
+ reverse_sequence_out,
+ x.shape,
+ x_init_value=x)
+ print "ReverseSequence gradient error = %g" % err
+ self.assertLess(err, 1e-8)
+
+ def testShapeFunctionEdgeCases(self):
+ # Batch size mismatched between input and seq_lengths.
+ with self.assertRaises(ValueError):
+ tf.reverse_sequence(
+ tf.placeholder(tf.float32, shape=(32, 2, 3)),
+ seq_lengths=tf.placeholder(tf.int64, shape=(33,)),
+ seq_dim=3)
+
+ # seq_dim out of bounds.
+ with self.assertRaisesRegexp(ValueError, "seq_dim must be < input.dims()"):
+ tf.reverse_sequence(
+ tf.placeholder(tf.float32, shape=(32, 2, 3)),
+ seq_lengths=tf.placeholder(tf.int64, shape=(32,)),
+ seq_dim=3)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/save_restore_ops_test.py b/tensorflow/python/kernel_tests/save_restore_ops_test.py
new file mode 100644
index 0000000000..d59d76c58f
--- /dev/null
+++ b/tensorflow/python/kernel_tests/save_restore_ops_test.py
@@ -0,0 +1,21 @@
+"""Tests for tensorflow.ops.io_ops."""
+import tensorflow.python.platform
+
+import tensorflow as tf
+from tensorflow.python.ops import gen_io_ops
+
+
+class ShardedFileOpsTest(tf.test.TestCase):
+
+ def testShardedFileName(self):
+ with tf.Session(
+ target="",
+ config=tf.ConfigProto(device_count={"CPU": 2})):
+ self.assertEqual(gen_io_ops._sharded_filename("foo", 4, 100).eval(),
+ "foo-00004-of-00100")
+ self.assertEqual(gen_io_ops._sharded_filespec("foo", 100).eval(),
+ "foo-?????-of-00100")
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py
new file mode 100644
index 0000000000..dd645819a3
--- /dev/null
+++ b/tensorflow/python/kernel_tests/scatter_ops_test.py
@@ -0,0 +1,49 @@
+"""Tests for tensorflow.ops.tf.scatter."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class ScatterTest(tf.test.TestCase):
+
+ def _VariableRankTest(self, np_scatter, tf_scatter):
+ np.random.seed(8)
+ with self.test_session():
+ for indices_shape in (), (2,), (2, 3), (2, 3, 4):
+ for extra_shape in (), (5,), (5, 6):
+ # Generate random indices with no duplicates for easy numpy comparison
+ size = np.prod(indices_shape, dtype=np.int32)
+ indices = np.arange(2 * size)
+ np.random.shuffle(indices)
+ indices = indices[:size].reshape(indices_shape)
+ updates = np.random.randn(*(indices_shape + extra_shape))
+ old = np.random.randn(*((2 * size,) + extra_shape))
+ # Scatter via numpy
+ new = old.copy()
+ np_scatter(new, indices, updates)
+ # Scatter via tensorflow
+ ref = tf.Variable(old)
+ ref.initializer.run()
+ tf_scatter(ref, indices, updates).eval()
+ # Compare
+ self.assertAllClose(ref.eval(), new)
+
+ def testVariableRankUpdate(self):
+ def update(ref, indices, updates):
+ ref[indices] = updates
+ self._VariableRankTest(update, tf.scatter_update)
+
+ def testVariableRankAdd(self):
+ def add(ref, indices, updates):
+ ref[indices] += updates
+ self._VariableRankTest(add, tf.scatter_add)
+
+ def testVariableRankSub(self):
+ def sub(ref, indices, updates):
+ ref[indices] -= updates
+ self._VariableRankTest(sub, tf.scatter_sub)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
new file mode 100644
index 0000000000..558ce06285
--- /dev/null
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -0,0 +1,269 @@
+"""Functional tests for segment reduction ops."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker
+
+
+class SegmentReductionHelper(tf.test.TestCase):
+
+ def _input(self, input_shape, dtype=tf.int32):
+ num_elem = 1
+ for x in input_shape:
+ num_elem *= x
+ values = range(1, num_elem + 1)
+ np_values = np.array(values).reshape(input_shape).astype(
+ dtype.as_numpy_dtype)
+ return tf.constant(values, shape=input_shape,
+ dtype=dtype), np_values
+
+ def _segmentReduce(self, indices, x, op1, op2=None, num_out_rows=None):
+ if not x.size: return np.array([])
+ indices = np.asarray(indices)
+ if num_out_rows is None:
+ num_out_rows = indices[-1] + 1
+ output = [None] * num_out_rows
+ slice_shape = x.shape[indices.ndim:]
+ x_flat = x.reshape((indices.size,) + slice_shape)
+ for i, index in enumerate(indices.ravel()):
+ if output[index] is not None:
+ output[index] = op1(output[index], x_flat[i])
+ else:
+ output[index] = x_flat[i]
+ # zero initialize values that are still uncalcuated.
+ output = [o if o is not None else np.zeros(slice_shape) for o in output]
+ if op2 is not None:
+ output = [op2(o) for o in output]
+ output = [o.reshape(slice_shape) for o in output]
+ return np.array(output)
+
+ def _assertAllClose(self, indices, np_x, tf_x):
+ for i in set(np.asarray(indices).ravel()):
+ self.assertAllClose(np_x[i], tf_x[i])
+
+ def _mean_cum_op(self, x, y):
+ return (x[0] + y, x[1] + 1) if isinstance(x, tuple) else (x + y, 2)
+
+ def _mean_reduce_op(self, x):
+ return x[0] / x[1] if isinstance(x, tuple) else x
+
+
+class SegmentReductionOpTest(SegmentReductionHelper):
+
+ def testValues(self):
+ dtypes = [tf.float32,
+ tf.float64,
+ tf.int64,
+ tf.int32]
+
+ # Each item is np_op1, np_op2, tf_op
+ ops_list = [(np.add, None, tf.segment_sum),
+ (self._mean_cum_op, self._mean_reduce_op,
+ tf.segment_mean),
+ (np.ndarray.__mul__, None, tf.segment_prod),
+ (np.minimum, None, tf.segment_min),
+ (np.maximum, None, tf.segment_max)]
+
+ n = 10
+ shape = [n, 2]
+ indices = [int(i / 3) for i in range(n)]
+ for dtype in dtypes:
+ with self.test_session(use_gpu=False):
+ tf_x, np_x = self._input(shape, dtype=dtype)
+ for np_op1, np_op2, tf_op in ops_list:
+ np_ans = self._segmentReduce(indices, np_x, np_op1, np_op2)
+ s = tf_op(data=tf_x, segment_ids=indices)
+ tf_ans = s.eval()
+ self._assertAllClose(indices, np_ans, tf_ans)
+ # NOTE(mrry): The static shape inference that computes
+ # `tf_ans.shape` can only infer that sizes from dimension 1
+ # onwards, because the size of dimension 0 is data-dependent
+ # and may therefore vary dynamically.
+ self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:])
+
+ def testSegmentIdsShape(self):
+ shape = [4, 4]
+ tf_x, _ = self._input(shape)
+ indices = tf.constant([0, 1, 2, 2], shape=[2, 2])
+ with self.assertRaises(ValueError):
+ tf.segment_sum(data=tf_x, segment_ids=indices)
+
+ def testSegmentIdsSize(self):
+ shape = [4, 4]
+ with self.test_session():
+ tf_x, _ = self._input(shape)
+ indices = [0, 1]
+ s = tf.segment_sum(data=tf_x, segment_ids=indices)
+ with self.assertRaisesOpError("segment_ids should be the same size"):
+ s.eval()
+
+ def testGradient(self):
+ shape = [4, 4]
+ indices = [0, 1, 2, 2]
+ for tf_op in [tf.segment_sum,
+ tf.segment_mean,
+ tf.segment_min,
+ tf.segment_max]:
+ with self.test_session():
+ tf_x, np_x = self._input(shape, dtype=tf.float64)
+ s = tf_op(data=tf_x, segment_ids=indices)
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ tf_x, shape, s, [3, 4], x_init_value=np_x.astype(np.double),
+ delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+
+
+class UnsortedSegmentSumTest(SegmentReductionHelper):
+
+ def testValues(self):
+ dtypes = [tf.float32,
+ tf.float64,
+ tf.int64,
+ tf.int32]
+ indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
+ num_segments = 12
+ for indices in indices_flat, indices_flat.reshape(5, 2):
+ shape = indices.shape + (2,)
+ for dtype in dtypes:
+ with self.test_session(use_gpu=False):
+ tf_x, np_x = self._input(shape, dtype=dtype)
+ np_ans = self._segmentReduce(indices,
+ np_x,
+ np.add,
+ op2=None,
+ num_out_rows=num_segments)
+ s = tf.unsorted_segment_sum(data=tf_x,
+ segment_ids=indices,
+ num_segments=num_segments)
+ tf_ans = s.eval()
+ self._assertAllClose(indices, np_ans, tf_ans)
+ self.assertShapeEqual(np_ans, s)
+
+ def testGradient(self):
+ num_cols = 2
+ indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
+ num_segments = max(indices_flat) + 3
+ for indices in indices_flat, indices_flat.reshape(5, 2):
+ shape = indices.shape + (num_cols,)
+ with self.test_session():
+ tf_x, np_x = self._input(shape, dtype=tf.float64)
+ s = tf.unsorted_segment_sum(data=tf_x,
+ segment_ids=indices,
+ num_segments=num_segments)
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ tf_x,
+ shape,
+ s,
+ [num_segments, num_cols],
+ x_init_value=np_x.astype(np.double),
+ delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+
+ def testGradientMatchesSegmentSum(self):
+ # Strategy: compute the gradient for UnsortedSegmentSum and SegmentSum
+ # and compare the outputs, which should be identical.
+ # NB: for this test to work, indices must be valid for SegmentSum, namely
+ # it must be sorted, the indices must be contiguous, and num_segments
+ # must be max(indices) + 1.
+ indices = [0, 0, 1, 1, 1, 2, 3, 4, 5]
+ n = len(indices)
+ num_cols = 2
+ shape = [n, num_cols]
+ num_segments = max(indices) + 1
+ with self.test_session():
+ tf_x, np_x = self._input(shape, dtype=tf.float64)
+ # Results from UnsortedSegmentSum
+ unsorted_s = tf.unsorted_segment_sum(data=tf_x,
+ segment_ids=indices,
+ num_segments=num_segments)
+ unsorted_jacob_t, unsorted_jacob_n = gradient_checker.ComputeGradient(
+ tf_x, shape, unsorted_s, [num_segments, num_cols],
+ x_init_value=np_x.astype(np.double),
+ delta=1)
+ # Results from SegmentSum
+ sorted_s = tf.segment_sum(data=tf_x, segment_ids=indices)
+ sorted_jacob_t, sorted_jacob_n = gradient_checker.ComputeGradient(
+ tf_x, shape, sorted_s, [num_segments, num_cols],
+ x_init_value=np_x.astype(np.double),
+ delta=1)
+ self.assertAllClose(unsorted_jacob_t, sorted_jacob_t, rtol=1e-3, atol=1e-3)
+ self.assertAllClose(unsorted_jacob_n, sorted_jacob_n, rtol=1e-3, atol=1e-3)
+
+
+class SparseSegmentReductionHelper(SegmentReductionHelper):
+
+ def _sparse_input(self, input_shape, num_indices,
+ dtype=tf.int32):
+ a, b = super(SparseSegmentReductionHelper, self)._input(input_shape,
+ dtype)
+ indices = np.random.randint(0, input_shape[0], num_indices).astype(np.int32)
+ return (tf.constant(indices, dtype=tf.int32),
+ indices, a, b)
+
+ def _sparseSegmentReduce(self, x, indices, segment_indices, op1, op2=None):
+ return self._segmentReduce(segment_indices, x[indices], op1, op2)
+
+
+class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
+
+ def testValues(self):
+ dtypes = [tf.float32,
+ tf.float64,
+ tf.int64,
+ tf.int32]
+
+ mean_dtypes = [tf.float32,
+ tf.float64]
+
+ # Each item is np_op1, np_op2, tf_op
+ ops_list = [(np.add, None, tf.sparse_segment_sum),
+ (self._mean_cum_op, self._mean_reduce_op,
+ tf.sparse_segment_mean)]
+
+ n = 400
+ shape = [n, 2]
+ segment_indices = []
+ for i in range(20):
+ for _ in range(i + 1):
+ segment_indices.append(i)
+ num_indices = len(segment_indices)
+ for dtype in dtypes:
+ with self.test_session(use_gpu=False):
+ tf_indices, np_indices, tf_x, np_x = self._sparse_input(shape,
+ num_indices,
+ dtype=dtype)
+ for np_op1, np_op2, tf_op in ops_list:
+ if tf_op == tf.sparse_segment_mean and dtype not in mean_dtypes:
+ continue
+ np_ans = self._sparseSegmentReduce(np_x, np_indices, segment_indices,
+ np_op1, np_op2)
+ s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
+ tf_ans = s.eval()
+ self._assertAllClose(segment_indices, np_ans, tf_ans)
+ # NOTE(mrry): The static shape inference that computes
+ # `tf_ans.shape` can only infer that sizes from dimension 1
+ # onwards, because the size of dimension 0 is data-dependent
+ # and may therefore vary dynamically.
+ self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:])
+
+ def testGradient(self):
+ shape = [10, 4]
+
+ segment_indices = [0, 1, 2, 2]
+ num_indices = len(segment_indices)
+ for tf_op in [tf.sparse_segment_sum,
+ tf.sparse_segment_mean]:
+ with self.test_session():
+ tf_indices, _, tf_x, np_x = self._sparse_input(
+ shape, num_indices, dtype=tf.float64)
+ s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
+ jacob_t, jacob_n = gradient_checker.ComputeGradient(
+ tf_x, shape, s, [3, 4], x_init_value=np_x.astype(np.double),
+ delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py
new file mode 100644
index 0000000000..ac97180dbe
--- /dev/null
+++ b/tensorflow/python/kernel_tests/shape_ops_test.py
@@ -0,0 +1,389 @@
+"""Tests for various tensorflow.ops.tf."""
+import tensorflow.python.platform
+
+import numpy as np
+
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+
+class ShapeOpsTest(tf.test.TestCase):
+
+ def _compareShape(self, x, use_gpu=False):
+ np_ans = np.array(np.shape(x))
+ with self.test_session(use_gpu=use_gpu):
+ tf_ans = tf.shape(x)
+ result = tf_ans.eval()
+ self.assertAllEqual(np_ans, result)
+ self.assertShapeEqual(np_ans, tf_ans)
+
+ def _compareRank(self, x, use_gpu=False):
+ np_ans = np.asarray(np.ndim(x))
+ with self.test_session(use_gpu=use_gpu):
+ tf_ans = tf.rank(x)
+ result = tf_ans.eval()
+ self.assertAllEqual(np_ans, result)
+ self.assertShapeEqual(np_ans, tf_ans)
+
+ def _compareSize(self, x, use_gpu=False):
+ np_ans = np.asarray(np.size(x))
+ with self.test_session(use_gpu=use_gpu):
+ tf_ans = tf.size(x)
+ result = tf_ans.eval()
+ self.assertAllEqual(np_ans, result)
+ self.assertShapeEqual(np_ans, tf_ans)
+
+ def _testCpu(self, x):
+ self._compareShape(x, use_gpu=False)
+ self._compareRank(x, use_gpu=False)
+ self._compareSize(x, use_gpu=False)
+
+ def _testGpu(self, x):
+ self._compareShape(x, use_gpu=True)
+ self._compareRank(x, use_gpu=True)
+ self._compareSize(x, use_gpu=True)
+
+ def _testAll(self, x):
+ self._testCpu(x)
+ self._testGpu(x)
+
+ def testBasic(self):
+ self._testAll(np.zeros([2]))
+ self._testAll(np.zeros([2, 3]))
+ self._testAll(np.zeros([2, 3, 5]))
+ self._testAll(np.zeros([2, 3, 5, 7]))
+ self._testAll(np.zeros([2, 3, 5, 7, 11]))
+ self._testAll(np.zeros([2, 3, 5, 7, 11, 13]))
+
+ def _compareExpandDims(self, x, dim, use_gpu):
+ np_ans = np.expand_dims(x, axis=dim)
+ with self.test_session(use_gpu=use_gpu):
+ tensor = tf.expand_dims(x, dim)
+ tf_ans = tensor.eval()
+ self.assertShapeEqual(np_ans, tensor)
+ self.assertAllEqual(np_ans, tf_ans)
+
+ def _compareExpandDimsAll(self, x, dim):
+ self._compareExpandDims(x, dim, False)
+ self._compareExpandDims(x, dim, True)
+
+ def testExpandDims(self):
+ self._compareExpandDimsAll(np.zeros([2]), 0)
+ self._compareExpandDimsAll(np.zeros([2]), 1)
+ self._compareExpandDimsAll(np.zeros([2]), -1)
+
+ self._compareExpandDimsAll(np.zeros([2, 3]), 0)
+ self._compareExpandDimsAll(np.zeros([2, 3]), 1)
+ self._compareExpandDimsAll(np.zeros([2, 3]), 2)
+ self._compareExpandDimsAll(np.zeros([2, 3]), -1)
+ self._compareExpandDimsAll(np.zeros([2, 3]), -2)
+
+ self._compareExpandDimsAll(np.zeros([2, 3, 5]), 0)
+ self._compareExpandDimsAll(np.zeros([2, 3, 5]), 1)
+ self._compareExpandDimsAll(np.zeros([2, 3, 5]), 2)
+ self._compareExpandDimsAll(np.zeros([2, 3, 5]), 3)
+
+ self._compareExpandDimsAll(np.zeros([2, 3, 5]), -1)
+ self._compareExpandDimsAll(np.zeros([2, 3, 5]), -2)
+ self._compareExpandDimsAll(np.zeros([2, 3, 5]), -3)
+ self._compareExpandDimsAll(np.zeros([2, 3, 5]), -4)
+
+ def testExpandDimsErrors(self):
+ with self.test_session():
+ self.assertRaises(ValueError, tf.expand_dims, np.zeros([2, 3, 5]), -5)
+ self.assertRaises(ValueError, tf.expand_dims, np.zeros([2, 3, 5]), 4)
+
+ def testExpandDimsGradient(self):
+ with self.test_session():
+ inp = tf.constant(np.random.rand(4, 2).astype("f"),
+ dtype=tf.float32)
+ squeezed = tf.expand_dims(inp, 1)
+
+ err = gc.ComputeGradientError(inp, [4, 2], squeezed, [4, 1, 2])
+ self.assertLess(err, 1e-3)
+
+ def testExpandDimsScalar(self):
+ with self.test_session():
+ inp = tf.constant(7)
+ self.assertAllEqual([7], tf.expand_dims(inp, 0).eval())
+ self.assertAllEqual([7], tf.expand_dims(inp, -1).eval())
+
+ def _compareSqueeze(self, x, squeeze_dims, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ if squeeze_dims:
+ np_ans = np.squeeze(x, axis=tuple(squeeze_dims))
+ tensor = tf.squeeze(x, squeeze_dims)
+ tf_ans = tensor.eval()
+ else:
+ np_ans = np.squeeze(x)
+ tensor = tf.squeeze(x)
+ tf_ans = tensor.eval()
+ self.assertShapeEqual(np_ans, tensor)
+ self.assertAllEqual(np_ans, tf_ans)
+
+ def _compareSqueezeAll(self, x, squeeze_dims=None):
+ if squeeze_dims is None:
+ squeeze_dims = []
+ self._compareSqueeze(x, squeeze_dims, False)
+ self._compareSqueeze(x, squeeze_dims, True)
+
+ def testSqueeze(self):
+ # Nothing to squeeze.
+ self._compareSqueezeAll(np.zeros([2]))
+ self._compareSqueezeAll(np.zeros([2, 3]))
+
+ # Squeeze the middle element away.
+ self._compareSqueezeAll(np.zeros([2, 1, 2]))
+
+ # Squeeze on both ends.
+ self._compareSqueezeAll(np.zeros([1, 2, 1, 3, 1]))
+
+ def testSqueezeSpecificDimension(self):
+ # Positive squeeze dim index.
+ self._compareSqueezeAll(np.zeros([1, 2, 1, 3, 1]), [0])
+ self._compareSqueezeAll(np.zeros([1, 2, 1, 3, 1]), [2, 4])
+ self._compareSqueezeAll(np.zeros([1, 2, 1, 3, 1]), [0, 4, 2])
+
+ # Negative squeeze dim index.
+ self._compareSqueezeAll(np.zeros([1, 2, 1, 3, 1]), [-1])
+ self._compareSqueezeAll(np.zeros([1, 2, 1, 3, 1]), [-3, -5])
+ self._compareSqueezeAll(np.zeros([1, 2, 1, 3, 1]), [-3, -5, -1])
+
+ def testSqueezeAllOnes(self):
+ # Numpy squeezes a 1 element tensor into a zero dimensional tensor.
+ # Verify that we do the same.
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ tensor = tf.squeeze(np.zeros([1, 1, 1]), [])
+ self.assertEqual(np.shape(1), tensor.get_shape())
+ tf_ans = tensor.eval()
+ self.assertEqual(np.shape(1), tf_ans.shape)
+
+ def testSqueezeOnlyOnes(self):
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ input_1x1x3 = np.zeros([1, 1, 3])
+ self._compareSqueezeAll(input_1x1x3)
+ self._compareSqueezeAll(input_1x1x3, [0])
+ self._compareSqueezeAll(input_1x1x3, [1])
+ self.assertRaises(ValueError, tf.squeeze, input_1x1x3, [2])
+
+ def testSqueezeErrors(self):
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ self.assertRaises(ValueError, tf.squeeze, np.zeros([1, 2, 1]), [-4])
+ self.assertRaises(ValueError, tf.squeeze, np.zeros([1, 2, 1]), [0, -4])
+ self.assertRaises(ValueError, tf.squeeze, np.zeros([1, 2, 1]), [3])
+ self.assertRaises(ValueError, tf.squeeze, np.zeros([1, 2, 1]), [2, 3])
+
+ def testSqueezeGradient(self):
+ with self.test_session():
+ inp = np.random.rand(4, 2).astype("f")
+ a = tf.reshape(inp, [4, 1, 2])
+ squeezed = tf.squeeze(a, [])
+
+ err = gc.ComputeGradientError(a, [4, 1, 2], squeezed, [4, 2])
+ self.assertLess(err, 1e-3)
+
+ def testSqueezeGradientWithSqueezeDims(self):
+ with self.test_session():
+ inp = np.random.rand(4, 2).astype("f")
+ a = tf.reshape(inp, [4, 1, 2, 1])
+ squeezed = tf.squeeze(a, [1])
+
+ err = gc.ComputeGradientError(a, [4, 1, 2, 1], squeezed, [4, 2, 1])
+ self.assertLess(err, 1e-3)
+
+
+class TileTest(tf.test.TestCase):
+
+ def testScalar(self):
+ with self.test_session():
+ a = tf.constant(7, shape=[], dtype=tf.float32)
+ tiled = tf.tile(a, [])
+ result = tiled.eval()
+ self.assertEqual(result.shape, ())
+ self.assertEqual([], tiled.get_shape())
+ self.assertEqual(7, result)
+
+ def testSimple(self):
+ with self.test_session():
+ inp = np.random.rand(4, 1).astype("f")
+ a = tf.constant([float(x) for x in inp.ravel(order="C")],
+ shape=[4, 1], dtype=tf.float32)
+ tiled = tf.tile(a, [1, 4])
+ result = tiled.eval()
+ self.assertEqual(result.shape, (4, 4))
+ self.assertEqual([4, 4], tiled.get_shape())
+ self.assertTrue((result == np.tile(inp, (1, 4))).all())
+
+ def testTypes(self):
+ types_to_test = {
+ "bool": (tf.bool, bool),
+ "float32": (tf.float32, float),
+ "float64": (tf.float64, float),
+ "uint8": (tf.uint8, int),
+ "int32": (tf.int32, int),
+ "int64": (tf.int64, int),
+ "string": (tf.string, str)
+ }
+ for dtype_np, v in types_to_test.iteritems():
+ with self.test_session():
+ dtype_tf = v[0]
+ cast = v[1]
+ inp = np.random.rand(4, 1).astype(dtype_np)
+ a = tf.constant([cast(x) for x in inp.ravel(order="C")],
+ shape=[4, 1],
+ dtype=dtype_tf)
+ tiled = tf.tile(a, [1, 4])
+ result = tiled.eval()
+ self.assertEqual(result.shape, (4, 4))
+ self.assertEqual([4, 4], tiled.get_shape())
+ self.assertTrue((result == np.tile(inp, (1, 4))).all())
+
+ def testInvalidDim(self):
+ with self.test_session():
+ inp = np.random.rand(4, 1).astype("f")
+ a = tf.constant([float(x) for x in inp.ravel(order="C")],
+ shape=[4, 1], dtype=tf.float32)
+ # Wrong length of multiples.
+ with self.assertRaises(ValueError):
+ tf.tile(a, [1, 4, 2])
+ # Wrong rank for multiples.
+ with self.assertRaises(ValueError):
+ tf.tile(a, [[2, 3], [3, 4]]).eval()
+
+ def _RunAndVerifyResult(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ # Random dims of rank 5
+ input_shape = np.random.randint(1, 4, size=5)
+ inp = np.random.rand(*input_shape).astype("f")
+ a = tf.constant([float(x) for x in inp.ravel(order="C")],
+ shape=input_shape, dtype=tf.float32)
+ multiples = np.random.randint(1, 4, size=5).astype(np.int32)
+ tiled = tf.tile(a, multiples)
+ result = tiled.eval()
+ self.assertTrue((np.array(multiples) * np.array(inp.shape) ==
+ np.array(result.shape)).all())
+ self.assertAllEqual(result, np.tile(inp, tuple(multiples)))
+ self.assertShapeEqual(result, tiled)
+
+ def testRandom(self):
+ for _ in range(5):
+ self._RunAndVerifyResult(use_gpu=False)
+ for _ in range(5):
+ self._RunAndVerifyResult(use_gpu=True)
+
+ def testGradientSimpleReduction(self):
+ with self.test_session():
+ inp = np.random.rand(4, 1).astype("f")
+ a = tf.constant([float(x) for x in inp.flatten()],
+ shape=[4, 1], dtype=tf.float32)
+ tiled = tf.tile(a, [1, 4])
+ grad_shape = [4, 4]
+ grad_inp = np.random.rand(*grad_shape).astype("f")
+ grad_tensor = tf.constant([float(x) for x in grad_inp.flatten()],
+ shape=grad_shape)
+ grad = tf.gradients([tiled], [a], [grad_tensor])[0]
+ self.assertShapeEqual(inp, grad)
+ result = grad.eval()
+ self.assertAllClose(np.sum(grad_inp, axis=1).reshape(4, 1), result, 1e-3)
+
+ def testGradientStridedReduction(self):
+ with self.test_session():
+ inp = np.random.rand(4, 2).astype("f")
+ a = tf.constant([float(x) for x in inp.flatten()],
+ shape=[4, 2], dtype=tf.float32)
+ tiled = tf.tile(a, [1, 2])
+ grad_shape = [4, 4]
+ grad_inp = np.random.rand(*grad_shape).astype("f")
+ grad_tensor = tf.constant([float(x) for x in grad_inp.flatten()],
+ shape=grad_shape)
+ grad = tf.gradients([tiled], [a], [grad_tensor])[0]
+ self.assertShapeEqual(inp, grad)
+ result = grad.eval()
+ expected_shape = [4, 2]
+ expected = np.zeros(expected_shape)
+ expected[:, 0] = grad_inp[:, 0] + grad_inp[:, 2]
+ expected[:, 1] = grad_inp[:, 1] + grad_inp[:, 3]
+ self.assertTrue((np.abs(expected - result) < 1e-3).all())
+
+ def testGradientSimpleReductionOnGPU(self):
+ with self.test_session(use_gpu=True):
+ inp = np.random.rand(4, 1).astype("f")
+ a = tf.constant([float(x) for x in inp.flatten()],
+ shape=[4, 1], dtype=tf.float32)
+ tiled = tf.tile(a, [1, 4])
+ grad_shape = [4, 4]
+ grad_inp = np.random.rand(*grad_shape).astype("f")
+ grad_tensor = tf.constant([float(x) for x in grad_inp.flatten()],
+ shape=grad_shape)
+ grad = tf.gradients([tiled], [a], [grad_tensor])[0]
+ result = grad.eval()
+ self.assertAllClose(np.sum(grad_inp, axis=1).reshape(4, 1), result, 1e-3)
+
+ def testGradientStridedReductionOnGPU(self):
+ with self.test_session(use_gpu=True):
+ inp = np.random.rand(4, 2).astype("f")
+ a = tf.constant([float(x) for x in inp.flatten()],
+ shape=[4, 2], dtype=tf.float32)
+ tiled = tf.tile(a, [1, 2])
+ grad_shape = [4, 4]
+ grad_inp = np.random.rand(*grad_shape).astype("f")
+ grad_tensor = tf.constant([float(x) for x in grad_inp.flatten()],
+ shape=grad_shape)
+ grad = tf.gradients([tiled], [a], [grad_tensor])[0]
+ result = grad.eval()
+ expected_shape = [4, 2]
+ expected = np.zeros(expected_shape)
+ expected[:, 0] = grad_inp[:, 0] + grad_inp[:, 2]
+ expected[:, 1] = grad_inp[:, 1] + grad_inp[:, 3]
+ self.assertAllClose(expected, result, 1e-3)
+
+ def _RunAndVerifyGradientResult(self, input_shape, multiples):
+ with self.test_session():
+ # Random values
+ inp = np.random.rand(*input_shape)
+ a = tf.constant([float(x) for x in inp.flatten()],
+ shape=input_shape, dtype=tf.float64)
+ tiled = tf.tile(a, multiples)
+ grad_shape = list(np.array(multiples) * np.array(inp.shape))
+ err = gc.ComputeGradientError(a, list(input_shape), tiled, grad_shape,
+ x_init_value=inp)
+ print "tile(float) error = ", err
+ self.assertLess(err, 1e-3)
+
+ def testGradientRandom(self):
+ self._RunAndVerifyGradientResult([2, 2, 1, 1, 3], [1, 2, 1, 3, 1])
+ self._RunAndVerifyGradientResult([2, 3, 1, 1, 3], [3, 1, 1, 2, 2])
+ self._RunAndVerifyGradientResult([2, 1, 3, 3, 2], [1, 3, 3, 1, 2])
+
+ def testGradientStridedReductionGC(self):
+ with self.test_session():
+ inp = np.random.rand(4, 2).astype("f")
+ a = tf.constant([float(x) for x in inp.flatten()],
+ shape=[4, 2], dtype=tf.float32)
+ tiled = tf.tile(a, [1, 2])
+ err = gc.ComputeGradientError(a, [4, 2], tiled, [4, 4])
+ self.assertLess(err, 1e-3)
+
+ def testShapeFunctionEdgeCases(self):
+ # Unknown multiples shape.
+ inp = tf.constant(0.0, shape=[4, 4, 4, 4])
+ tiled = tf.tile(inp, tf.placeholder(tf.int32))
+ self.assertEqual([None, None, None, None], tiled.get_shape().as_list())
+
+ # Unknown input shape.
+ inp = tf.placeholder(tf.float32)
+ tiled = tf.tile(inp, [2, 2, 2, 2])
+ self.assertEqual([None, None, None, None], tiled.get_shape().as_list())
+
+ # Unknown input and multiples shape.
+ inp = tf.placeholder(tf.float32)
+ tiled = tf.tile(inp, tf.placeholder(tf.int32))
+ self.assertIs(None, tiled.get_shape().ndims)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py
new file mode 100644
index 0000000000..62d7e31dfc
--- /dev/null
+++ b/tensorflow/python/kernel_tests/slice_op_test.py
@@ -0,0 +1,235 @@
+"""Functional tests for slice op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class SliceTest(tf.test.TestCase):
+
+ def _testEmpty(self, use_gpu):
+ inp = np.random.rand(4, 4).astype("f")
+ for k in xrange(4):
+ with self.test_session(use_gpu=use_gpu):
+ a = tf.constant(inp, shape=[4, 4], dtype=tf.float32)
+ slice_t = a[2, k:k]
+ slice_val = slice_t.eval()
+ self.assertAllEqual(slice_val, inp[2, k:k])
+
+ def testEmptyAll(self):
+ self._testEmpty(use_gpu=False)
+ self._testEmpty(use_gpu=True)
+
+ def _testInt32(self, use_gpu):
+ inp = np.random.rand(4, 4).astype("i")
+ for k in xrange(4):
+ with self.test_session(use_gpu=use_gpu):
+ a = tf.constant(inp, shape=[4, 4], dtype=tf.int32)
+ slice_t = a[2, k:k]
+ slice_val = slice_t.eval()
+ self.assertAllEqual(slice_val, inp[2, k:k])
+
+ def testInt32(self):
+ self._testEmpty(use_gpu=False)
+ self._testEmpty(use_gpu=True)
+
+ def _testSelectAll(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ inp = np.random.rand(4, 4, 4, 4).astype("f")
+ a = tf.constant(inp, shape=[4, 4, 4, 4],
+ dtype=tf.float32)
+
+ slice_explicit_t = tf.slice(a, [0, 0, 0, 0], [-1, -1, -1, -1])
+ slice_implicit_t = a[:, :, :, :]
+
+ self.assertAllEqual(inp, slice_explicit_t.eval())
+ self.assertAllEqual(inp, slice_implicit_t.eval())
+ self.assertEqual(inp.shape, slice_explicit_t.get_shape())
+ self.assertEqual(inp.shape, slice_implicit_t.get_shape())
+
+ def testSelectAll(self):
+ for _ in range(10):
+ self._testSelectAll(use_gpu=False)
+ self._testSelectAll(use_gpu=True)
+
+ def _testSingleDimension(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ inp = np.random.rand(10).astype("f")
+ a = tf.constant(inp, shape=[10], dtype=tf.float32)
+
+ hi = np.random.random_integers(0, 9)
+ scalar_t = a[hi]
+ scalar_val = scalar_t.eval()
+ self.assertAllEqual(scalar_val, inp[hi])
+
+ lo = np.random.random_integers(0, hi)
+ slice_t = a[lo:hi]
+ slice_val = slice_t.eval()
+ self.assertAllEqual(slice_val, inp[lo:hi])
+
+ def testSingleDimension(self):
+ for _ in range(10):
+ self._testSingleDimension(use_gpu=False)
+ self._testSingleDimension(use_gpu=True)
+
+ def _testSliceMatrixDim0(self, x, begin, size, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ tf_ans = tf.slice(x, [begin, 0], [size, x.shape[1]]).eval()
+ np_ans = x[begin:begin+size, :]
+ self.assertAllEqual(tf_ans, np_ans)
+
+ def testSliceMatrixDim0(self):
+ for use_gpu in [False, True]:
+ x = np.random.rand(8, 4).astype("f")
+ self._testSliceMatrixDim0(x, 1, 2, use_gpu)
+ self._testSliceMatrixDim0(x, 3, 3, use_gpu)
+ y = np.random.rand(8, 7).astype("f") # 7 * sizeof(float) is not aligned
+ self._testSliceMatrixDim0(y, 1, 2, use_gpu)
+ self._testSliceMatrixDim0(y, 3, 3, use_gpu)
+
+ def _testIndexAndSlice(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ inp = np.random.rand(4, 4).astype("f")
+ a = tf.constant(inp, shape=[4, 4], dtype=tf.float32)
+
+ x, y = np.random.random_integers(0, 3, size=2).tolist()
+ slice_t = a[x, 0:y]
+ slice_val = slice_t.eval()
+ self.assertAllEqual(slice_val, inp[x, 0:y])
+
+ def testSingleElementAll(self):
+ for _ in range(10):
+ self._testIndexAndSlice(use_gpu=False)
+ self._testIndexAndSlice(use_gpu=True)
+
+ def _testSimple(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu) as sess:
+ inp = np.random.rand(4, 4).astype("f")
+ a = tf.constant([float(x) for x in inp.ravel(order="C")],
+ shape=[4, 4], dtype=tf.float32)
+ slice_t = tf.slice(a, [0, 0], [2, 2])
+ slice2_t = a[:2, :2]
+ slice_val, slice2_val = sess.run([slice_t, slice2_t])
+ self.assertAllEqual(slice_val, inp[:2, :2])
+ self.assertAllEqual(slice2_val, inp[:2, :2])
+ self.assertEqual(slice_val.shape, slice_t.get_shape())
+ self.assertEqual(slice2_val.shape, slice2_t.get_shape())
+
+ def testSimpleAll(self):
+ self._testSimple(use_gpu=False)
+ self._testSimple(use_gpu=True)
+
+ def _testComplex(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ inp = np.random.rand(4, 10, 10, 4).astype("f")
+ a = tf.constant(inp, dtype=tf.float32)
+
+ x = np.random.random_integers(0, 9)
+ z = np.random.random_integers(0, 9)
+ y = np.random.random_integers(0, z)
+ slice_t = a[:, x, y:z, :]
+ self.assertAllEqual(slice_t.eval(), inp[:, x, y:z, :])
+
+ def testComplex(self):
+ for _ in range(10):
+ self._testComplex(use_gpu=False)
+ self._testComplex(use_gpu=True)
+
+ def _RunAndVerifyResult(self, use_gpu):
+ # Random dims of rank 5
+ input_shape = np.random.randint(0, 20, size=5)
+ inp = np.random.rand(*input_shape).astype("f")
+ with self.test_session(use_gpu=use_gpu) as sess:
+ a = tf.constant([float(x) for x in inp.ravel(order="C")],
+ shape=input_shape, dtype=tf.float32)
+ indices = [0 if x == 0 else np.random.randint(x) for x in input_shape]
+ sizes = [np.random.randint(0, input_shape[i] - indices[i] + 1)
+ for i in range(5)]
+ slice_t = tf.slice(a, indices, sizes)
+ slice2_t = a[indices[0]:indices[0]+sizes[0],
+ indices[1]:indices[1]+sizes[1],
+ indices[2]:indices[2]+sizes[2],
+ indices[3]:indices[3]+sizes[3],
+ indices[4]:indices[4]+sizes[4]]
+
+ slice_val, slice2_val = sess.run([slice_t, slice2_t])
+
+ expected_val = inp[indices[0]:indices[0]+sizes[0],
+ indices[1]:indices[1]+sizes[1],
+ indices[2]:indices[2]+sizes[2],
+ indices[3]:indices[3]+sizes[3],
+ indices[4]:indices[4]+sizes[4]]
+ self.assertAllEqual(slice_val, expected_val)
+ self.assertAllEqual(slice2_val, expected_val)
+ self.assertEqual(expected_val.shape, slice_t.get_shape())
+ self.assertEqual(expected_val.shape, slice2_t.get_shape())
+
+ def testRandom(self):
+ for _ in range(10):
+ self._RunAndVerifyResult(use_gpu=False)
+ self._RunAndVerifyResult(use_gpu=True)
+
+ def _testGradientSlice(self, input_shape, slice_begin, slice_size, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ num_inputs = np.prod(input_shape)
+ num_grads = np.prod(slice_size)
+ inp = np.random.rand(num_inputs).astype("f").reshape(input_shape)
+ a = tf.constant([float(x) for x in inp.ravel(order="C")],
+ shape=input_shape, dtype=tf.float32)
+ slice_t = tf.slice(a, slice_begin, slice_size)
+ grads = np.random.rand(num_grads).astype("f").reshape(slice_size)
+ grad_tensor = tf.constant(grads)
+ grad = tf.gradients(slice_t, [a], grad_tensor)[0]
+ result = grad.eval()
+
+ # Create a zero tensor of the input shape ane place
+ # the grads into the right location to compare against TensorFlow.
+ np_ans = np.zeros(input_shape)
+ slices = []
+ for i in xrange(len(input_shape)):
+ slices.append(slice(slice_begin[i], slice_begin[i] + slice_size[i]))
+ np_ans[slices] = grads
+
+ self.assertAllClose(np_ans, result)
+
+ def _testGradientVariableSize(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ inp = tf.constant([1.0, 2.0, 3.0], name="in")
+ out = tf.slice(inp, [1], [-1])
+ grad_actual = tf.gradients(out, inp)[0].eval()
+ self.assertAllClose([0., 1., 1.], grad_actual)
+
+ def _testGradientsSimple(self, use_gpu):
+ # Slice the middle square out of a 4x4 input
+ self._testGradientSlice([4, 4], [1, 1], [2, 2], use_gpu)
+
+ # Slice the upper left square out of a 4x4 input
+ self._testGradientSlice([4, 4], [0, 0], [2, 2], use_gpu)
+
+ # Slice a non-square input starting from (2,1)
+ self._testGradientSlice([4, 4], [2, 1], [1, 2], use_gpu)
+
+ # Slice a 3D tensor
+ self._testGradientSlice([3, 3, 3], [0, 1, 0], [2, 1, 1], use_gpu)
+
+ # Use -1 as a slice dimension.
+ self._testGradientVariableSize(use_gpu)
+
+ def testGradientsAll(self):
+ self._testGradientsSimple(use_gpu=False)
+ self._testGradientsSimple(use_gpu=True)
+
+ def testNotIterable(self):
+ # NOTE(mrry): If we register __getitem__ as an overloaded
+ # operator, Python will valiantly attempt to iterate over the
+ # Tensor from 0 to infinity. This test ensures that this
+ # unintended behavior is prevented.
+ c = tf.constant(5.0)
+ with self.assertRaisesWithPredicateMatch(
+ TypeError,
+ lambda e: "'Tensor' object is not iterable" in e.message):
+ for _ in c:
+ pass
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py
new file mode 100644
index 0000000000..fd25970093
--- /dev/null
+++ b/tensorflow/python/kernel_tests/softmax_op_test.py
@@ -0,0 +1,65 @@
+"""Tests for SoftmaxOp."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class SoftmaxTest(tf.test.TestCase):
+
+ def _npSoftmax(self, features):
+ batch_dim = 0
+ class_dim = 1
+ batch_size = features.shape[batch_dim]
+ e = np.exp(features -
+ np.reshape(np.amax(features, axis=class_dim), [batch_size, 1]))
+ return e / np.reshape(np.sum(e, axis=class_dim), [batch_size, 1])
+
+ def _testSoftmax(self, np_features, use_gpu=False):
+ np_softmax = self._npSoftmax(np_features)
+ with self.test_session(use_gpu=use_gpu):
+ tf_softmax = tf.nn.softmax(np_features)
+ out = tf_softmax.eval()
+ self.assertAllClose(np_softmax, out)
+ self.assertShapeEqual(np_softmax, tf_softmax)
+ # Bonus check: the softmaxes should add to one in each
+ # batch element.
+ self.assertAllClose(np.ones(out.shape[0]),
+ np.sum(out, axis=1))
+
+ def _testAll(self, features):
+ self._testSoftmax(features, use_gpu=False)
+ self._testSoftmax(features, use_gpu=True)
+
+ def testNpSoftmax(self):
+ features = [[1., 1., 1., 1.], [1., 2., 3., 4.]]
+ # Batch 0: All exps are 1. The expected result is
+ # [0.25, 0.25, 0.25, 0.25]
+ #
+ # Batch 1:
+ # exps = [1., 2.718, 7.389, 20.085]
+ # sum = 31.192
+ # Softmaxes = exps / sum = [0.0320586, 0.08714432, 0.23688282, 0.64391426]
+ np_sm = self._npSoftmax(np.array(features))
+ self.assertAllClose(
+ np.array([[0.25, 0.25, 0.25, 0.25],
+ [0.0320586, 0.08714432, 0.23688282, 0.64391426]]),
+ np_sm,
+ rtol=1.e-5, atol=1.e-5)
+
+ def testShapeMismatch(self):
+ with self.assertRaises(ValueError):
+ tf.nn.softmax([0., 1., 2., 3.])
+
+ def testFloat(self):
+ self._testAll(
+ np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32))
+
+ def testDouble(self):
+ self._testSoftmax(
+ np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64),
+ use_gpu=False)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/softplus_op_test.py b/tensorflow/python/kernel_tests/softplus_op_test.py
new file mode 100644
index 0000000000..25b68aa659
--- /dev/null
+++ b/tensorflow/python/kernel_tests/softplus_op_test.py
@@ -0,0 +1,47 @@
+"""Tests for Softplus and SoftplusGrad."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+
+class SoftplusTest(tf.test.TestCase):
+
+ def _npSoftplus(self, np_features):
+ return np.log(1 + np.exp(np_features))
+
+ def _testSoftplus(self, np_features, use_gpu=False):
+ np_softplus = self._npSoftplus(np_features)
+ with self.test_session(use_gpu=use_gpu):
+ softplus = tf.nn.softplus(np_features)
+ tf_softplus = softplus.eval()
+ self.assertAllClose(np_softplus, tf_softplus)
+ self.assertShapeEqual(np_softplus, softplus)
+
+ def testNumbers(self):
+ for t in [np.float, np.double]:
+ self._testSoftplus(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
+ use_gpu=False)
+ self._testSoftplus(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
+ use_gpu=True)
+
+ def testGradient(self):
+ with self.test_session():
+ x = tf.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5], name="x")
+ y = tf.nn.softplus(x, name="softplus")
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float32, order="F")
+ err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init)
+ print "softplus (float) gradient err = ", err
+ self.assertLess(err, 1e-4)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/sparse_concat_op_test.py b/tensorflow/python/kernel_tests/sparse_concat_op_test.py
new file mode 100644
index 0000000000..0f5650b89c
--- /dev/null
+++ b/tensorflow/python/kernel_tests/sparse_concat_op_test.py
@@ -0,0 +1,260 @@
+"""Tests for SparseConcat."""
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class SparseConcatTest(tf.test.TestCase):
+
+ def _SparseTensor_UnknownShape(self, ind_shape=None, val_shape=None,
+ shape_shape=None):
+ return tf.SparseTensor(
+ tf.placeholder(tf.int64, shape=ind_shape),
+ tf.placeholder(tf.float32, shape=val_shape),
+ tf.placeholder(tf.int64, shape=shape_shape))
+
+ def _SparseTensor_3x3(self):
+ # [ 1]
+ # [2 ]
+ # [3 4]
+ ind = np.array([[0, 2], [1, 0], [2, 0], [2, 2]])
+ val = np.array([1, 2, 3, 4])
+ shape = np.array([3, 3])
+ return tf.SparseTensor(
+ tf.constant(ind, tf.int64),
+ tf.constant(val, tf.float32),
+ tf.constant(shape, tf.int64))
+
+ def _SparseTensor_3x5(self):
+ # [ ]
+ # [ 1 ]
+ # [2 1 0]
+ ind = np.array([[1, 1], [2, 0], [2, 3], [2, 4]])
+ val = np.array([1, 2, 1, 0])
+ shape = np.array([3, 5])
+ return tf.SparseTensor(
+ tf.constant(ind, tf.int64),
+ tf.constant(val, tf.float32),
+ tf.constant(shape, tf.int64))
+
+ def _SparseTensor_3x2(self):
+ # [ ]
+ # [1 ]
+ # [2 ]
+ ind = np.array([[1, 0], [2, 0]])
+ val = np.array([1, 2])
+ shape = np.array([3, 2])
+ return tf.SparseTensor(
+ tf.constant(ind, tf.int64),
+ tf.constant(val, tf.float32),
+ tf.constant(shape, tf.int64))
+
+ def _SparseTensor_2x3(self):
+ # [ 1 ]
+ # [1 2]
+ ind = np.array([[0, 1], [1, 0], [1, 2]])
+ val = np.array([1, 1, 2])
+ shape = np.array([2, 3])
+ return tf.SparseTensor(
+ tf.constant(ind, tf.int64),
+ tf.constant(val, tf.float32),
+ tf.constant(shape, tf.int64))
+
+ def _SparseTensor_2x3x4(self):
+ ind = np.array([
+ [0, 0, 1],
+ [0, 1, 0], [0, 1, 2],
+ [1, 0, 3],
+ [1, 1, 1], [1, 1, 3],
+ [1, 2, 2]])
+ val = np.array([1, 10, 12, 103, 111, 113, 122])
+ shape = np.array([2, 3, 4])
+ return tf.SparseTensor(
+ tf.constant(ind, tf.int64),
+ tf.constant(val, tf.float32),
+ tf.constant(shape, tf.int64))
+
+ def _SparseTensor_String3x3(self):
+ # [ a]
+ # [b ]
+ # [c d]
+ ind = np.array([[0, 2], [1, 0], [2, 0], [2, 2]])
+ val = np.array(["a", "b", "c", "d"])
+ shape = np.array([3, 3])
+ return tf.SparseTensor(
+ tf.constant(ind, tf.int64),
+ tf.constant(val, tf.string),
+ tf.constant(shape, tf.int64))
+
+ def _SparseTensor_String3x5(self):
+ # [ ]
+ # [ e ]
+ # [f g h]
+ ind = np.array([[1, 1], [2, 0], [2, 3], [2, 4]])
+ val = np.array(["e", "f", "g", "h"])
+ shape = np.array([3, 5])
+ return tf.SparseTensor(
+ tf.constant(ind, tf.int64),
+ tf.constant(val, tf.string),
+ tf.constant(shape, tf.int64))
+
+ def testConcat1(self):
+ with self.test_session(use_gpu=False) as sess:
+ # concat(A):
+ # [ 1]
+ # [2 ]
+ # [3 4]
+ sp_a = self._SparseTensor_3x3()
+
+ sp_concat = tf.sparse_concat(1, [sp_a])
+
+ self.assertEqual(sp_concat.indices.get_shape(), [4, 2])
+ self.assertEqual(sp_concat.values.get_shape(), [4])
+ self.assertEqual(sp_concat.shape.get_shape(), [2])
+
+ concat_out = sess.run(sp_concat)
+
+ self.assertAllEqual(
+ concat_out.indices, [[0, 2], [1, 0], [2, 0], [2, 2]])
+ self.assertAllEqual(concat_out.values, [1, 2, 3, 4])
+ self.assertAllEqual(concat_out.shape, [3, 3])
+
+ def testConcat2(self):
+ with self.test_session(use_gpu=False) as sess:
+ # concat(A, B):
+ # [ 1 ]
+ # [2 1 ]
+ # [3 4 2 1 0]
+ sp_a = self._SparseTensor_3x3()
+ sp_b = self._SparseTensor_3x5()
+
+ sp_concat = tf.sparse_concat(1, [sp_a, sp_b])
+
+ self.assertEqual(sp_concat.indices.get_shape(), [8, 2])
+ self.assertEqual(sp_concat.values.get_shape(), [8])
+ self.assertEqual(sp_concat.shape.get_shape(), [2])
+
+ concat_out = sess.run(sp_concat)
+
+ self.assertAllEqual(
+ concat_out.indices,
+ [[0, 2], [1, 0], [1, 4], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7]])
+ self.assertAllEqual(concat_out.values, [1, 2, 1, 3, 4, 2, 1, 0])
+ self.assertAllEqual(concat_out.shape, [3, 8])
+
+ def testConcatDim0(self):
+ with self.test_session(use_gpu=False) as sess:
+ # concat(A, D):
+ # [ 1]
+ # [2 ]
+ # [3 4]
+ # [ 1 ]
+ # [1 2]
+ sp_a = self._SparseTensor_3x3()
+ sp_d = self._SparseTensor_2x3()
+
+ sp_concat = tf.sparse_concat(0, [sp_a, sp_d])
+
+ self.assertEqual(sp_concat.indices.get_shape(), [7, 2])
+ self.assertEqual(sp_concat.values.get_shape(), [7])
+ self.assertEqual(sp_concat.shape.get_shape(), [2])
+
+ concat_out = sess.run(sp_concat)
+
+ self.assertAllEqual(
+ concat_out.indices,
+ [[0, 2], [1, 0], [2, 0], [2, 2], [3, 1], [4, 0], [4, 2]])
+ self.assertAllEqual(
+ concat_out.values, np.array([1, 2, 3, 4, 1, 1, 2]))
+ self.assertAllEqual(
+ concat_out.shape, np.array([5, 3]))
+
+ def testConcat3(self):
+ with self.test_session(use_gpu=False) as sess:
+ # concat(A, B, C):
+ # [ 1 ]
+ # [2 1 1 ]
+ # [3 4 2 1 0 2 ]
+ sp_a = self._SparseTensor_3x3()
+ sp_b = self._SparseTensor_3x5()
+ sp_c = self._SparseTensor_3x2()
+
+ sp_concat = tf.sparse_concat(1, [sp_a, sp_b, sp_c])
+
+ self.assertEqual(sp_concat.indices.get_shape(), [10, 2])
+ self.assertEqual(sp_concat.values.get_shape(), [10])
+ self.assertEqual(sp_concat.shape.get_shape(), [2])
+
+ concat_out = sess.run(sp_concat)
+
+ self.assertAllEqual(
+ concat_out.indices,
+ [[0, 2], [1, 0], [1, 4], [1, 8], [2, 0], [2, 2], [2, 3], [2, 6],
+ [2, 7], [2, 8]])
+ self.assertAllEqual(concat_out.values, [1, 2, 1, 1, 3, 4, 2, 1, 0, 2])
+ self.assertAllEqual(concat_out.shape, [3, 10])
+
+ def testConcatNonNumeric(self):
+ with self.test_session(use_gpu=False) as sess:
+ # concat(A, B):
+ # [ a ]
+ # [b e ]
+ # [c d f g h]
+ sp_a = self._SparseTensor_String3x3()
+ sp_b = self._SparseTensor_String3x5()
+
+ sp_concat = tf.sparse_concat(1, [sp_a, sp_b])
+
+ self.assertEqual(sp_concat.indices.get_shape(), [8, 2])
+ self.assertEqual(sp_concat.values.get_shape(), [8])
+ self.assertEqual(sp_concat.shape.get_shape(), [2])
+
+ concat_out = sess.run(sp_concat)
+
+ self.assertAllEqual(
+ concat_out.indices,
+ [[0, 2], [1, 0], [1, 4], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7]])
+ self.assertAllEqual(
+ concat_out.values, ["a", "b", "e", "c", "d", "f", "g", "h"])
+ self.assertAllEqual(concat_out.shape, [3, 8])
+
+ def testMismatchedRank(self):
+ with self.test_session(use_gpu=False):
+ sp_a = self._SparseTensor_3x3()
+ sp_e = self._SparseTensor_2x3x4()
+
+ # Rank mismatches can be caught at shape-inference time
+ with self.assertRaises(ValueError):
+ tf.sparse_concat(1, [sp_a, sp_e])
+
+ def testMismatchedShapes(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_a = self._SparseTensor_3x3()
+ sp_b = self._SparseTensor_3x5()
+ sp_c = self._SparseTensor_3x2()
+ sp_d = self._SparseTensor_2x3()
+ sp_concat = tf.sparse_concat(1, [sp_a, sp_b, sp_c, sp_d])
+
+ # Shape mismatches can only be caught when the op is run
+ with self.assertRaisesOpError("Input shapes must match"):
+ sess.run(sp_concat)
+
+ def testShapeInferenceUnknownShapes(self):
+ with self.test_session(use_gpu=False):
+ sp_inputs = [
+ self._SparseTensor_UnknownShape(),
+ self._SparseTensor_UnknownShape(val_shape=[3]),
+ self._SparseTensor_UnknownShape(ind_shape=[1, 3]),
+ self._SparseTensor_UnknownShape(shape_shape=[3])]
+
+ sp_concat = tf.sparse_concat(0, sp_inputs)
+
+ self.assertEqual(sp_concat.indices.get_shape().as_list(), [None, 3])
+ self.assertEqual(sp_concat.values.get_shape().as_list(), [None])
+ self.assertEqual(sp_concat.shape.get_shape(), [3])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
new file mode 100644
index 0000000000..d87d15cae9
--- /dev/null
+++ b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
@@ -0,0 +1,82 @@
+"""Tests for tensorflow.ops.tf.matmul."""
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+
+def RandMatrix(rows, cols, tr):
+ if tr:
+ rows, cols = cols, rows
+ return (np.clip(np.random.uniform(low=-100.0, high=100.0, size=rows * cols),
+ 0, 100) / 100).reshape([rows, cols]).astype(np.float32)
+
+
+class SparseMatMulTest(tf.test.TestCase):
+
+ def _testCpuMatmul(self, x, y, tr_a=False, tr_b=False,
+ sp_a=True, sp_b=False):
+ x_mat = np.matrix(x)
+ if tr_a:
+ x_mat = np.transpose(x_mat)
+ y_mat = np.matrix(y)
+ if tr_b:
+ y_mat = np.transpose(y_mat)
+ np_ans = x_mat * y_mat
+ with self.test_session(use_gpu=False):
+ tf_ans = tf.matmul(x, y,
+ transpose_a=tr_a, transpose_b=tr_b,
+ a_is_sparse=sp_a,
+ b_is_sparse=sp_b)
+ out = tf_ans.eval()
+ self.assertAllClose(np_ans, out)
+ self.assertShapeEqual(np_ans, tf_ans)
+
+ def testFloatBasic(self):
+ x = np.arange(0., 4.).reshape([4, 1]).astype(np.float32)
+ y = np.arange(-1., 1.).reshape([1, 2]).astype(np.float32)
+ self._testCpuMatmul(x, y)
+
+ # Tests testing random sized matrices.
+ def testFloatRandom(self):
+ for _ in range(10):
+ for tr_a in [True, False]:
+ for tr_b in [True, False]:
+ for sp_a in [True, False]:
+ for sp_b in [True, False]:
+ n, k, m = np.random.randint(1, 100, size=3)
+ x = RandMatrix(n, k, tr_a)
+ y = RandMatrix(k, m, tr_b)
+ self._testCpuMatmul(x, y, tr_a, tr_b, sp_a, sp_b)
+
+
+class MatMulGradientTest(tf.test.TestCase):
+
+ def _testGradients(self, tr_a, tr_b, sp_a, sp_b, name):
+ with self.test_session():
+ a = tf.constant(RandMatrix(3, 2, tr_a), dtype=tf.float32)
+ b = tf.constant(RandMatrix(2, 4, tr_b), dtype=tf.float32)
+ m = tf.matmul(a, b,
+ name=name,
+ transpose_a=tr_a,
+ transpose_b=tr_b,
+ a_is_sparse=sp_a,
+ b_is_sparse=sp_b)
+ err = (gc.ComputeGradientError(a, [2, 3] if tr_a else [3, 2], m, [3, 4]) +
+ gc.ComputeGradientError(b, [4, 2] if tr_b else [2, 4], m, [3, 4]))
+ print "sparse_matmul gradient err = ", err
+ self.assertLess(err, 1e-3)
+
+ def testGradientInput(self):
+ for tr_a in [True, False]:
+ for tr_b in [True, False]:
+ for sp_a in [True, False]:
+ for sp_b in [True, False]:
+ name = "sparse_matmul_%s_%s_%s_%s" % (tr_a, tr_b, sp_a, sp_b)
+ self._testGradients(tr_a, tr_b, sp_a, sp_b, name)
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/sparse_reorder_op_test.py b/tensorflow/python/kernel_tests/sparse_reorder_op_test.py
new file mode 100644
index 0000000000..c3bcc25311
--- /dev/null
+++ b/tensorflow/python/kernel_tests/sparse_reorder_op_test.py
@@ -0,0 +1,56 @@
+"""Tests for SparseReorder."""
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class SparseReorderTest(tf.test.TestCase):
+
+ def _SparseTensorPlaceholder(self):
+ return tf.SparseTensor(
+ tf.placeholder(tf.int64),
+ tf.placeholder(tf.int32),
+ tf.placeholder(tf.int64))
+
+ def _SparseTensorValue_5x6(self, permutation):
+ ind = np.array([
+ [0, 0],
+ [1, 0], [1, 3], [1, 4],
+ [3, 2], [3, 3]]).astype(np.int64)
+ val = np.array([0, 10, 13, 14, 32, 33]).astype(np.int32)
+
+ ind = ind[permutation]
+ val = val[permutation]
+
+ shape = np.array([5, 6]).astype(np.int64)
+ return tf.SparseTensorValue(ind, val, shape)
+
+ def testAlreadyInOrder(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensorPlaceholder()
+ input_val = self._SparseTensorValue_5x6(np.arange(6))
+ sp_output = tf.sparse_reorder(sp_input)
+
+ output_val = sess.run(sp_output, {sp_input: input_val})
+ self.assertAllEqual(output_val.indices, input_val.indices)
+ self.assertAllEqual(output_val.values, input_val.values)
+ self.assertAllEqual(output_val.shape, input_val.shape)
+
+ def testOutOfOrder(self):
+ expected_output_val = self._SparseTensorValue_5x6(np.arange(6))
+ with self.test_session(use_gpu=False) as sess:
+ for _ in range(5): # To test various random permutations
+ sp_input = self._SparseTensorPlaceholder()
+ input_val = self._SparseTensorValue_5x6(np.random.permutation(6))
+ sp_output = tf.sparse_reorder(sp_input)
+
+ output_val = sess.run(sp_output, {sp_input: input_val})
+ self.assertAllEqual(output_val.indices, expected_output_val.indices)
+ self.assertAllEqual(output_val.values, expected_output_val.values)
+ self.assertAllEqual(output_val.shape, expected_output_val.shape)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py b/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
new file mode 100644
index 0000000000..2bab89923e
--- /dev/null
+++ b/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
@@ -0,0 +1,111 @@
+"""Tests for tensorflow.kernels.sparse_op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+def _SparseToDense(sparse_indices, output_size, sparse_values,
+ default_value):
+ return tf.sparse_to_dense(sparse_indices, output_size,
+ sparse_values, default_value)
+
+
+class SparseToDenseTest(tf.test.TestCase):
+
+ def testInt(self):
+ with self.test_session(use_gpu=False):
+ tf_ans = _SparseToDense([1, 3], [5], 1, 0).eval()
+ np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32)
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testFloat(self):
+ with self.test_session(use_gpu=False):
+ tf_ans = _SparseToDense([1, 3], [5], 1.0, 0.0).eval()
+ np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32)
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testString(self):
+ with self.test_session(use_gpu=False):
+ tf_ans = _SparseToDense([1, 3], [5], "a", "b").eval()
+ np_ans = np.array(["b", "a", "b", "a", "b"]).astype(np.string_)
+ self.assertAllEqual(np_ans, tf_ans)
+
+ def testSetValue(self):
+ with self.test_session(use_gpu=False):
+ tf_ans = _SparseToDense([1, 3], [5], [1, 2], -1).eval()
+ np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32)
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testSetSingleValue(self):
+ with self.test_session(use_gpu=False):
+ tf_ans = _SparseToDense([1, 3], [5], 1, -1).eval()
+ np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32)
+ self.assertAllClose(np_ans, tf_ans)
+
+ def test2d(self):
+ # pylint: disable=bad-whitespace
+ with self.test_session(use_gpu=False):
+ tf_ans = _SparseToDense([[1, 3], [2, 0]], [3, 4], 1, -1).eval()
+ np_ans = np.array([[-1, -1, -1, -1],
+ [-1, -1, -1, 1],
+ [ 1, -1, -1, -1]]).astype(np.int32)
+ self.assertAllClose(np_ans, tf_ans)
+
+ def test3d(self):
+ with self.test_session(use_gpu=False):
+ tf_ans = _SparseToDense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, -1).eval()
+ np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1
+ np_ans[1, 3, 0] = 1
+ np_ans[2, 0, 1] = 1
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testBadShape(self):
+ with self.test_session():
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, lambda e: ("Input shape should be a vector" == str(e))):
+ _SparseToDense([1, 3], [[5], [3]], 1, -1)
+
+ def testBadValue(self):
+ with self.test_session():
+ dense = _SparseToDense([1, 3], [5], [[5], [3]], -1)
+ with self.assertRaisesOpError(
+ r"sparse_values has incorrect shape \[2,1\], "
+ r"should be \[\] or \[2\]"):
+ dense.eval()
+
+ def testBadNumValues(self):
+ with self.test_session():
+ dense = _SparseToDense([1, 3], [5], [1, 2, 3], -1)
+ with self.assertRaisesOpError(
+ r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"):
+ dense.eval()
+
+ def testBadDefault(self):
+ with self.test_session():
+ dense = _SparseToDense([1, 3], [5], [1, 2], [1, 2])
+ with self.assertRaisesOpError("default_value should be a scalar"):
+ dense.eval()
+
+ def testShapeInferenceKnownShape(self):
+ with self.test_session(use_gpu=False):
+ indices = tf.placeholder(tf.int64)
+
+ shape = [4, 5, 6]
+ output = tf.sparse_to_dense(indices, shape, 1, 0)
+ self.assertEqual(output.get_shape(), [4, 5, 6])
+
+ shape = tf.placeholder(tf.int64, shape=(3,))
+ output = tf.sparse_to_dense(indices, shape, 1, 0)
+ self.assertEqual(output.get_shape().as_list(), [None, None, None])
+
+ def testShapeInferenceUnknownShape(self):
+ with self.test_session(use_gpu=False):
+ indices = tf.placeholder(tf.int64)
+ shape = tf.placeholder(tf.int64)
+ output = tf.sparse_to_dense(indices, shape, 1, 0)
+ self.assertEqual(output.get_shape().ndims, None)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/sparsemask_op_test.py b/tensorflow/python/kernel_tests/sparsemask_op_test.py
new file mode 100644
index 0000000000..ffde8f7944
--- /dev/null
+++ b/tensorflow/python/kernel_tests/sparsemask_op_test.py
@@ -0,0 +1,32 @@
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class SparseMaskTest(tf.test.TestCase):
+
+ def testBasic(self):
+ values = np.random.rand(4, 4).astype(np.single)
+ indices = np.array([0, 2, 3, 4], dtype=np.int32)
+ mask_indices = np.array([0], dtype=np.int32)
+
+ out_values = values[1:, :]
+ out_indices = np.array([2, 3, 4], dtype=np.int32)
+
+ with self.test_session() as sess:
+ values_tensor = tf.convert_to_tensor(values)
+ indices_tensor = tf.convert_to_tensor(indices)
+ mask_indices_tensor = tf.convert_to_tensor(mask_indices)
+
+ t = tf.IndexedSlices(values_tensor, indices_tensor)
+ masked_t = tf.sparse_mask(t, mask_indices_tensor)
+
+ tf_out_values, tf_out_indices = sess.run([masked_t.values,
+ masked_t.indices])
+
+ self.assertAllEqual(tf_out_values, out_values)
+ self.assertAllEqual(tf_out_indices, out_indices)
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py
new file mode 100644
index 0000000000..19906aa02b
--- /dev/null
+++ b/tensorflow/python/kernel_tests/split_op_test.py
@@ -0,0 +1,132 @@
+"""Functional tests for Split Op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class SplitOpTest(tf.test.TestCase):
+
+ def _compare(self, x, dim, num, use_gpu):
+ np_ans = np.split(x, num, dim)
+ with self.test_session(use_gpu=use_gpu) as sess:
+ tf_ans = tf.split(dim, num, x)
+ out = sess.run(tf_ans)
+ self.assertEqual(num, len(np_ans))
+ self.assertEqual(num, len(np_ans))
+ self.assertEqual(num, len(out))
+ for i in range(num):
+ self.assertAllEqual(np_ans[i], out[i])
+ self.assertShapeEqual(np_ans[i], tf_ans[i])
+
+ def _testSplitRows(self, use_gpu):
+ inp = np.random.rand(4, 4).astype("f")
+ self._compare(inp, 0, 4, use_gpu)
+
+ def testSplitRowsAll(self):
+ self._testSplitRows(use_gpu=False)
+ self._testSplitRows(use_gpu=True)
+
+ def _testSplitCols(self, use_gpu):
+ inp = np.random.rand(4, 4).astype("f")
+ self._compare(inp, 1, 4, use_gpu)
+
+ def testSplitColsAll(self):
+ self._testSplitRows(use_gpu=False)
+ self._testSplitCols(use_gpu=True)
+
+ def _testEmpty(self, x, dim, num, expected_shape):
+ with self.test_session() as sess:
+ tf_ans = tf.split(dim, num, x)
+ out = sess.run(tf_ans)
+ self.assertEqual(x.size, 0)
+ self.assertEqual(len(out), num)
+ for i in range(num):
+ self.assertEqual(out[i].shape, expected_shape)
+ self.assertEqual(expected_shape, tf_ans[i].get_shape())
+
+ def testEmpty(self):
+ # Note: np.split returns a rank-0 empty ndarray
+ # if the input ndarray is empty.
+ inp = np.random.rand(8, 0, 21).astype("f")
+ self._testEmpty(inp, 0, 2, (4, 0, 21))
+ self._testEmpty(inp, 0, 4, (2, 0, 21))
+ self._testEmpty(inp, 1, 4, (8, 0, 21))
+ self._testEmpty(inp, 2, 3, (8, 0, 7))
+ self._testEmpty(inp, 2, 7, (8, 0, 3))
+
+ def testIdentity(self):
+ inp = np.random.rand(2, 2, 2).astype("f")
+ for use_gpu in [False, True]:
+ self._compare(inp, 0, 1, use_gpu)
+ self._compare(inp, 1, 1, use_gpu)
+ self._compare(inp, 2, 1, use_gpu)
+
+ def testSplitDim0(self):
+ for use_gpu in [False, True]:
+ self._compare(np.random.rand(6, 10, 18).astype("f"), 0, 3, use_gpu)
+ self._compare(np.random.rand(6, 7, 18).astype("f"), 0, 3, use_gpu)
+ self._compare(np.random.rand(6, 7, 9).astype("f"), 0, 3, use_gpu)
+
+ def _RunAndVerify(self, use_gpu):
+ # Random dims of rank 5
+ shape = np.random.randint(0, 5, size=5)
+ split_dim = np.random.randint(0, 5)
+ num_split = np.random.randint(2, 8)
+ shape[split_dim] = np.random.randint(2, 5) * num_split
+ inp = np.random.rand(*shape).astype("f")
+ with self.test_session(use_gpu=use_gpu) as sess:
+ result = sess.run(tf.split(split_dim, num_split, inp))
+ slices = [slice(0, x) for x in shape]
+ offset = 0
+ length = shape[split_dim] / num_split
+ for i in range(num_split):
+ slices[split_dim] = slice(offset, offset + length)
+ offset += length
+ self.assertAllEqual(result[i], inp[slices])
+
+ def testRandom(self):
+ for _ in range(5):
+ self._RunAndVerify(use_gpu=False)
+ self._RunAndVerify(use_gpu=True)
+
+ def _testGradientsSimple(self, use_gpu):
+ inp = np.random.rand(4, 4).astype("f")
+ with self.test_session(use_gpu=use_gpu):
+ inp_tensor = tf.convert_to_tensor(inp)
+ s = tf.split(1, 4, inp_tensor)
+ inp_grads = [np.random.rand(4, 1).astype("f") for _ in range(4)]
+ grad_tensors = [tf.constant(x) for x in inp_grads]
+ grad = tf.gradients(s, [inp_tensor], grad_tensors)[0]
+ result = grad.eval()
+ for i in range(4):
+ self.assertAllEqual(result[:, i:i+1], inp_grads[i])
+
+ def testGradientsAll(self):
+ self._testGradientsSimple(use_gpu=False)
+ self._testGradientsSimple(use_gpu=True)
+
+ def testShapeFunctionEdgeCases(self):
+ # split_dim greater than rank of input.
+ with self.assertRaises(ValueError):
+ tf.split(2, 4, [[0, 1], [2, 3]])
+
+ # num_split does not evenly divide the size in split_dim.
+ with self.assertRaisesRegexp(ValueError, "should evenly divide"):
+ tf.split(0, 3, [0, 1, 2, 3])
+
+ # Unknown split_dim.
+ splits = tf.split(tf.placeholder(tf.int32),
+ 4, [[0, 1, 2, 3]])
+ for s in splits:
+ self.assertEqual([None, None], s.get_shape().as_list())
+
+ # Unknown split_dim and input shape.
+ splits = tf.split(tf.placeholder(tf.int32),
+ 4, tf.placeholder(tf.float32))
+ for s in splits:
+ self.assertEqual(None, s.get_shape().ndims)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py b/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
new file mode 100644
index 0000000000..8615b271b8
--- /dev/null
+++ b/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
@@ -0,0 +1,34 @@
+"""Tests for StringToHashBucket op from string_ops."""
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class StringToHashBucketOpTest(tf.test.TestCase):
+
+ def testStringToOneHashBucket(self):
+ with self.test_session():
+ input_string = tf.placeholder(tf.string)
+ output = tf.string_to_hash_bucket(input_string, 1)
+ result = output.eval(feed_dict={
+ input_string: ['a', 'b', 'c']
+ })
+
+ self.assertAllEqual([0, 0, 0], result)
+
+ def testStringToHashBuckets(self):
+ with self.test_session():
+ input_string = tf.placeholder(tf.string)
+ output = tf.string_to_hash_bucket(input_string, 10)
+ result = output.eval(feed_dict={
+ input_string: ['a', 'b', 'c']
+ })
+
+ # Hash64('a') -> 2996632905371535868 -> mod 10 -> 8
+ # Hash64('b') -> 5795986006276551370 -> mod 10 -> 0
+ # Hash64('c') -> 14899841994519054197 -> mod 10 -> 7
+ self.assertAllEqual([8, 0, 7], result)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/string_to_number_op_test.py b/tensorflow/python/kernel_tests/string_to_number_op_test.py
new file mode 100644
index 0000000000..39505e18ba
--- /dev/null
+++ b/tensorflow/python/kernel_tests/string_to_number_op_test.py
@@ -0,0 +1,66 @@
+"""Tests for StringToNumber op from parsing_ops."""
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+_ERROR_MESSAGE = "StringToNumberOp could not correctly convert string: "
+
+
+class StringToNumberOpTest(tf.test.TestCase):
+
+ def testToFloat(self):
+ with self.test_session():
+ input_string = tf.placeholder(tf.string)
+ output = tf.string_to_number(
+ input_string,
+ out_type=tf.float32)
+
+ result = output.eval(feed_dict={
+ input_string: ["0",
+ "3",
+ "-1",
+ "1.12",
+ "0xF",
+ " -10.5",
+ "3.40282e+38",
+ # The next two exceed maximum value for float, so we
+ # expect +/-INF to be returned instead.
+ "3.40283e+38",
+ "-3.40283e+38",
+ "NAN",
+ "INF"]
+ })
+
+ self.assertAllClose([0, 3, -1, 1.12, 0xF, -10.5, 3.40282e+38,
+ float("INF"), float("-INF"), float("NAN"),
+ float("INF")], result)
+
+ with self.assertRaisesOpError(_ERROR_MESSAGE + "10foobar"):
+ output.eval(feed_dict={input_string: ["10foobar"]})
+
+ def testToInt32(self):
+ with self.test_session():
+ input_string = tf.placeholder(tf.string)
+ output = tf.string_to_number(
+ input_string,
+ out_type=tf.int32)
+
+ result = output.eval(feed_dict={
+ input_string: ["0", "3", "-1", " -10", "-2147483648", "2147483647"]
+ })
+
+ self.assertAllEqual([0, 3, -1, -10, -2147483648, 2147483647], result)
+
+ with self.assertRaisesOpError(_ERROR_MESSAGE + "2.9"):
+ output.eval(feed_dict={input_string: ["2.9"]})
+
+ # The next two exceed maximum value of int32.
+ for in_string in ["-2147483649", "2147483648"]:
+ with self.assertRaisesOpError(_ERROR_MESSAGE + in_string):
+ output.eval(feed_dict={input_string: [in_string]})
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/summary_image_op_test.py b/tensorflow/python/kernel_tests/summary_image_op_test.py
new file mode 100644
index 0000000000..dfdb2c8938
--- /dev/null
+++ b/tensorflow/python/kernel_tests/summary_image_op_test.py
@@ -0,0 +1,63 @@
+"""Tests for summary image op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.ops import image_ops
+
+
+class SummaryImageOpTest(tf.test.TestCase):
+
+ def _AsSummary(self, s):
+ summ = tf.Summary()
+ summ.ParseFromString(s)
+ return summ
+
+ def testImageSummary(self):
+ np.random.seed(7)
+ with self.test_session() as sess:
+ for depth in 1, 3, 4:
+ shape = (4, 5, 7) + (depth,)
+ bad_color = [255, 0, 0, 255][:depth]
+ for positive in False, True:
+ # Build a mostly random image with one nan
+ const = np.random.randn(*shape)
+ const[0, 1, 2] = 0 # Make the nan entry not the max
+ if positive:
+ const = 1 + np.maximum(const, 0)
+ scale = 255 / const.reshape(4, -1).max(axis=1)
+ offset = 0
+ else:
+ scale = 127 / np.abs(const.reshape(4, -1)).max(axis=1)
+ offset = 128
+ adjusted = np.floor(scale[:, None, None, None] * const + offset)
+ const[0, 1, 2, depth / 2] = np.nan
+
+ # Summarize
+ summ = tf.image_summary("img", const)
+ value = sess.run(summ)
+ self.assertEqual([], summ.get_shape())
+ image_summ = self._AsSummary(value)
+
+ # Decode the first image and check consistency
+ image = image_ops.decode_png(
+ image_summ.value[0].image.encoded_image_string).eval()
+ self.assertAllEqual(image[1, 2], bad_color)
+ image[1, 2] = adjusted[0, 1, 2]
+ self.assertAllClose(image, adjusted[0])
+
+ # Check the rest of the proto
+ # Only the first 3 images are returned.
+ for v in image_summ.value:
+ v.image.ClearField("encoded_image_string")
+ expected = '\n'.join("""
+ value {
+ tag: "img/image/%d"
+ image { height: %d width: %d colorspace: %d }
+ }""" % ((i,) + shape[1:]) for i in xrange(3))
+ self.assertProtoEquals(expected, image_summ)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py
new file mode 100644
index 0000000000..13e5021ccc
--- /dev/null
+++ b/tensorflow/python/kernel_tests/summary_ops_test.py
@@ -0,0 +1,83 @@
+"""Tests for summary ops."""
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+class SummaryOpsTest(tf.test.TestCase):
+
+ def _AsSummary(self, s):
+ summ = tf.Summary()
+ summ.ParseFromString(s)
+ return summ
+
+ def testScalarSummary(self):
+ with self.test_session() as sess:
+ const = tf.constant([10.0, 20.0])
+ summ = tf.scalar_summary(["c1", "c2"], const, name="mysumm")
+ value = sess.run(summ)
+ self.assertEqual([], summ.get_shape())
+ self.assertProtoEquals("""
+ value { tag: "c1" simple_value: 10.0 }
+ value { tag: "c2" simple_value: 20.0 }
+ """, self._AsSummary(value))
+
+ def testScalarSummaryDefaultName(self):
+ with self.test_session() as sess:
+ const = tf.constant([10.0, 20.0])
+ summ = tf.scalar_summary(["c1", "c2"], const)
+ value = sess.run(summ)
+ self.assertEqual([], summ.get_shape())
+ self.assertProtoEquals("""
+ value { tag: "c1" simple_value: 10.0 }
+ value { tag: "c2" simple_value: 20.0 }
+ """, self._AsSummary(value))
+
+ def testMergeSummary(self):
+ with self.test_session() as sess:
+ const = tf.constant(10.0)
+ summ1 = tf.histogram_summary("h", const, name="histo")
+ summ2 = tf.scalar_summary("c", const, name="summ")
+ merge = tf.merge_summary([summ1, summ2])
+ value = sess.run(merge)
+ self.assertEqual([], merge.get_shape())
+ self.assertProtoEquals("""
+ value {
+ tag: "h"
+ histo {
+ min: 10.0
+ max: 10.0
+ num: 1.0
+ sum: 10.0
+ sum_squares: 100.0
+ bucket_limit: 9.93809490288
+ bucket_limit: 10.9319043932
+ bucket_limit: 1.79769313486e+308
+ bucket: 0.0
+ bucket: 1.0
+ bucket: 0.0
+ }
+ }
+ value { tag: "c" simple_value: 10.0 }
+ """, self._AsSummary(value))
+
+ def testMergeAllSummaries(self):
+ with tf.Graph().as_default():
+ const = tf.constant(10.0)
+ summ1 = tf.histogram_summary("h", const, name="histo")
+ summ2 = tf.scalar_summary("o", const, name="oops",
+ collections=["foo_key"])
+ summ3 = tf.scalar_summary("c", const, name="summ")
+ merge = tf.merge_all_summaries()
+ self.assertEqual("MergeSummary", merge.op.type)
+ self.assertEqual(2, len(merge.op.inputs))
+ self.assertEqual(summ1, merge.op.inputs[0])
+ self.assertEqual(summ3, merge.op.inputs[1])
+ merge = tf.merge_all_summaries("foo_key")
+ self.assertEqual("MergeSummary", merge.op.type)
+ self.assertEqual(1, len(merge.op.inputs))
+ self.assertEqual(summ2, merge.op.inputs[0])
+ self.assertTrue(tf.merge_all_summaries("bar_key") is None)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py
new file mode 100644
index 0000000000..497dc9ac1e
--- /dev/null
+++ b/tensorflow/python/kernel_tests/topk_op_test.py
@@ -0,0 +1,52 @@
+"""Tests for TopK op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class TopKTest(tf.test.TestCase):
+
+ def _validateTopK(self, inputs, k, expected_values, expected_indices):
+ np_values = np.array(expected_values)
+ np_indices = np.array(expected_indices)
+ with self.test_session():
+ values_op, indices_op = tf.nn.top_k(inputs, k)
+ values = values_op.eval()
+ indices = indices_op.eval()
+ self.assertAllClose(np_values, values)
+ self.assertAllEqual(np_indices, indices)
+ self.assertShapeEqual(np_values, values_op)
+ self.assertShapeEqual(np_indices, indices_op)
+
+ def testTop1(self):
+ inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]]
+ self._validateTopK(inputs, 1,
+ [[0.4], [0.3]],
+ [[3], [1]])
+
+ def testTop2(self):
+ inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]]
+ self._validateTopK(inputs, 2,
+ [[0.4, 0.3], [0.3, 0.3]],
+ [[3, 1], [1, 2]])
+
+ def testTopAll(self):
+ inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]]
+ self._validateTopK(inputs, 4,
+ [[0.4, 0.3, 0.2, 0.1], [0.3, 0.3, 0.2, 0.1]],
+ [[3, 1, 2, 0], [1, 2, 3, 0]])
+
+ def testKNegative(self):
+ inputs = [[0.1, 0.2], [0.3, 0.4]]
+ with self.assertRaisesRegexp(ValueError, "less than minimum 1"):
+ tf.nn.top_k(inputs, -1)
+
+ def testKTooLarge(self):
+ inputs = [[0.1, 0.2], [0.3, 0.4]]
+ with self.assertRaisesRegexp(ValueError, "input must have at least k"):
+ tf.nn.top_k(inputs, 4)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
new file mode 100644
index 0000000000..2786eaf37b
--- /dev/null
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
@@ -0,0 +1,176 @@
+"""Functional tests for Transpose op."""
+import itertools
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests.gradient_checker import ComputeGradient
+
+
+class TransposeTest(tf.test.TestCase):
+
+ def _np_transpose(self, x, perm):
+ ret = np.copy(x)
+ ret = ret.transpose(perm)
+ return ret
+
+ def _compareCpu(self, x, p):
+ np_ans = self._np_transpose(x, p)
+ with self.test_session(use_gpu=False):
+ inx = tf.convert_to_tensor(x)
+ y = tf.transpose(inx, p)
+ tf_ans = y.eval()
+ self.assertAllEqual(np_ans, tf_ans)
+ self.assertShapeEqual(np_ans, y)
+
+ jacob_t = None
+ # Gradient check on CPU.
+ xs = list(np.shape(x))
+ ys = list(np.shape(tf_ans))
+ if x.dtype == np.float32:
+ jacob_t, jacob_n = ComputeGradient(inx, xs, y, ys, x, 1e-2)
+ self.assertAllClose(jacob_t, jacob_n, 1e-3, 1e-3)
+ elif x.dtype == np.float64:
+ jacob_t, jacob_n = ComputeGradient(inx, xs, y, ys, x, 1e-2)
+ self.assertAllClose(jacob_t, jacob_n, 1e-6, 1e-6)
+
+ return tf_ans, jacob_t
+
+ def _compareGpu(self, x, p):
+ np_ans = self._np_transpose(x, p)
+ with self.test_session(use_gpu=True):
+ inx = tf.convert_to_tensor(x)
+ y = tf.transpose(inx, p)
+ tf_ans = y.eval()
+ self.assertAllEqual(np_ans, tf_ans)
+ self.assertShapeEqual(np_ans, y)
+
+ jacob_t = None
+ # Gradient check on GPU.
+ xs = list(np.shape(x))
+ ys = list(np.shape(tf_ans))
+ if x.dtype == np.float32:
+ jacob_t, jacob_n = ComputeGradient(inx, xs, y, ys, x, 1e-2)
+ self.assertAllClose(jacob_t, jacob_n, 1e-3, 1e-3)
+ elif x.dtype == np.float64:
+ jacob_t, jacob_n = ComputeGradient(inx, xs, y, ys, x, 1e-2)
+ self.assertAllClose(jacob_t, jacob_n, 1e-6, 1e-6)
+
+ return tf_ans, jacob_t
+
+ def _compare(self, x, use_gpu=False):
+ n = np.ndim(x)
+ # generate all permutations of [0, 1, ... n-1] in random order.
+ all_perm = np.random.permutation(
+ [p for p in itertools.permutations(range(n))]).astype(np.int32)
+ for p in all_perm[0:2]:
+ self._compareCpu(x, p)
+ if use_gpu:
+ self._compareGpu(x, p)
+
+ def _compare_cpu_gpu(self, x):
+ n = np.ndim(x)
+ # generate all permutation of [0, 1, ... n-1] in random order.
+ all_perm = np.random.permutation(
+ [p for p in itertools.permutations(range(n))]).astype(np.int32)
+ for p in all_perm[0:2]:
+ tf_a_cpu, tf_g_cpu = self._compareCpu(x, p)
+ tf_a_gpu, tf_g_gpu = self._compareGpu(x, p)
+ assert tf_g_cpu is not None
+ assert tf_g_gpu is not None
+ if x.dtype == np.float32:
+ self.assertAllClose(tf_a_cpu, tf_a_gpu, 1e-3, 1e-3)
+ self.assertAllClose(tf_g_cpu, tf_g_gpu, 1e-3, 1e-3)
+ elif x.dtype == np.float64:
+ self.assertAllClose(tf_a_cpu, tf_a_gpu, 1e-6, 1e-6)
+ self.assertAllClose(tf_g_cpu, tf_g_gpu, 1e-6, 1e-6)
+
+ def _testCpu(self, x):
+ self._compare(x, use_gpu=False)
+
+ def test1D(self):
+ self._compareCpu(np.arange(0., 2), [0])
+
+ def testNop(self):
+ self._compareCpu(np.arange(0, 6).reshape([3, 2]).astype(np.float32), [0, 1])
+
+ def testSimple(self):
+ self._compareCpu(np.arange(0, 8).reshape([2, 4]).astype(np.float32),
+ np.array([1, 0]).astype(np.int32))
+
+ def testFloat(self):
+ self._compare_cpu_gpu(np.arange(0, 21).reshape([3, 7]).astype(np.float32))
+ self._compare_cpu_gpu(
+ np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.float32))
+
+ def testDouble(self):
+ self._compare_cpu_gpu(np.arange(0, 21).reshape([3, 7]).astype(np.float64))
+ self._compare_cpu_gpu(
+ np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.float64))
+
+ def testSComplex(self):
+ self._testCpu(np.complex(1, 2) * np.arange(0, 21).reshape(
+ [3, 7]).astype(np.complex64))
+ self._testCpu(np.complex(1, 2) * np.arange(0, 210).reshape(
+ [2, 3, 5, 7]).astype(np.complex64))
+
+ def testInt8(self):
+ self._testCpu(np.arange(0, 21).reshape([3, 7]).astype(np.int8))
+ self._testCpu(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int8))
+
+ def testInt16(self):
+ self._testCpu(np.arange(0, 21).reshape([3, 7]).astype(np.int16))
+ self._testCpu(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int16))
+
+ def testInt32(self):
+ self._testCpu(np.arange(0, 21).reshape([3, 7]).astype(np.int32))
+ self._testCpu(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int32))
+
+ def testInt64(self):
+ self._testCpu(np.arange(0, 21).reshape([3, 7]).astype(np.int64))
+ self._testCpu(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int64))
+
+ def testTranspose2DAuto(self):
+ x_np = [[1, 2, 3], [4, 5, 6]]
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = tf.transpose(x_np).eval()
+ self.assertAllEqual(x_tf, [[1, 4], [2, 5], [3, 6]])
+
+ def testTransposeShapes(self):
+ self.assertEqual([], tf.transpose(
+ tf.placeholder(tf.int32, shape=[])).get_shape().dims)
+ self.assertEqual([100], tf.transpose(
+ tf.placeholder(tf.int32, shape=[100])).get_shape().dims)
+ self.assertEqual([37, 100], tf.transpose(
+ tf.placeholder(tf.int32, shape=[100, 37])).get_shape().dims)
+ self.assertEqual([100, 37], tf.transpose(
+ tf.placeholder(tf.int32, shape=[100, 37]), [0, 1]).get_shape().dims)
+ self.assertEqual([15, 37, 100], tf.transpose(
+ tf.placeholder(tf.int32, shape=[100, 37, 15])).get_shape().dims)
+ self.assertEqual([15, 100, 37], tf.transpose(
+ tf.placeholder(tf.int32,
+ shape=[100, 37, 15]), [2, 0, 1]).get_shape().dims)
+ self.assertEqual(tf.TensorShape(None), tf.transpose(
+ tf.placeholder(tf.int32)).get_shape())
+
+ def _testError(self, x, p, err):
+ with self.test_session():
+ with self.assertRaisesOpError(err):
+ tf.transpose(x, p).eval()
+
+ def testError(self):
+ with self.assertRaises(ValueError):
+ tf.transpose(np.arange(0., 30).reshape([2, 3, 5]), [[0, 1], [2, 3]])
+ self._testError(np.arange(0., 2 ** 10).reshape([2] * 10),
+ range(10),
+ "not implemented")
+ with self.assertRaises(IndexError):
+ tf.transpose(np.arange(0., 30).reshape([2, 3, 5]), [0, 1, 3])
+ self._testError(np.arange(0., 30).reshape([2, 3, 5]),
+ [0, 1, 1],
+ "2 is missing")
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/unique_op_test.py b/tensorflow/python/kernel_tests/unique_op_test.py
new file mode 100644
index 0000000000..4d6543a206
--- /dev/null
+++ b/tensorflow/python/kernel_tests/unique_op_test.py
@@ -0,0 +1,22 @@
+"""Tests for tensorflow.kernels.unique_op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class UniqueTest(tf.test.TestCase):
+
+ def testInt32(self):
+ x = list(np.random.randint(2, high=10, size=7000))
+ with self.test_session() as sess:
+ y, idx = tf.unique(x)
+ tf_y, tf_idx = sess.run([y, idx])
+
+ self.assertEqual(len(x), len(tf_idx))
+ self.assertEqual(len(tf_y), len(np.unique(x)))
+ for i in range(len(x)):
+ self.assertEqual(x[i], tf_y[tf_idx[i]])
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/unpack_op_test.py b/tensorflow/python/kernel_tests/unpack_op_test.py
new file mode 100644
index 0000000000..4929af035f
--- /dev/null
+++ b/tensorflow/python/kernel_tests/unpack_op_test.py
@@ -0,0 +1,56 @@
+"""Functional tests for Unpack Op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker
+
+
+class UnpackOpTest(tf.test.TestCase):
+
+ def testSimple(self):
+ np.random.seed(7)
+ for use_gpu in False, True:
+ with self.test_session(use_gpu=use_gpu):
+ for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
+ data = np.random.randn(*shape)
+ # Convert data to a single tensorflow tensor
+ x = tf.constant(data)
+ # Unpack into a list of tensors
+ cs = tf.unpack(x, num=shape[0])
+ self.assertEqual(type(cs), list)
+ self.assertEqual(len(cs), shape[0])
+ cs = [c.eval() for c in cs]
+ self.assertAllEqual(cs, data)
+
+ def testGradients(self):
+ for use_gpu in False, True:
+ for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
+ data = np.random.randn(*shape)
+ shapes = [shape[1:]] * shape[0]
+ for i in xrange(shape[0]):
+ with self.test_session(use_gpu=use_gpu):
+ x = tf.constant(data)
+ cs = tf.unpack(x, num=shape[0])
+ err = gradient_checker.ComputeGradientError(x, shape, cs[i],
+ shapes[i])
+ self.assertLess(err, 1e-6)
+
+ def testInferNum(self):
+ with self.test_session():
+ for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
+ x = tf.placeholder(np.float32, shape=shape)
+ cs = tf.unpack(x)
+ self.assertEqual(type(cs), list)
+ self.assertEqual(len(cs), shape[0])
+
+ def testCannotInferNum(self):
+ x = tf.placeholder(np.float32)
+ with self.assertRaisesRegexp(
+ ValueError, r'Cannot infer num from shape TensorShape\(None\)'):
+ tf.unpack(x)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/variable_ops_test.py b/tensorflow/python/kernel_tests/variable_ops_test.py
new file mode 100644
index 0000000000..aaa4237260
--- /dev/null
+++ b/tensorflow/python/kernel_tests/variable_ops_test.py
@@ -0,0 +1,225 @@
+"""Tests for tensorflow.ops.tf.variable_op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_state_ops
+from tensorflow.python.ops import state_ops
+
+
+_NP_TO_TF = {
+ np.float32: tf.float32,
+ np.float64: tf.float64,
+ np.int32: tf.int32,
+ np.int64: tf.int64,
+}
+
+
+class VariableOpTest(tf.test.TestCase):
+
+ def _initFetch(self, x, tftype, use_gpu=None):
+ with self.test_session(use_gpu=use_gpu):
+ p = state_ops.variable_op(x.shape, tftype)
+ op = tf.assign(p, x)
+ op.op.run()
+ return p.eval()
+
+ def _testTypes(self, vals):
+ for dtype in [np.float32, np.float64, np.int32, np.int64]:
+ self.setUp()
+ x = vals.astype(dtype)
+ tftype = _NP_TO_TF[dtype]
+ self.assertAllEqual(x, self._initFetch(x, tftype, use_gpu=False))
+ # NOTE(mdevin): the GPU test should pass for all types, whether the
+ # Variable op has an implementation for that type on GPU as we expect
+ # that Variable and Assign have GPU implementations for matching tf.
+ self.assertAllEqual(x, self._initFetch(x, tftype, use_gpu=True))
+
+ def testBasic(self):
+ self._testTypes(np.arange(0, 20).reshape([4, 5]))
+
+ def testset_shape(self):
+ p = state_ops.variable_op([1, 2], tf.float32)
+ self.assertEqual([1, 2], p.get_shape())
+ p = state_ops.variable_op([1, 2], tf.float32, set_shape=False)
+ self.assertEqual(tensor_shape.unknown_shape(), p.get_shape())
+
+ def testAssign(self):
+ value = np.array([[42.0, 43.0]])
+ var = state_ops.variable_op(value.shape, tf.float32)
+ self.assertShapeEqual(value, var)
+ assigned = tf.assign(var, value)
+ self.assertShapeEqual(value, assigned)
+
+ def testAssignNoValidateShape(self):
+ value = np.array([[42.0, 43.0]])
+ var = state_ops.variable_op(value.shape, tf.float32)
+ self.assertShapeEqual(value, var)
+ assigned = tf.assign(var, value, validate_shape=False)
+ self.assertShapeEqual(value, assigned)
+
+ def testAssignNoVarShape(self):
+ value = np.array([[42.0, 43.0]])
+ var = state_ops.variable_op(value.shape, tf.float32, set_shape=False)
+ self.assertEqual(tensor_shape.unknown_shape(), var.get_shape())
+ assigned = tf.assign(var, value)
+ self.assertShapeEqual(value, assigned)
+
+ def testAssignNoVarShapeNoValidateShape(self):
+ value = np.array([[42.0, 43.0]])
+ var = state_ops.variable_op(value.shape, tf.float32, set_shape=False)
+ self.assertEqual(tensor_shape.unknown_shape(), var.get_shape())
+ assigned = tf.assign(var, value, validate_shape=False)
+ self.assertShapeEqual(value, assigned)
+
+ def _NewShapelessTensor(self):
+ tensor = tf.placeholder(tf.float32)
+ self.assertEqual(tensor_shape.unknown_shape(), tensor.get_shape())
+ return tensor
+
+ def testAssignNoValueShape(self):
+ value = self._NewShapelessTensor()
+ shape = [1, 2]
+ var = state_ops.variable_op(shape, tf.float32)
+ assigned = tf.assign(var, value)
+ self.assertEqual(shape, var.get_shape())
+ self.assertEqual(shape, assigned.get_shape())
+
+ def testAssignNoValueShapeNoValidateShape(self):
+ value = self._NewShapelessTensor()
+ shape = [1, 2]
+ var = state_ops.variable_op(shape, tf.float32)
+ self.assertEqual(shape, var.get_shape())
+ assigned = tf.assign(var, value, validate_shape=False)
+ self.assertEqual(tensor_shape.unknown_shape(), assigned.get_shape())
+
+ def testAssignNoShape(self):
+ with self.test_session():
+ value = self._NewShapelessTensor()
+ var = state_ops.variable_op([1, 2], tf.float32, set_shape=False)
+ self.assertEqual(tensor_shape.unknown_shape(), var.get_shape())
+ self.assertEqual(tensor_shape.unknown_shape(),
+ tf.assign(var, value).get_shape())
+
+ def testAssignNoShapeNoValidateShape(self):
+ with self.test_session():
+ value = self._NewShapelessTensor()
+ var = state_ops.variable_op([1, 2], tf.float32, set_shape=False)
+ self.assertEqual(tensor_shape.unknown_shape(), var.get_shape())
+ self.assertEqual(tensor_shape.unknown_shape(),
+ tf.assign(var, value, validate_shape=False).get_shape())
+
+ def testAssignUpdate(self):
+ var = state_ops.variable_op([1, 2], tf.float32)
+ added = tf.assign_add(var, [[2.0, 3.0]])
+ self.assertEqual([1, 2], added.get_shape())
+ subbed = tf.assign_sub(var, [[12.0, 13.0]])
+ self.assertEqual([1, 2], subbed.get_shape())
+
+ def testAssignUpdateNoVarShape(self):
+ var = state_ops.variable_op([1, 2], tf.float32, set_shape=False)
+ added = tf.assign_add(var, [[2.0, 3.0]])
+ self.assertEqual([1, 2], added.get_shape())
+ subbed = tf.assign_sub(var, [[12.0, 13.0]])
+ self.assertEqual([1, 2], subbed.get_shape())
+
+ def testAssignUpdateNoValueShape(self):
+ var = state_ops.variable_op([1, 2], tf.float32)
+ added = tf.assign_add(var, self._NewShapelessTensor())
+ self.assertEqual([1, 2], added.get_shape())
+ subbed = tf.assign_sub(var, self._NewShapelessTensor())
+ self.assertEqual([1, 2], subbed.get_shape())
+
+ def testAssignUpdateNoShape(self):
+ var = state_ops.variable_op([1, 2], tf.float32, set_shape=False)
+ added = tf.assign_add(var, self._NewShapelessTensor())
+ self.assertEqual(tensor_shape.unknown_shape(), added.get_shape())
+ subbed = tf.assign_sub(var, self._NewShapelessTensor())
+ self.assertEqual(tensor_shape.unknown_shape(), subbed.get_shape())
+
+ def testTemporaryVariable(self):
+ with self.test_session(use_gpu=True):
+ var = gen_state_ops._temporary_variable(
+ [1, 2],
+ tf.float32,
+ var_name="foo")
+ var = tf.assign(var, [[4.0, 5.0]])
+ var = tf.assign_add(var, [[6.0, 7.0]])
+ final = gen_state_ops._destroy_temporary_variable(var, var_name="foo")
+ self.assertAllClose([[10.0, 12.0]], final.eval())
+
+ def testDestroyNonexistentTemporaryVariable(self):
+ with self.test_session(use_gpu=True):
+ var = gen_state_ops._temporary_variable([1, 2], tf.float32)
+ final = gen_state_ops._destroy_temporary_variable(var, var_name="bad")
+ with self.assertRaises(errors.NotFoundError):
+ final.eval()
+
+ def testDuplicateTemporaryVariable(self):
+ with self.test_session(use_gpu=True):
+ var1 = gen_state_ops._temporary_variable(
+ [1, 2],
+ tf.float32,
+ var_name="dup")
+ var1 = tf.assign(var1, [[1.0, 2.0]])
+ var2 = gen_state_ops._temporary_variable(
+ [1, 2],
+ tf.float32,
+ var_name="dup")
+ var2 = tf.assign(var2, [[3.0, 4.0]])
+ final = var1 + var2
+ with self.assertRaises(errors.AlreadyExistsError):
+ final.eval()
+
+ def testDestroyTemporaryVariableTwice(self):
+ with self.test_session(use_gpu=True):
+ var = gen_state_ops._temporary_variable([1, 2], tf.float32)
+ val1 = gen_state_ops._destroy_temporary_variable(var, var_name="dup")
+ val2 = gen_state_ops._destroy_temporary_variable(var, var_name="dup")
+ final = val1 + val2
+ with self.assertRaises(errors.NotFoundError):
+ final.eval()
+
+ def testTemporaryVariableNoLeak(self):
+ with self.test_session(use_gpu=True):
+ var = gen_state_ops._temporary_variable(
+ [1, 2],
+ tf.float32,
+ var_name="bar")
+ final = tf.identity(var)
+ final.eval()
+
+ def testTwoTemporaryVariablesNoLeaks(self):
+ with self.test_session(use_gpu=True):
+ var1 = gen_state_ops._temporary_variable(
+ [1, 2],
+ tf.float32,
+ var_name="var1")
+ var2 = gen_state_ops._temporary_variable(
+ [1, 2],
+ tf.float32,
+ var_name="var2")
+ final = var1 + var2
+ final.eval()
+
+ def testAssignDependencyAcrossDevices(self):
+ with self.test_session(use_gpu=True):
+ # The variable and an op to increment it are on the GPU.
+ var = state_ops.variable_op([1], tf.float32)
+ tf.assign(var, [1.0]).eval()
+ increment = tf.assign_add(var, [1.0])
+ with tf.control_dependencies([increment]):
+ with tf.device("/cpu:0"):
+ # This mul op is pinned to the CPU, but reads the variable from the
+ # GPU. The test ensures that the dependency on 'increment' is still
+ # honored, i.e., the Send and Recv from GPU to CPU should take place
+ # only after the increment.
+ result = tf.mul(var, var)
+ self.assertAllClose([4.0], result.eval())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
new file mode 100644
index 0000000000..bb538198ea
--- /dev/null
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -0,0 +1,160 @@
+"""Tests for variable store."""
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+from tensorflow.python.ops import variable_scope
+
+
+class VariableStoreTest(tf.test.TestCase):
+
+ def testGetVar(self):
+ vs = variable_scope._get_default_variable_store()
+ v = vs.get_variable("v", [1])
+ v1 = vs.get_variable("v", [1])
+ assert v == v1
+
+ def testNameExists(self):
+ vs = variable_scope._get_default_variable_store()
+ # No check by default, so we can both create and get existing names.
+ v = vs.get_variable("v", [1])
+ v1 = vs.get_variable("v", [1])
+ assert v == v1
+ # When reuse is False, we fail when variables are already there.
+ vs.get_variable("w", [1], reuse=False) # That's ok.
+ with self.assertRaises(ValueError):
+ vs.get_variable("v", [1], reuse=False) # That fails.
+ # When reuse is True, we fail when variables are new.
+ vs.get_variable("v", [1], reuse=True) # That's ok.
+ with self.assertRaises(ValueError):
+ vs.get_variable("u", [1], reuse=True) # That fails.
+
+ def testNamelessStore(self):
+ vs = variable_scope._get_default_variable_store()
+ vs.get_variable("v1", [2])
+ vs.get_variable("v2", [2])
+ expected_names = ["%s:0" % name for name in ["v1", "v2"]]
+ self.assertEqual(set(expected_names),
+ set([v.name for v in vs._vars.values()]))
+
+ def testVarScopeIntializer(self):
+ with self.test_session() as sess:
+ init = tf.constant_initializer(0.3)
+ with variable_scope.variable_scope("tower") as tower:
+ with variable_scope.variable_scope("foo", initializer=init):
+ v = variable_scope.get_variable("v", [])
+ sess.run(tf.initialize_variables([v]))
+ self.assertAllClose(v.eval(), 0.3)
+ with variable_scope.variable_scope(tower, initializer=init):
+ w = variable_scope.get_variable("w", [])
+ sess.run(tf.initialize_variables([w]))
+ self.assertAllClose(w.eval(), 0.3)
+
+ def testGetVariableScope(self):
+ # Test the get_variable_scope() function and setting properties of result.
+ with self.test_session() as sess:
+ init = tf.constant_initializer(0.3)
+ with variable_scope.variable_scope("foo"):
+ new_init1 = variable_scope.get_variable_scope().initializer
+ self.assertEqual(new_init1, None)
+ # Check that we can set initializer like this.
+ variable_scope.get_variable_scope().set_initializer(init)
+ v = variable_scope.get_variable("v", [])
+ sess.run(tf.initialize_variables([v]))
+ self.assertAllClose(v.eval(), 0.3)
+ # Check that we can set reuse.
+ variable_scope.get_variable_scope().reuse_variables()
+ with self.assertRaises(ValueError): # Fail, w does not exist yet.
+ variable_scope.get_variable("w", [1])
+ # Check that the set initializer goes away.
+ new_init = variable_scope.get_variable_scope().initializer
+ self.assertEqual(new_init, None)
+
+ def testVarScope(self):
+ with self.test_session():
+ with variable_scope.variable_scope("tower") as tower:
+ self.assertEqual(tower.name, "tower")
+ with tf.name_scope("scope") as sc:
+ self.assertEqual(sc, "tower/scope/")
+
+ with variable_scope.variable_scope("foo"):
+ with variable_scope.variable_scope("bar") as bar:
+ self.assertEqual(bar.name, "foo/bar")
+ with tf.name_scope("scope") as sc:
+ self.assertEqual(sc, "foo/bar/scope/")
+
+ with variable_scope.variable_scope("foo"):
+ with variable_scope.variable_scope(tower, reuse=True) as tower_shared:
+ self.assertEqual(tower_shared.name, "tower")
+ with tf.name_scope("scope") as sc:
+ self.assertEqual(sc, "foo_1/scope/")
+
+ def testVarScopeNameScope(self):
+ with self.test_session():
+ with tf.name_scope("scope1"):
+ with variable_scope.variable_scope("tower") as tower:
+ with tf.name_scope("scope2") as sc2:
+ self.assertEqual(sc2, "scope1/tower/scope2/")
+ with variable_scope.variable_scope("tower"): # Re-enter adds suffix.
+ with tf.name_scope("scope2") as sc2:
+ self.assertEqual(sc2, "scope1/tower_1/scope2/")
+
+ with tf.name_scope("scope3"):
+ with variable_scope.variable_scope("tower"):
+ with tf.name_scope("scope2") as sc2:
+ self.assertEqual(sc2, "scope3/tower/scope2/")
+ with variable_scope.variable_scope(tower):
+ with tf.name_scope("scope2") as sc2:
+ self.assertEqual(sc2, "scope3/scope2/")
+
+ def testVarScopeGetVar(self):
+ with self.test_session():
+ with variable_scope.variable_scope("root"):
+ with variable_scope.variable_scope("towerA") as tower_a:
+ va = variable_scope.get_variable("v", [1])
+ self.assertEqual(va.name, "root/towerA/v:0")
+
+ with variable_scope.variable_scope(tower_a, reuse=True):
+ va2 = variable_scope.get_variable("v", [1])
+ self.assertEqual(va2, va)
+
+ with variable_scope.variable_scope("towerB"):
+ vb = variable_scope.get_variable("v", [1])
+ self.assertEqual(vb.name, "root/towerB/v:0")
+
+ with self.assertRaises(ValueError) as exc:
+ with variable_scope.variable_scope("towerA"):
+ va2 = variable_scope.get_variable("v", [1])
+ self.assertEqual(exc.exception.message[:12], "Over-sharing")
+
+ with variable_scope.variable_scope("towerA", reuse=True):
+ va2 = variable_scope.get_variable("v", [1])
+ self.assertEqual(va2, va)
+
+ with variable_scope.variable_scope("foo"):
+ with variable_scope.variable_scope("bar"):
+ v = variable_scope.get_variable("v", [1])
+ self.assertEqual(v.name, "root/foo/bar/v:0")
+ with variable_scope.variable_scope(tower_a, reuse=True):
+ va3 = variable_scope.get_variable("v", [1])
+ self.assertEqual(va, va3)
+
+ with self.assertRaises(ValueError) as exc:
+ with variable_scope.variable_scope(tower_a, reuse=True):
+ with variable_scope.variable_scope("baz"):
+ variable_scope.get_variable("v", [1])
+ self.assertEqual(exc.exception.message[:13], "Under-sharing")
+
+ with self.assertRaises(ValueError) as exc:
+ with variable_scope.variable_scope(tower_a, reuse=True):
+ variable_scope.get_variable("v", [2]) # Different shape.
+ self.assertEqual("shape" in exc.exception.message, True)
+
+ with self.assertRaises(ValueError) as exc:
+ with variable_scope.variable_scope(tower_a, reuse=True):
+ variable_scope.get_variable("v", [1], dtype=tf.int32)
+ self.assertEqual("dtype" in exc.exception.message, True)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
new file mode 100644
index 0000000000..f2a7ea0af8
--- /dev/null
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -0,0 +1,242 @@
+"""Tests for tf.py."""
+import operator
+
+import tensorflow.python.platform
+
+import numpy as np
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+from tensorflow.python.ops import random_ops
+
+
+class VariablesTestCase(tf.test.TestCase):
+
+ def testInitialization(self):
+ with self.test_session():
+ var0 = tf.Variable(0.0)
+ self.assertEqual("Variable:0", var0.name)
+ self.assertEqual([], var0.get_shape())
+ self.assertEqual([], var0.get_shape())
+
+ var1 = tf.Variable(1.1)
+ self.assertEqual("Variable_1:0", var1.name)
+ self.assertEqual([], var1.get_shape())
+ self.assertEqual([], var1.get_shape())
+
+ with self.assertRaisesOpError("Attempting to use uninitialized value"):
+ var0.eval()
+
+ with self.assertRaisesOpError("Attempting to use uninitialized value"):
+ var1.eval()
+
+ tf.initialize_all_variables().run()
+
+ self.assertAllClose(0.0, var0.eval())
+ self.assertAllClose(1.1, var1.eval())
+
+ def testInitializationOrder(self):
+ with self.test_session():
+ rnd = tf.Variable(random_ops.random_uniform([3, 6]), name="rnd")
+ self.assertEqual("rnd:0", rnd.name)
+ self.assertEqual([3, 6], rnd.get_shape())
+ self.assertEqual([3, 6], rnd.get_shape())
+
+ dep = tf.Variable(rnd.initialized_value(), name="dep")
+ self.assertEqual("dep:0", dep.name)
+ self.assertEqual([3, 6], dep.get_shape())
+ self.assertEqual([3, 6], dep.get_shape())
+
+ # Currently have to set the shape manually for Add.
+ added_val = rnd.initialized_value() + dep.initialized_value() + 2.0
+ added_val.set_shape(rnd.get_shape())
+
+ depdep = tf.Variable(added_val, name="depdep")
+ self.assertEqual("depdep:0", depdep.name)
+ self.assertEqual([3, 6], depdep.get_shape())
+ self.assertEqual([3, 6], depdep.get_shape())
+
+ tf.initialize_all_variables().run()
+
+ self.assertAllClose(rnd.eval(), dep.eval())
+ self.assertAllClose(rnd.eval() + dep.eval() + 2.0,
+ depdep.eval())
+
+ def testAssignments(self):
+ with self.test_session():
+ var = tf.Variable(0.0)
+ plus_one = var.assign_add(1.0)
+ minus_one = var.assign_sub(2.0)
+ four = var.assign(4.0)
+ tf.initialize_all_variables().run()
+ self.assertAllClose(0.0, var.eval())
+
+ self.assertAllClose(1.0, plus_one.eval())
+ self.assertAllClose(1.0, var.eval())
+
+ self.assertAllClose(-1.0, minus_one.eval())
+ self.assertAllClose(-1.0, var.eval())
+
+ self.assertAllClose(4.0, four.eval())
+ self.assertAllClose(4.0, var.eval())
+
+ def _countUpToTest(self, dtype):
+ with self.test_session():
+ zero = tf.constant(0, dtype=dtype)
+ var = tf.Variable(zero)
+ count_up_to = var.count_up_to(3)
+
+ tf.initialize_all_variables().run()
+ self.assertEqual(0, var.eval())
+
+ self.assertEqual(0, count_up_to.eval())
+ self.assertEqual(1, var.eval())
+
+ self.assertEqual(1, count_up_to.eval())
+ self.assertEqual(2, var.eval())
+
+ self.assertEqual(2, count_up_to.eval())
+ self.assertEqual(3, var.eval())
+
+ with self.assertRaisesOpError("Reached limit of 3"):
+ count_up_to.eval()
+ self.assertEqual(3, var.eval())
+
+ with self.assertRaisesOpError("Reached limit of 3"):
+ count_up_to.eval()
+ self.assertEqual(3, var.eval())
+
+ def testCountUpToInt32(self):
+ self._countUpToTest(tf.int32)
+
+ def testCountUpToInt64(self):
+ self._countUpToTest(tf.int64)
+
+ def testUseVariableAsTensor(self):
+ with self.test_session():
+ var_x = tf.Variable(2.0)
+ var_y = tf.Variable(3.0)
+ tf.initialize_all_variables().run()
+ self.assertAllClose(2.0, var_x.eval())
+ self.assertAllClose(3.0, var_y.eval())
+ self.assertAllClose(5.0, tf.add(var_x, var_y).eval())
+
+ def testCollections(self):
+ with self.test_session():
+ var_x = tf.Variable(2.0)
+ var_y = tf.Variable(2.0, trainable=False)
+ var_z = tf.Variable(2.0, trainable=True)
+ var_t = tf.Variable(
+ 2.0, trainable=True,
+ collections=[tf.GraphKeys.TRAINABLE_VARIABLES,
+ tf.GraphKeys.VARIABLES])
+ self.assertEqual([var_x, var_y, var_z, var_t], tf.all_variables())
+ self.assertEqual([var_x, var_z, var_t], tf.trainable_variables())
+
+ def testOperators(self):
+ with self.test_session():
+ var_f = tf.Variable([2.0])
+ add = var_f + 0.0
+ radd = 1.0 + var_f
+ sub = var_f - 1.0
+ rsub = 1.0 - var_f
+ mul = var_f * 10.0
+ rmul = 10.0 * var_f
+ div = var_f / 10.0
+ rdiv = 10.0 / var_f
+ lt = var_f < 3.0
+ rlt = 3.0 < var_f
+ le = var_f <= 2.0
+ rle = 2.0 <= var_f
+ gt = var_f > 3.0
+ rgt = 3.0 > var_f
+ ge = var_f >= 2.0
+ rge = 2.0 >= var_f
+ neg = -var_f
+ abs_v = abs(var_f)
+
+ var_i = tf.Variable([20])
+ mod = var_i % 7
+ rmod = 103 % var_i
+
+ var_b = tf.Variable([True, False])
+ and_v = operator.and_(var_b, [True, True])
+ or_v = operator.or_(var_b, [False, True])
+ xor_v = operator.xor(var_b, [False, False])
+ invert_v = ~var_b
+
+ rnd = np.random.rand(4, 4).astype("f")
+ var_t = tf.Variable(rnd)
+ slice_v = var_t[2, 0:0]
+
+ tf.initialize_all_variables().run()
+ self.assertAllClose([2.0], add.eval())
+ self.assertAllClose([3.0], radd.eval())
+ self.assertAllClose([1.0], sub.eval())
+ self.assertAllClose([-1.0], rsub.eval())
+ self.assertAllClose([20.0], mul.eval())
+ self.assertAllClose([20.0], rmul.eval())
+ self.assertAllClose([0.2], div.eval())
+ self.assertAllClose([5.0], rdiv.eval())
+ self.assertAllClose([-2.0], neg.eval())
+ self.assertAllClose([2.0], abs_v.eval())
+ self.assertAllClose([True], lt.eval())
+ self.assertAllClose([False], rlt.eval())
+ self.assertAllClose([True], le.eval())
+ self.assertAllClose([True], rle.eval())
+ self.assertAllClose([False], gt.eval())
+ self.assertAllClose([True], rgt.eval())
+ self.assertAllClose([True], ge.eval())
+ self.assertAllClose([True], rge.eval())
+
+ self.assertAllClose([6], mod.eval())
+ self.assertAllClose([3], rmod.eval())
+
+ self.assertAllClose([True, False], and_v.eval())
+ self.assertAllClose([True, True], or_v.eval())
+ self.assertAllClose([True, False], xor_v.eval())
+ self.assertAllClose([False, True], invert_v.eval())
+
+ self.assertAllClose(rnd[2, 0:0], slice_v.eval())
+
+ def testSession(self):
+ with self.test_session() as sess:
+ var = tf.Variable([1, 12])
+ tf.initialize_all_variables().run()
+ self.assertAllClose([1, 12], sess.run(var))
+
+
+class IsInitializedTest(tf.test.TestCase):
+
+ def testNoVars(self):
+ with tf.Graph().as_default():
+ self.assertEqual(None, tf.assert_variables_initialized())
+
+ def testVariables(self):
+ with tf.Graph().as_default(), self.test_session() as sess:
+ v = tf.Variable([1, 2])
+ w = tf.Variable([3, 4])
+ _ = v, w
+ inited = tf.assert_variables_initialized()
+ with self.assertRaisesOpError("Attempting to use uninitialized value"):
+ sess.run(inited)
+ tf.initialize_all_variables().run()
+ sess.run(inited)
+
+ def testVariableList(self):
+ with tf.Graph().as_default(), self.test_session() as sess:
+ v = tf.Variable([1, 2])
+ w = tf.Variable([3, 4])
+ inited = tf.assert_variables_initialized([v])
+ with self.assertRaisesOpError("Attempting to use uninitialized value"):
+ inited.op.run()
+ sess.run(w.initializer)
+ with self.assertRaisesOpError("Attempting to use uninitialized value"):
+ inited.op.run()
+ v.initializer.run()
+ inited.op.run()
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py
new file mode 100644
index 0000000000..263f98f622
--- /dev/null
+++ b/tensorflow/python/kernel_tests/where_op_test.py
@@ -0,0 +1,43 @@
+"""Tests for tensorflow.ops.reverse_sequence_op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class WhereOpTest(tf.test.TestCase):
+
+ def _testWhere(self, x, truth, expected_err_re=None):
+ with self.test_session():
+ ans = tf.where(x)
+ self.assertEqual([None, x.ndim], ans.get_shape().as_list())
+ if expected_err_re is None:
+ tf_ans = ans.eval()
+ self.assertAllClose(tf_ans, truth, atol=1e-10)
+ else:
+ with self.assertRaisesOpError(expected_err_re):
+ ans.eval()
+
+ def testBasicMat(self):
+ x = np.asarray([[True, False], [True, False]])
+
+ # Ensure RowMajor mode
+ truth = np.asarray([[0, 0], [1, 0]], dtype=np.int64)
+
+ self._testWhere(x, truth)
+
+ def testBasic3Tensor(self):
+ x = np.asarray(
+ [[[True, False], [True, False]], [[False, True], [False, True]],
+ [[False, False], [False, True]]])
+
+ # Ensure RowMajor mode
+ truth = np.asarray(
+ [[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1], [2, 1, 1]],
+ dtype=np.int64)
+
+ self._testWhere(x, truth)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
new file mode 100644
index 0000000000..4e44472c0d
--- /dev/null
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -0,0 +1,110 @@
+"""Tests for SoftmaxCrossEntropyWithLogits op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+
+class XentTest(tf.test.TestCase):
+
+ def _npXent(self, features, labels):
+ batch_dim = 0
+ class_dim = 1
+ batch_size = features.shape[batch_dim]
+ e = np.exp(features -
+ np.reshape(np.amax(features, axis=class_dim), [batch_size, 1]))
+ probs = e / np.reshape(np.sum(e, axis=class_dim), [batch_size, 1])
+ bp = (probs - labels)
+ l = -np.sum(labels * np.log(probs + 1.0e-20), axis=1)
+ return l, bp
+
+ def _testXent(self, np_features, np_labels, use_gpu=False):
+ np_loss, np_backprop = self._npXent(np_features, np_labels)
+ with self.test_session(use_gpu=use_gpu) as sess:
+ loss = tf.nn.softmax_cross_entropy_with_logits(np_features, np_labels)
+ backprop = loss.op.outputs[1]
+ tf_loss, tf_backprop = sess.run([loss, backprop])
+ self.assertAllClose(np_loss, tf_loss)
+ self.assertAllClose(np_backprop, tf_backprop)
+
+ def _testAll(self, features, labels):
+ self._testXent(features, labels, use_gpu=False)
+ self._testXent(features, labels, use_gpu=True)
+
+ def testNpXent(self):
+ # We create 2 batches of logits for testing.
+ # batch 0 is the boring uniform distribution: 1, 1, 1, 1, with target 3.
+ # batch 1 has a bit of difference: 1, 2, 3, 4, with soft targets (1, 2).
+ features = [[1., 1., 1., 1.], [1., 2., 3., 4.]]
+ labels = [[0., 0., 0., 1.], [0., .5, .5, 0.]]
+
+ # For batch 0, we expect the uniform distribution: 0.25, 0.25, 0.25, 0.25
+ # With a hard target 3, the backprop is [0.25, 0.25, 0.25, -0.75]
+ # The loss for this batch is -log(0.25) = 1.386
+ #
+ # For batch 1, we have:
+ # exp(0) = 1
+ # exp(1) = 2.718
+ # exp(2) = 7.389
+ # exp(3) = 20.085
+ # SUM = 31.192
+ # So we have as probabilities:
+ # exp(0) / SUM = 0.032
+ # exp(1) / SUM = 0.087
+ # exp(2) / SUM = 0.237
+ # exp(3) / SUM = 0.644
+ # With a soft target (1, 2), the backprop is
+ # [0.032, 0.087 - 0.5 = -0.413, 0.237 - 0.5 = -0.263, 0.644]
+ # The loss for this batch is [0.5 * -log(0.087), 0.5 * -log(0.237)]
+ # = [1.3862, 1.9401]
+ np_loss, np_backprop = self._npXent(np.array(features), np.array(labels))
+ self.assertAllClose(np.array([[0.25, 0.25, 0.25, -0.75],
+ [0.0321, -0.4129, -0.2632, 0.6439]]),
+ np_backprop,
+ rtol=1.e-3, atol=1.e-3)
+ self.assertAllClose(np.array([1.3862, 1.9401]), np_loss,
+ rtol=1.e-3, atol=1.e-3)
+
+ def testShapeMismatch(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.nn.softmax_cross_entropy_with_logits(
+ [[0., 1.], [2., 3.]], [[0., 1., 0.], [1., 0., 0.]])
+
+ def testNotMatrix(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.nn.softmax_cross_entropy_with_logits([0., 1., 2., 3.],
+ [0., 1., 0., 1.])
+
+ def testFloat(self):
+ self._testAll(
+ np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32),
+ np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float32))
+
+ def testDouble(self):
+ self._testXent(
+ np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64),
+ np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float64),
+ use_gpu=False)
+
+ def testGradient(self):
+ with self.test_session():
+ l = tf.constant([0.0, 0.0, 1.0, 0.0,
+ 1.0, 0.0, 0.0, 0.0,
+ 0.0, 0.5, 0.0, 0.5], shape=[3, 4],
+ dtype=tf.float64, name="l")
+ f = tf.constant([0.1, 0.2, 0.3, 0.4,
+ 0.1, 0.4, 0.9, 1.6,
+ 0.1, 0.8, 2.7, 6.4], shape=[3, 4],
+ dtype=tf.float64, name="f")
+ x = tf.nn.softmax_cross_entropy_with_logits(f, l, name="xent")
+ err = gc.ComputeGradientError(f, [3, 4], x, [3])
+ print "cross entropy gradient err = ", err
+ self.assertLess(err, 5e-8)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/lib/__init__.py b/tensorflow/python/lib/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/lib/__init__.py
diff --git a/tensorflow/python/lib/core/__init__.py b/tensorflow/python/lib/core/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/lib/core/__init__.py
diff --git a/tensorflow/python/lib/core/pywrap_status_test.py b/tensorflow/python/lib/core/pywrap_status_test.py
new file mode 100644
index 0000000000..000a784b6c
--- /dev/null
+++ b/tensorflow/python/lib/core/pywrap_status_test.py
@@ -0,0 +1,35 @@
+"""Tests for SWIG wrapped brain::Status."""
+
+from tensorflow.core.lib.core import error_codes_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.platform import googletest
+
+
+class StatusTest(googletest.TestCase):
+
+ def testDefaultOk(self):
+ status = pywrap_tensorflow.Status()
+ self.assertTrue(status.ok())
+
+ def testCodeAndMessage(self):
+ status = pywrap_tensorflow.Status(error_codes_pb2.INVALID_ARGUMENT, 'foo')
+ self.assertEqual(error_codes_pb2.INVALID_ARGUMENT, status.code())
+ self.assertEqual('foo', status.error_message())
+
+ def testToString(self):
+ status = pywrap_tensorflow.Status()
+ # .ToString was remapped in the .swig file, hence will not work
+ # self.assertIn('OK', status.ToString())
+ self.assertIn('OK', str(status))
+
+ def testException(self):
+ with self.assertRaises(pywrap_tensorflow.StatusNotOK) as context:
+ pywrap_tensorflow.NotOkay()
+ self.assertEqual(context.exception.code, error_codes_pb2.INVALID_ARGUMENT)
+ self.assertEqual(context.exception.error_message, 'Testing 1 2 3')
+ self.assertEqual(None, pywrap_tensorflow.Okay(),
+ 'Status wrapper should not return anything upon OK.')
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/lib/core/status.i b/tensorflow/python/lib/core/status.i
new file mode 100644
index 0000000000..fddbc31e24
--- /dev/null
+++ b/tensorflow/python/lib/core/status.i
@@ -0,0 +1,116 @@
+// SWIG wrapper for lib::tensorflow::Status
+
+%include "tensorflow/python/platform/base.i"
+%include "tensorflow/python/lib/core/strings.i"
+
+%apply int { tensorflow::error::Code }; // Treat the enum as an integer.
+
+%{
+#include "tensorflow/core/public/status.h"
+%}
+
+%typemap(out, fragment="StatusNotOK") tensorflow::Status {
+ if ($1.ok()) {
+ $result = SWIG_Py_Void();
+ } else {
+ RaiseStatusNotOK($1, $descriptor(tensorflow::Status*));
+ SWIG_fail;
+ }
+}
+
+%init %{
+// Setup the StatusNotOK exception class.
+PyObject *pywrap_status = PyImport_ImportModuleNoBlock(
+ "tensorflow.python.pywrap_tensorflow");
+if (pywrap_status) {
+ PyObject *exception = PyErr_NewException(
+ "tensorflow.python.pywrap_tensorflow.StatusNotOK",
+ NULL, NULL);
+ if (exception) {
+ PyModule_AddObject(pywrap_status, "StatusNotOK", exception); // Steals ref.
+ }
+ Py_DECREF(pywrap_status);
+}
+%}
+
+%fragment("StatusNotOK", "header") %{
+#include "tensorflow/core/public/status.h"
+
+namespace {
+// Initialized on the first call to RaiseStatusNotOK().
+static PyObject *StatusNotOKError = nullptr;
+
+inline void Py_DECREF_wrapper(PyObject *o) { Py_DECREF(o); }
+typedef std::unique_ptr<PyObject, decltype(&Py_DECREF_wrapper)> SafePyObjectPtr;
+SafePyObjectPtr make_safe(PyObject* o) {
+ return SafePyObjectPtr(o, Py_DECREF_wrapper);
+}
+
+void RaiseStatusNotOK(const tensorflow::Status& status, swig_type_info *type) {
+ const int code = status.code();
+ string fullmsg = status.ToString();
+
+ PyObject *exception = nullptr;
+
+ // We're holding the Python GIL, so we don't need to synchronize
+ // access to StatusNotOKError with a Mutex of our own.
+ if (!StatusNotOKError) {
+ PyObject *cls = nullptr;
+ auto pywrap = make_safe(PyImport_ImportModule(
+ "tensorflow.python.pywrap_tensorflow"));
+ if (pywrap) {
+ cls = PyObject_GetAttrString(pywrap.get(), "StatusNotOK");
+ }
+ if (!cls) {
+ cls = Py_None;
+ Py_INCREF(cls);
+ }
+ StatusNotOKError = cls;
+ }
+
+ if (StatusNotOKError != Py_None) {
+ auto fullmsg_ptr = make_safe(_SwigString_FromString(fullmsg));
+ auto exception_ptr = make_safe(PyObject_CallFunctionObjArgs(
+ StatusNotOKError, fullmsg_ptr.get(), NULL));
+ exception = exception_ptr.get();
+ if (exception) {
+ auto pycode = make_safe(PyInt_FromLong(static_cast<long>(code)));
+ auto pymsg = make_safe(_SwigString_FromString(status.error_message()));
+ auto pystatus = make_safe(SWIG_NewPointerObj(
+ SWIG_as_voidptr(new tensorflow::Status(status)), type, SWIG_POINTER_OWN));
+ PyObject_SetAttrString(exception, "code", pycode.get());
+ PyObject_SetAttrString(exception, "error_message", pymsg.get());
+ PyErr_SetObject(StatusNotOKError, exception);
+ }
+ }
+ if (!exception) {
+ fullmsg =
+ ("could not construct StatusNotOK (original error "
+ " was: " +
+ fullmsg + ")");
+ PyErr_SetString(PyExc_SystemError, fullmsg.c_str());
+ }
+}
+
+} // namespace
+%}
+
+%ignoreall
+
+%unignore tensorflow;
+%unignore tensorflow::lib;
+%unignore tensorflow::Status;
+%unignore tensorflow::Status::Status;
+%unignore tensorflow::Status::Status(tensorflow::error::Code, StringPiece);
+%unignore tensorflow::Status::~Status;
+%unignore tensorflow::Status::code;
+%unignore tensorflow::Status::ok;
+%unignore tensorflow::Status::error_message;
+%unignore tensorflow::Status::ToString;
+%ignore tensorflow::Status::operator=;
+
+%rename(__str__) tensorflow::Status::ToString;
+
+%include "tensorflow/core/public/status.h"
+
+%unignoreall
diff --git a/tensorflow/python/lib/core/status_helper.i b/tensorflow/python/lib/core/status_helper.i
new file mode 100644
index 0000000000..2e01e79ebd
--- /dev/null
+++ b/tensorflow/python/lib/core/status_helper.i
@@ -0,0 +1,16 @@
+// SWIG test helper for lib::tensorflow::Status
+
+%include "tensorflow/python/platform/base.i"
+%import(module="tensorflow.python.pywrap_tensorflow") "tensorflow/python/lib/core/status.i"
+
+%inline %{
+#include "tensorflow/core/public/status.h"
+
+tensorflow::Status NotOkay() {
+ return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, "Testing 1 2 3");
+}
+
+tensorflow::Status Okay() {
+ return tensorflow::Status();
+}
+%}
diff --git a/tensorflow/python/lib/core/strings.i b/tensorflow/python/lib/core/strings.i
new file mode 100644
index 0000000000..c88e426a54
--- /dev/null
+++ b/tensorflow/python/lib/core/strings.i
@@ -0,0 +1,94 @@
+// Wrapper functions to provide a scripting-language-friendly interface
+// to our string libraries.
+//
+// NOTE: as of 2005-01-13, this SWIG file is not used to generate a pywrap
+// library for manipulation of various string-related types or access
+// to the special string functions (Python has plenty). This SWIG file
+// should be %import'd so that other SWIG wrappers have proper access
+// to the types in //strings (such as the StringPiece object). We may
+// generate a pywrap at some point in the future.
+//
+// NOTE: (Dan Ardelean) as of 2005-11-15 added typemaps to convert Java String
+// arguments to C++ StringPiece& objects. This is required because a
+// StringPiece class does not make sense - the code SWIG generates for a
+// StringPiece class is useless, because it releases the buffer set in
+// StringPiece after creating the object. C++ StringPiece objects rely on
+// the buffer holding the data being allocated externally.
+
+// NOTE: for now, we'll just start with what is needed, and add stuff
+// as it comes up.
+
+%{
+#include "tensorflow/core/lib/core/stringpiece.h"
+%}
+
+%typemap(typecheck) tensorflow::StringPiece = char *;
+%typemap(typecheck) const tensorflow::StringPiece & = char *;
+
+// "tensorflow::StringPiece" arguments can be provided by a simple Python 'str' string
+// or a 'unicode' object. If 'unicode', it's translated using the default
+// encoding, i.e., sys.getdefaultencoding(). If passed None, a tensorflow::StringPiece
+// of zero length with a NULL pointer is provided.
+%typemap(in) tensorflow::StringPiece {
+ if ($input != Py_None) {
+ char * buf;
+ Py_ssize_t len;
+%#if PY_VERSION_HEX >= 0x03030000
+ /* Do unicode handling as PyBytes_AsStringAndSize doesn't in Python 3. */
+ if (PyUnicode_Check($input)) {
+ buf = PyUnicode_AsUTF8AndSize($input, &len);
+ if (buf == NULL)
+ SWIG_fail;
+ } else {
+%#elif PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 3
+%# error "Unsupported Python 3.x C API version (3.3 or later required)."
+%#endif
+ if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) {
+ // Python has raised an error (likely TypeError or UnicodeEncodeError).
+ SWIG_fail;
+ }
+%#if PY_VERSION_HEX >= 0x03030000
+ }
+%#endif
+ $1.set(buf, len);
+ }
+}
+
+// "const tensorflow::StringPiece&" arguments can be provided the same as
+// "tensorflow::StringPiece", whose typemap is defined above.
+%typemap(in) const tensorflow::StringPiece & (tensorflow::StringPiece temp) {
+ if ($input != Py_None) {
+ char * buf;
+ Py_ssize_t len;
+%#if PY_VERSION_HEX >= 0x03030000
+ /* Do unicode handling as PyBytes_AsStringAndSize doesn't in Python 3. */
+ if (PyUnicode_Check($input)) {
+ buf = PyUnicode_AsUTF8AndSize($input, &len);
+ if (buf == NULL)
+ SWIG_fail;
+ } else {
+%#elif PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 3
+%# error "Unsupported Python 3.x C API version (3.3 or later required)."
+%#endif
+ if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) {
+ // Python has raised an error (likely TypeError or UnicodeEncodeError).
+ SWIG_fail;
+ }
+%#if PY_VERSION_HEX >= 0x03030000
+ }
+%#endif
+ temp.set(buf, len);
+ }
+ $1 = &temp;
+}
+
+// C++ functions returning tensorflow::StringPiece will simply return bytes in Python,
+// or None if the StringPiece contained a NULL pointer.
+%typemap(out) tensorflow::StringPiece {
+ if ($1.data()) {
+ $result = PyString_FromStringAndSize($1.data(), $1.size());
+ } else {
+ Py_INCREF(Py_None);
+ $result = Py_None;
+ }
+}
diff --git a/tensorflow/python/lib/io/__init__.py b/tensorflow/python/lib/io/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/lib/io/__init__.py
diff --git a/tensorflow/python/lib/io/py_record_reader.cc b/tensorflow/python/lib/io/py_record_reader.cc
new file mode 100644
index 0000000000..5cc5229a8b
--- /dev/null
+++ b/tensorflow/python/lib/io/py_record_reader.cc
@@ -0,0 +1,49 @@
+#include "tensorflow/python/lib/io/py_record_reader.h"
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/record_reader.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+
+class RandomAccessFile;
+
+namespace io {
+
+PyRecordReader::PyRecordReader() {}
+
+PyRecordReader* PyRecordReader::New(const string& filename,
+ uint64 start_offset) {
+ RandomAccessFile* file;
+ Status s = Env::Default()->NewRandomAccessFile(filename, &file);
+ if (!s.ok()) {
+ return nullptr;
+ }
+ PyRecordReader* reader = new PyRecordReader;
+ reader->offset_ = start_offset;
+ reader->file_ = file;
+ reader->reader_ = new RecordReader(reader->file_);
+ return reader;
+}
+
+PyRecordReader::~PyRecordReader() {
+ delete reader_;
+ delete file_;
+}
+
+bool PyRecordReader::GetNext() {
+ if (reader_ == nullptr) return false;
+ Status s = reader_->ReadRecord(&offset_, &record_);
+ return s.ok();
+}
+
+void PyRecordReader::Close() {
+ delete reader_;
+ delete file_;
+ file_ = nullptr;
+ reader_ = nullptr;
+}
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/python/lib/io/py_record_reader.h b/tensorflow/python/lib/io/py_record_reader.h
new file mode 100644
index 0000000000..5a775761df
--- /dev/null
+++ b/tensorflow/python/lib/io/py_record_reader.h
@@ -0,0 +1,50 @@
+#ifndef TENSORFLOW_PYTHON_LIB_IO_PY_RECORD_READER_H_
+#define TENSORFLOW_PYTHON_LIB_IO_PY_RECORD_READER_H_
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+class RandomAccessFile;
+
+namespace io {
+
+class RecordReader;
+
+// A wrapper around io::RecordReader that is more easily SWIG wrapped for
+// Python. An instance of this class is not safe for concurrent access
+// by multiple threads.
+class PyRecordReader {
+ public:
+ static PyRecordReader* New(const string& filename, uint64 start_offset);
+ ~PyRecordReader();
+
+ // Attempt to get the next record at "current_offset()". If
+ // successful, returns true, and the record contents can be retrieve
+ // with "this->record()". Otherwise, returns false.
+ bool GetNext();
+ // Return the current record contents. Only valid after the preceding call
+ // to GetNext() returned true
+ string record() const { return record_; }
+ // Return the current offset in the file.
+ uint64 offset() const { return offset_; }
+
+ // Close the underlying file and release its resources.
+ void Close();
+
+ private:
+ PyRecordReader();
+
+ uint64 offset_;
+ RandomAccessFile* file_; // Owned
+ io::RecordReader* reader_; // Owned
+ string record_;
+ TF_DISALLOW_COPY_AND_ASSIGN(PyRecordReader);
+};
+
+} // namespace io
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PYTHON_LIB_IO_PY_RECORD_READER_H_
diff --git a/tensorflow/python/lib/io/py_record_reader.i b/tensorflow/python/lib/io/py_record_reader.i
new file mode 100644
index 0000000000..19f911bd52
--- /dev/null
+++ b/tensorflow/python/lib/io/py_record_reader.i
@@ -0,0 +1,39 @@
+%nothread tensorflow::io::PyRecordReader::GetNext;
+
+%include "tensorflow/python/platform/base.i"
+
+%feature("except") tensorflow::io::PyRecordReader::New {
+ // Let other threads run while we read
+ Py_BEGIN_ALLOW_THREADS
+ $action
+ Py_END_ALLOW_THREADS
+}
+
+%newobject tensorflow::io::PyRecordReader::New;
+
+%feature("except") tensorflow::io::PyRecordReader::GetNext {
+ // Let other threads run while we read
+ Py_BEGIN_ALLOW_THREADS
+ $action
+ Py_END_ALLOW_THREADS
+}
+
+%{
+#include "tensorflow/python/lib/io/py_record_reader.h"
+%}
+
+%ignoreall
+
+%unignore tensorflow;
+%unignore tensorflow::io;
+%unignore tensorflow::io::PyRecordReader;
+%unignore tensorflow::io::PyRecordReader::~PyRecordReader;
+%unignore tensorflow::io::PyRecordReader::GetNext;
+%unignore tensorflow::io::PyRecordReader::offset;
+%unignore tensorflow::io::PyRecordReader::record;
+%unignore tensorflow::io::PyRecordReader::Close;
+%unignore tensorflow::io::PyRecordReader::New;
+
+%include "tensorflow/python/lib/io/py_record_reader.h"
+
+%unignoreall
diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc
new file mode 100644
index 0000000000..e557756cbc
--- /dev/null
+++ b/tensorflow/python/lib/io/py_record_writer.cc
@@ -0,0 +1,44 @@
+#include "tensorflow/python/lib/io/py_record_writer.h"
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace io {
+
+PyRecordWriter::PyRecordWriter() {}
+
+PyRecordWriter* PyRecordWriter::New(const string& filename) {
+ WritableFile* file;
+ Status s = Env::Default()->NewWritableFile(filename, &file);
+ if (!s.ok()) {
+ return nullptr;
+ }
+ PyRecordWriter* writer = new PyRecordWriter;
+ writer->file_ = file;
+ writer->writer_ = new RecordWriter(writer->file_);
+ return writer;
+}
+
+PyRecordWriter::~PyRecordWriter() {
+ delete writer_;
+ delete file_;
+}
+
+bool PyRecordWriter::WriteRecord(::tensorflow::StringPiece record) {
+ if (writer_ == nullptr) return false;
+ Status s = writer_->WriteRecord(record);
+ return s.ok();
+}
+
+void PyRecordWriter::Close() {
+ delete writer_;
+ delete file_;
+ writer_ = nullptr;
+ file_ = nullptr;
+}
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/python/lib/io/py_record_writer.h b/tensorflow/python/lib/io/py_record_writer.h
new file mode 100644
index 0000000000..e3fd05bd9a
--- /dev/null
+++ b/tensorflow/python/lib/io/py_record_writer.h
@@ -0,0 +1,38 @@
+#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_LIB_IO_PY_RECORD_WRITER_H_
+#define THIRD_PARTY_TENSORFLOW_PYTHON_LIB_IO_PY_RECORD_WRITER_H_
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+class WritableFile;
+
+namespace io {
+
+class RecordWriter;
+
+// A wrapper around io::RecordWriter that is more easily SWIG wrapped for
+// Python. An instance of this class is not safe for concurrent access
+// by multiple threads.
+class PyRecordWriter {
+ public:
+ static PyRecordWriter* New(const string& filename);
+ ~PyRecordWriter();
+
+ bool WriteRecord(::tensorflow::StringPiece record);
+ void Close();
+
+ private:
+ PyRecordWriter();
+
+ WritableFile* file_; // Owned
+ io::RecordWriter* writer_; // Owned
+ TF_DISALLOW_COPY_AND_ASSIGN(PyRecordWriter);
+};
+
+} // namespace io
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_PYTHON_LIB_IO_PY_RECORD_WRITER_H_
diff --git a/tensorflow/python/lib/io/py_record_writer.i b/tensorflow/python/lib/io/py_record_writer.i
new file mode 100644
index 0000000000..20fe52c495
--- /dev/null
+++ b/tensorflow/python/lib/io/py_record_writer.i
@@ -0,0 +1,38 @@
+%nothread tensorflow::io::PyRecordWriter::WriteRecord;
+
+%include "tensorflow/python/platform/base.i"
+%include "tensorflow/python/lib/core/strings.i"
+
+%feature("except") tensorflow::io::PyRecordWriter::New {
+ // Let other threads run while we write
+ Py_BEGIN_ALLOW_THREADS
+ $action
+ Py_END_ALLOW_THREADS
+}
+
+%newobject tensorflow::io::PyRecordWriter::New;
+
+%feature("except") tensorflow::io::PyRecordWriter::WriteRecord {
+ // Let other threads run while we write
+ Py_BEGIN_ALLOW_THREADS
+ $action
+ Py_END_ALLOW_THREADS
+}
+
+%{
+#include "tensorflow/python/lib/io/py_record_writer.h"
+%}
+
+%ignoreall
+
+%unignore tensorflow;
+%unignore tensorflow::io;
+%unignore tensorflow::io::PyRecordWriter;
+%unignore tensorflow::io::PyRecordWriter::~PyRecordWriter;
+%unignore tensorflow::io::PyRecordWriter::WriteRecord;
+%unignore tensorflow::io::PyRecordWriter::Close;
+%unignore tensorflow::io::PyRecordWriter::New;
+
+%include "tensorflow/python/lib/io/py_record_writer.h"
+
+%unignoreall
diff --git a/tensorflow/python/lib/io/python_io.py b/tensorflow/python/lib/io/python_io.py
new file mode 100644
index 0000000000..aedcd2ef03
--- /dev/null
+++ b/tensorflow/python/lib/io/python_io.py
@@ -0,0 +1,29 @@
+"""## Data IO (Python Functions)
+
+A TFRecords file represents a sequence of (binary) strings. The format is not
+random access, so it is suitable for streaming large amounts of data but not
+suitable if fast sharding or other non-sequential access is desired.
+
+@@TFRecordWriter
+@@tf_record_iterator
+
+- - -
+
+### TFRecords Format Details
+
+A TFRecords file contains a sequence of strings with CRC hashes. Each record
+has the format
+
+ uint64 length
+ uint32 masked_crc32_of_length
+ byte data[length]
+ uint32 masked_crc32_of_data
+
+and the records are concatenated together to produce the file. The CRC32s
+are [described here](https://en.wikipedia.org/wiki/Cyclic_redundancy_check),
+and the mask of a CRC is
+
+ masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul
+"""
+
+from tensorflow.python.lib.io.tf_record import *
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
new file mode 100644
index 0000000000..00825bbda2
--- /dev/null
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -0,0 +1,68 @@
+"""For reading and writing TFRecords files."""
+
+from tensorflow.python import pywrap_tensorflow
+
+
+def tf_record_iterator(path):
+ """An iterator that read the records from a TFRecords file.
+
+ Args:
+ path: The path to the TFRecords file.
+
+ Yields:
+ Strings.
+
+ Raises:
+ IOError: If `path` cannot be opened for reading.
+ """
+ reader = pywrap_tensorflow.PyRecordReader_New(path, 0)
+ if reader is None:
+ raise IOError("Could not open %s." % path)
+ while reader.GetNext():
+ yield reader.record()
+ reader.Close()
+
+
+class TFRecordWriter(object):
+ """A class to write records to a TFRecords file.
+
+ This class implements `__enter__` and `__exit__`, and can be used
+ in `with` blocks like a normal file.
+
+ @@__init__
+ @@write
+ @@close
+ """
+ # TODO(josh11b): Support appending?
+ def __init__(self, path):
+ """Opens file `path` and creates a `TFRecordWriter` writing to it.
+
+ Args:
+ path: The path to the TFRecords file.
+
+ Raises:
+ IOError: If `path` cannot be opened for writing.
+ """
+ self._writer = pywrap_tensorflow.PyRecordWriter_New(path)
+ if self._writer is None:
+ raise IOError("Could not write to %s." % path)
+
+ def __enter__(self):
+ """Enter a `with` block."""
+ pass
+
+ def __exit__(self, unused_type, unused_value, unused_traceback):
+ """Exit a `with` block, closing the file."""
+ self.close()
+
+ def write(self, record):
+ """Write a string record to the file.
+
+ Args:
+ record: str
+ """
+ self._writer.WriteRecord(record)
+
+ def close(self):
+ """Close the file."""
+ self._writer.Close()
diff --git a/tensorflow/python/ops/__init__.py b/tensorflow/python/ops/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/ops/__init__.py
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
new file mode 100644
index 0000000000..2a463940d6
--- /dev/null
+++ b/tensorflow/python/ops/array_grad.py
@@ -0,0 +1,187 @@
+"""Gradients for operators defined in array_ops.py."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import gen_array_ops
+
+
+@ops.RegisterGradient("Pack")
+def _PackGrad(op, grad):
+ """Gradient for pack op."""
+ return array_ops.unpack(grad, num=op.get_attr('N'))
+
+
+@ops.RegisterGradient("Unpack")
+def _UnpackGrad(_, *grads):
+ """Gradient for unpack op."""
+ return array_ops.pack(grads)
+
+
+@ops.RegisterGradient("Concat")
+def _ConcatGrad(op, grad):
+ """Gradient for concat op."""
+ assert isinstance(grad, ops.Tensor)
+ # Degenerate concatenation, just return grad.
+ if len(op.inputs) == 2:
+ return [None, grad]
+ # Get the inputs' tensor shapes
+ sizes = [array_ops.shape(x) for x in op.inputs[1:]]
+ concat_dim = op.inputs[0]
+ # Since shape is 1-D, shape_of_shape = [rank-of-inputs]
+ shape_of_shape = array_ops.shape(sizes[0])
+ # Make a vector of length equal to the input's dimensions,
+ # with 0's everywhere and 1 in the concat dim position.
+ # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now)
+ mask = array_ops.concat(0,
+ [array_ops.fill(
+ array_ops.expand_dims(concat_dim, 0), 0), [1],
+ array_ops.fill(shape_of_shape - concat_dim - 1, 0)])
+ out_grads = []
+ begin = array_ops.fill(shape_of_shape, 0)
+ for i in range(len(sizes)):
+ out_grads.append(array_ops.slice(grad, begin, sizes[i]))
+ # Lint complains begin = begin + ...
+ begin = math_ops.add(begin, sizes[i] * mask)
+ return [None] + out_grads
+
+
+@ops.RegisterGradient("Slice")
+def _SliceGrad(op, grad):
+ """Gradient for Slice op."""
+ # Create an Nx2 padding where the first column represents how many
+ # zeros are to be prepended for each dimension, and the second
+ # column indicates how many zeros are appended.
+ #
+ # The number of zeros to append is the shape of the input
+ # elementwise-subtracted by both the begin vector and sizes vector.
+ #
+ # Some more reshaping is needed to assemble this tensor with the
+ # right dimensions.
+ input_vec = op.inputs[0]
+ begin_vec = op.inputs[1]
+ input_rank = array_ops.rank(input_vec)
+ slice_size = array_ops.shape(op.outputs[0])
+
+ shape = array_ops.pack([input_rank, 1])
+ before_pad = array_ops.reshape(begin_vec, shape)
+ after_pad = array_ops.reshape(
+ array_ops.shape(input_vec) - slice_size - begin_vec, shape)
+ paddings = array_ops.concat(1, [before_pad, after_pad])
+ return array_ops.pad(grad, paddings), None, None
+
+
+@ops.RegisterGradient("Split")
+def _SplitGrad(op, *grads):
+ return None, array_ops.concat(op.inputs[0], list(grads))
+
+
+ops.NoGradient("Const")
+
+# TODO(liqzhang): The gradient for Diag operator would be
+# the diagonal of the backprop. Implement if there is a need.
+ops.NoGradient("Diag")
+
+# Edit Distance has no gradient (but can be used to eval seq2seq or CTC).
+ops.NoGradient("EditDistance")
+
+ops.NoGradient("Fill")
+
+
+@ops.RegisterGradient("Gather")
+def _GatherGrad(op, grad):
+ return [
+ ops.IndexedSlices(grad, op.inputs[1], array_ops.shape(op.inputs[0])), None
+ ]
+
+
+@ops.RegisterGradient("Identity")
+def _IdGrad(_, grad):
+ return grad
+
+
+@ops.RegisterGradient("RefIdentity")
+def _RefIdGrad(_, grad):
+ return grad
+
+
+ops.NoGradient("StopGradient")
+
+
+@ops.RegisterGradient("Reshape")
+def _ReshapeGrad(op, grad):
+ return [array_ops.reshape(grad, array_ops.shape(op.inputs[0])), None]
+
+
+ops.NoGradient("InvertPermutation")
+
+
+def _ReshapeToInput(op, grad):
+ """Reshapes the gradient to the shape of the original input."""
+ return array_ops.reshape(grad, array_ops.shape(op.inputs[0]))
+
+
+@ops.RegisterGradient("ExpandDims")
+def _ExpandDimsGrad(op, grad):
+ return [_ReshapeToInput(op, grad), None]
+
+
+@ops.RegisterGradient("Squeeze")
+def _SqueezeGrad(op, grad):
+ return _ReshapeToInput(op, grad)
+
+
+@ops.RegisterGradient("Transpose")
+def _TransposeGrad(op, grad):
+ """Returns unshuffle(grad)."""
+ p = op.inputs[1]
+ return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None]
+
+
+ops.NoGradient("Shape")
+
+
+ops.NoGradient("Rank")
+
+
+ops.NoGradient("Size")
+
+
+@ops.RegisterGradient("Tile")
+def _TileGrad(op, grad):
+ """Sum reduces grad along the tiled dimensions."""
+ assert isinstance(grad, ops.Tensor)
+ return [gen_array_ops._tile_grad(grad, op.inputs[1]), None]
+
+
+ops.NoGradient("TileGrad")
+
+
+ops.NoGradient("BroadcastGradientArgs")
+
+
+@ops.RegisterGradient("Pad")
+def _PadGrad(op, grad):
+ """Gradient for Pad."""
+ # Pad introduces values around the original tensor, so the gradient function
+ # slices the original shape out of the gradient."""
+ x = op.inputs[0]
+ a = op.inputs[1] # [Rank(x), 2]
+ # Takes a slice of a. The 1st column. [Rank(x), 1].
+ pad_before = array_ops.slice(a, [0, 0],
+ array_ops.pack([array_ops.rank(x), 1]))
+ # Make it a 1-D tensor.
+ begin = array_ops.reshape(pad_before, [-1])
+ sizes = array_ops.shape(x)
+ return array_ops.slice(grad, begin, sizes), None
+
+
+# ReverseSequence is just a permutation. The gradient permutes back.
+@ops.RegisterGradient("ReverseSequence")
+def _ReverseSequenceGrad(op, grad):
+ seq_lengths = op.inputs[1]
+ return [array_ops.reverse_sequence(grad,
+ seq_dim=op.get_attr("seq_dim"),
+ seq_lengths=seq_lengths),
+ None]
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
new file mode 100644
index 0000000000..ed780db625
--- /dev/null
+++ b/tensorflow/python/ops/array_ops.py
@@ -0,0 +1,1207 @@
+"""## Casting
+
+TensorFlow provides several operations that you can use to cast tensor data
+types in your graph.
+
+@@string_to_number
+@@to_double
+@@to_float
+@@to_bfloat16
+@@to_int32
+@@to_int64
+@@cast
+
+## Shapes and Shaping
+
+TensorFlow provides several operations that you can use to determine the shape
+of a tensor and change the shape of a tensor.
+
+@@shape
+@@size
+@@rank
+@@reshape
+@@squeeze
+@@expand_dims
+
+## Slicing and Joining
+
+TensorFlow provides several operations to slice or extract parts of a tensor,
+or join multiple tensors together.
+
+@@slice
+@@split
+@@tile
+@@pad
+@@concat
+@@pack
+@@unpack
+@@reverse_sequence
+@@reverse
+@@transpose
+@@gather
+@@dynamic_partition
+@@dynamic_stitch
+"""
+import sys
+import tensorflow.python.platform
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_math_ops
+# pylint: disable=wildcard-import
+# 'Constant' gets imported in the module 'array_ops'.
+from tensorflow.python.ops.constant_op import constant
+from tensorflow.python.ops.gen_array_ops import *
+
+
+# We override the 'slice' for the "slice" op, so we keep python's
+# existing 'slice' for later use in this module.
+_baseslice = slice
+
+
+# Aliases for some automatically-generated names.
+listdiff = gen_array_ops.list_diff
+
+
+# pylint: disable=undefined-variable,protected-access
+def _SliceHelper(tensor, slice_spec):
+ """Overload for Tensor.__getitem__.
+
+ Currently the size of the slice must be statically known in each dimension,
+ i.e. the "stop" of the slice must not be omitted.
+
+ TODO(mrry): Support slices where the sizes are not specified.
+ TODO(mrry): Support negative indices in slices with numpy/Python semantics.
+
+ Args:
+ tensor: An ops.Tensor object.
+ slice_spec: The arguments to Tensor.__getitem__.
+
+ Returns:
+ The appropriate slice of "tensor", based on "slice_spec".
+
+ Raises:
+ ValueError: If a slice range is negative size.
+ TypeError: If the slice indices aren't int, slice, or Ellipsis.
+ """
+ if not isinstance(slice_spec, (list, tuple)):
+ slice_spec = [slice_spec]
+ indices = []
+ sizes = []
+ squeeze_dims = []
+ for dim, s in enumerate(slice_spec):
+ if isinstance(s, int):
+ if s < 0:
+ raise NotImplementedError("Negative indices are currently unsupported")
+ indices.append(s)
+ sizes.append(1)
+ squeeze_dims.append(dim)
+ elif isinstance(s, _baseslice):
+ if s.step not in (None, 1):
+ raise NotImplementedError(
+ "Steps other than 1 are not currently supported")
+ start = s.start if s.start is not None else 0
+ if start < 0:
+ raise NotImplementedError(
+ "Negative start indices are not currently supported")
+ indices.append(start)
+ if s.stop is not None and s.stop < 0:
+ raise NotImplementedError(
+ "Negative stop indices are not currently supported")
+ # NOTE(mrry): If the stop is not specified, Python substitutes
+ # sys.maxsize, which is typically (2 ** 63) - 1. Since Slice currently
+ # supports signed DT_INT32 arguments, we use -1 to specify that all
+ # elements should be captured.
+ if s.stop is None or s.stop == sys.maxsize:
+ sizes.append(-1)
+ else:
+ if start > s.stop:
+ raise ValueError("Stop must be at least start")
+ sizes.append(s.stop - start)
+ elif s is Ellipsis:
+ raise NotImplementedError("Ellipsis is not currently supported")
+ else:
+ raise TypeError("Bad slice index %s of type %s" % (s, type(s)))
+ sliced = slice(tensor, indices, sizes)
+ if squeeze_dims:
+ return squeeze(sliced, squeeze_dims=squeeze_dims)
+ else:
+ return sliced
+
+
+def slice(input_, begin, size, name=None):
+ """Extracts a slice from a tensor.
+
+ This operation extracts a slice of size `size` from a tensor `input` starting
+ at the location specified by `begin`. The slice `size` is represented as a
+ tensor shape, where `size[i]` is the number of elements of the 'i'th dimension
+ of `input` that you want to slice. The starting location (`begin`) for the
+ slice is represented as an offset in each dimension of `input`. In other
+ words, `begin[i]` is the offset into the 'i'th dimension of `input` that you
+ want to slice from.
+
+ `begin` is zero-based; `size` is one-based. If `size[i]` is -1,
+ all remaining elements in dimension i are included in the
+ slice. In other words, this is equivalent to setting:
+
+ `size[i] = input.dim_size(i) - begin[i]`
+
+ This operation requires that:
+
+ `0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n]`
+
+ For example:
+
+ ```
+ # 'input' is [[[1, 1, 1], [2, 2, 2]],
+ # [[3, 3, 3], [4, 4, 4]],
+ # [[5, 5, 5], [6, 6, 6]]]
+ tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]
+ tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3],
+ [4, 4, 4]]]
+ tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
+ [[5, 5, 5]]]
+ ```
+
+ Args:
+ input_: A `Tensor`.
+ begin: An `int32` or `int64` `Tensor`.
+ size: An `int32` or `int64` `Tensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` the same type as `input`.
+ """
+ return gen_array_ops._slice(input_, begin, size, name=name)
+
+
+ops.Tensor._override_operator("__getitem__", _SliceHelper)
+
+
+def pack(values, name="pack"):
+ """Packs a list of rank-`R` tensors into one rank-`(R+1)` tensor.
+
+ Packs tensors in `values` into a tensor with rank one higher than each tensor
+ in `values` and shape `[len(values)] + values[0].shape`. The output satisfies
+ `output[i, ...] = values[i][...]`.
+
+ This is the opposite of unpack. The numpy equivalent is
+
+ tf.pack([x, y, z]) = np.asarray([x, y, z])
+
+ Args:
+ values: A list of `Tensor` objects with the same shape and type.
+ name: A name for this operation (optional).
+
+ Returns:
+ output: A packed `Tensor` with the same type as `values`.
+ """
+ return gen_array_ops._pack(values, name=name)
+
+
+def unpack(value, num=None, name="unpack"):
+ """Unpacks the outer dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
+
+ Unpacks `num` tensors from `value` along the first dimension.
+ If `num` is not specified (the default), it is inferred from `value`'s shape.
+ If `value.shape[0]` is not known, `ValueError` is raised.
+
+ The ith tensor in `output` is the slice `value[i, ...]`. Each tensor in
+ `output` has shape `value.shape[1:]`.
+
+ This is the opposite of pack. The numpy equivalent is
+
+ tf.unpack(x, n) = list(x)
+
+ Args:
+ value: A rank `R > 0` `Tensor` to be unpacked.
+ num: An `int`. The first dimension of value. Automatically inferred if
+ `None` (the default).
+ name: A name for the operation (optional).
+
+ Returns:
+ The list of `Tensor` objects unpacked from `value`.
+
+ Raises:
+ ValueError: If `num` is unspecified and cannot be inferred.
+ """
+ if num is None:
+ value = ops.convert_to_tensor(value)
+ shape = value.get_shape()
+ num = shape[0].value
+ if num is None:
+ raise ValueError("Cannot infer num from shape %s" % shape)
+ return gen_array_ops._unpack(value, num=num, name=name)
+
+
+def concat(concat_dim, values, name="concat"):
+ """Concatenates tensors along one dimension.
+
+ Concatenates the list of tensors `values` along dimension `concat_dim`. If
+ `values[i].shape = [D0, D1, ... Dconcat_dim(i), ...Dn]`, the concatenated
+ result has shape
+
+ [D0, D1, ... Rconcat_dim, ...Dn]
+
+ where
+
+ Rconcat_dim = sum(Dconcat_dim(i))
+
+ That is, the data from the input tensors is joined along the `concat_dim`
+ dimension.
+
+ The number of dimensions of the input tensors must match, and all dimensions
+ except `concat_dim` must be equal.
+
+ For example:
+
+ ```python
+ t1 = [[1, 2, 3], [4, 5, 6]]
+ t2 = [[7, 8, 9], [10, 11, 12]]
+ tf.concat(0, [t1, t2]) ==> [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
+ tf.concat(1, [t1, t2]) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]
+
+ # tensor t3 with shape [2, 3]
+ # tensor t4 with shape [2, 3]
+ tf.shape(tf.concat(0, [t3, t4])) ==> [4, 3]
+ tf.shape(tf.concat(1, [t3, t4])) ==> [2, 6]
+ ```
+
+ Args:
+ concat_dim: 0-D `int32` `Tensor`. Dimension along which to concatenate.
+ values: A list of `Tensor` objects or a single `Tensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` resulting from concatenation of the input tensors.
+ """
+ if not isinstance(values, (list)):
+ values = [values]
+ # TODO(mrry): Change to return values?
+ if len(values) == 1: # Degenerate case of one tensor.
+ return identity(values[0], name=name)
+ return gen_array_ops._concat(concat_dim=concat_dim,
+ values=values,
+ name=name)
+
+
+@ops.RegisterShape("Pack")
+def _PackShape(op):
+ input_shape = op.inputs[0].get_shape()
+ for inp in op.inputs[1:]:
+ input_shape = input_shape.merge_with(inp.get_shape())
+ return [tensor_shape.TensorShape([len(op.inputs)]).concatenate(input_shape)]
+
+
+@ops.RegisterShape("Unpack")
+def _UnpackShape(op):
+ input_shape = op.inputs[0].get_shape()
+ return [input_shape[1:]] * op.get_attr("num")
+
+
+@ops.RegisterShape("Concat")
+def _ConcatShape(op):
+ concat_dim = tensor_util.ConstantValue(op.inputs[0])
+ if concat_dim is None:
+ # Return an unknown shape with the same rank as the inputs, or an
+ # unknown rank if no input's rank is known.
+ rank = None
+ for value in op.inputs[1:]:
+ if rank is not None:
+ value.get_shape().assert_has_rank(rank)
+ else:
+ rank = value.get_shape().ndims
+ return [tensor_shape.unknown_shape(ndims=max(rank, 1))]
+
+ else:
+ # Merge all the non-concat dims, and sum the concat dim to make an
+ # output shape.
+ concat_dim = int(concat_dim)
+ output_shape = op.inputs[1].get_shape()
+ # TODO(irving): Remove once !kAllowLegacyScalars.
+ if output_shape.ndims == 0:
+ output_shape = tensor_shape.TensorShape([1])
+ for value in op.inputs[2:]:
+ value_shape = value.get_shape()
+ if value_shape.ndims is not None and concat_dim >= value_shape.ndims:
+ if value_shape.ndims == 0 and concat_dim == 0:
+ # Let concat handle scalars
+ # TODO(irving): Remove once !kAllowLegacyScalars.
+ value_shape = tensor_shape.TensorShape([1])
+ else:
+ raise ValueError("concat_dim is out of range (values rank = %d)" %
+ value_shape.ndims)
+ before = output_shape[:concat_dim].merge_with(value_shape[:concat_dim])
+ at = output_shape[concat_dim] + value_shape[concat_dim]
+ after = output_shape[
+ concat_dim + 1:].merge_with(value_shape[concat_dim + 1:])
+ output_shape = before.concatenate(at).concatenate(after)
+ return [output_shape]
+
+
+def sparse_mask(a, mask_indices, name=None):
+ """Masks elements of `IndexedSlices`.
+
+ Given an `IndexedSlices` instance `a`, returns another `IndexedSlices` that
+ contains a subset of the slices of `a`. Only the slices at indices specified
+ in `mask_indices` are returned.
+
+ This is useful when you need to extract a subset of slices in an
+ `IndexedSlices` object.
+
+ For example:
+
+ ```python
+ # `a` contains slices at indices [12, 26, 37, 45] from a large tensor
+ # with shape [1000, 10]
+ a.indices => [12, 26, 37, 45]
+ tf.shape(a.values) => [4, 10]
+
+ # `b` will be the subset of `a` slices at its second and third indices, so
+ # we want to mask of its first and last indices (which are at absolute
+ # indices 12, 45)
+ b = tf.sparse_mask(a, [12, 45])
+
+ b.indices => [26, 37]
+ tf.shape(b.values) => [2, 10]
+
+ ```
+
+ Args:
+ * `a`: An `IndexedSlices` instance.
+ * `mask_indices`: Indices of elements to mask.
+ * `name`: A name for the operation (optional).
+
+ Returns:
+ The masked `IndexedSlices` instance.
+ """
+ with ops.op_scope([a, mask_indices], name, "sparse_mask") as name:
+ indices = a.indices
+ out_indices, to_gather = listdiff(indices, mask_indices)
+ out_values = gather(a.values, to_gather, name=name)
+ return ops.IndexedSlices(out_values, out_indices, a.dense_shape)
+
+
+def split(split_dim, num_split, value, name="split"):
+ """Splits a tensor into `num_split` tensors along one dimension.
+
+ Splits `value` along dimension `split_dim` into `num_split` smaller tensors.
+ Requires that `num_split` evenly divide `value.shape[split_dim]`.
+
+ For example:
+
+ ```python
+ # 'value' is a tensor with shape [5, 30]
+ # Split 'value' into 3 tensors along dimension 1
+ split0, split1, split2 = tf.split(1, 3, value)
+ tf.shape(split0) ==> [5, 10]
+ ```
+
+ Args:
+ split_dim: A 0-D `int32` `Tensor`. The dimension along which to split.
+ Must be in the range `[0, rank(value))`.
+ num_split: A 0-D `int32` `Tensor`. The number of ways to split.
+ value: The `Tensor` to split.
+ name: A name for the operation (optional).
+
+ Returns:
+ `num_split` `Tensor` objects resulting from splitting `value`.
+ """
+ return gen_array_ops._split(split_dim=split_dim,
+ num_split=num_split,
+ value=value,
+ name=name)
+
+
+@ops.RegisterShape("Reverse")
+def _ReverseShape(op):
+ return [op.inputs[0].get_shape().with_rank_at_most(8)]
+
+
+def transpose(a, perm=None, name="transpose"):
+ """Transposes `a`. Permutes the dimensions according to `perm`.
+
+ The returned tensor's dimension i will correspond to the input dimension
+ `perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is
+ the rank of the input tensor. Hence by default, this operation performs a
+ regular matrix transpose on 2-D input Tensors.
+
+ For example:
+
+ ```python
+ # 'x' is [[1 2 3]
+ # [4 5 6]]
+ tf.transpose(x) ==> [[1 4]
+ [2 5]
+ [3 6]]
+
+ # Equivalently
+ tf.transpose(x perm=[0, 1]) ==> [[1 4]
+ [2 5]
+ [3 6]]
+
+ # 'perm' is more useful for n-dimensional tensors, for n > 2
+ # 'x' is [[[1 2 3]
+ # [4 5 6]]
+ # [[7 8 9]
+ # [10 11 12]]]
+ # Take the transpose of the matrices in dimension-0
+ tf.transpose(b, perm=[0, 2, 1]) ==> [[[1 4]
+ [2 5]
+ [3 6]]
+
+ [[7 10]
+ [8 11]
+ [9 12]]]
+ ```
+
+ Args:
+ a: A `Tensor`.
+ perm: A permutation of the dimensions of `a`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A transposed `Tensor`.
+ """
+ with ops.op_scope([a], name, "transpose") as name:
+ if perm is None:
+ dims = gen_math_ops._range(0, gen_array_ops.rank(a), 1)
+ perm = gen_array_ops.reverse(dims, [True])
+ ret = gen_array_ops.transpose(a, perm, name=name)
+ # NOTE(mrry): Setting the shape explicitly because
+ # reverse is not handled by the shape function.
+ input_shape = ret.op.inputs[0].get_shape().dims
+ if input_shape is not None:
+ ret.set_shape(input_shape[::-1])
+ else:
+ ret = gen_array_ops.transpose(a, perm, name=name)
+ return ret
+
+
+def zeros(shape, dtype=types.float32, name=None):
+ """Creates a tensor with all elements set to zero.
+
+ This operation returns a tensor of type `dtype` with shape `shape` and
+ all elements set to zero.
+
+ For example:
+
+ ```python
+ tf.zeros([3, 4], int32) ==> [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
+ ```
+
+ Args:
+ shape: Either a list of integers, or a 1-D `Tensor` of type `int32`.
+ dtype: The type of an element in the resulting `Tensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with all elements set to zero.
+ """
+ with ops.op_scope([shape], name, "zeros") as name:
+ if isinstance(shape, list):
+ output = constant(0, shape=shape, dtype=dtype, name=name)
+ else:
+ shape = ops.convert_to_tensor(shape, name="shape")
+ output = fill(shape, constant(0, dtype=dtype), name=name)
+ assert output.dtype.base_dtype == types.as_dtype(dtype).base_dtype
+ return output
+
+
+def zeros_like(tensor, dtype=None, name=None):
+ """Creates a tensor with all elements set to zero.
+
+ Given a single tensor (`tensor`), this operation returns a tensor of the
+ same type and shape as `tensor` with all elements set to zero. Optionally,
+ you can use `dtype` to specify a new type for the returned tensor.
+
+ For example:
+
+ ```python
+ # 'tensor' is [[1, 2, 3], [4, 5, 6]]
+ tf.zeros_like(tensor) ==> [[0, 0, 0], [0, 0, 0]]
+ ```
+
+ Args:
+ tensor: A `Tensor`.
+ dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
+ `int8`, `int16`, `int32`, `int64`, `uint8`, or `complex64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with all elements set to zero.
+ """
+ with ops.op_scope([tensor], name, "zeros_like") as name:
+ tensor = ops.convert_to_tensor(tensor, name="tensor")
+ zeros_shape = shape(tensor)
+ if dtype is None:
+ dtype = tensor.dtype
+ return zeros(zeros_shape, dtype=dtype, name=name)
+
+
+def ones_like(tensor, dtype=None, name=None):
+ """Creates a tensor with all elements set to 1.
+
+ Given a single tensor (`tensor`), this operation returns a tensor of the same
+ type and shape as `tensor` with all elements set to 1. Optionally, you can
+ specify a new type (`dtype`) for the returned tensor.
+
+ For example:
+
+ ```python
+ # 'tensor' is [[1, 2, 3], [4, 5, 6]]
+ tf.ones_like(tensor) ==> [[1, 1, 1], [1, 1, 1]]
+ ```
+
+ Args:
+ tensor: A `Tensor`.
+ dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
+ `int8`, `int16`, `int32`, `int64`, `uint8`, or `complex64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with all elements set to 1.
+ """
+ with ops.op_scope([tensor], name, "ones_like") as name:
+ tensor = ops.convert_to_tensor(tensor, name="tensor")
+ ones_shape = shape(tensor)
+ if dtype is None:
+ dtype = tensor.dtype
+ return ones(ones_shape, dtype=dtype, name=name)
+
+
+def zeros_initializer(shape, dtype=types.float32):
+ """An adaptor for zeros() to match the Initializer spec."""
+ return zeros(shape, dtype)
+
+
+def ones(shape, dtype=types.float32, name=None):
+ """Creates a tensor with all elements set to 1.
+
+ This operation returns a tensor of type `dtype` with shape `shape` and all
+ elements set to 1.
+
+ For example:
+
+ ```python
+ tf.ones([2, 3], int32) ==> [[1, 1, 1], [1, 1, 1]]
+ ```
+
+ Args:
+ shape: Either a list of integers, or a 1-D `Tensor` of type `int32`.
+ dtype: The type of an element in the resulting `Tensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with all elements set to 1.
+ """
+ with ops.op_scope([shape], name, "ones") as name:
+ if isinstance(shape, list):
+ output = constant(1, shape=shape, dtype=dtype, name=name)
+ else:
+ shape = ops.convert_to_tensor(shape, name="shape")
+ output = fill(shape, constant(1, dtype=dtype), name=name)
+ assert output.dtype.base_dtype == types.as_dtype(dtype).base_dtype
+ return output
+
+
+def placeholder(dtype, shape=None, name=None):
+ """Inserts a placeholder for a tensor that will be always fed.
+
+ **Important**: This tensor will produce an error if evaluated. Its value must
+ be fed using the `feed_dict` optional argument to `Session.run()`,
+ `Tensor.eval()`, or `Operation.run()`.
+
+ For example:
+
+ ```python
+ x = tf.placeholder(float, shape=(1024, 1024))
+ y = tf.matmul(x, x)
+
+ with tf.Session() as sess:
+ print sess.run(y) # ERROR: will fail because x was not fed.
+
+ rand_array = np.random.rand(1024, 1024)
+ print sess.run(y, feed_dict={x: rand_array}) # Will succeed.
+ ```
+
+ Args:
+ dtype: The type of elements in the tensor to be fed.
+ shape: The shape of the tensor to be fed (optional). If the shape is not
+ specified, you can feed a tensor of any shape.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` that may be used as a handle for feeding a value, but not
+ evaluated directly.
+ """
+ shape = tensor_shape.as_shape(shape)
+ if shape.is_fully_defined():
+ dim_list = shape.as_list()
+ else:
+ dim_list = []
+ ret = gen_array_ops._placeholder(
+ dtype=dtype,
+ shape=dim_list,
+ name=name)
+ ret.set_shape(shape)
+ return ret
+
+
+@ops.RegisterShape("Placeholder")
+def _PlaceholderShape(op):
+ given_shape = tensor_util.TensorShapeProtoToList(op.get_attr("shape"))
+ if given_shape:
+ return [tensor_shape.TensorShape(given_shape)]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("CheckNumerics")
+@ops.RegisterShape("Identity")
+@ops.RegisterShape("RefIdentity")
+@ops.RegisterShape("StopGradient")
+def _UnchangedShape(op):
+ return [op.inputs[0].get_shape()]
+
+
+@ops.RegisterShape("Rank")
+@ops.RegisterShape("Size")
+def _ScalarShape(unused_op):
+ return [tensor_shape.scalar()]
+
+
+@ops.RegisterShape("Slice")
+def _SliceShape(op):
+ """Shape function for array_ops.slice."""
+ input_shape = op.inputs[0].get_shape()
+ begin_shape = op.inputs[1].get_shape().with_rank_at_most(1)
+ sizes_shape = op.inputs[2].get_shape().with_rank_at_most(1)
+ rank_vector_shape = begin_shape.merge_with(sizes_shape)
+ ndims = rank_vector_shape.num_elements()
+ if ndims is not None:
+ input_shape.assert_has_rank(ndims)
+ begin_value = tensor_util.ConstantValue(op.inputs[1])
+ sizes_value = tensor_util.ConstantValue(op.inputs[2])
+ if sizes_value is not None:
+ returned_dims = []
+ for i, slice_size in enumerate(sizes_value.ravel()):
+ if slice_size != -1:
+ returned_dims.append(slice_size)
+ elif begin_value is not None:
+ returned_dims.append(input_shape[i] - begin_value[i])
+ else:
+ returned_dims.append(None)
+ return [tensor_shape.TensorShape(returned_dims)]
+ else:
+ if input_shape.ndims is not None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
+ elif ndims is not None:
+ return [tensor_shape.unknown_shape(ndims=ndims)]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("Gather")
+def _GatherShape(op):
+ """Shape function for array_ops.gather."""
+ params_shape = op.inputs[0].get_shape()
+ indices_shape = op.inputs[1].get_shape()
+ return [indices_shape.concatenate(params_shape[1:])]
+
+
+@ops.RegisterShape("Unique")
+def _UniqueShape(op):
+ """Shape function for array_ops.Unique."""
+ # The output is a vector with data-dependent length.
+ input_shape = op.inputs[0].get_shape()
+ input_shape.assert_has_rank(1)
+ return [tensor_shape.vector(None), input_shape]
+
+
+@ops.RegisterShape("Diag")
+def _DiagShape(op):
+ """Shape function for array_ops.diag.
+
+ This op has one input (of rank k <= 3), and one output (of rank 2k),
+ where the shape of the output is the concatenation of the input
+ shape with itself.
+
+ Args:
+ op: A Diag Operation.
+
+ Returns:
+ A single-element list containing the shape of the output.
+ """
+ input_shape = op.inputs[0].get_shape().with_rank_at_most(3)
+ return [input_shape.concatenate(input_shape)]
+
+
+@ops.RegisterShape("ExpandDims")
+def _ExpandDimsShape(op):
+ """Determine shape for expand op's output tensor.
+
+ Args:
+ op: Operation for which to determine shape.
+ op.inputs[0] is the input tensor.
+ op.inputs[1] is the dimension in which to expand.
+ Returns:
+ Shape of op's output tensor.
+ Raises:
+ ValueError: If dim is outside of [-rank - 1, rank], where rank is the number
+ of dimensions in the input tensor.
+ """
+ input_shape = op.inputs[0].get_shape()
+ if input_shape.dims is None:
+ return [tensor_shape.unknown_shape()]
+ dim = tensor_util.ConstantValue(op.inputs[1])
+ input_ndims = input_shape.ndims
+ if dim < -input_ndims - 1 or dim > input_ndims:
+ raise ValueError(
+ "dim %d not in [%d, %d]." % (dim, -input_ndims, input_ndims))
+ if dim < 0:
+ dim += (input_ndims + 1)
+ result_shape = list(input_shape.dims)
+ result_shape.insert(dim, 1)
+ return [tensor_shape.TensorShape(result_shape)]
+
+
+@ops.RegisterShape("Squeeze")
+def _SqueezeShape(op):
+ """Determine shape for squeeze op's output tensor.
+
+ Args:
+ op: Operation for which to determine shape.
+ Returns:
+ Shape of op's output tensor.
+ Raises:
+ ValueError: if squeeze_dims includes a dimension outside of [-rank, rank),
+ where rank is the number of dimensions in the input tensor. Or, if
+ squeeze_dims includes a dimension for which input shape has a value
+ not equal to 1.
+ """
+ input_shape = op.inputs[0].get_shape()
+ if input_shape.dims is None:
+ return [tensor_shape.unknown_shape()]
+
+ squeeze_dims = op.get_attr("squeeze_dims") or []
+ wrapped_squeeze_dims = []
+ input_ndims = input_shape.ndims
+ for i, squeeze_dim in enumerate(squeeze_dims):
+ if squeeze_dim < -input_ndims or squeeze_dim >= input_ndims:
+ raise ValueError(
+ "squeeze_dims[%d]=%d not in [%d, %d)." % (
+ i, squeeze_dim, -input_ndims, input_ndims))
+ if squeeze_dim < 0:
+ squeeze_dim += input_ndims
+ wrapped_squeeze_dims.append(squeeze_dim)
+
+ result_shape = []
+ for i, dim in enumerate([d.value for d in input_shape.dims]):
+ is_explicit_match = i in wrapped_squeeze_dims
+ if is_explicit_match or not wrapped_squeeze_dims:
+ if dim is None:
+ return [tensor_shape.unknown_shape()]
+ if dim != 1:
+ if is_explicit_match:
+ raise ValueError(
+ "Can not squeeze dim[%d], expected a dimension of 1, got %d." % (
+ i, dim))
+ result_shape.append(dim)
+ else:
+ result_shape.append(dim)
+ return [tensor_shape.TensorShape(result_shape)]
+
+
+@ops.RegisterShape("Reshape")
+def _ReshapeShape(op):
+ """Shape function for Reshape op."""
+ input_shape = op.inputs[0].get_shape()
+ new_shape_shape = op.inputs[1].get_shape().with_rank_at_most(1)
+ new_shape = tensor_util.ConstantValue(op.inputs[1])
+ if new_shape is None:
+ # Attempt to infer the rank of the output from the length of
+ # new_shape.
+ return [tensor_shape.unknown_shape(ndims=new_shape_shape.num_elements())]
+ new_shape = np.reshape(new_shape, -1).tolist()
+ if -1 not in new_shape:
+ # The new shape is fully defined.
+ return [tensor_shape.TensorShape(new_shape)]
+ elif input_shape.is_fully_defined():
+ # We know the input shape, so we can calculate the missing
+ # dimension in the new_shape.
+ num_elements = 1
+ for dim in input_shape.dims:
+ num_elements *= dim.value
+ known_elements = 1
+ unknown_index = None
+ for i, dim in enumerate(new_shape):
+ if dim == -1:
+ unknown_index = i
+ else:
+ known_elements *= dim
+ if known_elements == 0:
+ raise ValueError("cannot infer the missing input size for "
+ "an empty tensor unless all specified "
+ "input sizes are non-zero")
+ if num_elements % known_elements != 0:
+ raise ValueError("input has %s elements, which isn't divisible by %d" %
+ (num_elements, known_elements))
+ new_shape[unknown_index] = num_elements / known_elements
+ return [tensor_shape.TensorShape(new_shape)]
+ else:
+ # We don't know the input shape, but we know n-1 of the dimensions
+ # in the new shape.
+ new_shape[new_shape.index(-1)] = None
+ return [tensor_shape.TensorShape(new_shape)]
+
+
+@ops.RegisterShape("BroadcastGradientArgs")
+def _BroadcastGradientArgsShape(op):
+ """Shape function for the BroadcastGradientArgs op."""
+ # TODO(mrry): Implement ConstantValue for BroadcastGradientArgs?
+ op.inputs[0].get_shape().assert_has_rank(1)
+ op.inputs[1].get_shape().assert_has_rank(1)
+ return [tensor_shape.vector(None), tensor_shape.vector(None)]
+
+
+@ops.RegisterShape("Fill")
+def _FillShape(op):
+ """Shape function for the Fill op.
+
+ This op takes a vector of dimensions and a scalar, and produces a
+ tensor with the given dimensions.
+
+ Args:
+ op: A Fill Operation.
+
+ Returns:
+ A single-element list containing the shape of the output.
+ """
+ dimensions_shape = op.inputs[0].get_shape().with_rank_at_most(1)
+ op.inputs[1].get_shape().assert_is_compatible_with(tensor_shape.scalar())
+ fill_dims = tensor_util.ConstantValue(op.inputs[0])
+ if fill_dims is None:
+ # Attempt to infer the rank of the output from the length of
+ # dimensions.
+ return [tensor_shape.unknown_shape(ndims=dimensions_shape.num_elements())]
+ else:
+ return [tensor_shape.TensorShape(fill_dims.tolist())]
+
+
+@ops.RegisterShape("InvertPermutation")
+def _InvertPermutationShape(op):
+ """Shape function for the InvertPermutation op."""
+ return [op.inputs[0].get_shape().with_rank(1)]
+
+
+@ops.RegisterShape("ListDiff")
+def _ListDiffShape(op):
+ """Shape function for the ListDiff op."""
+ op.inputs[0].get_shape().assert_has_rank(1)
+ op.inputs[1].get_shape().assert_has_rank(1)
+ # TODO(mrry): Indicate that the length falls within an interval?
+ return [tensor_shape.vector(None)] * 2
+
+
+@ops.RegisterShape("Pad")
+def _PadShape(op):
+ """Shape function for the Pad op.
+
+ This op has two inputs:
+
+ * input: A rank-N tensor.
+ * paddings: An N-by-2 matrix, in which the i^th row contains the
+ number of padding elements to add before and after `input` in the
+ i^th dimension.
+
+ It has one output, which has the same rank as input, and additional
+ elements according to the values in paddings.
+
+ Args:
+ op: A Pad Operation.
+
+ Returns:
+ A single-element list containing the shape of the output.
+
+ Raises:
+ ValueError: If the input shapes are incompatible.
+ """
+ paddings_shape = op.inputs[1].get_shape().with_rank(2)
+ input_shape = op.inputs[0].get_shape()
+ if input_shape.ndims == 0 and paddings_shape[0].value == 1:
+ # TODO(irving): Remove once !kAllowLegacyScalars.
+ input_shape = tensor_shape.TensorShape([1])
+ else:
+ input_shape = input_shape.with_rank(paddings_shape[0].value)
+ paddings_shape = paddings_shape.merge_with(
+ tensor_shape.matrix(input_shape.ndims, 2))
+ paddings = tensor_util.ConstantValue(op.inputs[1])
+ if paddings is None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
+ else:
+ output_dims = []
+ for i, dim in enumerate(input_shape.dims):
+ if paddings[i, 0] < 0 or paddings[i, 1] < 0:
+ raise ValueError("paddings must be non-negative")
+ output_dims.append(dim + paddings[i, 0] + paddings[i, 1])
+ return [tensor_shape.TensorShape(output_dims)]
+
+
+@ops.RegisterShape("ReverseSequence")
+def _ReverseSequenceShape(op):
+ """Shape function for the ReverseSequence op.
+
+ This op has two inputs:
+
+ * input: A rank-N tensor with size B in the 0th dimension.
+ * seq_lens: A vector of length B.
+
+ It has one output, with the same size as input.
+
+ Args:
+ op: A ReverseSequence Operation.
+
+ Returns:
+ A single-element list containing the shape of the output.
+
+ Raises:
+ ValueError: If the input shapes are incompatible.
+ """
+ input_shape = op.inputs[0].get_shape()
+ seq_lens_shape = op.inputs[1].get_shape().with_rank(1)
+ batch_size = input_shape[0].merge_with(seq_lens_shape[0])
+ input_shape = tensor_shape.TensorShape([batch_size]).concatenate(
+ input_shape[1:])
+ seq_dim = op.get_attr("seq_dim")
+ if seq_dim >= input_shape.ndims:
+ raise ValueError("seq_dim must be < input.dims() (%d vs %d)" %
+ (seq_dim, input_shape.ndims))
+ return [input_shape]
+
+
+@ops.RegisterShape("Shape")
+def _ShapeShape(op):
+ """Shape function for the Shape op."""
+ input_shape = op.inputs[0].get_shape()
+ return [tensor_shape.vector(input_shape.ndims)]
+
+
+@ops.RegisterShape("Transpose")
+def _TransposeShape(op):
+ """Shape function for the Transpose op.
+
+ This op takes two inputs:
+
+ * input: a rank-N tensor of arbitrary shape.
+ * shuffle: a length-N vector.
+
+ Its output is the rank-N tensor computed by permuting the dimensions
+ of input according to shuffle.
+
+ Args:
+ op: A Transpose op.
+
+ Returns:
+ A single-element list containing the shape of the output.
+
+ Raises:
+ ValueError: If the shapes of input and shuffle are incompatible.
+ IndexError: If shuffle contains an index that is >= the rank of input.
+ """
+ input_shape = op.inputs[0].get_shape()
+ transpose_shape = op.inputs[1].get_shape().merge_with(tensor_shape.vector(
+ input_shape.ndims))
+ transpose_vec = tensor_util.ConstantValue(op.inputs[1])
+ if transpose_vec is None:
+ return [tensor_shape.unknown_shape(ndims=transpose_shape[0].value)]
+ else:
+ return [tensor_shape.TensorShape([input_shape[i]
+ for i in transpose_vec.tolist()])]
+
+
+@ops.RegisterShape("Split")
+def _SplitShape(op):
+ """Shape function for the Split op."""
+ split_dim = tensor_util.ConstantValue(op.inputs[0])
+ num_split = len(op.outputs)
+ input_shape = op.inputs[1].get_shape()
+ if split_dim is None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)] * num_split
+ else:
+ split_dim = int(split_dim)
+ input_shape = input_shape.with_rank_at_least(split_dim + 1)
+ if not (input_shape[split_dim] % num_split).is_compatible_with(0):
+ raise ValueError(
+ "Number of ways to split should evenly divide the split "
+ "dimension but got split_dim %d (size = %d) and num_split %d" %
+ (split_dim, input_shape[split_dim].value, num_split))
+ prefix = input_shape[:split_dim]
+ size_in_split_dim = input_shape[split_dim] / num_split
+ suffix = input_shape[split_dim + 1:]
+ output_shape = prefix.concatenate(size_in_split_dim).concatenate(suffix)
+ return [output_shape] * num_split
+
+
+@ops.RegisterShape("Tile")
+def _TileShape(op):
+ """Shape function for the Tile op.
+
+ This op has two inputs:
+
+ * input: A rank-N tensor.
+ * multiples: A length-N vector, in which the i^th element contains
+ the factor by which `input` will be tiled in the i^th dimension.
+
+ It has one output, which has the same rank as input, and additional
+ elements according to the values in multiples
+
+ Args:
+ op: A Tile Operation.
+
+ Returns:
+ A single-element list containing the shape of the output.
+ """
+ multiples_shape = op.inputs[1].get_shape().with_rank_at_most(1)
+ input_shape = op.inputs[0].get_shape().with_rank(multiples_shape.num_elements())
+ multiples = tensor_util.ConstantValue(op.inputs[1])
+ if multiples is None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
+ else:
+ output_dims = []
+ multiples = multiples.ravel()
+ for i, dim in enumerate(input_shape.dims):
+ output_dims.append(dim * multiples[i])
+ return [tensor_shape.TensorShape(output_dims)]
+
+
+@ops.RegisterShape("TileGrad")
+def _TileGradShape(op):
+ """Shape function for the TileGrad op."""
+ multiples_shape = op.inputs[1].get_shape().with_rank_at_most(1)
+ input_shape = op.inputs[0].get_shape().with_rank(multiples_shape.num_elements())
+ multiples = tensor_util.ConstantValue(op.inputs[1])
+ if multiples is None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
+ else:
+ output_dims = []
+ for i, dim in enumerate(input_shape.dims):
+ output_dims.append(dim / multiples[i])
+ return [tensor_shape.TensorShape(output_dims)]
+
+
+@ops.RegisterShape("Where")
+def _WhereShape(op):
+ """Shape function for the Where op."""
+ input_shape = op.inputs[0].get_shape()
+ return [tensor_shape.matrix(None, input_shape.ndims)]
+
+
+@ops.RegisterShape("ZerosLike")
+def _ZerosLikeShape(op):
+ """Shape function for the ZerosLike op."""
+ return [op.inputs[0].get_shape()]
+
+
+def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
+ """Computes the Levenshtein distance between sequences.
+
+ This operation takes variable-length sequences (`hypothesis` and `truth`),
+ each provided as a `SparseTensor`, and computes the Levenshtein distance.
+ You can normalize the edit distance by length of `truth` by setting
+ `normalize` to true.
+
+ For example, given the following input:
+
+ ```python
+ # 'hypothesis' is a tensor of shape `[2, 1]` with variable-length values:
+ # (0,0) = ["a"]
+ # (1,0) = ["b"]
+ hypothesis = tf.SparseTensor(
+ [[0, 0, 0],
+ [1, 0, 0]],
+ ["a", "b"]
+ (2, 1, 1))
+
+ # 'truth' is a tensor of shape `[2, 2]` with variable-length values:
+ # (0,0) = []
+ # (0,1) = ["a"]
+ # (1,0) = ["b", "c"]
+ # (1,1) = ["a"]
+ truth = tf.SparseTensor(
+ [[0, 1, 0],
+ [1, 0, 0],
+ [1, 0, 1],
+ [1, 1, 0]]
+ ["a", "b", "c", "a"],
+ (2, 2, 2))
+
+ normalize = True
+ ```
+
+ This operation would return the following:
+
+ ```python
+ # 'output' is a tensor of shape `[2, 2]` with edit distances normalized
+ # by 'truth' lengths.
+ output ==> [[inf, 1.0], # (0,0): no truth, (0,1): no hypothesis
+ [0.5, 1.0]] # (1,0): addition, (1,1): no hypothesis
+ ```
+
+ Args:
+ hypothesis: A `SparseTensor` containing hypothesis sequences.
+ truth: A `SparseTensor` containing truth sequences.
+ normalize: A `bool`. If `True`, normalizes the Levenshtein distance by
+ length of `truth.`
+ name: A name for the operation (optional).
+
+ Returns:
+ A dense `Tensor` with rank `R - 1`, where R is the rank of the
+ `SparseTensor` inputs `hypothesis` and `truth`.
+
+ Raises:
+ TypeError: If either `hypothesis` or `truth` are not a `SparseTensor`.
+ """
+ if not isinstance(hypothesis, ops.SparseTensor):
+ raise TypeError("Hypothesis must be a SparseTensor")
+ if not isinstance(truth, ops.SparseTensor):
+ raise TypeError("Truth must be a SparseTensor")
+
+ return gen_array_ops._edit_distance(hypothesis.indices,
+ hypothesis.values,
+ hypothesis.shape,
+ truth.indices,
+ truth.values,
+ truth.shape,
+ normalize=normalize,
+ name=name)
+
+
+@ops.RegisterShape("EditDistance")
+def _EditDistanceShape(op):
+ """Shape function for the EditDistance op."""
+ hypothesis_shape = tensor_util.ConstantValue(op.inputs[2])
+ truth_shape = tensor_util.ConstantValue(op.inputs[5])
+ if hypothesis_shape is not None and truth_shape is not None:
+ if len(hypothesis_shape) != len(truth_shape):
+ raise ValueError(
+ "Inconsistent ranks in hypothesis and truth. Saw shapes: %s and %s" %
+ (str(hypothesis_shape), str(truth_shape)))
+ return [tensor_shape.TensorShape(
+ [max(h, t) for h, t in zip(hypothesis_shape[:-1], truth_shape[:-1])])]
+
+ return [tensor_shape.unknown_shape()]
+
+
+# The remaining ops do not change the shape of their inputs.
+@ops.RegisterShape("Quantize")
+@ops.RegisterShape("Dequantize")
+def _QuantizeDequantizeShape(op):
+ unused_min_range = op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
+ unused_max_range = op.inputs[2].get_shape().merge_with(tensor_shape.scalar())
+ return common_shapes.unchanged_shape(op)
diff --git a/tensorflow/python/ops/attention_ops.py b/tensorflow/python/ops/attention_ops.py
new file mode 100644
index 0000000000..4829bcd7cd
--- /dev/null
+++ b/tensorflow/python/ops/attention_ops.py
@@ -0,0 +1,34 @@
+"""Operations for implementing attention.
+"""
+import tensorflow.python.platform
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import gen_attention_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_attention_ops import *
+
+
+# TODO(bsteiner): Implement the gradient function for extract_glimpse
+ops.NoGradient("ExtractGlimpse")
+
+
+@ops.RegisterShape("ExtractGlimpse")
+def _ExtractGlimpseShape(op):
+ """Shape function for ExtractGlimpse op."""
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ unused_size_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.vector(2))
+ offsets_shape = op.inputs[2].get_shape().merge_with(
+ input_shape[:1].concatenate([2]))
+ offsets_shape = offsets_shape
+ size_value = tensor_util.ConstantValue(op.inputs[1])
+ if size_value is not None:
+ height = size_value[0]
+ width = size_value[1]
+ else:
+ height = None
+ width = None
+ return [tensor_shape.TensorShape(
+ [input_shape[0], height, width, input_shape[3]])]
diff --git a/tensorflow/python/ops/candidate_sampling_ops.py b/tensorflow/python/ops/candidate_sampling_ops.py
new file mode 100644
index 0000000000..06857c0adc
--- /dev/null
+++ b/tensorflow/python/ops/candidate_sampling_ops.py
@@ -0,0 +1,365 @@
+"""Wrappers for primitive Neural Net (NN) Operations."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_candidate_sampling_ops
+from tensorflow.python.ops import math_ops
+
+
+def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
+ range_max, seed=None, name=None):
+ """Samples a set of classes using a uniform base distribution.
+
+ This operation randomly samples a tensor of sampled classes
+ (`sampled_candidates`) from the range of integers `[0, range_max]`.
+
+ The elements of `sampled_candidates` are drawn without replacement
+ (if `unique=True`) or with replacement (if `unique=False`) from
+ the base distribution.
+
+ The base distribution for this operation is the uniform distribution
+ over the range of integers `[0, range_max]`.
+
+ In addition, this operation returns tensors `true_expected_count`
+ and `sampled_expected_count` representing the number of times each
+ of the target classes (`true_classes`) and the sampled
+ classes (`sampled_candidates`) is expected to occur in an average
+ tensor of sampled classes. These values correspond to `Q(y|x)`
+ defined in [this
+ document](http://www.tensorflow.org/extras/candidate_sampling.pdf).
+ If `unique=True`, then these are post-rejection probabilities and we
+ compute them approximately.
+
+ Args:
+ true_classes: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes.
+ num_true: An `int`. The number of target classes per training example.
+ num_sampled: An `int`. The number of classes to randomly sample per batch.
+ unique: A `bool`. Determines whether all sampled classes in a batch are
+ unique.
+ range_max: An `int`. The number of possible classes.
+ seed: An `int`. An operation-specific seed. Default is 0.
+ name: A name for the operation (optional).
+
+ Returns:
+ sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
+ The sampled classes.
+ true_expected_count: A tensor of type `float`. Same shape as
+ `true_classes`. The expected counts under the sampling distribution
+ of each of `true_classes`.
+ sampled_expected_count: A tensor of type `float`. Same shape as
+ `sampled_candidates`. The expected counts under the sampling distribution
+ of each of `sampled_candidates`.
+ """
+ seed1, seed2 = random_seed.get_seed(seed)
+ return gen_candidate_sampling_ops._uniform_candidate_sampler(
+ true_classes, num_true, num_sampled, unique, range_max, seed=seed1,
+ seed2=seed2, name=name)
+
+
+def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
+ range_max, seed=None, name=None):
+ """Samples a set of classes using a log-uniform (Zipfian) base distribution.
+
+ This operation randomly samples a tensor of sampled classes
+ (`sampled_candidates`) from the range of integers `[0, range_max]`.
+
+ The elements of `sampled_candidates` are drawn without replacement
+ (if `unique=True`) or with replacement (if `unique=False`) from
+ the base distribution.
+
+ The base distribution for this operation is an approximately log-uniform
+ or Zipfian distribution:
+
+ `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
+
+ This sampler is useful when the target classes approximately follow such
+ a distribution - for example, if the classes represent words in a lexicon
+ sorted in decreasing order of frequency. If your classes are not ordered by
+ decreasing frequency, do not use this op.
+
+ In addition, this operation returns tensors `true_expected_count`
+ and `sampled_expected_count` representing the number of times each
+ of the target classes (`true_classes`) and the sampled
+ classes (`sampled_candidates`) is expected to occur in an average
+ tensor of sampled classes. These values correspond to `Q(y|x)`
+ defined in [this
+ document](http://www.tensorflow.org/extras/candidate_sampling.pdf).
+ If `unique=True`, then these are post-rejection probabilities and we
+ compute them approximately.
+
+ Args:
+ true_classes: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes.
+ num_true: An `int`. The number of target classes per training example.
+ num_sampled: An `int`. The number of classes to randomly sample per batch.
+ unique: A `bool`. Determines whether all sampled classes in a batch are
+ unique.
+ range_max: An `int`. The number of possible classes.
+ seed: An `int`. An operation-specific seed. Default is 0.
+ name: A name for the operation (optional).
+
+ Returns:
+ sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
+ The sampled classes.
+ true_expected_count: A tensor of type `float`. Same shape as
+ `true_classes`. The expected counts under the sampling distribution
+ of each of `true_classes`.
+ sampled_expected_count: A tensor of type `float`. Same shape as
+ `sampled_candidates`. The expected counts under the sampling distribution
+ of each of `sampled_candidates`.
+ """
+ seed1, seed2 = random_seed.get_seed(seed)
+ return gen_candidate_sampling_ops._log_uniform_candidate_sampler(
+ true_classes, num_true, num_sampled, unique, range_max, seed=seed1,
+ seed2=seed2, name=name)
+
+
+def learned_unigram_candidate_sampler(true_classes, num_true, num_sampled,
+ unique, range_max, seed=None, name=None):
+ """Samples a set of classes from a distribution learned during training.
+
+ This operation randomly samples a tensor of sampled classes
+ (`sampled_candidates`) from the range of integers `[0, range_max]`.
+
+ The elements of `sampled_candidates` are drawn without replacement
+ (if `unique=True`) or with replacement (if `unique=False`) from
+ the base distribution.
+
+ The base distribution for this operation is constructed on the fly
+ during training. It is a unigram distribution over the target
+ classes seen so far during training. Every integer in `[0, range_max]`
+ begins with a weight of 1, and is incremented by 1 each time it is
+ seen as a target class. The base distribution is not saved to checkpoints,
+ so it is reset when the model is reloaded.
+
+ In addition, this operation returns tensors `true_expected_count`
+ and `sampled_expected_count` representing the number of times each
+ of the target classes (`true_classes`) and the sampled
+ classes (`sampled_candidates`) is expected to occur in an average
+ tensor of sampled classes. These values correspond to `Q(y|x)`
+ defined in [this
+ document](http://www.tensorflow.org/extras/candidate_sampling.pdf).
+ If `unique=True`, then these are post-rejection probabilities and we
+ compute them approximately.
+
+ Args:
+ true_classes: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes.
+ num_true: An `int`. The number of target classes per training example.
+ num_sampled: An `int`. The number of classes to randomly sample per batch.
+ unique: A `bool`. Determines whether all sampled classes in a batch are
+ unique.
+ range_max: An `int`. The number of possible classes.
+ seed: An `int`. An operation-specific seed. Default is 0.
+ name: A name for the operation (optional).
+
+ Returns:
+ sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
+ The sampled classes.
+ true_expected_count: A tensor of type `float`. Same shape as
+ `true_classes`. The expected counts under the sampling distribution
+ of each of `true_classes`.
+ sampled_expected_count: A tensor of type `float`. Same shape as
+ `sampled_candidates`. The expected counts under the sampling distribution
+ of each of `sampled_candidates`.
+
+ """
+ seed1, seed2 = random_seed.get_seed(seed)
+ return gen_candidate_sampling_ops._learned_unigram_candidate_sampler(
+ true_classes, num_true, num_sampled, unique, range_max, seed=seed1,
+ seed2=seed2, name=name)
+
+
+def fixed_unigram_candidate_sampler(true_classes, num_true, num_sampled, unique,
+ range_max, vocab_file='', distortion=0.0,
+ num_reserved_ids=0, num_shards=1, shard=0,
+ unigrams=[], seed=None, name=None):
+ """Samples a set of classes using the provided (fixed) base distribution.
+
+ This operation randomly samples a tensor of sampled classes
+ (`sampled_candidates`) from the range of integers `[0, range_max]`.
+
+ The elements of `sampled_candidates` are drawn without replacement
+ (if `unique=True`) or with replacement (if `unique=False`) from
+ the base distribution.
+
+ The base distribution is read from a file or passed in as an
+ in-memory array. There is also an option to skew the distribution by
+ applying a distortion power to the weights.
+
+ In addition, this operation returns tensors `true_expected_count`
+ and `sampled_expected_count` representing the number of times each
+ of the target classes (`true_classes`) and the sampled
+ classes (`sampled_candidates`) is expected to occur in an average
+ tensor of sampled classes. These values correspond to `Q(y|x)`
+ defined in [this
+ document](http://www.tensorflow.org/extras/candidate_sampling.pdf).
+ If `unique=True`, then these are post-rejection probabilities and we
+ compute them approximately.
+
+ Args:
+ true_classes: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes.
+ num_true: An `int`. The number of target classes per training example.
+ num_sampled: An `int`. The number of classes to randomly sample per batch.
+ unique: A `bool`. Determines whether all sampled classes in a batch are
+ unique.
+ range_max: An `int`. The number of possible classes.
+ vocab_file: Each valid line in this file (which should have a CSV-like
+ format) corresponds to a valid word ID. IDs are in sequential order,
+ starting from num_reserved_ids. The last entry in each line is expected
+ to be a value corresponding to the count or relative probability. Exactly
+ one of `vocab_file` and `unigrams` needs to be passed to this operation.
+ distortion: The distortion is used to skew the unigram probability
+ distribution. Each weight is first raised to the distortion's power
+ before adding to the internal unigram distribution. As a result,
+ `distortion = 1.0` gives regular unigram sampling (as defined by the vocab
+ file), and `distortion = 0.0` gives a uniform distribution.
+ num_reserved_ids: Optionally some reserved IDs can be added in the range
+ `[0, num_reserved_ids]` by the users. One use case is that a special
+ unknown word token is used as ID 0. These IDs will have a sampling
+ probability of 0.
+ num_shards: A sampler can be used to sample from a subset of the original
+ range in order to speed up the whole computation through parallelism. This
+ parameter (together with `shard`) indicates the number of partitions that
+ are being used in the overall computation.
+ shard: A sampler can be used to sample from a subset of the original range
+ in order to speed up the whole computation through parallelism. This
+ parameter (together with `num_shards`) indicates the particular partition
+ number of the operation, when partitioning is being used.
+ unigrams: A list of unigram counts or probabilities, one per ID in
+ sequential order. Exactly one of `vocab_file` and `unigrams` should be
+ passed to this operation.
+ seed: An `int`. An operation-specific seed. Default is 0.
+ name: A name for the operation (optional).
+
+ Returns:
+ sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
+ The sampled classes.
+ true_expected_count: A tensor of type `float`. Same shape as
+ `true_classes`. The expected counts under the sampling distribution
+ of each of `true_classes`.
+ sampled_expected_count: A tensor of type `float`. Same shape as
+ `sampled_candidates`. The expected counts under the sampling distribution
+ of each of `sampled_candidates`.
+
+ """
+ seed1, seed2 = random_seed.get_seed(seed)
+ return gen_candidate_sampling_ops._fixed_unigram_candidate_sampler(
+ true_classes, num_true, num_sampled, unique, range_max,
+ vocab_file=vocab_file, distortion=distortion,
+ num_reserved_ids=num_reserved_ids, num_shards=num_shards, shard=shard,
+ unigrams=unigrams, seed=seed1, seed2=seed2, name=name)
+
+
+def all_candidate_sampler(true_classes, num_true, num_sampled, unique,
+ seed=None, name=None):
+ """Generate the set of all classes.
+
+ Deterministically generates and returns the set of all possible classes.
+ For testing purposes. There is no need to use this, since you might as
+ well use full softmax or full logistic regression.
+
+ Args:
+ true_classes: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes.
+ num_true: An `int`. The number of target classes per training example.
+ num_sampled: An `int`. The number of possible classes.
+ unique: A `bool`. Ignored.
+ unique.
+ seed: An `int`. An operation-specific seed. Default is 0.
+ name: A name for the operation (optional).
+
+ Returns:
+ sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
+ This operation deterministically returns the entire range
+ `[0, num_sampled]`.
+ true_expected_count: A tensor of type `float`. Same shape as
+ `true_classes`. The expected counts under the sampling distribution
+ of each of `true_classes`. All returned values are 1.0.
+ sampled_expected_count: A tensor of type `float`. Same shape as
+ `sampled_candidates`. The expected counts under the sampling distribution
+ of each of `sampled_candidates`. All returned values are 1.0.
+ """
+ seed1, seed2 = random_seed.get_seed(seed)
+ return gen_candidate_sampling_ops._all_candidate_sampler(
+ true_classes, num_true, num_sampled, unique, seed=seed1, seed2=seed2,
+ name=name)
+
+
+def compute_accidental_hits(true_classes, sampled_candidates, num_true,
+ seed=None, name=None):
+ """Compute the ids of positions in sampled_candidates matching true_classes.
+
+ In Candidate Sampling, this operation facilitates virtually removing
+ sampled classes which happen to match target classes. This is done
+ in Sampled Softmax and Sampled Logistic.
+
+ See our [Candidate Sampling Algorithms
+ Reference](http://www.tensorflow.org/extras/candidate_sampling.pdf).
+
+ We presuppose that the `sampled_candidates` are unique.
+
+ We call it an 'accidental hit' when one of the target classes
+ matches one of the sampled classes. This operation reports
+ accidental hits as triples `(index, id, weight)`, where `index`
+ represents the row number in `true_classes`, `id` represents the
+ position in `sampled_candidates`, and weight is `-FLOAT_MAX`.
+
+ The result of this op should be passed through a `sparse_to_dense`
+ operation, then added to the logits of the sampled classes. This
+ removes the contradictory effect of accidentally sampling the true
+ target classes as noise classes for the same example.
+
+ Args:
+ true_classes: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes.
+ sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
+ The sampled_candidates output of CandidateSampler.
+ num_true: An `int`. The number of target classes per training example.
+ seed: An `int`. An operation-specific seed. Default is 0.
+ name: A name for the operation (optional).
+
+ Returns:
+ indices: A `Tensor` of type `int32` and shape `[num_accidental_hits]`.
+ Values indicate rows in `true_classes`.
+ ids: A `Tensor` of type `int64` and shape `[num_accidental_hits]`.
+ Values indicate positions in `sampled_candidates`.
+ weights: A `Tensor` of type `float` and shape `[num_accidental_hits]`.
+ Each value is `-FLOAT_MAX`.
+
+ """
+ seed1, seed2 = random_seed.get_seed(seed)
+ return gen_candidate_sampling_ops._compute_accidental_hits(
+ true_classes, sampled_candidates, num_true, seed=seed1, seed2=seed2,
+ name=name)
+
+
+@ops.RegisterShape("AllCandidateSampler")
+@ops.RegisterShape("FixedUnigramCandidateSampler")
+@ops.RegisterShape("LearnedUnigramCandidateSampler")
+@ops.RegisterShape("LogUniformCandidateSampler")
+@ops.RegisterShape("ThreadUnsafeUnigramCandidateSampler")
+@ops.RegisterShape("UniformCandidateSampler")
+def _CandidateSamplerShape(op):
+ true_classes_shape = op.inputs[0].get_shape().with_rank(2)
+ batch_size = true_classes_shape[0]
+ num_sampled = op.get_attr("num_sampled")
+ num_true = op.get_attr("num_true")
+ return [tensor_shape.vector(num_sampled),
+ tensor_shape.matrix(batch_size, num_true),
+ tensor_shape.vector(num_sampled)]
+
+
+@ops.RegisterShape("ComputeAccidentalHits")
+def _ComputeAccidentalHitsShape(op):
+ num_true = op.get_attr("num_true")
+ # Validate that the input shape matches the attrs, even though it
+ # does not influence the shape of the output.
+ true_candidates_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.matrix(None, num_true))
+ output_shape = tensor_shape.vector(None)
+ return [output_shape] * 3
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
new file mode 100644
index 0000000000..08781932f9
--- /dev/null
+++ b/tensorflow/python/ops/clip_ops.py
@@ -0,0 +1,234 @@
+"""Operations for clipping (gradient, weight) tensors to min/max values."""
+
+import collections
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import math_ops
+
+
+def clip_by_value(t, clip_value_min, clip_value_max,
+ name=None):
+ """Clips tensor values to a specified min and max.
+
+ Given a tensor `t`, this operation returns a tensor of the same type and
+ shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`.
+ Any values less than `clip_value_min` are set to `clip_value_min`. Any values
+ greater than `clip_value_max` are set to `clip_value_max`.
+
+ Args:
+ t: A `Tensor`.
+ clip_value_min: A 0-D (scalar) `Tensor`. The minimum value to clip by.
+ clip_value_max: A 0-D (scalar) `Tensor`. The maximum value to clip by.
+ name: A name for the operation (optional).
+
+ Returns:
+ A clipped `Tensor`.
+ """
+ with ops.op_scope([t, clip_value_min, clip_value_max], name,
+ "clip_by_value") as name:
+ t = ops.convert_to_tensor(t, name="t")
+
+ # Go through list of tensors, for each value in each tensor clip
+ t_min = math_ops.minimum(
+ t, array_ops.fill(array_ops.shape(t), clip_value_max))
+ t_max = math_ops.maximum(
+ t_min, array_ops.fill(array_ops.shape(t), clip_value_min),
+ name=name)
+
+ return t_max
+
+
+def clip_by_norm(t, clip_norm, name=None):
+ """Clips tensor values to a maximum L2-norm.
+
+ Given a tensor `t`, and a maximum clip value `clip_norm`, this operation
+ normalizes `t` so that its L2-norm is less than or equal to `clip_norm'.
+ Specifically, if the L2-norm is already less than or equal to `clip_norm`,
+ then `t` is not modified. If the L2-norm is greater than `clip_norm`, then
+ this operation returns a tensor of the same type and shape as `t` with its
+ values set to:
+
+ `t * clip_norm / l2norm(t)`
+
+ In this case, the L2-norm of the output tensor is `clip_norm`.
+
+ This operation is typically used to clip gradients before applying them with
+ an optimizer.
+
+ Args:
+ t: A `Tensor`.
+ clip_norm: A 0-D (scalar) `Tensor` > 0. A maximum clipping value.
+ name: A name for the operation (optional).
+
+ Returns:
+ A clipped `Tensor`.
+ """
+ with ops.op_scope([t, clip_norm], name, "clip_by_norm") as name:
+ t = ops.convert_to_tensor(t, name="t")
+
+ # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
+ l2norm_inv = math_ops.rsqrt(
+ math_ops.reduce_sum(t * t, math_ops.range(0, array_ops.rank(t))))
+ tclip = array_ops.identity(t * clip_norm * math_ops.minimum(
+ l2norm_inv, constant_op.constant(1.0 / clip_norm)), name=name)
+
+ return tclip
+
+def global_norm(t_list, name=None):
+ """Computes the global norm of multiple tensors.
+
+ Given a tuple or list of tensors `t_list`, this operation returns the
+ global norm of the elements in all tensors in `t_list`. The global norm is
+ computed as:
+
+ `global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))`
+
+ Any entries in `t_list` that are of type None are ignored.
+
+ Args:
+ t_list: A tuple or list of mixed `Tensors`, `IndexedSlices`, or None.
+ name: A name for the operation (optional).
+
+ Returns:
+ A 0-D (scalar) `Tensor` of type `float`.
+
+ Raises:
+ TypeError: If `t_list` is not a sequence.
+ """
+ if (not isinstance(t_list, collections.Sequence)
+ or isinstance(t_list, basestring)):
+ raise TypeError("t_list should be a sequence")
+ t_list = list(t_list)
+ with ops.op_scope(t_list, name, "global_norm") as name:
+ values = [
+ ops.convert_to_tensor(
+ t.values if isinstance(t, ops.IndexedSlices) else t,
+ name="t_%d" % i)
+ if t is not None else t
+ for i, t in enumerate(t_list)]
+ squared_norms = array_ops.pack(
+ [math_ops.reduce_sum(v * v) for v in values if v])
+
+ norm = math_ops.sqrt(
+ math_ops.reduce_sum(squared_norms), name="global_norm")
+
+ return norm
+
+def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
+ """Clips values of multiple tensors by the ratio of the sum of their norms.
+
+ Given a tuple or list of tensors `t_list`, and a clipping ratio `clip_norm`,
+ this operation returns a list of clipped tensors `list_clipped`
+ and the global norm (`global_norm`) of all tensors in `t_list`. Optionally,
+ if you've already computed the global norm for `t_list`, you can specify
+ the global norm with `use_norm`.
+
+ To perform the clipping, the values t_list[i] are set to:
+
+ `t_list[i] * clip_norm / max(global_norm, clip_norm)`
+
+ where:
+
+ `global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))`
+
+ If `clip_norm > global_norm` then the entries in `t_list` remain as they are,
+ otherwise they're all shrunk by the global ratio.
+
+ Any of the entries of `t_list` that are of type None are ignored.
+
+ This is the correct way to perform gradient clipping (for example, see
+ R. Pascanu, T. Mikolov, and Y. Bengio, "On the difficulty of training
+ Recurrent Neural Networks". http://arxiv.org/abs/1211.5063)
+
+ However, it is slower than `clip_by_norm()` because all the parameters must be
+ ready before the clipping operation can be performed.
+
+ Args:
+ t_list: A tuple or list of mixed `Tensors`, `IndexedSlices`, or None.
+ clip_norm: A 0-D (scalar) `Tensor` > 0. The clipping ratio.
+ use_norm: A 0-D (scalar) `Tensor` of type `float` (optional). The global
+ norm to use. If not provided, `global_norm()` is used to compute the norm.
+ name: A name for the operation (optional).
+
+ Returns:
+ list_clipped: A list of `Tensors` of the same type as `list_t`.
+ global_norm: A 0-D (scalar) `Tensor` representing the global norm.
+
+ Raises:
+ TypeError: If `t_list` is not a sequence.
+ """
+ if (not isinstance(t_list, collections.Sequence)
+ or isinstance(t_list, basestring)):
+ raise TypeError("t_list should be a sequence")
+ t_list = list(t_list)
+ if use_norm is None:
+ use_norm = global_norm(t_list, name)
+
+ with ops.op_scope(t_list + [clip_norm], name, "clip_by_global_norm") as name:
+ # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
+ scale = clip_norm * math_ops.minimum(
+ 1.0 / use_norm, constant_op.constant(1.0 / clip_norm))
+
+ values = [
+ ops.convert_to_tensor(
+ t.values if isinstance(t, ops.IndexedSlices) else t,
+ name="t_%d" % i)
+ if t is not None else t
+ for i, t in enumerate(t_list)]
+
+ values_clipped = [
+ array_ops.identity(v * scale, name="%s_%d" % (name, i))
+ if v is not None else None
+ for i, v in enumerate(values)]
+
+ list_clipped = [
+ ops.IndexedSlices(c_v, t.indices)
+ if isinstance(t, ops.IndexedSlices)
+ else c_v
+ for (c_v, t) in zip(values_clipped, t_list)]
+
+ return list_clipped, use_norm
+
+
+def clip_by_average_norm(t, clip_norm, name=None):
+ """Clips tensor values to a maximum average L2-norm.
+
+ Given a tensor `t`, and a maximum clip value `clip_norm`, this operation
+ normalizes `t` so that its average L2-norm is less than or equal to
+ `clip_norm'. Specifically, if the average L2-norm is already less than or
+ equal to `clip_norm`, then `t` is not modified. If the average L2-norm is
+ greater than `clip_norm`, then this operation returns a tensor of the same
+ type and shape as `t` with its values set to:
+
+ `t * clip_norm / l2norm_avg(t)`
+
+ In this case, the average L2-norm of the output tensor is `clip_norm`.
+
+ This operation is typically used to clip gradients before applying them with
+ an optimizer.
+
+ Args:
+ t: A `Tensor`.
+ clip_norm: A 0-D (scalar) `Tensor` > 0. A maximum clipping value.
+ name: A name for the operation (optional).
+
+ Returns:
+ A clipped `Tensor`.
+ """
+ with ops.op_scope([t, clip_norm], name, "clip_by_average_norm") as name:
+ t = ops.convert_to_tensor(t, name="t")
+
+ # Calculate L2-norm per element, clip elements by ratio of clip_norm to
+ # L2-norm per element
+ n_element = math_ops.cast(array_ops.size(t), types.float32)
+ l2norm_inv = math_ops.rsqrt(
+ math_ops.reduce_sum(t * t, math_ops.range(0, array_ops.rank(t))))
+ tclip = array_ops.identity(
+ t * clip_norm * math_ops.minimum(
+ l2norm_inv * n_element, constant_op.constant(1.0 / clip_norm)),
+ name=name)
+
+ return tclip
diff --git a/tensorflow/python/ops/common_shapes.py b/tensorflow/python/ops/common_shapes.py
new file mode 100644
index 0000000000..c41d1ff71d
--- /dev/null
+++ b/tensorflow/python/ops/common_shapes.py
@@ -0,0 +1,371 @@
+"""A library of common shape functions."""
+import math
+
+from tensorflow.python.framework import tensor_shape
+
+
+def scalar_shape(unused_op):
+ """Shape function for ops that output a scalar value."""
+ return [tensor_shape.scalar()]
+
+
+def unchanged_shape(op):
+ """Shape function for ops that output an tensor like their first input."""
+ return [op.inputs[0].get_shape()]
+
+
+def unchanged_shape_with_rank(rank):
+ """Returns a shape function for ops that constrain the rank of their input.
+
+ Args:
+ rank: The exact rank of the input and output.
+
+ Returns:
+ A shape function for ops that output a tensor of the same size as their
+ input, with a particular rank.
+ """
+ def _ShapeFunction(op):
+ return [op.inputs[0].get_shape().with_rank(rank)]
+ return _ShapeFunction
+
+
+def unchanged_shape_with_rank_at_least(rank):
+ """Returns a shape function for ops that constrain the rank of their input.
+
+ Args:
+ rank: A lower bound on the rank of the input and output.
+
+ Returns:
+ A shape function for ops that output a tensor of the same size as their
+ input, with a particular rank.
+ """
+ def _ShapeFunction(op):
+ return [op.inputs[0].get_shape().with_rank_at_least(rank)]
+ return _ShapeFunction
+
+
+def unchanged_shape_with_rank_at_most(rank):
+ """Returns a shape function for ops that constrain the rank of their input.
+
+ Args:
+ rank: An upper bound on the rank of the input and output.
+
+ Returns:
+ A shape function for ops that output a tensor of the same size as their
+ input, with a particular rank.
+ """
+ def _ShapeFunction(op):
+ return [op.inputs[0].get_shape().with_rank_at_most(rank)]
+ return _ShapeFunction
+
+
+def matmul_shape(op):
+ """Shape function for a MatMul op."""
+ a_shape = op.inputs[0].get_shape().with_rank(2)
+ transpose_a = op.get_attr("transpose_a")
+ b_shape = op.inputs[1].get_shape().with_rank(2)
+ transpose_b = op.get_attr("transpose_b")
+ output_rows = a_shape[1] if transpose_a else a_shape[0]
+ output_cols = b_shape[0] if transpose_b else b_shape[1]
+ inner_a = a_shape[0] if transpose_a else a_shape[1]
+ inner_b = b_shape[1] if transpose_b else b_shape[0]
+ inner_a.assert_is_compatible_with(inner_b)
+ return [tensor_shape.TensorShape([output_rows, output_cols])]
+
+
+def bias_add_shape(op):
+ """Shape function for a BiasAdd op."""
+ input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
+ bias_shape = op.inputs[1].get_shape().with_rank(1)
+ if input_shape.ndims is not None:
+ # Output has the same shape as input, and matches the length of
+ # bias in its last dimension.
+ output_shape = input_shape[0:-1].concatenate(
+ input_shape[-1].merge_with(bias_shape[0]))
+ else:
+ output_shape = tensor_shape.unknown_shape()
+ return [output_shape]
+
+
+def _Get2DOutputSize(input_height, input_width, filter_height, filter_width,
+ row_stride, col_stride, padding_type):
+ """Returns the number of rows and columns in a convolution/pooling output."""
+ input_height = tensor_shape.as_dimension(input_height)
+ input_width = tensor_shape.as_dimension(input_width)
+ filter_height = tensor_shape.as_dimension(filter_height)
+ filter_width = tensor_shape.as_dimension(filter_width)
+ row_stride = int(row_stride)
+ col_stride = int(col_stride)
+
+ if filter_height.value == 1 and filter_width.value == 1 and (
+ row_stride == 1 and col_stride == 1):
+ return input_height, input_width
+ else:
+ if filter_height > input_height or filter_width > input_width:
+ raise ValueError("filter must not be larger than the input: ",
+ "Filter: [", filter_height, "x", filter_width, "] ",
+ "Input: [", input_height, "x", input_width, "] ")
+ if row_stride > filter_height or col_stride > filter_width:
+ raise ValueError("stride must be less than or equal to filter size",
+ "stride: [", row_stride, "x", col_stride, "] ",
+ "filter: [", filter_height, "x", filter_width, "] ")
+
+ # Compute number of rows in the output, based on the padding.
+ if input_height.value is None or filter_height.value is None:
+ out_rows = None
+ elif padding_type == "VALID":
+ out_rows = int(
+ math.ceil((input_height.value - filter_height.value + 1.0)
+ / row_stride))
+ elif padding_type == "SAME":
+ out_rows = int(math.ceil(input_height.value * 1.0
+ / row_stride))
+ else:
+ raise ValueError("Invalid value for padding: %r" % padding_type)
+
+ # Compute number of columns in the output, based on the padding.
+ if input_width.value is None or filter_width.value is None:
+ out_cols = None
+ elif padding_type == "VALID":
+ out_cols = int(
+ math.ceil((input_width.value - filter_width.value + 1.0)
+ / col_stride))
+ elif padding_type == "SAME":
+ out_cols = int(math.ceil(input_width.value * 1.0 / col_stride))
+
+ return out_rows, out_cols
+
+
+def conv2d_shape(op):
+ """Shape function for a Conv2D op.
+
+ This op has two inputs:
+
+ * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
+ * filter, a 4D tensor with shape = [filter_rows, filter_cols,
+ depth_in, depth_out]
+
+ The output is a 4D tensor with shape = [batch_size, out_rows,
+ out_cols, depth_out], where out_rows and out_cols depend on the
+ value of the op's "padding" and "strides" attrs.
+
+ Args:
+ op: A Conv2D Operation.
+
+ Returns:
+ A list containing the Shape of the Conv2D output.
+
+ Raises:
+ ValueError: If the shapes of the input or filter are incompatible.
+ """
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ filter_shape = op.inputs[1].get_shape().with_rank(4)
+
+ batch_size = input_shape[0]
+ in_rows = input_shape[1]
+ in_cols = input_shape[2]
+
+ filter_rows = filter_shape[0]
+ filter_cols = filter_shape[1]
+ depth_out = filter_shape[3]
+ # Check that the input depths are compatible.
+ input_shape[3].assert_is_compatible_with(filter_shape[2])
+
+ stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
+ if stride_b != 1 or stride_d != 1:
+ raise ValueError("Current implementation does not yet support "
+ "strides in the batch and depth dimensions.")
+ if stride_r != stride_c:
+ # TODO(shlens): Add support for this.
+ raise ValueError("Current implementation only supports equal length "
+ "strides in the row and column dimensions.")
+
+ # TODO(mrry,shlens): Raise an error if the stride would cause
+ # information in the input to be ignored. This will require a change
+ # in the kernel implementation.
+ stride = stride_r
+ padding = op.get_attr("padding")
+ out_rows, out_cols = _Get2DOutputSize(
+ in_rows, in_cols, filter_rows, filter_cols, stride, stride, padding)
+
+ return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
+
+
+def separable_conv2d_shape(op):
+ """Shape function for a SeparableConv2D op.
+
+ This op has three inputs:
+
+ * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
+
+ * depthwise_filter, a 4D tensor with shape = [filter_rows,
+ filter_cols, depth_in, depth_multiplier]
+
+ * pointwise_filter, a 4D tensor with shape = [1, 1, depth_in *
+ depth_multiplier, depth_out]
+
+ The output is a 4D tensor with shape = [batch_size, out_rows,
+ out_cols, depth_out], where out_rows and out_cols depend on the
+ value of the op's "padding" and "strides" attrs.
+
+ Args:
+ op: A SeparableConv2D Operation.
+
+ Returns:
+ A list containing the Shape of the SeparableConv2D output.
+
+ Raises:
+ ValueError: If the shapes of the input or filter are incompatible.
+ """
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ depthwise_filter_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.TensorShape([None, None, input_shape[3], None]))
+ pointwise_depth_in = depthwise_filter_shape[2] * depthwise_filter_shape[3]
+
+ pointwise_filter_shape = op.inputs[2].get_shape().merge_with(
+ tensor_shape.TensorShape([1, 1, pointwise_depth_in, None]))
+
+ batch_size = input_shape[0]
+ in_rows = input_shape[1]
+ in_cols = input_shape[2]
+
+ filter_rows = depthwise_filter_shape[0]
+ filter_cols = depthwise_filter_shape[1]
+ depth_out = pointwise_filter_shape[3]
+
+ stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
+ if stride_b != 1 or stride_d != 1:
+ raise ValueError("Current implementation does not yet support "
+ "strides in the batch and depth dimensions.")
+ if stride_r != stride_c:
+ # TODO(shlens): Add support for this.
+ raise ValueError("Current implementation only supports equal length "
+ "strides in the row and column dimensions.")
+
+ # TODO(mrry,shlens): Raise an error if the stride would cause
+ # information in the input to be ignored. This will require a change
+ # in the kernel implementation.
+ stride = stride_r
+ padding = op.get_attr("padding")
+ out_rows, out_cols = _Get2DOutputSize(
+ in_rows, in_cols, filter_rows, filter_cols, stride, stride, padding)
+
+ return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
+
+
+def avg_pool_shape(op):
+ """Shape function for an AvgPool op.
+
+ This op has one input:
+
+ * input, a 4D tensor with shape = [batch_size, rows, cols, depth]
+
+ The output is a 4D tensor with shape = [batch_size, out_rows,
+ out_cols, depth_out], where out_rows and out_cols depend on the
+ value of the op's "ksize", "strides", and "padding" attrs.
+
+ Args:
+ op: An AvgPool Operation.
+
+ Returns:
+ A single-element list containing the Shape of the AvgPool output.
+
+ Raises:
+ ValueError: If the shape of the input is invalid or incompatible with
+ the values of the attrs.
+ """
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
+ stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
+
+ batch_size = input_shape[0]
+ in_rows = input_shape[1]
+ in_cols = input_shape[2]
+ depth = input_shape[3]
+
+ if ksize_b != 1 or ksize_d != 1:
+ raise ValueError("Current implementation does not support pooling "
+ "in the batch and depth dimensions.")
+ if stride_b != 1 or stride_d != 1:
+ raise ValueError("Current implementation does not support strides "
+ "in the batch and depth dimensions.")
+
+ # TODO(mrry,shlens): Raise an error if the stride would cause
+ # information in the input to be ignored. This will require a change
+ # in the kernel implementation.
+ padding = op.get_attr("padding")
+
+ out_rows, out_cols = _Get2DOutputSize(
+ in_rows, in_cols, ksize_r, ksize_c, stride_r, stride_c, padding)
+
+ return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth])]
+
+
+def max_pool_shape(op):
+ """Shape function for a MaxPool op.
+
+ This op has one input:
+
+ * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
+
+ The output is a 4D tensor with shape = [batch_size, out_rows,
+ out_cols, depth_out], where out_rows, out_cols, and depth_out depend
+ on the value of the op's "ksize", "strides", and "padding" attrs.
+
+ Args:
+ op: A MaxPool Operation.
+
+ Returns:
+ A single-element list containing the Shape of the MaxPool output.
+
+ Raises:
+ ValueError: If the shape of the input is invalid or incompatible with
+ the values of the attrs.
+ """
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
+ stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
+
+ batch_size = input_shape[0]
+ in_rows = input_shape[1]
+ in_cols = input_shape[2]
+ depth = input_shape[3]
+
+ if ksize_b != 1:
+ raise ValueError("Current implementation does not support pooling "
+ "in the batch dimension.")
+ if stride_b != 1:
+ raise ValueError("Current implementation does not support strides "
+ "in the batch dimension.")
+
+ if not ((ksize_r == 1 and ksize_c == 1) or ksize_d == 1):
+ raise ValueError("MaxPooling supports exactly one of pooling across depth "
+ "or pooling across width/height.")
+
+ # TODO(mrry,shlens): Raise an error if the stride would cause
+ # information in the input to be ignored. This will require a change
+ # in the kernel implementation.
+ if ksize_d == 1:
+ padding = op.get_attr("padding")
+ out_rows, out_cols = _Get2DOutputSize(
+ in_rows, in_cols, ksize_r, ksize_c, stride_r, stride_c, padding)
+ return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth])]
+ else:
+ if depth % ksize_d > 0:
+ raise ValueError("Depthwise max pooling requires the depth window "
+ "to evenly divide the input depth.")
+ if stride_d != ksize_d:
+ raise ValueError("Depthwise max pooling requires the depth window "
+ "to equal the depth stride.")
+ return [tensor_shape.TensorShape(
+ [batch_size, in_rows, in_cols, depth / ksize_d])]
+
+
+def no_outputs(unused_op):
+ """Shape function for use with ops that have no outputs."""
+ return []
+
+
+def unknown_shape(op):
+ """Shape function for use with ops whose output shapes are unknown."""
+ return [tensor_shape.unknown_shape() for _ in op.outputs]
diff --git a/tensorflow/python/ops/constant_op.py b/tensorflow/python/ops/constant_op.py
new file mode 100644
index 0000000000..7d9044b689
--- /dev/null
+++ b/tensorflow/python/ops/constant_op.py
@@ -0,0 +1,189 @@
+"""## Constant Value Tensors
+
+TensorFlow provides several operations that you can use to generate constants.
+
+@@zeros
+@@zeros_like
+
+@@ones
+@@ones_like
+
+@@fill
+
+@@constant
+
+## Sequences
+
+@@linspace
+
+@@range
+
+## Random Tensors
+
+TensorFlow has several ops that create random tensors with different
+distributions. The random ops are stateful, and create new random values each
+time they are evaluated.
+
+The `seed` keyword argument in these functions acts in conjunction with
+the graph-level random seed. Changing either the graph-level seed using
+[`set_random_seed`](constant_op.md#set_random_seed) or the op-level seed
+will change the underlying seed of these operations. Setting neither graph-level
+nor op-level seed, results in a random seed for all operations.
+See [`set_random_seed`](constant_op.md#set_random_seed) for details on the
+interaction between operation-level and graph-level random seeds.
+
+### Examples:
+
+```python
+# Create a tensor of shape [2, 3] consisting of random normal values, with mean
+# -1 and standard deviation 4.
+norm = tf.random_normal([2, 3], mean=-1, stddev=4)
+
+# Shuffle the first dimension of a tensor
+c = tf.constant([[1, 2], [3, 4], [5, 6]])
+shuff = tf.random_shuffle(c)
+
+# Each time we run these ops, different results are generated
+sess = tf.Session()
+print sess.run(norm)
+print sess.run(norm)
+
+# Set an op-level seed to generate repeatable sequences across sessions.
+c = tf.constant([[1, 2], [3, 4], [5, 6]])
+sess = tf.Session()
+norm = tf.random_normal(c, seed=1234)
+print sess.run(norm)
+print sess.run(norm)
+```
+
+Another common use of random values is the intialization of variables. Also see
+the [Variables How To](../../how_tos/variables/index.md).
+
+```python
+# Use random uniform values in [0, 1) as the initializer for a variable of shape
+# [2, 3]. The default type is float32.
+var = tf.Variable(tf.random_uniform([2, 3]), name="var")
+init = tf.initialize_all_variables()
+
+sess = tf.Session()
+sess.run(init)
+print sess.run(var)
+```
+
+@@random_normal
+@@truncated_normal
+@@random_uniform
+@@random_shuffle
+@@set_random_seed
+
+"""
+"""Constant Operation.
+
+Has to be separate from array_ops to avoid a cyclic dependency.
+"""
+import tensorflow.python.platform
+import numpy as np
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+
+
+def constant(value, dtype=None, shape=None, name="Const"):
+ """Creates a constant tensor.
+
+ The resulting tensor is populated with values of type `dtype`, as
+ specified by arguments `value` and (optionally) `shape` (see examples
+ below).
+
+ The argument `value` can be a constant value, or a list of values of type
+ `dtype`. If `value` is a list, then the length of the list must be less
+ than or equal to the number of elements implied by the `shape` argument (if
+ specified). In the case where the list length is less than the number of
+ elements specified by `shape`, the last element in the list will be used
+ to fill the remaining entries.
+
+ The argument `shape` is optional. If present, it specifies the dimensions
+ of the resulting tensor. If not present, then the tensor is a scalar (0-D)
+ if `value` is a scalar, or 1-D otherwise.
+
+ If the argument `dtype` is not specified, then the type is inferred from
+ the type of `value`.
+
+ For example:
+
+ ```python
+ # Constant 1-D Tensor populated with value list.
+ tensor = tf.constant([1, 2, 3, 4, 5, 6, 7]) => [1 2 3 4 5 6 7]
+
+ # Constant 2-D tensor populated with scalar value -1.
+ tensor = tf.constant(-1.0, shape=[2, 3]) => [[-1. -1. -1.]
+ [-1. -1. -1.]]
+ ```
+
+ Args:
+ value: A constant value (or list) of output type `dtype`.
+
+ dtype: The type of the elements of the resulting tensor.
+
+ shape: Optional dimensions of resulting tensor.
+
+ name: Optional name for the tensor.
+
+ Returns:
+ A Constant Tensor.
+ """
+ g = ops.get_default_graph()
+ tensor_value = attr_value_pb2.AttrValue()
+ tensor_value.tensor.CopyFrom(
+ tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape))
+ dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
+ const_tensor = g.create_op(
+ "Const", [], [dtype_value.type],
+ attrs={"value": tensor_value, "dtype": dtype_value}, name=name).outputs[0]
+ return const_tensor
+
+
+@ops.RegisterShape("Const")
+def _ConstantShape(op):
+ return [tensor_shape.TensorShape(
+ [d.size for d in op.get_attr("value").tensor_shape.dim])]
+
+
+ops.register_tensor_conversion_function((list, tuple), constant, 100)
+ops.register_tensor_conversion_function(np.ndarray, constant, 100)
+ops.register_tensor_conversion_function(np.generic, constant, 100)
+ops.register_tensor_conversion_function(object, constant, 200)
+
+def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None):
+ if not s.is_fully_defined():
+ raise ValueError(
+ "Cannot convert a partially known TensorShape to a Tensor: %s" % s)
+ if dtype is not None:
+ if dtype not in (types.int32, types.int64):
+ raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype)
+ else:
+ dtype = types.int32
+ if name is None:
+ name = "shape_as_tensor"
+ return constant(s.as_list(), dtype=dtype, name=name)
+
+ops.register_tensor_conversion_function(
+ tensor_shape.TensorShape, _tensor_shape_tensor_conversion_function, 100)
+
+def _dimension_tensor_conversion_function(d, dtype=None, name=None):
+ if d.value is None:
+ raise ValueError("Cannot convert an unknown Dimension to a Tensor: %s" % d)
+ if dtype is not None:
+ if dtype not in (types.int32, types.int64):
+ raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype)
+ else:
+ dtype = types.int32
+ if name is None:
+ name = "shape_as_tensor"
+ return constant(d.value, dtype=dtype, name=name)
+
+ops.register_tensor_conversion_function(
+ tensor_shape.Dimension, _dimension_tensor_conversion_function, 100)
diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py
new file mode 100644
index 0000000000..3a1a5b91c0
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_grad.py
@@ -0,0 +1,100 @@
+"""Gradients for operators defined in control_flow_ops.py."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+# pylint: disable=wildcard-import,undefined-variable
+from tensorflow.python.ops.control_flow_ops import *
+from tensorflow.python.ops.gen_control_flow_ops import *
+
+
+@ops.RegisterGradient("Switch")
+def _SwitchGrad(op, *grad):
+ op = GetRealOp(op)
+ ctxt = op._get_control_flow_context() # pylint: disable=protected-access
+ if isinstance(ctxt, WhileContext):
+ merge_op = ctxt.switch_map.get(op)
+ if merge_op:
+ merge_op._update_input(1, grad[1])
+ return None, None
+ else:
+ merge_op = merge(grad, name="b_switch")[0]
+ ctxt.switch_map[op] = merge_op.op
+ return merge_op, None
+ elif isinstance(ctxt, CondContext):
+ good_grad = grad[ctxt.branch]
+ zero_grad = grad[1 - ctxt.branch]
+ zero_grad = switch(zero_grad, ctxt.pred, name="grad_0")[1 - ctxt.branch]
+ return merge([good_grad, zero_grad], name="switch_grad")[0], None
+ else:
+ false_grad = switch(grad[0], op.inputs[1])[0]
+ true_grad = switch(grad[1], op.inputs[1])[1]
+ return merge([false_grad, true_grad])[0], None
+
+
+@ops.RegisterGradient("RefSwitch")
+def _RefSwitchGrad(op, *grad):
+ return _SwitchGrad(op, *grad)
+
+
+@ops.RegisterGradient("Merge")
+def _MergeGrad(op, grad, _):
+ op = GetRealOp(op)
+ input_op = op.inputs[0].op
+ # pylint: disable=protected-access
+ ctxt = input_op._get_control_flow_context()
+ # pylint: enable=protected-access
+ if isinstance(ctxt, WhileContext):
+ grad_ctxt = ctxt.grad_context
+ return switch(grad, grad_ctxt.pivot)
+ elif isinstance(ctxt, CondContext):
+ return switch(grad, ctxt.pred, name="merge_grad")
+ else:
+ num_inputs = len(op.inputs)
+ cond = [math_ops.equal(op.outputs[1], i) for i in xrange(num_inputs)]
+ return [Switch(grad, cond[i])[1] for i in xrange(num_inputs)]
+
+
+@ops.RegisterGradient("Exit")
+def _ExitGrad(op, grad):
+ # pylint: disable=protected-access
+ forward_ctxt = op._get_control_flow_context()
+ # pylint: enable=protected-access
+ if not forward_ctxt.back_prop:
+ return None
+ grad_ctxt = forward_ctxt.grad_context
+ grad_ctxt.AddName(grad.name)
+ return enter(grad, grad_ctxt.name, is_constant=False,
+ parallel_iterations=forward_ctxt.parallel_iterations,
+ name="b_exit")
+
+
+@ops.RegisterGradient("NextIteration")
+def _NextIterationGrad(_, grad):
+ return next_iteration(grad)
+
+
+@ops.RegisterGradient("Enter")
+def _EnterGrad(op, grad):
+ op = GetRealOp(op)
+ # pylint: disable=protected-access
+ forward_ctxt = op._get_control_flow_context()
+ # pylint: enable=protected-access
+ grad_ctxt = forward_ctxt.grad_context
+ if grad_ctxt:
+ if op.get_attr("is_constant"):
+ # Add a gradient accumulator for every loop invariant.
+ result = grad_ctxt.AddBackPropAccumulateLoop(grad)
+ else:
+ result = exit(grad)
+ return result
+ else:
+ return grad
+
+
+@ops.RegisterGradient("RefEnter")
+def _RefEnterGrad(op, grad):
+ return _EnterGrad(op, grad)
+
+
+@ops.RegisterGradient("LoopCond")
+def _LoopCondGrad(_):
+ return None
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
new file mode 100644
index 0000000000..068e3b5553
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -0,0 +1,1561 @@
+"""## Control Flow Operations
+
+TensorFlow provides several operations and classes that you can use to control
+the execution of operations and add conditional dependencies to your graph.
+
+@@identity
+@@tuple
+@@group
+@@no_op
+@@count_up_to
+
+## Logical Operators
+
+TensorFlow provides several operations that you can use to add logical operators
+to your graph.
+
+@@logical_and
+@@logical_not
+@@logical_or
+@@logical_xor
+
+## Comparison Operators
+
+TensorFlow provides several operations that you can use to add comparison
+operators to your graph.
+
+@@equal
+@@not_equal
+@@less
+@@less_equal
+@@greater
+@@greater_equal
+@@select
+@@where
+
+## Debugging Operations
+
+TensorFlow provides several operations that you can use to validate values and
+debug your graph.
+
+@@is_finite
+@@is_inf
+@@is_nan
+@@verify_tensor_all_finite
+@@check_numerics
+@@add_check_numerics_ops
+@@Assert
+@@Print
+"""
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import gen_control_flow_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+# pylint: disable=wildcard-import,undefined-variable
+from tensorflow.python.ops.gen_control_flow_ops import *
+
+
+# We override the 'tuple' for a control flow op, so we keep python's
+# existing 'tuple' for later use in this module.
+_basetuple = tuple
+
+
+# pylint: disable=protected-access
+def _Identity(data, name=None):
+ """Return a tensor with the same shape and contents as the input tensor.
+
+ Args:
+ data: A Tensor.
+ name: A name for this operation (optional).
+
+ Returns:
+ A Tensor with the same type and value as the input Tensor.
+ """
+ if not data.dtype.is_ref_dtype:
+ return array_ops.identity(data, name=name)
+ else:
+ return gen_array_ops._ref_identity(data, name=name)
+
+
+def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
+ name=None):
+ """Creates or finds a child frame, and makes 'data' available to it.
+
+ The unique `frame_name` is used by the `Executor` to identify frames. If
+ `is_constant` is true, `output` is a constant in the child frame; otherwise
+ it may be changed in the child frame. At most `parallel_iterations` iterations
+ are run in parallel in the child frame.
+
+ Args:
+ data: The tensor to be made available to the child frame.
+ frame_name: The name of the child frame.
+ is_constant: If true, the output is constant within the child frame.
+ parallel_iterations: The number of iterations allowed to run in parallel.
+ name: A name for this operation (optional).
+
+ Returns:
+ The same tensor as 'data'.
+ """
+ if not data.dtype.is_ref_dtype:
+ return enter(data, frame_name, is_constant, parallel_iterations,
+ name=name)
+ else:
+ return ref_enter(data, frame_name, is_constant, parallel_iterations,
+ name=name)
+
+
+def exit(data, name=None):
+ """Exits the current frame to its parent frame.
+
+ Exit makes its input `data` available to the parent frame.
+
+ Args:
+ data: The tensor to be made available to the parent frame.
+ name: A name for this operation (optional).
+
+ Returns:
+ The same tensor as `data`.
+ """
+ return gen_control_flow_ops._exit(data, name)
+
+
+def switch(data, pred, name=None):
+ """Forwards `data` to an output determined by `pred`.
+
+ If `pred` is true, the `data` input is forwared to the first output.
+ Otherwise, the data goes to the second output.
+
+ This op handles `Tensor`s and `IndexedSlices`.
+
+ Args:
+ data: The tensor to be forwarded to the appropriate output.
+ pred: A scalar that specifies which output port will receive data.
+ name: A name for this operation (optional).
+
+ Returns:
+ `(output_true, output_false)`: If `pred` is true, data will be forwarded to
+ `output_true`, otherwise it goes to `output_false`.
+ """
+ with ops.op_scope([data, pred], name, "Switch") as name:
+ data = ops.convert_to_tensor_or_indexed_slices(data, name="data")
+ pred = ops.convert_to_tensor(pred, name="pred")
+ if isinstance(data, ops.Tensor):
+ return gen_control_flow_ops._switch(data, pred, name=name)
+ else:
+ val, ind, dense_shape = data.values, data.indices, data.dense_shape
+ val_f, val_t = gen_control_flow_ops._switch(val, pred, name=name)
+ ind_f, ind_t = gen_control_flow_ops._switch(ind, pred, name="indices")
+ if dense_shape:
+ dense_shape_f, dense_shape_t = gen_control_flow_ops._switch(
+ dense_shape, pred, name="dense_shape")
+ else:
+ dense_shape_f, dense_shape_t = None, None
+ return (ops.IndexedSlices(val_f, ind_f, dense_shape_f),
+ ops.IndexedSlices(val_t, ind_t, dense_shape_t))
+
+
+def merge(inputs, name=None):
+ """Returns the value of an available element of `inputs`.
+
+ This op tests each of the tensors in `inputs` in turn to determine if any of
+ them is available. If it finds an available tensor, it returns it and its
+ index in `inputs`.
+
+ It is an error if more than one tensor in `inputs` is available. If no tensor
+ in `inputs` is available, the returned tensor and index are not set.
+
+ This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
+ `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
+ before merging.
+
+ Args:
+ inputs: The input tensors, at most one of which is available.
+ name: A name for this operation (optional).
+
+ Returns:
+ A tuple containing the chosen input tensor and its index in `inputs`.
+
+ Raises:
+ ValueError: If inputs are IndexedSlices and some but not all have a
+ dense_shape property.
+ """
+ with ops.op_scope(inputs, name, "Merge") as name:
+ inputs = [ops.convert_to_tensor_or_indexed_slices(inp) for inp in inputs]
+ if all([isinstance(inp, ops.Tensor) for inp in inputs]):
+ return gen_control_flow_ops._merge(inputs, name=name)
+ else:
+ inputs = math_ops._as_indexed_slices_list(inputs)
+ values, _ = gen_control_flow_ops._merge([inp.values for inp in inputs],
+ name=name)
+ indices, chosen_index = gen_control_flow_ops._merge(
+ [inp.indices for inp in inputs], name="indices")
+ if any(inp.dense_shape for inp in inputs):
+ if not all(inp.dense_shape for inp in inputs):
+ raise ValueError("Either all merged IndexedSlices must have a "
+ "dense_shape, or none must have a dense_shape.")
+ dense_shape, _ = gen_control_flow_ops._merge(
+ [inp.dense_shape for inp in inputs], name="dense_shape")
+ else:
+ dense_shape = None
+ return ops.IndexedSlices(values, indices, dense_shape), chosen_index
+
+
+def _SwitchRefOrTensor(data, pred, name="Switch"):
+ """Forwards `data` to an output determined by `pred`.
+
+ If `pred` is true, the `data` input is forwared to the first output.
+ Otherwise, the data goes to the second output.
+
+ This op handles `Tensor`s and `IndexedSlices`.
+
+ Args:
+ data: The tensor to be forwarded to the appropriate output.
+ pred: A scalar that specifies which output port will receive data.
+ name: A name for this operation (optional).
+
+ Returns:
+ `(output_false, output_false)`: If `pred` is true, data will be forwarded to
+ `output_true`, otherwise it goes to `output_false`.
+
+ Raises:
+ TypeError: if data is not a Tensor or IndexedSlices
+ """
+ data = ops.convert_to_tensor_or_indexed_slices(data, name="data")
+ if isinstance(data, ops.Tensor):
+ if not data.dtype.is_ref_dtype:
+ return switch(data, pred, name=name)
+ else:
+ return ref_switch(data, pred, name=name)
+ else:
+ return switch(data, pred, name=name)
+
+
+class ControlFlowOpInputs(object):
+ """An indirection to capture the input tensors needed in backprop."""
+
+ def __init__(self, op):
+ self._op = op
+ self._inputs = None
+
+ def __len__(self):
+ return len(self._op._inputs)
+
+ def __getitem__(self, index):
+ if self._inputs is None:
+ self._inputs = [None for _ in self._op.inputs]
+ if isinstance(index, int):
+ val = self._inputs[index]
+ if val is None:
+ f_val = self._op.inputs[index]
+ val = _GetRealValue(f_val)
+ self._inputs[index] = val
+ return val
+ elif isinstance(index, slice):
+ start, stop, step = index.indices(len(self))
+ vals = [self[i] for i in xrange(start, stop, step)]
+ return vals
+ else:
+ raise TypeError("index must be an integer or slice")
+
+
+class ControlFlowOpOutputs(object):
+ """An indirection to capture the output tensors needed in backprop."""
+
+ def __init__(self, op):
+ self._op = op
+ self._outputs = None
+
+ def __len__(self):
+ return len(self._op._outputs)
+
+ def __getitem__(self, index):
+ if self._outputs is None:
+ self._outputs = [None for _ in self._op.outputs]
+ if isinstance(index, int):
+ val = self._outputs[index]
+ if val is None:
+ f_val = self._op.outputs[index]
+ val = _GetRealValue(f_val)
+ self._outputs[index] = val
+ return val
+ elif isinstance(index, slice):
+ start, stop, step = index.indices(len(self))
+ vals = [self[i] for i in xrange(start, stop, step)]
+ return vals
+ else:
+ raise TypeError("index must be an integer or slice")
+
+
+class ControlFlowOpWrapper(object):
+ """A wrapper class for Operation."""
+
+ def __init__(self, op):
+ self._op = op
+ self._inputs = None
+ self._outputs = None
+
+ @property
+ def inputs(self):
+ if self._inputs is None:
+ self._inputs = ControlFlowOpInputs(self._op)
+ return self._inputs
+
+ @property
+ def outputs(self):
+ if self._outputs is None:
+ self._outputs = ControlFlowOpOutputs(self._op)
+ return self._outputs
+
+ @property
+ def op(self):
+ return self._op
+
+ @property
+ def name(self):
+ """Returns the name of this instance of op."""
+ return self._op.name
+
+ @property
+ def _id(self):
+ """Returns the unique id of this operation."""
+ return self._op._id
+
+ @property
+ def device(self):
+ """Returns the device of this operation.
+
+ Returns:
+ a string or None if the device was not set.
+ """
+ return self._op.device
+
+ @property
+ def output_types(self):
+ return self._op.output_types
+
+ @property
+ def input_types(self):
+ return self._op._input_types
+
+ @property
+ def type(self):
+ """Returns the type of the op."""
+ return self._op.type
+
+ @property
+ def graph(self):
+ """Returns the parent graph."""
+ return self._op.graph
+
+ def GetAttr(self, attr_name):
+ """Returns the value of attribute 'attr_name' of NodeDef."""
+ return self._op.get_attr(attr_name)
+
+ def _get_control_flow_context(self):
+ return self._op._get_control_flow_context()
+
+
+def GetRealOp(op):
+ while isinstance(op, ControlFlowOpWrapper):
+ op = op.op
+ return op
+
+
+def MakeWrapper(op):
+ """Make a wrapper for op if it is in a WhileContext."""
+ forward_ctxt = op._get_control_flow_context()
+ if forward_ctxt and isinstance(forward_ctxt, WhileContext):
+ return ControlFlowOpWrapper(op)
+ return op
+
+
+def EnterGradWhileContext(op):
+ """Enter the WhileContext for gradient computation."""
+ forward_ctxt = op._get_control_flow_context()
+ if forward_ctxt and isinstance(forward_ctxt, WhileContext):
+ grad_ctxt = forward_ctxt.CreateGradWhileContext()
+ grad_ctxt.Enter()
+
+
+def ExitGradWhileContext(op):
+ """Exit the WhileContext for gradient computation."""
+ forward_ctxt = op._get_control_flow_context()
+ if forward_ctxt and isinstance(forward_ctxt, WhileContext):
+ assert forward_ctxt.grad_context
+ forward_ctxt.grad_context.Exit()
+
+
+def _GetRealValue(value):
+ """Get the real value.
+
+ If backprop "uses" a value produced by forward inference, an
+ accumulator is added in the forward loop to accumulate its values,
+ so we use the accumulated value, indexed by the backprop counter.
+
+ Args:
+ value: A tensor to be captured.
+
+ Returns:
+ The same tensor value from the saved history.
+ """
+ real_value = value
+ forward_ctxt = value.op._get_control_flow_context()
+ real_value = forward_ctxt.history_map.get(value.name)
+ assert value.op.type != "Variable"
+ if real_value is None:
+ if value.op.type == "Enter" and value.op.get_attr("is_constant"):
+ # Use the input of this Enter node
+ real_value = GetRealOp(value.op).inputs[0]
+ else:
+ # Accumulate the history of this value.
+ # NOTE(yuanbyu): Don't accumulate for constants. One approach is
+ # to deepcopy the constants for the grad while context.
+ history_value = forward_ctxt.AddForwardAccumulateLoop(value)
+
+ # The shapes of the whole history and a single event element.
+ forward_ctxt.grad_context.Exit()
+ elem_rank = array_ops.rank(history_value) - 1
+ elem_rank_vec = array_ops.expand_dims(elem_rank, 0)
+ elem_shape = array_ops.slice(array_ops.shape(history_value), [1],
+ elem_rank_vec)
+ slice_shape = array_ops.concat(0, [[1], elem_shape])
+ forward_ctxt.grad_context.Enter()
+
+ # The begin position of the slice at slice_index.
+ slice_index = forward_ctxt.grad_context.index
+ b1 = array_ops.zeros(elem_rank_vec, dtype=types.int32)
+ b = array_ops.concat(0, [array_ops.expand_dims(slice_index, 0), b1])
+
+ # The slice at slice_index.
+ # TODO(irving): Replace with gather once that's GPU accelerated
+ real_value = array_ops.squeeze(
+ array_ops.slice(history_value,
+ b,
+ slice_shape,
+ name="real"),
+ squeeze_dims=[0])
+ forward_ctxt.history_map[value.name] = real_value
+ return real_value
+
+
+def IsLoopSwitch(op):
+ """Returns true if `op` is the Switch for a While loop."""
+ if op.type == "Switch":
+ ctxt = op._get_control_flow_context()
+ return ctxt and isinstance(ctxt, WhileContext)
+ return False
+
+
+class ControlFlowContext(object):
+ """The base class for control flow context.
+
+ The usage pattern is a sequence of (Enter, Exit) followed by a final
+ ExitResult.
+ """
+
+ def AddName(self, name):
+ self._values.add(name)
+
+ # pylint: disable=protected-access
+ def Enter(self):
+ """Enter the current context."""
+ self._outer_context = ops.get_default_graph()._get_control_flow_context()
+ ops.get_default_graph()._set_control_flow_context(self)
+
+ def Exit(self):
+ """Exit the current context."""
+ ops.get_default_graph()._set_control_flow_context(self._outer_context)
+ # pylint: enable=protected-access
+
+ def ExitResult(self, result):
+ """Make a list of tensors available in the outer context."""
+ if self._outer_context is not None:
+ for x in result:
+ self._outer_context.AddName(x.name)
+
+ def GetWhileContext(self):
+ """Get the current while context."""
+ if self._outer_context is not None:
+ return self._outer_context.GetWhileContext()
+ return None
+
+ def AddToWhileContext(self, op):
+ """Add a control dependency to the containing WhileContext.
+
+ The added control dependency ensures that the outputs of this op
+ belong to the WhileContext.
+
+ Args:
+ op: An operation.
+ """
+ while_ctxt = self.GetWhileContext()
+ if while_ctxt is not None:
+ # pylint: disable=protected-access
+ op._add_control_input(while_ctxt.GetControlPivot().op)
+ # pylint: enable=protected-access
+
+
+class CondContext(ControlFlowContext):
+ """The context for the conditional construct."""
+
+ def __init__(self, pred, pivot, branch):
+ self._pred = pred
+ self._outer_context = None
+ self._pivot = pivot
+ self._branch = branch
+ self._values = set()
+ self._values.add(pred.name)
+ self._values.add(pivot.name)
+ self._external_values = {}
+
+ @property
+ def pred(self):
+ return self._pred
+
+ @property
+ def pivot(self):
+ return self._pivot
+
+ @property
+ def branch(self):
+ return self._branch
+
+ def AddValue(self, val):
+ """Add 'val' to the current context and its outer context recursively."""
+ result = val
+ if val.name not in self._values:
+ self._values.add(val.name)
+ if self._outer_context is not None:
+ result = self._outer_context.AddValue(val)
+ result = with_dependencies([self._pivot], result)
+ self._external_values[val.name] = result
+ return result
+
+ def AddOp(self, op):
+ """Add 'op' to the current context."""
+ if not op.inputs:
+ # Add this op to the enclosing while context
+ self.AddToWhileContext(op)
+ # pylint: disable=protected-access
+ op._add_control_input(self._pivot.op)
+ # pylint: enable=protected-access
+ for x in op.outputs:
+ self._values.add(x.name)
+ else:
+ for index in range(len(op.inputs)):
+ x = op.inputs[index]
+ if x.name not in self._values:
+ self._values.add(x.name)
+ # Add this value to the parent contexts up to the context that
+ # creates this value.
+ real_x = x
+ if self._outer_context is not None:
+ real_x = self._outer_context.AddValue(x)
+ real_x = _SwitchRefOrTensor(real_x, self._pred)[self._branch]
+ self._external_values[x.name] = real_x
+ x = self._external_values.get(x.name)
+ if x is not None:
+ op._update_input(index, x)
+ for x in op.outputs:
+ self._values.add(x.name)
+
+ def BuildCondBranch(self, fn):
+ """Add the subgraph defined by fn() to the graph."""
+ r = fn()
+ result = []
+ if r is not None:
+ if not isinstance(r, list) and not isinstance(r, _basetuple):
+ r = [r]
+ for v in r:
+ if isinstance(v, ops.Operation):
+ v = with_dependencies([v], self._pivot)
+ elif v.name not in self._values:
+ self._values.add(v.name)
+ if self._outer_context is not None:
+ v = self._outer_context.AddValue(v)
+ v = _SwitchRefOrTensor(v, self._pred)[self._branch]
+ else:
+ external_v = self._external_values.get(v.name)
+ if external_v is not None:
+ v = external_v
+ result.append(v)
+ return result
+
+
+def cond(pred, fn1, fn2, name=None):
+ """Return either 'fn1()' or 'fn2()' based on the boolean predicate 'pred'.
+
+ `fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have
+ the same number and type of outputs.
+
+ Args:
+ pred: A scalar determining whether to return the result of `fn1` or `fn2`.
+ fn1: The function to be performed if pred is true.
+ fn2: The function to be performed if pref is false.
+ name: Optional name prefix for the returned tensors.
+
+ Returns:
+ Tensors returned by the call to either `fn1` or `fn2`. If the functions
+ return a singleton list, the element is extracted from the list.
+
+ Raises:
+ TypeError: if `fn1` or `fn2` is not callable.
+ ValueError: if `fn1` and `fn2` do not return the same number of tensors, or
+ return tensors of different types.
+
+ Example:
+ ```python
+ x = constant(2)
+ y = constant(5)
+ def f1(): return constant(17)
+ def f2(): return constant(23)
+ r = cond(math_ops.less(x, y), f1, f2)
+ # r is set to f1()
+ ```
+ """
+ with ops.op_scope([pred], name, "Cond") as name:
+ if not callable(fn1):
+ raise TypeError("fn1 must be callable.")
+ if not callable(fn2):
+ raise TypeError("fn2 must be callable.")
+
+ # Add the Switch to the graph.
+ p_2, p_1 = switch(pred, pred)
+ pivot_1 = array_ops.identity(p_1, name="switch_t")
+ pivot_2 = array_ops.identity(p_2, name="switch_f")
+ pred = array_ops.identity(pred, name="pred_id")
+
+ # Build the graph for the true branch in a new context.
+ context_t = CondContext(pred, pivot_1, 1)
+ context_t.Enter()
+ res_t = context_t.BuildCondBranch(fn1)
+ context_t.ExitResult(res_t)
+ context_t.Exit()
+
+ # Build the graph for the false branch in a new context.
+ context_f = CondContext(pred, pivot_2, 0)
+ context_f.Enter()
+ res_f = context_f.BuildCondBranch(fn2)
+ context_t.ExitResult(res_f)
+ context_f.Exit()
+
+ # Add the final merge to the graph.
+ if len(res_t) != len(res_f):
+ raise ValueError("fn1 and fn2 must return the same number of tensors.")
+ for x, y in zip(res_f, res_t):
+ assert ((isinstance(x, ops.IndexedSlices) and
+ isinstance(y, ops.IndexedSlices)) or
+ (isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)))
+ val_x = x if isinstance(x, ops.Tensor) else x.values
+ val_y = y if isinstance(y, ops.Tensor) else y.values
+ if val_x.dtype.base_dtype != val_y.dtype.base_dtype:
+ raise ValueError("Outputs of fn1 and fn2 must have the same type: "
+ "%s, %s" % (val_x.dtype.name, val_y.dtype.name))
+ merges = [merge([x[0], x[1]])[0] for x in zip(res_f, res_t)]
+ return merges[0] if len(merges) == 1 else merges
+
+
+# TODO(yuanbyu): We should probably separate the notion of context so it
+# could be used not only for conditionals and loops but also subgraphs.
+class WhileContext(ControlFlowContext):
+ """The context for the loop construct."""
+
+ def __init__(self, parallel_iterations, back_prop, name):
+ self._name = ops.get_default_graph().unique_name(name)
+ self._parallel_iterations = parallel_iterations
+ self._back_prop = back_prop
+ self._outer_context = None
+ # We use this node to control constants created by the pred lambda.
+ self._pivot_for_pred = None
+ # We use this node to control constants created by the body lambda.
+ self._pivot_for_body = None
+ # The boolean tensor for loop termination condition. Used in code
+ # generation for gradient computation
+ self._pivot = None
+
+ # The tensors for the counters added by AddForwardCounterLoop or
+ # AddBackPropCounterLoop
+ self._index = None
+
+ # Information needed by backprop
+ self._grad_context = None
+ self._total_iterations = None
+ self._history_map = {}
+ self._switch_map = {}
+
+ # values considered to have been already seen in this context
+ self._values = set()
+
+ # values referenced by but external to this context
+ self._external_values = {}
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def parallel_iterations(self):
+ """The number of iterations allowed to run in parallel."""
+ return self._parallel_iterations
+
+ @property
+ def back_prop(self):
+ """True iff backprop is enabled for this While loop."""
+ return self._back_prop
+
+ @property
+ def pivot(self):
+ """The boolean tensor representing the loop termination condition."""
+ return self._pivot
+
+ @property
+ def index(self):
+ """The loop index representing the current iteration."""
+ return self._index
+
+ @property
+ def grad_context(self):
+ """The corresponding WhileContext for gradient."""
+ return self._grad_context
+
+ @property
+ def history_map(self):
+ """The map that records all the tensors needed for backprop."""
+ return self._history_map
+
+ @property
+ def switch_map(self):
+ """The map that records all the Switch ops in the While loop."""
+ return self._switch_map
+
+ @property
+ def total_iterations(self):
+ """The total number of iterations of the while loop."""
+ return self._total_iterations
+
+ def GetWhileContext(self):
+ return self
+
+ def GetControlPivot(self):
+ if self._pivot_for_body:
+ return self._pivot_for_body
+ return self._pivot_for_pred
+
+ def AddValue(self, val):
+ """Add 'val' to the current context and its outer context recursively."""
+ result = val
+ if val.name not in self._values:
+ self._values.add(val.name)
+ if self._outer_context is not None:
+ result = self._outer_context.AddValue(val)
+ # Create an Enter that makes 'result' known to this context.
+ enter = _Enter(result, self._name, is_constant=True,
+ parallel_iterations=self._parallel_iterations)
+ self._values.add(enter.name)
+ self._external_values[val.name] = enter
+ result = enter
+ else:
+ actual_val = self._external_values.get(val.name)
+ if actual_val is not None:
+ result = actual_val
+ return result
+
+ def AddOp(self, op):
+ """Adds 'op' to the current context."""
+ if not op.inputs:
+ if not op.control_inputs:
+ # Add a control edge from the control pivot to this op.
+ # pylint: disable=protected-access
+ op._add_control_input(self.GetControlPivot().op)
+ # pylint: enable=protected-access
+ else:
+ # Control edges must be in the same context.
+ for x in op.control_inputs:
+ assert x._get_control_flow_context() == self, (
+ "Control inputs must come from Operations in the same while "
+ "loop context (not an outer context).")
+ for x in op.outputs:
+ self._values.add(x.name)
+ else:
+ for index in range(len(op.inputs)):
+ x = op.inputs[index]
+ self.AddValue(x)
+ real_x = self._external_values.get(x.name)
+ if real_x is not None:
+ op._update_input(index, real_x)
+ # Add a control dependency to prevent loop invariants from
+ # enabling ops that should not be executed.
+ if real_x.op.type == "RefEnter" and real_x.op.get_attr("is_constant"):
+ # pylint: disable=protected-access
+ op._add_control_input(self.GetControlPivot().op)
+ # pylint: enable=protected-access
+ for x in op.outputs:
+ self._values.add(x.name)
+
+ def CreateGradWhileContext(self):
+ """Creates the WhileContext for backprop gradient computation."""
+ if self._grad_context is None:
+ cnt = self.AddForwardCounterLoop()
+ self._grad_context = WhileContext(self._parallel_iterations,
+ self._back_prop, self._name)
+ self._grad_context.AddBackPropCounterLoop(cnt)
+ return self._grad_context
+
+ def AddForwardCounterLoop(self):
+ """Adds a loop that counts the number of iterations.
+
+ This is added to the forward loop at the time when we start to
+ create the loop for backprop gradient computation.
+
+ The pseudocode is:
+ `n = 0; while (_pivot) { n++; }`
+
+ Returns:
+ The number of iterations taken by the forward loop.
+ """
+ n = constant_op.constant(0, name="f_count")
+ self.Enter()
+ self.AddName(n.name)
+ enter_n = _Enter(n, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="f_count")
+ merge_n = merge([enter_n, enter_n])[0]
+ switch_n = switch(merge_n, self._pivot)
+ self._index = switch_n[1]
+
+ add_n = math_ops.add(self._index, 1)
+ next_n = next_iteration(add_n)
+ merge_n.op._update_input(1, next_n)
+
+ self._total_iterations = exit(switch_n[0], name="f_count")
+ self.Exit()
+ return self._total_iterations
+
+ def AddForwardAccumulateLoop(self, value):
+ """Add an accumulation loop for each value needed in backprop.
+
+ This is added to the forward loop at the first time when a value
+ in the forward loop is used by backprop gradient computation loop.
+
+ The pseudocode is:
+ ```
+ acc;
+ while (_pivot) {
+ if (index == 0) [value] else Concat(acc, [value]);
+ }
+ ```
+
+ Args:
+ value: The tensor that is accumulated.
+
+ Returns:
+ The accumulated history of value.
+
+ Raises:
+ ValueError: If the shape of "value" is not known statically.
+ """
+ if not value.get_shape().is_fully_defined():
+ raise ValueError("Must have known shape: %s" % value)
+ self._grad_context.Exit()
+ # TODO(irving): Now that acc starts out empty, most of the
+ # conditional logic can go away.
+ acc = constant_op.constant([],
+ value.dtype,
+ shape=[0] + value.get_shape().as_list(),
+ name="f_acc")
+ self.Enter()
+ self.AddName(acc.name)
+ enter_acc = _Enter(acc, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="f_acc")
+ merge_acc = merge([enter_acc, enter_acc])[0]
+ switch_acc = switch(merge_acc, self._pivot)
+
+ # If index = 0 then [value] else Concat(acc, [value]).
+ cond = math_ops.greater(self._index, 0)
+ switch_add_acc = switch(switch_acc[1], cond)
+ expand_value = array_ops.expand_dims(value, 0)
+ true_branch = array_ops.concat(0, [switch_add_acc[1], expand_value])
+ false_branch = array_ops.identity(switch_add_acc[0])
+ false_branch = with_dependencies([false_branch], expand_value)
+ add_acc = merge([false_branch, true_branch])[0]
+
+ next_acc = next_iteration(add_acc)
+ merge_acc.op._update_input(1, next_acc)
+
+ exit_acc = exit(switch_acc[0], name="f_acc")
+ self.Exit()
+ self._grad_context.Enter()
+ return exit_acc
+
+ def AddForwardAccumulateCondLoop(self, value):
+ """Add an accumulation loop for each conditional switch.
+
+ This is added to the forward loop at the first time when a conditional
+ switch in the forward loop is used by backprop gradient computation loop.
+
+ The pseudocode is:
+ ```
+ acc;
+ while (_pivot) {
+ Concat(acc, value);
+ }
+ ```
+
+ Args:
+ value: The boolean tensor that is accumulated.
+
+ Returns:
+ The accumulated history of value.
+ """
+ self._grad_context.Exit()
+ acc = constant_op.constant(False, name="f_acc")
+ self.Enter()
+ self.AddName(acc.name)
+ enter_acc = _Enter(acc, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="f_acc")
+ merge_acc = merge([enter_acc, enter_acc])[0]
+ switch_acc = switch(merge_acc, self._pivot)
+ acc = array_ops.concat(0, [switch_add_acc[1], value])
+ next_acc = next_iteration(acc)
+ merge_acc.op._update_input(1, next_acc)
+
+ exit_acc = exit(switch_acc[0], name="f_acc")
+ self.Exit()
+ self._grad_context.Enter()
+ return exit_acc
+
+ def AddBackPropCounterLoop(self, count):
+ """Add the backprop loop that controls the iterations.
+
+ This is added to the backprop loop. It is used to control the loop
+ termination and the slice index.
+
+ The pseudocode is:
+ `n = count; while (n >= 1) { n--; }`
+
+ Args:
+ count: The number of iterations for backprop.
+
+ Returns:
+ always 0.
+ """
+ one = constant_op.constant(1, name="b_count")
+ self.Enter()
+ self.AddName(count.name)
+ enter_count = _Enter(count, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="b_count")
+ merge_count = merge([enter_count, enter_count])[0]
+ self._pivot_for_pred = merge_count
+
+ cond = math_ops.greater_equal(merge_count, one)
+ self._pivot = loop_cond(cond, name="b_count")
+ switch_count = switch(merge_count, self._pivot)
+
+ # Add next_iteration right after Switch to match the gradient function.
+ next_count = next_iteration(switch_count[1])
+ self._pivot_for_body = next_count
+ self._index = math_ops.sub(next_count, one)
+ merge_count.op._update_input(1, self._index)
+
+ exit_count = exit(switch_count[0], name="b_count")
+ self.Exit()
+ return exit_count
+
+ def AddBackPropAccumulateLoop(self, value):
+ """Add an accumulation loop for every loop invariant.
+
+ This is added to the backprop loop. It is used to accumulate partial
+ gradients for each loop iteration. Called when in the while context
+ for gradient.
+
+ The pseudocode is:
+ ```
+ acc = 0;
+ while (_pivot) {
+ acc += value;
+ }
+ ```
+
+ Args:
+ value: The partial gradient of an iteration for a loop invariant.
+
+ Returns:
+ The gradient for a loop invariant.
+ """
+ self.Exit()
+ acc = constant_op.constant(0, value.dtype, name="b_acc")
+ self.Enter()
+ self.AddName(acc.name)
+ enter_acc = _Enter(acc, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="b_acc")
+ merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
+ switch_acc = switch(merge_acc, self._pivot)
+
+ next_acc = next_iteration(switch_acc[1])
+ add_acc = math_ops.add(next_acc, value)
+ merge_acc.op._update_input(1, add_acc)
+
+ exit_acc = exit(switch_acc[0], name="b_acc")
+ return exit_acc
+
+ def BuildLoop(self, pred, body, loop_vars):
+ """Add the loop termination condition and body to the graph."""
+
+ loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars)
+ # Let the context know the loop variabes so the _Enter nodes below
+ # would be added into the context correctly.
+ self._values = set([x.name for x in loop_vars])
+ if self._outer_context is not None:
+ real_vars = [self._outer_context.AddValue(x) for x in loop_vars]
+ else:
+ real_vars = loop_vars
+ enter_vars = [_Enter(x, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations)
+ for x in real_vars]
+ self._values = set([x.name for x in enter_vars])
+
+ merge_vars = [merge([x, x])[0] for x in enter_vars]
+ self._pivot_for_pred = merge_vars[0]
+
+ # Build the graph for pred.
+ c = ops.convert_to_tensor(pred(*merge_vars))
+ self._pivot = loop_cond(c, name="LoopCond")
+ switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]
+
+ # Build the graph for body.
+ vars_for_body = [_Identity(x[1]) for x in switch_vars]
+ self._pivot_for_body = vars_for_body[0]
+
+ body_result = body(*vars_for_body)
+ if not isinstance(body_result, (list, _basetuple)):
+ body_result = [body_result]
+ result = ops.convert_n_to_tensor_or_indexed_slices(body_result)
+ next_vars = [next_iteration(x) for x in result]
+
+ # Add the back edges to complete the loop.
+ assert len(merge_vars) == len(next_vars)
+ for x in zip(merge_vars, next_vars):
+ x[0].op._update_input(1, x[1])
+
+ # Add the exit ops.
+ exit_vars = [exit(x[0]) for x in switch_vars]
+
+ for m_var, n_var, e_var in zip(merge_vars, next_vars, exit_vars):
+ if m_var.get_shape().is_compatible_with(n_var.get_shape()):
+ e_var.set_shape(m_var.get_shape().merge_with(n_var.get_shape()))
+
+ # Exit the loop.
+ self.ExitResult(exit_vars)
+ self.Exit()
+ return exit_vars[0] if len(exit_vars) == 1 else exit_vars
+
+
+def While(cond, body, loop_vars, parallel_iterations=10, back_prop=True,
+ name=None):
+ """Repeat `body` while the condition `cond` is true.
+
+ `cond` is a function taking a list of tensors and returning a boolean scalar
+ tensor. `body` is a function taking a list of tensors and returning a list of
+ tensors of the same length and with the same types as the input. `loop_vars`
+ is a list of tensors that is passed to both `cond` and `body`.
+
+ While `cond` evaluates to true, `body` is executed.
+
+ Args:
+ cond: The termination condition of the loop.
+ body: A function that represents the loop body.
+ loop_vars: The list of variable input tensors.
+ parallel_iterations: The number of iterations allowed to run in parallel.
+ back_prop: Whether backprop is enabled for this while loop.
+ name: Optional name prefix for the returned tensors.
+
+ Returns:
+ The output tensors for the loop variables after the loop.
+
+ Raises:
+ TypeError: if `cond` or `body` is not callable.
+ ValueError: if `loop_var` is empty.
+
+ Example:
+ ```python
+ i = Constant(0)
+ c = lambda i: math_ops.less(i, 10)
+ b = lambda i: math_ops.add(i, 1)
+ r = While(c, b, [i])
+ ```
+ """
+ with ops.op_scope(loop_vars, name, "While") as name:
+ if not loop_vars:
+ raise ValueError("No loop variables provided")
+ if not callable(cond):
+ raise TypeError("cond must be callable.")
+ if not callable(body):
+ raise TypeError("body must be callable.")
+
+ context = WhileContext(parallel_iterations, back_prop, name)
+ context.Enter()
+ return context.BuildLoop(cond, body, loop_vars)
+
+
+def _AsTensorList(x, p):
+ """Return x as a list of Tensors or IndexedSlices.
+
+ For entries of `x` that are Operations, this returns an Identity of `p`
+ with a dependency on the operation.
+
+ Args:
+ x: A Tensor/IndexedSlices/Operation or a list or tuple of them.
+ p: A Tensor to return for entries in `x` that are Operations.
+
+ Returns:
+ A list of Tensors or IndexedSlices.
+ """
+ if not isinstance(x, list) and not isinstance(x, _basetuple):
+ x = [x]
+
+ l = []
+ for v in x:
+ if isinstance(v, ops.Operation):
+ v = with_dependencies([v], p)
+ v = ops.convert_to_tensor_or_indexed_slices(v)
+ if isinstance(v, ops.Tensor):
+ l.append(array_ops.identity(v))
+ else:
+ l.append(ops.IndexedSlices(array_ops.identity(v.values),
+ array_ops.identity(v.indices)))
+ return l
+
+
+def _CheckResults(a, b):
+ assert len(a) == len(b), (
+ "Values returned by a() and b() must have the same length.")
+ for x, y in zip(a, b):
+ assert x.dtype == y.dtype, (
+ "Values returned by a() [%s] and b() [%s] must have "
+ "the same type: %s, %s." %
+ (x.name, y.name, x.dtype.name, y.dtype.name))
+
+
+def with_dependencies(dependencies, output_tensor, name=None):
+ """Produces the content of `output_tensor` only after `dependencies`.
+
+ In some cases, a user may want the output of an operation to be
+ consumed externally only after some other dependencies have run
+ first. This function ensures returns `output_tensor`, but only after all
+ operations in `dependencies` have run. Note that this means that there is
+ no guarantee that `output_tensor` will be evaluated after any `dependencies`
+ have run.
+
+ See also `tuple` and `group`.
+
+ Args:
+ dependencies: A list of operations to run before this op finishes.
+ output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
+ name: (Optional) A name for this operation.
+
+ Returns:
+ Same as `output_tensor`.
+
+ Raises:
+ TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
+ """
+ with ops.op_scope(dependencies + [output_tensor], name,
+ "control_dependency") as name:
+ with ops.device(output_tensor.device
+ or ops.get_default_graph().get_default_device()):
+ with ops.control_dependencies(dependencies):
+ output_tensor = ops.convert_to_tensor_or_indexed_slices(output_tensor)
+ if isinstance(output_tensor, ops.Tensor):
+ return _Identity(output_tensor, name=name)
+ else:
+ return ops.IndexedSlices(_Identity(output_tensor.values, name=name),
+ output_tensor.indices,
+ output_tensor.dense_shape)
+
+
+def _GroupControlDeps(dev, deps, name=None):
+ with ops.control_dependencies(deps):
+ if dev is None:
+ return no_op(name=name)
+ else:
+ with ops.device(dev):
+ return no_op(name=name)
+
+
+# TODO(mdevin): Accept "inputs" as a list.
+def group(*inputs, **kwargs):
+ """Create an op that groups multiple operations.
+
+ When this op finishes, all ops in `input` have finished. This op has no
+ output.
+
+ See also `tuple` and `with_dependencies`.
+
+ Args:
+ *inputs: One or more tensors to group.
+ **kwargs: Optional parameters to pass when constructing the NodeDef.
+ name: A name for this operation (optional).
+
+ Returns:
+ An Operation that executes all its inputs.
+
+ Raises:
+ ValueError: If an unknown keyword argument is provided, or if there are
+ no inputs.
+ """
+ name = kwargs.pop("name", None)
+ if kwargs:
+ raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys()))
+ if not inputs:
+ # TODO(mdevin): Would make sense to return a NoOp.
+ raise ValueError("No inputs provided")
+ with ops.op_scope(inputs, name, "group_deps") as name:
+ # Sorts *inputs according to their devices.
+ ops_on_device = {} # device -> operations specified on the device.
+ for inp in inputs:
+ dev = inp.device
+ if dev in ops_on_device:
+ ops_on_device[dev].append(inp)
+ else:
+ ops_on_device[dev] = [inp]
+ if len(ops_on_device) == 1:
+ # 1-level tree. The root node is the returned NoOp node.
+ dev, deps = ops_on_device.items()[0]
+ return _GroupControlDeps(dev, deps, name=name)
+ # 2-level tree. The root node is the returned NoOp node.
+ # deps contains 1 NoOp node for each device.
+ deps = []
+ for dev in sorted(ops_on_device.iterkeys()):
+ deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
+ return _GroupControlDeps(None, deps, name=name)
+
+def tuple(tensors, name=None, control_inputs=None):
+ """Group tensors together.
+
+ This creates a tuple of tensors with the same values as the `tensors`
+ argument, except that the value of each tensor is only returned after the
+ values of all tensors have been computed.
+
+ `control_inputs` contains additional ops that have to finish before this op
+ finishes, but whose outputs are not returned.
+
+ This can be used as a "join" mechanism for parallel computations: all the
+ argument tensors can be computed in parallel, but the values of any tensor
+ returned by `tuple` are only available after all the parallel computations
+ are done.
+
+ See also `group` and `with_dependencies`.
+
+ Args:
+ tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
+ name: (optional) A name to use as a `name_scope` for the operation.
+ control_inputs: List of additional ops to finish before returning.
+
+ Returns:
+ Same as `tensors`.
+
+ Raises:
+ ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
+
+ """
+ with ops.op_scope(tensors, name, "tuple") as name:
+ gating_ops = [t.op for t in tensors if t]
+ if control_inputs:
+ gating_ops += control_inputs
+ # Note that in order to ensure ordering in the pbtxt, we must take care to
+ # ensure the order here.
+ gating_ops = sorted(set(gating_ops), key=lambda op: op._id) # Uniquify ops.
+ if not gating_ops:
+ raise ValueError("Must have at least one Tensor: %s" % tensors)
+ gate = group(*gating_ops)
+ tpl = []
+ for t in tensors:
+ if t:
+ tpl.append(with_dependencies([gate], t))
+ else:
+ tpl.append(None)
+ return tpl
+
+
+# TODO(yuanbyu): It would be nicer if we could have the distributed list
+# support that Derek has been proposing.
+# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
+def fold(fn, elems, elem_shape, name=None):
+ """The fold operator on slices of a tensor.
+
+ This fold operator applies the function `fn` to slices of `elems` on
+ dimension 0. The shape of the slices is specified by `elem_shape`. `elems`
+ must contain at least one slice (`shape(elems)[0] / elem_shape[0] > 0`).
+
+ Args:
+ fn: The function to be performed on each slice of the tensor.
+ elems: The tensor to whose slices we want to apply `fn`.
+ elem_shape: The shape definition for the slices.
+ name: Optional name prefix for the returned tensors.
+
+ Returns:
+ A tensor resulting from applying `fn` consecutively on each slice of
+ `elems`.
+
+ Raises:
+ TypeError: if `fn` is not callable.
+ """
+ with ops.op_scope([elems], name, "Fold") as name:
+ if not callable(fn):
+ raise TypeError("fn must be callable.")
+
+ s0 = array_ops.shape(elems)[0]
+ d0 = elem_shape[0]
+ n = math_ops.div(s0, d0)
+ b1 = array_ops.zeros(array_ops.expand_dims(array_ops.rank(elems) - 1, 0),
+ dtype=types.int32)
+ # Initialize the output with slice 0
+ b = array_ops.concat(0, [[0], b1])
+ o = array_ops.slice(elems, b, elem_shape)
+ i = ops.convert_to_tensor(d0)
+
+ def Compute(i, o):
+ b = array_ops.concat(0, [array_ops.expand_dims(i, 0), b1])
+ x = array_ops.slice(elems, b, elem_shape)
+ o = fn(o, x)
+ i = math_ops.add(i, d0)
+ return [i, o]
+ r = While(lambda i, o: math_ops.less(i, n), Compute, [i, o])
+ return r[1]
+
+
+def case(pred_fn_pairs, default, exclusive=False, name="Case"):
+ """Create a Case operation.
+
+ The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
+ Each pair contains a boolean scalar tensor and a python callable that
+ creates the tensors to be returned if the boolean evaluates to True. `default`
+ is a callable generating a list of tensors. All the callables in
+ `pred_fn_pairs` as well as `default` should return the same number and types
+ of tensors.
+
+ If `exclusive==True`, all predicates are evaluated, and a logging operation
+ with an error is returned if more than one of the predicates evaluates to
+ True. If `exclusive==False`, execution stops are the first predicate which
+ evaluates to True, and the tensors generated by the corresponding function
+ are returned immediately. If none of the predicates evaluate to True, this
+ operation returns the tensors generated by `default`.
+
+ Example 1:
+ Pseudocode:
+ ```
+ if (x < y) return 17;
+ else return 23;
+ ```
+
+ Expressions:
+ ```
+ f1 = lambda: Constant(17)
+ f2 = lambda: Constant(23)
+ r = Case([(math_ops.less(x, y), f1)], default=f2)
+ ```
+
+ Example 2:
+ Pseudocode:
+ ```
+ if (x < y && x > z) raise OpError("Only one predicate may evaluate true");
+ if (x < y) return 17;
+ else if (x > z) return 23;
+ else return -1;
+ ```
+
+ Expressions:
+ ```
+ def f1(): return Constant(17)
+ def f2(): return Constant(23)
+ def f3(): return Constant(-1)
+ r = Case({math_ops.less(x, y): f1, math_ops.greater(x, z): f2},
+ default=f3, exclusive=True)
+ ```
+
+ Args:
+ pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
+ callable which returns a list of tensors.
+ default: A callable that returns a list of tensors.
+ exclusive: True iff more than one predicate is allowed to evaluate to True.
+ name: A name for this operation (optional).
+
+ Returns:
+ The tensors returned by the first pair whose predicate evaluated to True, or
+ those returned by `default` if none does.
+
+ Raises:
+ TypeError: If `pred_fn_pairs` is not a list/dictionary.
+ TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
+ TypeError: If `fns[i]` is not callable for any i, or `default` is not
+ callable.
+ """
+ pfp = pred_fn_pairs # For readability
+ if not (isinstance(pfp, list) or isinstance(pfp, _basetuple)
+ or isinstance(pfp, dict)):
+ raise TypeError("fns must be a list, tuple, or dict")
+ if isinstance(pfp, dict):
+ pfp = pfp.items()
+ if not exclusive:
+ logging.warn("%s: Provided dictionary of predicate/fn pairs, but "
+ "exclusive=False. Order of conditional tests is "
+ "not guaranteed." % name)
+ for tup in pfp:
+ if not isinstance(tup, _basetuple) or len(tup) != 2:
+ raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple")
+ pred, fn = tup
+ if pred.dtype != types.bool:
+ raise TypeError("pred must be of type bool: %s", pred.name)
+ if not callable(fn):
+ raise TypeError("fn for pred %s must be callable." % pred.name)
+ if not callable(default):
+ raise TypeError("default must be callable.")
+
+ preds, fns = map(list, zip(*pfp))
+ with ops.op_scope([[f() for f in fns] + preds + [default()]], name, "Case"):
+ if not preds:
+ return default()
+ not_preds = []
+ for i, p in enumerate(preds):
+ with ops.name_scope("not_%d" % i):
+ not_preds.append(math_ops.logical_not(p))
+ and_not_preds = [constant_op.constant(True, name="and_not_true")]
+ for i, notp in enumerate(not_preds[:-1]):
+ with ops.name_scope("and_not_%d" % i):
+ and_not_preds.append(math_ops.logical_and(and_not_preds[-1], notp))
+
+ # preds = [p1, p2, p3]
+ # fns = [f1, f2, f3]
+ # not_preds = [~p1, ~p2, ~p3]
+ # case_preds = [p1 & True,
+ # p2 & ~p1,
+ # p3 & ~p1 & ~ p2]
+ case_preds = []
+ for i, (p, and_not_p_prev) in enumerate(zip(preds, and_not_preds)):
+ with ops.name_scope("case_%d" % i):
+ case_preds.append(math_ops.logical_and(p, and_not_p_prev))
+
+ # case_sequence = [Cond(p3 & ..., f3, default),
+ # Cond(p2 & ..., f2, lambda: case_sequence[0]),
+ # ...
+ # Cond(p1 & True, f1, lambda: case_sequence[i-1])]
+ # and prev_case_seq will loop from case_sequence[0] to case_sequence[-1]
+ if exclusive:
+ # TODO(ebrevdo): Add Where() for DT_BOOL, replace with Size(Where(preds))
+ preds_c = array_ops.concat(0, preds, name="preds_c")
+ num_true_conditions = math_ops.reduce_sum(
+ math_ops.cast(preds_c, types.int32), name="num_true_conds")
+ at_most_one_true_condition = math_ops.less(
+ num_true_conditions, constant_op.constant(2, name="two_true_conds"))
+
+ error_msg = [
+ ("More than one condition evaluated as True but "
+ "exclusive=True. Conditions: (%s), Values:"
+ % ", ".join([p.name for p in preds])),
+ preds_c]
+ with ops.control_dependencies([
+ logging_ops.Assert(condition=at_most_one_true_condition,
+ data=error_msg, summarize=len(preds))]):
+ prev_case_seq = default()
+ for i, (cp, fn) in enumerate(zip(case_preds, fns)[::-1]):
+ prev_case_seq = cond(cp, fn, lambda: prev_case_seq, name="If_%d" % i)
+ else:
+ prev_case_seq = default()
+ for i, (cp, fn) in enumerate(zip(case_preds, fns)[::-1]):
+ prev_case_seq = cond(cp, fn, lambda: prev_case_seq, name="If_%d" % i)
+
+ return prev_case_seq
+
+
+ops.RegisterShape("Enter")(common_shapes.unchanged_shape)
+ops.RegisterShape("Exit")(common_shapes.unknown_shape)
+ops.RegisterShape("NextIteration")(common_shapes.unchanged_shape)
+ops.RegisterShape("RefEnter")(common_shapes.unchanged_shape)
+ops.RegisterShape("ControlTrigger")(common_shapes.no_outputs)
+ops.RegisterShape("NoOp")(common_shapes.no_outputs)
+
+
+@ops.RegisterShape("LoopCond")
+def _LoopCondShape(op):
+ """Shape function for the LoopCond op."""
+ return [op.inputs[0].get_shape().merge_with(tensor_shape.scalar())]
+
+
+@ops.RegisterShape("Merge")
+def _MergeShape(op):
+ """Shape function for the Merge op.
+
+ The Merge op takes many inputs of arbitrary shapes, and produces a
+ first output that is one of those inputs, and a second scalar
+ output.
+
+ This function conservatively assumes that if any of its inputs is
+ not fully defined, the output shape is unknown. If all of the inputs
+ have the exact same known shape, the output must have that shape.
+
+ Args:
+ op: A Merge Operation.
+
+ Returns:
+ A single-element list containing the Shape of the Merge op.
+
+ """
+ first_input_shape = op.inputs[0].get_shape()
+ if first_input_shape.is_fully_defined():
+ for input_ in op.inputs[1:]:
+ input_shape = input_.get_shape()
+ if (not input_shape.is_fully_defined()
+ or not input_shape.is_compatible_with(first_input_shape)):
+ return [tensor_shape.unknown_shape(), tensor_shape.scalar()]
+ return [first_input_shape, tensor_shape.scalar()]
+ else:
+ return [tensor_shape.unknown_shape(), tensor_shape.scalar()]
+
+
+@ops.RegisterShape("RefSelect")
+def _RefSelectShape(op):
+ """Shape function for the RefSelect op.
+
+ The RefSelect takes one scalar input and N inputs of arbitrary
+ shapes, and produces one output, which is one of those N inputs.
+
+ This function conservatively assumes that if any of the N inputs is
+ not fully defined, the output shape is unknown. If all of the N
+ inputs have the exact same known shape, the output must have that
+ shape.
+
+ Args:
+ op: A RefSelect Operation.
+
+ Returns:
+ A single-element list containing the Shape of the RefSelect op.
+ """
+ unused_shape = op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
+ first_input_shape = op.inputs[1].get_shape()
+ if first_input_shape.is_fully_defined():
+ for input_ in op.inputs[2:]:
+ input_shape = input_.get_shape()
+ if (not input_shape.is_fully_defined()
+ or not input_shape.is_compatible_with(first_input_shape)):
+ return [tensor_shape.unknown_shape()]
+ return [first_input_shape]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("RefSwitch")
+@ops.RegisterShape("Switch")
+def _SwitchShape(op):
+ input_shape = op.inputs[0].get_shape()
+ unused_pred_shape = op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
+ return [input_shape] * 2
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
new file mode 100644
index 0000000000..34b1ab0a25
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -0,0 +1,88 @@
+"""Tests for control_flow_ops.py."""
+import tensorflow.python.platform
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.framework.test_util import TensorFlowTestCase
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import standard_ops as tf
+from tensorflow.python.platform import googletest
+
+
+class GroupTestCase(TensorFlowTestCase):
+
+ def _StripNode(self, nd):
+ snode = graph_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input)
+ if nd.device:
+ snode.device = nd.device
+ return snode
+
+ def _StripGraph(self, gd):
+ """Copy gd keeping only, node.name, node.op, node.input, and node.device."""
+ return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node])
+
+ def testGroup_NoDevices(self):
+ with ops.Graph().as_default() as g:
+ a = tf.constant(0, name="a")
+ b = tf.constant(0, name="b")
+ c = tf.constant(0, name="c")
+ tf.group(a.op, b.op, c.op, name="root")
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "a" op: "Const"}
+ node { name: "b" op: "Const"}
+ node { name: "c" op: "Const"}
+ node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" }
+ """, self._StripGraph(gd))
+
+ def testGroup_OneDevice(self):
+ with ops.Graph().as_default() as g:
+ with g.device("/task:0"):
+ a = tf.constant(0, name="a")
+ b = tf.constant(0, name="b")
+ tf.group(a.op, b.op, name="root")
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "a" op: "Const" device: "/task:0" }
+ node { name: "b" op: "Const" device: "/task:0" }
+ node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" }
+ """, self._StripGraph(gd))
+
+ def testGroup_MultiDevice(self):
+ with ops.Graph().as_default() as g:
+ with g.device("/task:0"):
+ a = tf.constant(0, name="a")
+ b = tf.constant(0, name="b")
+ with g.device("/task:1"):
+ c = tf.constant(0, name="c")
+ d = tf.constant(0, name="d")
+ with g.device("/task:2"):
+ tf.group(a.op, b.op, c.op, d.op, name="root")
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "a" op: "Const" device: "/task:0"}
+ node { name: "b" op: "Const" device: "/task:0"}
+ node { name: "c" op: "Const" device: "/task:1"}
+ node { name: "d" op: "Const" device: "/task:1"}
+ node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b"
+ device: "/task:0" }
+ node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d"
+ device: "/task:1" }
+ node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1"
+ device: "/task:2" }
+ """, self._StripGraph(gd))
+
+
+class ShapeTestCase(TensorFlowTestCase):
+
+ def testShape(self):
+ with ops.Graph().as_default():
+ tensor = tf.constant([1.0, 2.0])
+ self.assertEquals([2], tensor.get_shape())
+ self.assertEquals([2],
+ control_flow_ops.with_dependencies(
+ [tf.constant(1.0)], tensor).get_shape())
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/ops/data_flow_grad.py b/tensorflow/python/ops/data_flow_grad.py
new file mode 100644
index 0000000000..d2473490ce
--- /dev/null
+++ b/tensorflow/python/ops/data_flow_grad.py
@@ -0,0 +1,37 @@
+"""Gradients for operators defined in data_flow_ops.py."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gen_data_flow_ops
+from tensorflow.python.ops import math_ops
+
+
+@ops.RegisterGradient("DynamicStitch")
+def _DynamicStitchGrads(op, grad):
+ """Gradients for DynamicStitch."""
+
+ num_values = len(op.inputs) / 2
+ indices_grad = [None] * num_values
+
+ def AsInt32(x):
+ return (x if op.inputs[0].dtype == types.int32 else
+ math_ops.cast(x, types.int32))
+ inputs = [AsInt32(op.inputs[i]) for i in range(num_values)]
+ if isinstance(grad, ops.IndexedSlices):
+ output_shape = array_ops.shape(op.outputs[0])
+ output_rows = output_shape[0]
+ grad = math_ops.unsorted_segment_sum(grad.values, grad.indices, output_rows)
+ values_grad = [array_ops.gather(grad, inp) for inp in inputs]
+ return indices_grad + values_grad
+
+
+ops.NoGradient("Queue")
+ops.NoGradient("QueueEnqueue")
+ops.NoGradient("QueueEnqueueMany")
+ops.NoGradient("QueueDequeue")
+ops.NoGradient("QueueDequeueMany")
+ops.NoGradient("QueueClose")
+ops.NoGradient("QueueSize")
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
new file mode 100644
index 0000000000..5c8ab66297
--- /dev/null
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -0,0 +1,680 @@
+"""Data Flow Operations."""
+# pylint: disable=g-bad-name
+import re
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_data_flow_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_data_flow_ops import *
+
+
+def _as_type_list(dtypes):
+ """Convert dtypes to a list of types."""
+ assert dtypes is not None
+ if not (isinstance(dtypes, list) or isinstance(dtypes, tuple)):
+ # We have a single type.
+ return [dtypes]
+ else:
+ # We have a list or tuple of types.
+ return list(dtypes)
+
+
+def _as_shape_list(shapes, dtypes):
+ """Convert shapes to a list of tuples of int (or None)."""
+ if shapes is None: return None
+ if isinstance(shapes, tensor_shape.TensorShape):
+ shapes = [shapes]
+ if not isinstance(shapes, (tuple, list)):
+ raise TypeError(
+ "shapes must be a TensorShape or a list or tuple of TensorShapes.")
+ if all(isinstance(shape, int) for shape in shapes):
+ # We have a single shape.
+ shapes = [shapes]
+ shapes = [tensor_shape.as_shape(shape) for shape in shapes]
+ if any(not shape.is_fully_defined() for shape in shapes):
+ raise ValueError("All shapes must be fully defined.")
+ return shapes
+
+
+# pylint: disable=protected-access
+class QueueBase(object):
+ """Base class for queue implementations.
+
+ A queue is a TensorFlow data structure that stores tensors across
+ multiple steps, and exposes operations that enqueue and dequeue
+ tensors.
+
+ Each queue element is a tuple of one or more tensors, where each
+ tuple component has a static dtype, and may have a static shape. The
+ queue implementations support versions of enqueue and dequeue that
+ handle single elements, versions that support enqueuing and
+ dequeuing a batch of elements at once.
+
+ See [`tf.FIFOQueue`](#FIFOQueue) and
+ [`tf.RandomShuffleQueue`](#RandomShuffleQueue) for concrete
+ implementations of this class, and instructions on how to create
+ them.
+
+ @@enqueue
+ @@enqueue_many
+
+ @@dequeue
+ @@dequeue_many
+
+ @@size
+
+ @@close
+
+ """
+
+ def __init__(self, dtypes, shapes, queue_ref):
+ """Constructs a queue object from a queue reference.
+
+ Args:
+ dtypes: A list of types. The length of dtypes must equal the number
+ of tensors in each element.
+ shapes: Constraints on the shapes of tensors in an element:
+ A list of shape tuples or None. This list is the same length
+ as dtypes. If the shape of any tensors in the element are constrained,
+ all must be; shapes can be None if the shapes should not be constrained.
+ queue_ref: The queue reference, i.e. the output of the queue op.
+ """
+ self._dtypes = dtypes
+ if shapes is not None:
+ self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
+ else:
+ self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes]
+ self._queue_ref = queue_ref
+ self._name = self._queue_ref.op.name.split("/")[-1]
+
+ @staticmethod
+ def from_list(index, queues):
+ """Create a queue using the queue reference from `queues[index]`.
+
+ Args:
+ index: An integer scalar tensor that determines the input that gets
+ selected.
+ queues: A list of `QueueBase` objects.
+
+ Returns:
+ A `QueueBase` object.
+
+ Raises:
+ TypeError: when `queues` is not a list of `QueueBase` objects,
+ or when the data types of `queues` are not all the same.
+ """
+ if ((not queues) or
+ (not isinstance(queues, list)) or
+ (not all([isinstance(x, QueueBase) for x in queues]))):
+ raise TypeError("A list of queues expected")
+
+ dtypes = queues[0].dtypes
+ if not all([dtypes == q.dtypes for q in queues[1:]]):
+ raise TypeError("Queues do not have matching component dtypes.")
+
+ queue_refs = [x.queue_ref for x in queues]
+ selected_queue = control_flow_ops.ref_select(index, queue_refs)
+ # TODO(josh11b): Unify the shapes of the queues too?
+ return QueueBase(dtypes=dtypes, shapes=None, queue_ref=selected_queue)
+
+ @property
+ def queue_ref(self):
+ """The underlying queue reference."""
+ return self._queue_ref
+
+ @property
+ def name(self):
+ """The name of the underlying queue."""
+ return self._queue_ref.op.name
+
+ @property
+ def dtypes(self):
+ """The list of dtypes for each component of a queue element."""
+ return self._dtypes
+
+ def enqueue(self, vals, name=None):
+ """Enqueues one element to this queue.
+
+ If the queue is full when this operation executes, it will block
+ until the element has been enqueued.
+
+ Args:
+ vals: The tuple of `Tensor` objects to be enqueued.
+ name: A name for the operation (optional).
+
+ Returns:
+ The operation that enqueues a new tuple of tensors to the queue.
+ """
+ if name is None:
+ name = "%s_enqueue" % self._name
+ ret = gen_data_flow_ops._queue_enqueue(self._queue_ref, vals, name=name)
+
+ # NOTE(mrry): Not using a shape function because we need access to
+ # the Queue object.
+ for val, shape in zip(ret.inputs[1:], self._shapes):
+ val.get_shape().assert_is_compatible_with(shape)
+
+ return ret
+
+ def enqueue_many(self, vals, name=None):
+ """Enqueues zero or elements to this queue.
+
+ This operation slices each component tensor along the 0th dimension to
+ make multiple queue elements. All of the tensors in `vals` must have the
+ same size in the 0th dimension.
+
+ If the queue is full when this operation executes, it will block
+ until all of the elements have been enqueued.
+
+ Args:
+ vals: The tensor or tuple of tensors from which the queue elements
+ are taken.
+ name: A name for the operation (optional).
+
+ Returns:
+ The operation that enqueues a batch of tuples of tensors to the queue.
+ """
+ if name is None:
+ name = "%s_EnqueueMany" % self._name
+
+ ret = gen_data_flow_ops._queue_enqueue_many(
+ self._queue_ref, vals, name=name)
+
+ # NOTE(mrry): Not using a shape function because we need access to
+ # the `QueueBase` object.
+ batch_dim = ret.inputs[1].get_shape()[0]
+ for val, shape in zip(ret.inputs[1:], self._shapes):
+ batch_dim.merge_with(val.get_shape()[0])
+ val.get_shape()[1:].assert_is_compatible_with(shape)
+
+ return ret
+
+ def dequeue(self, name=None):
+ """Dequeues one element from this queue.
+
+ If the queue is empty when this operation executes, it will block
+ until there is an element to dequeue.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ The tuple of tensors that was dequeued.
+ """
+ if name is None:
+ name = "%s_Dequeue" % self._name
+ ret = gen_data_flow_ops._queue_dequeue(
+ self._queue_ref, self._dtypes, name=name)
+
+ # NOTE(mrry): Not using a shape function because we need access to
+ # the `QueueBase` object.
+ op = ret[0].op
+ for output, shape in zip(op.values(), self._shapes):
+ output.set_shape(shape)
+
+ return ret if len(ret) != 1 else ret[0]
+
+ def dequeue_many(self, n, name=None):
+ """Dequeues and concatenates `n` elements from this queue.
+
+ This operation concatenates queue-element component tensors along
+ the 0th dimension to make a single component tensor. All of the
+ components in the dequeued tuple will have size `n` in the 0th dimension.
+
+ If the queue contains fewer than `n` elements when this operation
+ executes, it will block until `n` elements have been dequeued.
+
+ Args:
+ n: A scalar `Tensor` containing the number of elements to dequeue.
+ name: A name for the operation (optional).
+
+ Returns:
+ The tuple of concatenated tensors that was dequeued.
+ """
+ if name is None:
+ name = "%s_DequeueMany" % self._name
+
+ ret = gen_data_flow_ops._queue_dequeue_many(
+ self._queue_ref, n, self._dtypes, name=name)
+
+ # NOTE(mrry): Not using a shape function because we need access to
+ # the Queue object.
+ op = ret[0].op
+ batch_dim = tensor_shape.Dimension(tensor_util.ConstantValue(op.inputs[1]))
+ for output, shape in zip(op.values(), self._shapes):
+ output.set_shape(tensor_shape.TensorShape([batch_dim]).concatenate(shape))
+
+ return ret if len(ret) != 1 else ret[0]
+
+ def close(self, cancel_pending_enqueues=False, name=None):
+ """Closes this queue.
+
+ This operation signals that no more elements will be enqueued in
+ the given queue. Subsequent `enqueue` and `enqueue_many`
+ operations will fail. Subsequent `dequeue` and `dequeue_many`
+ operations will continue to succeed if sufficient elements remain
+ in the queue. Subsequent `dequeue` and `dequeue_many` operations
+ that would block will fail immediately.
+
+ If `cancel_pending_enqueues` is `True`, all pending requests will also
+ be cancelled.
+
+ Args:
+ cancel_pending_enqueues: (Optional.) A boolean, defaulting to
+ `False` (described above).
+ name: A name for the operation (optional).
+
+ Returns:
+ The operation that closes the queue.
+ """
+ if name is None:
+ name = "%s_Close" % self._name
+ return gen_data_flow_ops._queue_close(
+ self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues,
+ name=name)
+
+ def size(self, name=None):
+ """Compute the number of elements in this queue.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar tensor containing the number of elements in this queue.
+ """
+ if name is None:
+ name = "%s_Size" % self._name
+ return gen_data_flow_ops._queue_size(self._queue_ref, name=name)
+
+
+class RandomShuffleQueue(QueueBase):
+ """A queue implementation that dequeues elements in a random order.
+
+ See [`tf.QueueBase`](#QueueBase) for a description of the methods on
+ this class.
+
+ @@__init__
+ """
+
+ def __init__(self, capacity, min_after_dequeue, dtypes, shapes=None,
+ seed=None, shared_name=None, name="random_shuffle_queue"):
+ """Create a queue that dequeues elements in a random order.
+
+ A `RandomShuffleQueue` has bounded capacity; supports multiple
+ concurrent producers and consumers; and provides exactly-once
+ delivery.
+
+ A `RandomShuffleQueue` holds a list of up to `capacity`
+ elements. Each element is a fixed-length tuple of tensors whose
+ dtypes are described by `dtypes`, and whose shapes are optionally
+ described by the `shapes` argument.
+
+ If the `shapes` argument is specified, each component of a queue
+ element must have the respective fixed shape. If it is
+ unspecified, different queue elements may have different shapes,
+ but the use of `dequeue_many` is disallowed.
+
+ The `min_after_dequeue` argument allows the caller to specify a
+ minimum number of elements that will remain in the queue after a
+ `dequeue` or `dequeue_many` operation completes, to ensure a
+ minimum level of mixing of elements. This invariant is maintained
+ by blocking those operations until sufficient elements have been
+ enqueued. The `min_after_dequeue` argument is ignored after the
+ queue has been closed.
+
+ Args:
+ capacity: An integer. The upper bound on the number of elements
+ that may be stored in this queue.
+ min_after_dequeue: An integer (described above).
+ dtypes: A list of `DType` objects. The length of `dtypes` must equal
+ the number of tensors in each queue element.
+ shapes: (Optional.) A list of fully-defined `TensorShape` objects,
+ with the same length as `dtypes` or `None`.
+ seed: A Python integer. Used to create a random seed.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ shared_name: (Optional.) If non-empty, this queue will be shared under
+ the given name across multiple sessions.
+ name: Optional name for the queue operation.
+ """
+ dtypes = _as_type_list(dtypes)
+ shapes = _as_shape_list(shapes, dtypes)
+ seed1, seed2 = random_seed.get_seed(seed)
+ queue_ref = gen_data_flow_ops._random_shuffle_queue(
+ component_types=dtypes, shapes=shapes, capacity=capacity,
+ min_after_dequeue=min_after_dequeue, seed=seed1, seed2=seed2,
+ shared_name=shared_name, name=name)
+
+ super(RandomShuffleQueue, self).__init__(dtypes, shapes, queue_ref)
+
+
+class FIFOQueue(QueueBase):
+ """A queue implementation that dequeues elements in first-in-first out order.
+
+ See [`tf.QueueBase`](#QueueBase) for a description of the methods on
+ this class.
+
+ @@__init__
+ """
+
+ def __init__(self, capacity, dtypes, shapes=None, shared_name=None,
+ name="fifo_queue"):
+ """Creates a queue that dequeues elements in a first-in first-out order.
+
+ A `FIFOQueue` has bounded capacity; supports multiple concurrent
+ producers and consumers; and provides exactly-once delivery.
+
+ A `FIFOQueue` holds a list of up to `capacity` elements. Each
+ element is a fixed-length tuple of tensors whose dtypes are
+ described by `dtypes`, and whose shapes are optionally described
+ by the `shapes` argument.
+
+ If the `shapes` argument is specified, each component of a queue
+ element must have the respective fixed shape. If it is
+ unspecified, different queue elements may have different shapes,
+ but the use of `dequeue_many` is disallowed.
+
+ Args:
+ capacity: An integer. The upper bound on the number of elements
+ that may be stored in this queue.
+ dtypes: A list of `DType` objects. The length of `dtypes` must equal
+ the number of tensors in each queue element.
+ shapes: (Optional.) A list of fully-defined `TensorShape` objects,
+ with the same length as `dtypes` or `None`.
+ shared_name: (Optional.) If non-empty, this queue will be shared under
+ the given name across multiple sessions.
+ name: Optional name for the queue operation.
+ """
+ dtypes = _as_type_list(dtypes)
+ shapes = _as_shape_list(shapes, dtypes)
+ queue_ref = gen_data_flow_ops._fifo_queue(
+ component_types=dtypes, shapes=shapes, capacity=capacity,
+ shared_name=shared_name, name=name)
+
+ super(FIFOQueue, self).__init__(dtypes, shapes, queue_ref)
+
+
+# TODO(josh11b): class BatchQueue(QueueBase):
+
+
+# pylint: disable=protected-access
+class LookupTableBase(object):
+ """Represents a lookup table that persists across different steps."""
+
+ def __init__(self, key_dtype, value_dtype, default_value, table_ref):
+ """Construct a table object from a table reference.
+
+ Args:
+ key_dtype: The key data type of the table.
+ value_dtype: The kvalue data type of the table.
+ default_value: The scalar tensor to be used when a key is not present in
+ the table.
+ table_ref: The table reference, i.e. the output of the lookup table ops.
+ """
+ self._key_dtype = types.as_dtype(key_dtype)
+ self._value_dtype = types.as_dtype(value_dtype)
+ self._shapes = [tensor_shape.TensorShape([1])]
+ self._table_ref = table_ref
+ self._name = self._table_ref.op.name.split("/")[-1]
+ self._default_value = ops.convert_to_tensor(default_value,
+ dtype=self._value_dtype)
+ self._default_value.get_shape().merge_with(tensor_shape.scalar())
+
+ @property
+ def table_ref(self):
+ """Get the underlying table reference."""
+ return self._table_ref
+
+ @property
+ def key_dtype(self):
+ """The key dtype supported by the table."""
+ return self._key_dtype
+
+ @property
+ def value_dtype(self):
+ """The value dtype supported by the table."""
+ return self._value_dtype
+
+ @property
+ def name(self):
+ """The name of the table."""
+ return self._name
+
+ @property
+ def default_value(self):
+ """The default value of the table."""
+ return self._default_value
+
+ def size(self, name=None):
+ """Compute the number of elements in this table.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar tensor containing the number of elements in this table.
+ """
+ if name is None:
+ name = "%s_Size" % self._name
+ return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name)
+
+ def lookup(self, keys, name=None):
+ """Returns the values for the given 'keys' tensor.
+
+ If an element on the key tensor is not found in the table, the default_value
+ is used.
+
+ Args:
+ keys: The tensor for the keys.
+ name: Optional name for the op.
+
+ Returns:
+ The operation that looks up the keys.
+
+ Raises:
+ TypeError: when 'keys' or 'default_value' doesn't match the table data
+ types.
+ """
+ if name is None:
+ name = "%s_lookup_table_find" % self._name
+
+ if keys.dtype != self._key_dtype:
+ raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (
+ self._key_dtype, keys.dtype))
+
+ return gen_data_flow_ops._lookup_table_find(
+ self._table_ref, keys, self._default_value, name=name)
+
+ def initialize_from(self, keys, values, name=None):
+ """Initialize the lookup table with the provided keys and values tensors.
+
+ Construct an initializer object from keys and value tensors.
+
+ Args:
+ keys: The tensor for the keys.
+ values: The tensor for the values.
+ name: Optional name for the op.
+
+ Returns:
+ The operation that initializes a lookup table.
+
+ Raises:
+ TypeError: when the 'keys' and 'values' data type do not match the table
+ key and value data types.
+ """
+ if name is None:
+ name = "%s_initialize_table" % self.name
+ with ops.op_scope([keys, values], None, name):
+ keys = ops.convert_to_tensor(keys, dtype=self.key_dtype, name="keys")
+ values = ops.convert_to_tensor(values, dtype=self.value_dtype,
+ name="values")
+
+ init_op = gen_data_flow_ops._initialize_table(
+ self.table_ref, keys, values, name=name)
+ ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
+ return init_op
+
+ def _check_table_dtypes(self, key_dtype, value_dtype):
+ """Check that the given key_dtype and value_dtype matches the table dtypes'.
+
+ Args:
+ key_dtype: The key data type to check.
+ value_dtype: The value data type to check.
+
+ Raises:
+ TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data
+ types.
+ """
+ if key_dtype != self.key_dtype:
+ raise TypeError("Invalid key dtype, expected %s but got %s." % (
+ self.key_dtype, key_dtype))
+ if value_dtype != self.value_dtype:
+ raise TypeError("Invalid value dtype, expected %s but got %s." % (
+ self.value_dtype, value_dtype))
+
+
+class HashTable(LookupTableBase):
+ """A generic hash table implementation."""
+
+ def __init__(self, key_dtype, value_dtype, default_value, shared_name=None,
+ name="hash_table"):
+ """Create a generic hash table.
+
+ A table holds a key-value pairs. The key and value types are
+ described by key_dtype and value_dtype respectively.
+
+ Args:
+ key_dtype: The key data type of the table.
+ value_dtype: The kvalue data type of the table.
+ default_value: The scalar tensor to be used when a key is not present in
+ the table.
+ shared_name: Optional. If non-empty, this table will be shared under
+ the given name across multiple sessions.
+ name: Optional name for the hash table op.
+
+ Returns:
+ A table object that can be used to lookup data.
+ """
+ table_ref = gen_data_flow_ops._hash_table(
+ shared_name=shared_name, key_dtype=key_dtype,
+ value_dtype=value_dtype, name=name)
+
+ super(HashTable, self).__init__(key_dtype, value_dtype, default_value,
+ table_ref)
+
+
+def initialize_all_tables(name="init_all_tables"):
+ """Returns an Op that initializes all tables of the default graph.
+
+ Returns:
+ An Op that initializes all tables. Note that if there are
+ not tables the returned Op is a NoOp.
+ """
+ initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS)
+ if initializers:
+ return control_flow_ops.group(*initializers, name=name)
+ return control_flow_ops.no_op(name=name)
+
+
+ops.NoGradient("LookupTableFind")
+ops.NoGradient("LookupTableSize")
+ops.NoGradient("HashTable")
+ops.NoGradient("InitializeTable")
+
+
+ops.RegisterShape("QueueSize")(common_shapes.scalar_shape)
+ops.RegisterShape("Queue")(common_shapes.scalar_shape)
+ops.RegisterShape("FIFOQueue")(common_shapes.scalar_shape)
+ops.RegisterShape("RandomShuffleQueue")(common_shapes.scalar_shape)
+
+
+# NOTE(mrry): The following ops use higher-level information in the
+# Queue class to provide shape information.
+ops.RegisterShape("QueueDequeue")(common_shapes.unknown_shape)
+ops.RegisterShape("QueueDequeueMany")(common_shapes.unknown_shape)
+ops.RegisterShape("QueueEnqueue")(common_shapes.unknown_shape)
+ops.RegisterShape("QueueEnqueueMany")(common_shapes.unknown_shape)
+
+
+@ops.RegisterShape("QueueClose")
+def _ScalarToVoidShape(op):
+ """Shape function for ops that take a scalar and produce no outputs."""
+ unused_input_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ return []
+
+
+@ops.RegisterShape("DynamicPartition")
+def _DynamicPartitionShape(op):
+ """Shape function for data_flow_ops.dynamic_partition."""
+ data_shape = op.inputs[0].get_shape()
+ partitions_shape = op.inputs[1].get_shape()
+ # If we don't know the rank of partitions, we don't know anything
+ mid = partitions_shape.ndims
+ if mid is None:
+ result_shape = tensor_shape.unknown_shape()
+ else:
+ # data_shape must start with partitions_shape
+ partitions_shape.assert_is_compatible_with(data_shape[:mid])
+ # The partition shape is dynamic in the 0th dimension, and matches
+ # data_shape in the remaining dimensions.
+ result_shape = tensor_shape.TensorShape([None]).concatenate(
+ data_shape[mid:])
+ return [result_shape] * op.get_attr("num_partitions")
+
+
+@ops.RegisterShape("DynamicStitch")
+def _DynamicStitchShape(op):
+ """Shape function for data_flow_ops.dynamic_stitch."""
+ num_partitions = op.get_attr("N")
+ indices_shapes = [t.get_shape() for t in op.inputs[0:num_partitions]]
+ data_shapes = [t.get_shape() for t in op.inputs[num_partitions:]]
+ output_shape = tensor_shape.unknown_shape()
+ extra_shape = tensor_shape.TensorShape(None)
+ for indices_shape, data_shape in zip(indices_shapes, data_shapes):
+ indices_ndims = indices_shape.ndims
+ if indices_ndims is not None:
+ # Assert that data_shape starts with indices_shape
+ indices_shape.merge_with(data_shape[:indices_ndims])
+ # The rest belongs to output
+ extra_shape = extra_shape.merge_with(data_shape[indices_ndims:])
+ return [tensor_shape.TensorShape([None]).concatenate(extra_shape)]
+
+
+@ops.RegisterShape("LookupTableFind")
+def _LookupTableFindShape(op):
+ """Shape function for data_flow_ops._lookup_table_find."""
+ unused_table_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ shape_in = op.inputs[1].get_shape()
+ return [shape_in]
+
+
+@ops.RegisterShape("LookupTableSize")
+def _LookupTableSizeShape(op):
+ """Shape function for data_flow_ops._lookup_table_find."""
+ unused_table_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ return [tensor_shape.scalar()]
+
+
+@ops.RegisterShape("HashTable")
+def _HashTableShape(unused_op):
+ """Shape function for data_flow_ops._hash_table."""
+ return [tensor_shape.scalar()]
+
+
+@ops.RegisterShape("InitializeTable")
+def _InitializeLookupTableShape(op):
+ """Shape function for data_flow_ops._initialize_table."""
+ unused_table_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ keys_shape = op.inputs[1].get_shape().with_rank(1)
+ unused_values_shape = op.inputs[2].get_shape().merge_with(keys_shape)
+ return []
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
new file mode 100644
index 0000000000..bc64593d23
--- /dev/null
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -0,0 +1,197 @@
+"""Operations for embeddings."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import math_ops
+
+
+def embedding_lookup(params, ids, name=None):
+ """Return a tensor of embedding values by looking up "ids" in "params".
+
+ Args:
+ params: List of tensors of the same shape. A single tensor is
+ treated as a singleton list.
+ ids: Tensor of integers containing the ids to be looked up in
+ 'params'. Let P be len(params). If P > 1, then the ids are
+ partitioned by id % P, and we do separate lookups in params[p]
+ for 0 <= p < P, and then stitch the results back together into
+ a single result tensor.
+ name: Optional name for the op.
+
+ Returns:
+ A tensor of shape ids.shape + params[0].shape[1:] containing the
+ values params[i % P][i] for each i in ids.
+
+ Raises:
+ ValueError: if some parameters are invalid.
+ """
+ if not isinstance(params, list):
+ params = [params]
+ with ops.op_scope(params + [ids], name, "embedding_lookup") as name:
+ if not params:
+ raise ValueError("Need at least one param")
+ np = len(params) # Number of partitions
+ params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
+ if np == 1:
+ with ops.device(params[0].device):
+ return array_ops.gather(params[0], ids, name=name)
+ else:
+ ids = ops.convert_to_tensor(ids, name="ids")
+ flat_ids = array_ops.reshape(ids, [-1])
+ original_indices = math_ops.range(0, array_ops.size(flat_ids))
+ # Compute flat_ids % partitions for each id
+ ids_mod_p = flat_ids % np
+ if ids_mod_p.dtype != types.int32:
+ ids_mod_p = math_ops.cast(ids_mod_p, types.int32)
+ # Partition single list of ids based on ids % np into np separate lists
+ plist = data_flow_ops.dynamic_partition(flat_ids, ids_mod_p, np)
+ # Similarly, partition the original indices.
+ pindices = data_flow_ops.dynamic_partition(original_indices, ids_mod_p,
+ np)
+ # Do np separate lookups, finding embeddings for plist[p] in params[p]
+ partitioned_result = []
+ for p in range(np):
+ # TODO(agarwal): handle device allocations here and later in the
+ # colocate code.
+ gather_ids = plist[p] / np
+ with ops.device(params[p].device):
+ partitioned_result.append(array_ops.gather(params[p], gather_ids))
+ # Stitch these back together
+ ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result,
+ name=name)
+ # Reshape to reverse the flattening of ids.
+ # It's important that we compute params[0].shape on the right device
+ # to avoid data motion.
+ with ops.device(params[0].device):
+ params_shape = array_ops.shape(params[0])
+ ret = array_ops.reshape(ret, array_ops.concat(0, [
+ array_ops.shape(ids), array_ops.slice(params_shape, [1], [-1])]))
+ # output shape = ids.shape + params[*].shape[1:]
+ # Normally the reshape is sufficient, but setting shape explicitly
+ # teaches shape inference that params[1:].get_shape() matters.
+ element_shape = params[0].get_shape()[1:]
+ for p in params[1:]:
+ element_shape = element_shape.merge_with(p.get_shape()[1:])
+ ret.set_shape(ids.get_shape().concatenate(element_shape))
+ return ret
+
+
+# TODO(lif): Add support for higher-rank SparseTensors
+def embedding_lookup_sparse(params, sp_ids, sp_weights,
+ name=None,
+ combiner="mean"):
+ """Computes embeddings for the given ids and weights.
+
+ This op assumes that there is at least one id for each row in the dense tensor
+ represented by sp_ids (i.e. there are no rows with empty features), and that
+ all the indices of sp_ids are in canonical row-major order.
+
+ It also assumes that all id values lie in the range [0, p0), where p0
+ is the sum of the size of params along dimension 0.
+
+ Args:
+ params: A single tensor representing the complete embedding tensor,
+ or a list of P tensors all of same shape except for the first dimension,
+ representing sharded embedding tensors. In the latter case, the ids are
+ partitioned by id % P, and we do separate lookups in params[p] for
+ 0 <= p < P, and then stitch the results back together into a single
+ result tensor. The first dimension is allowed to vary as the vocab
+ size is not necessarily a multiple of P.
+ sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
+ where N is typically batch size and M is arbitrary.
+ sp_weights: either a SparseTensor of float / double weights, or None to
+ indicate all weights should be taken to be 1. If specified, sp_weights
+ must have exactly the same shape and indices as sp_ids.
+ name: Optional name for the op.
+ combiner: A string specifying the reduction op. Currently "mean" and "sum"
+ are supported.
+ "sum" computes the weighted sum of the embedding results for each row.
+ "mean" is the weighted sum divided by the total weight.
+
+ Returns:
+ A dense tensor representing the combined embeddings for the
+ sparse ids. For each row in the dense tensor represented by sp_ids, the op
+ looks up the embeddings for all ids in that row, multiplies them by the
+ corresponding weight, and combines these embeddings as specified.
+
+ In other words, if
+ shape(combined params) = [p0, p1, ..., pm]
+ and
+ shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]
+ then
+ shape(output) = [d0, d1, ..., dn-1, p1, ..., pm].
+
+ For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
+
+ [0, 0]: id 1, weight 2.0
+ [0, 1]: id 3, weight 0.5
+ [1, 0]: id 0, weight 1.0
+ [2, 3]: id 1, weight 3.0
+
+ with combiner="mean", then the output will be a 3x20 matrix where
+ output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
+ output[1, :] = params[0, :] * 1.0
+ output[2, :] = params[1, :] * 3.0
+
+ Raises:
+ TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
+ None nor SparseTensor.
+ ValueError: If combiner is not one of {"mean", "sum"}.
+ """
+ if combiner not in ("mean", "sum"):
+ raise ValueError("combiner must be one of 'mean' or 'sum'")
+ if not isinstance(params, list):
+ params = [params]
+ if not isinstance(sp_ids, ops.SparseTensor):
+ raise TypeError("sp_ids must be SparseTensor")
+ ignore_weights = sp_weights is None
+ if not ignore_weights and not isinstance(sp_weights, ops.SparseTensor):
+ raise TypeError("sp_weights must be either None or SparseTensor")
+
+ with ops.op_scope(params + [sp_ids], name, "embedding_lookup_sparse") as name:
+ segment_ids = sp_ids.indices[:, 0]
+ if segment_ids.dtype != types.int32:
+ segment_ids = math_ops.cast(segment_ids, types.int32)
+
+ ids = sp_ids.values
+ if ignore_weights:
+ ids, idx = array_ops.unique(ids)
+ else:
+ idx = None
+
+ embeddings = embedding_lookup(params, ids)
+ if not ignore_weights:
+ weights = sp_weights.values
+ if weights.dtype != embeddings.dtype:
+ weights = math_ops.cast(weights, embeddings.dtype)
+
+ # Reshape weights to allow broadcast
+ ones = array_ops.fill(
+ array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
+ bcast_weights_shape = array_ops.concat(0, [
+ array_ops.shape(weights), ones])
+ weights = array_ops.reshape(weights, bcast_weights_shape)
+ embeddings *= weights
+
+ if combiner == "sum":
+ embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name)
+ elif combiner == "mean":
+ embeddings = math_ops.segment_sum(embeddings, segment_ids)
+ weight_sum = math_ops.segment_sum(weights, segment_ids)
+ embeddings = math_ops.div(embeddings, weight_sum, name=name)
+ else:
+ assert False, "Unrecognized combiner"
+ else:
+ assert idx is not None
+ if combiner == "sum":
+ embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids,
+ name=name)
+ elif combiner == "mean":
+ embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids,
+ name=name)
+ else:
+ assert False, "Unrecognized combiner"
+
+ return embeddings
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
new file mode 100644
index 0000000000..ffa7828c04
--- /dev/null
+++ b/tensorflow/python/ops/gradients.py
@@ -0,0 +1,661 @@
+"""Implements the graph generation for computation of gradients."""
+
+import collections
+import warnings
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+# pylint: disable=unused-import
+from tensorflow.python.ops import array_grad
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_grad
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import linalg_grad
+from tensorflow.python.ops import math_grad
+# pylint: enable=unused-import
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.platform import logging
+
+
+# Warn the user if we convert a sparse representation to dense with at
+# least this number of elements.
+_LARGE_SPARSE_NUM_ELEMENTS = 100000000
+
+
+def _IndexedSlicesToTensor(value, dtype=None, name=None):
+ """Converts an IndexedSlices object `value` to a Tensor.
+
+ NOTE(mrry): This function is potentially expensive.
+
+ Args:
+ value: An ops.IndexedSlices object.
+ dtype: The dtype of the Tensor to be returned.
+ name: Optional name to use for the returned Tensor.
+
+ Returns:
+ A dense Tensor representing the values in the given IndexedSlices.
+
+ Raises:
+ ValueError: If the IndexedSlices does not have the same dtype.
+ """
+ if dtype and not dtype.is_compatible_with(value.dtype):
+ raise ValueError(
+ "Tensor conversion requested dtype %s for IndexedSlices with dtype %s"
+ % (dtype.name, value.dtype.name))
+ if value.dense_shape is None:
+ raise ValueError(
+ "Tensor conversion requested for IndexedSlices without dense_shape: %s"
+ % str(value))
+ # TODO(mrry): Consider adding static shape information to
+ # IndexedSlices, to avoid using numpy here.
+ dense_shape_value = tensor_util.ConstantValue(value.dense_shape)
+ if dense_shape_value is not None:
+ num_elements = np.prod(dense_shape_value)
+ if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS:
+ warnings.warn(
+ "Converting sparse IndexedSlices to a dense Tensor with %d elements. "
+ "This may consume a large amount of memory." % num_elements)
+ else:
+ warnings.warn(
+ "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
+ "This may consume a large amount of memory.")
+ return math_ops.unsorted_segment_sum(
+ value.values, value.indices, value.dense_shape[0], name=name)
+
+
+ops.register_tensor_conversion_function(ops.IndexedSlices, _IndexedSlicesToTensor)
+
+
+def _MarkReachedOps(from_ops, reached_ops):
+ """Mark all ops reached from "from_ops".
+
+ Args:
+ from_ops: list of Operations.
+ reached_ops: list of booleans, indexed by operation id.
+ """
+ queue = collections.deque()
+ queue.extend(from_ops)
+ while queue:
+ op = queue.popleft()
+ if not reached_ops[op._id]:
+ reached_ops[op._id] = True
+ for output in op.outputs:
+ queue.extend(output.consumers())
+
+
+def _GatherInputs(to_ops, reached_ops):
+ """List all inputs of to_ops that are in reached_ops.
+
+ Args:
+ to_ops: list of Operations.
+ reached_ops: list of booleans, indexed by operation id.
+
+ Returns:
+ The list of all inputs of to_ops that are in reached_ops.
+ That list includes all elements of to_ops.
+ """
+ inputs = []
+ queue = collections.deque()
+ queue.extend(to_ops)
+ while queue:
+ op = queue.popleft()
+ # We are interested in this op.
+ if reached_ops[op._id]:
+ inputs.append(op)
+ # Clear the boolean so we won't add the inputs again.
+ reached_ops[op._id] = False
+ for inp in op.inputs:
+ queue.append(inp.op)
+ return inputs
+
+
+def _GetGradsDevice(op, colocate_gradients_with_ops):
+ """Gets the device to which to assign gradients of "op".
+
+ Args:
+ op: an Operation.
+ colocate_gradients_with_ops: If True, try colocating gradients with the
+ corresponding op.
+
+ Returns:
+ A device string.
+ """
+ if colocate_gradients_with_ops and op.device:
+ return op.device
+ else:
+ return op.graph.get_default_device()
+
+
+def _PendingCount(graph, to_ops, from_ops):
+ """Initialize the pending count for ops between two lists of Operations.
+
+ 'pending_count[op._id]' indicates the number of backprop inputs
+ to this operation.
+
+ Args:
+ graph: a Graph.
+ to_ops: list of Operations.
+ from_ops: list of Operations.
+
+ Returns:
+ A tuple containing: (1) a list of integers indexed by operation id,
+ indicating the number of backprop inputs to this operation, and (2)
+ a boolean which is True if any of the ops in between from_ops and to_ops
+ contain control flow loops.
+ """
+ # Mark reachable ops from from_ops.
+ reached_ops = [False] * (graph._last_id + 1)
+ for op in to_ops:
+ reached_ops[op._id] = True
+ _MarkReachedOps(from_ops, reached_ops)
+
+ # Mark between ops.
+ between_ops = [False] * (graph._last_id + 1)
+ between_op_list = []
+ queue = collections.deque()
+ queue.extend(to_ops)
+ while queue:
+ op = queue.popleft()
+ # We are interested in this op.
+ if reached_ops[op._id]:
+ between_ops[op._id] = True
+ between_op_list.append(op)
+ # Clear the boolean so we won't add the inputs again.
+ reached_ops[op._id] = False
+ for inp in op.inputs:
+ queue.append(inp.op)
+
+ # Initialize pending count for between ops.
+ pending_count = [0] * (graph._last_id + 1)
+ has_control_flow = False
+ for op in between_op_list:
+ for x in op.inputs:
+ if between_ops[x.op._id]:
+ pending_count[x.op._id] += 1
+ for x in op.control_inputs:
+ if between_ops[x._id]:
+ pending_count[x._id] += 1
+ if op.type == "Exit":
+ has_control_flow = True
+ return pending_count, has_control_flow
+
+
+def _AsList(x):
+ return x if isinstance(x, (list, tuple)) else [x]
+
+
+def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
+ """Fill in default values for grad_ys.
+
+ Args:
+ grad_ys: List of gradients, can contain None.
+ ys: List of tensors.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+
+ Returns:
+ A list of gradients to use, without None.
+
+ Raises:
+ ValueError: If one of the grad_ys is invalid.
+ """
+ if len(grad_ys) != len(ys):
+ raise ValueError("Passed %d grad_ys for %d ys" % (len(grad_ys), len(ys)))
+ grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y")
+ for i in xrange(len(grad_ys)):
+ grad_y = grad_ys[i]
+ y = ys[i]
+ if grad_y is None:
+ with ops.device(_GetGradsDevice(y.op, colocate_gradients_with_ops)):
+ grad_ys[i] = array_ops.fill(array_ops.shape(y),
+ constant_op.constant(1, dtype=y.dtype))
+ else:
+ if grad_y.dtype != y.dtype:
+ raise ValueError("Y and ys_grad must be of the same type, "
+ "not y: %s, ys_grad: %s " %
+ (types.as_dtype(y.dtype).name,
+ types.as_dtype(grad_y.dtype).name))
+ return grad_ys
+
+
+def _VerifyGeneratedGradients(grads, op):
+ """Verify that gradients are valid in number and type.
+
+ Args:
+ grads: List of generated gradients.
+ op: Operation for which the gradients where generated.
+
+ Raises:
+ ValueError: if the gradients are invalid.
+ """
+ if len(grads) != len(op.inputs):
+ raise ValueError("Num gradients %d generated for op %s do not match num "
+ "inputs %d" % (len(grads), op.node_def, len(op.inputs)))
+ for i in xrange(len(grads)):
+ grad = grads[i]
+ inp = op.inputs[i]
+ if grad is not None:
+ if not grad.dtype.is_compatible_with(inp.dtype):
+ raise ValueError(
+ "Gradient type %s generated for op %s does "
+ "not match input type %s" %
+ (types.as_dtype(grad.dtype).name, op.node_def,
+ types.as_dtype(inp.dtype).name))
+
+
+def _StopOps(from_ops, pending_count):
+ """The set of ops that terminate the gradient computation.
+
+ This computes the frontier of the forward graph *before* which backprop
+ should stop. Operations in the returned set will not be differentiated.
+ This set is defined as the subset of `from_ops` containing ops that have
+ no predecessor in `from_ops`. `pending_count` is the result of
+ `_PendingCount(g, xs, from_ops)`. An 'op' has predecessors in `from_ops`
+ iff pending_count[op._id] > 0.
+
+ Args:
+ from_ops: list of Operations.
+ pending_count: List of integers, indexed by operation id.
+
+ Returns:
+ The set of operations.
+ """
+ stop_ops = set()
+ for op in from_ops:
+ is_stop_op = True
+ for inp in op.inputs:
+ if pending_count[inp.op._id] > 0:
+ is_stop_op = False
+ break
+ if is_stop_op:
+ stop_ops.add(op._id)
+ return stop_ops
+
+
+def gradients(ys, xs, grad_ys=None, name="gradients",
+ colocate_gradients_with_ops=False,
+ gate_gradients=False,
+ aggregation_method=None):
+ """Constructs symbolic partial derivatives of `ys` w.r.t. x in `xs`.
+
+ `ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys`
+ is a list of `Tensor`, holding the gradients received by the
+ `ys`. The list must be the same length as `ys`.
+
+ `gradients()` adds ops to the graph to output the partial
+ derivatives of `ys` with respect to `xs`. It returns a list of
+ `Tensor` of length `len(xs)` where each tensor is the `sum(dy/dx)`
+ for y in `ys`.
+
+ `grad_ys` is a list of tensors of the same length as `ys` that holds
+ the initial gradients for each y in `ys`. When `grad_ys` is None,
+ we fill in a tensor of '1's of the shape of y for each y in `ys`. A
+ user can provide their own initial 'grad_ys` to compute the
+ derivatives using a different initial gradient for each y (e.g., if
+ one wanted to weight the gradient differently for each value in
+ each y).
+
+ Args:
+ ys: A `Tensor` or list of tensors to be differentiated.
+ xs: A `Tensor` or list of tensors to be used for differentiation.
+ grad_ys: Optional. A `Tensor` or list of tensors the same size as
+ `ys` and holding the gradients computed for each y in `ys`.
+ name: Optional name to use for grouping all the gradient ops together.
+ defaults to 'gradients'.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+ gate_gradients: If True, add a tuple around the gradients returned
+ for an operations. This avoids some race conditions.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Accepted values are constants defined in the class `AggregationMethod`.
+
+ Returns:
+ A list of `sum(dy/dx)` for each x in `xs`.
+
+ Raises:
+ LookupError: if one of the operations between `x` and `y` does not
+ have a registered gradient function.
+ ValueError: if the arguments are invalid.
+
+ """
+ ys = _AsList(ys)
+ xs = _AsList(xs)
+ if grad_ys is None:
+ grad_ys = [None] * len(ys)
+ else:
+ grad_ys = _AsList(grad_ys)
+ with ops.op_scope(ys + xs + grad_ys, name, "gradients"):
+ ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
+ xs = ops.convert_n_to_tensor_or_indexed_slices(xs, name="x")
+ grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)
+
+ # The approach we take here is as follows: Create a list of all ops in the
+ # subgraph between the ys and xs. Visit these ops in reverse order of ids
+ # to ensure that when we visit an op the gradients w.r.t its outputs have
+ # been collected. Then aggregate these gradients if needed, call the op's
+ # gradient function, and add the generated gradients to the gradients for
+ # its input.
+
+ # Initialize the pending count for ops in the connected subgraph from ys
+ # to the xs.
+ to_ops = [t.op for t in ys]
+ from_ops = [t.op for t in xs]
+ pending_count, has_control_flow = _PendingCount(
+ ops.get_default_graph(), to_ops, from_ops)
+
+ # Iterate over the collected ops.
+ #
+ # grads: op => list of gradients received on each output endpoint of the
+ # op. The gradients for each endpoint are initially collected as a list.
+ # When it is time to call the op's gradient function, for each endpoint we
+ # aggregate the list of received gradients into a Add() Operation if there
+ # is more than one.
+ grads = {}
+
+ # Add the initial gradients for the ys.
+ for y, grad_y in zip(ys, grad_ys):
+ _SetGrad(grads, y, grad_y)
+
+ # Initialize queue with to_ops.
+ queue = collections.deque()
+ # Add the ops in 'to_ops' into the queue.
+ to_ops_set = set()
+ for op in to_ops:
+ if op._id not in to_ops_set:
+ to_ops_set.add(op._id)
+ queue.append(op)
+ # The set of 'from_ops'.
+ stop_ops = _StopOps(from_ops, pending_count)
+ while queue:
+ # generate gradient subgraph for op.
+ op = queue.popleft()
+ with ops.device(_GetGradsDevice(op, colocate_gradients_with_ops)):
+ if has_control_flow:
+ control_flow_ops.EnterGradWhileContext(op)
+ out_grads = _AggregatedGrads(grads, op, has_control_flow,
+ aggregation_method)
+ grad_fn = None
+ if any(out_grads) and op._id not in stop_ops:
+ # A grad_fn must be defined, either as a function or as None
+ # for ops that do not have gradients.
+ try:
+ grad_fn = ops.get_gradient_function(op)
+ except LookupError:
+ raise LookupError(
+ "No gradient defined for operation '%s' (op type: %s)" %
+ (op.name, op.type))
+ if grad_fn and any(out_grads):
+ # NOTE: If _AggregatedGrads didn't compute a value for the i'th
+ # output, it means that the cost does not depend on output[i],
+ # therefore dC/doutput[i] is 0.
+ for i, out_grad in enumerate(out_grads):
+ if (not out_grad
+ and types.as_dtype(op.outputs[i].dtype).base_dtype in (
+ types.float32, types.float64)):
+ # Only floating-point outputs get a zero gradient. Gradient
+ # functions should ignore the gradient for other outputs.
+ out_grads[i] = array_ops.zeros_like(op.outputs[i])
+ with ops.name_scope(op.name + "_grad"):
+ # pylint: disable=protected-access
+ with ops.get_default_graph()._original_op(op):
+ # pylint: enable=protected-access
+ op_wrapper = op
+ if has_control_flow:
+ op_wrapper = control_flow_ops.MakeWrapper(op)
+ in_grads = _AsList(grad_fn(op_wrapper, *out_grads))
+ _VerifyGeneratedGradients(in_grads, op)
+ if gate_gradients and len(in_grads) > 1:
+ in_grads = control_flow_ops.tuple(in_grads)
+ logging.vlog(1, "Gradient for '" + op.name + "'")
+ logging.vlog(1, " in --> %s",
+ ", ".join([x.name for x in out_grads if x]))
+ logging.vlog(1, " out --> %s",
+ ", ".join([x.name for x in in_grads if x]))
+ else:
+ # If no grad_fn is defined or none of out_grads is available,
+ # just propagates a list of None backwards.
+ in_grads = [None] * len(op.inputs)
+ for t_in, in_grad in zip(op.inputs, in_grads):
+ if in_grad:
+ _SetGrad(grads, t_in, in_grad)
+ if has_control_flow:
+ control_flow_ops.ExitGradWhileContext(op)
+
+ # update pending count for the inputs of op.
+ for x in op.inputs:
+ pending_count[x.op._id] -= 1
+ ready = (pending_count[x.op._id] == 0)
+ if has_control_flow and not ready:
+ ready = (pending_count[x.op._id] > 0 and
+ control_flow_ops.IsLoopSwitch(x.op))
+ if ready:
+ queue.append(x.op)
+ for x in op.control_inputs:
+ pending_count[x._id] -= 1
+ if pending_count[x._id] is 0:
+ queue.append(x)
+ return [_GetGrad(grads, x) for x in xs]
+
+
+def _SetGrad(grads, t, grad):
+ """Sets gradient "grad" in "grads" for tensor "t"."""
+ op = t.op
+ op_grads = grads.get(op)
+ if not op_grads:
+ op_grads = [[] for _ in xrange(len(op.outputs))]
+ grads[op] = op_grads
+ t_grads = op_grads[t.value_index]
+ if isinstance(t_grads, list):
+ t_grads.append(grad)
+ else:
+ assert op.type == "Switch"
+ op_grads[t.value_index] = grad
+
+
+def _GetGrad(grads, t):
+ """Gets gradient for tensor "t"."""
+ op = t.op
+ op_grads = grads.get(op)
+ if not op_grads: return None
+ t_grad = op_grads[t.value_index]
+ assert not isinstance(t_grad, list), (
+ "gradients list should have been aggregated by now.")
+ return t_grad
+
+
+def _GetGrads(grads, op):
+ """Gets all gradients for op."""
+ if op in grads:
+ return grads[op]
+ else:
+ return [[] for _ in xrange(len(op.outputs))]
+
+
+def _HandleNestedIndexedSlices(grad):
+ assert isinstance(grad, ops.IndexedSlices)
+ if isinstance(grad.values, ops.Tensor):
+ return grad
+ else:
+ assert isinstance(grad.values, ops.IndexedSlices)
+ g = _HandleNestedIndexedSlices(grad.values)
+ return ops.IndexedSlices(
+ g.values, array_ops.gather(grad.indices, g.indices), g.dense_shape)
+
+
+def _AccumulatorShape(inputs):
+ shape = tensor_shape.unknown_shape()
+ for i in inputs:
+ if isinstance(i, ops.Tensor):
+ shape = shape.merge_with(i.get_shape())
+ return shape
+
+
+class AggregationMethod(object):
+ """A class listing aggregation methods used to combine gradients.
+
+ Computing partial derivatives can require aggregating gradient
+ contributions. This class lists the various methods that can
+ be used to combine gradients in the graph:
+
+ * `ADD_N`: All of the gradient terms are summed as part of one
+ operation using the "AddN" op. It has the property that all
+ gradients must be ready before any aggregation is performed.
+ * `DEFAULT`: The system-chosen default aggregation method.
+ """
+ ADD_N = 0
+ DEFAULT = ADD_N
+ # The following are experimental and may not be supported in future releases.
+ EXPERIMENTAL_TREE = 1
+ EXPERIMENTAL_ACCUMULATE_N = 2
+
+
+def _AggregatedGrads(grads, op, has_control_flow, aggregation_method=None):
+ """Get the aggregated gradients for op.
+
+ Args:
+ grads: The map of memoized gradients.
+ op: The op to get gradients for.
+ has_control_flow: True iff the graph contains control flow ops.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Accepted values are constants defined in the class `AggregationMethod`.
+
+ Returns:
+ A list of gradients, one per each output of `op`. If the gradients
+ for a particular output is a list, this function aggregates it
+ before returning.
+
+ Raises:
+ TypeError: if the incoming grads are not Tensors or IndexedSlices.
+ ValueError: if the arguments are invalid.
+
+ """
+ if aggregation_method is None:
+ aggregation_method = AggregationMethod.DEFAULT
+ if aggregation_method not in [AggregationMethod.ADD_N,
+ AggregationMethod.EXPERIMENTAL_TREE,
+ AggregationMethod.EXPERIMENTAL_ACCUMULATE_N]:
+ raise ValueError("Invalid aggregation_method specified.")
+ out_grads = _GetGrads(grads, op)
+ for i, out_grad in enumerate(out_grads):
+ if has_control_flow:
+ if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
+ assert op.type == "Switch"
+ continue
+ # Grads have to be Tensors or IndexedSlices
+ if not all([isinstance(g, (ops.Tensor, ops.IndexedSlices))
+ for g in out_grad if g]):
+ raise TypeError("gradients have to be either all Tensors "
+ "or all IndexedSlices")
+ # Aggregate multiple gradients, and convert [] to None.
+ if out_grad:
+ if all([isinstance(g, ops.Tensor) for g in out_grad if g]):
+ tensor_shape = _AccumulatorShape(out_grad)
+ if len(out_grad) < 2:
+ used = "nop"
+ out_grads[i] = out_grad[0]
+ elif (aggregation_method == AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
+ and len(out_grad) > 2 and tensor_shape.is_fully_defined()):
+ # The benefit of using AccumulateN is that its inputs can be combined
+ # in any order and this can allow the expression to be evaluated with
+ # a smaller memory footprint. When used with gpu_allocator_retry,
+ # it is possible to compute a sum of terms which are much larger than
+ # total GPU memory.
+ # AccumulateN can currently only be used if we know the shape for
+ # an accumulator variable. If this is not known, or if we only have
+ # 2 grads then we fall through to the "tree" case below.
+ used = "accumulate_n"
+ out_grads[i] = math_ops.accumulate_n(out_grad)
+ elif aggregation_method in [AggregationMethod.EXPERIMENTAL_TREE,
+ AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
+ ]:
+ # Aggregate all gradients by doing pairwise sums: this may
+ # reduce performance, but it can improve memory because the
+ # gradients can be released earlier.
+ #
+ # TODO(vrv): Consider replacing this with a version of
+ # tf.AddN() that eagerly frees its inputs as soon as they are
+ # ready, so the order of this tree does not become a problem.
+ used = "tree"
+ with ops.name_scope(op.name + "_gradient_sum"):
+ running_sum = out_grad[0]
+ for grad in out_grad[1:]:
+ running_sum = math_ops.add_n([running_sum, grad])
+ out_grads[i] = running_sum
+ else:
+ used = "add_n"
+ out_grads[i] = math_ops.add_n(out_grad)
+ logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad),
+ tensor_shape, used)
+ else:
+ out_grad = math_ops._as_indexed_slices_list([g for g in out_grad if g])
+ out_grad = [_HandleNestedIndexedSlices(x) for x in out_grad]
+ # Form IndexedSlices out of the concatenated values and
+ # indices.
+ out_grads[i] = ops.IndexedSlices(
+ array_ops.concat(0, [x.values for x in out_grad]),
+ array_ops.concat(0, [x.indices for x in out_grad]),
+ out_grad[0].dense_shape)
+ else:
+ out_grads[i] = []
+ return out_grads
+
+
+# TODO(vrv): Make this available when we want to make it public.
+def _hessian_vector_product(ys, xs, v):
+ """Multiply the Hessian of `ys` wrt `xs` by `v`.
+
+ This is an efficient construction that uses a backprop-like approach
+ to compute the product between the Hessian and another vector. The
+ Hessian is usually too large to be explicitly computed or even
+ represented, but this method allows us to at least multiply by it
+ for the same big-O cost as backprop.
+
+ Implicit Hessian-vector products are the main practical, scalable way
+ of using second derivatives with neural networks. They allow us to
+ do things like construct Krylov subspaces and approximate conjugate
+ gradient descent.
+
+ Example: if `y` = 1/2 `x`^T A `x`, then `hessian_vector_product(y,
+ x, v)` will return an expression that evaluates to the same values
+ as (A + A.T) `v`.
+
+ Args:
+ ys: A scalar value, or a tensor or list of tensors to be summed to
+ yield a scalar.
+ xs: A list of tensors that we should construct the Hessian over.
+ v: A list of tensors, with the same shapes as xs, that we want to
+ multiply by the Hessian.
+
+ Returns:
+ A list of tensors (or if the list would be length 1, a single tensor)
+ containing the product between the Hessian and `v`.
+
+ Raises:
+ ValueError: `xs` and `v` have different length.
+
+ """
+
+ # Validate the input
+ length = len(xs)
+ if len(v) != length:
+ raise ValueError("xs and v must have the same length.")
+
+ # First backprop
+ grads = gradients(ys, xs)
+
+ assert len(grads) == length
+ elemwise_products = [math_ops.mul(grad_elem, array_ops.stop_gradient(v_elem))
+ for grad_elem, v_elem in zip(grads, v)
+ if grad_elem is not None]
+
+ # Second backprop
+ return gradients(elemwise_products, xs)
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
new file mode 100644
index 0000000000..dac0ebbb60
--- /dev/null
+++ b/tensorflow/python/ops/gradients_test.py
@@ -0,0 +1,337 @@
+"""Tests for tensorflow.ops.gradients."""
+import warnings
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+# pylint: disable=unused-import
+from tensorflow.python.ops import array_grad
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import data_flow_grad
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_grad
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_grad
+from tensorflow.python.ops import state_grad
+# pylint: enable=unused-import
+from tensorflow.python.ops.constant_op import constant
+from tensorflow.python.ops.nn_ops import bias_add
+from tensorflow.python.platform import googletest
+
+
+def _OpsBetween(graph, to_ops, from_ops):
+ """Build the list of operations between two lists of Operations.
+
+ Args:
+ graph: a Graph.
+ to_ops: list of Operations.
+ from_ops: list of Operations.
+
+ Returns:
+ The list of operations between "from_ops" and "to_ops", sorted by
+ decreasing operation id. This list contains all elements of to_ops.
+
+ TODO(mdevin): Think about returning an empty list if from_ops are not
+ reachable from to_ops. Presently it returns to_ops in that case.
+ """
+ # List of booleans, indexed by operation id, indicating if
+ # an op is reached from the output of "input_ops".
+ reached_ops = [False] * (graph._last_id + 1)
+ # We only care to reach up to "output_ops" so we mark the
+ # output ops as reached to avoid recursing past them.
+ for op in to_ops:
+ reached_ops[op._id] = True
+ gradients._MarkReachedOps(from_ops, reached_ops)
+ between_ops = gradients._GatherInputs(to_ops, reached_ops)
+ between_ops.sort(lambda x, y: y._id - x._id)
+ return between_ops
+
+
+class GradientsTest(test_util.TensorFlowTestCase):
+
+ def _OpNames(self, op_list):
+ return ["%s/%d" % (str(op.name), op._id) for op in op_list]
+
+ def _assertOpListEqual(self, ops1, ops2):
+ self.assertEquals(self._OpNames(ops1), self._OpNames(ops2))
+
+ def testOpsBetweenSimple(self):
+ with ops.Graph().as_default() as g:
+ t1 = constant(1.0)
+ t2 = constant(2.0)
+ t3 = array_ops.pack([t1, t2])
+ # Full graph
+ self._assertOpListEqual([t3.op, t2.op, t1.op],
+ _OpsBetween(g, [t3.op], [t1.op, t2.op]))
+ # Only t1, t3.
+ self._assertOpListEqual([t3.op, t1.op],
+ _OpsBetween(g, [t3.op], [t1.op]))
+
+ def testOpsBetweenUnreachable(self):
+ with ops.Graph().as_default() as g:
+ t1 = constant(1.0)
+ t2 = constant(2.0)
+ _ = array_ops.pack([t1, t2])
+ t4 = constant(1.0)
+ t5 = constant(2.0)
+ t6 = array_ops.pack([t4, t5])
+ # Elements of to_ops are always listed.
+ self._assertOpListEqual([t6.op], _OpsBetween(g, [t6.op], [t1.op]))
+
+ def testOpsBetweenCut(self):
+ with ops.Graph().as_default() as g:
+ t1 = constant(1.0)
+ t2 = constant(2.0)
+ t3 = array_ops.pack([t1, t2])
+ t4 = constant([1.0])
+ t5 = array_ops.concat(0, [t4, t3])
+ t6 = constant([2.0])
+ t7 = array_ops.concat(0, [t5, t6])
+ self._assertOpListEqual([t7.op, t5.op, t4.op],
+ _OpsBetween(g, [t7.op], [t4.op]))
+
+ def testOpsBetweenCycle(self):
+ with ops.Graph().as_default() as g:
+ t1 = constant(1.0)
+ t2 = constant(2.0)
+ t3 = array_ops.pack([t1, t2])
+ t4 = array_ops.concat(0, [t3, t3, t3])
+ t5 = constant([1.0])
+ t6 = array_ops.concat(0, [t4, t5])
+ t7 = array_ops.concat(0, [t6, t3])
+ self._assertOpListEqual([t6.op, t4.op, t3.op],
+ _OpsBetween(g, [t6.op], [t3.op]))
+ self._assertOpListEqual([t7.op, t6.op, t5.op, t4.op, t3.op, t1.op],
+ _OpsBetween(g, [t7.op], [t1.op, t5.op]))
+ self._assertOpListEqual([t6.op, t5.op, t4.op, t3.op, t2.op],
+ _OpsBetween(g, [t6.op], [t2.op, t5.op]))
+
+ def testGradients(self):
+ with ops.Graph().as_default():
+ inp = constant(1.0, shape=[32, 100], name="in")
+ w = constant(1.0, shape=[100, 10], name="w")
+ b = constant(1.0, shape=[10], name="b")
+ xw = math_ops.matmul(inp, w, name="xw")
+ h = bias_add(xw, b, name="h")
+ w_grad = gradients.gradients(h, w)[0]
+ self.assertEquals("MatMul", w_grad.op.type)
+ self.assertEquals(w_grad.op._original_op, xw.op)
+ self.assertTrue(w_grad.op.get_attr("transpose_a"))
+ self.assertFalse(w_grad.op.get_attr("transpose_b"))
+
+ def testUnusedOutput(self):
+ with ops.Graph().as_default():
+ w = constant(1.0, shape=[2, 2])
+ x = constant(1.0, shape=[2, 2])
+ wx = math_ops.matmul(w, x)
+ split_wx = array_ops.split(0, 2, wx)
+ c = math_ops.reduce_sum(split_wx[1])
+ gw = gradients.gradients(c, [w])[0]
+ self.assertEquals("MatMul", gw.op.type)
+
+ def testColocateGradients(self):
+ with ops.Graph().as_default() as g:
+ w = constant(1.0, shape=[1, 1])
+ x = constant(1.0, shape=[1, 2])
+ with g.device("/gpu:0"):
+ wx = math_ops.matmul(w, x)
+ gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0]
+ self.assertEquals("/gpu:0", gw.device)
+
+ def testColocateGradientsWithAggregation(self):
+ with ops.Graph().as_default() as g:
+ with g.device("/gpu:1"):
+ w = constant(1.0, shape=[1, 1])
+ x = constant(1.0, shape=[1, 2])
+ y = constant(1.0, shape=[1, 2])
+ wx = math_ops.matmul(w, x)
+ wy = math_ops.matmul(w, y)
+ with g.device("/gpu:0"):
+ z = wx + wy
+ gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
+ self.assertEquals("/gpu:1", gw1.device)
+ gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
+ self.assertEquals(None, gw2.device)
+
+ def testBoundaryStop(self):
+ # Test that we don't differentiate 'x'. The gradient function for 'x' is
+ # set explicitly to None so we will get an exception if the gradient code
+ # tries to differentiate 'x'.
+ with ops.Graph().as_default() as g:
+ c = constant(1.0)
+ x = array_ops.identity(c)
+ y = x + 1.0
+ z = y + 1
+ grads = gradients.gradients(z, [x])
+ self.assertTrue(all([x for x in grads]))
+
+ def testBoundaryContinue(self):
+ # Test that we differentiate both 'x' and 'y' correctly when x is a
+ # predecessor of y.
+ with self.test_session():
+ x = constant(1.0)
+ y = x * 2.0
+ z = y * 3.0
+ grads = gradients.gradients(z, [x, y])
+ self.assertTrue(all([x for x in grads]))
+ self.assertEqual(6.0, grads[0].eval())
+
+ def testAggregationMethodAccumulateN(self):
+ with self.test_session():
+ x = constant(1.0)
+ y = x * 2.0
+ z = y + y + y + y + y + y + y + y + y + y
+ grads = gradients.gradients(
+ z,
+ [x, y],
+ aggregation_method=
+ gradients.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
+ self.assertTrue(all([x for x in grads]))
+ self.assertEqual(20.0, grads[0].eval())
+ self.assertEqual(10.0, grads[1].eval())
+
+ def testAggregationMethodAddN(self):
+ with self.test_session():
+ x = constant(1.0)
+ y = x * 2.0
+ z = y + y + y + y + y + y + y + y + y + y
+ grads = gradients.gradients(
+ z,
+ [x, y],
+ aggregation_method=gradients.AggregationMethod.ADD_N)
+ self.assertTrue(all([x for x in grads]))
+ self.assertEqual(20.0, grads[0].eval())
+ self.assertEqual(10.0, grads[1].eval())
+
+ def testAggregationMethodTree(self):
+ with self.test_session():
+ x = constant(1.0)
+ y = x * 2.0
+ z = y + y + y + y + y + y + y + y + y + y
+ grads = gradients.gradients(
+ z,
+ [x, y],
+ aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE)
+ self.assertTrue(all([x for x in grads]))
+ self.assertEqual(20.0, grads[0].eval())
+ self.assertEqual(10.0, grads[1].eval())
+
+ def testNoGradientForStringOutputs(self):
+ with ops.Graph().as_default() as g:
+ @ops.RegisterGradient("TestOp")
+ def _TestOpGrad(op, float_grad, string_grad):
+ """Gradient function for TestOp."""
+ self.assertEquals(float_grad.dtype, types.float32)
+ self.assertFalse(string_grad)
+ return float_grad
+ ops.RegisterShape("TestOp")(None)
+
+ c = constant(1.0)
+ x, y = g.create_op("TestOp", [c], [types.float32, types.string]).outputs
+ z = x * 2.0
+ w = z * 3.0
+ grads = gradients.gradients(z, [c])
+ self.assertTrue(isinstance(grads[0], ops.Tensor))
+
+
+class StopGradientTest(test_util.TensorFlowTestCase):
+
+ def testStopGradient(self):
+ with ops.Graph().as_default():
+ inp = constant(1.0, shape=[100, 32], name="in")
+ out = array_ops.stop_gradient(inp)
+ igrad = gradients.gradients(out, inp)[0]
+ assert igrad is None
+
+
+class HessianVectorProductTest(test_util.TensorFlowTestCase):
+
+ def testHessianVectorProduct(self):
+ # Manually compute the Hessian explicitly for a low-dimensional problem
+ # and check that HessianVectorProduct matches multiplication by the
+ # explicit Hessian.
+ # Specifically, the Hessian of f(x) = x^T A x is
+ # H = A + A^T.
+ # We expect HessianVectorProduct(f(x), x, v) to be H v.
+ m = 4
+ rng = np.random.RandomState([1, 2, 3])
+ mat_value = rng.randn(m, m).astype("float32")
+ v_value = rng.randn(m, 1).astype("float32")
+ x_value = rng.randn(m, 1).astype("float32")
+ hess_value = mat_value + mat_value.T
+ hess_v_value = np.dot(hess_value, v_value)
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ mat = constant_op.constant(mat_value)
+ v = constant_op.constant(v_value)
+ x = constant_op.constant(x_value)
+ mat_x = math_ops.matmul(mat, x, name="Ax")
+ x_mat_x = math_ops.matmul(array_ops.transpose(x), mat_x, name="xAx")
+ hess_v = gradients._hessian_vector_product(x_mat_x, [x], [v])[0]
+ hess_v_actual = hess_v.eval()
+ self.assertAllClose(hess_v_value, hess_v_actual)
+
+
+class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
+
+ def testIndexedSlicesToTensor(self):
+ with self.test_session():
+ np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
+ c = constant_op.constant(np_val)
+ c_sparse = math_ops._as_indexed_slices(c)
+ self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
+ c_dense = math_ops.mul(c_sparse, 1.0)
+ self.assertAllClose(np_val, c_dense.eval())
+
+ def testInt64Indices(self):
+ with self.test_session():
+ np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
+ c = constant_op.constant(np_val)
+ c_sparse = math_ops._as_indexed_slices(c)
+ c_sparse = ops.IndexedSlices(
+ c_sparse.values, math_ops.cast(c_sparse.indices, types.int64),
+ c_sparse.dense_shape)
+ self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
+ c_dense = math_ops.mul(c_sparse, 1.0)
+ self.assertAllClose(np_val, c_dense.eval())
+
+ def testWarnings(self):
+ # Smaller than the threshold: no warning.
+ c_sparse = ops.IndexedSlices(array_ops.placeholder(types.float32),
+ array_ops.placeholder(types.int32),
+ constant([4, 4, 4, 4]))
+ with warnings.catch_warnings(record=True) as w:
+ math_ops.mul(c_sparse, 1.0)
+ self.assertEqual(0, len(w))
+
+ # Greater than or equal to the threshold: warning.
+ c_sparse = ops.IndexedSlices(array_ops.placeholder(types.float32),
+ array_ops.placeholder(types.int32),
+ constant([100, 100, 100, 100]))
+ with warnings.catch_warnings(record=True) as w:
+ math_ops.mul(c_sparse, 1.0)
+ self.assertEqual(1, len(w))
+ self.assertTrue(
+ "with 100000000 elements. This may consume a large amount of memory."
+ in str(w[0].message))
+
+ # Unknown dense shape: warning.
+ c_sparse = ops.IndexedSlices(array_ops.placeholder(types.float32),
+ array_ops.placeholder(types.int32),
+ array_ops.placeholder(types.int32))
+ with warnings.catch_warnings(record=True) as w:
+ math_ops.mul(c_sparse, 1.0)
+ self.assertEqual(1, len(w))
+ self.assertTrue(
+ "of unknown shape. This may consume a large amount of memory."
+ in str(w[0].message))
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
new file mode 100644
index 0000000000..1b4f4aef22
--- /dev/null
+++ b/tensorflow/python/ops/image_ops.py
@@ -0,0 +1,786 @@
+"""## Encoding and Decoding.
+
+TensorFlow provides Ops to decode and encode JPEG and PNG formats. Encoded
+images are represented by scalar string Tensors, decoded images by 3-D uint8
+tensors of shape `[height, width, channels]`.
+
+The encode and decode Ops apply to one image at a time. Their input and output
+are all of variable size. If you need fixed size images, pass the output of
+the decode Ops to one of the cropping and resizing Ops.
+
+Note: The PNG encode and decode Ops support RGBA, but the conversions Ops
+presently only support RGB, HSV, and GrayScale.
+
+@@decode_jpeg
+@@encode_jpeg
+
+@@decode_png
+@@encode_png
+
+## Resizing.
+
+The resizing Ops accept input images as tensors of several types. They always
+output resized images as float32 tensors.
+
+The convenience function [resize_images()](#resize_images) supports both 4-D
+and 3-D tensors as input and output. 4-D tensors are for batches of images,
+3-D tensors for individual images.
+
+Other resizing Ops only support 3-D individual images as input:
+[resize_area](#resize_area), [resize_bicubic](#resize_bicubic),
+[resize_bilinear](#resize_bilinear),
+[resize_nearest_neighbor](#resize_nearest_neighbor).
+
+Example:
+
+```python
+# Decode a JPG image and resize it to 299 by 299.
+image = tf.image.decode_jpeg(...)
+resized_image = tf.image.resize_bilinear(image, [299, 299])
+```
+
+<i>Maybe refer to the Queue examples that show how to add images to a Queue
+after resizing them to a fixed size, and how to dequeue batches of resized
+images from the Queue.</i>
+
+@@resize_images
+
+@@resize_area
+@@resize_bicubic
+@@resize_bilinear
+@@resize_nearest_neighbor
+
+
+## Cropping.
+
+@@resize_image_with_crop_or_pad
+
+@@pad_to_bounding_box
+@@crop_to_bounding_box
+@@random_crop
+@@extract_glimpse
+
+## Flipping and Transposing.
+
+@@flip_up_down
+@@random_flip_up_down
+
+@@flip_left_right
+@@random_flip_left_right
+
+@@transpose_image
+
+## Image Adjustments.
+
+TensorFlow provides functions to adjust images in various ways: brightness,
+contrast, hue, and saturation. Each adjustment can be done with predefined
+parameters or with random parameters picked from predefined intervals. Random
+adjustments are often useful to expand a training set and reduce overfitting.
+
+@@adjust_brightness
+@@random_brightness
+
+@@adjust_contrast
+@@random_contrast
+
+@@per_image_whitening
+"""
+import math
+
+import tensorflow.python.platform
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import gen_image_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+
+
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_image_ops import *
+from tensorflow.python.ops.gen_attention_ops import *
+# pylint: enable=wildcard-import
+
+ops.NoGradient('ResizeBilinear')
+ops.NoGradient('RandomCrop')
+
+
+def _ImageDimensions(images):
+ """Returns the dimensions of an image tensor.
+
+ Args:
+ images: 4-D Tensor of shape [batch, height, width, channels]
+
+ Returns:
+ list of integers [batch, height, width, channels]
+ """
+ # A simple abstraction to provide names for each dimension. This abstraction
+ # should make it simpler to switch dimensions in the future (e.g. if we ever
+ # want to switch height and width.)
+ return images.get_shape().as_list()
+
+
+def _Check3DImage(image):
+ """Assert that we are working with properly shaped image.
+
+ Args:
+ image: 3-D Tensor of shape [height, width, channels]
+
+ Raises:
+ ValueError: if image.shape is not a [3] vector.
+ """
+ if not image.get_shape().is_fully_defined():
+ raise ValueError('\'image\' must be fully defined.')
+ if image.get_shape().ndims != 3:
+ raise ValueError('\'image\' must be three-dimensional.')
+ if not all(x > 0 for x in image.get_shape()):
+ raise ValueError('all dims of \'image.shape\' must be > 0: %s' %
+ image.get_shape())
+
+
+def _CheckAtLeast3DImage(image):
+ """Assert that we are working with properly shaped image.
+
+ Args:
+ image: >= 3-D Tensor of size [*, height, width, depth]
+
+ Raises:
+ ValueError: if image.shape is not a [>= 3] vector.
+ """
+ if not image.get_shape().is_fully_defined():
+ raise ValueError('\'image\' must be fully defined.')
+ if image.get_shape().ndims < 3:
+ raise ValueError('\'image\' must be at least three-dimensional.')
+ if not all(x > 0 for x in image.get_shape()):
+ raise ValueError('all dims of \'image.shape\' must be > 0: %s' %
+ image.get_shape())
+
+
+def random_flip_up_down(image, seed=None):
+ """Randomly flips an image vertically (upside down).
+
+ With a 1 in 2 chance, outputs the contents of `image` flipped along the first
+ dimension, which is `height`. Otherwise output the image as-is.
+
+ Args:
+ image: A 3-D tensor of shape `[height, width, channels].`
+ seed: A Python integer. Used to create a random seed.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+
+ Returns:
+ A 3-D tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
+ """
+ _Check3DImage(image)
+ uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
+ mirror = math_ops.less(array_ops.pack([uniform_random, 1.0, 1.0]), 0.5)
+ return array_ops.reverse(image, mirror)
+
+
+def random_flip_left_right(image, seed=None):
+ """Randomly flip an image horizontally (left to right).
+
+ With a 1 in 2 chance, outputs the contents of `image` flipped along the
+ second dimension, which is `width`. Otherwise output the image as-is.
+
+ Args:
+ image: A 3-D tensor of shape `[height, width, channels].`
+ seed: A Python integer. Used to create a random seed.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+
+ Returns:
+ A 3-D tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
+ """
+ _Check3DImage(image)
+ uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
+ mirror = math_ops.less(array_ops.pack([1.0, uniform_random, 1.0]), 0.5)
+ return array_ops.reverse(image, mirror)
+
+
+def flip_left_right(image):
+ """Flip an image horizontally (left to right).
+
+ Outputs the contents of `image` flipped along the second dimension, which is
+ `width`.
+
+ See also `reverse()`.
+
+ Args:
+ image: A 3-D tensor of shape `[height, width, channels].`
+
+ Returns:
+ A 3-D tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
+ """
+ _Check3DImage(image)
+ return array_ops.reverse(image, [False, True, False])
+
+
+def flip_up_down(image):
+ """Flip an image horizontally (upside down).
+
+ Outputs the contents of `image` flipped along the first dimension, which is
+ `height`.
+
+ See also `reverse()`.
+
+ Args:
+ image: A 3-D tensor of shape `[height, width, channels].`
+
+ Returns:
+ A 3-D tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
+ """
+ _Check3DImage(image)
+ return array_ops.reverse(image, [True, False, False])
+
+
+def transpose_image(image):
+ """Transpose an image by swapping the first and second dimension.
+
+ See also `transpose()`.
+
+ Args:
+ image: 3-D tensor of shape `[height, width, channels]`
+
+ Returns:
+ A 3-D tensor of shape `[width, height, channels]`
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
+ """
+ _Check3DImage(image)
+ return array_ops.transpose(image, [1, 0, 2], name='transpose_image')
+
+
+def pad_to_bounding_box(image, offset_height, offset_width, target_height,
+ target_width):
+ """Pad `image` with zeros to the specified `height` and `width`.
+
+ Adds `offset_height` rows of zeros on top, `offset_width` columns of
+ zeros on the left, and then pads the image on the bottom and right
+ with zeros until it has dimensions `target_height`, `target_width`.
+
+ This op does nothing if `offset_*` is zero and the image already has size
+ `target_height` by `target_width`.
+
+ Args:
+ image: 3-D tensor with shape `[height, width, channels]`
+ offset_height: Number of rows of zeros to add on top.
+ offset_width: Number of columns of zeros to add on the left.
+ target_height: Height of output image.
+ target_width: Width of output image.
+
+ Returns:
+ 3-D tensor of shape `[target_height, target_width, channels]`
+
+ Raises:
+ ValueError: If the shape of `image` is incompatible with the `offset_*` or
+ `target_*` arguments
+ """
+ _Check3DImage(image)
+ height, width, depth = _ImageDimensions(image)
+
+ if target_width < width:
+ raise ValueError('target_width must be >= width')
+ if target_height < height:
+ raise ValueError('target_height must be >= height')
+
+ after_padding_width = target_width - offset_width - width
+ after_padding_height = target_height - offset_height - height
+
+ if after_padding_width < 0:
+ raise ValueError('target_width not possible given '
+ 'offset_width and image width')
+ if after_padding_height < 0:
+ raise ValueError('target_height not possible given '
+ 'offset_height and image height')
+
+ # Do not pad on the depth dimensions.
+ if (offset_width or offset_height or after_padding_width or
+ after_padding_height):
+ paddings = [[offset_height, after_padding_height],
+ [offset_width, after_padding_width], [0, 0]]
+ padded = array_ops.pad(image, paddings)
+ padded.set_shape([target_height, target_width, depth])
+ else:
+ padded = image
+
+ return padded
+
+
+def crop_to_bounding_box(image, offset_height, offset_width, target_height,
+ target_width):
+ """Crops an image to a specified bounding box.
+
+ This op cuts a rectangular part out of `image`. The top-left corner of the
+ returned image is at `offset_height, offset_width` in `image`, and its
+ lower-right corner is at
+ `offset_height + target_height, offset_width + target_width'.
+
+ Args:
+ image: 3-D tensor with shape `[height, width, channels]`
+ offset_height: Vertical coordinate of the top-left corner of the result in
+ the input.
+ offset_width: Horizontal coordinate of the top-left corner of the result in
+ the input.
+ target_height: Height of the result.
+ target_width: Width of the result.
+
+ Returns:
+ 3-D tensor of image with shape `[target_height, target_width, channels]`
+
+ Raises:
+ ValueError: If the shape of `image` is incompatible with the `offset_*` or
+ `target_*` arguments
+ """
+ _Check3DImage(image)
+ height, width, _ = _ImageDimensions(image)
+
+ if offset_width < 0:
+ raise ValueError('offset_width must be >= 0.')
+ if offset_height < 0:
+ raise ValueError('offset_height must be >= 0.')
+
+ if width < (target_width + offset_width):
+ raise ValueError('width must be >= target + offset.')
+ if height < (target_height + offset_height):
+ raise ValueError('height must be >= target + offset.')
+
+ cropped = array_ops.slice(image, [offset_height, offset_width, 0],
+ [target_height, target_width, -1])
+
+ return cropped
+
+
+def resize_image_with_crop_or_pad(image, target_height, target_width):
+ """Crops and/or pads an image to a target width and height.
+
+ Resizes an image to a target width and height by either centrally
+ cropping the image or padding it evenly with zeros.
+
+ If `width` or `height` is greater than the specified `target_width` or
+ `target_height` respectively, this op centrally crops along that dimension.
+ If `width` or `height` is smaller than the specified `target_width` or
+ `target_height` respectively, this op centrally pads with 0 along that
+ dimension.
+
+ Args:
+ image: 3-D tensor of shape [height, width, channels]
+ target_height: Target height.
+ target_width: Target width.
+
+ Raises:
+ ValueError: if `target_height` or `target_width` are zero or negative.
+
+ Returns:
+ Cropped and/or padded image of shape
+ `[target_height, target_width, channels]`
+ """
+ _Check3DImage(image)
+ original_height, original_width, _ = _ImageDimensions(image)
+
+ if target_width <= 0:
+ raise ValueError('target_width must be > 0.')
+ if target_height <= 0:
+ raise ValueError('target_height must be > 0.')
+
+ offset_crop_width = 0
+ offset_pad_width = 0
+ if target_width < original_width:
+ offset_crop_width = int((original_width - target_width) / 2)
+ elif target_width > original_width:
+ offset_pad_width = int((target_width - original_width) / 2)
+
+ offset_crop_height = 0
+ offset_pad_height = 0
+ if target_height < original_height:
+ offset_crop_height = int((original_height - target_height) / 2)
+ elif target_height > original_height:
+ offset_pad_height = int((target_height - original_height) / 2)
+
+ # Maybe crop if needed.
+ cropped = crop_to_bounding_box(image, offset_crop_height, offset_crop_width,
+ min(target_height, original_height),
+ min(target_width, original_width))
+
+ # Maybe pad if needed.
+ resized = pad_to_bounding_box(cropped, offset_pad_height, offset_pad_width,
+ target_height, target_width)
+
+ if resized.get_shape().ndims is None:
+ raise ValueError('resized contains no shape.')
+ if not resized.get_shape()[0].is_compatible_with(target_height):
+ raise ValueError('resized height is not correct.')
+ if not resized.get_shape()[1].is_compatible_with(target_width):
+ raise ValueError('resized width is not correct.')
+ return resized
+
+
+class ResizeMethod(object):
+ BILINEAR = 0
+ NEAREST_NEIGHBOR = 1
+ BICUBIC = 2
+ AREA = 3
+
+
+def resize_images(images, new_height, new_width, method=ResizeMethod.BILINEAR):
+ """Resize `images` to `new_width`, `new_height` using the specified `method`.
+
+ Resized images will be distorted if their original aspect ratio is not
+ the same as `new_width`, `new_height`. To avoid distortions see
+ [resize_image_with_crop_or_pad](#resize_image_with_crop_or_pad).
+
+ `method` can be one of:
+
+ * <b>ResizeMethod.BILINEAR</b>: [Bilinear interpolation.]
+ (https://en.wikipedia.org/wiki/Bilinear_interpolation)
+ * <b>ResizeMethod.NEAREST_NEIGHBOR</b>: [Nearest neighbor interpolation.]
+ (https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation)
+ * <b>ResizeMethod.BICUBIC</b>: [Bicubic interpolation.]
+ (https://en.wikipedia.org/wiki/Bicubic_interpolation)
+ * <b>ResizeMethod.AREA</b>: Area interpolation.
+
+ Args:
+ images: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
+ new_height: integer.
+ new_width: integer.
+ method: ResizeMethod. Defaults to `ResizeMethod.BILINEAR`.
+
+ Raises:
+ ValueError: if the shape of `images` is incompatible with the
+ shape arguments to this function
+ ValueError: if an unsupported resize method is specified.
+
+ Returns:
+ If `images` was 4-D, a 4-D float Tensor of shape
+ `[batch, new_height, new_width, channels]`.
+ If `images` was 3-D, a 3-D float Tensor of shape
+ `[new_height, new_width, channels]`.
+ """
+ if images.get_shape().ndims is None:
+ raise ValueError('\'images\' contains no shape.')
+ # TODO(shlens): Migrate this functionality to the underlying Op's.
+ is_batch = True
+ if len(images.get_shape()) == 3:
+ is_batch = False
+ images = array_ops.expand_dims(images, 0)
+
+ _, height, width, depth = _ImageDimensions(images)
+
+ if width == new_width and height == new_height:
+ return images
+
+ if method == ResizeMethod.BILINEAR:
+ images = gen_image_ops.resize_bilinear(images, [new_height, new_width])
+ elif method == ResizeMethod.NEAREST_NEIGHBOR:
+ images = gen_image_ops.resize_nearest_neighbor(images, [new_height,
+ new_width])
+ elif method == ResizeMethod.BICUBIC:
+ images = gen_image_ops.resize_bicubic(images, [new_height, new_width])
+ elif method == ResizeMethod.AREA:
+ images = gen_image_ops.resize_area(images, [new_height, new_width])
+ else:
+ raise ValueError('Resize method is not implemented.')
+
+ if not is_batch:
+ images = array_ops.reshape(images, [new_height, new_width, depth])
+ return images
+
+
+def per_image_whitening(image):
+ """Linearly scales `image` to have zero mean and unit norm.
+
+ This op computes `(x - mean) / adjusted_stddev`, where `mean` is the average
+ of all values in image, and
+ `adjusted_stddev = max(stddev, 1.0/srqt(image.NumElements()))`.
+
+ `stddev` is the standard deviation of all values in `image`. It is capped
+ away from zero to protect against division by 0 when handling uniform images.
+
+ Note that this implementation is limited:
+ * It only whitens based on the statistics of an individual image.
+ * It does not take into account the covariance structure.
+
+ Args:
+ image: 3-D tensor of shape `[height, width, channels]`.
+
+ Returns:
+ The whitened image with same shape as `image`.
+
+ Raises:
+ ValueError: if the shape of 'image' is incompatible with this function.
+ """
+ _Check3DImage(image)
+ height, width, depth = _ImageDimensions(image)
+ num_pixels = height * width * depth
+
+ image = math_ops.cast(image, dtype=types.float32)
+ image_mean = math_ops.reduce_mean(image)
+
+ variance = (math_ops.reduce_mean(math_ops.square(image)) -
+ math_ops.square(image_mean))
+ stddev = math_ops.sqrt(variance)
+
+ # Apply a minimum normalization that protects us against uniform images.
+ min_stddev = constant_op.constant(1.0 / math.sqrt(num_pixels))
+ pixel_value_scale = math_ops.maximum(stddev, min_stddev)
+ pixel_value_offset = image_mean
+
+ image = math_ops.sub(image, pixel_value_offset)
+ image = math_ops.div(image, pixel_value_scale)
+ return image
+
+
+def random_brightness(image, max_delta, seed=None):
+ """Adjust the brightness of images by a random factor.
+
+ Equivalent to `adjust_brightness()` using a `delta` randomly picked in the
+ interval `[-max_delta, max_delta)`.
+
+ Note that `delta` is picked as a float. Because for integer type images,
+ the brightness adjusted result is rounded before casting, integer images may
+ have modifications in the range `[-max_delta,max_delta]`.
+
+ Args:
+ image: 3-D tensor of shape `[height, width, channels]`.
+ max_delta: float, must be non-negative.
+ seed: A Python integer. Used to create a random seed.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+
+ Returns:
+ 3-D tensor of images of shape `[height, width, channels]`
+
+ Raises:
+ ValueError: if max_delta is negative.
+ """
+ _Check3DImage(image)
+
+ if max_delta < 0:
+ raise ValueError('max_delta must be non-negative.')
+
+ delta = random_ops.random_uniform([], -max_delta, max_delta, seed=seed)
+ return adjust_brightness(image, delta)
+
+
+def random_contrast(image, lower, upper, seed=None):
+ """Adjust the contrase of an image by a random factor.
+
+ Equivalent to `adjust_constrast()` but uses a `contrast_factor` randomly
+ picked in the interval `[lower, upper]`.
+
+ Args:
+ image: 3-D tensor of shape `[height, width, channels]`.
+ lower: float. Lower bound for the random contrast factor.
+ upper: float. Upper bound for the random contrast factor.
+ seed: A Python integer. Used to create a random seed.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+
+ Returns:
+ 3-D tensor of shape `[height, width, channels]`.
+
+ Raises:
+ ValueError: if `upper <= lower` or if `lower < 0`.
+ """
+ _Check3DImage(image)
+
+ if upper <= lower:
+ raise ValueError('upper must be > lower.')
+
+ if lower < 0:
+ raise ValueError('lower must be non-negative.')
+
+ # Generate an a float in [lower, upper]
+ contrast_factor = random_ops.random_uniform([], lower, upper, seed=seed)
+ return adjust_contrast(image, contrast_factor)
+
+
+def adjust_brightness(image, delta, min_value=None, max_value=None):
+ """Adjust the brightness of RGB or Grayscale images.
+
+ The value `delta` is added to all components of the tensor `image`. `image`
+ and `delta` are cast to `float` before adding, and the resulting values are
+ clamped to `[min_value, max_value]`. Finally, the result is cast back to
+ `images.dtype`.
+
+ If `min_value` or `max_value` are not given, they are set to the minimum and
+ maximum allowed values for `image.dtype` respectively.
+
+ Args:
+ image: A tensor.
+ delta: A scalar. Amount to add to the pixel values.
+ min_value: Minimum value for output.
+ max_value: Maximum value for output.
+
+ Returns:
+ A tensor of the same shape and type as `image`.
+ """
+ if min_value is None:
+ min_value = image.dtype.min
+ if max_value is None:
+ max_value = image.dtype.max
+
+ with ops.op_scope([image, delta, min_value, max_value], None,
+ 'adjust_brightness') as name:
+ adjusted = math_ops.add(
+ math_ops.cast(image, types.float32),
+ math_ops.cast(delta, types.float32),
+ name=name)
+ if image.dtype.is_integer:
+ rounded = math_ops.round(adjusted)
+ else:
+ rounded = adjusted
+ clipped = clip_ops.clip_by_value(rounded, float(min_value),
+ float(max_value))
+ output = math_ops.cast(clipped, image.dtype)
+ return output
+
+
+def adjust_contrast(images, contrast_factor, min_value=None, max_value=None):
+ """Adjust contrast of RGB or grayscale images.
+
+ `images` is a tensor of at least 3 dimensions. The last 3 dimensions are
+ interpreted as `[height, width, channels]`. The other dimensions only
+ represent a collection of images, such as `[batch, height, width, channels].`
+
+ Contrast is adjusted independently for each channel of each image.
+
+ For each channel, this Op first computes the mean of the image pixels in the
+ channel and then adjusts each component `x` of each pixel to
+ `(x - mean) * contrast_factor + mean`.
+
+ The adjusted values are then clipped to fit in the `[min_value, max_value]`
+ interval. If `min_value` or `max_value` is not given, it is replaced with the
+ minimum and maximum values for the data type of `images` respectively.
+
+ The contrast-adjusted image is always computed as `float`, and it is
+ cast back to its original type after clipping.
+
+ Args:
+ images: Images to adjust. At least 3-D.
+ contrast_factor: A float multiplier for adjusting contrast.
+ min_value: Minimum value for clipping the adjusted pixels.
+ max_value: Maximum value for clipping the adjusted pixels.
+
+ Returns:
+ The constrast-adjusted image or images.
+
+ Raises:
+ ValueError: if the arguments are invalid.
+ """
+ _CheckAtLeast3DImage(images)
+
+ # If these are None, the min/max should be a nop, but still prevent overflows
+ # from the cast back to images.dtype at the end of adjust_contrast.
+ if min_value is None:
+ min_value = images.dtype.min
+ if max_value is None:
+ max_value = images.dtype.max
+
+ with ops.op_scope(
+ [images, contrast_factor, min_value,
+ max_value], None, 'adjust_contrast') as name:
+ adjusted = gen_image_ops.adjust_contrast(images,
+ contrast_factor=contrast_factor,
+ min_value=min_value,
+ max_value=max_value,
+ name=name)
+ if images.dtype.is_integer:
+ return math_ops.cast(math_ops.round(adjusted), images.dtype)
+ else:
+ return math_ops.cast(adjusted, images.dtype)
+
+
+ops.RegisterShape('AdjustContrast')(
+ common_shapes.unchanged_shape_with_rank_at_least(3))
+
+
+@ops.RegisterShape('ResizeBilinear')
+@ops.RegisterShape('ResizeNearestNeighbor')
+@ops.RegisterShape('ResizeBicubic')
+@ops.RegisterShape('ResizeArea')
+def _ResizeShape(op):
+ """Shape function for the resize_bilinear and resize_nearest_neighbor ops."""
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ size = tensor_util.ConstantValue(op.inputs[1])
+ if size is not None:
+ height = size[0]
+ width = size[1]
+ else:
+ height = None
+ width = None
+ return [tensor_shape.TensorShape(
+ [input_shape[0], height, width, input_shape[3]])]
+
+
+@ops.RegisterShape('DecodeJpeg')
+@ops.RegisterShape('DecodePng')
+def _ImageDecodeShape(op):
+ """Shape function for image decoding ops."""
+ unused_input_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ channels = op.get_attr('channels') or None
+ return [tensor_shape.TensorShape([None, None, channels])]
+
+
+@ops.RegisterShape('EncodeJpeg')
+@ops.RegisterShape('EncodePng')
+def _ImageEncodeShape(op):
+ """Shape function for image encoding ops."""
+ unused_input_shape = op.inputs[0].get_shape().with_rank(3)
+ return [tensor_shape.scalar()]
+
+
+@ops.RegisterShape('RandomCrop')
+def _random_cropShape(op):
+ """Shape function for the random_crop op."""
+ input_shape = op.inputs[0].get_shape().with_rank(3)
+ unused_size_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.vector(2))
+ size = tensor_util.ConstantValue(op.inputs[1])
+ if size is not None:
+ height = size[0]
+ width = size[1]
+ else:
+ height = None
+ width = None
+ channels = input_shape[2]
+ return [tensor_shape.TensorShape([height, width, channels])]
+
+
+def random_crop(image, size, seed=None, name=None):
+ """Randomly crops `image` to size `[target_height, target_width]`.
+
+ The offset of the output within `image` is uniformly random. `image` always
+ fully contains the result.
+
+ Args:
+ image: 3-D tensor of shape `[height, width, channels]`
+ size: 1-D tensor with two elements, specifying target `[height, width]`
+ seed: A Python integer. Used to create a random seed.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ name: A name for this operation (optional).
+
+ Returns:
+ A cropped 3-D tensor of shape `[target_height, target_width, channels]`.
+ """
+ seed1, seed2 = random_seed.get_seed(seed)
+ return gen_image_ops.random_crop(image, size, seed=seed1, seed2=seed2,
+ name=name)
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
new file mode 100644
index 0000000000..2c51299198
--- /dev/null
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -0,0 +1,771 @@
+"""Tests for tensorflow.ops.image_ops."""
+import math
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import image_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.platform import googletest
+
+
+class FlipTest(test_util.TensorFlowTestCase):
+
+ def testIdempotentLeftRight(self):
+ x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.flip_left_right(image_ops.flip_left_right(x_tf))
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
+ def testLeftRight(self):
+ x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
+ y_np = np.array([[3, 2, 1], [3, 2, 1]], dtype=np.uint8).reshape([2, 3, 1])
+
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.flip_left_right(x_tf)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
+ def testIdempotentUpDown(self):
+ x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
+
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.flip_up_down(image_ops.flip_up_down(x_tf))
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
+ def testUpDown(self):
+ x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
+ y_np = np.array([[4, 5, 6], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
+
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.flip_up_down(x_tf)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
+ def testIdempotentTranspose(self):
+ x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
+
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.transpose_image(image_ops.transpose_image(x_tf))
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
+ def testTranspose(self):
+ x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
+ y_np = np.array([[1, 4], [2, 5], [3, 6]], dtype=np.uint8).reshape([3, 2, 1])
+
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.transpose_image(x_tf)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
+
+class RandomFlipTest(test_util.TensorFlowTestCase):
+
+ def testRandomLeftRight(self):
+ x_np = np.array([0, 1], dtype=np.uint8).reshape([1, 2, 1])
+ num_iterations = 500
+
+ hist = [0, 0]
+ with self.test_session():
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.random_flip_left_right(x_tf)
+ for _ in xrange(num_iterations):
+ y_np = y.eval().flatten()[0]
+ hist[y_np] += 1
+
+ # Ensure that each entry is observed within 4 standard deviations.
+ four_stddev = 4.0 * np.sqrt(num_iterations / 2.0)
+ self.assertAllClose(hist, [num_iterations / 2.0] * 2, atol=four_stddev)
+
+ def testRandomUpDown(self):
+ x_np = np.array([0, 1], dtype=np.uint8).reshape([2, 1, 1])
+ num_iterations = 500
+
+ hist = [0, 0]
+ with self.test_session():
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.random_flip_up_down(x_tf)
+ for _ in xrange(num_iterations):
+ y_np = y.eval().flatten()[0]
+ hist[y_np] += 1
+
+ # Ensure that each entry is observed within 4 standard deviations.
+ four_stddev = 4.0 * np.sqrt(num_iterations / 2.0)
+ self.assertAllClose(hist, [num_iterations / 2.0] * 2, atol=four_stddev)
+
+
+class AdjustContrastTest(test_util.TensorFlowTestCase):
+
+ def _testContrast(self, x_np, y_np, contrast_factor, min_value, max_value):
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ x = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.adjust_contrast(x,
+ contrast_factor,
+ min_value=min_value,
+ max_value=max_value)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
+ def testDoubleContrastUint8(self):
+ x_shape = [1, 2, 2, 3]
+ x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
+ x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
+
+ y_data = [0, 0, 0, 63, 169, 255, 29, 0, 255, 135, 255, 0]
+ y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
+
+ self._testContrast(x_np,
+ y_np,
+ contrast_factor=2.0,
+ min_value=None,
+ max_value=None)
+
+ def testDoubleContrastFloat(self):
+ x_shape = [1, 2, 2, 3]
+ x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
+ x_np = np.array(x_data, dtype=np.float).reshape(x_shape)
+
+ y_data = [0, 0, 0, 62.75, 169.25, 255, 28.75, 0, 255, 134.75, 255, 0]
+ y_np = np.array(y_data, dtype=np.float).reshape(x_shape)
+
+ self._testContrast(x_np,
+ y_np,
+ contrast_factor=2.0,
+ min_value=0,
+ max_value=255)
+
+ def testHalfContrastUint8(self):
+ x_shape = [1, 2, 2, 3]
+ x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
+ x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
+
+ y_data = [23, 53, 66, 50, 118, 172, 41, 54, 176, 68, 178, 60]
+ y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
+
+ self._testContrast(x_np,
+ y_np,
+ contrast_factor=0.5,
+ min_value=None,
+ max_value=None)
+
+ def testBatchDoubleContrast(self):
+ x_shape = [2, 1, 2, 3]
+ x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
+ x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
+
+ y_data = [0, 0, 0, 81, 200, 255, 11, 0, 255, 117, 255, 0]
+ y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
+
+ self._testContrast(x_np,
+ y_np,
+ contrast_factor=2.0,
+ min_value=None,
+ max_value=None)
+
+
+class AdjustBrightnessTest(test_util.TensorFlowTestCase):
+
+ def _testBrightness(self, x_np, y_np, delta, min_value, max_value):
+ with self.test_session():
+ x = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.adjust_brightness(x,
+ delta,
+ min_value=min_value,
+ max_value=max_value)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
+ def testPositiveDeltaUint8(self):
+ x_shape = [2, 2, 3]
+ x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
+ x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
+
+ y_data = [10, 15, 23, 64, 145, 236, 47, 18, 244, 100, 255, 11]
+ y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
+
+ self._testBrightness(x_np, y_np, delta=10.0, min_value=None, max_value=None)
+
+ def testPositiveDeltaFloat(self):
+ x_shape = [2, 2, 3]
+ x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
+ x_np = np.array(x_data, dtype=np.float32).reshape(x_shape)
+
+ y_data = [10, 15, 23, 64, 145, 236, 47, 18, 244, 100, 265, 11]
+ y_np = np.array(y_data, dtype=np.float32).reshape(x_shape)
+
+ self._testBrightness(x_np, y_np, delta=10.0, min_value=None, max_value=None)
+
+ def testNegativeDelta(self):
+ x_shape = [2, 2, 3]
+ x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
+ x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
+
+ y_data = [5, 5, 5, 44, 125, 216, 27, 5, 224, 80, 245, 5]
+ y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
+
+ self._testBrightness(x_np, y_np, delta=-10.0, min_value=5, max_value=None)
+
+
+class RandomCropTest(test_util.TensorFlowTestCase):
+
+ def testNoOp(self):
+ # No random cropping is performed since the target width and height
+ # are match the image dimensions.
+ height = 4
+ width = 5
+ x_shape = [height, width, 3]
+ x_np = np.arange(0, np.prod(x_shape), dtype=np.int32).reshape(x_shape)
+ target_shape_np = np.array([height, width], dtype=np.int64)
+
+ with self.test_session():
+ x = constant_op.constant(x_np, shape=x_shape)
+ target_shape = constant_op.constant(target_shape_np, shape=[2])
+ y = image_ops.random_crop(x, target_shape)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
+ def testRandomization(self):
+ # Run 1x1 crop num_samples times in an image and ensure that one finds each
+ # pixel 1/num_pixels of the time.
+ num_samples = 1000
+ height = 5
+ width = 4
+
+ num_pixels = height * width
+ data = np.arange(num_pixels).reshape([height, width, 1])
+ x_np = np.array(data).astype(np.int32)
+
+ target_shape_np = np.array([1, 1], dtype=np.int64)
+
+ y = []
+ with self.test_session():
+ x = constant_op.constant(x_np, shape=x_np.shape)
+ target_shape = constant_op.constant(target_shape_np, shape=[2])
+ y_tf = image_ops.random_crop(x, target_shape)
+ for _ in xrange(num_samples):
+ y_np = y_tf.eval()
+ self.assertAllEqual(y_np.shape, [1, 1, 1])
+ y.extend(y_np.flatten())
+
+ # Calculate the mean and 4 * standard deviation.
+ mean = [num_samples / num_pixels] * num_pixels
+ four_stddev = 4.0 * np.sqrt(mean)
+
+ # Ensure that each entry is observed in 1/num_pixels of the samples
+ # within 4 standard deviations.
+ counts = np.bincount(y)
+ self.assertAllClose(counts, mean, atol=four_stddev)
+
+
+class PerImageWhiteningTest(test_util.TensorFlowTestCase):
+
+ def _NumpyPerImageWhitening(self, x):
+ num_pixels = np.prod(x.shape)
+ x2 = np.square(x).astype(np.float32)
+ mn = np.mean(x)
+ vr = np.mean(x2) - (mn * mn)
+ stddev = max(math.sqrt(vr), 1.0 / math.sqrt(num_pixels))
+
+ y = x.astype(np.float32)
+ y -= mn
+ y /= stddev
+ return y
+
+ def testBasic(self):
+ x_shape = [13, 9, 3]
+ x_np = np.arange(0, np.prod(x_shape), dtype=np.int32).reshape(x_shape)
+ y_np = self._NumpyPerImageWhitening(x_np)
+
+ with self.test_session():
+ x = constant_op.constant(x_np, shape=x_shape)
+ y = image_ops.per_image_whitening(x)
+ y_tf = y.eval()
+ self.assertAllClose(y_tf, y_np, atol=1e-4)
+
+
+class CropToBoundingBoxTest(test_util.TensorFlowTestCase):
+
+ def testNoOp(self):
+ x_shape = [13, 9, 3]
+ x_np = np.ones(x_shape, dtype=np.float32)
+
+ with self.test_session():
+ x = constant_op.constant(x_np, shape=x_shape)
+ target_height = x_shape[0]
+ target_width = x_shape[1]
+ y = image_ops.crop_to_bounding_box(x, 0, 0, target_height, target_width)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
+ def testCropping(self):
+ x_np = np.arange(0, 30, dtype=np.int32).reshape([6, 5, 1])
+
+ offset_height = 1
+ after_height = 2
+
+ offset_width = 0
+ after_width = 3
+
+ target_height = x_np.shape[0] - offset_height - after_height
+ target_width = x_np.shape[1] - offset_width - after_width
+
+ y_np = x_np[offset_height:offset_height + target_height,
+ offset_width:offset_width + target_width, :]
+
+ with self.test_session():
+ x = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.crop_to_bounding_box(x, offset_height, offset_width,
+ target_height, target_width)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf.flatten(), y_np.flatten())
+
+
+class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
+
+ def testNoOp(self):
+ x_shape = [13, 9, 3]
+ x_np = np.ones(x_shape, dtype=np.float32)
+
+ target_height = x_shape[0]
+ target_width = x_shape[1]
+
+ with self.test_session():
+ x = constant_op.constant(x_np, shape=x_shape)
+ y = image_ops.pad_to_bounding_box(x, 0, 0, target_height, target_width)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
+ def testPadding(self):
+ x_shape = [3, 4, 1]
+ x_np = np.ones(x_shape, dtype=np.float32)
+
+ offset_height = 2
+ after_height = 3
+
+ offset_width = 1
+ after_width = 4
+
+ target_height = x_shape[0] + offset_height + after_height
+ target_width = x_shape[1] + offset_width + after_width
+
+ # Note the padding are along batch, height, width and depth.
+ paddings = ((offset_height, after_height),
+ (offset_width, after_width),
+ (0, 0))
+
+ y_np = np.pad(x_np, paddings, 'constant')
+
+ with self.test_session():
+ x = constant_op.constant(x_np, shape=x_shape)
+ y = image_ops.pad_to_bounding_box(x, offset_height, offset_width,
+ target_height, target_width)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
+
+class ResizeImagesTest(test_util.TensorFlowTestCase):
+
+ OPTIONS = [image_ops.ResizeMethod.BILINEAR,
+ image_ops.ResizeMethod.NEAREST_NEIGHBOR,
+ image_ops.ResizeMethod.BICUBIC,
+ image_ops.ResizeMethod.AREA]
+
+ def testNoOp(self):
+ img_shape = [1, 6, 4, 1]
+ data = [128, 128, 64, 64,
+ 128, 128, 64, 64,
+ 64, 64, 128, 128,
+ 64, 64, 128, 128,
+ 50, 50, 100, 100,
+ 50, 50, 100, 100]
+ img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
+
+ target_height = 6
+ target_width = 4
+
+ for opt in self.OPTIONS:
+ with self.test_session():
+ image = constant_op.constant(img_np, shape=img_shape)
+ y = image_ops.resize_images(image, target_height, target_width, opt)
+ resized = y.eval()
+ self.assertAllClose(resized, img_np, atol=1e-5)
+
+ def testResizeDown(self):
+
+ data = [128, 128, 64, 64,
+ 128, 128, 64, 64,
+ 64, 64, 128, 128,
+ 64, 64, 128, 128,
+ 50, 50, 100, 100,
+ 50, 50, 100, 100]
+ expected_data = [128, 64,
+ 64, 128,
+ 50, 100]
+ target_height = 3
+ target_width = 2
+
+ # Test out 3-D and 4-D image shapes.
+ img_shapes = [[1, 6, 4, 1], [6, 4, 1]]
+ target_shapes = [[1, target_height, target_width, 1],
+ [target_height, target_width, 1]]
+
+ for target_shape, img_shape in zip(target_shapes, img_shapes):
+ img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
+
+ for opt in self.OPTIONS:
+ with self.test_session():
+ image = constant_op.constant(img_np, shape=img_shape)
+ y = image_ops.resize_images(image, target_height, target_width, opt)
+ expected = np.array(expected_data).reshape(target_shape)
+ resized = y.eval()
+ self.assertAllClose(resized, expected, atol=1e-5)
+
+ def testResizeUp(self):
+ img_shape = [1, 3, 2, 1]
+ data = [128, 64,
+ 64, 128,
+ 50, 100]
+ img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
+
+ target_height = 6
+ target_width = 4
+ expected_data = {}
+ expected_data[image_ops.ResizeMethod.BILINEAR] = [
+ 128.0, 96.0, 64.0, 64.0,
+ 96.0, 96.0, 96.0, 96.0,
+ 64.0, 96.0, 128.0, 128.0,
+ 57.0, 85.5, 114.0, 114.0,
+ 50.0, 75.0, 100.0, 100.0,
+ 50.0, 75.0, 100.0, 100.0]
+ expected_data[image_ops.ResizeMethod.NEAREST_NEIGHBOR] = [
+ 128.0, 128.0, 64.0, 64.0,
+ 128.0, 128.0, 64.0, 64.0,
+ 64.0, 64.0, 128.0, 128.0,
+ 64.0, 64.0, 128.0, 128.0,
+ 50.0, 50.0, 100.0, 100.0,
+ 50.0, 50.0, 100.0, 100.0]
+ expected_data[image_ops.ResizeMethod.AREA] = [
+ 128.0, 128.0, 64.0, 64.0,
+ 128.0, 128.0, 64.0, 64.0,
+ 64.0, 64.0, 128.0, 128.0,
+ 64.0, 64.0, 128.0, 128.0,
+ 50.0, 50.0, 100.0, 100.0,
+ 50.0, 50.0, 100.0, 100.0]
+
+ for opt in [
+ image_ops.ResizeMethod.BILINEAR,
+ image_ops.ResizeMethod.NEAREST_NEIGHBOR,
+ image_ops.ResizeMethod.AREA]:
+ with self.test_session():
+ image = constant_op.constant(img_np, shape=img_shape)
+ y = image_ops.resize_images(image, target_height, target_width, opt)
+ resized = y.eval()
+ expected = np.array(expected_data[opt]).reshape(
+ [1, target_height, target_width, 1])
+ self.assertAllClose(resized, expected, atol=1e-05)
+
+ def testResizeUpBicubic(self):
+ img_shape = [1, 6, 6, 1]
+ data = [128, 128, 64, 64, 128, 128, 64, 64,
+ 64, 64, 128, 128, 64, 64, 128, 128,
+ 50, 50, 100, 100, 50, 50, 100, 100,
+ 50, 50, 100, 100, 50, 50, 100, 100,
+ 50, 50, 100, 100]
+ img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
+
+ target_height = 8
+ target_width = 8
+ expected_data = [128, 135, 96, 55, 64, 114, 134, 128,
+ 78, 81, 68, 52, 57, 118, 144, 136,
+ 55, 49, 79, 109, 103, 89, 83, 84,
+ 74, 70, 95, 122, 115, 69, 49, 55,
+ 100, 105, 75, 43, 50, 89, 105, 100,
+ 57, 54, 74, 96, 91, 65, 55, 58,
+ 70, 69, 75, 81, 80, 72, 69, 70,
+ 105, 112, 75, 36, 45, 92, 111, 105]
+
+ with self.test_session():
+ image = constant_op.constant(img_np, shape=img_shape)
+ y = image_ops.resize_images(image, target_height, target_width,
+ image_ops.ResizeMethod.BICUBIC)
+ resized = y.eval()
+ expected = np.array(expected_data).reshape(
+ [1, target_height, target_width, 1])
+ self.assertAllClose(resized, expected, atol=1)
+
+ def testResizeDownArea(self):
+ img_shape = [1, 6, 6, 1]
+ data = [128, 64, 32, 16, 8, 4,
+ 4, 8, 16, 32, 64, 128,
+ 128, 64, 32, 16, 8, 4,
+ 5, 10, 15, 20, 25, 30,
+ 30, 25, 20, 15, 10, 5,
+ 5, 10, 15, 20, 25, 30]
+ img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
+
+ target_height = 4
+ target_width = 4
+ expected_data = [73, 33, 23, 39,
+ 73, 33, 23, 39,
+ 14, 16, 19, 21,
+ 14, 16, 19, 21]
+
+ with self.test_session():
+ image = constant_op.constant(img_np, shape=img_shape)
+ y = image_ops.resize_images(image, target_height, target_width,
+ image_ops.ResizeMethod.AREA)
+ expected = np.array(expected_data).reshape(
+ [1, target_height, target_width, 1])
+ resized = y.eval()
+ self.assertAllClose(resized, expected, atol=1)
+
+
+class ResizeImageWithCropOrPadTest(test_util.TensorFlowTestCase):
+
+ def _ResizeImageWithCropOrPad(self, original, original_shape,
+ expected, expected_shape):
+ x_np = np.array(original, dtype=np.uint8).reshape(original_shape)
+ y_np = np.array(expected).reshape(expected_shape)
+
+ target_height = expected_shape[0]
+ target_width = expected_shape[1]
+
+ with self.test_session():
+ image = constant_op.constant(x_np, shape=original_shape)
+ y = image_ops.resize_image_with_crop_or_pad(image,
+ target_height,
+ target_width)
+ resized = y.eval()
+ self.assertAllClose(resized, y_np, atol=1e-5)
+
+ def testBasic(self):
+ # Basic no-op.
+ original = [1, 2, 3, 4,
+ 5, 6, 7, 8]
+ self._ResizeImageWithCropOrPad(original, [2, 4, 1],
+ original, [2, 4, 1])
+
+ def testPad(self):
+ # Pad even along col.
+ original = [1, 2, 3, 4, 5, 6, 7, 8]
+ expected = [0, 1, 2, 3, 4, 0,
+ 0, 5, 6, 7, 8, 0]
+ self._ResizeImageWithCropOrPad(original, [2, 4, 1],
+ expected, [2, 6, 1])
+ # Pad odd along col.
+ original = [1, 2, 3, 4,
+ 5, 6, 7, 8]
+ expected = [0, 1, 2, 3, 4, 0, 0,
+ 0, 5, 6, 7, 8, 0, 0]
+ self._ResizeImageWithCropOrPad(original, [2, 4, 1],
+ expected, [2, 7, 1])
+
+ # Pad even along row.
+ original = [1, 2, 3, 4,
+ 5, 6, 7, 8]
+ expected = [0, 0, 0, 0,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0]
+ self._ResizeImageWithCropOrPad(original, [2, 4, 1],
+ expected, [4, 4, 1])
+ # Pad odd along row.
+ original = [1, 2, 3, 4,
+ 5, 6, 7, 8]
+ expected = [0, 0, 0, 0,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0]
+ self._ResizeImageWithCropOrPad(original, [2, 4, 1],
+ expected, [5, 4, 1])
+
+ def testCrop(self):
+ # Crop even along col.
+ original = [1, 2, 3, 4,
+ 5, 6, 7, 8]
+ expected = [2, 3,
+ 6, 7]
+ self._ResizeImageWithCropOrPad(original, [2, 4, 1],
+ expected, [2, 2, 1])
+ # Crop odd along col.
+
+ original = [1, 2, 3, 4, 5, 6,
+ 7, 8, 9, 10, 11, 12]
+ expected = [2, 3, 4,
+ 8, 9, 10]
+ self._ResizeImageWithCropOrPad(original, [2, 6, 1],
+ expected, [2, 3, 1])
+
+ # Crop even along row.
+ original = [1, 2,
+ 3, 4,
+ 5, 6,
+ 7, 8]
+ expected = [3, 4,
+ 5, 6]
+ self._ResizeImageWithCropOrPad(original, [4, 2, 1],
+ expected, [2, 2, 1])
+
+ # Crop odd along row.
+ original = [1, 2,
+ 3, 4,
+ 5, 6,
+ 7, 8,
+ 9, 10,
+ 11, 12,
+ 13, 14,
+ 15, 16]
+ expected = [3, 4,
+ 5, 6,
+ 7, 8,
+ 9, 10,
+ 11, 12]
+ self._ResizeImageWithCropOrPad(original, [8, 2, 1],
+ expected, [5, 2, 1])
+
+ def testCropAndPad(self):
+ # Pad along row but crop along col.
+ original = [1, 2, 3, 4,
+ 5, 6, 7, 8]
+ expected = [0, 0,
+ 2, 3,
+ 6, 7,
+ 0, 0]
+ self._ResizeImageWithCropOrPad(original, [2, 4, 1],
+ expected, [4, 2, 1])
+
+ # Crop along row but pad along col.
+ original = [1, 2,
+ 3, 4,
+ 5, 6,
+ 7, 8]
+ expected = [0, 3, 4, 0,
+ 0, 5, 6, 0]
+ self._ResizeImageWithCropOrPad(original, [4, 2, 1],
+ expected, [2, 4, 1])
+
+
+def _SimpleColorRamp():
+ """Build a simple color ramp RGB image."""
+ w, h = 256, 200
+ i = np.arange(h)[:, None]
+ j = np.arange(w)
+ image = np.empty((h, w, 3), dtype=np.uint8)
+ image[:, :, 0] = i
+ image[:, :, 1] = j
+ image[:, :, 2] = (i + j) >> 1
+ return image
+
+
+class JpegTest(test_util.TensorFlowTestCase):
+
+ # TODO(irving): Add self.assertAverageLess or similar to test_util
+ def averageError(self, image0, image1):
+ self.assertEqual(image0.shape, image1.shape)
+ image0 = image0.astype(int) # Avoid overflow
+ return np.abs(image0 - image1).sum() / float(np.prod(image0.shape))
+
+ def testExisting(self):
+ # Read a real jpeg and verify shape
+ path = ('tensorflow/core/lib/jpeg/testdata/'
+ 'jpeg_merge_test1.jpg')
+ with self.test_session() as sess:
+ jpeg0 = io_ops.read_file(path)
+ image0 = image_ops.decode_jpeg(jpeg0)
+ image1 = image_ops.decode_jpeg(image_ops.encode_jpeg(image0))
+ jpeg0, image0, image1 = sess.run([jpeg0, image0, image1])
+ self.assertEqual(len(jpeg0), 3771)
+ self.assertEqual(image0.shape, (256, 128, 3))
+ self.assertLess(self.averageError(image0, image1), 0.8)
+
+ def testSynthetic(self):
+ with self.test_session() as sess:
+ # Encode it, then decode it, then encode it
+ image0 = constant_op.constant(_SimpleColorRamp())
+ jpeg0 = image_ops.encode_jpeg(image0)
+ image1 = image_ops.decode_jpeg(jpeg0)
+ image2 = image_ops.decode_jpeg(image_ops.encode_jpeg(image1))
+ jpeg0, image0, image1, image2 = sess.run([jpeg0, image0, image1, image2])
+
+ # The decoded-encoded image should be similar to the input
+ self.assertLess(self.averageError(image0, image1), 0.6)
+
+ # We should be very close to a fixpoint
+ self.assertLess(self.averageError(image1, image2), 0.02)
+
+ # Smooth ramps compress well (input size is 153600)
+ self.assertGreaterEqual(len(jpeg0), 5000)
+ self.assertLessEqual(len(jpeg0), 6000)
+
+ def testShape(self):
+ with self.test_session() as sess:
+ jpeg = constant_op.constant('nonsense')
+ for channels in 0, 1, 3:
+ image = image_ops.decode_jpeg(jpeg, channels=channels)
+ self.assertEqual(image.get_shape().as_list(),
+ [None, None, channels or None])
+
+
+class PngTest(test_util.TensorFlowTestCase):
+
+ def testExisting(self):
+ # Read some real PNGs, converting to different channel numbers
+ prefix = 'tensorflow/core/lib/png/testdata/'
+ inputs = (1, 'lena_gray.png'), (4, 'lena_rgba.png')
+ for channels_in, filename in inputs:
+ for channels in 0, 1, 3, 4:
+ with self.test_session() as sess:
+ png0 = io_ops.read_file(prefix + filename)
+ image0 = image_ops.decode_png(png0, channels=channels)
+ png0, image0 = sess.run([png0, image0])
+ self.assertEqual(image0.shape, (26, 51, channels or channels_in))
+ if channels == channels_in:
+ image1 = image_ops.decode_png(image_ops.encode_png(image0))
+ self.assertAllEqual(image0, image1.eval())
+
+ def testSynthetic(self):
+ with self.test_session() as sess:
+ # Encode it, then decode it
+ image0 = constant_op.constant(_SimpleColorRamp())
+ png0 = image_ops.encode_png(image0, compression=7)
+ image1 = image_ops.decode_png(png0)
+ png0, image0, image1 = sess.run([png0, image0, image1])
+
+ # PNG is lossless
+ self.assertAllEqual(image0, image1)
+
+ # Smooth ramps compress well, but not too well
+ self.assertGreaterEqual(len(png0), 400)
+ self.assertLessEqual(len(png0), 750)
+
+ def testShape(self):
+ with self.test_session() as sess:
+ png = constant_op.constant('nonsense')
+ for channels in 0, 1, 3:
+ image = image_ops.decode_png(png, channels=channels)
+ self.assertEqual(image.get_shape().as_list(),
+ [None, None, channels or None])
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
new file mode 100644
index 0000000000..09c8801e0e
--- /dev/null
+++ b/tensorflow/python/ops/init_ops.py
@@ -0,0 +1,181 @@
+"""Operations often used for initializing tensors."""
+
+import math
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+
+
+# TODO(mrry): PEP8 these.
+def constant_initializer(value=0.0):
+ """Returns an initializer that generates Tensors with a single value.
+
+ Args:
+ value: A Python scalar. All elements of the initialized variable
+ will be set to this value.
+
+ Returns:
+ An initializer that generates Tensors with a single value.
+ """
+ def _initializer(shape, dtype=types.float32):
+ return constant_op.constant(value, dtype=dtype, shape=shape)
+ return _initializer
+
+def random_uniform_initializer(minval=0.0, maxval=1.0, seed=None):
+ """Returns an initializer that generates Tensors with a uniform distribution.
+
+ Args:
+ minval: a python scalar or a scalar tensor. lower bound of the range
+ of random values to generate.
+ maxval: a python scalar or a scalar tensor. upper bound of the range
+ of random values to generate.
+ seed: A Python integer. Used to create random seeds.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+
+ Returns:
+ An initializer that generates Tensors with a uniform distribution.
+ """
+ def _initializer(shape, dtype=types.float32):
+ return random_ops.random_uniform(shape, minval, maxval, dtype, seed=seed)
+ return _initializer
+
+def random_normal_initializer(mean=0.0, stddev=1.0, seed=None):
+ """Returns an initializer that generates Tensors with a normal distribution.
+
+ Args:
+ mean: a python scalar or a scalar tensor. Mean of the random values
+ to generate.
+ stddev: a python scalar or a scalar tensor. Standard deviation of the
+ random values to generate.
+ seed: A Python integer. Used to create random seeds.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+
+ Returns:
+ An initializer that generates Tensors with a normal distribution.
+ """
+ def _initializer(shape, dtype=types.float32):
+ return random_ops.random_normal(shape, mean, stddev, dtype, seed=seed)
+ return _initializer
+
+def truncated_normal_initializer(mean=0.0, stddev=1.0, seed=None):
+ """Returns an initializer that generates a truncated normal distribution.
+
+ These values are similar to values from a random_normal_initializer
+ except that values more than two standard deviations from the mean
+ are discarded and re-drawn. This is the recommended initializer for
+ neural network weights and filters.
+
+ Args:
+ mean: a python scalar or a scalar tensor. Mean of the random values
+ to generate.
+ stddev: a python scalar or a scalar tensor. Standard deviation of the
+ random values to generate.
+ seed: A Python integer. Used to create random seeds.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+
+ Returns:
+ An initializer that generates Tensors with a truncated normal
+ distribution.
+ """
+ def _initializer(shape, dtype=types.float32):
+ return random_ops.truncated_normal(shape, mean, stddev, dtype, seed=seed)
+ return _initializer
+
+def uniform_unit_scaling_initializer(factor=1.0, seed=None):
+ """Returns an initializer that generates tensors without scaling variance.
+
+ When initializing a deep network, it is in principle advantageous to keep
+ the scale of the input variance constant, so it does not explode or diminish
+ by reaching the final layer. If the input is `x` and the operation `x * W`,
+ and we want to initialize `W` uniformly at random, we need to pick `W` from
+
+ [-sqrt(3) / sqrt(dim), sqrt(3) / sqrt(dim)]
+
+ to keep the scale intact, where `dim = W.shape[0]` (the size of the input).
+ A similar calculation for convolutional networks gives an analogous result
+ with `dim` equal to the product of the first 3 dimensions. When
+ nonlinearities are present, we need to multiply this by a constant `factor`.
+ See <https://arxiv.org/pdf/1412.6558v3.pdf> for deeper motivation, experiments
+ and the calculation of constants. In section 2.3 there, the constants were
+ numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15.
+
+ Args:
+ factor: Float. A multiplicative factor by which the values will be scaled.
+ seed: A Python integer. Used to create random seeds.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+
+ Returns:
+ An initializer that generates tensors with unit variance.
+ """
+ def _initializer(shape, dtype=types.float32):
+ input_size = 1.0
+ # Estimating input size is not possible to do perfectly, but we try.
+ # The estimate, obtained by multiplying all dimensions but the last one,
+ # is the right thing for matrix multiply and convolutions (see above).
+ for dim in shape[:-1]:
+ input_size *= float(dim)
+ max_val = math.sqrt(float(3) / float(input_size)) * factor
+ return random_ops.random_uniform(shape, -max_val, max_val,
+ dtype, seed=seed)
+ return _initializer
+
+# TODO(vrv): Unhide when we are ready to expose this publicly.
+def _random_walk(shape, nonlinearity, dtype=types.float32, seed=None,
+ name="random_walk"):
+ """Create a random tensor such that backprop neither vanishes nor explodes.
+
+ Args:
+ shape: a python array of int or a 1-d tensor. Sizes of the Tensor.
+ nonlinearity: the brain python function for implementing the
+ nonlinearity in tensor flow.
+ dtype: The type of the output.
+ seed: A Python integer. Used to create random seeds.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ name: string. Optional name for the op.
+
+ Returns:
+ A Tensor of the specified sizes filled with random values.
+ """
+ assert len(shape) == 2, "Random Walk initialization only supports 2D tensors."
+ num_inputs = shape[0]
+ if nonlinearity == math_ops.tanh:
+ # No real formula for this case yet, but this works well for many
+ # layer widths.
+ rwg = 1.13
+ elif nonlinearity == array_ops.identity:
+ rwg = math.exp(1.0 / float(2.0 * num_inputs))
+ elif nonlinearity == nn_ops.relu:
+ rwg = math.sqrt(2.0) * math.exp(1.2 / float(max(num_inputs, 6) - 2.4))
+ else:
+ assert False, "Unsupported nonlinearity for Random Walk initialization."
+
+ mean = 0.0
+ stddev = rwg / math.sqrt(float(num_inputs))
+
+ return random_ops.random_normal(shape, mean=mean, stddev=stddev, dtype=dtype,
+ seed=seed, name=name)
+
+
+# TODO(vrv): Unhide when we are ready to expose this publicly.
+class _RandomWalkInitializer(object):
+ """An Initializer that generates a tensor for Random Walk Initialization."""
+
+ def __init__(self, nonlinearity, seed=None):
+ """Construct a RandomWalkInitializer.
+
+ Args:
+ nonlinearity: the python tensorflow function that computes a nonlinearity
+ in the graph, typically after a Wx+b type operation.
+ seed: A Python integer. Used to create random seeds.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ """
+ self._nonlinearity = nonlinearity
+ self._seed = seed
+
+ def __call__(self, shape, dtype=types.float32):
+ """Generate a tensor used to initialize a variable."""
+ return random_ops._random_walk(shape, self._nonlinearity, dtype,
+ seed=self._seed)
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
new file mode 100644
index 0000000000..9eb3bdfae4
--- /dev/null
+++ b/tensorflow/python/ops/io_ops.py
@@ -0,0 +1,541 @@
+"""## Placeholders
+
+TensorFlow provides a placeholder operation that must be fed with data
+on execution. For more info, see the section on [Feeding
+data](../../how_tos/reading_data/index.md#feeding).
+
+@@placeholder
+
+## Readers
+
+TensorFlow provides a set of Reader classes for reading data formats.
+For more information on inputs and readers, see [Reading
+data](../../how_tos/reading_data/index.md).
+
+@@ReaderBase
+@@TextLineReader
+@@WholeFileReader
+@@IdentityReader
+@@TFRecordReader
+@@FixedLengthRecordReader
+
+## Converting
+
+TensorFlow provides several operations that you can use to convert various data
+formats into tensors.
+
+@@decode_csv
+@@decode_raw
+@@parse_example
+@@parse_single_example
+
+## Queues
+
+TensorFlow provides several implementations of 'Queues', which are
+structures within the TensorFlow computation graph to stage pipelines
+of tensors together. The following describe the basic Queue interface
+and some implementations. To see an example use, see [Threading and
+Queues](../../how_tos/threading_and_queues/index.md).
+
+@@QueueBase
+@@FIFOQueue
+@@RandomShuffleQueue
+
+## Dealing with the filesystem
+
+@@matching_files
+@@read_file
+
+## Input pipeline
+
+TensorFlow functions for setting up an input-prefetching pipeline.
+Please see the [reading data how-to](../../how_tos/reading_data.md)
+for context.
+
+### Beginning of an input pipeline
+
+The "producer" functions add a queue to the graph and a corresponding
+`QueueRunner` for running the subgraph that fills that queue.
+
+@@match_filenames_once
+@@limit_epochs
+@@range_input_producer
+@@slice_input_producer
+@@string_input_producer
+
+### Batching at the end of an input pipeline
+
+These functions add a queue to the graph to assemble a batch of examples, with
+possible shuffling. They also add a `QueueRunner` for running the subgraph
+that fills that queue.
+
+Use [batch](#batch) or [batch_join](#batch_join) for batching examples that have
+already been well shuffled. Use [shuffle_batch](#shuffle_batch) or
+[shuffle_batch_join](#shuffle_batch_join) for examples that
+would benefit from additional shuffling.
+
+Use [batch](#batch) or [shuffle_batch](#shuffle_batch) if you want a
+single thread producing examples to batch, or if you have a
+single subgraph producing examples but you want to run it in N threads
+(where you increase N until it can keep the queue full). Use
+[batch_join](#batch_join) or [shuffle_batch_join](#shuffle_batch_join)
+if you have N different subgraphs producing examples to batch and you
+want them run by N threads.
+
+@@batch
+@@batch_join
+@@shuffle_batch
+@@shuffle_batch_join
+"""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_io_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_io_ops import *
+# pylint: enable=wildcard-import
+
+
+# pylint: disable=protected-access
+def _save(filename, tensor_names, tensors, tensor_slices=None, name="save"):
+ """Save a list of tensors to a file with given names.
+
+ Example usage without slice info:
+ Save("/foo/bar", ["w", "b"], [w, b])
+
+ Example usage with slices:
+ Save("/foo/bar", ["w", "w"], [slice0, slice1],
+ tensor_slices=["4 10 0,2:-", "4 10 2,2:-"])
+
+ Args:
+ filename: the file name of the sstable.
+ tensor_names: a list of strings.
+ tensors: the list of tensors to be saved.
+ tensor_slices: Optional list of strings to specify the shape and slices of
+ a larger virtual tensor that each tensor is a part of. If not specified
+ each tensor is saved as a full slice.
+ name: string. Optional name for the op.
+
+ Requires:
+ The length of tensors should match the size of tensor_names and of
+ tensor_slices.
+
+ Returns:
+ An Operation that saves the tensors.
+ """
+ if tensor_slices is None:
+ return gen_io_ops._save(filename, tensor_names, tensors, name=name)
+ else:
+ return gen_io_ops._save_slices(filename, tensor_names, tensor_slices,
+ tensors, name=name)
+
+
+def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type,
+ name="restore_slice", preferred_shard=-1):
+ """Restore a tensor slice from a set of files with a given pattern.
+
+ Example usage:
+ RestoreSlice("/foo/bar-?????-of-?????", "w", "10 10 0,2:-", DT_FLOAT)
+
+ Args:
+ file_pattern: the file pattern used to match a set of checkpoint files.
+ tensor_name: the name of the tensor to restore.
+ shape_and_slice: the shape-and-slice spec of the slice.
+ tensor_type: the type of the tensor to restore.
+ name: string. Optional name for the op.
+ preferred_shard: Int. Optional shard to open first in the checkpoint file.
+
+ Returns:
+ A tensor of type "tensor_type".
+ """
+ base_type = types.as_dtype(tensor_type).base_dtype
+ return gen_io_ops._restore_slice(
+ file_pattern, tensor_name, shape_and_slice, base_type,
+ preferred_shard, name=name)
+
+
+@ops.RegisterShape("Restore")
+def _RestoreShape(op):
+ """Shape function for Restore op."""
+ # Validate input shapes.
+ unused_file_pattern = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ unused_tensor_name = op.inputs[1].get_shape().merge_with(
+ tensor_shape.scalar())
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("RestoreSlice")
+def _RestoreSliceShape(op):
+ """Shape function for RestoreSlice op."""
+ # Validate input shapes.
+ unused_file_pattern = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ unused_tensor_name = op.inputs[1].get_shape().merge_with(
+ tensor_shape.scalar())
+ unused_shape_and_slice_shape = op.inputs[2].get_shape().merge_with(
+ tensor_shape.scalar())
+ # TODO(mrry): Attempt to parse the shape_and_slice value and use it
+ # to form the shape of the output.
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("Save")
+def _SaveShape(op):
+ """Shape function for Save op."""
+ # Validate input shapes.
+ unused_filename = op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
+ data_count = len(op.inputs) - 2
+ unused_tensor_names_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.vector(data_count))
+ return []
+
+
+@ops.RegisterShape("SaveSlices")
+def _SaveSlicesShape(op):
+ """Shape function for SaveSlices op."""
+ # Validate input shapes.
+ unused_filename = op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
+ data_count = len(op.inputs) - 3
+ unused_tensor_names_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.vector(data_count))
+ unused_shapes_and_slices_shape = op.inputs[2].get_shape().merge_with(
+ tensor_shape.vector(data_count))
+ # TODO(mrry): Attempt to parse the shapes_and_slices values and use
+ # them to constrain the shape of the remaining inputs.
+ return []
+
+
+@ops.RegisterShape("ShardedFilename")
+def _ShardedFilenameShape(op):
+ """Shape function for ShardedFilename op."""
+ # Validate input shapes.
+ unused_basename_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ unused_shard_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.scalar())
+ unused_num_shards_shape = op.inputs[2].get_shape().merge_with(
+ tensor_shape.scalar())
+ return [tensor_shape.scalar()]
+
+
+@ops.RegisterShape("ShardedFilespec")
+def _ShardedFilespecShape(op):
+ """Shape function for ShardedFilespec op."""
+ # Validate input shapes.
+ unused_basename_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ unused_num_shards_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.scalar())
+ return [tensor_shape.scalar()]
+
+
+class ReaderBase(object):
+ """Base class for different Reader types, that produce a record every step.
+
+ Conceptually, Readers convert string 'work units' into records (key,
+ value pairs). Typically the 'work units' are filenames and the
+ records are extracted from the contents of those files. We want a
+ single record produced per step, but a work unit can correspond to
+ many records.
+
+ Therefore we introduce some decoupling using a queue. The queue
+ contains the work units and the Reader dequeues from the queue when
+ it is asked to produce a record (via Read()) but it has finished the
+ last work unit.
+ """
+
+ def __init__(self, reader_ref, supports_serialize=False):
+ """Creates a new ReaderBase.
+
+ Args:
+ reader_ref: The operation that implements the reader.
+ supports_serialize: True if the reader implementation can
+ serialize its state.
+ """
+ self._reader_ref = reader_ref
+ self._supports_serialize = supports_serialize
+
+ @property
+ def reader_ref(self):
+ """Op that implements the reader."""
+ return self._reader_ref
+
+ def read(self, queue, name=None):
+ """Returns the next record (key, value pair) produced by a reader.
+
+ Will dequeue a work unit from queue if necessary (e.g. when the
+ Reader needs to start reading from a new file since it has
+ finished with the previous file).
+
+ Args:
+ queue: A Queue or a mutable string Tensor representing a handle
+ to a Queue, with string work items.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tuple of Tensors (key, value).
+ key: A string scalar Tensor.
+ value: A string scalar Tensor.
+ """
+ if isinstance(queue, ops.Tensor):
+ queue_ref = queue
+ else:
+ queue_ref = queue.queue_ref
+ return gen_io_ops._reader_read(self._reader_ref, queue_ref, name=name)
+
+ def num_records_produced(self, name=None):
+ """Returns the number of records this reader has produced.
+
+ This is the same as the number of Read executions that have
+ succeeded.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ An int64 Tensor.
+
+ """
+ return gen_io_ops._reader_num_records_produced(self._reader_ref, name=name)
+
+ def num_work_units_completed(self, name=None):
+ """Returns the number of work units this reader has finished processing.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ An int64 Tensor.
+ """
+ return gen_io_ops._reader_num_work_units_completed(self._reader_ref,
+ name=name)
+
+ def serialize_state(self, name=None):
+ """Produce a string tensor that encodes the state of a reader.
+
+ Not all Readers support being serialized, so this can produce an
+ Unimplemented error.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ A string Tensor.
+ """
+ return gen_io_ops._reader_serialize_state(self._reader_ref, name=name)
+
+ def restore_state(self, state, name=None):
+ """Restore a reader to a previously saved state.
+
+ Not all Readers support being restored, so this can produce an
+ Unimplemented error.
+
+ Args:
+ state: A string Tensor.
+ Result of a SerializeState of a Reader with matching type.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created Operation.
+ """
+ return gen_io_ops._reader_restore_state(self._reader_ref, state, name=name)
+
+ @property
+ def supports_serialize(self):
+ """Whether the Reader implementation can serialize its state."""
+ return self._supports_serialize
+
+ def reset(self, name=None):
+ """Restore a reader to its initial clean state.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ The created Operation.
+ """
+ return gen_io_ops._reader_reset(self._reader_ref, name=name)
+
+
+ops.NoGradient("ReaderRead")
+ops.NoGradient("ReaderNumRecordsProduced")
+ops.NoGradient("ReaderNumWorkUnitsCompleted")
+ops.NoGradient("ReaderSerializeState")
+ops.NoGradient("ReaderRestoreState")
+ops.NoGradient("ReaderReset")
+
+
+class WholeFileReader(ReaderBase):
+ """A Reader that outputs the entire contents of a file as a value.
+
+ To use, enqueue filenames in a Queue. The output of Read will
+ be a filename (key) and the contents of that file (value).
+
+ See ReaderBase for supported methods.
+ """
+
+ def __init__(self, name=None):
+ """Create a WholeFileReader.
+
+ Args:
+ name: A name for the operation (optional).
+ """
+ rr = gen_io_ops._whole_file_reader(name=name)
+ super(WholeFileReader, self).__init__(rr, supports_serialize=True)
+
+
+ops.NoGradient("WholeFileReader")
+
+
+class TextLineReader(ReaderBase):
+ """A Reader that outputs the lines of a file delimited by newlines.
+
+ Newlines are stripped from the output.
+ See ReaderBase for supported methods.
+ """
+ # TODO(josh11b): Support serializing and restoring state.
+
+ def __init__(self, skip_header_lines=None, name=None):
+ """Create a TextLineReader.
+
+ Args:
+ skip_header_lines: An optional int. Defaults to 0. Number of lines
+ to skip from the beginning of every file.
+ name: A name for the operation (optional).
+ """
+ rr = gen_io_ops._text_line_reader(skip_header_lines=skip_header_lines,
+ name=name)
+ super(TextLineReader, self).__init__(rr)
+
+
+ops.NoGradient("TextLineReader")
+
+
+class FixedLengthRecordReader(ReaderBase):
+ """A Reader that outputs fixed-length records from a file.
+
+ See ReaderBase for supported methods.
+ """
+ # TODO(josh11b): Support serializing and restoring state.
+
+ def __init__(self, record_bytes, header_bytes=None, footer_bytes=None,
+ name=None):
+ """Create a FixedLengthRecordReader.
+
+ Args:
+ record_bytes: An int.
+ header_bytes: An optional int. Defaults to 0.
+ footer_bytes: An optional int. Defaults to 0.
+ name: A name for the operation (optional).
+ """
+ rr = gen_io_ops._fixed_length_record_reader(
+ record_bytes=record_bytes, header_bytes=header_bytes,
+ footer_bytes=footer_bytes, name=name)
+ super(FixedLengthRecordReader, self).__init__(rr)
+
+
+ops.NoGradient("FixedLengthRecordReader")
+
+
+class TFRecordReader(ReaderBase):
+ """A Reader that outputs the records from a TFRecords file.
+
+ See ReaderBase for supported methods.
+ """
+ # TODO(josh11b): Support serializing and restoring state.
+
+ def __init__(self, name=None):
+ """Create a TFRecordReader.
+
+ Args:
+ name: A name for the operation (optional).
+ """
+ rr = gen_io_ops._tf_record_reader(name=name)
+ super(TFRecordReader, self).__init__(rr)
+
+
+ops.NoGradient("TFRecordReader")
+
+
+class IdentityReader(ReaderBase):
+ """A Reader that outputs the queued work as both the key and value.
+
+ To use, enqueue strings in a Queue. Read will take the front
+ work string and output (work, work).
+
+ See ReaderBase for supported methods.
+ """
+
+ def __init__(self, name=None):
+ """Create a IdentityReader.
+
+ Args:
+ name: A name for the operation (optional).
+ """
+ rr = gen_io_ops._identity_reader(name=name)
+ super(IdentityReader, self).__init__(rr, supports_serialize=True)
+
+
+ops.NoGradient("IdentityReader")
+
+
+ops.RegisterShape("FixedLengthRecordReader")(common_shapes.scalar_shape)
+ops.RegisterShape("IdentityReader")(common_shapes.scalar_shape)
+ops.RegisterShape("TextLineReader")(common_shapes.scalar_shape)
+ops.RegisterShape("WholeFileReader")(common_shapes.scalar_shape)
+ops.RegisterShape("TFRecordReader")(common_shapes.scalar_shape)
+
+
+@ops.RegisterShape("ReaderNumRecordsProduced")
+@ops.RegisterShape("ReaderNumWorkUnitsCompleted")
+@ops.RegisterShape("ReaderSerializeState")
+def _ReaderScalarShape(op):
+ """Shape function for ops that transform a reader to a scalar."""
+ unused_handle_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ return [tensor_shape.scalar()]
+
+
+@ops.RegisterShape("ReaderRead")
+def _ReaderReadShape(op):
+ """Shape function for the ReaderBase.Read op."""
+ unused_handle_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ unused_queue_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.scalar())
+ return [tensor_shape.scalar(), tensor_shape.scalar()]
+
+
+@ops.RegisterShape("ReaderReset")
+def _ReaderResetShape(op):
+ """Shape function for the ReaderBase.Reset op."""
+ unused_handle_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ return []
+
+
+@ops.RegisterShape("ReaderRestoreState")
+def _ReaderRestoreStateShape(op):
+ """Shape function for the ReaderBase.Restore op."""
+ unused_handle_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ unused_state_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.scalar())
+ return []
+
+
+@ops.RegisterShape("ReadFile")
+def _ReadFileShape(op):
+ """Shape function for the ReadFile op."""
+ return [op.inputs[0].get_shape().merge_with(tensor_shape.scalar())]
+
+
+@ops.RegisterShape("MatchingFiles")
+def _MatchingFilesShape(op):
+ """Shape function for the MatchingFiles op."""
+ unused_patern_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ return [tensor_shape.unknown_shape(ndims=1)]
diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py
new file mode 100644
index 0000000000..893618c9dd
--- /dev/null
+++ b/tensorflow/python/ops/linalg_grad.py
@@ -0,0 +1,25 @@
+"""Gradients for operators defined in linalg_ops.py."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+
+@ops.RegisterGradient("MatrixInverse")
+def _MatrixInverseGrad(op, grad):
+ """Gradient for MatrixInverse."""
+ ainv = op.outputs[0]
+ return -math_ops.matmul(
+ ainv,
+ math_ops.matmul(grad, ainv, transpose_b=True),
+ transpose_a=True)
+
+@ops.RegisterGradient("BatchMatrixInverse")
+def _BatchMatrixInverseGrad(op, grad):
+ """Gradient for BatchMatrixInverse."""
+ ainv = op.outputs[0]
+ return -math_ops.batch_matmul(
+ ainv,
+ math_ops.batch_matmul(grad, ainv, adj_y=True),
+ adj_x=True)
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
new file mode 100644
index 0000000000..76fd83fb3d
--- /dev/null
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -0,0 +1,62 @@
+"""Operations for linear algebra."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_linalg_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_linalg_ops import *
+# pylint: enable=wildcard-import
+
+
+@ops.RegisterShape("Cholesky")
+def _CholeskyShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank(2)
+ # The matrix must be square.
+ input_shape[0].assert_is_compatible_with(input_shape[1])
+ return [input_shape]
+
+
+@ops.RegisterShape("BatchCholesky")
+def _BatchCholeskyShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
+ # The matrices in the batch must be square.
+ input_shape[-1].assert_is_compatible_with(input_shape[-2])
+ return [input_shape]
+
+
+@ops.RegisterShape("MatrixDeterminant")
+def _MatrixDeterminantShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank(2)
+ # The matrix must be square.
+ input_shape[0].assert_is_compatible_with(input_shape[1])
+ if input_shape.ndims is not None:
+ return [tensor_shape.scalar()]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("BatchMatrixDeterminant")
+def _BatchMatrixDeterminantShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
+ # The matrices in the batch must be square.
+ input_shape[-1].assert_is_compatible_with(input_shape[-2])
+ if input_shape.ndims is not None:
+ return [input_shape[:-2]]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("MatrixInverse")
+def _MatrixInverseShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank(2)
+ # The matrix must be square.
+ input_shape[0].assert_is_compatible_with(input_shape[1])
+ return [input_shape]
+
+
+@ops.RegisterShape("BatchMatrixInverse")
+def _BatchMatrixInverseShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
+ # The matrices in the batch must be square.
+ input_shape[-1].assert_is_compatible_with(input_shape[-2])
+ return [input_shape]
diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py
new file mode 100644
index 0000000000..0fad4a2dde
--- /dev/null
+++ b/tensorflow/python/ops/logging_ops.py
@@ -0,0 +1,58 @@
+"""Logging Operations."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_logging_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_logging_ops import *
+# pylint: enable=wildcard-import
+
+
+# Assert and Print are special symbols in python, so we must
+# use an upper-case version of them.
+def Assert(condition, data, summarize=None, name=None):
+ """Asserts that the given condition is true.
+
+ If `condition` evaluates to false, print the list of tensors in `data`.
+ `summarize` determines how many entries of the tensors to print.
+
+ Args:
+ condition: The condition to evaluate.
+ data: The tensors to print out when condition is false.
+ summarize: Print this many entries of each tensor.
+ name: A name for this operation (optional).
+ """
+ return gen_logging_ops._assert(condition, data, summarize, name)
+
+
+def Print(input_, data, message=None, first_n=None, summarize=None,
+ name=None):
+ """Prints a list of tensors.
+
+ This is an identity op with the side effect of printing `data` when
+ evaluating.
+
+ Args:
+ input_: A tensor passed through this op.
+ data: A list of tensors to print out when op is evaluated.
+ message: A string, prefix of the error message.
+ first_n: Only log `first_n` number of times. Negative numbers log always;
+ this is the default.
+ summarize: Only print this many entries of each tensor.
+ name: A name for the operation (optional).
+
+ Returns:
+ Same tensor as `input_`.
+ """
+ return gen_logging_ops._print(input_, data, message, first_n, summarize, name)
+
+
+@ops.RegisterGradient("Print")
+def _PrintGrad(op, *grad):
+ return list(grad) + [None] * (len(op.inputs) - 1)
+
+
+# NOTE(mrry): Assert and Print produce an empty output, which is
+# presumably never read.
+ops.RegisterShape("Assert")(common_shapes.unknown_shape)
+ops.RegisterShape("Print")(common_shapes.unknown_shape)
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
new file mode 100644
index 0000000000..cb808ff5b8
--- /dev/null
+++ b/tensorflow/python/ops/math_grad.py
@@ -0,0 +1,506 @@
+"""Gradients for operators defined in math_ops.py."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
+
+
+def _ReductionGradAssist(op):
+ """Reduction grads have much in common, so factor the commonality out."""
+ inp = op.inputs[0] # Example:
+ input_shape = array_ops.shape(inp) # [2, 3, 5, 7]
+ input_rank = array_ops.rank(inp) # 4
+ indices = op.inputs[1] # [1, 2]
+ indices_shape = array_ops.shape(indices) # [2]
+ new_output_shape = data_flow_ops.dynamic_stitch( # [2, 1, 1, 7]
+ [math_ops.range(0, input_rank), # [0, 1, 2, 3]
+ indices], # [1, 2]
+ [input_shape, # [2, 3, 5, 7]
+ array_ops.fill(indices_shape, 1)]) # [1, 1]
+ return inp, new_output_shape, input_shape
+
+
+@ops.RegisterGradient("Sum")
+def _SumGrad(op, grad):
+ """Gradient for Sum."""
+ _, new_output_shape, input_shape = _ReductionGradAssist(op)
+ tile_scaling = input_shape / new_output_shape
+ grad = array_ops.reshape(grad, new_output_shape)
+ return [array_ops.tile(grad, tile_scaling), None]
+
+
+def _MinOrMaxGrad(op, grad):
+ """Gradient for Max or Max. Amazingly it's precisely the same code."""
+ inp, new_output_shape, _ = _ReductionGradAssist(op)
+ y = op.outputs[0]
+ y = array_ops.reshape(y, new_output_shape)
+ grad = array_ops.reshape(grad, new_output_shape)
+ indicators = math_ops.cast(math_ops.equal(y, inp), grad.dtype)
+ return [indicators * grad, None]
+
+
+@ops.RegisterGradient("Max")
+def _MaxGrad(op, grad):
+ """Gradient for Max."""
+ return _MinOrMaxGrad(op, grad)
+
+
+@ops.RegisterGradient("Min")
+def _MinGrad(op, grad):
+ return _MinOrMaxGrad(op, grad)
+
+
+@ops.RegisterGradient("Mean")
+def _MeanGrad(op, grad):
+ """Gradient for Mean."""
+ sum_grad = _SumGrad(op, grad)[0]
+ input_shape = array_ops.shape(op.inputs[0])
+ output_shape = array_ops.shape(op.outputs[0])
+ factor = (math_ops.reduce_prod(input_shape) /
+ math_ops.reduce_prod(output_shape))
+ return sum_grad / math_ops.cast(factor, sum_grad.dtype), None
+
+
+@ops.RegisterGradient("Prod")
+def _ProdGrad(op, grad):
+ """Gradient for Prod."""
+ # TODO(kearnes): this gives NaNs for 0s in the input tensor
+ _, new_output_shape, input_shape = _ReductionGradAssist(op)
+ tile_scaling = input_shape / new_output_shape
+ grad = array_ops.reshape(grad * op.outputs[0], new_output_shape)
+ grad = math_ops.div(array_ops.tile(grad, tile_scaling), op.inputs[0])
+ return grad, None
+
+
+@ops.RegisterGradient("SegmentSum")
+def _SegmentSumGrad(op, grad):
+ """Gradient for SegmentSum."""
+ return array_ops.gather(grad, op.inputs[1]), None
+
+
+@ops.RegisterGradient("SegmentMean")
+def _SegmentMeanGrad(op, grad):
+ """Gradient for SegmentMean."""
+ input_rank = array_ops.rank(op.inputs[0])
+ ones_shape = array_ops.concat(
+ 0, [array_ops.shape(op.inputs[1]),
+ array_ops.fill(array_ops.expand_dims(input_rank - 1, 0), 1)])
+ ones = array_ops.fill(ones_shape,
+ constant_op.constant(1, dtype=grad.dtype))
+ scaled_grad = grad * math_ops.inv(math_ops.segment_sum(ones, op.inputs[1]))
+ return array_ops.gather(scaled_grad, op.inputs[1]), None
+
+
+@ops.RegisterGradient("SparseSegmentSum")
+def _SparseSegmentSumGrad(op, grad):
+ """Gradient for SparseSegmentSum."""
+ input_rows = array_ops.shape(op.inputs[0])[0]
+ return (math_ops.unsorted_segment_sum(
+ array_ops.gather(grad, op.inputs[2]),
+ op.inputs[1], input_rows), None, None)
+
+
+@ops.RegisterGradient("SparseSegmentMean")
+def _SparseSegmentMeanGrad(op, grad):
+ """Gradient for SparseSegmentMean."""
+ dim0 = array_ops.shape(op.inputs[0])[0]
+ return (math_ops.sparse_segment_mean_grad(grad,
+ op.inputs[1],
+ op.inputs[2],
+ dim0),
+ None, None)
+
+
+@ops.RegisterGradient("SegmentMin")
+def _SegmentMinGrad(op, grad):
+ """Gradient for SegmentMin."""
+ zeros = array_ops.zeros(array_ops.shape(op.inputs[0]),
+ dtype=op.inputs[0].dtype)
+ gathered_grads = array_ops.gather(grad, op.inputs[1])
+ gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
+ return math_ops.select(math_ops.greater(op.inputs[0], gathered_outputs),
+ zeros,
+ gathered_grads), None
+
+
+@ops.RegisterGradient("SegmentMax")
+def _SegmentMaxGrad(op, grad):
+ """Gradient for SegmentMax."""
+ zeros = array_ops.zeros(array_ops.shape(op.inputs[0]),
+ dtype=op.inputs[0].dtype)
+ gathered_grads = array_ops.gather(grad, op.inputs[1])
+ gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
+ return math_ops.select(math_ops.less(op.inputs[0], gathered_outputs),
+ zeros,
+ gathered_grads), None
+
+
+@ops.RegisterGradient("UnsortedSegmentSum")
+def _UnsortedSegmentSumGrad(op, grad):
+ """Gradient for SegmentSum."""
+ return array_ops.gather(grad, op.inputs[1]), None, None
+
+
+@ops.RegisterGradient("Abs")
+def _AbsGrad(op, grad):
+ x = op.inputs[0]
+ return grad * math_ops.sign(x)
+
+
+@ops.RegisterGradient("Neg")
+def _NegGrad(_, grad):
+ """Returns -grad."""
+ return - grad
+
+
+@ops.RegisterGradient("Inv")
+def _InvGrad(op, grad):
+ """Returns -grad * (1 / x^2)."""
+ y = op.outputs[0] # y = 1 / x
+ return grad * (- math_ops.square(y))
+
+
+@ops.RegisterGradient("Square")
+def _SquareGrad(op, grad):
+ x = op.inputs[0]
+ return grad * (2.0 * x)
+
+
+@ops.RegisterGradient("Sqrt")
+def _SqrtGrad(op, grad):
+ y = op.outputs[0] # y = x^(1/2)
+ return grad * (.5 * math_ops.inv(y))
+
+
+@ops.RegisterGradient("Rsqrt")
+def _RsqrtGrad(op, grad):
+ x = op.inputs[0]
+ y = op.outputs[0] # y = x^(-1/2)
+ return grad * ((-0.5) * math_ops.inv(x) * y)
+
+
+@ops.RegisterGradient("Exp")
+def _ExpGrad(op, grad):
+ """Returns grad * exp(x)."""
+ y = op.outputs[0] # y = e^x
+ return grad * y
+
+
+@ops.RegisterGradient("Log")
+def _LogGrad(op, grad):
+ """Returns grad * (1/x)."""
+ x = op.inputs[0]
+ return grad * math_ops.inv(x)
+
+
+@ops.RegisterGradient("Tanh")
+def _TanhGrad(op, grad):
+ """Returns grad * (1 - tanh(x) * tanh(x))."""
+ y = op.outputs[0] # y = tanh(x)
+ return grad * (1 - math_ops.square(y))
+
+
+@ops.RegisterGradient("Sigmoid")
+def _SigmoidGrad(op, grad):
+ """Returns grad * sigmoid(x) * (1 - sigmoid(x))."""
+ y = op.outputs[0] # y = sigmoid(x)
+ return grad * (y * (1 - y))
+
+
+@ops.RegisterGradient("Sign")
+def _SignGrad(op, _):
+ """Returns 0."""
+ x = op.inputs[0]
+ return array_ops.zeros(array_ops.shape(x), dtype=x.dtype)
+
+
+@ops.RegisterGradient("Sin")
+def _SinGrad(op, grad):
+ """Returns grad * cos(x)."""
+ x = op.inputs[0]
+ return grad * math_ops.cos(x)
+
+
+@ops.RegisterGradient("Cos")
+def _CosGrad(op, grad):
+ """Returns grad * -sin(x)."""
+ x = op.inputs[0]
+ return -grad * math_ops.sin(x)
+
+
+@ops.RegisterGradient("AddN")
+def _AddNGrad(op, grad):
+ """Copies the gradient to all inputs."""
+ # Not broadcasting.
+ return [grad] * len(op.inputs)
+
+
+@ops.RegisterGradient("Add")
+def _AddGrad(op, grad):
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(grad, ry), sy))
+
+
+@ops.RegisterGradient("Sub")
+def _SubGrad(op, grad):
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
+ array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy))
+
+
+@ops.RegisterGradient("Mul")
+def _MulGrad(op, grad):
+ x = op.inputs[0]
+ y = op.inputs[1]
+ assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ if x.dtype.base_dtype == types.complex64:
+ return (array_ops.reshape(math_ops.reduce_sum(grad * math_ops.conj(y), rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(math_ops.conj(x) * grad, ry), sy))
+ else:
+ return (array_ops.reshape(math_ops.reduce_sum(grad * y, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(x * grad, ry), sy))
+
+
+@ops.RegisterGradient("Div")
+def _DivGrad(op, grad):
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ return (array_ops.reshape(math_ops.reduce_sum(grad / y, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(grad *
+ (-x / math_ops.square(y)), ry), sy))
+
+
+@ops.RegisterGradient("Pow")
+def _PowGrad(op, grad):
+ """Returns grad * (y*x^(y-1), z*log(x))."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ z = op.outputs[0]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ gx = array_ops.reshape(math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx),
+ sx)
+ gy = array_ops.reshape(math_ops.reduce_sum(grad * z * math_ops.log(x), ry), sy)
+ return gx, gy
+
+
+def _MaximumMinimumGrad(op, grad, selector_op):
+ """Factor out the code for the gradient of Maximum or Minimum."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ gdtype = grad.dtype
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ gradshape = array_ops.shape(grad)
+ zeros = array_ops.zeros(gradshape, gdtype)
+ xmask = selector_op(x, y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ xgrad = math_ops.select(xmask, grad, zeros)
+ ygrad = math_ops.select(math_ops.logical_not(xmask), grad, zeros)
+ gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx)
+ gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy)
+ return (gx, gy)
+
+
+@ops.RegisterGradient("Maximum")
+def _MaximumGrad(op, grad):
+ """Returns grad*(x > y, x <= y) with type of grad."""
+ return _MaximumMinimumGrad(op, grad, math_ops.greater_equal)
+
+
+@ops.RegisterGradient("Minimum")
+def _MinimumGrad(op, grad):
+ """Returns grad*(x < y, x >= y) with type of grad."""
+ return _MaximumMinimumGrad(op, grad, math_ops.less_equal)
+
+
+# Logical operations have no gradients.
+ops.NoGradient("Less")
+ops.NoGradient("LessEqual")
+ops.NoGradient("Greater")
+ops.NoGradient("GreaterEqual")
+ops.NoGradient("Equal")
+ops.NoGradient("NotEqual")
+ops.NoGradient("LogicalAnd")
+ops.NoGradient("LogicalOr")
+ops.NoGradient("LogicalNot")
+
+
+@ops.RegisterGradient("Select")
+def _SelectGrad(op, grad):
+ c = op.inputs[0]
+ x = op.inputs[1]
+ zeros = array_ops.zeros(array_ops.shape(c), dtype=x.dtype)
+ return (None, math_ops.select(c, grad, zeros),
+ math_ops.select(c, zeros, grad))
+
+
+@ops.RegisterGradient("MatMul")
+def _MatMulGrad(op, grad):
+ t_a = op.get_attr("transpose_a")
+ t_b = op.get_attr("transpose_b")
+ if not t_a and not t_b:
+ return (math_ops.matmul(grad, op.inputs[1], transpose_b=True),
+ math_ops.matmul(op.inputs[0], grad, transpose_a=True))
+ elif not t_a and t_b:
+ return (math_ops.matmul(grad, op.inputs[1]),
+ math_ops.matmul(grad, op.inputs[0], transpose_a=True))
+ elif t_a and not t_b:
+ return (math_ops.matmul(op.inputs[1], grad, transpose_b=True),
+ math_ops.matmul(op.inputs[0], grad))
+ elif t_a and t_b:
+ return (math_ops.matmul(op.inputs[1], grad, transpose_a=True,
+ transpose_b=True),
+ math_ops.matmul(grad, op.inputs[0], transpose_a=True,
+ transpose_b=True))
+
+
+@ops.RegisterGradient("SparseMatMul")
+def _SparseMatMulGrad(op, grad):
+ """Gradient for SparseMatMul."""
+
+ t_a = op.get_attr("transpose_a")
+ t_b = op.get_attr("transpose_b")
+ is_sparse = {
+ op.inputs[0]: op.get_attr("a_is_sparse"),
+ op.inputs[1]: op.get_attr("b_is_sparse"),
+ # Use heuristic to figure out if grad might be sparse
+ grad: (grad.op.type == "ReluGrad")
+ }
+ def _SparseMatMul(t1, t2, transpose_a=False, transpose_b=False):
+ """Helper function to create SparseMatMul op."""
+
+ assert t1 in is_sparse and t2 in is_sparse
+ t1_sparse = is_sparse[t1]
+ t2_sparse = is_sparse[t2]
+ if not t1_sparse and not t2_sparse:
+ return math_ops.matmul(t1, t2,
+ transpose_a=transpose_a,
+ transpose_b=transpose_b)
+ transpose_out = False
+ if not t1_sparse:
+ transpose_out = True
+ t1, t2 = t2, t1
+ t1_sparse, t2_sparse = t2_sparse, t1_sparse
+ assert t1_sparse
+ transpose_a, transpose_b = not transpose_b, not transpose_a
+
+ if transpose_b:
+ t2 = array_ops.transpose(t2)
+ transpose_b = False
+ m = math_ops.matmul(t1, t2,
+ transpose_a=transpose_a,
+ transpose_b=transpose_b,
+ a_is_sparse=t1_sparse,
+ b_is_sparse=t2_sparse)
+ if transpose_out:
+ m = array_ops.transpose(m)
+ return m
+
+ if not t_a and not t_b:
+ return (_SparseMatMul(grad, op.inputs[1], transpose_b=True),
+ _SparseMatMul(op.inputs[0], grad, transpose_a=True))
+ elif not t_a and t_b:
+ return (_SparseMatMul(grad, op.inputs[1]),
+ _SparseMatMul(grad, op.inputs[0], transpose_a=True))
+ elif t_a and not t_b:
+ return (_SparseMatMul(op.inputs[1], grad, transpose_b=True),
+ _SparseMatMul(op.inputs[0], grad))
+ elif t_a and t_b:
+ return (_SparseMatMul(op.inputs[1], grad,
+ transpose_a=True, transpose_b=True),
+ _SparseMatMul(grad, op.inputs[0],
+ transpose_a=True, transpose_b=True))
+
+
+@ops.RegisterGradient("Floor")
+def _FloorGrad(_, grad):
+ return grad
+
+
+@ops.RegisterGradient("BatchMatMul")
+def _BatchMatMul(op, grad):
+ """Returns the gradient of x and y given the gradient of x * y."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ adj_x = op.get_attr("adj_x")
+ adj_y = op.get_attr("adj_y")
+
+ if not adj_x:
+ if not adj_y:
+ grad_x = math_ops.batch_matmul(grad, y, False, True)
+ grad_y = math_ops.batch_matmul(x, grad, True, False)
+ else:
+ grad_x = math_ops.batch_matmul(grad, y, False, False)
+ grad_y = math_ops.batch_matmul(grad, x, True, False)
+ else:
+ if not adj_y:
+ grad_x = math_ops.batch_matmul(y, grad, False, True)
+ grad_y = math_ops.batch_matmul(x, grad, False, False)
+ else:
+ grad_x = math_ops.batch_matmul(y, grad, True, True)
+ grad_y = math_ops.batch_matmul(grad, x, True, True)
+
+ return grad_x, grad_y
+
+
+ops.NoGradient("Range")
+ops.NoGradient("LinSpace")
+
+
+@ops.RegisterGradient("Complex")
+def _ComplexGrad(_, grad):
+ """Returns the real and imaginary components of 'grad', respectively."""
+ return math_ops.real(grad), math_ops.imag(grad)
+
+
+@ops.RegisterGradient("Real")
+def _RealGrad(_, grad):
+ """Returns 'grad' as the real part and set the imaginary part 0."""
+ zero = constant_op.constant(0, dtype=grad.dtype)
+ return math_ops.complex(grad, zero)
+
+
+@ops.RegisterGradient("Imag")
+def _ImagGrad(_, grad):
+ """Returns 'grad' as the imaginary part and set the real part 0."""
+ zero = constant_op.constant(0, dtype=grad.dtype)
+ return math_ops.complex(zero, grad)
+
+
+@ops.RegisterGradient("Conj")
+def _ConjGrad(_, grad):
+ """Returns the complex conjugate of grad."""
+ return math_ops.conj(grad)
+
+
+@ops.RegisterGradient("Cast")
+def _CastGrad(op, grad):
+ t = [types.float32, types.float64, types.bfloat16]
+ src_type = op.inputs[0].dtype.base_dtype
+ dst_type = grad.dtype.base_dtype
+ if src_type in t and dst_type in t:
+ return math_ops.cast(grad, src_type)
+ else:
+ return None
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
new file mode 100644
index 0000000000..d96320e96e
--- /dev/null
+++ b/tensorflow/python/ops/math_ops.py
@@ -0,0 +1,1201 @@
+"""## Arithmetic Operators
+
+TensorFlow provides several operations that you can use to add basic arithmetic
+operators to your graph.
+
+@@add
+@@sub
+@@mul
+@@div
+@@mod
+
+## Basic Math Functions
+
+TensorFlow provides several operations that you can use to add basic
+mathematical functions to your graph.
+
+@@add_n
+@@abs
+@@neg
+@@sign
+@@inv
+@@square
+@@round
+@@sqrt
+@@rsqrt
+@@pow
+@@exp
+@@log
+@@ceil
+@@floor
+@@maximum
+@@minimum
+@@cos
+@@sin
+
+## Matrix Math Functions
+
+TensorFlow provides several operations that you can use to add basic
+mathematical functions for matrices to your graph.
+
+@@diag
+@@transpose
+
+@@matmul
+@@batch_matmul
+
+@@matrix_determinant
+@@batch_matrix_determinant
+
+@@matrix_inverse
+@@batch_matrix_inverse
+
+@@cholesky
+@@batch_cholesky
+
+## Complex Number Functions
+
+TensorFlow provides several operations that you can use to add complex number
+functions to your graph.
+
+@@complex
+@@complex_abs
+@@conj
+@@imag
+@@real
+
+## Reduction
+
+TensorFlow provides several operations that you can use to perform
+common math computations that reduce various dimensions of a tensor.
+
+@@reduce_sum
+@@reduce_prod
+@@reduce_min
+@@reduce_max
+@@reduce_mean
+@@reduce_all
+@@reduce_any
+
+@@accumulate_n
+
+## Segmentation
+
+TensorFlow provides several operations that you can use to perform common
+math computations on tensor segments.
+Here a segmentation is a partitioning of a tensor along
+the first dimension, i.e. it defines a mapping from the first dimension onto
+`segment_ids`. The `segment_ids` tensor should be the size of
+the first dimension, `d0`, with consecutive IDs in the range `0` to `k`,
+where `k<d0`.
+In particular, a segmentation of a matrix tensor is a mapping of rows to
+segments.
+
+For example:
+
+```python
+c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+tf.segment_sum(c, tf.constant([0, 0, 1]))
+ ==> [[0 0 0 0]
+ [5 6 7 8]]
+```
+
+@@segment_sum
+@@segment_prod
+@@segment_min
+@@segment_max
+@@segment_mean
+
+@@unsorted_segment_sum
+
+@@sparse_segment_sum
+@@sparse_segment_mean
+
+
+## Sequence Comparison and Indexing
+
+TensorFlow provides several operations that you can use to add sequence
+comparison and index extraction to your graph. You can use these operations to
+determine sequence differences and determine the indexes of specific values in
+a tensor.
+
+@@argmin
+@@argmax
+
+@@listdiff
+@@where
+@@unique
+
+@@edit_distance
+
+@@invert_permutation
+"""
+import itertools
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import gen_state_ops
+# pylint: disable=wildcard-import,undefined-variable
+from tensorflow.python.ops.gen_math_ops import *
+
+
+# Aliases for some automatically-generated names.
+argmax = gen_math_ops.arg_max
+argmin = gen_math_ops.arg_min
+linspace = gen_math_ops.lin_space
+
+
+# pylint: disable=anomalous-backslash-in-string,protected-access
+def abs(x, name=None):
+ """Computes the absolute value of a tensor.
+
+ Given a tensor of real numbers `x`, this operation returns a tensor
+ containing the absolute value of each element in `x`. For example, if x is
+ an input element and y is an output element, this operation computes
+ \\\\(y = |x|\\\\).
+
+ See [`tf.complex_abs()`](#tf_complex_abs) to compute the absolute value of a complex
+ number.
+
+ Args:
+ x: A `Tensor` of type `float`, `double`, `int32`, or `int64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` the same size and type as `x` with absolute values.
+ """
+ with ops.op_scope([x], name, "Abs") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ if x.dtype == types.complex64:
+ return gen_math_ops.complex_abs(x, name=name)
+ return gen_math_ops._abs(x, name=name)
+
+
+
+def pow(x, y, name=None):
+ """Computes the power of one value to another.
+
+ Given a tensor `x` and a tensor `y`, this operation computes \\\\(x^y\\\\) for
+ corresponding elements in `x` and `y`. For example:
+
+ ```
+ # tensor 'x' is [[2, 2]], [3, 3]]
+ # tensor 'y' is [[8, 16], [2, 3]]
+ tf.pow(x, y) ==> [[256, 65536], [9, 27]]
+ ```
+
+ Args:
+ x: A `Tensor` of type `float`, `double`, `int32`, `complex64`, or `int64`.
+ y: A `Tensor` of type `float`, `double`, `int32`, `complex64`, or `int64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`.
+ """
+ with ops.op_scope([x], name, "Pow") as name:
+ return gen_math_ops._pow(x, y, name=name)
+
+
+def complex(real, imag, name=None):
+ """Converts two real numbers to a complex number.
+
+ Given a tensor `real` representing the real part of a complex number, and a
+ tensor `imag` representing the imaginary part of a complex number, this
+ operation computes complex numbers elementwise of the form \\\\(a + bj\\\\),
+ where *a* represents the `real` part and *b* represents the `imag` part.
+
+ The input tensors `real` and `imag` must be the same shape.
+
+ For example:
+
+ ```
+ # tensor 'real' is [2.25, 3.25]
+ # tensor `imag` is [4.75, 5.75]
+ tf.complex(real, imag) ==> [[2.25 + 4.74j], [3.25 + 5.75j]]
+ ```
+
+ Args:
+ real: A `Tensor` of type `float`.
+ imag: A `Tensor` of type `float`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `complex64`.
+ """
+ with ops.op_scope([real, imag], name, "Complex") as name:
+ return gen_math_ops._complex(real, imag, name=name)
+
+
+def round(x, name=None):
+ """Rounds the values of a tensor to the nearest integer, element-wise.
+
+ For example:
+
+ ```python
+ # 'a' is [0.9, 2.5, 2.3, -4.4]
+ tf.round(a) ==> [ 1.0, 3.0, 2.0, -4.0 ]
+ ```
+
+ Args:
+ x: A `Tensor` of type `float` or `double`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of same shape and type as `x`.
+ """
+ x = ops.convert_to_tensor(x, name="x")
+ if x.dtype.is_integer:
+ return x
+ else:
+ return floor(x + 0.5, name=name)
+
+
+def cast(x, dtype, name=None):
+ """Casts a tensor to a new type.
+
+ The operation casts `x` (in case of `Tensor`) or `x.values`
+ (in case of `SparseTensor`) to `dtype`.
+
+ For example:
+
+ ```python
+ # tensor `a` is [1.8, 2.2], dtype=tf.float
+ tf.cast(a, tf.int32) ==> [1, 2] # dtype=tf.int32
+ ```
+
+ Args:
+ x: A `Tensor` or `SparseTensor`.
+ dtype: The destination type.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor` with same shape as `x`.
+
+ Raises:
+ TypeError: If `x` cannot be cast to the `dtype`.
+ """
+ with ops.op_scope([x], name, "Cast") as name:
+ if isinstance(x, ops.SparseTensor):
+ values_cast = cast(x.values, dtype, name=name)
+ return ops.SparseTensor(x.indices, values_cast, x.shape)
+ else:
+ # TODO(mdevin): Handle what Josh said.
+ #
+ # Could return ops.convert_to_tensor(x, dtype=dtype, ...) here, but that
+ # allows some conversions that cast() can't do, e.g. casting numbers to
+ # strings.
+ x = ops.convert_to_tensor(x, name="x")
+ if x.dtype.base_dtype == dtype:
+ return x
+ return gen_math_ops.cast(x, dtype, name=name)
+
+
+def to_float(x, name="ToFloat"):
+ """Casts a tensor to type `float32`.
+
+ Args:
+ x: A `Tensor` or `SparseTensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor` with same shape as `x` with type `float32`.
+
+ Raises:
+ TypeError: If `x` cannot be cast to the `float32`.
+ """
+ return cast(x, types.float32, name=name)
+
+
+def to_double(x, name="ToDouble"):
+ """Casts a tensor to type `float64`.
+
+ Args:
+ x: A `Tensor` or `SparseTensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor` with same shape as `x` with type `float64`.
+
+ Raises:
+ TypeError: If `x` cannot be cast to the `float64`.
+ """
+ return cast(x, types.float64, name=name)
+
+
+def to_int32(x, name="ToInt32"):
+ """Casts a tensor to type `int32`.
+
+ Args:
+ x: A `Tensor` or `SparseTensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor` with same shape as `x` with type `int32`.
+
+ Raises:
+ TypeError: If `x` cannot be cast to the `int32`.
+ """
+ return cast(x, types.int32, name=name)
+
+
+def to_int64(x, name="ToInt64"):
+ """Casts a tensor to type `int64`.
+
+ Args:
+ x: A `Tensor` or `SparseTensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor` with same shape as `x` with type `int64`.
+
+ Raises:
+ TypeError: If `x` cannot be cast to the `int64`.
+ """
+ return cast(x, types.int64, name=name)
+
+
+def to_bfloat16(x, name="ToBFloat16"):
+ """Casts a tensor to type `bfloat16`.
+
+ Args:
+ x: A `Tensor` or `SparseTensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor` with same shape as `x` with type `bfloat16`.
+
+ Raises:
+ TypeError: If `x` cannot be cast to the `bfloat16`.
+ """
+ return cast(x, types.bfloat16, name=name)
+
+
+ops.Tensor._override_operator("__neg__", neg)
+ops.Tensor._override_operator("__abs__", abs)
+# __invert__ corresponds to the ~ operator. Here we follow the numpy convention
+# ~ marks an elementwise bit-wise inverse. This is only implemented for boolean
+# tensors and will throw a TypeError if used on nonboolean arrays
+ops.Tensor._override_operator("__invert__", logical_not)
+
+
+def _OverrideBinaryOperatorHelper(func, op_name):
+ """Register operators with different tensor and scalar versions.
+
+ Args:
+ func: the operator
+ op_name: name of the operator being overridden
+ """
+
+ def binary_op_wrapper(x, y):
+ with ops.op_scope([x, y], None, op_name) as name:
+ assert isinstance(x, ops.Tensor)
+ y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y")
+ return func(x, y, name=name)
+
+ ops.Tensor._override_operator("__%s__" % op_name, binary_op_wrapper)
+ del binary_op_wrapper
+
+ def r_binary_op_wrapper(y, x):
+ with ops.op_scope([x, y], None, op_name) as name:
+ assert isinstance(y, ops.Tensor)
+ x = ops.convert_to_tensor(x, dtype=y.dtype.base_dtype, name="x")
+ return func(x, y, name=name)
+
+ ops.Tensor._override_operator("__r%s__" % op_name, r_binary_op_wrapper)
+ del r_binary_op_wrapper
+
+
+_OverrideBinaryOperatorHelper(add, "add")
+_OverrideBinaryOperatorHelper(sub, "sub")
+_OverrideBinaryOperatorHelper(mul, "mul")
+_OverrideBinaryOperatorHelper(div, "div")
+_OverrideBinaryOperatorHelper(mod, "mod")
+
+
+def logical_xor(x, y, name="LogicalXor"):
+ """x ^ y = (x | y) & ~(x & y)."""
+ # TODO(alemi) Make this a cwise op if people end up relying on it.
+ return logical_and(logical_or(x, y), logical_not(logical_and(x, y)),
+ name=name)
+
+_OverrideBinaryOperatorHelper(logical_and, "and")
+_OverrideBinaryOperatorHelper(logical_or, "or")
+_OverrideBinaryOperatorHelper(logical_xor, "xor")
+
+ops.Tensor._override_operator("__lt__", less)
+ops.Tensor._override_operator("__le__", less_equal)
+ops.Tensor._override_operator("__gt__", greater)
+ops.Tensor._override_operator("__ge__", greater_equal)
+
+
+def range(start, limit, delta=1, name="range"):
+ """Creates a sequence of integers.
+
+ This operation creates a sequence of integers that begins at `start` and
+ extends by increments of `delta` up to but not including `limit`.
+
+ For example:
+
+ ```
+ # 'start' is 3
+ # 'limit' is 18
+ # 'delta' is 3
+ tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
+ ```
+
+ Args:
+ start: A 0-D (scalar) of type `int32`. First entry in sequence.
+ limit: A 0-D (scalar) of type `int32`. Upper limit of sequence,
+ exclusive.
+ delta: A 0-D `Tensor` (scalar) of type `int32`. Optional. Default is 1.
+ Number that increments `start`.
+ name: A name for the operation (optional).
+
+ Returns:
+ An 1-D `int32` `Tensor`.
+ """
+ return gen_math_ops._range(start, limit, delta, name=name)
+
+
+@ops.RegisterShape("Range")
+def _RangeShape(op):
+ start_value = tensor_util.ConstantValue(op.inputs[0])
+ limit_value = tensor_util.ConstantValue(op.inputs[1])
+ delta_value = tensor_util.ConstantValue(op.inputs[2])
+ if start_value is None or limit_value is None or delta_value is None:
+ return [tensor_shape.vector(None)]
+ else:
+ return [tensor_shape.vector(
+ (limit_value - start_value + delta_value - 1) / delta_value)]
+
+
+# Reduction operations
+def _ReductionDims(x, reduction_indices):
+ """Returns range(0, rank(x)) if reduction_indices is None."""
+ if reduction_indices is not None:
+ return reduction_indices
+ else:
+ return range(0, array_ops.rank(x))
+
+
+def reduce_sum(input_tensor, reduction_indices=None, keep_dims=False,
+ name=None):
+ """Computes the sum of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `reduction_indices`.
+ Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `reduction_indices` has no entries, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ For example:
+
+ ```python
+ # 'x' is [[1, 1, 1]]
+ # [1, 1, 1]]
+ tf.reduce_sum(x) ==> 6
+ tf.reduce_sum(x, 0) ==> [2, 2, 2]
+ tf.reduce_sum(x, 1) ==> [3, 3]
+ tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]]
+ tf.reduce_sum(x, [0, 1]) ==> 6
+ ```
+
+ Args:
+ input_tensor: The tensor to reduce. Should have numeric type.
+ reduction_indices: The dimensions to reduce. If `None` (the defaut),
+ reduces all dimensions.
+ keep_dims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+ """
+ return gen_math_ops._sum(input_tensor, _ReductionDims(input_tensor,
+ reduction_indices),
+ keep_dims, name=name)
+
+
+def reduce_mean(input_tensor, reduction_indices=None, keep_dims=False,
+ name=None):
+ """Computes the mean of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `reduction_indices`.
+ Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `reduction_indices` has no entries, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ For example:
+
+ ```python
+ # 'x' is [[1., 1. ]]
+ # [2., 2.]]
+ tf.reduce_mean(x) ==> 1.5
+ tf.reduce_mean(x, 0) ==> [1.5, 1.5]
+ tf.reduce_mean(x, 1) ==> [1., 2.]
+ ```
+
+ Args:
+ input_tensor: The tensor to reduce. Should have numeric type.
+ reduction_indices: The dimensions to reduce. If `None` (the defaut),
+ reduces all dimensions.
+ keep_dims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+ """
+ return gen_math_ops._mean(input_tensor, _ReductionDims(input_tensor,
+ reduction_indices),
+ keep_dims, name=name)
+
+
+def reduce_prod(input_tensor, reduction_indices=None, keep_dims=False,
+ name=None):
+ """Computes the product of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `reduction_indices`.
+ Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `reduction_indices` has no entries, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ Args:
+ input_tensor: The tensor to reduce. Should have numeric type.
+ reduction_indices: The dimensions to reduce. If `None` (the defaut),
+ reduces all dimensions.
+ keep_dims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+ """
+ return gen_math_ops._prod(input_tensor, _ReductionDims(input_tensor,
+ reduction_indices),
+ keep_dims, name=name)
+
+
+def reduce_min(input_tensor, reduction_indices=None, keep_dims=False,
+ name=None):
+ """Computes the minimum of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `reduction_indices`.
+ Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `reduction_indices` has no entries, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ Args:
+ input_tensor: The tensor to reduce. Should have numeric type.
+ reduction_indices: The dimensions to reduce. If `None` (the defaut),
+ reduces all dimensions.
+ keep_dims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+ """
+ return gen_math_ops._min(input_tensor, _ReductionDims(input_tensor,
+ reduction_indices),
+ keep_dims, name=name)
+
+
+def reduce_max(input_tensor, reduction_indices=None, keep_dims=False,
+ name=None):
+ """Computes the maximum of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `reduction_indices`.
+ Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `reduction_indices` has no entries, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ Args:
+ input_tensor: The tensor to reduce. Should have numeric type.
+ reduction_indices: The dimensions to reduce. If `None` (the defaut),
+ reduces all dimensions.
+ keep_dims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+ """
+ return gen_math_ops._max(input_tensor, _ReductionDims(input_tensor,
+ reduction_indices),
+ keep_dims, name=name)
+
+
+def reduce_all(input_tensor, reduction_indices=None, keep_dims=False,
+ name=None):
+ """Computes the "logical and" of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `reduction_indices`.
+ Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `reduction_indices` has no entries, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ For example:
+
+ ```python
+ # 'x' is [[True, True]]
+ # [False, False]]
+ tf.reduce_all(x) ==> False
+ tf.reduce_all(x, 0) ==> [False, False]
+ tf.reduce_all(x, 1) ==> [True, False]
+ ```
+
+ Args:
+ input_tensor: The boolean tensor to reduce.
+ reduction_indices: The dimensions to reduce. If `None` (the defaut),
+ reduces all dimensions.
+ keep_dims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+ """
+ return gen_math_ops._all(input_tensor, _ReductionDims(input_tensor,
+ reduction_indices),
+ keep_dims, name=name)
+
+
+def reduce_any(input_tensor, reduction_indices=None, keep_dims=False,
+ name=None):
+ """Computes the "logical or" of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `reduction_indices`.
+ Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `reduction_indices` has no entries, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ For example:
+
+ ```python
+ # 'x' is [[True, True]]
+ # [False, False]]
+ tf.reduce_any(x) ==> True
+ tf.reduce_any(x, 0) ==> [True, True]
+ tf.reduce_any(x, 1) ==> [True, False]
+ ```
+
+ Args:
+ input_tensor: The boolean tensor to reduce.
+ reduction_indices: The dimensions to reduce. If `None` (the defaut),
+ reduces all dimensions.
+ keep_dims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+ """
+ return gen_math_ops._any(input_tensor, _ReductionDims(input_tensor,
+ reduction_indices),
+ keep_dims, name=name)
+
+
+def matmul(a, b,
+ transpose_a=False, transpose_b=False,
+ a_is_sparse=False, b_is_sparse=False,
+ name=None):
+ """Multiplies matrix `a` by matrix `b`, producing `a` * `b`.
+
+ The inputs must be two-dimensional matrices, with matching inner dimensions,
+ possibly after transposition.
+
+ Both matrices must be of the same type. The supported types are:
+ `float`, `double`, `int32`, `complex64`.
+
+ Either matrix can be transposed on the fly by setting the corresponding flag
+ to `True`. This is `False` by default.
+
+ If one or both of the matrices contain a lot of zeros, a more efficient
+ multiplication algorithm can be used by setting the corresponding
+ `a_is_sparse` or `b_is_sparse` flag to `True`. These are `False` by default.
+
+ For example:
+
+ ```python
+ # 2-D tensor `a`
+ a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3]) => [[1. 2. 3.]
+ [4. 5. 6.]]
+ # 2-D tensor `b`
+ b = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2]) => [[7. 8.]
+ [9. 10.]
+ [11. 12.]]
+ c = tf.matmul(a, b) => [[58 64]
+ [139 154]]
+ ```
+
+ Args:
+ a: `Tensor` of type `float`, `double`, `int32` or `complex64`.
+ b: `Tensor` with same type as `a`.
+ transpose_a: If `True`, `a` is transposed before multiplication.
+ transpose_b: If `True`, `b` is transposed before multiplication.
+ a_is_sparse: If `True`, `a` is treated as a sparse matrix.
+ b_is_sparse: If `True`, `b` is treated as a sparse matrix.
+ name: Name for the operation (optional).
+
+ Returns:
+ A `Tensor` of the same type as `a`.
+ """
+ with ops.op_scope([a, b], name, "MatMul") as name:
+ a = ops.convert_to_tensor(a, name="a")
+ b = ops.convert_to_tensor(b, name="b")
+ if a.dtype == types.float32 and (a_is_sparse or b_is_sparse):
+ return sparse_matmul(a, b,
+ transpose_a=transpose_a,
+ transpose_b=transpose_b,
+ a_is_sparse=a_is_sparse,
+ b_is_sparse=b_is_sparse,
+ name=name)
+ else:
+ return gen_math_ops._mat_mul(a, b,
+ transpose_a=transpose_a,
+ transpose_b=transpose_b,
+ name=name)
+
+sparse_matmul = gen_math_ops._sparse_mat_mul
+batch_matmul = gen_math_ops._batch_mat_mul
+
+ops.RegisterShape("MatMul")(common_shapes.matmul_shape)
+ops.RegisterShape("SparseMatMul")(common_shapes.matmul_shape)
+
+
+def _as_indexed_slices(x):
+ """Convert 'x' to IndexedSlices.
+
+ Convert a dense Tensor to a block-sparse IndexedSlices.
+
+ Args:
+ x: Either a Tensor object, or an IndexedSlices object.
+
+ Returns:
+ An IndexedSlices object.
+
+ Raises:
+ TypeError: If 'x' is not a Tensor or an IndexedSlices object.
+ """
+ # TODO(mdevin): op_scope
+ if not isinstance(x, (ops.Tensor, ops.IndexedSlices)):
+ raise TypeError("Not a Tensor or IndexedSlices: %s" % type(x))
+ if isinstance(x, ops.IndexedSlices):
+ return x
+ x_shape = array_ops.shape(x)
+ return ops.IndexedSlices(x, range(0, x_shape[0]), x_shape)
+
+
+def _as_indexed_slices_list(inputs):
+ """Convert all elements of 'inputs' to IndexedSlices.
+
+ Additionally, homogenize the types of all the indices to
+ either int32 or int64.
+
+ Args:
+ inputs: List containing either Tensor or IndexedSlices objects.
+
+ Returns:
+ A list of IndexedSlices objects.
+
+ Raises:
+ TypeError: If 'inputs' is not a list or a tuple.
+ """
+ if not isinstance(inputs, (list, tuple)):
+ raise TypeError("Expected a list or tuple, not a %s" % type(inputs))
+ outputs = [_as_indexed_slices(i) for i in inputs]
+ with_int32_index = [o.indices for o in outputs
+ if o.indices.dtype == types.int32]
+ if not with_int32_index or len(with_int32_index) == len(outputs):
+ return outputs
+ casted_outputs = []
+ for o in outputs:
+ if o.indices.dtype == types.int32:
+ casted_outputs.append(
+ ops.IndexedSlices(o.values, cast(o.indices, types.int64),
+ o.dense_shape))
+ else:
+ casted_outputs.append(o)
+ return casted_outputs
+
+
+def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
+ """Returns the element-wise sum of a list of tensors.
+
+ Optionally, pass `shape` and `tensor_dtype` for shape and type checking,
+ otherwise, these are inferred.
+
+ For example:
+
+ ```python
+ # tensor 'a' is [[1, 2], [3, 4]
+ # tensor `b` is [[5, 0], [0, 6]]
+ tf.accumulate_n([a, b, a]) ==> [[7, 4], [6, 14]]
+
+ # Explicitly pass shape and type
+ tf.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32)
+ ==> [[7, 4], [6, 14]]
+ ```
+
+ Args:
+ inputs: A list of `Tensor` objects, each with same shape and type.
+ shape: Shape of elements of `inputs`.
+ tensor_dtype: The type of `inputs`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of same shape and type as the elements of `inputs`.
+
+ Raises:
+ ValueError: If `inputs` don't all have same shape and dtype or the shape
+ cannot be inferred.
+ """
+ if tensor_dtype is None:
+ if not inputs or not isinstance(inputs, (list, tuple)):
+ raise ValueError("inputs must be a list of at least one Tensor with the "
+ "same dtype and shape")
+ inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
+ if not all(isinstance(x, ops.Tensor) for x in inputs):
+ raise ValueError("inputs must be a list of at least one Tensor with the "
+ "same dtype and shape")
+ if not all(x.dtype == inputs[0].dtype for x in inputs):
+ raise ValueError("inputs must be a list of at least one Tensor with the "
+ "same dtype and shape")
+ tensor_dtype = inputs[0].dtype
+ if shape is not None:
+ shape = tensor_shape.as_shape(shape)
+ else:
+ shape = tensor_shape.unknown_shape()
+ for input_tensor in inputs:
+ if isinstance(input_tensor, ops.Tensor):
+ shape = shape.merge_with(input_tensor.get_shape())
+ if not shape.is_fully_defined():
+ # TODO(pbar): Make a version of assign_add that accepts an uninitialized
+ # lvalue, and takes its shape from that? This would allow accumulate_n to
+ # work in all situations that add_n currently works.
+ raise ValueError("Cannot infer the shape of the accumulator for "
+ "accumulate_n. Pass the shape argument, or set the shape "
+ "of at least one of the inputs.")
+ with ops.op_scope(inputs, name, "AccumulateN") as name:
+ var = gen_state_ops._temporary_variable(shape=shape, dtype=tensor_dtype)
+ var_name = var.op.name
+ var = state_ops.assign(var, array_ops.zeros_like(inputs[0]))
+ update_ops = []
+ for input_tensor in inputs:
+ op = state_ops.assign_add(var, input_tensor, use_locking=True)
+ update_ops.append(op)
+ with ops.control_dependencies(update_ops):
+ return gen_state_ops._destroy_temporary_variable(var,
+ var_name=var_name,
+ name=name)
+
+
+@ops.RegisterShape("BatchMatMul")
+def _BatchMatMulShape(op):
+ """Shape function for BatchMatMul op."""
+ a_shape = op.inputs[0].get_shape()
+ adj_a = op.get_attr("adj_x")
+ b_shape = op.inputs[1].get_shape()
+ adj_b = op.get_attr("adj_y")
+ if not a_shape.is_fully_defined() or not b_shape.is_fully_defined():
+ return [tensor_shape.unknown_shape()]
+ batch_dims = a_shape[:-2].merge_with(b_shape[:-2])
+ output_rows = a_shape[-1] if adj_a else a_shape[-2]
+ output_cols = b_shape[-2] if adj_b else b_shape[-1]
+ inner_a = a_shape[-2] if adj_a else a_shape[-1]
+ inner_b = b_shape[-1] if adj_b else b_shape[-2]
+ inner_a.assert_is_compatible_with(inner_b)
+ return [batch_dims.concatenate([output_rows, output_cols])]
+
+
+def sigmoid(x, name=None):
+ """Computes sigmoid of `x` element-wise.
+
+ Specifically, `y = 1 / (1 + exp(-x))`.
+
+ Args:
+ x: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`,
+ or `qint32`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A Tensor with the same type as `x` if `x.dtype != qint32`
+ otherwise the return type is `quint8`.
+ """
+ with ops.op_scope([x], name, "Sigmoid") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ return gen_math_ops._sigmoid(x, name=name)
+
+
+def tanh(x, name=None):
+ """Computes hyperbolic tangent of `x` element-wise.
+
+ Args:
+ x: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`,
+ or `qint32`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A Tensor with the same type as `x` if `x.dtype != qint32` otherwise
+ the return type is `quint8`.
+ """
+ with ops.op_scope([x], name, "Tanh") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ return gen_math_ops._tanh(x, name=name)
+
+
+ops.RegisterShape("Abs")(common_shapes.unchanged_shape)
+ops.RegisterShape("Ceil")(common_shapes.unchanged_shape)
+ops.RegisterShape("Conj")(common_shapes.unchanged_shape)
+ops.RegisterShape("Cos")(common_shapes.unchanged_shape)
+ops.RegisterShape("Exp")(common_shapes.unchanged_shape)
+ops.RegisterShape("Floor")(common_shapes.unchanged_shape)
+ops.RegisterShape("Imag")(common_shapes.unchanged_shape)
+ops.RegisterShape("Inv")(common_shapes.unchanged_shape)
+ops.RegisterShape("IsFinite")(common_shapes.unchanged_shape)
+ops.RegisterShape("IsInf")(common_shapes.unchanged_shape)
+ops.RegisterShape("IsNan")(common_shapes.unchanged_shape)
+ops.RegisterShape("Log")(common_shapes.unchanged_shape)
+ops.RegisterShape("LogicalNot")(common_shapes.unchanged_shape)
+ops.RegisterShape("Neg")(common_shapes.unchanged_shape)
+ops.RegisterShape("Real")(common_shapes.unchanged_shape)
+ops.RegisterShape("Rsqrt")(common_shapes.unchanged_shape)
+ops.RegisterShape("Sign")(common_shapes.unchanged_shape)
+ops.RegisterShape("Sin")(common_shapes.unchanged_shape)
+ops.RegisterShape("Sqrt")(common_shapes.unchanged_shape)
+ops.RegisterShape("Square")(common_shapes.unchanged_shape)
+ops.RegisterShape("Sigmoid")(common_shapes.unchanged_shape)
+ops.RegisterShape("Tanh")(common_shapes.unchanged_shape)
+ops.RegisterShape("Cast")(common_shapes.unchanged_shape)
+ops.RegisterShape("ComplexAbs")(common_shapes.unchanged_shape)
+
+
+@ops.RegisterShape("Add")
+@ops.RegisterShape("Complex")
+@ops.RegisterShape("Div")
+@ops.RegisterShape("Equal")
+@ops.RegisterShape("Greater")
+@ops.RegisterShape("GreaterEqual")
+@ops.RegisterShape("Less")
+@ops.RegisterShape("LessEqual")
+@ops.RegisterShape("LogicalAnd")
+@ops.RegisterShape("LogicalOr")
+@ops.RegisterShape("Maximum")
+@ops.RegisterShape("Minimum")
+@ops.RegisterShape("Mod")
+@ops.RegisterShape("Mul")
+@ops.RegisterShape("NotEqual")
+@ops.RegisterShape("Pow")
+@ops.RegisterShape("Sub")
+def _BroadcastShape(op):
+ """Common shape function for binary operators that broadcast their inputs."""
+ shape_x = op.inputs[0].get_shape()
+ shape_y = op.inputs[1].get_shape()
+ if shape_x.ndims is None or shape_y.ndims is None:
+ return [tensor_shape.unknown_shape()]
+
+ # To compute the broadcasted dimensions, we zip together shape_x and shape_y,
+ # and pad with 1 to make them the same length.
+ broadcasted_dims = reversed(list(itertools.izip_longest(
+ reversed(shape_x.dims), reversed(shape_y.dims),
+ fillvalue=tensor_shape.Dimension(1))))
+ # Next we combine the dimensions according to the numpy broadcasting rules.
+ # http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ return_dims = []
+ for (dim_x, dim_y) in broadcasted_dims:
+ if dim_x.value is None or dim_y.value is None:
+ # One or both dimensions is unknown. If either dimension is greater than
+ # 1, we assume that the program is correct, and the other dimension will
+ # be broadcast to match it.
+ # TODO(mrry): If we eliminate the shape checks in C++, we must still
+ # assert that the unknown dim is either 1 or the same as the known dim.
+ if dim_x.value is not None and dim_x.value > 1:
+ return_dims.append(dim_x)
+ elif dim_y.value is not None and dim_y.value > 1:
+ return_dims.append(dim_y)
+ else:
+ return_dims.append(None)
+ elif dim_x.value == 1:
+ # We will broadcast dim_x to dim_y.
+ return_dims.append(dim_y)
+ elif dim_y.value == 1:
+ # We will broadcast dim_y to dim_x.
+ return_dims.append(dim_x)
+ elif dim_x.value == dim_y.value:
+ # The dimensions are compatible, so output is the same size in that
+ # dimension.
+ return_dims.append(dim_x.merge_with(dim_y))
+ else:
+ raise ValueError("Incompatible shapes for broadcasting: %s and %s"
+ % (shape_x, shape_y))
+ return [tensor_shape.TensorShape(return_dims)]
+
+
+@ops.RegisterShape("AddN")
+def _AddNShape(op):
+ merged_shape = tensor_shape.unknown_shape()
+ for input_ in op.inputs:
+ merged_shape = merged_shape.merge_with(input_.get_shape())
+ return [merged_shape]
+
+
+@ops.RegisterShape("Select")
+def _SelectShape(op):
+ # All three inputs must have the same shape.
+ return [op.inputs[0].get_shape()
+ .merge_with(op.inputs[1].get_shape())
+ .merge_with(op.inputs[2].get_shape())]
+
+
+@ops.RegisterShape("ArgMax")
+@ops.RegisterShape("ArgMin")
+def _ArgOpShape(op):
+ """Common shape function for arg-reduction ops."""
+ dimension_shape = op.inputs[1].get_shape()
+ dimension_shape.assert_is_compatible_with(tensor_shape.scalar())
+ input_shape = op.inputs[0].get_shape()
+ if input_shape.ndims is None:
+ return [tensor_shape.unknown_shape()]
+ elif input_shape.ndims <= 1:
+ return [tensor_shape.scalar()]
+
+ dimension = tensor_util.ConstantValue(op.inputs[1])
+ if dimension is None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims - 1)]
+ elif 0 <= dimension and dimension < input_shape.ndims:
+ returned_shape = []
+ for i, dim in enumerate(input_shape.dims):
+ if i != dimension:
+ returned_shape.append(dim)
+ return [tensor_shape.TensorShape(returned_shape)]
+ else:
+ raise ValueError(
+ "dimension (%d) must be in the range [0, %d), where %d is the number "
+ "of dimensions in the input"
+ % (dimension, input_shape.ndims, input_shape.ndims))
+
+
+@ops.RegisterShape("All")
+@ops.RegisterShape("Any")
+@ops.RegisterShape("Max")
+@ops.RegisterShape("Mean")
+@ops.RegisterShape("Min")
+@ops.RegisterShape("Prod")
+@ops.RegisterShape("Sum")
+def _ReductionShape(op):
+ """Common shape function for reduction ops."""
+ input_shape = op.inputs[0].get_shape()
+ reduction_indices = tensor_util.ConstantValue(op.inputs[1])
+ keep_dims = op.get_attr("keep_dims")
+ if reduction_indices is None or input_shape.ndims is None:
+ if keep_dims:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+ # Turn reduction_indices from scalar to vector if necessary
+ reduction_indices = np.ravel(reduction_indices)
+
+ for reduction_index in reduction_indices:
+ if reduction_index < 0 or reduction_index >= input_shape.ndims:
+ raise ValueError("Invalid reduction dimension %d for input with %d "
+ "dimensions" % (reduction_index, input_shape.ndims))
+
+ returned_dims = []
+ if keep_dims:
+ for i, dim in enumerate(input_shape.dims):
+ if i in reduction_indices:
+ returned_dims.append(1)
+ else:
+ returned_dims.append(dim)
+ else:
+ for i, dim in enumerate(input_shape.dims):
+ if i not in reduction_indices:
+ returned_dims.append(dim)
+ return [tensor_shape.TensorShape(returned_dims)]
+
+
+@ops.RegisterShape("SegmentMax")
+@ops.RegisterShape("SegmentMean")
+@ops.RegisterShape("SegmentMin")
+@ops.RegisterShape("SegmentProd")
+@ops.RegisterShape("SegmentSum")
+def _SegmentReductionShape(op):
+ """Common shape function for segment reduction ops."""
+ data_shape = op.inputs[0].get_shape()
+ segment_ids_shape = op.inputs[1].get_shape()
+ segment_ids_shape.assert_has_rank(1)
+ return [tensor_shape.TensorShape([None]).concatenate(data_shape[1:])]
+
+
+@ops.RegisterShape("SparseSegmentMean")
+@ops.RegisterShape("SparseSegmentSum")
+def _SparseSegmentReductionShape(op):
+ """Common shape function for sparse segment reduction ops."""
+ data_shape = op.inputs[0].get_shape()
+ indices_shape = op.inputs[1].get_shape()
+ indices_shape.assert_has_rank(1)
+ segment_ids_shape = op.inputs[2].get_shape()
+ segment_ids_shape.assert_has_rank(1)
+ indices_shape.assert_is_compatible_with(segment_ids_shape)
+ return [tensor_shape.TensorShape([None]).concatenate(data_shape[1:])]
+
+
+@ops.RegisterShape("SparseSegmentMeanGrad")
+def _SparseSegmentMeanGradShape(op):
+ """Shape function for the SparseSegmentMeanGrad op."""
+ input_shape = op.inputs[0].get_shape()
+ indices_shape = op.inputs[1].get_shape().with_rank(1)
+ unused_segment_ids_shape = op.inputs[2].get_shape().merge_with(indices_shape)
+ unused_output_dim0_shape = op.inputs[3].get_shape().merge_with(
+ tensor_shape.scalar())
+ output_dim0 = tensor_util.ConstantValue(op.inputs[3])
+ if output_dim0 is not None:
+ dim0 = output_dim0[0]
+ else:
+ dim0 = None
+ return [tensor_shape.TensorShape([dim0]).concatenate(input_shape[1:])]
+
+
+@ops.RegisterShape("UnsortedSegmentSum")
+def _UnsortedSegmentSumShape(op):
+ """Shape function for UnsortedSegmentSum."""
+ data_shape = op.inputs[0].get_shape()
+ segment_ids_shape = op.inputs[1].get_shape()
+ mid = segment_ids_shape.ndims
+ if mid is None:
+ return [tensor_shape.unknown_shape()]
+ else:
+ num_segments = tensor_util.ConstantValue(op.inputs[2])
+ return [tensor_shape.TensorShape([num_segments]).concatenate(
+ data_shape[mid:])]
+
+
+@ops.RegisterShape("LinSpace")
+def _LinspaceShape(op):
+ num = tensor_util.ConstantValue(op.inputs[2])
+ return [tensor_shape.vector(num)]
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
new file mode 100644
index 0000000000..86ea04f54d
--- /dev/null
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -0,0 +1,68 @@
+"""Tests for tensorflow.ops.math_ops."""
+import math
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+exp = math.exp
+log = math.log
+
+class ReduceTest(test_util.TensorFlowTestCase):
+
+ def testReduceAllDims(self):
+ x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
+ with self.test_session():
+ y_tf = math_ops.reduce_sum(x).eval()
+ self.assertEqual(y_tf, 21)
+
+class RoundTest(test_util.TensorFlowTestCase):
+
+ def testRounding(self):
+ x = [0.49, 0.7, -0.3, -0.8]
+ for dtype in [np.float32, np.double]:
+ x_np = np.array(x, dtype=dtype)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y_tf = math_ops.round(x_tf)
+ y_tf_np = y_tf.eval()
+ y_np = np.round(x_np)
+ self.assertAllClose(y_tf_np, y_np, atol=1e-2)
+
+
+class ModTest(test_util.TensorFlowTestCase):
+
+ def testFloat(self):
+ x = [0.5, 0.7, 0.3]
+ for dtype in [np.float32, np.double]:
+ # Test scalar and vector versions.
+ for denom in [x[0], [x[0]] * 3]:
+ x_np = np.array(x, dtype=dtype)
+ with self.test_session():
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y_tf = math_ops.mod(x_tf, denom)
+ y_tf_np = y_tf.eval()
+ y_np = np.fmod(x_np, denom)
+ self.assertAllClose(y_tf_np, y_np, atol=1e-2)
+
+ def testFixed(self):
+ x = [5, 10, 23]
+ for dtype in [np.int32, np.int64]:
+ # Test scalar and vector versions.
+ for denom in [x[0], x]:
+ x_np = np.array(x, dtype=dtype)
+ with self.test_session():
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y_tf = math_ops.mod(x_tf, denom)
+ y_tf_np = y_tf.eval()
+ y_np = np.mod(x_np, denom)
+ self.assertAllClose(y_tf_np, y_np)
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
new file mode 100644
index 0000000000..7a4dc25e8b
--- /dev/null
+++ b/tensorflow/python/ops/nn.py
@@ -0,0 +1,816 @@
+# pylint: disable=wildcard-import,unused-import,g-bad-import-order
+"""## Activation Functions
+
+The activation ops provide different types of nonlinearities for use in
+neural networks. These include smooth nonlinearities (`sigmoid`,
+`tanh`, and `softplus`), continuous but not everywhere differentiable
+functions (`relu`, `relu6`, and `relu_x`), and random regularization
+(`dropout`).
+
+All activation ops apply componentwise, and produce a tensor of the same
+shape as the input tensor.
+
+@@relu
+@@relu6
+@@softplus
+@@dropout
+@@bias_add
+@@sigmoid
+@@tanh
+
+## Convolution
+
+The convolution ops sweep a 2-D filter over a batch of images, applying the
+filter to each window of each image of the appropriate size. The different
+ops trade off between generic vs. specific filters:
+
+* `conv2d`: Arbitrary filters that can mix channels together.
+* `depthwise_conv2d`: Filters that operate on each channel independently.
+* `separable_conv2d`: A depthwise spatial filter followed by a pointwise filter.
+
+Note that although these ops are called "convolution", they are strictly
+speaking "cross-correlation" since the filter is combined with an input window
+without reversing the filter. For details, see [the properties of
+cross-correlation](https://en.wikipedia.org/wiki/Cross-correlation#Properties).
+
+The filter is applied to image patches of the same size as the filter and
+strided according to the `strides` argument. `strides = [1, 1, 1, 1]` applies
+the filter to a patch at every offset, `strides = [1, 2, 2, 1]` applies the
+filter to every other image patch in each dimension, etc.
+
+Ignoring channels for the moment, the spatial semantics of the convolution ops
+are as follows. If the 4-D `input` has shape
+`[batch, in_height, in_width, ...]` and the 4-D `filter` has shape
+`[filter_height, filter_width, ...]`, then
+
+ output.shape = [batch,
+ (in_height - filter_height + 1) / strides[1],
+ (in_width - filter_width + 1) / strides[2],
+ ...]
+
+ output[b, i, j, :] =
+ sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, ...] *
+ filter[di, dj, ...]
+
+Since `input` is 4-D, each `input[b, i, j, :]` is a vector. For `conv2d`, these
+vectors are multiplied by the `filter[di, dj, :, :]` matrices to produce new
+vectors. For `depthwise_conv_2d`, each scalar component `input[b, i, j, k]`
+is multiplied by a vector `filter[di, dj, k]`, and all the vectors are
+concatenated.
+
+In the formula for `output.shape`, the rounding direction depends on padding:
+
+* `padding = 'SAME'`: Round down (only full size windows are considered).
+* `padding = 'VALID'`: Round up (partial windows are included).
+
+@@conv2d
+@@depthwise_conv2d
+@@separable_conv2d
+
+## Pooling
+
+The pooling ops sweep a rectangular window over the input tensor, computing a
+reduction operation for each window (average, max, or max with argmax). Each
+pooling op uses rectangular windows of size `ksize` separated by offset
+`strides`. For example, if `strides` is all ones every window is used, if
+`strides` is all twos every other window is used in each dimension, etc.
+
+In detail, the output is
+
+ output[i] = reduce(value[strides * i:strides * i + ksize])
+
+for each tuple of indices `i`. The output shape is
+
+ output.shape = (value.shape - ksize + 1) / strides
+
+where the rounding direction depends on padding:
+
+* `padding = 'SAME'`: Round down (only full size windows are considered).
+* `padding = 'VALID'`: Round up (partial windows are included).
+
+@@avg_pool
+@@max_pool
+@@max_pool_with_argmax
+
+## Normalization
+
+Normalization is useful to prevent neurons from saturating when inputs may
+have varying scale, and to aid generalization.
+
+@@l2_normalize
+@@local_response_normalization
+@@moments
+
+## Losses
+
+The loss ops measure error between two tensors, or between a tensor and zero.
+These can be used for measuring accuracy of a network in a regression task
+or for regularization purposes (weight decay).
+
+@@l2_loss
+
+## Classification
+
+TensorFlow provides several operations that help you perform classification.
+
+@@sigmoid_cross_entropy_with_logits
+@@softmax
+@@softmax_cross_entropy_with_logits
+
+## Embeddings
+
+TensorFlow provides several operations that help you compute embeddings.
+
+@@embedding_lookup
+@@embedding_lookup_sparse
+
+## Evaluation
+
+The evaluation ops are useful for measuring the performance of a network.
+Since they are nondifferentiable, they are typically used at evaluation time.
+
+@@top_k
+@@in_top_k
+
+## Candidate Sampling
+
+Do you want to train a multiclass or multilabel model with thousands
+or millions of output classes (for example, a language model with a
+large vocabulary)? Training with a full Softmax is slow in this case,
+since all of the classes are evaluated for every training example.
+Candidate Sampling training algorithms can speed up your step times by
+only considering a small randomly-chosen subset of contrastive classes
+(called candidates) for each batch of training examples.
+
+See our [Candidate Sampling Algorithms Reference]
+(http://www.tensorflow.org/extras/candidate_sampling.pdf)
+
+### Sampled Loss Functions
+
+TensorFlow provides the following sampled loss functions for faster training.
+
+@@nce_loss
+@@sampled_softmax_loss
+
+### Candidate Samplers
+
+TensorFlow provides the following samplers for randomly sampling candidate
+classes when using one of the sampled loss functions above.
+
+@@uniform_candidate_sampler
+@@log_uniform_candidate_sampler
+@@learned_unigram_candidate_sampler
+@@fixed_unigram_candidate_sampler
+
+### Miscellaneous candidate sampling utilities
+
+@@compute_accidental_hits
+
+"""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import candidate_sampling_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_grad
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import numerics
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops.math_ops import sigmoid
+from tensorflow.python.ops.math_ops import tanh
+
+# Bring more nn-associated functionality into this package.
+from tensorflow.python.ops.nn_ops import *
+from tensorflow.python.ops.candidate_sampling_ops import *
+from tensorflow.python.ops.embedding_ops import *
+
+
+def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
+ """Computes sigmoid cross entropy given `logits`.
+
+ Measures the probability error in discrete classification tasks in which each
+ class is independent and not mutually exclusive. For instance, one could
+ perform multilabel classification where a picture can contain both an elephant
+ and a dog at the same time.
+
+ For brevity, let `x = logits`, `z = targets`. The logistic loss is
+
+ x - x * z + log(1 + exp(-x))
+
+ To ensure stability and avoid overflow, the implementation uses
+
+ max(x, 0) - x * z + log(1 + exp(-abs(x)))
+
+ `logits` and `targets` must have the same type and shape.
+
+ Args:
+ logits: A `Tensor` of type `float32` or `float64`.
+ targets: A `Tensor` of the same type and shape as `logits`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of the same shape as `logits` with the componentwise
+ logistic losses.
+ """
+ with ops.op_scope([logits, targets], name, "logistic_loss") as name:
+ logits = ops.convert_to_tensor(logits, name="logits")
+ targets = ops.convert_to_tensor(targets, name="targets")
+ # The logistic loss formula from above is
+ # x - x * z + log(1 + exp(-x))
+ # For x < 0, a more numerically stable formula is
+ # -x * z + log(1 + exp(x))
+ # To avoid branching, we use the combined version
+ # max(x, 0) - x * z + log(1 + exp(-abs(x)))
+ return math_ops.add(nn_ops.relu(logits) - logits * targets,
+ math_ops.log(1 + math_ops.exp(-math_ops.abs(logits))),
+ name=name)
+
+
+def xw_plus_b(x, weights, biases, name=None):
+ """Computes matmul(x, weights) + biases.
+
+ Args:
+ x: a 2D tensor. Dimensions typically: batch, in_units
+ weights: a 2D tensor. Dimensions typically: in_units, out_units
+ biases: a 1D tensor. Dimensions: out_units
+ name: A name for the operation (optional). If not specified
+ "wx_plus_b" is used.
+
+ Returns:
+ A 2-D Tensor computing matmul(x, weights) + biases.
+ Dimensions typically: batch, out_units.
+ """
+ with ops.op_scope([x, weights, biases], name, "xw_plus_b") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ weights = ops.convert_to_tensor(weights, name="weights")
+ biases = ops.convert_to_tensor(biases, name="biases")
+ mm = math_ops.matmul(x, weights)
+ return nn_ops.bias_add(mm, biases, name=name)
+
+
+def relu_layer(x, weights, biases, name=None):
+ """Computes Relu(x * weight + biases).
+
+ Args:
+ x: a 2D tensor. Dimensions typically: batch, in_units
+ weights: a 2D tensor. Dimensions typically: in_units, out_units
+ biases: a 1D tensor. Dimensions: out_units
+ name: A name for the operation (optional). If not specified
+ "nn_relu_layer" is used.
+
+ Returns:
+ A 2-D Tensor computing relu(matmul(x, weights) + biases).
+ Dimensions typically: batch, out_units.
+ """
+ with ops.op_scope([x, weights, biases], name, "relu_layer") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ weights = ops.convert_to_tensor(weights, name="weights")
+ biases = ops.convert_to_tensor(biases, name="biases")
+ xw_plus_b = nn_ops.bias_add(math_ops.matmul(x, weights), biases)
+ return nn_ops.relu(xw_plus_b, name=name)
+
+
+def l2_normalize(x, dim, epsilon=1e-12, name=None):
+ """Normalizes along dimension `dim` using an L2 norm.
+
+ For a 1-D tensor with `dim = 0`, computes
+
+ output = x / sqrt(max(sum(x**2), epsilon))
+
+ For `x` with more dimensions, independently normalizes each 1-D slice along
+ dimension `dim`.
+
+ Args:
+ x: A `Tensor`.
+ dim: Dimension along which to normalize.
+ epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the
+ divisor if `norm < sqrt(epsilon)`.
+ name: A name for this operation (optional).
+
+ Returns:
+ A `Tensor` with the same shape as `x`.
+ """
+ with ops.op_scope([x], name, "l2_normalize") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ square_sum = math_ops.reduce_sum(math_ops.square(x), [dim], keep_dims=True)
+ x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
+ return math_ops.mul(x, x_inv_norm, name=name)
+
+
+def zero_fraction(value, name=None):
+ """Returns the fraction of zeros in `value`.
+
+ If `value` is empty, the result is `nan`.
+
+ This is useful in summaries to measure and report sparsity. For example,
+
+ z = tf.Relu(...)
+ summ = tf.scalar_summary('sparsity', tf.zero_fraction(z))
+
+ Args:
+ value: A tensor of numeric type.
+ name: A name for the operation (optional).
+
+ Returns:
+ The fraction of zeros in `value`, with type `float32`.
+ """
+ with ops.op_scope([value], name, "zero_fraction"):
+ value = ops.convert_to_tensor(value, name="value")
+ zero = constant_op.constant(0, dtype=value.dtype, name="zero")
+ return math_ops.reduce_mean(math_ops.cast(math_ops.equal(value, zero),
+ types.float32))
+
+
+def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
+ """Computes dropout.
+
+ With probability `keep_prob`, outputs the input element scaled up by
+ `1 / keep_prob`, otherwise outputs `0`. The scaling is so that the expected
+ sum is unchanged.
+
+ By default, each element is kept or dropped independently. If `noise_shape`
+ is specified, it must be
+ [broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+ to the shape of `x`, and only dimensions with `noise_shape[i] == x.shape[i]`
+ will make independent decisions. For example, if `x.shape = [b, x, y, c]` and
+ `noise_shape = [b, 1, 1, c]`, each batch and channel component will be
+ kept independently and each row and column will be kept or not kept together.
+
+ Args:
+ x: A tensor.
+ keep_prob: Float probability that each element is kept.
+ noise_shape: Shape for randomly generated keep/drop flags.
+ seed: A Python integer. Used to create a random seed.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ name: A name for this operation (optional).
+
+ Returns:
+ A Tensor of the same shape of `x`.
+
+ Raises:
+ ValueError: If `keep_prob` is not in `(0, 1]`.
+ """
+ if not (0 < keep_prob <= 1):
+ raise ValueError("Expected keep_prob in (0, 1], got %g" % keep_prob)
+ with ops.op_scope([x], name, "dropout") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ noise_shape = noise_shape or array_ops.shape(x)
+ # uniform [keep_prob, 1.0 + keep_prob)
+ random_tensor = keep_prob
+ random_tensor += random_ops.random_uniform(
+ noise_shape, seed=seed, dtype=x.dtype)
+ # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
+ binary_tensor = math_ops.floor(random_tensor)
+ return x * (1.0 / keep_prob) * binary_tensor
+
+
+def depthwise_conv2d(input, filter, strides, padding, name=None):
+ """Depthwise 2-D convolution.
+
+ Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
+ and a filter tensor of shape
+ `[filter_height, filter_width, in_channels, channel_multiplier]`
+ containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
+ applies a different filter to each input channel (expanding from 1 channel
+ to `channel_multiplier` channels for each), then concatenates the results
+ together. The output has `in_channels * channel_multiplier` channels.
+
+ In detail,
+
+ output[b, i, j, k * channel_multiplier + q] =
+ sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] *
+ filter[di, dj, k, q]
+
+ Must have `strides[0] = strides[3] = 1`. For the most common case of the
+ same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
+
+ Args:
+ input: 4-D with shape `[batch, in_height, in_width, in_channels]`.
+ filter: 4-D with shape
+ `[filter_height, filter_width, in_channels, channel_multiplier]`.
+ strides: 1-D of size 4. The stride of the sliding window for each
+ dimension of `input`.
+ padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
+ name: A name for this operation (optional).
+
+ Returns:
+ A 4-D `Tensor` of shape
+ `[batch, out_height, out_width, in_channels * channel_multiplier].`
+ """
+ with ops.op_scope([input, filter], name, "depthwise") as name:
+ input = ops.convert_to_tensor(input, name="tensor_in")
+ filter = ops.convert_to_tensor(filter, name="filter_in")
+ # A shape is required to statically compute the number of separable filters.
+ if filter.get_shape().ndims is not None:
+ assert len(filter.get_shape()) == 4
+ in_channels = filter.get_shape()[2]
+ # Sanity checks, if shape information is available for the inputs.
+ if input.get_shape().ndims is not None:
+ assert len(input.get_shape()) == 4
+ assert input.get_shape()[3] == in_channels, (
+ "Mismatched input depth %d and number of depthwise filters %d." % (
+ input.get_shape()[3].value, in_channels))
+ else:
+ assert input.get_shape().ndims is not None, (
+ "Either tensor must provide static shape information.")
+ assert input.get_shape().ndims == 4
+ in_channels = input.get_shape()[3]
+
+ if in_channels == 1:
+ return nn_ops.conv2d(input, filter, strides, padding, name=name)
+ else:
+ # Create one separate convolution per channel.
+ convs = []
+ for channel in xrange(in_channels):
+ with ops.name_scope("depth%d" % channel) as channel_scope:
+ t_in = array_ops.slice(input, [0, 0, 0, channel], [-1, -1, -1, 1],
+ name="slice_inputs")
+ f_in = array_ops.slice(filter, [0, 0, channel, 0], [-1, -1, 1, -1],
+ name="slice_params")
+ convs.append(nn_ops.conv2d(t_in, f_in,
+ strides, padding, name=channel_scope))
+ # Concatenate the per-channel convolutions along the channel dimension.
+ return array_ops.concat(3, convs, name=name)
+
+
+def separable_conv2d(input, depthwise_filter, pointwise_filter, strides,
+ padding,
+ name=None):
+ """2-D convolution with separable filters.
+
+ Performs a depthwise convolution that acts separately on channels followed by
+ a pointwise convolution that mixes channels. Note that this is separability
+ between dimensions `[1, 2]` and `3`, not spatial separability between
+ dimensions `1` and `2`.
+
+ In detail,
+
+ output[b, i, j, k] = sum_{di, dj, q, r]
+ input[b, strides[1] * i + di, strides[2] * j + dj, q] *
+ depthwise_filter[di, dj, q, r] *
+ pointwise_filter[0, 0, q * channel_multiplier + r, k]
+
+ `strides` controls the strides for the depthwise convolution only, since
+ the pointwise convolution has implicit strides of `[1, 1, 1, 1]`. Must have
+ `strides[0] = strides[3] = 1`. For the most common case of the same
+ horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
+
+ Args:
+ input: 4-D `Tensor` with shape `[batch, in_height, in_width, in_channels]`.
+ depthwise_filter: 4-D `Tensor` with shape
+ `[filter_height, filter_width, in_channels, channel_multiplier]`.
+ Contains `in_channels` convolutional filters of depth 1.
+ pointwise_filter: 4-D `Tensor` with shape
+ `[1, 1, channel_multiplier * in_channels, out_channels]`. Pointwise
+ filter to mix channels after `depthwise_filter` has convolved spatially.
+ strides: 1-D of size 4. The strides for the depthwise convolution for
+ each dimension of `input`.
+ padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
+ name: A name for this operation (optional).
+
+ Returns:
+ A 4-D `Tensor` of shape `[batch, out_height, out_width, out_channels]`.
+ """
+ with ops.op_scope([input, depthwise_filter, pointwise_filter],
+ name, "separable_conv2d") as name:
+ input = ops.convert_to_tensor(input, name="tensor_in")
+ depthwise_filter = ops.convert_to_tensor(depthwise_filter,
+ name="depthwise_filter")
+ pointwise_filter = ops.convert_to_tensor(pointwise_filter,
+ name="pointwise_filter")
+
+ if pointwise_filter.get_shape().ndims is not None:
+ assert len(pointwise_filter.get_shape()) == 4
+ assert pointwise_filter.get_shape()[0] == 1
+ assert pointwise_filter.get_shape()[1] == 1
+ if depthwise_filter.get_shape().ndims and input.get_shape().ndims:
+ channel_multiplier = depthwise_filter.get_shape()[3]
+ in_channels = input.get_shape()[3]
+ out_channels = pointwise_filter.get_shape()[3]
+ # This would mean the separable convolutions is over-parametrized.
+ assert channel_multiplier * in_channels < out_channels
+ # The layout of the ops in the graph are expected to be as follows:
+ # separable_conv2d // Conv2D op corresponding to the pointwise conv.
+ # separable_conv2d/depthwise // Concat op for the deptwise outputs.
+ # separable_conv2d/depthwise/depth0 // Conv2D op for depth 0
+ # separable_conv2d/depthwise/depth1 // Conv2D op for depth 1
+ # separable_conv2d/depthwise/depth2 // Conv2D op for depth 2
+ depthwise = depthwise_conv2d(input, depthwise_filter, strides,
+ padding, name="depthwise")
+ return nn_ops.conv2d(depthwise, pointwise_filter, [1, 1, 1, 1],
+ padding="VALID", name=name)
+
+
+def moments(x, axes, name=None):
+ """Calculate the mean and variance of `x`.
+
+ The mean and variance are calculated by aggregating the contents of `x`
+ across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean
+ and variance of a vector.
+
+ For so-called "global normalization" needed for convolutional filters pass
+ `axes=[0, 1, 2]` (batch, height, width). For batch normalization pass
+ `axes=[0]` (batch).
+
+ Args:
+ x: A `Tensor`.
+ axes: array of ints. Axes along which to compute mean and
+ variance.
+ name: Name used to scope the operations that compute the moments.
+
+ Returns:
+ Two `Tensors`: `mean` and `variance`.
+ """
+ with ops.op_scope([x, axes], name, "moments"):
+ x = ops.convert_to_tensor(x, name="x")
+ divisor = 1.0
+ for d in xrange(len(x.get_shape())):
+ if d in axes:
+ divisor *= x.get_shape()[d].value
+ divisor = constant_op.constant(1.0 / divisor, x.dtype, name="divisor")
+ axes = constant_op.constant(axes, name="axes")
+ # Note: We do not use Mean here because it is very slow on GPU.
+ # Note 2: The expression below is potentially more stable.
+ # It is however a bit slower and stability doesn't appear to be an issue.
+ # mean = math_ops.reduce_sum(math_ops.mul(x, divisor), axes, name="mean")
+ # var = math_ops.reduce_sum(math_ops.mul(math_ops.square(x - mean),
+ # divisor), axes,
+ # name="variance")
+ mean = math_ops.mul(math_ops.reduce_sum(x, axes), divisor, name="mean")
+ var = math_ops.mul(math_ops.reduce_sum(math_ops.square(x - mean), axes),
+ divisor, name="variance")
+ return mean, var
+
+
+def _sum_rows(x):
+ """Returns a vector summing up each row of the matrix x."""
+ # _sum_rows(x) is equivalent to math_ops.reduce_sum(x, 1) when x is
+ # a matrix. The gradient of _sum_rows(x) is more efficient than
+ # reduce_sum(x, 1)'s gradient in today's implementation. Therefore,
+ # we use _sum_rows(x) in the nce_loss() computation since the loss
+ # is mostly used for training.
+ cols = array_ops.shape(x)[1]
+ ones_shape = array_ops.pack([cols, 1])
+ ones = array_ops.ones(ones_shape, x.dtype)
+ return array_ops.reshape(math_ops.matmul(x, ones), [-1])
+
+
+def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
+ num_classes, num_true=1,
+ sampled_values=None,
+ subtract_log_q=True,
+ remove_accidental_hits=False,
+ name=None):
+ """Helper function for nce_loss and sampled_softmax_loss functions.
+
+ Computes sampled output training logits and labels suitable for implementing
+ e.g. noise-contrastive estimation (see nce_loss) or sampled softmax (see
+ sampled_softmax_loss).
+
+ Note: In the case where num_true > 1, we assign to each target class
+ the target probability 1 / num_true so that the target probabilities
+ sum to 1 per-example.
+
+ Args:
+ weights: tensor of label embeddings with shape = [num_classes, dim]
+ biases: tensor of num_classes label biases
+ inputs: tensor with shape = [batch_size, dim] corresponding to forward
+ activations of the input network
+ labels: int tensor with shape [batch_size, num_true]
+ num_sampled: number of label classes to sample per batch
+ num_classes: number of possible label classes in the data (e.g. vocab size)
+ num_true: number of target classes per example (default: 1)
+ sampled_values: a tuple of (sampled_candidates, true_expected_count,
+ sampled_expected_count) returned by a *CandidateSampler function to use
+ (if None, we default to LogUniformCandidateSampler)
+ subtract_log_q: subtract the log expected count of the labels in the sample
+ to get the logits of the true labels (default: True)
+ Turn off for Negative Sampling.
+ remove_accidental_hits: whether to remove "accidental hits" where a sampled
+ label equals the true labels (bool, default: False)
+ name: name for this op
+
+ Returns:
+ out_logits, out_labels: tensors with shape [batch_size, num_true +
+ num_sampled] for passing to either SigmoidCrossEntropyWithLogits (NCE)
+ or SoftmaxCrossEntropyWithLogits (sampled softmax).
+
+ """
+
+ with ops.op_scope(
+ [weights, biases, inputs, labels], name, "compute_sampled_logits"):
+ if labels.dtype != types.int64:
+ labels = math_ops.cast(labels, types.int64)
+ labels_flat = array_ops.reshape(labels, [-1])
+
+ # Sample the negative labels.
+ # sampled shape: num_sampled vector
+ # true_expected_count shape = [batch_size, 1]
+ # sampled_expected_count shape = num_sampled vector
+ if sampled_values is None:
+ sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
+ true_classes=labels,
+ num_true=num_true,
+ num_sampled=num_sampled,
+ unique=True,
+ range_max=num_classes)
+ # NOTE: pylint cannot tell that 'sampled_values' is a sequence
+ # pylint: disable=unpacking-non-sequence
+ sampled, true_expected_count, sampled_expected_count = sampled_values
+ # pylint: enable=unpacking-non-sequence
+
+ # weights shape is [num_classes, dim]
+ # labels_flat is a [batch_size * num_true] vector
+ # true_w shape is [batch_size * num_true, dim]
+ # true_b is a [batch_size * num_true] vector
+ true_w = embedding_ops.embedding_lookup(weights, labels_flat)
+ true_b = embedding_ops.embedding_lookup(biases, labels_flat)
+
+ # inputs shape is [batch_size, dim]
+ # true_w shape is [batch_size * num_true, dim]
+ # row_wise_dots is [batch_size, num_true, dim]
+ dim = array_ops.shape(true_w)[1:2]
+ new_true_w_shape = array_ops.concat(0, [[-1, num_true], dim])
+ row_wise_dots = math_ops.mul(
+ array_ops.expand_dims(inputs, 1),
+ array_ops.reshape(true_w, new_true_w_shape))
+ # We want the row-wise dot plus biases which yields a
+ # [batch_size, num_true] tensor of true_logits.
+ dots_as_matrix = array_ops.reshape(row_wise_dots,
+ array_ops.concat(0, [[-1], dim]))
+ true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
+ true_b = array_ops.reshape(true_b, [-1, num_true])
+ true_logits += true_b
+
+ # Lookup weights and biases for sampled labels.
+ # sampled is a num_sampled int vector
+ # sampled_w shape is [num_sampled, dim]
+ # sampled_b is a num_sampled float vector
+ sampled_w = embedding_ops.embedding_lookup(weights, sampled)
+ sampled_b = embedding_ops.embedding_lookup(biases, sampled)
+
+ # inputs has shape [batch_size, dim]
+ # sampled_w has shape [num_sampled, dim]
+ # sampled_b has shape [num_sampled]
+ # Apply X*W'+B, which yields [batch_size, num_sampled]
+ sampled_logits = math_ops.matmul(inputs,
+ sampled_w,
+ transpose_b=True) + sampled_b
+
+ if remove_accidental_hits:
+ acc_hits = candidate_sampling_ops.compute_accidental_hits(
+ labels, sampled, num_true=num_true)
+ acc_indices, acc_ids, acc_weights = acc_hits
+
+ # This is how SparseToDense expects the indices.
+ acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
+ acc_ids_2d_int32 = array_ops.reshape(math_ops.cast(
+ acc_ids, types.int32), [-1, 1])
+ sparse_indices = array_ops.concat(
+ 1, [acc_indices_2d, acc_ids_2d_int32], "sparse_indices")
+ # Create sampled_logits_shape = [batch_size, num_sampled]
+ sampled_logits_shape = array_ops.concat(
+ 0,
+ [array_ops.shape(labels)[:1], array_ops.expand_dims(num_sampled, 0)])
+ sampled_logits += sparse_ops.sparse_to_dense(
+ sparse_indices, sampled_logits_shape, acc_weights, 0.0)
+
+ if subtract_log_q:
+ # Subtract log of Q(l), prior probability that l appears in sampled.
+ true_logits -= math_ops.log(true_expected_count)
+ sampled_logits -= math_ops.log(sampled_expected_count)
+
+ # Construct output logits and labels. The true labels/logits start at col 0.
+ out_logits = array_ops.concat(1, [true_logits, sampled_logits])
+ # true_logits is a float tensor, ones_like(true_logits) is a float tensor
+ # of ones. We then divide by num_true to ensure the per-example labels sum
+ # to 1.0, i.e. form a proper probability distribution.
+ out_labels = array_ops.concat(
+ 1, [array_ops.ones_like(true_logits) / num_true,
+ array_ops.zeros_like(sampled_logits)])
+
+ return out_logits, out_labels
+
+
+def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
+ num_true=1,
+ sampled_values=None,
+ remove_accidental_hits=False,
+ name="nce_loss"):
+ """Computes and returns the noise-contrastive estimation training loss.
+
+ See [Noise-contrastive estimation: A new estimation principle for
+ unnormalized statistical models]
+ (http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
+ Also see our [Candidate Sampling Algorithms Reference]
+ (http://www.tensorflow.org/extras/candidate_sampling.pdf)
+
+ Note: In the case where num_true > 1, we assign to each target class
+ the target probability 1 / num_true so that the target probabilities
+ sum to 1 per-example.
+
+ Note: It would be useful to allow a variable number of target classes per
+ example. We hope to provide this functionality in a future release.
+ For now, if you have a variable number of target classes, you can pad them
+ out to a constant number by either repeating them or by padding
+ with an otherwise unused class.
+
+ Args:
+ weights: A `Tensor` of shape [num_classes, dim]. The class embeddings.
+ biases: A `Tensor` of shape [num_classes]. The class biases.
+ inputs: A `Tensor` of shape [batch_size, dim]. The forward
+ activations of the input network.
+ labels: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes.
+ num_sampled: An `int`. The number of classes to randomly sample per batch.
+ num_classes: An `int`. The number of possible classes.
+ num_true: An `int`. The number of target classes per training example.
+ sampled_values: a tuple of `(sampled_candidates, true_expected_count,
+ sampled_expected_count)` returned by a *_candidate_sampler function.
+ (if None, we default to LogUniformCandidateSampler)
+ remove_accidental_hits: A `bool`. Whether to remove "accidental hits"
+ where a sampled class equals one of the target classes. If set to
+ `True`, this is a "Sampled Logistic" loss instead of NCE, and we are
+ learning to generate log-odds instead of log probabilities. See
+ our [Candidate Sampling Algorithms Reference]
+ (http://www.tensorflow.org/extras/candidate_sampling.pdf).
+ Default is False.
+ name: A name for the operation (optional).
+
+ Returns:
+ A batch_size 1-D tensor of per-example NCE losses.
+ """
+ logits, labels = _compute_sampled_logits(
+ weights, biases, inputs, labels, num_sampled, num_classes,
+ num_true=num_true,
+ sampled_values=sampled_values,
+ subtract_log_q=True,
+ remove_accidental_hits=remove_accidental_hits,
+ name=name)
+ sampled_losses = sigmoid_cross_entropy_with_logits(logits,
+ labels,
+ name="sampled_losses")
+ # sampled_losses is batch_size x {true_loss, sampled_losses...}
+ # We sum out true and sampled losses.
+ return _sum_rows(sampled_losses)
+
+
+def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled,
+ num_classes, num_true=1,
+ sampled_values=None,
+ remove_accidental_hits=True,
+ name="sampled_softmax_loss"):
+ """Computes and returns the sampled softmax training loss.
+
+ This is a faster way to train a softmax classifier over a huge number of
+ classes.
+
+ This operation is for training only. It is generally an underestimate of
+ the full softmax loss.
+
+ At inference time, you can compute full softmax probabilities with the
+ expression `tf.nn.softmax(tf.matmul(inputs, weights) + biases)`.
+
+ See our [Candidate Sampling Algorithms Reference]
+ (http://www.tensorflow.org/extras/candidate_sampling.pdf)
+
+ Also see Section 3 of http://arxiv.org/abs/1412.2007 for the math.
+
+ Args:
+ weights: A `Tensor` of shape [num_classes, dim]. The class embeddings.
+ biases: A `Tensor` of shape [num_classes]. The class biases.
+ inputs: A `Tensor` of shape [batch_size, dim]. The forward
+ activations of the input network.
+ labels: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes. Note that this format differs from
+ the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
+ num_sampled: An `int`. The number of classes to randomly sample per batch.
+ num_classes: An `int`. The number of possible classes.
+ num_true: An `int`. The number of target classes per training example.
+ sampled_values: a tuple of `(sampled_candidates, true_expected_count,
+ sampled_expected_count)` returned by a *_candidate_sampler function.
+ (if None, we default to LogUniformCandidateSampler)
+ remove_accidental_hits: A `bool`. whether to remove "accidental hits"
+ where a sampled class equals one of the target classes. Default is
+ True.
+ name: A name for the operation (optional).
+
+ Returns:
+ A batch_size 1-D tensor of per-example sampled softmax losses.
+
+ """
+ logits, labels = _compute_sampled_logits(
+ weights, biases, inputs, labels, num_sampled, num_classes,
+ num_true=num_true,
+ sampled_values=sampled_values,
+ subtract_log_q=True,
+ remove_accidental_hits=remove_accidental_hits,
+ name=name)
+ sampled_losses = nn_ops.softmax_cross_entropy_with_logits(logits, labels)
+ # sampled_losses is a batch_size vector.
+ return sampled_losses
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
new file mode 100644
index 0000000000..0cf867d217
--- /dev/null
+++ b/tensorflow/python/ops/nn_grad.py
@@ -0,0 +1,229 @@
+"""Gradients for operators defined in nn_ops.py."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import gen_nn_ops
+
+
+@ops.RegisterGradient("Conv2DBackpropInput")
+def _DeConv2DGrad(op, grad):
+ """The derivatives for deconvolution.
+
+ Args:
+ op: the Deconvolution op.
+ grad: the tensor representing the gradient w.r.t. the output
+
+ Returns:
+ the gradients w.r.t. the input and the filter
+ """
+ return [None,
+ nn_ops.conv2d_backprop_filter(grad,
+ array_ops.shape(op.inputs[1]),
+ op.inputs[2],
+ op.get_attr("strides"),
+ op.get_attr("padding")),
+ nn_ops.conv2d(grad,
+ op.inputs[1],
+ op.get_attr("strides"),
+ op.get_attr("padding"))]
+
+
+@ops.RegisterGradient("Softmax")
+def _SoftmaxGrad(op, grad_softmax):
+ """The derivative of the softmax nonlinearity.
+
+ We assume that probs is of shape [batch_size * dim]
+ The formula for dsoftmax / dx = (diag(softmax) - softmax * softmax').
+ This matrix is diagonal minus a rank one matrix, so it is easy to implement
+ as follows:
+
+ grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax
+
+ Args:
+ op: the Softmax op.
+ grad_softmax: the tensor representing the gradient w.r.t. the
+ softmax output.
+
+ Returns:
+ gradient w.r.t the input to the softmax
+
+ """
+ # TODO(ilyasu): assert that the tensor has two dimensions at
+ # graph-construction time? Alternatively: do different things
+ # depending on the dimensionality of the input tensors.
+ softmax = op.outputs[0]
+ grad_x = ((grad_softmax -
+ array_ops.reshape(math_ops.reduce_sum(grad_softmax * softmax, [1]),
+ [-1, 1]))
+ * softmax)
+ return grad_x
+
+
+@ops.RegisterGradient("BiasAdd")
+def _BiasAddGrad(unused_bias_op, received_grad):
+ """Return the gradients for the 2 inputs of bias_op.
+
+ The first input of unused_bias_op is the tensor t, and its gradient is
+ just the gradient the unused_bias_op received.
+
+ The second input of unused_bias_op is the bias vector which has one fewer
+ dimension than "received_grad" (the batch dimension.) Its gradient is the
+ received gradient Summed on the batch dimension, which is the first dimension.
+
+ Args:
+ unused_bias_op: The BiasOp for which we need to generate gradients.
+ received_grad: Tensor. The gradients passed to the BiasOp.
+
+ Returns:
+ Two tensors, the first one for the "tensor" input of the BiasOp,
+ the second one for the "bias" input of the BiasOp.
+ """
+ reduction_dim_tensor = math_ops.range(0, array_ops.rank(received_grad) - 1)
+ return (received_grad, math_ops.reduce_sum(received_grad, reduction_dim_tensor))
+
+
+def _VerifyTensor(t, name, msg):
+ """Assert that the tensor does not contain any NaN's.
+
+ Args:
+ t: Tensor
+ name: name
+ msg: message to log
+ Returns:
+ Tensor, but verified
+ """
+ with ops.name_scope(name):
+ with ops.device(t.device or ops.get_default_graph().get_default_device()):
+ verify_input = array_ops.check_numerics(t, message=msg)
+ out = control_flow_ops.with_dependencies([verify_input], t)
+ return out
+
+
+@ops.RegisterGradient("Relu")
+def _ReluGrad(op, grad):
+ t = _VerifyTensor(op.inputs[0], op.name, "ReluGrad input is not finite.")
+ return gen_nn_ops._relu_grad(grad, t)
+
+
+@ops.RegisterGradient("Relu6")
+def _Relu6Grad(op, grad):
+ return gen_nn_ops._relu6_grad(grad, op.inputs[0])
+
+
+@ops.RegisterGradient("Softplus")
+def _SoftplusGrad(op, grad):
+ return gen_nn_ops._softplus_grad(grad, op.inputs[0])
+
+
+@ops.RegisterGradient("ReluGrad")
+def _ReluGradGrad(op, grad):
+ x = op.inputs[1]
+ return (gen_nn_ops._relu_grad(grad, x),
+ array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
+
+
+def _BroadcastMul(vec, mat):
+ """Multiply after broadcasting vec to match dimensions of mat.
+
+ Args:
+ vec: A 1-D tensor of dimension [D0]
+ mat: A 2-D tensor of dimension [D0, D1]
+
+ Returns:
+ A tensor of dimension [D0, D1], the result of vec * mat
+ """
+ # Reshape vec to [D0, 1]
+ vec = array_ops.expand_dims(vec, -1)
+ return vec * mat
+
+
+@ops.RegisterGradient("SoftmaxCrossEntropyWithLogits")
+def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
+ # grad_0 is the backprop for cost, and we multiply it with the gradients
+ # (which is output[1])
+ # There is no gradient for the labels
+ return _BroadcastMul(grad_0, op.outputs[1]), None
+
+
+@ops.RegisterGradient("Conv2D")
+def _Conv2DGrad(op, grad):
+ return [nn_ops.conv2d_backprop_input(array_ops.shape(op.inputs[0]),
+ op.inputs[1],
+ grad,
+ op.get_attr("strides"),
+ op.get_attr("padding")),
+ nn_ops.conv2d_backprop_filter(op.inputs[0],
+ array_ops.shape(op.inputs[1]),
+ grad,
+ op.get_attr("strides"),
+ op.get_attr("padding"))]
+
+
+@ops.RegisterGradient("LRN")
+def _LRNGrad(op, grad):
+ depth_radius = op.get_attr("depth_radius")
+ bias = op.get_attr("bias")
+ alpha = op.get_attr("alpha")
+ beta = op.get_attr("beta")
+ return [gen_nn_ops._lrn_grad(grad, op.inputs[0], op.outputs[0],
+ depth_radius, bias, alpha, beta)]
+
+
+@ops.RegisterGradient("AvgPool")
+def _AvgPoolGrad(op, grad):
+ return gen_nn_ops._avg_pool_grad(array_ops.shape(op.inputs[0]), grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ op.get_attr("padding"))
+
+
+@ops.RegisterGradient("MaxPool")
+def _MaxPoolGrad(op, grad):
+ return gen_nn_ops._max_pool_grad(op.inputs[0], op.outputs[0], grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ padding=op.get_attr("padding"))
+
+
+@ops.RegisterGradient("BatchNormWithGlobalNormalization")
+def _BatchNormWithGlobalNormalizationGrad(op, grad):
+ """Return the gradients for the 5 inputs of BatchNormWithGlobalNormalization.
+
+ We do not backprop anything for the mean and var intentionally as they are
+ not being trained with backprop in the operation.
+
+ Args:
+ op: The BatchNormOp for which we need to generate gradients.
+ grad: Tensor. The gradients passed to the BatchNormOp.
+
+ Returns:
+ dx: Backprop for input, which is (grad * (g * rsqrt(v + epsilon)))
+ dm: Backprop for mean, which is
+ sum_over_rest(grad * g) * (-1 / rsqrt(v + epsilon))
+ dv: Backprop for variance, which is
+ sum_over_rest(grad * g * (x - m)) * (-1/2) * (v + epsilon) ^ (-3/2)
+ db: Backprop for beta, which is grad reduced in all except the
+ last dimension.
+ dg: Backprop for gamma, which is (grad * ((x - m) * rsqrt(v + epsilon)))
+ """
+ dx, dm, dv, db, dg = gen_nn_ops._batch_norm_with_global_normalization_grad(
+ op.inputs[0], op.inputs[1], op.inputs[2], op.inputs[4], grad,
+ op.get_attr("variance_epsilon"), op.get_attr("scale_after_normalization"))
+ return dx, dm, dv, db, dg
+
+
+@ops.RegisterGradient("L2Loss")
+def _L2LossGrad(op, grad):
+ """Return the gradients for L2Loss.
+
+ Args:
+ op: The L2LossOp for which we need to generate gradients.
+ grad: Tensor containing a single number.
+
+ Returns:
+ The gradient, which is (x * grad).
+ """
+ return op.inputs[0] * grad
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
new file mode 100644
index 0000000000..0ffe95de2b
--- /dev/null
+++ b/tensorflow/python/ops/nn_ops.py
@@ -0,0 +1,365 @@
+"""Wrappers for primitive Neural Net (NN) Operations."""
+
+import tensorflow.python.platform
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_nn_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_nn_ops import *
+
+
+# Aliases for some automatically-generated names.
+local_response_normalization = gen_nn_ops.lrn
+
+
+def deconv2d(value, filter, output_shape, strides, padding="SAME",
+ name=None):
+ """The transpose of `conv2d`.
+
+ This used to be called "deconvolution", but it is actually the transpose
+ (gradient) of `conv2d`, not an actual deconvolution.
+
+ Args:
+ value: A 4-D `Tensor` of type `float` and shape
+ `[batch, height, width, in_channels]`.
+ filter: A 4-D `Tensor` with the same type as `value` and shape
+ `[height, width, output_channels, in_channels]`. `filter`'s
+ `in_channels` dimension must match that of `value`.
+ output_shape: A 1-D `Tensor` representing the output shape of the
+ deconvolution op.
+ strides: A list of ints. The stride of the sliding window for each
+ dimension of the input tensor.
+ padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
+ name: Optional name for the returned tensor.
+
+ Returns:
+ A `Tensor` with the same type as `value`.
+
+ Raises:
+ ValueError: If input/output depth does not match `filter`'s shape, or if
+ padding is other than `'VALID'` or `'SAME'`.
+ """
+ with ops.op_scope([value, filter, output_shape], name, "DeConv2D") as name:
+ value = ops.convert_to_tensor(value, name="value")
+ filter = ops.convert_to_tensor(filter, name="filter")
+ if not value.get_shape()[3].is_compatible_with(filter.get_shape()[3]):
+ raise ValueError(
+ "input channels does not match filter's input channels, "
+ "{} != {}".format(value.get_shape()[3], filter.get_shape()[3]))
+
+ output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
+ if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(4)):
+ raise ValueError("output_shape must have shape (4,), got {}"
+ .format(output_shape_.get_shape()))
+
+ if isinstance(output_shape, (list, np.ndarray)):
+ # output_shape's shape should be == [4] if reached this point.
+ if not filter.get_shape()[2].is_compatible_with(output_shape[3]):
+ raise ValueError(
+ "output_shape does not match filter's output channels, "
+ "{} != {}".format(output_shape[3], filter.get_shape()[2]))
+
+ if padding != "VALID" and padding != "SAME":
+ raise ValueError("padding must be either VALID or SAME:"
+ " {}".format(padding))
+
+ return gen_nn_ops.conv2d_backprop_input(input_sizes=output_shape_,
+ filter=filter,
+ out_backprop=value,
+ strides=strides,
+ padding=padding,
+ name=name)
+
+# pylint: disable=protected-access
+def bias_add(value, bias, name=None):
+ """Adds `bias` to `value`.
+
+ This is (mostly) a special case of `tf.add` where `bias` is restricted to 1-D.
+ Broadcasting is supported, so `value` may have any number of dimensions.
+ Unlike `tf.add`, the type of `bias` is allowed to differ from `value` in the
+ case where both types are quantized.
+
+ Args:
+ value: A `Tensor` with type `float`, `double`, `int64`, `int32`, `uint8`,
+ `int16`, `int8`, or `complex64`.
+ bias: A 1-D `Tensor` with size matching the last dimension of `value`.
+ Must be the same type as `value` unless `value` is a quantized type,
+ in which case a different quantized type may be used.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with the same type as `value`.
+ """
+ with ops.op_scope([value, bias], name, "BiasAdd") as name:
+ value = ops.convert_to_tensor(value, name="input")
+ bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
+ return gen_nn_ops._bias_add(value, bias, name=name)
+
+
+ops.RegisterShape("BiasAdd")(common_shapes.bias_add_shape)
+
+
+
+def relu6(features, name=None):
+ """Computes Rectified Linear 6: `min(max(features, 0), 6)`.
+
+ Args:
+ features: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`,
+ `int16`, or `int8`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with the same type as `features`.
+ """
+ with ops.op_scope([features], name, "Relu6") as name:
+ features = ops.convert_to_tensor(features, name="features")
+ return gen_nn_ops._relu6(features, name=name)
+
+
+def softmax_cross_entropy_with_logits(logits, labels, name=None):
+ """Computes softmax cross entropy between `logits` and `labels`.
+
+ Measures the probability error in discrete classification tasks in which the
+ classes are mutually exclusive (each entry is in exactly one class). For
+ example, each CIFAR-10 image is labeled with one and only one label: an image
+ can be a dog or a truck, but not both.
+
+ **WARNING:** This op expects unscaled logits, since it performs a `softmax`
+ on `logits` internally for efficiency. Do not call this op with the
+ output of `softmax`, as it will produce incorrect results.
+
+ `logits` and `labels` must have the same shape `[batch_size, num_classes]`
+ and the same dtype (either `float32` or `float64`).
+
+ Args:
+ logits: Unscaled log probabilities.
+ labels: Each row `labels[i]` must be a valid probability distribution.
+ name: A name for the operation (optional).
+
+ Returns:
+ A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the
+ softmax cross entropy loss.
+ """
+ # The second output tensor contains the gradients. We use it in
+ # _CrossEntropyGrad() in nn_grad but not here.
+ cost, unused_backprop = gen_nn_ops._softmax_cross_entropy_with_logits(
+ logits, labels, name=name)
+ return cost
+
+
+@ops.RegisterShape("SoftmaxCrossEntropyWithLogits")
+def _SoftmaxCrossEntropyWithLogitsShape(op):
+ """Shape function for SoftmaxCrossEntropyWithLogits op."""
+ logits_shape = op.inputs[0].get_shape()
+ labels_shape = op.inputs[1].get_shape()
+ input_shape = logits_shape.merge_with(labels_shape).with_rank(2)
+ batch_size = input_shape[0]
+ return [tensor_shape.vector(batch_size.value), input_shape]
+
+
+def avg_pool(value, ksize, strides, padding, name=None):
+ """Performs the average pooling on the input.
+
+ Each entry in `output` is the mean of the corresponding size `ksize`
+ window in `value`.
+
+ Args:
+ value: A 4-D `Tensor` of shape `[batch, height, width, channels]` and type
+ `float32`, `float64`, `qint8`, `quint8`, or `qint32`.
+ ksize: A list of ints that has length >= 4.
+ The size of the window for each dimension of the input tensor.
+ strides: A list of ints that has length >= 4.
+ The stride of the sliding window for each dimension of the
+ input tensor.
+ padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
+ name: Optional name for the operation.
+
+ Returns:
+ A `Tensor` with the same type as `value`. The average pooled output tensor.
+ """
+ with ops.op_scope([value], name, "AvgPool") as name:
+ value = ops.convert_to_tensor(value, name="input")
+ return gen_nn_ops._avg_pool(value, ksize=ksize, strides=strides,
+ padding=padding,
+ name=name)
+
+
+def max_pool(value, ksize, strides, padding, name=None):
+ """Performs the max pooling on the input.
+
+ Args:
+ value: A 4-D `Tensor` with shape `[batch, height, width, channels]` and
+ type `float32`, `float64`, `qint8`, `quint8`, `qint32`.
+ ksize: A list of ints that has length >= 4. The size of the window for
+ each dimension of the input tensor.
+ strides: A list of ints that has length >= 4. The stride of the sliding
+ window for each dimension of the input tensor.
+ padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
+ name: Optional name for the operation.
+
+ Returns:
+ A `Tensor` with the same type as `value`. The max pooled output tensor.
+ """
+ with ops.op_scope([value], name, "MaxPool") as name:
+ value = ops.convert_to_tensor(value, name="input")
+ return gen_nn_ops._max_pool(value, ksize=ksize, strides=strides,
+ padding=padding,
+ name=name)
+
+
+ops.RegisterShape("Relu")(common_shapes.unchanged_shape)
+ops.RegisterShape("Relu6")(common_shapes.unchanged_shape)
+ops.RegisterShape("Softplus")(common_shapes.unchanged_shape)
+
+
+@ops.RegisterShape("ReluGrad")
+@ops.RegisterShape("Relu6Grad")
+@ops.RegisterShape("SoftplusGrad")
+def _BinaryElementwiseShape(op):
+ """Returns same shape as both inputs to op.
+
+ Args:
+ op: Input operation.
+
+ Returns:
+ Shape of both inputs to `op`.
+ """
+ return [op.inputs[0].get_shape().merge_with(op.inputs[1].get_shape())]
+
+
+ops.RegisterShape("L2Loss")(common_shapes.scalar_shape)
+
+
+ops.RegisterShape("LRN")(common_shapes.unchanged_shape_with_rank(4))
+
+
+@ops.RegisterShape("LRNGrad")
+def _LRNGradShape(op):
+ """Shape function for LRNGrad op."""
+ in_grads_shape = op.inputs[0].get_shape().with_rank(4)
+ in_image_shape = op.inputs[1].get_shape().with_rank(4)
+ out_image_shape = op.inputs[2].get_shape().with_rank(4)
+ return [in_grads_shape.merge_with(in_image_shape).merge_with(out_image_shape)]
+
+
+ops.RegisterShape("Softmax")(
+ common_shapes.unchanged_shape_with_rank(2))
+
+
+@ops.RegisterShape("InTopK")
+def _InTopKShape(op):
+ """Shape function for InTopK op."""
+ predictions_shape = op.inputs[0].get_shape().with_rank(2)
+ targets_shape = op.inputs[1].get_shape().with_rank(1)
+ batch_size = predictions_shape[0].merge_with(targets_shape[0])
+ return [tensor_shape.vector(batch_size.value)]
+
+
+@ops.RegisterShape("TopK")
+def _TopKShape(op):
+ """Shape function for TopK op."""
+ input_shape = op.inputs[0].get_shape().with_rank(2)
+ k = op.get_attr("k")
+ num_rows = input_shape[0]
+ num_cols = input_shape[1]
+ if num_cols.value is not None and num_cols.value < k:
+ raise ValueError("input must have at least k (%d) columns" % k)
+ return [tensor_shape.TensorShape([num_rows, k]),
+ tensor_shape.TensorShape([num_rows, k])]
+
+
+@ops.RegisterShape("BatchNormWithGlobalNormalization")
+def _BatchNormShape(op):
+ """Shape function for BatchNormWithGlobalNormalization op."""
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ mean_shape = op.inputs[1].get_shape().with_rank(1)
+ var_shape = op.inputs[2].get_shape().with_rank(1)
+ beta_shape = op.inputs[3].get_shape().with_rank(1)
+ gamma_shape = op.inputs[4].get_shape().with_rank(1)
+ mean_shape[0].merge_with(input_shape[3])
+ var_shape[0].merge_with(input_shape[3])
+ beta_shape[0].merge_with(input_shape[3])
+ gamma_shape[0].merge_with(input_shape[3])
+ return [input_shape]
+
+
+@ops.RegisterShape("BatchNormWithGlobalNormalizationGrad")
+def _BatchNormGradShape(op):
+ """Shape function for BatchNormWithGlobalNormalizationGrad op."""
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ mean_shape = op.inputs[1].get_shape().with_rank(1)
+ var_shape = op.inputs[2].get_shape().with_rank(1)
+ beta_shape = op.inputs[3].get_shape().with_rank(1)
+ out_backprop_shape = op.inputs[4].get_shape().with_rank(4)
+ input_shape = input_shape.merge_with(out_backprop_shape)
+ vector_dim = input_shape[3]
+ vector_dim = vector_dim.merge_with(mean_shape[0])
+ vector_dim = vector_dim.merge_with(var_shape[0])
+ vector_dim = vector_dim.merge_with(beta_shape[0])
+ return [input_shape] + ([tensor_shape.vector(vector_dim)] * 4)
+
+
+ops.RegisterShape("Conv2D")(common_shapes.conv2d_shape)
+ops.RegisterShape("AvgPool")(common_shapes.avg_pool_shape)
+ops.RegisterShape("MaxPool")(common_shapes.max_pool_shape)
+
+
+@ops.RegisterShape("MaxPoolWithArgmax")
+def _MaxPoolWithArgMaxShape(op):
+ """Shape function for MaxPoolWithArgmax op."""
+ return common_shapes.max_pool_shape(op) * 2
+
+
+@ops.RegisterShape("AvgPoolGrad")
+def _AvgPoolGradShape(op):
+ """Shape function for the AvgPoolGrad op."""
+ orig_input_shape = tensor_util.ConstantValue(op.inputs[0])
+ if orig_input_shape is not None:
+ return [tensor_shape.TensorShape(orig_input_shape.tolist())]
+ else:
+ # NOTE(mrry): We could in principle work out the shape from the
+ # gradients and the attrs, but if we do not know orig_input_shape
+ # statically, then we are unlikely to know the shape of the
+ # gradients either.
+ return [tensor_shape.unknown_shape(ndims=4)]
+
+
+@ops.RegisterShape("Conv2DBackpropFilter")
+def _Conv2DBackpropFilterShape(op):
+ """Shape function for the Conv2DBackpropFilter op."""
+ filter_shape = tensor_util.ConstantValue(op.inputs[1])
+ if filter_shape is not None:
+ return [tensor_shape.TensorShape(filter_shape.tolist())]
+ else:
+ # NOTE(mrry): We could in principle work out the shape from the
+ # gradients and the attrs, but if we do not know filter_shape
+ # statically, then we are unlikely to know the shape of the
+ # gradients either.
+ return [tensor_shape.unknown_shape(ndims=4)]
+
+
+@ops.RegisterShape("Conv2DBackpropInput")
+def _Conv2DBackpropInputShape(op):
+ """Shape function for the Conv2DBackpropInput op."""
+ input_shape = tensor_util.ConstantValue(op.inputs[0])
+ if input_shape is not None:
+ return [tensor_shape.TensorShape(input_shape.tolist())]
+ else:
+ # NOTE(mrry): We could in principle work out the shape from the
+ # gradients and the attrs, but if we do not know input_shape
+ # statically, then we are unlikely to know the shape of the
+ # gradients either.
+ return [tensor_shape.unknown_shape(ndims=4)]
+
+
+@ops.RegisterShape("MaxPoolGrad")
+@ops.RegisterShape("MaxPoolGradWithArgmax")
+def _MaxPoolGradShape(op):
+ """Shape function for the MaxPoolGrad op."""
+ orig_input_shape = op.inputs[0].get_shape().with_rank(4)
+ return [orig_input_shape]
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
new file mode 100644
index 0000000000..11ce56e359
--- /dev/null
+++ b/tensorflow/python/ops/nn_test.py
@@ -0,0 +1,882 @@
+"""Tests for tensorflow.ops.nn."""
+import math
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.kernel_tests import gradient_checker as gc
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import nn_grad
+from tensorflow.python.platform import googletest
+
+exp = math.exp
+log = math.log
+
+
+class SigmoidCrossEntropyWithLogitsTest(test_util.TensorFlowTestCase):
+
+ def _SigmoidCrossEntropyWithLogits(self, logits, targets):
+ assert len(logits) == len(targets)
+ pred = [1 / (1 + exp(-x)) for x in logits]
+ eps = 0.0001
+ pred = [min(max(p, eps), 1 - eps) for p in pred]
+ return [-z * log(y) - (1 - z) * log(1 - y) for y, z in zip(pred, targets)]
+
+ def _Inputs(self, x=None, y=None, dtype=types.float64, sizes=None):
+ x = [-100, -2, -2, 0, 2, 2, 2, 100] if x is None else x
+ y = [0, 0, 1, 0, 0, 1, 0.5, 1] if y is None else y
+ assert len(x) == len(y)
+ sizes = sizes if sizes else [len(x)]
+ logits = constant_op.constant(x, shape=sizes, dtype=dtype, name="logits")
+ targets = constant_op.constant(y, shape=sizes, dtype=dtype, name="targets")
+ losses = np.array(self._SigmoidCrossEntropyWithLogits(x, y)).reshape(*sizes)
+ return logits, targets, losses
+
+ def testConstructionNamed(self):
+ with self.test_session():
+ logits, targets, _ = self._Inputs()
+ loss = nn.sigmoid_cross_entropy_with_logits(logits, targets,
+ name="mylogistic")
+ self.assertEqual("mylogistic", loss.op.name)
+
+ def testLogisticOutput(self):
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ logits, targets, losses = self._Inputs(dtype=types.float32)
+ loss = nn.sigmoid_cross_entropy_with_logits(logits, targets)
+ np_loss = np.array(losses).astype(np.float32)
+ tf_loss = loss.eval()
+ self.assertAllClose(np_loss, tf_loss, atol=0.001)
+
+ def testLogisticOutputMultiDim(self):
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ logits, targets, losses = self._Inputs(dtype=types.float32,
+ sizes=[2, 2, 2])
+ loss = nn.sigmoid_cross_entropy_with_logits(logits, targets)
+ np_loss = np.array(losses).astype(np.float32)
+ tf_loss = loss.eval()
+ self.assertAllClose(np_loss, tf_loss, atol=0.001)
+
+ def testGradient(self):
+ sizes = [4, 2]
+ with self.test_session():
+ logits, targets, _ = self._Inputs(sizes=sizes)
+ loss = nn.sigmoid_cross_entropy_with_logits(logits, targets)
+ err = gc.ComputeGradientError(logits, sizes, loss, sizes)
+ print "logistic loss gradient err = ", err
+ self.assertLess(err, 1e-7)
+
+
+class ZeroFractionTest(test_util.TensorFlowTestCase):
+
+ def _ZeroFraction(self, x):
+ assert x.shape
+ total_elements = float(np.prod(x.shape))
+ nonzeros = float(np.count_nonzero(x.flatten()))
+ return 1.0 - (nonzeros / total_elements)
+
+ def testZeroFraction(self):
+ x_shape = [5, 17]
+ x_np = np.random.randint(0, 2, size=x_shape).astype(np.float32)
+ y_np = self._ZeroFraction(x_np)
+ with self.test_session():
+ x_tf = constant_op.constant(x_np)
+ x_tf.set_shape(x_shape)
+ y_tf = nn.zero_fraction(x_tf)
+ y_tf_np = y_tf.eval()
+ eps = 1e-8
+ self.assertAllClose(y_tf_np, y_np, eps)
+
+ def testZeroFractionEmpty(self):
+ with self.test_session():
+ x = np.zeros(0)
+ y = nn.zero_fraction(x).eval()
+ self.assertTrue(np.isnan(y))
+
+
+class SoftmaxTest(test_util.TensorFlowTestCase):
+
+ def _softmax(self, x):
+ assert len(x.shape) == 2
+ m = x.max(1)[:, np.newaxis]
+ u = np.exp(x - m)
+ z = u.sum(1)[:, np.newaxis]
+ return u / z
+
+ def testSoftmax(self):
+ x_shape = [5, 10]
+ x_np = np.random.randn(*x_shape).astype(np.float32)
+ y_np = self._softmax(x_np)
+ with self.test_session():
+ x_tf = constant_op.constant(x_np)
+ y_tf = nn.softmax(x_tf)
+ y_tf_np = y_tf.eval()
+ eps = 1e-3
+ self.assertAllClose(y_tf_np, y_np, eps)
+
+ def testGradient(self):
+ x_shape = [5, 10]
+ x_np = np.random.randn(*x_shape).astype(np.float64)
+ with self.test_session():
+ x_tf = constant_op.constant(x_np)
+ y_tf = nn.softmax(x_tf)
+ err = gc.ComputeGradientError(x_tf, x_shape, y_tf, x_shape)
+ eps = 1e-8
+ self.assertLess(err, eps)
+
+
+class DeConv2DTest(test_util.TensorFlowTestCase):
+
+ def testDeConv2DSingleStride(self):
+ with self.test_session():
+ strides = [1, 1, 1, 1]
+
+ # Input, output: [batch, height, width, depth]
+ x_shape = [2, 6, 4, 3]
+ y_shape = [2, 6, 4, 2]
+
+ # Filter: [kernel_height, kernel_width, output_depth, input_depth]
+ f_shape = [3, 3, 2, 3]
+
+ x = constant_op.constant(1.0, shape=x_shape, name="x",
+ dtype=types.float32)
+ f = constant_op.constant(1.0, shape=f_shape, name="filter",
+ dtype=types.float32)
+ output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME")
+ value = output.eval()
+
+ # We count the number of cells being added at the locations in the output.
+ # At the center, #cells=kernel_height * kernel_width
+ # At the corners, #cells=ceil(kernel_height/2) * ceil(kernel_width/2)
+ # At the borders, #cells=ceil(kernel_height/2)*kernel_width or
+ # kernel_height * ceil(kernel_width/2)
+
+ for n in xrange(x_shape[0]):
+ for k in xrange(f_shape[2]):
+ for w in xrange(y_shape[2]):
+ for h in xrange(y_shape[1]):
+ target = 4 * 3.0
+ h_in = h > 0 and h < y_shape[1] - 1
+ w_in = w > 0 and w < y_shape[2] - 1
+ if h_in and w_in:
+ target += 5 * 3.0
+ elif h_in or w_in:
+ target += 2 * 3.0
+ self.assertAllClose(target, value[n, h, w, k])
+
+ def testDeConv2DSame(self):
+ with self.test_session():
+ strides = [1, 2, 2, 1]
+
+ # Input, output: [batch, height, width, depth]
+ x_shape = [2, 6, 4, 3]
+ y_shape = [2, 12, 8, 2]
+
+ # Filter: [kernel_height, kernel_width, output_depth, input_depth]
+ f_shape = [3, 3, 2, 3]
+
+ x = constant_op.constant(1.0, shape=x_shape, name="x",
+ dtype=types.float32)
+ f = constant_op.constant(1.0, shape=f_shape, name="filter",
+ dtype=types.float32)
+ output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME")
+ value = output.eval()
+
+ for n in xrange(x_shape[0]):
+ for k in xrange(f_shape[2]):
+ for w in xrange(y_shape[2]):
+ for h in xrange(y_shape[1]):
+ target = 3.0
+ # We add a case for locations divisible by the stride.
+ h_in = h % strides[1] == 0 and h > 0 and h < y_shape[1] - 1
+ w_in = w % strides[2] == 0 and w > 0 and w < y_shape[2] - 1
+ if h_in and w_in:
+ target += 9.0
+ elif h_in or w_in:
+ target += 3.0
+ self.assertAllClose(target, value[n, h, w, k])
+
+ def testDeConv2DValid(self):
+ with self.test_session():
+ strides = [1, 2, 2, 1]
+
+ # Input, output: [batch, height, width, depth]
+ x_shape = [2, 6, 4, 3]
+ y_shape = [2, 13, 9, 2]
+
+ # Filter: [kernel_height, kernel_width, output_depth, input_depth]
+ f_shape = [3, 3, 2, 3]
+
+ x = constant_op.constant(1.0, shape=x_shape, name="x",
+ dtype=types.float32)
+ f = constant_op.constant(1.0, shape=f_shape, name="filter",
+ dtype=types.float32)
+ output = nn.deconv2d(x, f, y_shape, strides=strides, padding="VALID")
+ value = output.eval()
+
+ cache_values = np.zeros(y_shape, dtype=np.float32)
+
+ # The amount of padding added
+ pad = 1
+
+ for n in xrange(x_shape[0]):
+ for k in xrange(f_shape[2]):
+ for w in xrange(pad, y_shape[2] - pad):
+ for h in xrange(pad, y_shape[1] - pad):
+ target = 3.0
+ # We add a case for locations divisible by the stride.
+ h_in = h % strides[
+ 1] == 0 and h > pad and h < y_shape[1] - 1 - pad
+ w_in = w % strides[
+ 2] == 0 and w > pad and w < y_shape[2] - 1 - pad
+ if h_in and w_in:
+ target += 9.0
+ elif h_in or w_in:
+ target += 3.0
+ cache_values[n, h, w, k] = target
+
+ # copy values in the border
+ cache_values[n, :, 0, k] = cache_values[n, :, 1, k]
+ cache_values[n, :, -1, k] = cache_values[n, :, -2, k]
+ cache_values[n, 0, :, k] = cache_values[n, 1, :, k]
+ cache_values[n, -1, :, k] = cache_values[n, -2, :, k]
+
+ self.assertAllClose(cache_values, value)
+
+ def testGradient(self):
+ x_shape = [2, 6, 4, 3]
+ f_shape = [3, 3, 2, 3]
+ y_shape = [2, 12, 8, 2]
+ strides = [1, 2, 2, 1]
+ np.random.seed(1) # Make it reproducible.
+ x_val = np.random.random_sample(x_shape).astype(np.float64)
+ f_val = np.random.random_sample(f_shape).astype(np.float64)
+ with self.test_session():
+ x = constant_op.constant(x_val, name="x", dtype=types.float32)
+ f = constant_op.constant(f_val, name="f", dtype=types.float32)
+ output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME")
+ err = gc.ComputeGradientError([x, f], [x_shape, f_shape], output, y_shape)
+ print "DeConv gradient err = %g " % err
+ err_tolerance = 0.0005
+ self.assertLess(err, err_tolerance)
+
+
+class L2LossTest(test_util.TensorFlowTestCase):
+
+ def testL2Loss(self):
+ with self.test_session():
+ x = constant_op.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="x")
+ l2loss = nn.l2_loss(x)
+ value = l2loss.eval()
+ self.assertAllClose(7.0, value)
+
+ def testGradient(self):
+ x_shape = [20, 7, 3]
+ np.random.seed(1) # Make it reproducible.
+ x_val = np.random.random_sample(x_shape).astype(np.float64)
+ with self.test_session():
+ x = constant_op.constant(x_val, name="x")
+ output = nn.l2_loss(x)
+ err = gc.ComputeGradientError(x, x_shape, output, [1])
+ print "L2Loss gradient err = %g " % err
+ err_tolerance = 1e-11
+ self.assertLess(err, err_tolerance)
+
+
+class L2NormalizeTest(test_util.TensorFlowTestCase):
+
+ def _l2Normalize(self, x, dim):
+ norm = np.apply_along_axis(np.linalg.norm, dim, x)
+ return x / np.expand_dims(norm, dim)
+
+ def testL2Normalize(self):
+ x_shape = [20, 7, 3]
+ np.random.seed(1)
+ x_np = np.random.random_sample(x_shape).astype(np.float32)
+ for dim in range(len(x_shape)):
+ y_np = self._l2Normalize(x_np, dim)
+ with self.test_session():
+ x_tf = constant_op.constant(x_np, name="x")
+ y_tf = nn.l2_normalize(x_tf, dim)
+ self.assertAllClose(y_np, y_tf.eval())
+
+ def testL2NormalizeGradient(self):
+ x_shape = [20, 7, 3]
+ np.random.seed(1)
+ x_np = np.random.random_sample(x_shape).astype(np.float64)
+ for dim in range(len(x_shape)):
+ with self.test_session():
+ x_tf = constant_op.constant(x_np, name="x")
+ y_tf = nn.l2_normalize(x_tf, dim)
+ err = gc.ComputeGradientError(x_tf, x_shape, y_tf, x_shape)
+ print "L2Normalize gradient err = %g " % err
+ self.assertLess(err, 1e-4)
+
+
+class DropoutTest(test_util.TensorFlowTestCase):
+
+ def testDropout(self):
+ # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
+ # that it is producing approximately the right number of ones over a large
+ # number of samples, based on the keep probability.
+ x_dim = 40
+ y_dim = 30
+ num_iter = 10
+ for keep_prob in [0.1, 0.5, 0.8]:
+ with self.test_session():
+ t = constant_op.constant(1.0,
+ shape=[x_dim, y_dim],
+ dtype=types.float32)
+ dropout = nn.dropout(t, keep_prob)
+ final_count = 0
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ for _ in xrange(0, num_iter):
+ value = dropout.eval()
+ final_count += np.count_nonzero(value)
+ # Verifies that there are only two values: 0 and 1/keep_prob.
+ sorted_value = np.unique(np.sort(value))
+ self.assertEqual(0, sorted_value[0])
+ self.assertAllClose(1 / keep_prob, sorted_value[1])
+ # Check that we are in the 15% error range
+ expected_count = x_dim * y_dim * keep_prob * num_iter
+ rel_error = math.fabs(final_count - expected_count) / expected_count
+ print rel_error
+ self.assertTrue(rel_error < 0.15)
+
+ def testShapedDropout(self):
+ # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
+ # that it is producing approximately the right number of ones over a large
+ # number of samples, based on the keep probability. This time with shaped
+ # noise.
+ x_dim = 40 * 30
+ y_dim = 3
+ num_iter = 10
+ for keep_prob in [0.1, 0.5, 0.8]:
+ with self.test_session():
+ t = constant_op.constant(1.0,
+ shape=[x_dim, y_dim],
+ dtype=types.float32)
+ dropout = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ final_count = 0
+ for _ in xrange(0, num_iter):
+ value = dropout.eval()
+ final_count += np.count_nonzero(value)
+ # Verifies that there are only two values: 0 and 1/keep_prob.
+ sorted_value = np.unique(np.sort(value))
+ self.assertEqual(0, sorted_value[0])
+ self.assertAllClose(1 / keep_prob, sorted_value[1])
+ # Check that we are in the 15% error range
+ expected_count = x_dim * y_dim * keep_prob * num_iter
+ rel_error = math.fabs(final_count - expected_count) / expected_count
+ print rel_error
+ self.assertTrue(rel_error < 0.15)
+
+ def testShapedDropoutCorrelation(self):
+ # Runs a shaped dropout and tests that the correlations are correct.
+ x_dim = 40
+ y_dim = 30
+ num_iter = 10
+ for keep_prob in [0.1, 0.5, 0.8]:
+ with self.test_session():
+ t = constant_op.constant(1.0,
+ shape=[x_dim, y_dim],
+ dtype=types.float32)
+ dropout = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ for _ in xrange(0, num_iter):
+ value = dropout.eval()
+ # Verifies that each y column as only one type of activation.
+ for i in xrange(x_dim):
+ sorted_value = np.unique(np.sort(value[i, :]))
+ self.assertEqual(sorted_value.size, 1)
+
+ def testShapedDropoutShapeError(self):
+ # Runs shaped dropout and verifies an error is thrown on misshapen noise.
+ x_dim = 40
+ y_dim = 30
+ keep_prob = 0.5
+ with self.test_session():
+ t = constant_op.constant(1.0,
+ shape=[x_dim, y_dim],
+ dtype=types.float32)
+ with self.assertRaises(ValueError):
+ _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10])
+ with self.assertRaises(ValueError):
+ _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5])
+ with self.assertRaises(ValueError):
+ _ = nn.dropout(t, keep_prob, noise_shape=[x_dim + 3])
+ with self.assertRaises(ValueError):
+ _ = nn.dropout(t, keep_prob, noise_shape=[x_dim])
+ # test that broadcasting proceeds
+ _ = nn.dropout(t, keep_prob, noise_shape=[y_dim])
+ _ = nn.dropout(t, keep_prob, noise_shape=[1, y_dim])
+ _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
+ _ = nn.dropout(t, keep_prob, noise_shape=[1, 1])
+
+
+class BatchNormWithGlobalNormalizationTest(test_util.TensorFlowTestCase):
+
+ def _npBatchNorm(self, x, m, v, beta, gamma, epsilon,
+ scale_after_normalization):
+ y = (x - m) / np.sqrt(v + epsilon)
+ y = y * gamma if scale_after_normalization else y
+ y += beta
+ return y
+
+ def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon,
+ scale_after_normalization):
+ y = (x - m) * math_ops.rsqrt(v + epsilon)
+ if scale_after_normalization:
+ y = gamma * y
+ y += beta
+ return y
+
+ def testBatchNorm(self):
+ x_shape = [3, 5, 4, 2]
+ param_shape = [2]
+ x_val = np.random.random_sample(x_shape).astype(np.float32)
+ m_val = np.random.random_sample(param_shape).astype(np.float32)
+ v_val = np.random.random_sample(param_shape).astype(np.float32)
+ beta_val = np.random.random_sample(param_shape).astype(np.float32)
+ gamma_val = np.random.random_sample(param_shape).astype(np.float32)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu) as sess:
+ x = constant_op.constant(x_val, name="x")
+ m = constant_op.constant(m_val, name="m")
+ v = constant_op.constant(v_val, name="v")
+ beta = constant_op.constant(beta_val, name="beta")
+ gamma = constant_op.constant(gamma_val, name="gamma")
+ epsilon = 0.001
+ for scale_after_normalization in [True, False]:
+ bn = nn.batch_norm_with_global_normalization(
+ x, m, v, beta, gamma, epsilon, scale_after_normalization)
+ on = self._opsBatchNorm(
+ x, m, v, beta, gamma, epsilon, scale_after_normalization)
+ np_batch_norm = self._npBatchNorm(
+ x_val, m_val, v_val, beta_val, gamma_val, epsilon,
+ scale_after_normalization)
+ tf_batch_norm, ops_batch_norm = sess.run([bn, on])
+ self.assertAllClose(np_batch_norm, tf_batch_norm, atol=0.000001)
+ self.assertAllClose(np_batch_norm, ops_batch_norm, atol=0.000001)
+ self.assertAllClose(tf_batch_norm, ops_batch_norm, atol=0.000001)
+
+ def _testBatchNormGradient(self, param_index, tag, scale_after_normalization,
+ err_tolerance=1e-11):
+ x_shape = [3, 5, 4, 5]
+ param_shape = [5]
+ np.random.seed(1) # Make it reproducible.
+ x_val = np.random.random_sample(x_shape).astype(np.float64)
+ m_val = np.random.random_sample(param_shape).astype(np.float64)
+ v_val = np.random.random_sample(param_shape).astype(np.float64)
+ beta_val = np.random.random_sample(param_shape).astype(np.float64)
+ gamma_val = np.random.random_sample(param_shape).astype(np.float64)
+ with self.test_session():
+ x = constant_op.constant(x_val, name="x")
+ m = constant_op.constant(m_val, name="m")
+ v = constant_op.constant(v_val, name="v")
+ beta = constant_op.constant(beta_val, name="beta")
+ gamma = constant_op.constant(gamma_val, name="gamma")
+ epsilon = 0.001
+ # If scale_after_normalization is False, backprop for gamma
+ # will be 0. gamma is unchanged.
+ output = nn.batch_norm_with_global_normalization(
+ x, m, v, beta, gamma, epsilon, scale_after_normalization)
+ all_params = [x, m, v, beta, gamma]
+ all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape]
+ err = gc.ComputeGradientError(all_params[param_index],
+ all_shapes[param_index], output, x_shape)
+ print "Batch normalization %s gradient %s scale err = " % (
+ tag, "with" if scale_after_normalization else "without"
+ ), err
+ self.assertLess(err, err_tolerance)
+
+ def testBatchNormInputGradient(self):
+ for scale_after_normalization in [True, False]:
+ self._testBatchNormGradient(0, "x", scale_after_normalization)
+
+ def testBatchNormMeanGradient(self):
+ for scale_after_normalization in [True, False]:
+ self._testBatchNormGradient(1, "mean", scale_after_normalization)
+
+ def testBatchNormVarianceGradient(self):
+ for scale_after_normalization in [True, False]:
+ self._testBatchNormGradient(2, "variance", scale_after_normalization,
+ err_tolerance=1e-03)
+
+ def testBatchNormBetaGradient(self):
+ for scale_after_normalization in [True, False]:
+ self._testBatchNormGradient(3, "beta", scale_after_normalization)
+
+ def testBatchNormGammaGradient(self):
+ for scale_after_normalization in [True, False]:
+ self._testBatchNormGradient(4, "gamma", scale_after_normalization)
+
+ def testBatchNormGradImpl(self):
+ x_shape = [7, 5, 4, 6]
+ param_shape = [6]
+ np.random.seed(1) # Make it reproducible.
+ x_val = np.random.random_sample(x_shape).astype(np.float32)
+ m_val = np.random.random_sample(param_shape).astype(np.float32)
+ v_val = np.random.random_sample(param_shape).astype(np.float32)
+ beta_val = np.random.random_sample(param_shape).astype(np.float32)
+ gamma_val = np.random.random_sample(param_shape).astype(np.float32)
+ backprop_val = np.random.random_sample(x_shape).astype(np.float32)
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu) as sess:
+ x = constant_op.constant(x_val, name="x")
+ m = constant_op.constant(m_val, name="m")
+ v = constant_op.constant(v_val, name="v")
+ beta = constant_op.constant(beta_val, name="beta")
+ gamma = constant_op.constant(gamma_val, name="gamma")
+ backprop = constant_op.constant(backprop_val, name="backprop")
+ epsilon = 0.001
+ for scale_after_normalization in [True, False]:
+ dx, dm, dv, db, dg = (
+ gen_nn_ops._batch_norm_with_global_normalization_grad(
+ x, m, v, gamma, backprop, epsilon, scale_after_normalization))
+ on = self._opsBatchNorm(
+ x, m, v, beta, gamma, epsilon, scale_after_normalization)
+ odx, odm, odv, odb, odg = gradients.gradients(
+ [on], [x, m, v, beta, gamma], [backprop])
+ if scale_after_normalization:
+ all_grads = sess.run([dx, dm, dv, db, dg, odx, odm, odv, odb, odg])
+ to_check = ["dx", "dm", "dv", "db", "dg"]
+ else:
+ all_grads = sess.run([dx, dm, dv, db, odx, odm, odv, odb])
+ to_check = ["dx", "dm", "dv", "db"]
+ for i, n in enumerate(to_check):
+ print n
+ self.assertAllClose(
+ all_grads[i + len(to_check)], all_grads[i], atol=0.000001)
+
+
+class MomentsTest(test_util.TensorFlowTestCase):
+
+ def RunMomentTest(self, shape, global_norm):
+ with self.test_session():
+ # shape = [batch, width, height, depth]
+ assert len(shape) == 4
+
+ x_numpy = np.random.normal(size=shape).astype(np.float32)
+ x = constant_op.constant(x_numpy)
+ x.set_shape(shape)
+ axes = [0, 1, 2] if global_norm else [0]
+ mean, var = nn.moments(x, axes)
+
+ num_elements = np.prod([shape[i] for i in axes])
+
+ ax = (0, 1, 2) if global_norm else (0)
+ expected_mean = np.sum(x_numpy, axis=ax) / num_elements
+ expected_mean_squared = np.multiply(expected_mean, expected_mean)
+ expected_x_squared = np.sum(
+ np.multiply(x_numpy, x_numpy), axis=ax) / num_elements
+ expected_variance = expected_x_squared - expected_mean_squared
+
+ # Check that the moments are correct.
+ self.assertAllClose(expected_mean, mean.eval())
+ self.assertAllClose(expected_variance, var.eval())
+
+ def testBasic(self):
+ self.RunMomentTest(shape=[2, 3, 5, 4], global_norm=False)
+
+ def testGlobalNormalization(self):
+ self.RunMomentTest(shape=[2, 3, 5, 4], global_norm=True)
+
+ def _testGlobalGradient(self, from_y="mean"):
+ with self.test_session():
+ x_shape = [3, 5, 4, 2]
+ x_val = np.random.random_sample(x_shape).astype(np.float64)
+ x = constant_op.constant(x_val)
+ x.set_shape(x_shape)
+
+ axes = [0, 1, 2]
+ y_shape = [2] # Depth of x
+ out_mean, out_var = nn.moments(x, axes)
+ if from_y == "mean":
+ y = out_mean
+ elif from_y == "var":
+ y = out_var
+ err = gc.ComputeGradientError(x, x_shape, y, y_shape)
+ print "Moments %s gradient err = %g" % (from_y, err)
+ self.assertLess(err, 1e-11)
+
+ def testMeanGlobalGradient(self):
+ self._testGlobalGradient(from_y="mean")
+
+ def testVarGlobalGradient(self):
+ self._testGlobalGradient(from_y="var")
+
+
+class ComputeSampledLogitsTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._num_classes = 5
+ self._dim = 10
+ self._batch_size = 3
+
+ def _GenerateTestInputs(self):
+ np.random.seed(0)
+ weights = np.random.randn(self._num_classes, self._dim).astype(np.float32)
+ biases = np.random.randn(self._num_classes).astype(np.float32)
+ hidden_acts = np.random.randn(self._batch_size, self._dim).astype(
+ np.float32)
+
+ return weights, biases, hidden_acts
+
+ def _ComputeSampledLogitsNP(self, true_w, true_b, sampled_w, sampled_b,
+ hidden_acts,
+ num_true=1,
+ true_expected=None,
+ sampled_expected=None):
+
+ batch_size, dim = hidden_acts.shape
+ true_logits = np.sum(
+ hidden_acts.reshape((batch_size, 1, dim)) * true_w.reshape(
+ (batch_size, num_true, dim)),
+ axis=2)
+ true_b = true_b.reshape((batch_size, num_true))
+ true_logits += true_b
+ sampled_logits = np.dot(hidden_acts, sampled_w.T) + sampled_b
+
+ if true_expected is not None:
+ true_logits -= np.log(true_expected)
+ if sampled_expected is not None:
+ sampled_logits -= np.log(sampled_expected[np.newaxis, :])
+
+ out_logits = np.concatenate([true_logits, sampled_logits], axis=1)
+ out_labels = np.hstack((np.ones_like(true_logits) / num_true,
+ np.zeros_like(sampled_logits)))
+
+ return out_logits, out_labels
+
+ def _ComputeSampledLogitsTF(self, weights, biases, hidden_acts, labels,
+ num_sampled, num_classes, num_true, sampled_vals,
+ subtract_log_q, remove_accidental_hits,
+ name="sampled_loss_TF"):
+ # Should be called from within a `with test_session():` block
+ weights_tf = constant_op.constant(weights)
+ biases_tf = constant_op.constant(biases)
+ hidden_acts_tf = constant_op.constant(hidden_acts,
+ shape=(self._batch_size, self._dim))
+ labels_tf = constant_op.constant(labels, dtype=types.int64,
+ shape=(self._batch_size, num_true))
+
+ pred_logits_tf, pred_labels_tf = nn._compute_sampled_logits(
+ weights_tf, biases_tf, hidden_acts_tf, labels_tf, num_sampled,
+ num_classes, num_true, sampled_vals,
+ subtract_log_q=subtract_log_q,
+ remove_accidental_hits=remove_accidental_hits,
+ name=name)
+ return pred_logits_tf, pred_labels_tf
+
+ def testComputeSampledLogitsShapes(self):
+ # We just check that the shapes of the returned values are correct.
+ weights, biases, hidden_acts = self._GenerateTestInputs()
+ sampled = [1, 0, 2, 3]
+ num_sampled = len(sampled)
+ true_exp = sampled_exp = [1., 1., 1., 1.]
+ test_sampled_vals = (sampled, true_exp, sampled_exp)
+ sampled_w, sampled_b = weights[sampled], biases[sampled]
+
+ with self.test_session() as sess:
+ for num_true_test in range(1, 5):
+ labels = np.random.randint(low=0, high=self._num_classes,
+ size=self._batch_size * num_true_test)
+ true_w, true_b = weights[labels], biases[labels]
+
+ logits_np, labels_np = self._ComputeSampledLogitsNP(
+ true_w, true_b, sampled_w, sampled_b, hidden_acts,
+ num_true=num_true_test)
+
+ logits_tf, labels_tf = self._ComputeSampledLogitsTF(
+ weights, biases, hidden_acts, labels, num_sampled,
+ self._num_classes,
+ num_true=num_true_test,
+ sampled_vals=test_sampled_vals,
+ remove_accidental_hits=True,
+ subtract_log_q=False)
+
+ logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf])
+ self.assertEqual(logits_np.shape, logits_tf_val.shape)
+ self.assertEqual(labels_np.shape, labels_tf_val.shape)
+
+ def testComputeSampledLogitsValues(self):
+ # Here we check the actual numerics.
+ weights, biases, hidden_acts = self._GenerateTestInputs()
+ eps = 1e-3
+ sampled = [1, 0, 2, 3]
+ num_sampled = len(sampled)
+ true_exp = np.empty([self._batch_size, 1], dtype=np.float32)
+ true_exp.fill(0.5)
+ sampled_exp = np.empty([num_sampled], dtype=np.float32)
+ sampled_exp.fill(0.5)
+ sampled_w, sampled_b = weights[sampled], biases[sampled]
+ test_sampled_vals = (sampled, true_exp, sampled_exp)
+
+ with self.test_session() as sess:
+ for num_true_test in range(1, 5):
+ # Generate test data for this run
+ labels = np.random.randint(low=0, high=self._num_classes,
+ size=self._batch_size * num_true_test)
+ true_w, true_b = weights[labels], biases[labels]
+
+ # Test 1: Without accidental hit removal or subtract_log_q
+ logits_np, labels_np = self._ComputeSampledLogitsNP(
+ true_w, true_b, sampled_w, sampled_b, hidden_acts,
+ num_true=num_true_test)
+ logits_tf, labels_tf = self._ComputeSampledLogitsTF(
+ weights, biases, hidden_acts, labels, num_sampled,
+ self._num_classes,
+ num_true=num_true_test,
+ sampled_vals=test_sampled_vals,
+ subtract_log_q=False,
+ remove_accidental_hits=False,
+ name="sampled_loss_test1_num_true%d" % num_true_test)
+
+ logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf])
+ self.assertAllClose(logits_np, logits_tf_val, eps)
+ self.assertAllClose(labels_np, labels_tf_val, eps)
+
+ # Test 2: With accidental hit removal, no subtract_log_q
+ logits_tf, labels_tf = self._ComputeSampledLogitsTF(
+ weights, biases, hidden_acts, labels, num_sampled,
+ self._num_classes,
+ num_true=num_true_test,
+ sampled_vals=test_sampled_vals,
+ subtract_log_q=False,
+ remove_accidental_hits=True,
+ name="sampled_loss_test2_num_true%d" % num_true_test)
+
+ # Test that the exponentiated logits of accidental hits are near 0.
+ # First we need to find the hits in this random test run:
+ labels_reshape = labels.reshape((self._batch_size, num_true_test))
+ logits_tf_np = logits_tf.eval()
+ for row in xrange(self._batch_size):
+ row_labels = labels_reshape[row, :]
+ for col in xrange(num_sampled):
+ if sampled[col] in row_labels:
+ # We need to add the num_true_test offset into logits_*
+ self.assertNear(
+ np.exp(logits_tf_np[row, col + num_true_test]), 0., eps)
+
+ # Test 3: With subtract_log_q, no accidental hit removal
+ logits_np, labels_np = self._ComputeSampledLogitsNP(
+ true_w, true_b, sampled_w, sampled_b, hidden_acts,
+ num_true=num_true_test,
+ true_expected=true_exp,
+ sampled_expected=sampled_exp)
+ logits_tf, labels_tf = self._ComputeSampledLogitsTF(
+ weights, biases, hidden_acts, labels, num_sampled,
+ self._num_classes,
+ num_true=num_true_test,
+ sampled_vals=test_sampled_vals,
+ subtract_log_q=True,
+ remove_accidental_hits=False,
+ name="sampled_loss_test3_num_true%d" % num_true_test)
+
+ logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf])
+ self.assertAllClose(logits_np, logits_tf_val, eps)
+ self.assertAllClose(labels_np, labels_tf_val, eps)
+
+ def testNCELoss(self):
+ # A simple test to verify the numerics.
+
+ def _SigmoidCrossEntropyWithLogits(logits, targets):
+ # logits, targets: float arrays of the same shape.
+ assert logits.shape == targets.shape
+ pred = 1. / (1. + np.exp(-logits))
+ eps = 0.0001
+ pred = np.minimum(np.maximum(pred, eps), 1 - eps)
+ return -targets * np.log(pred) - (1. - targets) * np.log(1. - pred)
+
+ weights, biases, hidden_acts = self._GenerateTestInputs()
+ labels = [0, 1, 2]
+ true_w, true_b = weights[labels], biases[labels]
+ sampled = [1, 0, 2, 3]
+ num_sampled = len(sampled)
+ true_exp = np.empty([self._batch_size, 1], dtype=np.float32)
+ true_exp.fill(0.5)
+ sampled_exp = np.empty([num_sampled], dtype=np.float32)
+ sampled_exp.fill(0.5)
+ sampled_w, sampled_b = weights[sampled], biases[sampled]
+ test_sampled_vals = (sampled, true_exp, sampled_exp)
+
+ with self.test_session():
+ logits_np, labels_np = self._ComputeSampledLogitsNP(
+ true_w, true_b, sampled_w, sampled_b, hidden_acts,
+ true_expected=true_exp,
+ sampled_expected=sampled_exp)
+ nce_loss_np = np.sum(
+ _SigmoidCrossEntropyWithLogits(logits_np, labels_np), 1)
+
+ labels_tf = constant_op.constant(labels, shape=(self._batch_size, 1))
+ weights_tf = constant_op.constant(weights)
+ biases_tf = constant_op.constant(biases)
+ inputs_tf = constant_op.constant(hidden_acts)
+
+ nce_loss_tf = nn.nce_loss(
+ weights_tf, biases_tf, inputs_tf, labels_tf,
+ num_sampled=1,
+ num_classes=self._num_classes,
+ num_true=1,
+ sampled_values=test_sampled_vals)
+
+ self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4)
+
+ def testSampledSoftmaxLoss(self):
+ # A simple test to verify the numerics.
+
+ def _SoftmaxCrossEntropyWithLogits(logits, targets):
+ # logits, targets: float arrays of the same shape.
+ assert logits.shape == targets.shape
+ stable_exp_logits = np.exp(logits - np.amax(
+ logits, axis=1, keepdims=True))
+ pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True)
+ return -np.sum(targets * np.log(pred + 1.0e-20), axis=1)
+
+ weights, biases, hidden_acts = self._GenerateTestInputs()
+ labels = [0, 1, 2]
+ true_w, true_b = weights[labels], biases[labels]
+ sampled = [1, 0, 2, 3]
+ num_sampled = len(sampled)
+ true_exp = np.full([self._batch_size, 1], fill_value=0.5, dtype=np.float32)
+ sampled_exp = np.full([num_sampled], fill_value=0.5, dtype=np.float32)
+ sampled_w, sampled_b = weights[sampled], biases[sampled]
+ test_sampled_vals = (sampled, true_exp, sampled_exp)
+
+ with self.test_session():
+ logits_np, labels_np = self._ComputeSampledLogitsNP(
+ true_w, true_b, sampled_w, sampled_b, hidden_acts,
+ true_expected=true_exp,
+ sampled_expected=sampled_exp)
+ sampled_softmax_loss_np = _SoftmaxCrossEntropyWithLogits(logits_np,
+ labels_np)
+
+ labels_tf = constant_op.constant(labels, shape=(self._batch_size, 1))
+ weights_tf = constant_op.constant(weights)
+ biases_tf = constant_op.constant(biases)
+ inputs_tf = constant_op.constant(hidden_acts)
+
+ sampled_softmax_loss_tf = nn.sampled_softmax_loss(
+ weights_tf, biases_tf, inputs_tf, labels_tf,
+ num_sampled=1,
+ num_classes=self._num_classes,
+ num_true=1,
+ sampled_values=test_sampled_vals,
+ remove_accidental_hits=False)
+
+ self.assertAllClose(
+ sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py
new file mode 100644
index 0000000000..93f5d5db20
--- /dev/null
+++ b/tensorflow/python/ops/numerics.py
@@ -0,0 +1,50 @@
+"""Connects all float and double tensors to CheckNumericsOp."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+
+
+def verify_tensor_all_finite(t, msg, name=None):
+ """Assert that the tensor does not contain any NaN's or Inf's.
+
+ Args:
+ t: Tensor to check.
+ msg: Message to log on failure.
+ name: A name for this operation (optional).
+
+ Returns:
+ Same tensor as `t`.
+ """
+ with ops.op_scope([t], name, "VerifyFinite") as name:
+ t = ops.convert_to_tensor(t, name="t")
+ with ops.device(t.device or t.graph.get_default_device()):
+ verify_input = array_ops.check_numerics(t, message=msg)
+ out = control_flow_ops.with_dependencies([verify_input], t)
+ return out
+
+
+def add_check_numerics_ops():
+ """Connect a check_numerics to every floating point tensor.
+
+ `check_numerics` operations themselves are added for each `float` or `double`
+ tensor in the graph. For all ops in the graph, the `check_numerics` op for
+ all of its (`float` or `double`) inputs is guaranteed to run before the
+ `check_numerics` op on any of its outputs.
+
+ Returns:
+ A `group` op depending on all `check_numerics` ops added.
+ """
+ check_op = []
+ # This code relies on the ordering of ops in get_operations().
+ # The consumer of a tensor always comes before that tensor's producer in
+ # this list. This is true because get_operations() returns ops in the order
+ # added, and ops can only be added once its inputs are added.
+ for op in ops.get_default_graph().get_operations():
+ for output in op.outputs:
+ if output.dtype in [types.float32, types.float64]:
+ message = op.name + ":" + str(output.value_index)
+ with ops.control_dependencies(check_op):
+ check_op = [array_ops.check_numerics(output, message=message)]
+ return control_flow_ops.group(*check_op)
diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py
new file mode 100644
index 0000000000..5947b6df89
--- /dev/null
+++ b/tensorflow/python/ops/op_def_library.py
@@ -0,0 +1,640 @@
+"""Class to hold a library of OpDefs and use it to create Brain operations."""
+
+import numbers
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.core.framework import op_def_pb2
+from tensorflow.core.framework import tensor_pb2
+from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types as types_lib
+from tensorflow.python.ops import constant_op
+from tensorflow.python.platform import logging
+
+
+def _Attr(op_def, name):
+ for attr in op_def.attr:
+ if attr.name == name:
+ return attr
+ raise TypeError("Inconsistent OpDef for '%s', missing attr '%s'" %
+ (op_def.name, name))
+
+
+def _AttrValue(attr_protos, name):
+ if name in attr_protos:
+ return attr_protos[name]
+ raise TypeError("Inconsistent OpDef, missing attr '%s' from '%s'." %
+ (name, attr_protos))
+
+
+def _SatisfiesTypeConstraint(dtype, attr_def):
+ if attr_def.HasField("allowed_values"):
+ allowed_list = attr_def.allowed_values.list.type
+ if dtype not in allowed_list:
+ raise TypeError(
+ "DataType %s for attr '%s' not in list of allowed values: %s" %
+ (types_lib.as_dtype(dtype).name, attr_def.name,
+ ", ".join(types_lib.as_dtype(x).name for x in allowed_list)))
+
+
+def _IsListParameter(arg):
+ if arg.number_attr:
+ return True
+ elif arg.type_list_attr:
+ return True
+ return False
+
+
+def _NumTypeFields(arg):
+ num = 0
+ if arg.type != types_pb2.DT_INVALID: num += 1
+ if arg.type_attr: num += 1
+ if arg.type_list_attr: num += 1
+ return num
+
+
+def _IsListValue(v):
+ return isinstance(v, (list, tuple))
+
+
+def _Flatten(l):
+ """Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5]."""
+ # [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]]
+ l_of_l = [x if _IsListValue(x) else [x] for x in l]
+ # [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5]
+ return [item for sublist in l_of_l for item in sublist]
+
+
+def _Restructure(l, structure):
+ """Returns the elements of list l structured according to the given structure.
+
+ A structure is represented by a list whose elements are either
+ `None` or a non-negative integer. `None` corresponds to a single
+ element in the output list, and an integer N corresponds to a nested
+ list of length N.
+
+ The function returns a data structure whose shape is given by
+ `structure`, and whose elements are taken from `l`. If `structure`
+ is a singleton, the function returns the single data structure
+ implied by the 0th element of `structure`. For example:
+
+ _Restructure(["foo", "bar", "baz", "qux"], [None, 2, None])
+ -> ["foo", ["bar", "baz"], "qux"]
+
+ _Restructure(["foo"], [None]) -> "foo"
+
+ _Restructure(["foo"], [1]) -> ["foo"]
+
+ _Restructure([], [0]) -> []
+
+ Args:
+ l: A list.
+ structure: A list whose elements are either `None` or a non-negative
+ integer.
+
+ Returns:
+ The elements of `l`, restructured according to `structure`. If
+ `structure` is a list of length 1, this function returns the
+ single data structure implied by `structure[0]`.
+
+ """
+ result = []
+ current_index = 0
+ for element in structure:
+ if element is None:
+ result.append(l[current_index])
+ current_index += 1
+ else:
+ result.append(l[current_index:current_index+element])
+ current_index += element
+
+ if len(result) == 1:
+ return result[0]
+ else:
+ return tuple(result)
+
+
+def _MakeFloat(v, arg_name):
+ if not isinstance(v, numbers.Real):
+ raise TypeError("Expected float for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ return float(v)
+
+
+def _MakeInt(v, arg_name):
+ if isinstance(v, basestring):
+ raise TypeError("Expected int for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ try:
+ return int(v)
+ except (ValueError, TypeError):
+ raise TypeError("Expected int for argument '%s' not %s." %
+ (arg_name, repr(v)))
+
+
+def _MakeStr(v, arg_name):
+ if not isinstance(v, basestring):
+ raise TypeError("Expected string for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ return str(v) # Convert unicode strings to bytes.
+
+
+def _MakeBool(v, arg_name):
+ if not isinstance(v, bool):
+ raise TypeError("Expected bool for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ return v
+
+
+def _MakeType(v, attr_def):
+ try:
+ v = types_lib.as_dtype(v)
+ except TypeError:
+ raise TypeError("Expected DataType for argument '%s' not %s." %
+ (attr_def.name, repr(v)))
+ i = v.as_datatype_enum
+ _SatisfiesTypeConstraint(i, attr_def)
+ return i
+
+
+def _MakeShape(v, arg_name):
+ """Convert v into a TensorShapeProto."""
+ # Args:
+ # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
+ # arg_name: String, for error messages.
+
+ # Returns:
+ # A TensorShapeProto.
+ if isinstance(v, tensor_shape_pb2.TensorShapeProto):
+ for d in v.dim:
+ if d.name:
+ logging.warning("Warning: TensorShapeProto with a named dimension: %s",
+ str(v))
+ break
+ return v
+ s = tensor_shape.as_shape(v)
+ ret = tensor_shape_pb2.TensorShapeProto()
+ for i in s.as_dimension_list():
+ ret.dim.add(size = i)
+ return ret
+
+
+def _MakeTensor(v, arg_name):
+ """Ensure v is a TensorProto."""
+ if isinstance(v, tensor_pb2.TensorProto):
+ return v
+ raise TypeError(
+ "Don't know how to convert %s to a TensorProto for argument '%s'" %
+ (repr(v), arg_name))
+
+
+class _OpInfo(object):
+ """All per-Op state we would like to precompute/validate."""
+
+ def __init__(self, op_def):
+ self.op_def = op_def
+ # TODO(josh11b): SWIG the ValidateOpDef() function from C++ and call it
+ # here, instead of these checks.
+ for arg in list(op_def.input_arg) + list(op_def.output_arg):
+ num_type_fields = _NumTypeFields(arg)
+ if num_type_fields != 1:
+ raise TypeError("Arg '%s' of '%s' must have one type field not %d" %
+ (arg.name, op_def.name, num_type_fields))
+ if arg.type_attr:
+ attr_type = _Attr(op_def, arg.type_attr).type
+ if attr_type != "type":
+ raise TypeError("Attr '%s' of '%s' used as a type_attr "
+ "but has type %s" %
+ (arg.type_attr, op_def.name, attr_type))
+ if arg.type_list_attr:
+ attr_type = _Attr(op_def, arg.type_list_attr).type
+ if attr_type != "list(type)":
+ raise TypeError(
+ "Attr '%s' of '%s' used as a type_list_attr but has type %s" %
+ (arg.type_attr, op_def.name, attr_type))
+ if arg.number_attr:
+ attr_type = _Attr(op_def, arg.number_attr).type
+ if attr_type != "int":
+ raise TypeError(
+ "Attr '%s' of '%s' used as a number_attr but has type %s" %
+ (arg.number_attr, op_def.name, attr_type))
+
+
+class OpDefLibrary(object):
+ """Holds a collection of OpDefs, can add the corresponding Ops to a graph."""
+
+ def __init__(self):
+ self._ops = {}
+
+ def add_op(self, op_def):
+ """Register an OpDef. May call apply_op with the name afterwards."""
+ if not isinstance(op_def, op_def_pb2.OpDef):
+ raise TypeError("%s is %s, not an op_def_pb2.OpDef" %
+ (op_def, type(op_def)))
+ if not op_def.name:
+ raise ValueError("%s missing name." % op_def)
+ if op_def.name in self._ops:
+ raise RuntimeError("Op name %s registered twice." % op_def.name)
+ self._ops[op_def.name] = _OpInfo(op_def)
+
+ def add_op_list(self, op_list):
+ """Register the OpDefs from an OpList."""
+ if not isinstance(op_list, op_def_pb2.OpList):
+ raise TypeError("%s is %s, not an op_def_pb2.OpList" %
+ (op_list, type(op_list)))
+ for op_def in op_list.op:
+ self.add_op(op_def)
+
+ def apply_op(self, op_type_name, g=None, name=None, **keywords):
+ # pylint: disable=g-doc-args
+ """Add a node invoking a registered Op to a graph.
+
+ Config proto extensions must be provided via the 'ext' keyword argument.
+ Example usage:
+ # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
+ # will convert to a Tensor.
+ op_def_library.apply_op("op", input1=input1, input2=input2)
+ # If none of the inputs are Tensors and your session doesn't have a
+ # default graph, you will have to specify the graph.
+ op_def_library.apply_op("op", input1=input1, g=g)
+ # Can specify a node name.
+ op_def_library.apply_op("op", input1=input1, name="node_name")
+ # Must use keyword arguments, with the names specified in the OpDef.
+ op_def_library.apply_op("op", input_name=input, attr_name=attr)
+
+ All attrs must either be inferred from an input or specified.
+ (If inferred, the attr must not be specified.) If an attr has a default
+ value specified in the Op's OpDef, then you may pass None as the value
+ of that attr to get the default.
+
+ Args:
+ op_type_name: string. Must match the name field of a registered Op.
+ g: The graph context (optional)
+ name: string. Optional name of the created op.
+ **keywords: input Tensor and attr arguments specified by name,
+ and optional parameters to pass when constructing the Operation.
+
+ Returns:
+ The Tensor(s) representing the output of the operation, or the Operation
+ itself if there are no outputs.
+
+ Raises:
+ RuntimeError: On some errors.
+ TypeError: On some errors.
+ ValueError: On some errors.
+ """
+ op_info = self._ops.get(op_type_name, None)
+ if op_info is None:
+ raise RuntimeError("Unrecognized Op name " + op_type_name)
+ op_def = op_info.op_def
+
+ # Determine the graph context.
+ try:
+ # Need to flatten all the arguments into a list.
+ # pylint: disable=protected-access
+ g = ops._get_graph_from_inputs(_Flatten(keywords.values()), graph=g)
+ # pyline: enable=protected-access
+ except AssertionError as e:
+ raise RuntimeError(
+ "Need to specify g=graph to Op '%s' (could not determine graph due "
+ "to: %s)" % (op_type_name, e.message))
+
+ # Default name if not specified.
+ if name is None:
+ name = op_type_name
+
+ # Requires that op_def has passed validation (using the C++
+ # ValidateOpDef() from ../framework/op_def_util.h).
+ attrs = {}
+ inputs = []
+ input_types = []
+ with g.as_default(), ops.name_scope(name) as scope:
+
+ # Perform input type inference
+ inferred_from = {}
+ for input_arg in op_def.input_arg:
+ input_name = input_arg.name
+ if input_name in keywords:
+ values = keywords.pop(input_name)
+ elif input_name + "_" in keywords:
+ # Handle the case where the name is a keyword or built-in
+ # for Python so we use the name + _ instead.
+ input_name += "_"
+ values = keywords.pop(input_name)
+ else:
+ raise TypeError("No argument for input " + input_name)
+
+ # Goals:
+ # * Convert values to Tensors if it contains constants.
+ # * Verify that values is a list if that matches the input_arg's
+ # type.
+ # * If the input_arg's type is determined by attrs, either set
+ # those attrs and validate those attr values are legal (if
+ # they have not yet been set) or validate the input matches
+ # the type indicated by the attrs (if they have already been
+ # inferred via an earlier input).
+ # * If the input_arg has an explicit type, make sure the input
+ # conforms.
+
+ if _IsListParameter(input_arg):
+ if not _IsListValue(values):
+ raise TypeError(
+ "Expected list for '%s' argument to '%s' Op, not %s." %
+ (input_name, op_type_name, values))
+ # In cases where we expect all elements of the list to have the
+ # same dtype, try to cast non-Tensor elements to that type.
+ dtype = None
+ if input_arg.type != types_pb2.DT_INVALID:
+ dtype = input_arg.type
+ elif input_arg.number_attr:
+ if input_arg.type_attr in attrs:
+ dtype = attrs[input_arg.type_attr]
+ else:
+ for t in values:
+ if isinstance(t, ops.Tensor):
+ dtype = t.dtype
+ break
+
+ try:
+ values = ops.convert_n_to_tensor_or_indexed_slices(
+ values, name=input_arg.name,
+ dtype=types_lib.as_dtype(dtype).base_dtype if dtype else None)
+ except (TypeError, ValueError):
+ assert dtype is not None, "Should not fail if dtype is None"
+ assert input_arg.number_attr, "Should be number_attr case"
+ # What types does the conversion function think values have?
+ values = ops.convert_n_to_tensor_or_indexed_slices(values)
+ observed = ", ".join(v.dtype.base_dtype.name for v in values)
+
+ prefix = (
+ "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
+ (input_name, op_type_name, observed))
+ if input_arg.type != types_pb2.DT_INVALID:
+ raise TypeError("%s that do not match expected type %s." %
+ (prefix, types_lib.as_dtype(dtype).name))
+ elif input_arg.type_attr in attrs:
+ raise TypeError("%s that do not match type %s inferred from "
+ "earlier arguments." %
+ (prefix, types_lib.as_dtype(dtype).name))
+ else:
+ raise TypeError("%s that don't all match." % prefix)
+
+ types = [x.dtype for x in values]
+ inputs.extend(values)
+ else:
+ # In cases where we have an expected type, try to convert non-Tensor
+ # arguments to that type.
+ dtype = None
+ if input_arg.type != types_pb2.DT_INVALID:
+ dtype = input_arg.type
+ elif input_arg.type_attr in attrs:
+ dtype = attrs[input_arg.type_attr]
+
+ try:
+ values = ops.convert_to_tensor(
+ values, name=input_arg.name, dtype=dtype)
+ except ValueError:
+ # What type does convert_to_tensor think it has?
+ observed = ops.convert_to_tensor(values).dtype.name
+ prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
+ (input_name, op_type_name, observed))
+ if input_arg.type != types_pb2.DT_INVALID:
+ raise TypeError("%s expected type of %s." %
+ (prefix, types_lib.as_dtype(input_arg.type).name))
+ else:
+ raise TypeError(
+ "%s type %s of argument '%s'." %
+ (prefix, types_lib.as_dtype(attrs[input_arg.type_attr]).name,
+ inferred_from[input_arg.type_attr]))
+
+ types = [values.dtype]
+ inputs.append(values)
+ base_types = [x.base_dtype for x in types]
+
+ if input_arg.number_attr:
+ # <number-attr> * <type> or <number-attr> * <type-attr>
+ if input_arg.number_attr in attrs:
+ if len(values) != attrs[input_arg.number_attr]:
+ raise ValueError(
+ "List argument '%s' to '%s' Op with length %d must match "
+ "length %d of argument '%s'." %
+ (input_name, op_type_name, len(values),
+ attrs[input_arg.number_attr],
+ inferred_from[input_arg.number_attr]))
+ else:
+ attrs[input_arg.number_attr] = len(values)
+ inferred_from[input_arg.number_attr] = input_name
+ num_attr = _Attr(op_def, input_arg.number_attr)
+ if num_attr.has_minimum and len(values) < num_attr.minimum:
+ raise ValueError(
+ "List argument '%s' to '%s' Op with length %d shorter "
+ "than minimum length %d." %
+ (input_name, op_type_name, len(values), num_attr.minimum))
+ # All tensors must have the same base type.
+ if any([bt != base_types[0] for bt in base_types]):
+ raise TypeError(
+ "All tensors passed to '%s' of '%s' Op "
+ "must have the same type." %
+ (input_name, op_type_name))
+ if input_arg.type != types_pb2.DT_INVALID:
+ # <number-attr> * <type> case
+ if base_types and base_types[0] != input_arg.type:
+ assert False, "Unreachable"
+ elif input_arg.type_attr in attrs:
+ # <number-attr> * <type-attr> case, where <type-attr> already
+ # has an inferred value.
+ if base_types and base_types[0] != attrs[input_arg.type_attr]:
+ assert False, "Unreachable"
+ else:
+ # <number-attr> * <type-attr> case, where we are now setting
+ # the <type-attr> based on this input
+ if not base_types:
+ raise TypeError(
+ "Don't know how to infer type variable from empty input "
+ "list passed to input '%s' of '%s' Op." %
+ (input_name, op_type_name))
+ attrs[input_arg.type_attr] = base_types[0]
+ inferred_from[input_arg.type_attr] = input_name
+ type_attr = _Attr(op_def, input_arg.type_attr)
+ _SatisfiesTypeConstraint(base_types[0], type_attr)
+ elif input_arg.type_attr:
+ # <type-attr>
+ attr_value = base_types[0]
+ if input_arg.type_attr in attrs:
+ if attrs[input_arg.type_attr] != attr_value:
+ assert False, "Unreachable"
+ else:
+ for base_type in base_types:
+ _SatisfiesTypeConstraint(base_type,
+ _Attr(op_def, input_arg.type_attr))
+ attrs[input_arg.type_attr] = attr_value
+ inferred_from[input_arg.type_attr] = input_name
+ elif input_arg.type_list_attr:
+ # <type-list-attr>
+ attr_value = base_types
+ if input_arg.type_list_attr in attrs:
+ if attrs[input_arg.type_list_attr] != attr_value:
+ raise TypeError(
+ "Input '%s' of '%s' Op has type list of %s that does not "
+ "match type list %s of argument '%s'." %
+ (input_name, op_type_name,
+ ", ".join(types_lib.as_dtype(x).name for x in attr_value),
+ ", ".join(types_lib.as_dtype(x).name
+ for x in attrs[input_arg.type_list_attr]),
+ inferred_from[input_arg.type_list_attr]))
+ else:
+ for base_type in base_types:
+ _SatisfiesTypeConstraint(base_type,
+ _Attr(op_def, input_arg.type_list_attr))
+ attrs[input_arg.type_list_attr] = attr_value
+ inferred_from[input_arg.type_list_attr] = input_name
+ else:
+ # single Tensor with specified type
+ if base_types[0] != input_arg.type:
+ assert False, "Unreachable"
+
+ if input_arg.is_ref:
+ if not all(x.is_ref_dtype for x in types):
+ raise TypeError(
+ "Input '%s' of '%s' Op requires l-value input" %
+ (input_name, op_type_name))
+ input_types.extend(types)
+ else:
+ input_types.extend(base_types)
+
+ # Process remaining attrs
+ for attr in op_def.attr:
+ # Skip attrs that have already had their values inferred
+ if attr.name in attrs:
+ if attr.name in keywords:
+ raise TypeError(
+ "Should not specify value for inferred attr '%s'." % attr.name)
+ continue
+ if attr.name in keywords:
+ attrs[attr.name] = keywords.pop(attr.name)
+ elif attr.name + "_" in keywords:
+ # Attrs whose names match Python keywords have an extra '_'
+ # appended, so we must check for that as well.
+ attrs[attr.name] = keywords.pop(attr.name + "_")
+ else:
+ raise TypeError("No argument for attr " + attr.name)
+
+ # Convert attr values to AttrValue protos.
+ attr_protos = {}
+ for attr_def in op_def.attr:
+ key = attr_def.name
+ value = attrs[key]
+ attr_value = attr_value_pb2.AttrValue()
+ if attr_def.HasField("default_value") and value is None:
+ attr_value.CopyFrom(attr_def.default_value)
+ attr_protos[key] = attr_value
+ continue
+ if attr_def.type.startswith("list("):
+ if not _IsListValue(value):
+ raise TypeError("Expected list for attr " + key)
+ if attr_def.has_minimum:
+ if len(value) < attr_def.minimum:
+ raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
+ "less than minimum %d." %
+ (key, op_type_name, len(value),
+ attr_def.minimum))
+ if attr_def.type == "string":
+ attr_value.s = _MakeStr(value, key)
+ if attr_def.HasField("allowed_values"):
+ if attr_value.s not in attr_def.allowed_values.list.s:
+ raise ValueError(
+ "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
+ (key, op_type_name, attr_value.s,
+ '", "'.join(attr_def.allowed_values.list.s)))
+ elif attr_def.type == "list(string)":
+ attr_value.list.s.extend([_MakeStr(x, key) for x in value])
+ if attr_def.HasField("allowed_values"):
+ for x in attr_value.list.s:
+ if x not in attr_def.allowed_values.list.s:
+ raise ValueError(
+ "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
+ (key, op_type_name, x,
+ '", "'.join(attr_def.allowed_values.list.s)))
+ elif attr_def.type == "int":
+ attr_value.i = _MakeInt(value, key)
+ if attr_def.has_minimum:
+ if attr_value.i < attr_def.minimum:
+ raise ValueError(
+ "Attr '%s' of '%s' Op passed %d less than minimum %d." %
+ (key, op_type_name, attr_value.i, attr_def.minimum))
+ elif attr_def.type == "list(int)":
+ attr_value.list.i.extend([_MakeInt(x, key) for x in value])
+ elif attr_def.type == "float":
+ attr_value.f = _MakeFloat(value, key)
+ elif attr_def.type == "list(float)":
+ attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
+ elif attr_def.type == "bool":
+ attr_value.b = _MakeBool(value, key)
+ elif attr_def.type == "list(bool)":
+ attr_value.list.b.extend([_MakeBool(x, key) for x in value])
+ elif attr_def.type == "type":
+ attr_value.type = _MakeType(value, attr_def)
+ elif attr_def.type == "list(type)":
+ attr_value.list.type.extend(
+ [_MakeType(x, attr_def) for x in value])
+ elif attr_def.type == "shape":
+ attr_value.shape.CopyFrom(_MakeShape(value, key))
+ elif attr_def.type == "list(shape)":
+ attr_value.list.shape.extend(
+ [_MakeShape(x, key) for x in value])
+ elif attr_def.type == "tensor":
+ attr_value.tensor.CopyFrom(_MakeTensor(value, key))
+ elif attr_def.type == "list(tensor)":
+ attr_value.list.tensor.extend(
+ [_MakeTensor(x, key) for x in value])
+ else:
+ raise TypeError("Unrecognized Attr type " + attr_def.type)
+
+ attr_protos[key] = attr_value
+ del attrs # attrs is no longer authoritative, use attr_protos instead
+
+ # Determine output types (possibly using attrs)
+ output_types = []
+ output_structure = []
+ for arg in op_def.output_arg:
+ types = []
+ if arg.number_attr:
+ n = _AttrValue(attr_protos, arg.number_attr).i
+ if arg.type_attr:
+ types = [_AttrValue(attr_protos, arg.type_attr).type] * n
+ else:
+ types = [arg.type] * n
+ output_structure.append(n)
+ elif arg.type_attr:
+ t = _AttrValue(attr_protos, arg.type_attr)
+ types = [t.type]
+ output_structure.append(None)
+ elif arg.type_list_attr:
+ t = _AttrValue(attr_protos, arg.type_list_attr)
+ types = t.list.type
+ output_structure.append(len(t.list.type))
+ else:
+ types = [arg.type]
+ output_structure.append(None)
+ if arg.is_ref:
+ types = [types_lib.as_dtype(x).as_ref for x in types]
+ output_types.extend(types)
+
+ if keywords:
+ raise TypeError("apply_op() got unexpected keyword arguments: " +
+ ", ".join(sorted(keywords.keys())))
+
+ # Add Op to graph
+ if output_structure:
+ op = g.create_op(op_type_name, inputs, output_types, name=scope,
+ input_types=input_types, attrs=attr_protos,
+ op_def=op_def)
+ outputs = op.outputs
+ return _Restructure(ops.convert_n_to_tensor_or_indexed_slices(outputs),
+ output_structure)
+ else:
+ return g.create_op(op_type_name, inputs, output_types, name=scope,
+ input_types=input_types, attrs=attr_protos,
+ op_def=op_def)
diff --git a/tensorflow/python/ops/op_def_library_test.py b/tensorflow/python/ops/op_def_library_test.py
new file mode 100644
index 0000000000..72de4586a3
--- /dev/null
+++ b/tensorflow/python/ops/op_def_library_test.py
@@ -0,0 +1,1402 @@
+"""Tests for tensorflow.python.ops.op_def_library."""
+
+from google.protobuf import text_format
+
+from tensorflow.core.framework import op_def_pb2
+from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types
+from tensorflow.python.ops.op_def_library import OpDefLibrary
+from tensorflow.python.platform import googletest
+
+
+# NOTE(mrry): Dummy shape registrations for ops used in the tests.
+ops.RegisterShape("Attr")(None)
+ops.RegisterShape("AttrBool")(None)
+ops.RegisterShape("AttrBoolList")(None)
+ops.RegisterShape("AttrDefault")(None)
+ops.RegisterShape("AttrEmptyListDefault")(None)
+ops.RegisterShape("AttrEnum")(None)
+ops.RegisterShape("AttrEnumList")(None)
+ops.RegisterShape("AttrFloat")(None)
+ops.RegisterShape("AttrListDefault")(None)
+ops.RegisterShape("AttrListMin")(None)
+ops.RegisterShape("AttrMin")(None)
+ops.RegisterShape("AttrShape")(None)
+ops.RegisterShape("AttrShapeList")(None)
+ops.RegisterShape("Binary")(None)
+ops.RegisterShape("ComplexStruct")(None)
+ops.RegisterShape("InPolymorphicTwice")(None)
+ops.RegisterShape("MixedStruct")(None)
+ops.RegisterShape("NInPolymorphicTwice")(None)
+ops.RegisterShape("NInTwice")(None)
+ops.RegisterShape("NInTwoTypeVariables")(None)
+ops.RegisterShape("NIntsIn")(None)
+ops.RegisterShape("NIntsOut")(None)
+ops.RegisterShape("NIntsOutDefault")(None)
+ops.RegisterShape("NPolymorphicIn")(None)
+ops.RegisterShape("NPolymorphicOut")(None)
+ops.RegisterShape("NPolymorphicOutDefault")(None)
+ops.RegisterShape("NPolymorphicRestrictIn")(None)
+ops.RegisterShape("NPolymorphicRestrictOut")(None)
+ops.RegisterShape("OutT")(None)
+ops.RegisterShape("OutTypeList")(None)
+ops.RegisterShape("OutTypeListRestrict")(None)
+ops.RegisterShape("Polymorphic")(None)
+ops.RegisterShape("PolymorphicDefaultOut")(None)
+ops.RegisterShape("PolymorphicOut")(None)
+ops.RegisterShape("RefIn")(None)
+ops.RegisterShape("RefOut")(None)
+ops.RegisterShape("ReservedAttr")(None)
+ops.RegisterShape("ReservedInput")(None)
+ops.RegisterShape("Restrict")(None)
+ops.RegisterShape("Simple")(None)
+ops.RegisterShape("SimpleStruct")(None)
+ops.RegisterShape("TypeList")(None)
+ops.RegisterShape("TypeListRestrict")(None)
+ops.RegisterShape("TypeListTwice")(None)
+
+
+class OpDefLibraryTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._lib = OpDefLibrary()
+ self._g = ops.Graph()
+ self._default_graph_controller = self._g.as_default()
+ self._default_graph_controller.__enter__()
+ self._add_op("name: 'Simple' input_arg { name: 'a' type: DT_INT32 } "
+ "output_arg { name: 'out' type: DT_FLOAT }")
+ self._add_op("name: 'OutT' output_arg { name: 'a' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+
+ def tearDown(self):
+ self._default_graph_controller.__exit__(None, None, None)
+
+ def _add_op(self, ascii):
+ op_def = op_def_pb2.OpDef()
+ text_format.Merge(ascii, op_def)
+ self._lib.add_op(op_def)
+
+ def Tensor(self, t, name="in"):
+ return self._lib.apply_op("OutT", T=t, name=name)
+
+ def testNoRegisteredOpFails(self):
+ with self.assertRaises(RuntimeError) as cm:
+ self._lib.apply_op("unknown", g=self._g)
+ self.assertEqual(cm.exception.message, "Unrecognized Op name unknown")
+
+ def testAddOpValidation(self):
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'MissingTypeAttr' "
+ "input_arg { name: 'a' type_attr: 'T' } ")
+ self.assertEqual(cm.exception.message,
+ "Inconsistent OpDef for 'MissingTypeAttr', "
+ "missing attr 'T'")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'BadTypeAttr' "
+ "output_arg { name: 'a' type_attr: 'T' } "
+ "attr { name: 'T' type: 'int' }")
+ self.assertEqual(
+ cm.exception.message,
+ "Attr 'T' of 'BadTypeAttr' used as a type_attr but has type int")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'MissingNumberAttr' "
+ "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } ")
+ self.assertEqual(cm.exception.message,
+ "Inconsistent OpDef for 'MissingNumberAttr', "
+ "missing attr 'N'")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'BadNumberAttr' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
+ "attr { name: 'N' type: 'type' }")
+ self.assertEqual(
+ cm.exception.message,
+ "Attr 'N' of 'BadNumberAttr' used as a number_attr but has type type")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'TwoTypesA' "
+ "input_arg { name: 'a' type: DT_INT32 type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+ self.assertEqual(cm.exception.message,
+ "Arg 'a' of 'TwoTypesA' must have one type field not 2")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'TwoTypesB' "
+ "input_arg { name: 'a' type: DT_INT32 type_list_attr: 'T' } "
+ "attr { name: 'T' type: 'list(type)' }")
+ self.assertEqual(cm.exception.message,
+ "Arg 'a' of 'TwoTypesB' must have one type field not 2")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'ThreeTypes' "
+ "input_arg { name: 'a' type: DT_INT32 type_attr: 'T' "
+ "type_list_attr: 'U' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'U' type: 'list(type)' }")
+ self.assertEqual(cm.exception.message,
+ "Arg 'a' of 'ThreeTypes' must have one type field not 3")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'NoTypes' output_arg { name: 'a' } ")
+ self.assertEqual(cm.exception.message,
+ "Arg 'a' of 'NoTypes' must have one type field not 0")
+
+ def testSimple(self):
+ out = self._lib.apply_op("Simple", a=3)
+ self.assertEquals(types.float32, out.dtype)
+ self.assertProtoEquals("""
+ name: 'Simple' op: 'Simple' input: 'Simple/a'
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Simple", a=4)
+ self.assertProtoEquals("""
+ name: 'Simple_1' op: 'Simple' input: 'Simple_1/a'
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Simple", a=5, name="named")
+ self.assertProtoEquals("""
+ name: 'named' op: 'Simple' input: 'named/a'
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Simple", a=[[1, 2, 3], [4, 5, 6]], name="two_d")
+ self.assertProtoEquals("""
+ name: 'two_d' op: 'Simple' input: 'two_d/a'
+ """, out.op.node_def)
+
+ def testSimpleFailures(self):
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", a="Bad string")
+ self.assertEqual(cm.exception.message,
+ "Expected int32, got 'Bad string' instead.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", a=self.Tensor(types.string))
+ self.assertEqual(cm.exception.message,
+ "Input 'a' of 'Simple' Op has type string "
+ "that does not match expected type of int32.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", a=6, extra="bogus")
+ self.assertEqual(cm.exception.message,
+ "apply_op() got unexpected keyword arguments: extra")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", a=6, extra1="bogus", extra2="also_bogus")
+ self.assertEqual(cm.exception.message,
+ "apply_op() got unexpected keyword arguments: extra1, "
+ "extra2")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple")
+ self.assertEqual(cm.exception.message, "No argument for input a")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", wrong=7)
+ self.assertEqual(cm.exception.message, "No argument for input a")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", a=[self.Tensor(types.int32)])
+ self.assertStartsWith(cm.exception.message, "Expected int32, got")
+
+ def testReservedInput(self):
+ self._add_op("name: 'ReservedInput' "
+ "input_arg { name: 'input' type: DT_INT32 } ")
+ op = self._lib.apply_op("ReservedInput", input_=7, name="x")
+ self.assertProtoEquals("""
+ name: 'x' op: 'ReservedInput' input: 'x/input'
+ """, op.node_def)
+
+ def testPolymorphic(self):
+ self._add_op("name: 'Polymorphic' "
+ "input_arg { name: 'a' type_attr: 'T' } "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+
+ out = self._lib.apply_op("Polymorphic", a=7, name="p")
+ self.assertEquals(types.int32, out.dtype)
+ self.assertProtoEquals("""
+ name: 'p' op: 'Polymorphic' input: 'p/a'
+ attr { key: 'T' value { type: DT_INT32 } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Polymorphic", a="s", name="q")
+ self.assertEquals(types.string, out.dtype)
+ self.assertProtoEquals("""
+ name: 'q' op: 'Polymorphic' input: 'q/a'
+ attr { key: 'T' value { type: DT_STRING } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Polymorphic", a=["s", "t", "u"], name="r")
+ self.assertEquals(types.string, out.dtype)
+ self.assertProtoEquals("""
+ name: 'r' op: 'Polymorphic' input: 'r/a'
+ attr { key: 'T' value { type: DT_STRING } }
+ """, out.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Polymorphic", a="s", T=types.string)
+ self.assertEqual(cm.exception.message,
+ "Should not specify value for inferred attr 'T'.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Polymorphic", a=[self.Tensor(types.bool)])
+ self.assertEqual(cm.exception.message,
+ "List of Tensors when single Tensor expected")
+
+ def testPolymorphicOut(self):
+ self._add_op("name: 'PolymorphicOut' "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+
+ out = self._lib.apply_op("PolymorphicOut", T=types.int32, name="p")
+ self.assertEquals(types.int32, out.dtype)
+ self.assertProtoEquals("""
+ name: 'p' op: 'PolymorphicOut'
+ attr { key: 'T' value { type: DT_INT32 } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("PolymorphicOut", T=types.bool, name="q")
+ self.assertEquals(types.bool, out.dtype)
+ self.assertProtoEquals("""
+ name: 'q' op: 'PolymorphicOut'
+ attr { key: 'T' value { type: DT_BOOL } }
+ """, out.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("PolymorphicOut")
+ self.assertEqual(cm.exception.message,
+ "No argument for attr T")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("PolymorphicOut", T=None)
+ self.assertEqual(cm.exception.message,
+ "Expected DataType for argument 'T' not None.")
+
+ def testPolymorphicDefaultOut(self):
+ self._add_op("name: 'PolymorphicDefaultOut' "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' "
+ " default_value { type: DT_STRING } }")
+
+ out = self._lib.apply_op("PolymorphicDefaultOut", T=None, name="p")
+ self.assertEquals(types.string, out.dtype)
+ self.assertProtoEquals("""
+ name: 'p' op: 'PolymorphicDefaultOut'
+ attr { key: 'T' value { type: DT_STRING } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("PolymorphicDefaultOut", T=types.bool,
+ name="q")
+ self.assertEquals(types.bool, out.dtype)
+ self.assertProtoEquals("""
+ name: 'q' op: 'PolymorphicDefaultOut'
+ attr { key: 'T' value { type: DT_BOOL } }
+ """, out.op.node_def)
+
+ def testBinary(self):
+ self._add_op("name: 'Binary' "
+ "input_arg { name: 'a' type_attr: 'T' } "
+ "input_arg { name: 'b' type_attr: 'T' } "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+
+ out = self._lib.apply_op("Binary", a=8, b=9, name="b")
+ self.assertEquals(types.int32, out.dtype)
+ self.assertProtoEquals("""
+ name: 'b' op: 'Binary' input: 'b/a' input: 'b/b'
+ attr { key: 'T' value { type: DT_INT32 } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Binary", a="left", b="right", name="c")
+ self.assertEquals(types.string, out.dtype)
+ self.assertProtoEquals("""
+ name: 'c' op: 'Binary' input: 'c/a' input: 'c/b'
+ attr { key: 'T' value { type: DT_STRING } }
+ """, out.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Binary", a="left", b=12)
+ self.assertEqual(cm.exception.message,
+ "Expected string, got 12 instead.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Binary", a=self.Tensor(types.string),
+ b=self.Tensor(types.int32))
+ self.assertEqual(cm.exception.message,
+ "Input 'b' of 'Binary' Op has type int32 "
+ "that does not match type string of argument 'a'.")
+
+ def testRestrict(self):
+ self._add_op("name: 'Restrict' "
+ "input_arg { name: 'a' type_attr: 'T' } "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' allowed_values { list { "
+ " type: DT_STRING type: DT_BOOL } } }")
+
+ out = self._lib.apply_op("Restrict", a="foo", name="g")
+ self.assertEquals(types.string, out.dtype)
+ self.assertProtoEquals("""
+ name: 'g' op: 'Restrict' input: 'g/a'
+ attr { key: 'T' value { type: DT_STRING } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Restrict", a=True, name="h")
+ self.assertEquals(types.bool, out.dtype)
+ self.assertProtoEquals("""
+ name: 'h' op: 'Restrict' input: 'h/a'
+ attr { key: 'T' value { type: DT_BOOL } }
+ """, out.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Restrict", a=17)
+ self.assertEqual(cm.exception.message,
+ "DataType int32 for attr 'T' "
+ "not in list of allowed values: "
+ "string, bool")
+
+ def testTypeList(self):
+ self._add_op("name: 'TypeList' "
+ "input_arg { name: 'a' type_list_attr: 'T' } "
+ "attr { name: 'T' type: 'list(type)' }")
+
+ op = self._lib.apply_op("TypeList", a=["foo"], name="z")
+ self.assertProtoEquals("""
+ name: 'z' op: 'TypeList' input: 'z/a_0'
+ attr { key: 'T' value { list { type: DT_STRING } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("TypeList", a=[True, 12], name="y")
+ self.assertProtoEquals("""
+ name: 'y' op: 'TypeList' input: 'y/a_0' input: 'y/a_1'
+ attr { key: 'T' value { list { type: DT_BOOL type: DT_INT32 } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("TypeList", a=[], name="empty")
+ self.assertProtoEquals("""
+ name: 'empty' op: 'TypeList' attr { key: 'T' value { list { } } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("TypeList", a=17)
+ self.assertStartsWith(cm.exception.message,
+ "Expected list for 'a' "
+ "argument to 'TypeList' Op, not ")
+
+ def testTypeListTwice(self):
+ self._add_op("name: 'TypeListTwice' "
+ "input_arg { name: 'a' type_list_attr: 'T' } "
+ "input_arg { name: 'b' type_list_attr: 'T' } "
+ "attr { name: 'T' type: 'list(type)' }")
+
+ op = self._lib.apply_op("TypeListTwice", a=["foo", True], b=["bar", False],
+ name="z")
+ self.assertProtoEquals("""
+ name: 'z' op: 'TypeListTwice'
+ input: 'z/a_0' input: 'z/a_1' input: 'z/b_0' input: 'z/b_1'
+ attr { key: 'T' value { list { type: DT_STRING type: DT_BOOL } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("TypeListTwice", a=[], b=[], name="empty")
+ self.assertProtoEquals("""
+ name: 'empty' op: 'TypeListTwice' attr { key: 'T' value { list { } } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("TypeListTwice", a=["foo", True], b=["bar", 6])
+ self.assertEqual(cm.exception.message,
+ "Input 'b' of 'TypeListTwice' Op has type list of "
+ "string, int32 that does not match type list "
+ "string, bool of argument 'a'.")
+
+ def testOutTypeList(self):
+ self._add_op("name: 'OutTypeList' "
+ "output_arg { name: 'out' type_list_attr: 'T' } "
+ "attr { name: 'T' type: 'list(type)' }")
+
+ out, = self._lib.apply_op("OutTypeList", T=[types.float32], name="x")
+ self.assertEquals(types.float32, out.dtype)
+ self.assertProtoEquals("""
+ name: 'x' op: 'OutTypeList'
+ attr { key: 'T' value { list { type: DT_FLOAT } } }
+ """, out.op.node_def)
+
+ out1, out2 = self._lib.apply_op("OutTypeList",
+ T=[types.int32, types.bool],
+ name="w")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.bool, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'w' op: 'OutTypeList'
+ attr { key: 'T' value { list { type: DT_INT32 type: DT_BOOL } } }
+ """, out1.op.node_def)
+
+ out = self._lib.apply_op("OutTypeList", T=[], name="empty")
+ self.assertEqual([], out)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("OutTypeList", T=types.int32)
+ self.assertEqual(cm.exception.message, "Expected list for attr T")
+
+ def testTypeListRestrict(self):
+ self._add_op("name: 'TypeListRestrict' "
+ "input_arg { name: 'a' type_list_attr: 'T' } "
+ "attr { name: 'T' type: 'list(type)' allowed_values { list { "
+ " type: DT_STRING type: DT_BOOL } } }")
+
+ op = self._lib.apply_op("TypeListRestrict", a=["foo", False], name="v")
+ self.assertProtoEquals("""
+ name: 'v' op: 'TypeListRestrict' input: 'v/a_0' input: 'v/a_1'
+ attr { key: 'T' value { list { type: DT_STRING type: DT_BOOL } } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("TypeListRestrict", a=[True, 12])
+ self.assertEqual(cm.exception.message,
+ "DataType int32 for attr 'T' "
+ "not in list of allowed values: string, bool")
+
+ def testOutTypeListRestrict(self):
+ self._add_op("name: 'OutTypeListRestrict' "
+ "output_arg { name: 'out' type_list_attr: 't' } "
+ "attr { name: 't' type: 'list(type)' allowed_values { list { "
+ " type: DT_STRING type: DT_BOOL } } }")
+
+ out1, out2 = self._lib.apply_op("OutTypeListRestrict",
+ t=[types.bool, types.string],
+ name="u")
+ self.assertEquals(types.bool, out1.dtype)
+ self.assertEquals(types.string, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'u' op: 'OutTypeListRestrict'
+ attr { key: 't' value { list { type: DT_BOOL type: DT_STRING } } }
+ """, out1.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("OutTypeListRestrict",
+ t=[types.string, types.int32])
+ self.assertEqual(cm.exception.message,
+ "DataType int32 for attr 't' "
+ "not in list of allowed values: string, bool")
+
+ def testAttr(self):
+ self._add_op("name: 'Attr' attr { name: 'a' type: 'int' }")
+ op = self._lib.apply_op("Attr", a=12, name="t")
+ self.assertProtoEquals("""
+ name: 't' op: 'Attr' attr { key: 'a' value { i: 12 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("Attr", a=tensor_shape.Dimension(13), name="u")
+ self.assertProtoEquals("""
+ name: 'u' op: 'Attr' attr { key: 'a' value { i: 13 } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Attr", a="bad")
+ self.assertEqual(cm.exception.message,
+ "Expected int for argument 'a' not 'bad'.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Attr", a=[12])
+ self.assertEqual(cm.exception.message,
+ "Expected int for argument 'a' not [12].")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Attr", a=None)
+ self.assertEqual(cm.exception.message,
+ "Expected int for argument 'a' not None.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Attr")
+ self.assertEqual(cm.exception.message, "No argument for attr a")
+
+ def testAttrFloat(self):
+ self._add_op("name: 'AttrFloat' attr { name: 'a' type: 'float' }")
+
+ op = self._lib.apply_op("AttrFloat", a=1.2, name="t")
+ self.assertProtoEquals("""
+ name: 't' op: 'AttrFloat' attr { key: 'a' value { f: 1.2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrFloat", a=12, name="u")
+ self.assertProtoEquals("""
+ name: 'u' op: 'AttrFloat' attr { key: 'a' value { f: 12 } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("AttrFloat", a="bad")
+ self.assertEqual(cm.exception.message,
+ "Expected float for argument 'a' not 'bad'.")
+
+ def testAttrBool(self):
+ self._add_op("name: 'AttrBool' attr { name: 'a' type: 'bool' }")
+
+ op = self._lib.apply_op("AttrBool", a=True, name="t")
+ self.assertProtoEquals("""
+ name: 't' op: 'AttrBool' attr { key: 'a' value { b: true } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrBool", a=False, name="u")
+ self.assertProtoEquals("""
+ name: 'u' op: 'AttrBool' attr { key: 'a' value { b: false } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("AttrBool", a=0)
+ self.assertEqual(cm.exception.message,
+ "Expected bool for argument 'a' not 0.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("AttrBool", a=1)
+ self.assertEqual(cm.exception.message,
+ "Expected bool for argument 'a' not 1.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("AttrBool", a=[])
+ self.assertEqual(cm.exception.message,
+ "Expected bool for argument 'a' not [].")
+
+ def testAttrBoolList(self):
+ self._add_op("name: 'AttrBoolList' attr { name: 'a' type: 'list(bool)' }")
+
+ op = self._lib.apply_op("AttrBoolList", a=[True, False, True], name="t")
+ self.assertProtoEquals("""
+ name: 't' op: 'AttrBoolList'
+ attr { key: 'a' value { list { b: true b: false b:true } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrBoolList", a=[], name="u")
+ self.assertProtoEquals("""
+ name: 'u' op: 'AttrBoolList' attr { key: 'a' value { list { } } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("AttrBoolList", a=[0])
+ self.assertEqual(cm.exception.message,
+ "Expected bool for argument 'a' not 0.")
+
+ def testAttrMin(self):
+ self._add_op("name: 'AttrMin' attr { name: 'a' type: 'int' "
+ "has_minimum: true minimum: 5 }")
+ op = self._lib.apply_op("AttrMin", a=12, name="s")
+ self.assertProtoEquals("""
+ name: 's' op: 'AttrMin' attr { key: 'a' value { i: 12 } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("AttrMin", a=2)
+ self.assertEqual(cm.exception.message,
+ "Attr 'a' of 'AttrMin' Op passed 2 less than minimum 5.")
+
+ def testAttrListMin(self):
+ self._add_op("name: 'AttrListMin' attr { name: 'a' type: 'list(int)' "
+ "has_minimum: true minimum: 2 }")
+
+ op = self._lib.apply_op("AttrListMin", a=[1, 2], name="r")
+ self.assertProtoEquals("""
+ name: 'r' op: 'AttrListMin'
+ attr { key: 'a' value { list { i: 1 i: 2 } } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("AttrListMin", a=[17])
+ self.assertEqual(cm.exception.message,
+ "Attr 'a' of 'AttrListMin' Op "
+ "passed list of length 1 less than minimum 2.")
+
+ def testAttrEnum(self):
+ self._add_op("name: 'AttrEnum' "
+ "attr { name: 'a' type: 'string' "
+ " allowed_values { list { s: 'apples' s: 'oranges' } } }")
+
+ op = self._lib.apply_op("AttrEnum", a="oranges", name="e")
+ self.assertProtoEquals("""
+ name: 'e' op: 'AttrEnum' attr { key: 'a' value { s: 'oranges' } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("AttrEnum", a="invalid")
+ self.assertEqual(cm.exception.message,
+ 'Attr \'a\' of \'AttrEnum\' Op '
+ 'passed string \'invalid\' not in: '
+ '"apples", "oranges".')
+
+ def testAttrEnumList(self):
+ self._add_op("name: 'AttrEnumList' "
+ "attr { name: 'a' type: 'list(string)' "
+ " allowed_values { list { s: 'apples' s: 'oranges' } } }")
+
+ op = self._lib.apply_op("AttrEnumList", a=["oranges", "apples"], name="f")
+ self.assertProtoEquals("""
+ name: 'f' op: 'AttrEnumList'
+ attr { key: 'a' value { list { s: 'oranges' s: 'apples' } } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("AttrEnumList", a=["apples", "invalid", "oranges"])
+ self.assertEqual(cm.exception.message,
+ 'Attr \'a\' of \'AttrEnumList\' Op '
+ 'passed string \'invalid\' not '
+ 'in: "apples", "oranges".')
+
+ def testAttrShape(self):
+ self._add_op("name: 'AttrShape' attr { name: 'a' type: 'shape' }")
+
+ op = self._lib.apply_op("AttrShape", a=[5], name="s1")
+ self.assertProtoEquals("""
+ name: 's1' op: 'AttrShape'
+ attr { key: 'a' value { shape { dim { size: 5 } } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrShape", a=(4, 3, 2), name="s2")
+ self.assertProtoEquals("""
+ name: 's2' op: 'AttrShape'
+ attr { key: 'a' value {
+ shape { dim { size: 4 } dim { size: 3 } dim { size: 2 } } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op(
+ "AttrShape", a=tensor_shape.TensorShape([3, 2]), name="s3")
+ self.assertProtoEquals("""
+ name: 's3' op: 'AttrShape'
+ attr { key: 'a' value {
+ shape { dim { size: 3 } dim { size: 2 } } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrShape", a=[], name="s4")
+ self.assertProtoEquals("""
+ name: 's4' op: 'AttrShape' attr { key: 'a' value { shape { } } }
+ """, op.node_def)
+
+ shape = tensor_shape_pb2.TensorShapeProto()
+ shape.dim.add().size = 6
+ shape.dim.add().size = 3
+ op = self._lib.apply_op("AttrShape", a=shape, name="s5")
+ self.assertProtoEquals("""
+ name: 's5' op: 'AttrShape'
+ attr { key: 'a' value { shape { dim { size: 6 } dim { size: 3 } } } }
+ """, op.node_def)
+
+ # TODO(josh11b): Re-enable this test once we stop promoting scalars to shapes.
+ # with self.assertRaises(TypeError) as cm:
+ # self._lib.apply_op("AttrShape", a=5)
+ # self.assertEqual(cm.exception.message,
+ # "Don't know how to convert 5 to a TensorShapeProto for "
+ # "argument 'a'")
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("AttrShape", a="ABC")
+
+ def testAttrShapeList(self):
+ self._add_op("name: 'AttrShapeList' attr { name: 'a' type: 'list(shape)' }")
+
+ op = self._lib.apply_op("AttrShapeList", a=[[3, 2], [6, 5, 4]], name="sl")
+ self.assertProtoEquals("""
+ name: 'sl' op: 'AttrShapeList'
+ attr { key: 'a' value { list {
+ shape { dim { size: 3 } dim { size: 2 } }
+ shape { dim { size: 6 } dim { size: 5 } dim { size: 4 } } } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrShapeList", a=[], name="esl")
+ self.assertProtoEquals("""
+ name: 'esl' op: 'AttrShapeList' attr { key: 'a' value { list { } } }
+ """, op.node_def)
+
+ def testAttrDefault(self):
+ self._add_op("name: 'AttrDefault' "
+ "attr { name: 'a' type: 'string' "
+ " default_value { s: 'banana' } }")
+
+ op = self._lib.apply_op("AttrDefault", a=None, name="d")
+ self.assertProtoEquals("""
+ name: 'd' op: 'AttrDefault' attr { key: 'a' value { s: 'banana' } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrDefault", a="kiwi", name="c")
+ self.assertProtoEquals("""
+ name: 'c' op: 'AttrDefault' attr { key: 'a' value { s: 'kiwi' } }
+ """, op.node_def)
+
+ def testAttrListDefault(self):
+ self._add_op("name: 'AttrListDefault' "
+ "attr { name: 'a' type: 'list(int)' "
+ " default_value { list { i: 5 i: 15 } } }")
+
+ op = self._lib.apply_op("AttrListDefault", a=None, name="b")
+ self.assertProtoEquals("""
+ name: 'b' op: 'AttrListDefault'
+ attr { key: 'a' value { list { i: 5 i: 15 } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrListDefault", a=[3], name="a")
+ self.assertProtoEquals("""
+ name: 'a' op: 'AttrListDefault'
+ attr { key: 'a' value { list { i: 3 } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrListDefault", a=[], name="empty")
+ self.assertProtoEquals("""
+ name: 'empty' op: 'AttrListDefault'
+ attr { key: 'a' value { list { } } }
+ """, op.node_def)
+
+ def testAttrEmptyListDefault(self):
+ self._add_op("name: 'AttrEmptyListDefault' "
+ "attr { name: 'a' type: 'list(float)' "
+ " default_value { list { } } }")
+
+ op = self._lib.apply_op("AttrEmptyListDefault", a=None, name="b")
+ self.assertProtoEquals("""
+ name: 'b' op: 'AttrEmptyListDefault'
+ attr { key: 'a' value { list { } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrEmptyListDefault", a=[3], name="a")
+ self.assertProtoEquals("""
+ name: 'a' op: 'AttrEmptyListDefault'
+ attr { key: 'a' value { list { f: 3 } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrEmptyListDefault", a=[], name="empty")
+ self.assertProtoEquals("""
+ name: 'empty' op: 'AttrEmptyListDefault'
+ attr { key: 'a' value { list { } } }
+ """, op.node_def)
+
+ def testReservedAttr(self):
+ self._add_op("name: 'ReservedAttr' "
+ "attr { name: 'range' type: 'int' } ")
+ op = self._lib.apply_op("ReservedAttr", range_=7, name="x")
+ self.assertProtoEquals("""
+ name: 'x' op: 'ReservedAttr' attr { key: 'range' value { i: 7 } }
+ """, op.node_def)
+
+ def testNIntsIn(self):
+ self._add_op("name: 'NIntsIn' "
+ "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ op = self._lib.apply_op("NIntsIn", a=[1, 2], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'NIntsIn' input: 'n/a_0' input: 'n/a_1'
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NIntsIn", a=[5, 4, 3, 2, 1], name="o")
+ self.assertProtoEquals("""
+ name: 'o' op: 'NIntsIn'
+ input: 'o/a_0' input: 'o/a_1' input: 'o/a_2' input: 'o/a_3' input: 'o/a_4'
+ attr { key: 'N' value { i: 5 } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsIn", a=["foo", "bar"])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NIntsIn' Op have types "
+ "[string, string] that do not match expected type int32.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsIn", a=[self.Tensor(types.string),
+ self.Tensor(types.string)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NIntsIn' Op have "
+ "types [string, string] that do not match expected type "
+ "int32.")
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NIntsIn", a=[99])
+ self.assertEqual(cm.exception.message,
+ "List argument 'a' to 'NIntsIn' Op "
+ "with length 1 shorter than "
+ "minimum length 2.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsIn", a=[38, "bar"])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NIntsIn' Op have types "
+ "[int32, string] that do not match expected type int32.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsIn", a=[self.Tensor(types.int32),
+ self.Tensor(types.string)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NIntsIn' Op "
+ "have types [int32, string] that do not match expected "
+ "type int32.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsIn", a=17)
+ self.assertStartsWith(cm.exception.message,
+ "Expected list for 'a' argument "
+ "to 'NIntsIn' Op, not ")
+
+ def testNPolymorphicIn(self):
+ self._add_op("name: 'NPolymorphicIn' "
+ "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ op = self._lib.apply_op("NPolymorphicIn", a=[1, 2], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'NPolymorphicIn' input: 'n/a_0' input: 'n/a_1'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NPolymorphicIn", a=[5, 4, 3, 2, 1], name="o")
+ self.assertProtoEquals("""
+ name: 'o' op: 'NPolymorphicIn'
+ input: 'o/a_0' input: 'o/a_1' input: 'o/a_2' input: 'o/a_3' input: 'o/a_4'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 5 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NPolymorphicIn", a=["foo", "bar"], name="p")
+ self.assertProtoEquals("""
+ name: 'p' op: 'NPolymorphicIn' input: 'p/a_0' input: 'p/a_1'
+ attr { key: 'T' value { type: DT_STRING } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NPolymorphicIn",
+ a=[1, self.Tensor(types.float32, name="x")],
+ name="q")
+ self.assertProtoEquals("""
+ name: 'q' op: 'NPolymorphicIn' input: 'q/a_0' input: 'x'
+ attr { key: 'T' value { type: DT_FLOAT } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NPolymorphicIn", a=[99])
+ self.assertEqual(cm.exception.message,
+ "List argument 'a' to 'NPolymorphicIn' Op with length 1 "
+ "shorter than minimum length 2.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicIn", a=[38, "bar"])
+ self.assertEqual(cm.exception.message,
+ "All tensors passed to 'a' of 'NPolymorphicIn' "
+ "Op must have the same type.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicIn",
+ a=[38, self.Tensor(types.string)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
+ "have types [int32, string] that don't all match.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicIn",
+ a=["abcd", self.Tensor(types.int32)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
+ "have types [string, int32] that don't all match.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicIn", a=17)
+ self.assertStartsWith(cm.exception.message,
+ "Expected list for 'a' argument "
+ "to 'NPolymorphicIn' Op, not ")
+
+ def testNPolymorphicRestrictIn(self):
+ self._add_op("name: 'NPolymorphicRestrictIn' "
+ "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' allowed_values { "
+ " list { type: DT_STRING type: DT_BOOL } } } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ op = self._lib.apply_op("NPolymorphicRestrictIn", a=["foo", "bar"],
+ name="p")
+ self.assertProtoEquals("""
+ name: 'p' op: 'NPolymorphicRestrictIn' input: 'p/a_0' input: 'p/a_1'
+ attr { key: 'T' value { type: DT_STRING } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NPolymorphicRestrictIn", a=[False, True, False],
+ name="b")
+ self.assertProtoEquals("""
+ name: 'b' op: 'NPolymorphicRestrictIn'
+ input: 'b/a_0' input: 'b/a_1' input: 'b/a_2'
+ attr { key: 'T' value { type: DT_BOOL } }
+ attr { key: 'N' value { i: 3 } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicRestrictIn", a=[1, 2])
+ self.assertEqual(cm.exception.message,
+ "DataType int32 for attr 'T' "
+ "not in list of allowed values: string, bool")
+
+ def testNInTwice(self):
+ self._add_op("name: 'NInTwice' "
+ "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
+ "input_arg { name: 'b' type: DT_STRING number_attr: 'N' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }")
+
+ op = self._lib.apply_op("NInTwice", a=[1, 2], b=["one", "two"], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'NInTwice'
+ input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NInTwice", a=[], b=[], name="o")
+ self.assertProtoEquals("""
+ name: 'o' op: 'NInTwice' attr { key: 'N' value { i: 0 } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NInTwice", a=[1, 2, 3], b=["too short"])
+ self.assertEqual(cm.exception.message,
+ "List argument 'b' to 'NInTwice' Op "
+ "with length 1 must match "
+ "length 3 of argument 'a'.")
+
+ def testNInPolymorphicTwice(self):
+ self._add_op("name: 'NInPolymorphicTwice' "
+ "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }")
+
+ op = self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=[3, 4], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'NInPolymorphicTwice'
+ input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NInPolymorphicTwice", a=[1, 2, 3], b=[5])
+ self.assertEqual(cm.exception.message,
+ "List argument 'b' to 'NInPolymorphicTwice' Op "
+ "with length 1 "
+ "must match length 3 of argument 'a'.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=["one", "two"])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'b' of 'NInPolymorphicTwice' "
+ "Op have types [string, string] that do not match type "
+ "int32 inferred from earlier arguments.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NInPolymorphicTwice",
+ a=[self.Tensor(types.int32)],
+ b=[self.Tensor(types.string)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'b' of "
+ "'NInPolymorphicTwice' Op have types [string] that do not "
+ "match type int32 inferred from earlier arguments.")
+
+ def testNInTwoTypeVariables(self):
+ self._add_op("name: 'NInTwoTypeVariables' "
+ "input_arg { name: 'a' type_attr: 'S' number_attr: 'N' } "
+ "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'S' type: 'type' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }")
+
+ op = self._lib.apply_op("NInTwoTypeVariables", a=[1, 2], b=[True, False],
+ name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'NInTwoTypeVariables'
+ input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'
+ attr { key: 'S' value { type: DT_INT32 } }
+ attr { key: 'T' value { type: DT_BOOL } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NInTwoTypeVariables", a=[1, 2], b=[3, 4], name="o")
+ self.assertProtoEquals("""
+ name: 'o' op: 'NInTwoTypeVariables'
+ input: 'o/a_0' input: 'o/a_1' input: 'o/b_0' input: 'o/b_1'
+ attr { key: 'S' value { type: DT_INT32 } }
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NInTwoTypeVariables",
+ a=[self.Tensor(types.int32, name="q")],
+ b=[self.Tensor(types.string, name="r")],
+ name="p")
+ self.assertProtoEquals("""
+ name: 'p' op: 'NInTwoTypeVariables' input: 'q' input: 'r'
+ attr { key: 'S' value { type: DT_INT32 } }
+ attr { key: 'T' value { type: DT_STRING } }
+ attr { key: 'N' value { i: 1 } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NInTwoTypeVariables", a=[1, 2, 3], b=["5"])
+ self.assertEqual(cm.exception.message,
+ "List argument 'b' to 'NInTwoTypeVariables' Op "
+ "with length 1 "
+ "must match length 3 of argument 'a'.")
+
+ def testInPolymorphicTwice(self):
+ self._add_op("name: 'InPolymorphicTwice' "
+ "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "input_arg { name: 'b' type_attr: 'T' number_attr: 'M' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 } "
+ "attr { name: 'M' type: 'int' has_minimum: true minimum: 0 } ")
+
+ op = self._lib.apply_op("InPolymorphicTwice", a=[8], b=[3, 4, 5], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'InPolymorphicTwice'
+ input: 'n/a_0' input: 'n/b_0' input: 'n/b_1' input: 'n/b_2'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 1 } }
+ attr { key: 'M' value { i: 3 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("InPolymorphicTwice", a=[8], b=[], name="o")
+ self.assertProtoEquals("""
+ name: 'o' op: 'InPolymorphicTwice' input: 'o/a_0'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 1 } }
+ attr { key: 'M' value { i: 0 } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("InPolymorphicTwice", a=[], b=[3, 4, 5])
+ self.assertEqual(cm.exception.message,
+ "Don't know how to infer type variable from empty input "
+ "list passed to input 'a' of 'InPolymorphicTwice' Op.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("InPolymorphicTwice", a=[1, 2], b=["one", "two"])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'b' of 'InPolymorphicTwice' Op "
+ "have types [string, string] that do not match type int32 "
+ "inferred from earlier arguments.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("InPolymorphicTwice",
+ a=[self.Tensor(types.int32)],
+ b=[self.Tensor(types.string)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'b' of 'InPolymorphicTwice' "
+ "Op have types [string] that do not match type int32 "
+ "inferred from earlier arguments.")
+
+ def testNIntsOut(self):
+ self._add_op("name: 'NIntsOut' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ out1, out2 = self._lib.apply_op("NIntsOut", N=2, name="n")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'n' op: 'NIntsOut' attr { key: 'N' value { i: 2 } }
+ """, out1.op.node_def)
+
+ out1, out2, out3, out4, out5 = self._lib.apply_op(
+ "NIntsOut", N=5, name="o")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertEquals(types.int32, out3.dtype)
+ self.assertEquals(types.int32, out4.dtype)
+ self.assertEquals(types.int32, out5.dtype)
+ self.assertProtoEquals("""
+ name: 'o' op: 'NIntsOut' attr { key: 'N' value { i: 5 } }
+ """, out5.op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NIntsOut", N=1)
+ self.assertEqual(cm.exception.message,
+ "Attr 'N' of 'NIntsOut' Op passed 1 less than minimum 2.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsOut", N=[3])
+ self.assertEqual(cm.exception.message,
+ "Expected int for argument 'N' not [3].")
+
+ def testNIntsOutDefault(self):
+ self._add_op("name: 'NIntsOutDefault' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2"
+ " default_value { i:3 } }")
+
+ out1, out2, out3 = self._lib.apply_op(
+ "NIntsOutDefault", N=None, name="z")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertEquals(types.int32, out3.dtype)
+ self.assertProtoEquals("""
+ name: 'z' op: 'NIntsOutDefault' attr { key: 'N' value { i: 3 } }
+ """, out1.op.node_def)
+
+ out1, out2 = self._lib.apply_op("NIntsOutDefault", N=2, name="y")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'y' op: 'NIntsOutDefault' attr { key: 'N' value { i: 2 } }
+ """, out2.op.node_def)
+
+ def testNPolymorphicOut(self):
+ self._add_op("name: 'NPolymorphicOut' "
+ "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ out1, out2 = self._lib.apply_op("NPolymorphicOut", N=2,
+ T=types.int32, name="n")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'n' op: 'NPolymorphicOut'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 2 } }
+ """, out1.op.node_def)
+
+ out1, out2, out3 = self._lib.apply_op(
+ "NPolymorphicOut", T=types.string, N=3, name="o")
+ self.assertEquals(types.string, out1.dtype)
+ self.assertEquals(types.string, out2.dtype)
+ self.assertEquals(types.string, out3.dtype)
+ self.assertProtoEquals("""
+ name: 'o' op: 'NPolymorphicOut'
+ attr { key: 'T' value { type: DT_STRING } }
+ attr { key: 'N' value { i: 3 } }
+ """, out3.op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NPolymorphicOut", N=1, T=types.string)
+ self.assertEqual(cm.exception.message,
+ "Attr 'N' of 'NPolymorphicOut' Op "
+ "passed 1 less than minimum 2.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicOut", N=3, T=[types.string])
+ self.assertEqual(
+ cm.exception.message,
+ "Expected DataType for argument 'T' not [tf.string].")
+
+ def testNPolymorphicOutDefault(self):
+ self._add_op("name: 'NPolymorphicOutDefault' "
+ "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type'"
+ " default_value { type: DT_BOOL } } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 "
+ " default_value { i: 2 } }")
+
+ out1, out2 = self._lib.apply_op(
+ "NPolymorphicOutDefault", N=None, T=None, name="r")
+ self.assertEquals(types.bool, out1.dtype)
+ self.assertEquals(types.bool, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'r' op: 'NPolymorphicOutDefault'
+ attr { key: 'T' value { type: DT_BOOL } }
+ attr { key: 'N' value { i: 2 } }
+ """, out1.op.node_def)
+
+ out1, out2, out3 = self._lib.apply_op(
+ "NPolymorphicOutDefault", N=3, T=None, name="s")
+ self.assertEquals(types.bool, out1.dtype)
+ self.assertEquals(types.bool, out2.dtype)
+ self.assertEquals(types.bool, out3.dtype)
+ self.assertProtoEquals("""
+ name: 's' op: 'NPolymorphicOutDefault'
+ attr { key: 'T' value { type: DT_BOOL } }
+ attr { key: 'N' value { i: 3 } }
+ """, out1.op.node_def)
+
+ out1, out2 = self._lib.apply_op(
+ "NPolymorphicOutDefault", N=None, T=types.int32, name="t")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertProtoEquals("""
+ name: 't' op: 'NPolymorphicOutDefault'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 2 } }
+ """, out1.op.node_def)
+
+ out1, out2, out3 = self._lib.apply_op(
+ "NPolymorphicOutDefault", N=3, T=types.int32, name="u")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertEquals(types.int32, out3.dtype)
+ self.assertProtoEquals("""
+ name: 'u' op: 'NPolymorphicOutDefault'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 3 } }
+ """, out1.op.node_def)
+
+ def testNPolymorphicRestrictOut(self):
+ self._add_op("name: 'NPolymorphicRestrictOut' "
+ "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' allowed_values { "
+ " list { type: DT_STRING type: DT_BOOL } } } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ out1, out2, out3 = self._lib.apply_op(
+ "NPolymorphicRestrictOut", N=3, T=types.bool, name="u")
+ self.assertEquals(types.bool, out1.dtype)
+ self.assertEquals(types.bool, out2.dtype)
+ self.assertEquals(types.bool, out3.dtype)
+ self.assertProtoEquals("""
+ name: 'u' op: 'NPolymorphicRestrictOut'
+ attr { key: 'T' value { type: DT_BOOL } }
+ attr { key: 'N' value { i: 3 } }
+ """, out1.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicRestrictOut", N=2, T=types.int32)
+ self.assertEqual(cm.exception.message,
+ "DataType int32 for attr 'T' "
+ "not in list of allowed values: string, bool")
+
+ def testRef(self):
+ self._add_op("name: 'RefIn' "
+ "input_arg { name: 'a' type_attr: 'T' is_ref: true } "
+ "attr { name: 'T' type: 'type' } ")
+ self._add_op("name: 'RefOut' "
+ "output_arg { name: 'a' type_attr: 'T' is_ref: true } "
+ "attr { name: 'T' type: 'type' } ")
+
+ out = self._lib.apply_op("RefOut", T=types.bool, name="o")
+ self.assertEquals(types.bool_ref, out.dtype)
+ self.assertProtoEquals("""
+ name: 'o' op: 'RefOut'
+ attr { key: 'T' value { type: DT_BOOL } }
+ """, out.op.node_def)
+
+ op = self._lib.apply_op("RefIn", a=out, name="i")
+ self.assertProtoEquals("""
+ name: 'i' op: 'RefIn' input: 'o'
+ attr { key: 'T' value { type: DT_BOOL } }
+ """, op.node_def)
+
+ # Can pass ref to non-ref input.
+ out = self._lib.apply_op("RefOut", T=types.int32, name="r")
+ out = self._lib.apply_op("Simple", a=out, name="s")
+ self.assertProtoEquals("""
+ name: 's' op: 'Simple' input: 'r'
+ """, out.op.node_def)
+
+ # Can't pass non-ref to ref input.
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("RefIn", a=2)
+ self.assertEqual(cm.exception.message,
+ "Input 'a' of 'RefIn' Op requires l-value input")
+
+ def testSpecifyDevice(self):
+ with self._g.device("ADevice"):
+ self._lib.apply_op("Simple", a=3)
+ # We look at the whole graph here to make sure the Const op is also given
+ # the specified device.
+ graph_def = self._g.as_graph_def()
+ self.assertEqual(len(graph_def.node), 2)
+ for node in graph_def.node:
+ self.assertEqual(node.device, "ADevice")
+
+ def testStructuredOutputSingleList(self):
+ self._add_op("name: 'SimpleStruct' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } "
+ "attr { name: 'n_a' type: 'int' }")
+ for n_a in [0, 1, 3]:
+ a = self._lib.apply_op("SimpleStruct", n_a=n_a)
+ self.assertTrue(isinstance(a, list))
+ self.assertEqual(n_a, len(a))
+
+ def testStructuredOutputListAndSingle(self):
+ self._add_op("name: 'MixedStruct' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } "
+ "output_arg { name: 'b' type: DT_FLOAT } "
+ "attr { name: 'n_a' type: 'int' }")
+ for n_a in [0, 1, 3]:
+ a, b = self._lib.apply_op("MixedStruct", n_a=n_a)
+ self.assertTrue(isinstance(a, list))
+ self.assertEqual(n_a, len(a))
+ self.assertTrue(all(x.dtype == types.int32 for x in a))
+ self.assertTrue(isinstance(b, ops.Tensor))
+ self.assertEqual(types.float32, b.dtype)
+
+ def testStructuredOutputMultipleLists(self):
+ self._add_op("name: 'ComplexStruct' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } "
+ "output_arg { name: 'b' type: DT_INT64 number_attr: 'n_b' } "
+ "output_arg { name: 'c' type_list_attr: 't_c' } "
+ "attr { name: 'n_a' type: 'int' } "
+ "attr { name: 'n_b' type: 'int' } "
+ "attr { name: 't_c' type: 'list(type)' }")
+ for n_a in [0, 1, 3]:
+ for n_b in [0, 1, 3]:
+ for t_c in [[],
+ [types.int32],
+ [types.int32, types.float32]]:
+ a, b, c = self._lib.apply_op("ComplexStruct",
+ n_a=n_a, n_b=n_b, t_c=t_c)
+
+ self.assertEqual(n_a, len(a))
+ self.assertTrue(all(x.dtype == types.int32 for x in a))
+ self.assertEqual(n_b, len(b))
+ self.assertTrue(all(x.dtype == types.int64 for x in b))
+ self.assertEqual(t_c, [x.dtype for x in c])
+
+
+class OpDefLibraryGraphTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._lib = OpDefLibrary()
+ self._g = ops.Graph()
+ self._add_op("name: 'Simple' input_arg { name: 'a' type: DT_INT32 } "
+ "output_arg { name: 'out' type: DT_FLOAT }")
+ self._add_op("name: 'Binary' "
+ "input_arg { name: 'a' type_attr: 'T' } "
+ "input_arg { name: 'b' type_attr: 'T' } "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+
+ def _add_op(self, ascii):
+ op_def = op_def_pb2.OpDef()
+ text_format.Merge(ascii, op_def)
+ self._lib.add_op(op_def)
+
+ def testNoGraph(self):
+ out = self._lib.apply_op("Simple", a=3)
+ self.assertEquals(out.graph, ops.get_default_graph())
+
+ def testDefaultGraph(self):
+ with self._g.as_default():
+ out = self._lib.apply_op("Simple", a=3)
+ self.assertEquals(out.graph, self._g)
+
+ def testIgnoreDefaultGraphWithGraphArgument(self):
+ default_g = ops.Graph()
+ with default_g.as_default():
+ out = self._lib.apply_op("Simple", a=3, g=self._g)
+ self.assertEquals(ops.get_default_graph(), default_g)
+ self.assertEquals(out.graph, self._g)
+
+ def testDifferentGraphFails(self):
+ a = self._lib.apply_op("Simple", a=3, g=self._g)
+ other_g = ops.Graph()
+ b = self._lib.apply_op("Simple", a=4, g=other_g)
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("Binary", a=a, b=b)
+ self.assertTrue("must be from the same graph" in cm.exception.message)
+
+ def testDifferentGraphFailsWithGraphArgument(self):
+ other_g = ops.Graph()
+ a = self._lib.apply_op("Simple", a=3, g=other_g)
+ b = self._lib.apply_op("Simple", a=4, g=other_g)
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("Binary", a=a, b=b, g=self._g)
+ self.assertTrue(
+ "not from the passed-in graph" in cm.exception.message)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
new file mode 100644
index 0000000000..dc954a3776
--- /dev/null
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -0,0 +1,390 @@
+"""Parsing Ops."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_parsing_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+# pylint: disable=wildcard-import,undefined-variable
+from tensorflow.python.ops.gen_parsing_ops import *
+
+
+ops.NoGradient("DecodeRaw")
+ops.NoGradient("StringToNumber")
+
+
+# pylint: disable=protected-access
+def parse_example(serialized,
+ names=None,
+ sparse_keys=None,
+ sparse_types=None,
+ dense_keys=None,
+ dense_types=None,
+ dense_defaults=None,
+ dense_shapes=None,
+ name="ParseExample"):
+ """Parse Example protos.
+
+ Args:
+ serialized: string vector, a batch of binary serialized Example protos.
+ names: A string vector, the names of the serialized protos.
+ "names" may contain, e.g., table key (descriptive) names for the
+ corresponding serialized protos. These are purely useful for debugging
+ purposes, and the presence of values here has no effect on the output.
+ "names" may be an empty vector, if no names are available.
+ If non-empty, this vector must be the same length as "serialized".
+ sparse_keys: A string list of keys in the Examples' features.
+ These keys are associated with sparse values.
+ sparse_types: A list of DTypes.
+ This list's length must match that of sparse_keys. Currently
+ parse_example supports tf.float32 (FloatList), tf.int64 (Int64List),
+ and tf.string (BytesList).
+ dense_keys: A string list of keys in the Examples' features.
+ These keys are associated with dense values.
+ dense_types: A list of DTypes.
+ This list's length must match that of dense_keys. Currently
+ parse_example supports tf.float32 (FloatList), tf.int64 (Int64List),
+ and tf.string (BytesList).
+ dense_defaults: A dict of {key:Tensor} (some may be missing).
+ The keys of the dict must match the dense_keys of the feature.
+ If a key is not present in this dictionary, the corresponding dense
+ Feature is required in all elements of serialized.
+ dense_shapes: A list of tuples.
+ Entries provide the shape of data in each dense Feature in features.
+ The length of dense_shapes must be the same as the length of dense_keys.
+ The number of elements in the Feature corresponding to dense_key[j]
+ must always have np.prod(dense_shapes[j]) entries.
+ If dense_shapes[j] == (D0, D1, ..., DN) then the the shape of output
+ Tensor dense_values[j] will be (|serialized|, D0, D1, ..., DN):
+ The dense outputs are just the inputs row-stacked by batch.
+ name: (Optional) Name of Op in the graph.
+
+ Returns:
+ A dictionary mapping keys to Tensors and SparseTensors.
+
+ The key dense_keys[j] is mapped to a tensor of type dense_types[j] and
+ of shape (serialized.size(),) + dense_shapes[j] (i.e., the dense outputs are
+ inputs, reshaped in row-major format and then row-stacked by batch).
+
+ The key sparse_keys[j] is mapped to a SparseTensor of type sparse_types[j].
+ The SparseTensor represents a ragged matrix. Its indices are [batch, index]
+ where "batch" is is the batch entry the value is from, and "index" is the
+ value's index in the list of values associated with that feature
+ and example. For example, if one expects a tf.float32 sparse feature "ft"
+ and three serialized examples are provided:
+
+ serialized = [
+ features:
+ { feature: [ key: { "ft" value: float_list: { value: [1.0, 2.0] } } ] },
+ features:
+ { feature: [] },
+ features:
+ { feature: [ key: { "ft" value: float_list: { value: [3.0] } } ] }
+ ]
+
+ then the output will look like:
+
+ {"ft": SparseTensor(indices=[[0, 0], [0, 1], [2, 0]],
+ values=[1.0, 2.0, 3.0],
+ shape=(3, 2)) }
+
+ Raises:
+ ValueError: If sparse and dense keys intersect, or input lengths do not
+ match up for sparse_* (similarly for dense_*).
+ TypeError: If an input is malformed.
+
+ Example input, format, and output: Just Sparse Inputs
+ ================================================
+
+ Given two brain.Example input protos:
+
+ serialized: // serialized versions of the protos below
+ [features: {
+ feature: { key: "kw" value: { bytes_list: { value: [ "knit", "big" ] } } }
+ feature: { key: "gps" value: { float_list: { value: [] } } }
+ },
+ features: {
+ feature: { key: "kw" value: { bytes_list: { value: [ "emmy" ] } } }
+ feature: { key: "dank" value: { int64_list: { value: [ 42 ] } } }
+ feature: { key: "gps" value: { } }
+ }]
+ names: ["input0", "input1"],
+ sparse_keys: ["kw", "dank", "gps"]
+ sparse_types: [DT_STRING, DT_INT64, DT_FLOAT]
+
+ Then the expected output is a dictionary:
+ {
+ "kw": SparseTensor(
+ indices=[[0, 0], [0, 1], [1, 0]],
+ values=["knit", "big", "emmy"]
+ shape=[2, 2]),
+ "dank": SparseTensor(
+ indices=[[1, 0]],
+ values=[42],
+ shape=[2, 1]),
+ "gps": SparseTensor(
+ indices=[],
+ values=[],
+ shape=[2, 0]),
+ }
+
+
+ Example input, format, and output: Dense Inputs (without defaults)
+ ==================================================================
+
+ Given two brain.Example input protos:
+
+ serialized: // serialized versions of the protos below
+ [features: {
+ feature: { key: "age" value: { int64_list: { value: [ 0 ] } } }
+ feature: { key: "gender" value: { bytes_list: { value: [ "f" ] } } }
+ },
+ features: {
+ feature: { key: "age" value: { int64_list: { value: [] } } }
+ feature: { key: "gender" value: { bytes_list: { value: [ "f" ] } } }
+ }]
+ names: ["input0", "input1"],
+ dense_keys: np.array(["age", "gender"])
+ dense_types: [tf.int64, tf.string]
+ dense_defaults: {
+ "age": -1 # defaults to -1 if missing
+ # "gender" has no specified default so it's required
+ }
+ dense_shapes: [(1,), (1,)] # age, gender, label, weight
+
+ Then the expected output is a dictionary:
+ {
+ "age": [[0], [-1]],
+ "gender": [["f"], ["f"]],
+ }
+
+
+ Example input, format, and output: Dense Inputs (with defaults)
+ ===============================================================
+
+ Given two brain.Example input protos:
+
+ serialized: // serialized versions of the protos below
+ [features: {
+ feature: { key: "weight" value: { float_list: { value: [ 1.0 ] } } }
+ },
+ features: {
+ feature: { key: "label" value: { float_list: { value: [ -1.0, 0.0 ] } } }
+ }]
+ names: ["input0", "input1"],
+ dense_keys: np.array(["label", "weight"])
+ dense_defaults: {
+ "label": [1.0, 2.0], # float (default: vector)
+ "weight": 5.0 # float (default: scalar, 5.0)
+ }
+ dense_shapes: [(2,), (1,)] # age, gender, label, weight
+
+ Then the expected output is a dictionary:
+ {
+ "label": [[1.0, 2.0], [-1.0, 0.0]],
+ "weight": [[1.0], [5.0]],
+ }
+ """
+ names = [] if names is None else names
+ dense_defaults = {} if dense_defaults is None else dense_defaults
+ sparse_keys = [] if sparse_keys is None else sparse_keys
+ sparse_types = [] if sparse_types is None else sparse_types
+ dense_keys = [] if dense_keys is None else dense_keys
+ dense_types = [] if dense_types is None else dense_types
+ dense_shapes = [
+ []] * len(dense_keys) if dense_shapes is None else dense_shapes
+
+ num_dense = len(dense_keys)
+ num_sparse = len(sparse_keys)
+
+ if len(dense_shapes) != num_dense:
+ raise ValueError("len(dense_shapes) != len(dense_keys): %d vs. %d"
+ % (len(dense_shapes), num_dense))
+ if len(dense_types) != num_dense:
+ raise ValueError("len(dense_types) != len(num_dense): %d vs. %d"
+ % (len(dense_types), num_dense))
+ if len(sparse_types) != num_sparse:
+ raise ValueError("len(sparse_types) != len(sparse_keys): %d vs. %d"
+ % (len(sparse_types), num_sparse))
+ if num_dense + num_sparse == 0:
+ raise ValueError("Must provide at least one sparse key or dense key")
+ if not set(dense_keys).isdisjoint(set(sparse_keys)):
+ raise ValueError(
+ "Dense and sparse keys must not intersect; intersection: %s" %
+ set(dense_keys).intersection(set(sparse_keys)))
+
+ dense_defaults_vec = []
+ for i, key in enumerate(dense_keys):
+ default_value = dense_defaults.get(key)
+ if default_value is None:
+ default_value = constant_op.constant([], dtype=dense_types[i])
+ elif not isinstance(default_value, ops.Tensor):
+ default_value = ops.convert_to_tensor(
+ default_value, dtype=dense_types[i], name=key)
+ default_value = array_ops.reshape(default_value, dense_shapes[i])
+
+ dense_defaults_vec.append(default_value)
+
+ dense_shapes = [tensor_util.MakeTensorShapeProto(shape)
+ if isinstance(shape, (list, tuple)) else shape
+ for shape in dense_shapes]
+
+ outputs = gen_parsing_ops._parse_example(
+ serialized=serialized,
+ names=names,
+ dense_defaults=dense_defaults_vec,
+ sparse_keys=sparse_keys,
+ sparse_types=sparse_types,
+ dense_keys=dense_keys,
+ dense_shapes=dense_shapes,
+ name=name)
+
+ (sparse_indices, sparse_values, sparse_shapes, dense_values) = outputs
+
+ sparse_tensors = [ops.SparseTensor(ix, val, shape) for (ix, val, shape)
+ in zip(sparse_indices, sparse_values, sparse_shapes)]
+
+ return dict(
+ zip(sparse_keys + dense_keys, sparse_tensors + dense_values))
+
+
+def parse_single_example(serialized, # pylint: disable=invalid-name
+ names=None,
+ sparse_keys=None,
+ sparse_types=None,
+ dense_keys=None,
+ dense_types=None,
+ dense_defaults=None,
+ dense_shapes=None,
+ name="ParseSingleExample"):
+ """Identical to parse_example but for scalar serialized and names.
+
+ Args:
+ serialized: A scalar string, a single serialized Example.
+ See parse_example documentation for more details.
+ names: (Optional) A scalar string, the associated name.
+ See parse_example documentation for more details.
+ sparse_keys: See parse_example documentation for more details.
+ sparse_types: See parse_example documentation for more details.
+ dense_keys: See parse_example documentation for more details.
+ dense_types: See parse_example documentation for more details.
+ dense_defaults: See parse_example documentation for more details.
+ dense_shapes: See parse_example documentation for more details.
+ name: Optional op name.
+
+ Returns:
+ A dictionary mapping keys to Tensors and SparseTensors.
+
+ For dense tensors, the Tensor is identical to the output of parse_example,
+ except it is one less dimension (the first, batch, dimension is removed).
+
+ For SparseTensors:
+ The first (batch) column of the indices matrix is removed
+ (it is now a column vector).
+ The values vector is unchanged.
+ The first (batch_size) entry of the shape vector is removed
+ (it is now a single element vector).
+
+ Raises:
+ ValueError: if "scalar" or "names" have known shapes, and are not scalars.
+ """
+ with ops.op_scope([serialized], name, "parse_single_example"):
+ serialized = ops.convert_to_tensor(serialized)
+ serialized_shape = serialized.get_shape()
+ if serialized_shape.ndims is not None:
+ if serialized_shape.ndims != 0:
+ raise ValueError("Input serialized must be a scalar")
+ else:
+ serialized = control_flow_ops.with_dependencies(
+ [logging_ops.Assert(
+ math_ops.equal(array_ops.rank(serialized), 0),
+ ["Input serialized must be a scalar"],
+ name="SerializedIsScalar")],
+ serialized,
+ name="SerializedDependencies")
+ serialized = array_ops.expand_dims(serialized, 0)
+ if names is not None:
+ names = ops.convert_to_tensor(names)
+ names_shape = names.get_shape()
+ if names_shape.ndims is not None:
+ if names_shape.ndims != 0:
+ raise ValueError("Input names must be a scalar")
+ else:
+ names = control_flow_ops.with_dependencies(
+ [logging_ops.Assert(
+ math_ops.equal(array_ops.rank(names), 0),
+ ["Input names must be a scalar"],
+ name="NamesIsScalar")],
+ names,
+ name="NamesDependencies")
+ names = array_ops.expand_dims(names, 0)
+
+ outputs = parse_example(serialized,
+ names=names,
+ sparse_keys=sparse_keys,
+ sparse_types=sparse_types,
+ dense_keys=dense_keys,
+ dense_types=dense_types,
+ dense_defaults=dense_defaults,
+ dense_shapes=dense_shapes,
+ name=name)
+ if dense_keys is not None:
+ for d in dense_keys:
+ outputs[d] = array_ops.squeeze(outputs[d], [0], name="Squeeze_%s" % d)
+ if sparse_keys is not None:
+ for s in sparse_keys:
+ outputs[s] = ops.SparseTensor(
+ array_ops.slice(outputs[s].indices,
+ [0, 1], [-1, -1], name="Slice_Indices_%s" % s),
+ outputs[s].values,
+ array_ops.slice(outputs[s].shape,
+ [1], [-1], name="Squeeze_Shape_%s" % s))
+ return outputs
+
+
+@ops.RegisterShape("ParseExample")
+def _ParseExampleShape(op):
+ """Shape function for the ParseExample op."""
+ input_shape = op.inputs[0].get_shape().with_rank(1)
+ num_sparse = op.get_attr("Nsparse")
+ num_dense = op.get_attr("Ndense")
+ dense_shapes = op.get_attr("dense_shapes")
+ sparse_index_shapes = [
+ tensor_shape.matrix(None, 2) for _ in range(num_sparse)]
+ sparse_value_shapes = [tensor_shape.vector(None) for _ in range(num_sparse)]
+ sparse_shape_shapes = [tensor_shape.vector(2) for _ in range(num_sparse)]
+ assert num_dense == len(dense_shapes)
+ dense_shapes = [
+ input_shape.concatenate((d.size for d in dense_shape.dim))
+ for dense_shape in dense_shapes]
+ return (sparse_index_shapes + sparse_value_shapes + sparse_shape_shapes +
+ dense_shapes)
+
+
+ops.RegisterShape("StringToNumber")(
+ common_shapes.unchanged_shape)
+
+
+@ops.RegisterShape("DecodeRaw")
+def _DecodeRawShape(op):
+ """Shape function for the DecodeRaw op."""
+ # NOTE(mrry): Last dimension is data-dependent.
+ return [op.inputs[0].get_shape().concatenate([None])]
+
+
+@ops.RegisterShape("DecodeCSV")
+def _DecodeCSVShape(op):
+ """Shape function for the DecodeCSV op."""
+ input_shape = op.inputs[0].get_shape()
+ # Optionally check that all of other inputs are scalar or empty.
+ for default_input in op.inputs[1:]:
+ default_input_shape = default_input.get_shape().with_rank(1)
+ if default_input_shape[0] > 1:
+ raise ValueError(
+ "Shape of a default must be a length-0 or length-1 vector.")
+ return [input_shape] * len(op.outputs)
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
new file mode 100644
index 0000000000..6bd8dd9e3d
--- /dev/null
+++ b/tensorflow/python/ops/random_ops.py
@@ -0,0 +1,181 @@
+"""Operations for generating random numbers."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_random_ops
+from tensorflow.python.ops import math_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_random_ops import *
+# pylint: enable=wildcard-import
+
+
+def _ShapeTensor(shape):
+ """Convert to an int32 or int64 tensor, defaulting to int32 if empty."""
+ if isinstance(shape, (tuple, list)) and not shape:
+ dtype = types.int32
+ else:
+ dtype = None
+ return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
+
+# pylint: disable=protected-access
+def random_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32,
+ seed=None, name=None):
+ """Outputs random values from a normal distribution.
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal
+ distribution.
+ stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
+ of the normal distribution.
+ dtype: The type of the output.
+ seed: A Python integer. Used to create a random seed for the distribution.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random normal values.
+ """
+ with ops.op_scope([shape, mean, stddev], name, "random_normal") as name:
+ shape_tensor = _ShapeTensor(shape)
+ mean_tensor = ops.convert_to_tensor(
+ mean, dtype=dtype, name="mean")
+ stddev_tensor = ops.convert_to_tensor(
+ stddev, dtype=dtype, name="stddev")
+ seed1, seed2 = random_seed.get_seed(seed)
+ rnd = gen_random_ops._random_standard_normal(shape_tensor, dtype,
+ seed=seed1,
+ seed2=seed2)
+ mul = rnd * stddev_tensor
+ value = math_ops.add(mul, mean_tensor, name=name)
+ return value
+
+
+ops.NoGradient("RandomStandardNormal")
+
+
+def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32,
+ seed=None, name=None):
+ """Outputs random values from a truncated normal distribution.
+
+ The generated values follow a normal distribution with specified mean and
+ standard deviation, except that values whose magnitude is more than 2 standard
+ deviations from the mean are dropped and re-picked.
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
+ truncated normal distribution.
+ stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
+ of the truncated normal distribution.
+ dtype: The type of the output.
+ seed: A Python integer. Used to create a random seed for the distribution.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random truncated normal values.
+ """
+ with ops.op_scope([shape, mean, stddev], name, "truncated_normal") as name:
+ shape_tensor = _ShapeTensor(shape)
+ mean_tensor = ops.convert_to_tensor(
+ mean, dtype=dtype, name="mean")
+ stddev_tensor = ops.convert_to_tensor(
+ stddev, dtype=dtype, name="stddev")
+ seed1, seed2 = random_seed.get_seed(seed)
+ rnd = gen_random_ops._truncated_normal(shape_tensor, dtype,
+ seed=seed1,
+ seed2=seed2)
+ mul = rnd * stddev_tensor
+ value = math_ops.add(mul, mean_tensor, name=name)
+ return value
+
+
+ops.NoGradient("TruncatedNormal")
+
+
+def random_uniform(shape, minval=0.0, maxval=1.0,
+ dtype=types.float32, seed=None,
+ name=None):
+ """Outputs random values from a uniform distribution.
+
+ The generated values follow a uniform distribution in the range
+ `[minval, maxval)`. The lower bound `minval` is included in the range, while
+ the upper bound `maxval` is excluded.
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ minval: A 0-D Tensor or Python value of type `dtype`. The lower bound on the
+ range of random values to generate.
+ maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on
+ the range of random values to generate.
+ dtype: The type of the output.
+ seed: A Python integer. Used to create a random seed for the distribution.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random uniform values.
+ """
+ with ops.op_scope([shape, minval, maxval], name, "random_uniform") as name:
+ shape_tensor = _ShapeTensor(shape)
+ min_tensor = ops.convert_to_tensor(minval, dtype=dtype, name="min")
+ range_tensor = ops.convert_to_tensor(
+ maxval - minval, dtype=dtype, name="range")
+ seed1, seed2 = random_seed.get_seed(seed)
+ rnd = gen_random_ops._random_uniform(shape_tensor, dtype,
+ seed=seed1,
+ seed2=seed2)
+ mul = rnd * range_tensor
+ value = math_ops.add(mul, min_tensor, name=name)
+ return value
+
+
+def random_shuffle(value, seed=None, name=None):
+ """Randomly shuffles a tensor along its first dimension.
+
+ The tensor is shuffled along dimension 0, such that each `value[j]` is mapped
+ to one and only one `output[i]`. For example, a mapping that might occur for a
+ 3x2 tensor is:
+
+ ```python
+ [[1, 2], [[5, 6],
+ [3, 4], ==> [1, 2],
+ [5, 6]] [3, 4]]
+ ```
+
+ Args:
+ value: A Tensor to be shuffled.
+ seed: A Python integer. Used to create a random seed for the distribution.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of same shape and type as `value`, shuffled along its first
+ dimension.
+ """
+ seed1, seed2 = random_seed.get_seed(seed)
+ return gen_random_ops._random_shuffle(value, seed=seed1, seed2=seed2,
+ name=name)
+
+
+ops.NoGradient("RandomUniform")
+
+
+@ops.RegisterShape("TruncatedNormal")
+@ops.RegisterShape("RandomStandardNormal")
+@ops.RegisterShape("RandomUniform")
+def _RandomShape(op):
+ shape_val = tensor_util.ConstantValue(op.inputs[0])
+ if shape_val is not None:
+ return [tensor_shape.TensorShape(shape_val.tolist())]
+ else:
+ shape_shape = op.inputs[0].get_shape().with_rank_at_most(1)
+ return [tensor_shape.unknown_shape(ndims=shape_shape.num_elements())]
+
+
+ops.RegisterShape("RandomShuffle")(common_shapes.unchanged_shape)
diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py
new file mode 100644
index 0000000000..3685b671b7
--- /dev/null
+++ b/tensorflow/python/ops/sparse_grad.py
@@ -0,0 +1,12 @@
+"""Gradients for operators defined in sparse_ops.py."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import sparse_ops
+
+
+ops.NoGradient("SparseToDense")
+
+
+ops.NoGradient("SparseConcat")
+
+
+ops.NoGradient("SparseReorder")
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
new file mode 100644
index 0000000000..c0dca6156d
--- /dev/null
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -0,0 +1,458 @@
+"""## Sparse Tensor Representation.
+
+Tensorflow supports a `SparseTensor` representation for data that is sparse
+in multiple dimensions. Contrast this representation with `IndexedSlices`,
+which is efficient for representing tensors that are sparse in their first
+dimension, and dense along all other dimensions.
+
+@@SparseTensor
+@@SparseTensorValue
+
+## Sparse to Dense Conversion.
+
+@@sparse_to_dense
+@@sparse_tensor_to_dense
+@@sparse_to_indicator
+
+## Manipulation.
+
+@@sparse_concat
+@@sparse_reorder
+@@sparse_retain
+@@sparse_fill_empty_rows
+"""
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import gen_sparse_ops
+from tensorflow.python.ops import math_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_sparse_ops import *
+# pylint: enable=wildcard-import
+# pylint: disable=protected-access
+
+
+def sparse_concat(concat_dim, sp_inputs, name=None):
+ """Concatenates a list of `SparseTensor` along the specified dimension.
+
+ Concatenation is with respect to the dense versions of each sparse input.
+ It is assumed that each inputs is a `SparseTensor` whose elements are ordered
+ along increasing dimension number.
+
+ All inputs' shapes must match, except for the concat dimension. The
+ `indices`, `values`, and `shapes` lists must have the same length.
+
+ The output shape is identical to the inputs', except along the concat
+ dimension, where it is the sum of the inputs' sizes along that dimension.
+
+ The output elements will be resorted to preserve the sort order along
+ increasing dimension number.
+
+ This op runs in `O(M log M)` time, where `M` is the total number of non-empty
+ values across all inputs. This is due to the need for an internal sort in
+ order to concatenate efficiently across an arbitrary dimension.
+
+ For example, if `concat_dim = 1` and the inputs are
+
+ sp_inputs[0]: shape = [2, 3]
+ [0, 2]: "a"
+ [1, 0]: "b"
+ [1, 1]: "c"
+
+ sp_inputs[1]: shape = [2, 4]
+ [0, 1]: "d"
+ [0, 2]: "e"
+
+ then the output will be
+
+ shape = [2, 7]
+ [0, 2]: "a"
+ [0, 4]: "d"
+ [0, 5]: "e"
+ [1, 0]: "b"
+ [1, 1]: "c"
+
+ Graphically this is equivalent to doing
+
+ [ a] concat [ d e ] = [ a d e ]
+ [b c ] [ ] [b c ]
+
+ Args:
+ concat_dim: Dimension to concatenate along.
+ sp_inputs: List of `SparseTensor` to concatenate.
+ name: A name prefix for the returned tensors (optional).
+
+ Returns:
+ A `SparseTensor` with the concatenated output.
+
+ Raises:
+ TypeError: If `sp_inputs` is not a list of `SparseTensor`.
+ """
+ if not isinstance(sp_inputs, list):
+ raise TypeError("Inputs must be a list")
+ if not all(isinstance(sp_input, ops.SparseTensor) for sp_input in sp_inputs):
+ raise TypeError("All inputs must be SparseTensors")
+
+ if len(sp_inputs) == 1: # Degenerate case of one tensor.
+ return sp_inputs[0]
+
+ inds = [sp_input.indices for sp_input in sp_inputs]
+ vals = [sp_input.values for sp_input in sp_inputs]
+ shapes = [sp_input.shape for sp_input in sp_inputs]
+
+ output_ind, output_val, output_shape = (
+ gen_sparse_ops._sparse_concat(
+ inds,
+ vals,
+ shapes,
+ concat_dim,
+ name=name))
+
+ return ops.SparseTensor(output_ind, output_val, output_shape)
+
+
+@ops.RegisterShape("SparseConcat")
+def _SparseConcatShape(op):
+ """Shape function for SparseConcat op."""
+ num_inputs = int(op.get_attr("N"))
+
+ # TF flattens and concatenates all list inputs, so reconstruct the lists here.
+ ind_shapes = [ind.get_shape().with_rank(2) for ind in op.inputs[0:num_inputs]]
+ val_shapes = [val.get_shape().with_rank(1)
+ for val in op.inputs[num_inputs:2 * num_inputs]]
+ shape_shapes = [shape.get_shape().with_rank(1)
+ for shape in op.inputs[2 * num_inputs:]]
+
+ output_ind_rows = tensor_shape.Dimension(0)
+ output_ind_cols = tensor_shape.Dimension(None)
+ output_val_elems = tensor_shape.Dimension(0)
+ output_shape_shape = tensor_shape.TensorShape(None)
+
+ for i in range(num_inputs):
+ num_elems_i = ind_shapes[i][0].merge_with(val_shapes[i][0])
+ output_ind_rows += num_elems_i
+ output_ind_cols = output_ind_cols.merge_with(ind_shapes[i][1])
+ output_val_elems += num_elems_i
+ output_shape_shape = output_shape_shape.merge_with(shape_shapes[i])
+
+ output_ind_shape = tensor_shape.matrix(output_ind_rows, output_ind_cols)
+ output_val_shape = tensor_shape.vector(output_val_elems)
+
+ return [output_ind_shape, output_val_shape, output_shape_shape]
+
+
+def sparse_reorder(sp_input, name=None):
+ """Reorders a `SparseTensor` into the canonical, row-major ordering.
+
+ Note that by convention, all sparse ops preserve the canonical ordering
+ along increasing dimension number. The only time ordering can be violated
+ is during manual manipulation of the indices and values to add entries.
+
+ Reordering does not affect the shape of the `SparseTensor`.
+
+ For example, if sp_input has shape `[4, 5]` and `indices` / `values`:
+
+ [0, 3]: b
+ [0, 1]: a
+ [3, 1]: d
+ [2, 0]: c
+
+ then the output will be a `SparseTensor` of shape `[4, 5]` and
+ `indices` / `values`:
+
+ [0, 1]: a
+ [0, 3]: b
+ [2, 0]: c
+ [3, 1]: d
+
+ Args:
+ sp_input: The input `SparseTensor`.
+ name: A name prefix for the returned tensors (optional)
+
+ Returns:
+ A `SparseTensor` with the same shape and non-empty values, but in
+ canonical ordering.
+
+ Raises:
+ TypeError: If `sp_input` is not a `SparseTensor`.
+ """
+ if not isinstance(sp_input, ops.SparseTensor):
+ raise TypeError("Input must be a SparseTensor")
+
+ reordered_ind, reordered_val = (
+ gen_sparse_ops._sparse_reorder(
+ sp_input.indices,
+ sp_input.values,
+ sp_input.shape,
+ name=name))
+
+ return ops.SparseTensor(
+ reordered_ind, reordered_val, array_ops.identity(sp_input.shape))
+
+
+@ops.RegisterShape("SparseReorder")
+def _SparseReorderShape(op):
+ """Shape function for SparseReorder op."""
+ input_indices_shape = op.inputs[0].get_shape().with_rank(2)
+ input_values_shape = op.inputs[1].get_shape().with_rank(1)
+ unused_shape_shape = op.inputs[2].get_shape().with_rank(1)
+
+ return [input_indices_shape, input_values_shape]
+
+
+@ops.RegisterShape("SparseToDense")
+def _SparseToDenseShape(op):
+ input_shape = tensor_util.ConstantValue(op.inputs[1])
+ if input_shape is not None:
+ if np.ndim(input_shape) > 1:
+ raise ValueError("Input shape should be a vector")
+ return [tensor_shape.TensorShape(input_shape.tolist())]
+ else:
+ input_shape_shape = op.inputs[1].get_shape().with_rank_at_most(1)
+ return [tensor_shape.unknown_shape(ndims=input_shape_shape.num_elements())]
+
+
+def sparse_tensor_to_dense(sp_input, default_value, name=None):
+ """Converts a `SparseTensor` into a dense tensor.
+
+ This op is a convenience wrapper around `sparse_to_dense` for `SparseTensor`s.
+
+ For example, if `sp_input` has shape `[3, 5]` and non-empty string values:
+
+ [0, 1]: a
+ [0, 3]: b
+ [2, 0]: c
+
+ and `default_value` is `x`, then the output will be a dense `[3, 5]`
+ string tensor with values:
+
+ [[x a x b x]
+ [x x x x x]
+ [c x x x x]]
+
+ Args:
+ sp_input: The input `SparseTensor`.
+ default_value: Scalar value to set for indices not specified in
+ `sp_input`.
+ name: A name prefix for the returned tensors (optional).
+
+ Returns:
+ A dense tensor with shape `sp_input.shape` and values specified by
+ the non-empty values in `sp_input`. Indices not in `sp_input` are assigned
+ `default_value`.
+
+ Raises:
+ TypeError: If `sp_input` is not a `SparseTensor`.
+ """
+ if not isinstance(sp_input, ops.SparseTensor):
+ raise TypeError("Input must be a SparseTensor")
+
+ return gen_sparse_ops.sparse_to_dense(
+ sp_input.indices,
+ sp_input.shape,
+ sp_input.values,
+ default_value,
+ name=name)
+
+
+def sparse_to_indicator(sp_input, vocab_size, name=None):
+ """Converts a `SparseTensor` of ids into a dense bool indicator tensor.
+
+ The last dimension of `sp_input` is discarded and replaced with the values of
+ `sp_input`. If `sp_input.shape = [D0, D1, ..., Dn, K]`, then
+ `output.shape = [D0, D1, ..., Dn, vocab_size]`, where
+
+ output[d_0, d_1, ..., d_n, sp_input[d_0, d_1, ..., d_n, k]] = True
+
+ and False elsewhere in `output`.
+
+ For example, if `sp_input.shape = [2, 3, 4]` with non-empty values:
+
+ [0, 0, 0]: 0
+ [0, 1, 0]: 10
+ [1, 0, 3]: 103
+ [1, 1, 2]: 112
+ [1, 1, 3]: 113
+ [1, 2, 1]: 121
+
+ and `vocab_size = 200`, then the output will be a `[2, 3, 200]` dense bool
+ tensor with False everywhere except at positions
+
+ (0, 0, 0), (0, 1, 10), (1, 0, 103), (1, 1, 112), (1, 1, 113), (1, 2, 121).
+
+ This op is useful for converting `SparseTensor`s into dense formats for
+ compatibility with ops that expect dense tensors.
+
+ The input `SparseTensor` must be in row-major order.
+
+ Args:
+ sp_input: A `SparseTensor` of type `int32` or `int64`.
+ vocab_size: The new size of the last dimension, with
+ `all(0 <= sp_input.values < vocab_size)`.
+ name: A name prefix for the returned tensors (optional)
+
+ Returns:
+ A dense bool indicator tensor representing the indices with specified value.
+
+ Raises:
+ TypeError: If `sp_input` is not a `SparseTensor`.
+ """
+ if not isinstance(sp_input, ops.SparseTensor):
+ raise TypeError("Input must be a SparseTensor")
+
+ with ops.op_scope([sp_input], name, "SparseToIndicator") as name:
+ indices_shape = array_ops.shape(sp_input.indices)
+ num_entries = indices_shape[0]
+ rank = indices_shape[1]
+
+ ids = sp_input.values
+ if ids.dtype != types.int64:
+ ids = math_ops.cast(ids, types.int64)
+
+ # Slice off the last dimension of indices, then then tack on the ids
+ indices_columns_to_preserve = array_ops.slice(
+ sp_input.indices, [0, 0], array_ops.pack([-1, rank - 1]))
+ new_indices = array_ops.concat(
+ 1, [indices_columns_to_preserve, array_ops.reshape(ids, [-1, 1])])
+
+ new_values = array_ops.fill(array_ops.expand_dims(num_entries, 0), True)
+ new_shape = array_ops.concat(
+ 0, [array_ops.slice(sp_input.shape, [0],
+ array_ops.expand_dims(rank - 1, 0)), [vocab_size]])
+
+ sp_new = ops.SparseTensor(new_indices, new_values, new_shape)
+
+ return sparse_tensor_to_dense(sp_new, False, name=name)
+
+
+def sparse_retain(sp_input, to_retain):
+ """Retains specified non-empty values within a `SparseTensor`.
+
+ For example, if `sp_input` has shape `[4, 5]` and 4 non-empty string values:
+
+ [0, 1]: a
+ [0, 3]: b
+ [2, 0]: c
+ [3, 1]: d
+
+ and `to_retain = [True, False, False, True]`, then the output will
+ be a `SparseTensor` of shape `[4, 5]` with 2 non-empty values:
+
+ [0, 1]: a
+ [3, 1]: d
+
+ Args:
+ sp_input: The input `SparseTensor` with `N` non-empty elements.
+ to_retain: A bool vector of length `N` with `M` true values.
+
+ Returns:
+ A `SparseTensor` with the same shape as the input and `M` non-empty
+ elements corresponding to the true positions in `to_retain`.
+
+ Raises:
+ TypeError: If `sp_input` is not a `SparseTensor`.
+ """
+ if not isinstance(sp_input, ops.SparseTensor):
+ raise TypeError("Input must be a SparseTensor")
+
+ to_retain = ops.convert_to_tensor(to_retain)
+
+ # Shape checking, if shape is known at graph construction time
+ retain_shape = to_retain.get_shape()
+ retain_shape.assert_has_rank(1)
+ sp_input.values.get_shape()[0].merge_with(retain_shape[0])
+
+ where_true = array_ops.reshape(array_ops.where(to_retain), [-1])
+ new_indices = array_ops.gather(sp_input.indices, where_true)
+ new_values = array_ops.gather(sp_input.values, where_true)
+ return ops.SparseTensor(
+ new_indices, new_values, array_ops.identity(sp_input.shape))
+
+
+def sparse_fill_empty_rows(sp_input, default_value, name=None):
+ """Fills empty rows in the input 2-D `SparseTensor` with a default value.
+
+ This op adds entries with the specified `default_value` at index
+ `[row, 0]` for any row in the input that does not already have a value.
+
+ For example, suppose `sp_input` has shape `[5, 6]` and non-empty values:
+
+ [0, 1]: a
+ [0, 3]: b
+ [2, 0]: c
+ [3, 1]: d
+
+ Rows 1 and 4 are empty, so the output will be of shape `[5, 6]` with values:
+
+ [0, 1]: a
+ [0, 3]: b
+ [1, 0]: default_value
+ [2, 0]: c
+ [3, 1]: d
+ [4, 0]: default_value
+
+ Note that the input may have empty columns at the end, with no effect on
+ this op.
+
+ The output `SparseTensor` will be in row-major order and will have the
+ same shape as the input.
+
+ This op also returns an indicator vector such that
+
+ empty_row_indicator[i] = True iff row i was an empty row.
+
+ Args:
+ sp_input: A `SparseTensor` with shape `[N, M]`.
+ default_value: The value to fill for empty rows, with the same type as
+ `sp_input.`
+ name: A name prefix for the returned tensors (optional)
+
+ Returns:
+ sp_ordered_output: A `SparseTensor` with shape `[N, M]`, and with all empty
+ rows filled in with `default_value`.
+ empty_row_indicator: A bool vector of length `N` indicating whether each
+ input row was empty.
+
+ Raises:
+ TypeError: If `sp_input` is not a `SparseTensor`.
+ """
+ if not isinstance(sp_input, ops.SparseTensor):
+ raise TypeError("Input must be a SparseTensor")
+
+ with ops.op_scope([sp_input], name, "SparseFillEmptyRows"):
+ default_value = ops.convert_to_tensor(
+ default_value, dtype=sp_input.values.dtype)
+
+ num_rows = math_ops.cast(sp_input.shape[0], types.int32)
+ all_row_indices = math_ops.cast(
+ math_ops.range(0, num_rows, 1), types.int64)
+ empty_row_indices, _ = array_ops.list_diff(
+ all_row_indices, sp_input.indices[:, 0])
+ empty_row_indicator = gen_sparse_ops.sparse_to_dense(
+ empty_row_indices, array_ops.expand_dims(sp_input.shape[0], -1), True,
+ False)
+
+ empty_row_indices_as_column = array_ops.reshape(empty_row_indices, [-1, 1])
+ additional_indices = array_ops.concat(
+ 1,
+ [empty_row_indices_as_column,
+ array_ops.zeros_like(empty_row_indices_as_column)])
+ additional_values = array_ops.fill(array_ops.shape(empty_row_indices),
+ default_value)
+
+ all_indices_unordered = array_ops.concat(
+ 0, [sp_input.indices, additional_indices])
+ all_values_unordered = array_ops.concat(
+ 0, [sp_input.values, additional_values])
+ sp_unordered_output = ops.SparseTensor(
+ all_indices_unordered, all_values_unordered, sp_input.shape)
+ sp_ordered_output = sparse_reorder(sp_unordered_output)
+
+ return sp_ordered_output, empty_row_indicator
diff --git a/tensorflow/python/ops/sparse_ops_test.py b/tensorflow/python/ops/sparse_ops_test.py
new file mode 100644
index 0000000000..07a5e6c6da
--- /dev/null
+++ b/tensorflow/python/ops/sparse_ops_test.py
@@ -0,0 +1,212 @@
+"""Tests for Python ops defined in sparse_ops."""
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import googletest
+
+
+class SparseToIndicatorTest(test_util.TensorFlowTestCase):
+
+ def _SparseTensor_5x6(self, dtype):
+ ind = np.array([
+ [0, 0],
+ [1, 0], [1, 3], [1, 4],
+ [3, 2], [3, 3]])
+ val = np.array([0, 10, 13, 14, 32, 33])
+ shape = np.array([5, 6])
+ return ops.SparseTensor(
+ constant_op.constant(ind, types.int64),
+ constant_op.constant(val, dtype),
+ constant_op.constant(shape, types.int64))
+
+ def _SparseTensor_2x3x4(self, dtype):
+ ind = np.array([
+ [0, 0, 1],
+ [0, 1, 0], [0, 1, 2],
+ [1, 0, 3],
+ [1, 1, 1], [1, 1, 3],
+ [1, 2, 2]])
+ val = np.array([1, 10, 12, 103, 111, 113, 122])
+ shape = np.array([2, 3, 4])
+ return ops.SparseTensor(
+ constant_op.constant(ind, types.int64),
+ constant_op.constant(val, dtype),
+ constant_op.constant(shape, types.int64))
+
+ def testInt32(self):
+ with self.test_session(use_gpu=False):
+ sp_input = self._SparseTensor_5x6(types.int32)
+ output = sparse_ops.sparse_to_indicator(sp_input, 50).eval()
+
+ expected_output = np.zeros((5, 50), dtype=np.bool)
+ expected_trues = ((0, 0), (1, 10), (1, 13), (1, 14), (3, 32), (3, 33))
+ for expected_true in expected_trues:
+ expected_output[expected_true] = True
+
+ self.assertAllEqual(output, expected_output)
+
+ def testInt64(self):
+ with self.test_session(use_gpu=False):
+ sp_input = self._SparseTensor_5x6(types.int64)
+ output = sparse_ops.sparse_to_indicator(sp_input, 50).eval()
+
+ expected_output = np.zeros((5, 50), dtype=np.bool)
+ expected_trues = [(0, 0), (1, 10), (1, 13), (1, 14), (3, 32), (3, 33)]
+ for expected_true in expected_trues:
+ expected_output[expected_true] = True
+
+ self.assertAllEqual(output, expected_output)
+
+ def testHigherRank(self):
+ with self.test_session(use_gpu=False):
+ sp_input = self._SparseTensor_2x3x4(types.int64)
+ output = sparse_ops.sparse_to_indicator(sp_input, 200).eval()
+
+ expected_output = np.zeros((2, 3, 200), dtype=np.bool)
+ expected_trues = [(0, 0, 1), (0, 1, 10), (0, 1, 12),
+ (1, 0, 103), (1, 1, 111), (1, 1, 113), (1, 2, 122)]
+ for expected_true in expected_trues:
+ expected_output[expected_true] = True
+
+ self.assertAllEqual(output, expected_output)
+
+
+class SparseRetainTest(test_util.TensorFlowTestCase):
+
+ def _SparseTensor_5x6(self):
+ ind = np.array([
+ [0, 0],
+ [1, 0], [1, 3], [1, 4],
+ [3, 2], [3, 3]])
+ val = np.array([0, 10, 13, 14, 32, 33])
+ shape = np.array([5, 6])
+ return ops.SparseTensor(
+ constant_op.constant(ind, types.int64),
+ constant_op.constant(val, types.int32),
+ constant_op.constant(shape, types.int64))
+
+ def testBasic(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensor_5x6()
+ to_retain = np.array([1, 0, 0, 1, 1, 0], dtype=np.bool)
+ sp_output = sparse_ops.sparse_retain(sp_input, to_retain)
+
+ output = sess.run(sp_output)
+
+ self.assertAllEqual(output.indices, [[0, 0], [1, 4], [3, 2]])
+ self.assertAllEqual(output.values, [0, 14, 32])
+ self.assertAllEqual(output.shape, [5, 6])
+
+ def testRetainNone(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensor_5x6()
+ to_retain = np.zeros((6,), dtype=np.bool)
+ sp_output = sparse_ops.sparse_retain(sp_input, to_retain)
+
+ output = sess.run(sp_output)
+
+ self.assertAllEqual(output.indices, np.array([]).reshape((0, 2)))
+ self.assertAllEqual(output.values, [])
+ self.assertAllEqual(output.shape, [5, 6])
+
+ def testMismatchedRetainShape(self):
+ with self.test_session(use_gpu=False):
+ sp_input = self._SparseTensor_5x6()
+ to_retain = np.array([1, 0, 0, 1, 0], dtype=np.bool)
+ with self.assertRaises(ValueError):
+ sparse_ops.sparse_retain(sp_input, to_retain)
+
+
+class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase):
+
+ def _SparseTensor_5x6(self):
+ ind = np.array([
+ [0, 0],
+ [1, 0], [1, 3], [1, 4],
+ [3, 2], [3, 3]])
+ val = np.array([0, 10, 13, 14, 32, 33])
+ shape = np.array([5, 6])
+ return ops.SparseTensor(
+ constant_op.constant(ind, types.int64),
+ constant_op.constant(val, types.int32),
+ constant_op.constant(shape, types.int64))
+
+ def _SparseTensor_String5x6(self):
+ ind = np.array([
+ [0, 0],
+ [1, 0], [1, 3], [1, 4],
+ [3, 2], [3, 3]])
+ val = np.array(["a", "b", "c", "d", "e", "f"])
+ shape = np.array([5, 6])
+ return ops.SparseTensor(
+ constant_op.constant(ind, types.int64),
+ constant_op.constant(val, types.string),
+ constant_op.constant(shape, types.int64))
+
+ def _SparseTensor_2x6(self):
+ ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4]])
+ val = np.array([0, 10, 13, 14])
+ shape = np.array([2, 6])
+ return ops.SparseTensor(
+ constant_op.constant(ind, types.int64),
+ constant_op.constant(val, types.int32),
+ constant_op.constant(shape, types.int64))
+
+ def testFillNumber(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensor_5x6()
+ sp_output, empty_row_indicator = (
+ sparse_ops.sparse_fill_empty_rows(sp_input, -1))
+
+ output, empty_row_indicator_out = sess.run(
+ [sp_output, empty_row_indicator])
+
+ self.assertAllEqual(
+ output.indices,
+ [[0, 0], [1, 0], [1, 3], [1, 4], [2, 0], [3, 2], [3, 3], [4, 0]])
+ self.assertAllEqual(output.values, [0, 10, 13, 14, -1, 32, 33, -1])
+ self.assertAllEqual(output.shape, [5, 6])
+ self.assertAllEqual(empty_row_indicator_out,
+ np.array([0, 0, 1, 0, 1]).astype(np.bool))
+
+ def testFillString(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensor_String5x6()
+ sp_output, empty_row_indicator = (
+ sparse_ops.sparse_fill_empty_rows(sp_input, ""))
+
+ output, empty_row_indicator_out = sess.run(
+ [sp_output, empty_row_indicator])
+
+ self.assertAllEqual(
+ output.indices,
+ [[0, 0], [1, 0], [1, 3], [1, 4], [2, 0], [3, 2], [3, 3], [4, 0]])
+ self.assertAllEqual(output.values, ["a", "b", "c", "d", "", "e", "f", ""])
+ self.assertAllEqual(output.shape, [5, 6])
+ self.assertAllEqual(empty_row_indicator_out,
+ np.array([0, 0, 1, 0, 1]).astype(np.bool))
+
+ def testNoEmptyRows(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensor_2x6()
+ sp_output, empty_row_indicator = (
+ sparse_ops.sparse_fill_empty_rows(sp_input, -1))
+
+ output, empty_row_indicator_out = sess.run(
+ [sp_output, empty_row_indicator])
+
+ self.assertAllEqual(output.indices, [[0, 0], [1, 0], [1, 3], [1, 4]])
+ self.assertAllEqual(output.values, [0, 10, 13, 14])
+ self.assertAllEqual(output.shape, [2, 6])
+ self.assertAllEqual(empty_row_indicator_out, np.zeros(2).astype(np.bool))
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
new file mode 100644
index 0000000000..beef8e75b5
--- /dev/null
+++ b/tensorflow/python/ops/standard_ops.py
@@ -0,0 +1,41 @@
+# pylint: disable=wildcard-import,unused-import
+"""Import names of Tensor Flow standard Ops."""
+
+# Imports the following modules so that @RegisterGradient get executed.
+from tensorflow.python.ops import array_grad
+from tensorflow.python.ops import data_flow_grad
+from tensorflow.python.ops import math_grad
+from tensorflow.python.ops import state_grad
+
+from tensorflow.python.ops.array_ops import *
+from tensorflow.python.ops.clip_ops import *
+# TODO(vrv): Switch to import * once we're okay with exposing the module.
+from tensorflow.python.ops.control_flow_ops import group
+from tensorflow.python.ops.control_flow_ops import no_op
+from tensorflow.python.ops.control_flow_ops import tuple
+from tensorflow.python.ops.data_flow_ops import *
+from tensorflow.python.ops.gradients import *
+from tensorflow.python.ops.init_ops import *
+from tensorflow.python.ops.io_ops import *
+from tensorflow.python.ops.linalg_ops import *
+from tensorflow.python.ops.logging_ops import *
+from tensorflow.python.ops.math_ops import *
+from tensorflow.python.ops.numerics import *
+from tensorflow.python.ops.parsing_ops import *
+from tensorflow.python.ops.random_ops import *
+from tensorflow.python.ops.sparse_ops import *
+from tensorflow.python.ops.state_ops import assign
+from tensorflow.python.ops.state_ops import assign_add
+from tensorflow.python.ops.state_ops import assign_sub
+from tensorflow.python.ops.state_ops import count_up_to
+from tensorflow.python.ops.state_ops import scatter_add
+from tensorflow.python.ops.state_ops import scatter_sub
+from tensorflow.python.ops.state_ops import scatter_update
+from tensorflow.python.ops.string_ops import *
+from tensorflow.python.ops.summary_ops import histogram_summary
+from tensorflow.python.ops.summary_ops import image_summary
+from tensorflow.python.ops.summary_ops import merge_all_summaries
+from tensorflow.python.ops.summary_ops import merge_summary
+from tensorflow.python.ops.summary_ops import scalar_summary
+from tensorflow.python.ops.variable_scope import *
+from tensorflow.python.ops.variables import *
diff --git a/tensorflow/python/ops/state_grad.py b/tensorflow/python/ops/state_grad.py
new file mode 100644
index 0000000000..d9b084693c
--- /dev/null
+++ b/tensorflow/python/ops/state_grad.py
@@ -0,0 +1,18 @@
+"""Gradients for operators defined in state_ops.py."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import state_ops
+
+ops.NoGradient("Assign")
+
+
+ops.NoGradient("AssignAdd")
+
+
+ops.NoGradient("AssignSub")
+
+
+ops.NoGradient("ScatterAdd")
+
+
+ops.NoGradient("ScatterSub")
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
new file mode 100644
index 0000000000..1c8f38b94c
--- /dev/null
+++ b/tensorflow/python/ops/state_ops.py
@@ -0,0 +1,189 @@
+"""## Variables
+
+@@Variable
+
+## Variable helper functions
+
+TensorFlow provides a set of functions to help manage the set of variables
+collected in the graph.
+
+@@all_variables
+@@trainable_variables
+
+@@initialize_all_variables
+@@initialize_variables
+@@assert_variables_initialized
+
+## Saving and Restoring Variables.
+
+@@Saver
+
+@@latest_checkpoint
+
+@@get_checkpoint_state
+@@update_checkpoint_state
+
+## Sharing Variables
+
+TensorFlow provides several classes and operations that you can use to
+create variables contingent on certain conditions.
+
+@@get_variable
+@@get_variable_scope
+@@variable_scope
+
+@@constant_initializer
+@@random_normal_initializer
+@@truncated_normal_initializer
+@@random_uniform_initializer
+@@uniform_unit_scaling_initializer
+@@zeros_initializer
+
+## Sparse Variable Updates
+
+The sparse update ops modify a subset of the entries in a dense `Variable`,
+either overwriting the entries or adding / subtracting a delta. These are
+useful for training embedding models and similar lookup-based networks, since
+only a small subset of embedding vectors change in any given step.
+
+Since a sparse update of a large tensor may be generated automatically during
+gradient computation (as in the gradient of [`tf.gather`](array_ops.md#gather)),
+an [`IndexedSlices`](#IndexedSlices) class is provided that encapsulates a set
+of sparse indices and values. `IndexedSlices` objects are detected and handled
+automatically by the optimizers in most cases.
+
+@@scatter_update
+@@scatter_add
+@@scatter_sub
+@@sparse_mask
+@@IndexedSlices
+"""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_state_ops
+# pylint: disable=wildcard-import,undefined-variable
+from tensorflow.python.ops.gen_state_ops import *
+
+
+# pylint: disable=protected-access
+def variable_op(shape, dtype, name="Variable", set_shape=True, container="",
+ shared_name=""):
+ """Create a variable Operation.
+
+ See also variables.Variable.
+
+ Args:
+ shape: The shape of the tensor managed by this variable
+ dtype: The underlying type of the tensor values.
+ name: optional name to use for the variable op.
+ set_shape: If True, set the shape property of the returned Tensor to
+ the shape argument.
+ container: An optional string. Defaults to "".
+ If non-empty, this variable is placed in the given container.
+ Otherwise, a default container is used.
+ shared_name: An optional string. Defaults to "".
+ If non-empty, this variable is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+
+ Returns:
+ A variable tensor.
+ """
+ ret = gen_state_ops._variable(shape=shape, dtype=dtype, name=name,
+ container=container, shared_name=shared_name)
+ # TODO(mrry): Move this to where it is used, so we can get rid of this op
+ # wrapper?
+ if set_shape:
+ ret.set_shape(shape)
+ return ret
+
+
+# NOTE(mrry): Shapes are conditionally set in the Python wrapper.
+ops.RegisterShape("Variable")(common_shapes.unknown_shape)
+
+
+@ops.RegisterShape("TemporaryVariable")
+def _TemporaryVariableShape(op):
+ """Shape function for the TemporaryVariable op."""
+ shape = tensor_util.TensorShapeProtoToList(op.get_attr("shape"))
+ return [tensor_shape.TensorShape(shape)]
+
+
+@ops.RegisterShape("DestroyTemporaryVariable")
+def _DestroyTemporaryVariableShape(op):
+ """Shape function for the DestroyTemporaryVariable op."""
+ return [op.inputs[0].get_shape()]
+
+
+def init_variable(v, init, name="init"):
+ """Initializes variable with "init".
+
+ This op does the following:
+ if init is a Tensor, v = init
+ if callable(init): v = init(VariableShape(v), v.dtype)
+
+ Args:
+ v: Variable to initialize
+ init: Tensor to assign to v,
+ Or an object convertible to Tensor e.g. nparray,
+ Or an Initializer that generates a tensor given the shape and type of v.
+ An "Initializer" is a callable that returns a tensor that "v" should be
+ set to. It will be called as init(shape, dtype).
+ name: Optional name for the op.
+
+ Returns:
+ The operation that initializes v.
+ """
+ with ops.op_scope([v, init], None, v.op.name + "/"):
+ with ops.name_scope(name) as scope:
+ with ops.device(v.device or ops.get_default_graph().get_default_device()):
+ if callable(init):
+ assert v.get_shape().is_fully_defined(), "Variable shape unknown."
+ # TODO(mrry): Convert to v.shape when the property and
+ # accessor are reconciled (and all initializers support
+ # tf.TensorShape objects).
+ value = init(v.get_shape().as_list(), v.dtype.base_dtype)
+ value = ops.convert_to_tensor(value, name="value")
+ return assign(v, value, name=scope)
+ else:
+ init = ops.convert_to_tensor(init, name="init")
+ return assign(v, init, name=scope)
+
+
+@ops.RegisterShape("Assign")
+def _AssignShape(op):
+ """Shape function for the Assign op."""
+ if op.get_attr("validate_shape"):
+ # NOTE(mrry): Return a known shape here. This makes it awkward to
+ # chain a validated-shape assignment and a reshaping assignment,
+ # but that is a sufficiently niche case that supporting it does
+ # not seem worthwhile.
+ return [op.inputs[0].get_shape().merge_with(op.inputs[1].get_shape())]
+ return [op.inputs[1].get_shape()]
+
+
+@ops.RegisterShape("AssignAdd")
+@ops.RegisterShape("AssignSub")
+def _AssignUpdateShape(op):
+ """Shape function for the AssignAdd and AssignSub dense update ops."""
+ return [op.inputs[0].get_shape().merge_with(op.inputs[1].get_shape())]
+
+
+@ops.RegisterShape("CountUpTo")
+def _CountUpToShape(op):
+ """Shape function for the CountUpTo op."""
+ return [op.inputs[0].get_shape().merge_with(tensor_shape.scalar())]
+
+
+@ops.RegisterShape("ScatterAdd")
+@ops.RegisterShape("ScatterSub")
+@ops.RegisterShape("ScatterUpdate")
+def _ScatterUpdateShape(op):
+ """Shape function for the sparse update ops."""
+ var_shape = op.inputs[0].get_shape()
+ indices_shape = op.inputs[1].get_shape()
+ unused_updates_shape = op.inputs[2].get_shape().merge_with(
+ indices_shape.concatenate(var_shape[1:]))
+ return [var_shape]
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
new file mode 100644
index 0000000000..8181fe9a2a
--- /dev/null
+++ b/tensorflow/python/ops/string_ops.py
@@ -0,0 +1,12 @@
+"""String Ops."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_string_ops
+# pylint: disable=wildcard-import,undefined-variable
+from tensorflow.python.ops.gen_string_ops import *
+
+ops.NoGradient("StringToHashBucket")
+
+ops.RegisterShape("StringToHashBucket")(common_shapes.unchanged_shape)
diff --git a/tensorflow/python/ops/summary_ops.py b/tensorflow/python/ops/summary_ops.py
new file mode 100644
index 0000000000..d65fd1ea7c
--- /dev/null
+++ b/tensorflow/python/ops/summary_ops.py
@@ -0,0 +1,177 @@
+"""Summary Operations."""
+# pylint: disable=wildcard-import,protected-access
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_summary_ops
+from tensorflow.python.ops.gen_summary_ops import *
+
+
+def _Collect(val, collections, default_collections):
+ if collections is None:
+ collections = default_collections
+ for key in collections:
+ ops.add_to_collection(key, val)
+
+
+def histogram_summary(tag, values, collections=None, name=None):
+ """Outputs a `Summary` protocol buffer with a histogram.
+
+ The generated
+ [`Summary`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/summary.proto)
+ has one summary value containing a histogram for `values`.
+
+ This op reports an `OutOfRange` error if any value is not finite.
+
+ Args:
+ tag: A `string` `Tensor`. 0-D. Tag to use for the summary value.
+ values: A `float32` `Tensor`. Any shape. Values to use to build the
+ histogram.
+ collections: Optional list of graph collections keys. The new summary op is
+ added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar `Tensor` of type `string`. The serialized `Summary` protocol
+ buffer.
+ """
+ with ops.op_scope([tag, values], name, "HistogramSummary") as scope:
+ val = gen_summary_ops._histogram_summary(
+ tag=tag, values=values, name=scope)
+ _Collect(val, collections, [ops.GraphKeys.SUMMARIES])
+ return val
+
+
+def image_summary(tag, tensor, max_images=None, collections=None, name=None):
+ """Outputs a `Summary` protocol buffer with images.
+
+ The summary has up to `max_images` summary values containing images. The
+ images are built from `tensor` which must be 4-D with shape `[batch_size,
+ height, width, channels]` and where `channels` can be:
+
+ * 1: `tensor` is interpreted as Grayscale.
+ * 3: `tensor` is interpreted as RGB.
+ * 4: `tensor` is interpreted as RGBA.
+
+ The images have the same number of channels as the input tensor. Their values
+ are normalized, one image at a time, to fit in the range `[0, 255]`. The
+ op uses two different normalization algorithms:
+
+ * If the input values are all positive, they are rescaled so the largest one
+ is 255.
+
+ * If any input value is negative, the values are shifted so input value 0.0
+ is at 127. They are then rescaled so that either the smallest value is 0,
+ or the largest one is 255.
+
+ The `tag` argument is a scalar `Tensor` of type `string`. It is used to
+ build the `tag` of the summary values:
+
+ * If `max_images` is 1, the summary value tag is '*tag*/image'.
+ * If `max_images` is greater than 1, the summary value tags are
+ generated sequentially as '*tag*/image/0', '*tag*/image/1', etc.
+
+ Args:
+ tag: A scalar `Tensor` of type `string`. Used to build the `tag`
+ of the summary values.
+ tensor: A 4-D `float32` `Tensor` of shape `[batch_size, height, width,
+ channels]` where `channels` is 1, 3, or 4.
+ max_images: Max number of batch elements to generate images for.
+ collections: Optional list of ops.GraphKeys. The collections to add the
+ summary to. Defaults to [ops.GraphKeys.SUMMARIES]
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar `Tensor` of type `string`. The serialized `Summary` protocol
+ buffer.
+ """
+ with ops.op_scope([tag, tensor], name, "ImageSummary") as scope:
+ val = gen_summary_ops._image_summary(
+ tag=tag, tensor=tensor, max_images=max_images, name=scope)
+ _Collect(val, collections, [ops.GraphKeys.SUMMARIES])
+ return val
+
+
+def merge_summary(inputs, collections=None, name=None):
+ """Merges summaries.
+
+ This op creates a
+ [`Summary`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/summary.proto)
+ protocol buffer that contains the union of all the values in the input
+ summaries.
+
+ When the Op is run, it reports an `InvalidArgument` error if multiple values
+ in the summaries to merge use the same tag.
+
+ Args:
+ inputs: A list of `string` `Tensor` objects containing serialized `Summary`
+ protocol buffers.
+ collections: Optional list of graph collections keys. The new summary op is
+ added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar `Tensor` of type `string`. The serialized `Summary` protocol
+ buffer resulting from the merging.
+ """
+ with ops.op_scope(inputs, name, "MergeSummary") as scope:
+ val = gen_summary_ops._merge_summary(inputs=inputs, name=name)
+ _Collect(val, collections, [])
+ return val
+
+
+def merge_all_summaries(key=ops.GraphKeys.SUMMARIES):
+ """Merges all summaries collected in the default graph.
+
+ Args:
+ key: `GraphKey` used to collect the summaries. Defaults to
+ `GraphKeys.SUMMARIES`.
+
+ Returns:
+ If no summaries were collected, returns None. Otherwise returns a scalar
+ `Tensor` of type`string` containing the serialized `Summary` protocol
+ buffer resulting from the merging.
+ """
+ summary_ops = ops.get_collection(key)
+ if not summary_ops:
+ return None
+ else:
+ return merge_summary(summary_ops)
+
+
+def scalar_summary(tags, values, collections=None, name=None):
+ """Outputs a `Summary` protocol buffer with scalar values.
+
+ The input `tags` and `values` must have the same shape. The generated
+ summary has a summary value for each tag-value pair in `tags` and `values`.
+
+ Args:
+ tags: A 1-D `string` `Tensor`. Tags for the summaries.
+ values: A 1-D `float32` or `float64` Tensor. Values for the summaries.
+ collections: Optional list of graph collections keys. The new summary op is
+ added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar `Tensor` of type `string`. The serialized `Summary` protocol
+ buffer.
+ """
+ with ops.op_scope([tags, values], name, "ScalarSummary") as scope:
+ val = gen_summary_ops._scalar_summary(tags=tags, values=values, name=scope)
+ _Collect(val, collections, [ops.GraphKeys.SUMMARIES])
+ return val
+
+
+ops.NoGradient("HistogramAccumulatorSummary")
+ops.NoGradient("HistogramSummary")
+ops.NoGradient("ImageSummary")
+ops.NoGradient("MergeSummary")
+ops.NoGradient("ScalarSummary")
+
+
+@ops.RegisterShape("HistogramAccumulatorSummary")
+@ops.RegisterShape("HistogramSummary")
+@ops.RegisterShape("ImageSummary")
+@ops.RegisterShape("MergeSummary")
+@ops.RegisterShape("ScalarSummary")
+def _ScalarShape(unused_op):
+ return [tensor_shape.scalar()]
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
new file mode 100644
index 0000000000..c9c2cac0a5
--- /dev/null
+++ b/tensorflow/python/ops/variable_scope.py
@@ -0,0 +1,333 @@
+"""A class to store named variables and a scope operator to manage sharing."""
+
+import contextlib
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import logging
+
+
+class _VariableStore(object):
+ """Variable store that carries a number of named Variables.
+
+ New variable names and new variables can be created; all stored
+ variables are initialized with the initializer passed to __init__.
+
+ Attributes:
+ vars: a dictionary with string names (same as passed in GetVar) as keys
+ and the corresponding TensorFlow Variables as values.
+ """
+
+ def __init__(self):
+ """Create a variable store."""
+ self._vars = {} # A dictionary of the stored TensorFlow variables.
+
+ def get_variable(self, name, shape=None, dtype=types.float32,
+ initializer=None, reuse=None, trainable=True,
+ collections=None):
+ """Gets an existing variable with these parameters or create a new one.
+
+ If a variable with the given name is already stored, we return the stored
+ variable. Otherwise, we create a new one.
+
+ Set `reuse` to `True` when you only want to reuse existing Variables.
+ Set `reuse` to `False` when you only want to create new Variables.
+ If `reuse` is `None` (the default), both new and existing variables are
+ returned.
+
+ If initializer is `None` (the default), the default initializer passed in
+ the constructor is used. If that one is `None` too, we use a new
+ `UniformUnitScalingInitializer`.
+
+ Args:
+ name: the name of the new or existing variable.
+ shape: shape of the new or existing variable.
+ dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
+ initializer: initializer for the variable.
+ reuse: a Boolean or `None`. Controls reuse or creation of variables.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see variables.Variable).
+ collections: List of graph collections keys to add the Variable to.
+ Defaults to `[GraphKeys.VARIABLES]` (see variables.Variable).
+
+ Returns:
+ The created or existing variable.
+
+ Raises:
+ ValueError: when creating a new variable and shape is not declared,
+ when reusing a variable and specifying a conflicting shape,
+ or when violating reuse during variable creation.
+ """
+ should_check = reuse is not None
+ dtype = types.as_dtype(dtype)
+ shape = tensor_shape.as_shape(shape)
+ if name in self._vars:
+ # Here we handle the case when returning an existing variable.
+ if should_check and not reuse:
+ raise ValueError("Over-sharing: Variable %s already exists, disallowed."
+ " Did you mean to set reuse=True in VarScope?" % name)
+ found_var = self._vars[name]
+ if not shape.is_compatible_with(found_var.get_shape()):
+ raise ValueError("Trying to share variable %s, but specified shape %s"
+ " and found shape %s." % (name, str(shape),
+ str(found_var.get_shape())))
+ if not dtype.is_compatible_with(found_var.dtype):
+ dtype_str = dtype.name
+ found_type_str = found_var.dtype.name
+ raise ValueError("Trying to share variable %s, but specified dtype %s"
+ " and found dtype %s." % (name, str(dtype_str),
+ str(found_type_str)))
+ return found_var
+
+ # The code below handles only the case of creating a new variable.
+ if should_check and reuse:
+ raise ValueError("Under-sharing: Variable %s does not exist, disallowed."
+ " Did you mean to set reuse=None in VarScope?" % name)
+ if not shape.is_fully_defined():
+ raise ValueError("Shape of a new variable (%s) must be fully defined, "
+ "but instead was %s." % (name, shape))
+ if initializer is None:
+ initializer = init_ops.uniform_unit_scaling_initializer()
+ with ops.name_scope(name + "/Initializer/"):
+ init_val = initializer(shape.as_list(), dtype=dtype)
+ v = variables.Variable(init_val, name=name, trainable=trainable,
+ collections=collections)
+ self._vars[name] = v
+ logging.info("Created variable %s with shape %s and init %s", v.name,
+ format(shape), str(initializer))
+ return v
+
+
+class _VariableScope(object):
+ """Variable scope object to carry defaults to provide to get_variable.
+
+ Many of the arguments we need for get_variable in a variable store are most
+ easily handled with a context. This object is used for the defaults.
+
+ Attributes:
+ name: name of the current scope, used as prefix in get_variable.
+ initializer: default initializer passed to get_variable.
+ reuse: Boolean or None, setting the reuse in get_variable.
+ """
+
+ def __init__(self, reuse, name="", initializer=None):
+ self._name = name
+ self._initializer = initializer
+ self._reuse = reuse
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def reuse(self):
+ return self._reuse
+
+ @property
+ def initializer(self):
+ return self._initializer
+
+ def reuse_variables(self):
+ """Reuse variables in this scope."""
+ self._reuse = True
+
+ def set_initializer(self, initializer):
+ """Set initializer for this scope."""
+ self._initializer = initializer
+
+ def get_variable(self, var_store, name, shape=None, dtype=types.float32,
+ initializer=None, trainable=True, collections=None):
+ """Gets an existing variable with this name or create a new one."""
+ if initializer is None and self._initializer:
+ initializer = self._initializer
+ full_name = self.name + "/" + name if self.name else name
+ # Variable names only depend on variable_scope (full_name here),
+ # not name_scope, so we reset it below for the time of variable creation.
+ with ops.name_scope(None):
+ return var_store.get_variable(full_name, shape, dtype, initializer,
+ self.reuse, trainable, collections)
+
+
+_VARSTORE_KEY = ("__variable_store",)
+_VARSCOPE_KEY = ("__varscope",)
+
+
+def get_variable_scope():
+ """Returns the current variable scope."""
+ scope = ops.get_collection(_VARSCOPE_KEY)
+ if scope: # This collection has at most 1 element, the default scope at [0].
+ return scope[0]
+ scope = _VariableScope(False)
+ ops.add_to_collection(_VARSCOPE_KEY, scope)
+ return scope
+
+
+def _get_default_variable_store():
+ store = ops.get_collection(_VARSTORE_KEY)
+ if store:
+ return store[0]
+ store = _VariableStore()
+ ops.add_to_collection(_VARSTORE_KEY, store)
+ return store
+
+
+def get_variable(name, shape=None, dtype=types.float32, initializer=None,
+ trainable=True, collections=None):
+ """Gets an existing variable with these parameters or create a new one.
+
+ This function prefixes the name with the current variable scope
+ and performs reuse checks. See the
+ [Variable Scope How To](../../how_tos/variable_scope/index.md)
+ for an extensive description of how reusing works. Here is a basic example:
+
+ ```python
+ with tf.variable_scope("foo"):
+ v = get_variable("v", [1]) # v.name == "foo/v:0"
+ w = get_variable("w", [1]) # w.name == "foo/w:0"
+ with tf.variable_scope("foo", reuse=True)
+ v1 = get_variable("v") # The same as v above.
+ ```
+
+ If initializer is `None` (the default), the default initializer passed in
+ the constructor is used. If that one is `None` too, a
+ `UniformUnitScalingInitializer` will be used.
+
+ Args:
+ name: the name of the new or existing variable.
+ shape: shape of the new or existing variable.
+ dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
+ initializer: initializer for the variable if one is created.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see variables.Variable).
+ collections: List of graph collections keys to add the Variable to.
+ Defaults to `[GraphKeys.VARIABLES]` (see variables.Variable).
+
+ Returns:
+ The created or existing variable.
+
+ Raises:
+ ValueError: when creating a new variable and shape is not declared,
+ or when violating reuse during variable creation. Reuse is set inside
+ `variable_scope`.
+ """
+ return get_variable_scope().get_variable(_get_default_variable_store(), name,
+ shape, dtype, initializer,
+ trainable, collections)
+
+
+@contextlib.contextmanager
+def variable_scope(name_or_scope, reuse=None, initializer=None):
+ """Returns a context for variable scope.
+
+ Variable scope allows to create new variables and to share already created
+ ones while providing checks to not create or share by accident. For details,
+ see the [Variable Scope How To](../../how_tos/variable_scope/index.md),
+ here we present only a few basic examples.
+
+ Simple example of how to create a new variable:
+
+ ```python
+ with tf.variable_scope("foo"):
+ with tf.variable_scope("bar"):
+ v = tf.get_variable("v", [1])
+ assert v.name == "foo/bar/v:0"
+ ```
+
+ Basic example of sharing a variable:
+
+ ```python
+ with tf.variable_scope("foo"):
+ v = get_variable("v", [1])
+ with tf.variable_scope("foo", reuse=True):
+ v1 = tf.get_variable("v", [1])
+ assert v1 == v
+ ```
+
+ Sharing a variable by capturing a scope and setting reuse:
+
+ ```python
+ with tf.variable_scope("foo") as scope.
+ v = get_variable("v", [1])
+ scope.reuse_variables()
+ v1 = tf.get_variable("v", [1])
+ assert v1 == v
+ ```
+
+ To prevent accidental sharing of variables, we raise an exception when
+ getting an existing variable in a non-reusing scope.
+
+ ```python
+ with tf.variable_scope("foo") as scope.
+ v = get_variable("v", [1])
+ v1 = tf.get_variable("v", [1])
+ # Raises ValueError("... v already exists ...").
+ ```
+
+ Similarly, we raise an exception when trying to get a variable that
+ does not exist in reuse mode.
+
+ ```python
+ with tf.variable_scope("foo", reuse=True):
+ v = get_variable("v", [1])
+ # Raises ValueError("... v does not exists ...").
+ ```
+
+ Note that the `reuse` flag is inherited: if we open a reusing scope,
+ then all its sub-scopes become reusing as well.
+
+ Args:
+ name_or_scope: `string` or `VariableScope`: the scope to open.
+ reuse: `True` or `None`; if `True`, we go into reuse mode for this scope as
+ well as all sub-scopes; if `None`, we just inherit the parent scope reuse.
+ initializer: default initializer for variables within this scope.
+
+ Yields:
+ A scope that can be to captured and reused.
+
+ Raises:
+ ValueError: when trying to reuse within a create scope, or create within
+ a reuse scope, or if reuse is not `None` or `True`.
+ TypeError: when the types of some arguments are not appropriate.
+ """
+ if not isinstance(name_or_scope, (_VariableScope, basestring)):
+ raise TypeError("VariableScope: name_scope must be a string or "
+ "VariableScope.")
+ if reuse not in [None, True]:
+ raise ValueError("VariableScope reuse parameter must be True or None.")
+ if not reuse and isinstance(name_or_scope, (_VariableScope)):
+ logging.info("Passing VariableScope to a non-reusing scope, intended?")
+ if reuse and isinstance(name_or_scope, (basestring)):
+ logging.info("Re-using string-named scope, consider capturing as object.")
+ get_variable_scope() # Ensure that a default exists, then get a pointer.
+ default_varscope = ops.get_collection(_VARSCOPE_KEY)
+ try:
+ old = default_varscope[0]
+ reuse = reuse or old.reuse # Re-using is inherited by sub-scopes.
+ if isinstance(name_or_scope, _VariableScope):
+ # Handler for the case when we jump to a shared scope.
+ # In this case, we leave the current name_scope unchanged.
+ # We create a new VariableScope (default_varscope[0]) that contains
+ # a copy of the provided shared scope, possibly with changed reuse
+ # and initializer, if the user requested this.
+ default_varscope[0] = _VariableScope(reuse, name_or_scope.name,
+ name_or_scope.initializer)
+ if initializer:
+ default_varscope[0].set_initializer(initializer)
+ yield default_varscope[0]
+ else:
+ # Handler for the case when we just prolong current variable scope.
+ # In this case we prolong the current name_scope and create a new
+ # VariableScope with name extended by the provided one, and inherited
+ # reuse and initializer (except if the user provided values to set).
+ with ops.name_scope(name_or_scope):
+ new_name = old.name + "/" + name_or_scope if old.name else name_or_scope
+ default_varscope[0] = _VariableScope(reuse, name=new_name,
+ initializer=old.initializer)
+ if initializer:
+ default_varscope[0].set_initializer(initializer)
+ yield default_varscope[0]
+ finally:
+ default_varscope[0] = old
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
new file mode 100644
index 0000000000..dafd3b8bdc
--- /dev/null
+++ b/tensorflow/python/ops/variables.py
@@ -0,0 +1,569 @@
+"""Variable class."""
+import tensorflow.python.platform
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import state_ops
+
+
+class Variable(object):
+ """See the [Variables How To](../../how_tos/variables/index.md) for a high
+ level overview.
+
+ A variable maintains state in the graph across calls to `run()`. You add a
+ variable to the graph by constructing an instance of the class `Variable`.
+
+ The `Variable()` constructor requires an initial value for the variable,
+ which can be a `Tensor` of any type and shape. The initial value defines the
+ type and shape of the variable. After construction, the type and shape of
+ the variable are fixed. The value can be changed using one of the assign
+ methods.
+
+ If you want to change the shape of a variable later you have to use an
+ `assign` Op with `validate_shape=False`.
+
+ Just like any `Tensor`, variables created with `Variable()` can be used as
+ inputs for other Ops in the graph. Additionally, all the operators
+ overloaded for the `Tensor` class are carried over to variables, so you can
+ also add nodes to the graph by just doing arithmetic on variables.
+
+ ```python
+ import tensorflow as tf
+
+ # Create a variable.
+ w = tf.Variable(<initial-value>, name=<optional-name>)
+
+ # Use the variable in the graph like any Tensor.
+ y = tf.matmul(w, ...another variable or tensor...)
+
+ # The overloaded operators are available too.
+ z = tf.sigmoid(w + b)
+
+ # Assign a new value to the variable with `assign()` or a related method.
+ w.assign(w + 1.0)
+ w.assign_add(1.0)
+ ```
+
+ When you launch the graph, variables have to be explicitly initialized before
+ you can run Ops that use their value. You can initialize a variable by
+ running its *initializer op*, restoring the variable from a save file, or
+ simply running an `assign` Op that assigns a value to the variable. In fact,
+ the variable *initializer op* is just an `assign` Op that assigns the
+ variable's initial value to the variable itself.
+
+ ```python
+ # Launch the graph in a session.
+ with tf.Session() as sess:
+ # Run the variable initializer.
+ sess.run(w.initializer)
+ # ...you now can run ops that use the value of 'w'...
+ ```
+
+ The most common initialization pattern is to use the convenience function
+ `initialize_all_variables()` to add an Op to the graph that initializes
+ all the variables. You then run that Op after launching the graph.
+
+ ```python
+ # Add an Op to initialize all variables.
+ init_op = tf.initialize_all_variables()
+
+ # Launch the graph in a session.
+ with tf.Session() as sess:
+ # Run the Op that initializes all variables.
+ sess.run(init_op)
+ # ...you can now run any Op that uses variable values...
+ ```
+
+ If you need to create a variable with an initial value dependent on another
+ variable, use the other variable's `initialized_value()`. This ensures that
+ variables are initialized in the right order.
+
+ All variables are automatically collected in the graph where they are
+ created. By default, the constructor adds the new variable to the graph
+ collection `GraphKeys.VARIABLES`. The convenience function
+ `all_variables()` returns the contents of that collection.
+
+ When building a machine learning model it is often convenient to distinguish
+ betwen variables holding the trainable model parameters and other variables
+ such as a `global step` variable used to count training steps. To make this
+ easier, the variable constructor supports a `trainable=<bool>` parameter. If
+ `True`, the new variable is also added to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES`. The convenience function
+ `trainable_variables()` returns the contents of this collection. The
+ various `Optimizer` classes use this collection as the default list of
+ variables to optimize.
+
+
+ Creating a variable.
+
+ @@__init__
+ @@initialized_value
+
+ Changing a variable value.
+
+ @@assign
+ @@assign_add
+ @@assign_sub
+ @@scatter_sub
+ @@count_up_to
+
+ @@eval
+
+ Properties.
+
+ @@name
+ @@dtype
+ @@get_shape
+ @@device
+ @@initializer
+ @@graph
+ @@op
+ """
+
+ def __init__(self, initial_value, trainable=True, collections=None,
+ validate_shape=True, name=None):
+ """Creates a new variable with value `initial_value`.
+
+ The new variable is added to the graph collections listed in `collections`,
+ which defaults to `[GraphKeys.VARIABLES]`.
+
+ If `trainable` is `True` the variable is also added to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES`.
+
+ This constructor creates both a `variable` Op and an `assign` Op to set the
+ variable to its initial value.
+
+ Args:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`.
+ The initial value for the Variable. Must have a shape specified unless
+ `validate_shape` is set to False.
+ trainable: If `True`, the default, also adds the variable to the graph
+ collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
+ the default list of variables to use by the `Optimizer` classes.
+ collections: List of graph collections keys. The new variable is added to
+ these collections. Defaults to `[GraphKeys.VARIABLES]`.
+ validate_shape: If `False`, allows the variable to be initialized with a
+ value of unknown shape. If `True`, the default, the shape of
+ `initial_value` must be known.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+
+ Returns:
+ A Variable.
+
+ Raises:
+ ValueError: If the initial value does not have a shape and
+ `validate_shape` is `True`.
+ """
+ if collections is None:
+ collections = [ops.GraphKeys.VARIABLES]
+ if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
+ # pylint: disable=g-no-augmented-assignment
+ #
+ # Pylint wants us to write collections += [...TRAINABLE_VARIABLES] which
+ # is not the same (it modifies the list in place.) Here, we only want to
+ # modify the value of the variable, not the list.
+ collections = collections + [ops.GraphKeys.TRAINABLE_VARIABLES]
+ # pylint: enable=g-no-augmented-assignment
+ with ops.op_scope([initial_value], name, "Variable") as name:
+ self._initial_value = ops.convert_to_tensor(initial_value,
+ name="initial_value")
+ if not self._initial_value.get_shape().is_fully_defined():
+ if validate_shape:
+ raise ValueError(
+ "initial_value must have a shape specified: %s"
+ % self._initial_value)
+ self._variable = state_ops.variable_op(
+ [], self._initial_value.dtype.base_dtype, set_shape=False,
+ name=name)
+ with ops.device(self._variable.device):
+ self._initializer_op = state_ops.assign(
+ self._variable, self._initial_value, validate_shape=False).op
+ else:
+ self._variable = state_ops.variable_op(
+ self._initial_value.get_shape(),
+ self._initial_value.dtype.base_dtype,
+ name=name)
+ with ops.device(self._variable.device):
+ self._initializer_op = state_ops.assign(
+ self._variable, self._initial_value).op
+ for key in collections:
+ ops.add_to_collection(key, self)
+ self._save_slice_info = None
+
+ def _as_graph_element(self):
+ """Conversion function for Graph.as_graph_element()."""
+ return self._variable
+
+ def _AsTensor(self):
+ """Conversion function for ops.convert_to_tensor()."""
+ return self._variable
+
+ def eval(self, session=None):
+ """In a session, computes and returns the value of this variable.
+
+ This is not a graph construction method, it does not add ops to the graph.
+
+ This convenience method requires a session where the graph containing this
+ variable has been launched. If no session is passed, the default session is
+ used. See the [Session class](../client.md#Session) for more information on
+ launching a graph and on sessions.
+
+ ```python
+ v = tf.Variable([1, 2])
+ init = tf.initialize_all_variables()
+
+ with tf.Session() as sess:
+ sess.run(init)
+ # Usage passing the session explicitly.
+ print v.eval(sess)
+ # Usage with the default session. The 'with' block
+ # above makes 'sess' the default session.
+ print v.eval()
+ ```
+
+ Args:
+ session: The session to use to evaluate this variable. If
+ none, the default session is used.
+
+ Returns:
+ A numpy `ndarray` with a copy of the value of this variable.
+ """
+ return self._variable.eval(session=session)
+
+ def initialized_value(self):
+ """Returns the value of the initialized variable.
+
+ You should use this instead of the variable itself to initialize another
+ variable with a value that depends on the value of this variable.
+
+ ```python
+ # Initialize 'v' with a random tensor.
+ v = tf.Variable(tf.truncated_normal([10, 40]))
+ # Use `initialized_value` to guarantee that `v` has been
+ # initialized before its value is used to initialize `w`.
+ # The random values are picked only once.
+ w = tf.Variable(v.initialized_value() * 2.0)
+ ```
+
+ Returns:
+ A `Tensor` holding the value of this variable after its initializer
+ has run.
+ """
+ return control_flow_ops.with_dependencies(
+ [self._initializer_op], self._variable)
+
+ def assign(self, value, use_locking=False):
+ """Assigns a new value to the variable.
+
+ This is essentially a shortcut for `assign(self, value)`.
+
+ Args:
+ value: A `Tensor`. The new value for this variable.
+ use_locking: If `True`, use locking during the assignment.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the assignment has completed.
+ """
+ return state_ops.assign(self._variable, value, use_locking=use_locking)
+
+ def assign_add(self, delta, use_locking=False):
+ """Adds a value to this variable.
+
+ This is essentially a shortcut for `assign_add(self, delta)`.
+
+ Args:
+ delta: A `Tensor`. The value to add to this variable.
+ use_locking: If `True`, use locking during the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the addition has completed.
+ """
+ return state_ops.assign_add(self._variable, delta, use_locking=use_locking)
+
+ def assign_sub(self, delta, use_locking=False):
+ """Subtracts a value from this variable.
+
+ This is essentially a shortcut for `assign_sub(self, delta)`.
+
+ Args:
+ delta: A `Tensor`. The value to subtract from this variable.
+ use_locking: If `True`, use locking during the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the subtraction has completed.
+ """
+ return state_ops.assign_sub(self._variable, delta, use_locking=use_locking)
+
+ def scatter_sub(self, sparse_delta, use_locking=False):
+ """Subtracts `IndexedSlices` from this variable.
+
+ This is essentially a shortcut for `scatter_sub(self, sparse_delta.indices,
+ sparse_delta.values)`.
+
+ Args:
+ sparse_delta: `IndexedSlices` to be subtracted from this variable.
+ use_locking: If `True`, use locking during the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ if not isinstance(sparse_delta, ops.IndexedSlices):
+ raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
+ return state_ops.scatter_sub(self._variable,
+ sparse_delta.indices,
+ sparse_delta.values,
+ use_locking=use_locking)
+
+ def count_up_to(self, limit):
+ """Increments this variable until it reaches `limit`.
+
+ When that Op is run it tries to increment the variable by `1`. If
+ incrementing the variable would bring it above `limit` then the Op raises
+ the exception `OutOfRangeError`.
+
+ If no error is raised, the Op outputs the value of the variable before
+ the increment.
+
+ This is essentially a shortcut for `count_up_to(self, limit)`.
+
+ Args:
+ limit: value at which incrementing the variable raises an error.
+
+ Returns:
+ A `Tensor` that will hold the variable value before the increment. If no
+ other Op modifies this variable, the values produced will all be
+ distinct.
+ """
+ return state_ops.count_up_to(self._variable, limit=limit)
+
+ # Conversion to tensor.
+ @staticmethod
+ def _TensorConversionFunction(v, dtype=None, name=None):
+ """Utility function for converting a Variable to a Tensor."""
+ _ = name
+ ret = v._AsTensor() # pylint: disable=protected-access
+ if dtype and not dtype.is_compatible_with(v.dtype):
+ raise ValueError(
+ "Incompatible type conversion requested to type '%s' for variable "
+ "of type '%s'" % (dtype.name, v.dtype.name))
+ return ret
+
+ # Operator overloading.
+ #
+ # To carry over all overloaded operators from ops.Tensor to Variable, we
+ # register the _RunOp() static method as the implementation of all operators.
+ # That function dynamically discovers the overloaded operator in ops.Tensor
+ # and invokes it after converting the Variable to a tensor.
+ @staticmethod
+ def _OverloadAllOperators():
+ """Register overloads for all operators."""
+ for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
+ Variable._OverloadOperator(operator)
+
+ @staticmethod
+ def _OverloadOperator(operator):
+ """Register _RunOp as the implementation of 'operator'.
+
+ Args:
+ operator: string. The operator name.
+ """
+ if operator in ["__invert__", "__neg__", "__abs__"]:
+ setattr(Variable, operator, lambda a: Variable._RunOp(operator, a, None))
+ else:
+ setattr(Variable, operator, lambda a, b: Variable._RunOp(operator, a, b))
+
+ @staticmethod
+ def _RunOp(operator, a, b):
+ """Run the operator 'op' for 'a'.
+
+ Args:
+ operator: string. The operator name.
+ a: A Variable.
+ b: Second argument to the operator. None if unary.
+ Returns:
+ The result of the operator.
+ """
+ # pylint: disable=protected-access
+ if b is not None:
+ return getattr(ops.Tensor, operator)(a._AsTensor(), b)
+ else:
+ return getattr(ops.Tensor, operator)(a._AsTensor())
+ # pylint: enable=protected-access
+
+ @property
+ def name(self):
+ """The name of this variable."""
+ return self._variable.name
+
+ @property
+ def initializer(self):
+ """The initializer operation for this variable."""
+ return self._initializer_op
+
+ @property
+ def device(self):
+ """The device of this variable."""
+ return self._variable.device
+
+ @property
+ def dtype(self):
+ """The `DType` of this variable."""
+ return self._variable.dtype
+
+ @property
+ def op(self):
+ """The `Operation` of this variable."""
+ return self._variable.op
+
+ @property
+ def graph(self):
+ """The `Graph` of this variable."""
+ return self._variable.graph
+
+ def get_shape(self):
+ """The `TensorShape` of this variable.
+
+ Returns:
+ A `TensorShape`.
+ """
+ return self._variable.get_shape()
+
+ # Experimental support for saving variables as slices of a larger variable.
+ class SaveSliceInfo(object):
+ """Information on how to save this Variable as a slice."""
+
+ def __init__(self, name, spec):
+ """Create a SliceInfo.
+
+ Args:
+ name: Name of the larger Tensor that this variable is a slice of.
+ spec: Slice specification for the saver.
+ """
+ self.name = name
+ self.spec = spec
+
+ def _set_save_slice_info(self, save_slice_info):
+ """Sets the slice info for this Variable.
+
+ Args:
+ save_slice_info: A Variable.SliceInfo object.
+ """
+ self._save_slice_info = save_slice_info
+
+
+def all_variables():
+ """Returns all variables collected in the graph.
+
+ The `Variable()` constructor automatically adds new variables to the graph
+ collection `GraphKeys.VARIABLES`. This convenience function returns the
+ contents of that collection.
+
+ Returns:
+ A list of `Variable` objects.
+ """
+ return ops.get_collection(ops.GraphKeys.VARIABLES)
+
+
+def trainable_variables():
+ """Returns all variables created with `trainable=True`.
+
+ When passed `trainable=True`, the `Variable()` constructor automatically
+ adds new variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES`. This convenience function returns the
+ contents of that collection.
+
+ Returns:
+ A list of Variable objects.
+ """
+ return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+
+
+def initialize_variables(var_list, name="init"):
+ """Returns an Op that initializes a list of variables.
+
+ After you launch the graph in a session, you can run the returned Op to
+ initialize all the variables in `var_list`. This Op runs all the
+ initializers of the variables in `var_list` in parallel.
+
+ Calling `initialize_variables()` is equivalent to passing the list of
+ initializers to `Group()`.
+
+ If `var_list` is empty, however, the function still returns an Op that can
+ be run. That Op just has no effect.
+
+ Args:
+ var_list: List of `Variable` objects to initialize.
+ name: Optional name for the returned operation.
+
+ Returns:
+ An Op that run the initializers of all the specified variables.
+ """
+ if var_list:
+ return control_flow_ops.group(
+ *[v.initializer for v in var_list], name=name)
+ return control_flow_ops.no_op(name=name)
+
+
+def initialize_all_variables():
+ """Returns an Op that initializes all variables.
+
+ This is just a shortcut for `initialize_variables(all_variables())`
+
+ Returns:
+ An Op that initializes all variables in the graph.
+ """
+ return initialize_variables(all_variables())
+
+
+def assert_variables_initialized(var_list=None):
+ """Returns an Op to check if variables are initialized.
+
+ When run, the returned Op will raise the exception `FailedPreconditionError`
+ if any of the variables has not yet been initialized.
+
+ Note: This function is implemented by trying to fetch the values of the
+ variables. If one of the variables is not initialized a message may be
+ logged by the C++ runtime. This is expected.
+
+ Args:
+ var_list: List of `Variable` objects to check. Defaults to the
+ value of `all_variables().`
+
+ Returns:
+ An Op, or None if there are no variables.
+ """
+ if var_list is None:
+ var_list = all_variables()
+ # Backwards compatibility for old-style variables. TODO(mdevin): remove.
+ if not var_list:
+ var_list = []
+ for op in ops.get_default_graph().get_operations():
+ if op.type in ["Variable", "AutoReloadVariable"]:
+ var_list.append(op.outputs[0])
+ if not var_list:
+ return None
+ else:
+ ranks = []
+ for var in var_list:
+ with ops.device(var.device):
+ ranks.append(array_ops.rank(var))
+ if len(ranks) == 1:
+ return ranks[0]
+ else:
+ return array_ops.pack(ranks)
+
+
+# pylint: disable=protected-access
+ops.register_tensor_conversion_function(Variable,
+ Variable._TensorConversionFunction)
+Variable._OverloadAllOperators()
+# pylint: enable=protected-access
diff --git a/tensorflow/python/platform/__init__.py b/tensorflow/python/platform/__init__.py
new file mode 100644
index 0000000000..b545bac907
--- /dev/null
+++ b/tensorflow/python/platform/__init__.py
@@ -0,0 +1,6 @@
+"""Setup system-specific platform environment for TensorFlow."""
+import control_imports
+if control_imports.USE_OSS:
+ from tensorflow.python.platform.default._init import *
+else:
+ from tensorflow.python.platform.google._init import *
diff --git a/tensorflow/python/platform/app.py b/tensorflow/python/platform/app.py
new file mode 100644
index 0000000000..3d51bc74b2
--- /dev/null
+++ b/tensorflow/python/platform/app.py
@@ -0,0 +1,13 @@
+"""Switch between depending on pyglib.app or an OSS replacement."""
+# pylint: disable=unused-import
+# pylint: disable=g-import-not-at-top
+# pylint: disable=wildcard-import
+import tensorflow.python.platform
+import control_imports
+if control_imports.USE_OSS and control_imports.OSS_APP:
+ from tensorflow.python.platform.default._app import *
+else:
+ from tensorflow.python.platform.google._app import *
+
+# Import 'flags' into this module
+from tensorflow.python.platform import flags
diff --git a/tensorflow/python/platform/base.i b/tensorflow/python/platform/base.i
new file mode 100644
index 0000000000..85fa3968a1
--- /dev/null
+++ b/tensorflow/python/platform/base.i
@@ -0,0 +1,176 @@
+// Helper macros and typemaps for use in Tensorflow swig files.
+//
+%{
+ #include <memory>
+ #include "tensorflow/core/platform/port.h"
+ using tensorflow::uint64;
+ using tensorflow::string;
+
+ template<class T>
+ bool _PyObjAs(PyObject *pystr, T* cstr) {
+ T::undefined; // You need to define specialization _PyObjAs<T>
+ }
+
+ template<class T>
+ PyObject *_PyObjFrom(const T& c) {
+ T::undefined; // You need to define specialization _PyObjFrom<T>
+ }
+
+#ifdef HAS_GLOBAL_STRING
+ template<>
+ bool _PyObjAs(PyObject *pystr, ::string* cstr) {
+ char *buf;
+ Py_ssize_t len;
+#if PY_VERSION_HEX >= 0x03030000
+ if (PyUnicode_Check(pystr)) {
+ buf = PyUnicode_AsUTF8AndSize(pystr, &len);
+ if (!buf) return false;
+ } else // NOLINT
+#endif
+ if (PyBytes_AsStringAndSize(pystr, &buf, &len) == -1) return false;
+ if (cstr) cstr->assign(buf, len);
+ return true;
+ }
+#endif
+ template<>
+ bool _PyObjAs(PyObject *pystr, std::string* cstr) {
+ char *buf;
+ Py_ssize_t len;
+#if PY_VERSION_HEX >= 0x03030000
+ if (PyUnicode_Check(pystr)) {
+ buf = PyUnicode_AsUTF8AndSize(pystr, &len);
+ if (!buf) return false;
+ } else // NOLINT
+#endif
+ if (PyBytes_AsStringAndSize(pystr, &buf, &len) == -1) return false;
+ if (cstr) cstr->assign(buf, len);
+ return true;
+ }
+#ifdef HAS_GLOBAL_STRING
+ template<>
+ PyObject* _PyObjFrom(const ::string& c) {
+ return PyString_FromStringAndSize(c.data(), c.size());
+ }
+#endif
+ template<>
+ PyObject* _PyObjFrom(const std::string& c) {
+ return PyString_FromStringAndSize(c.data(), c.size());
+ }
+
+ PyObject* _SwigString_FromString(const string& s) {
+ return PyUnicode_FromStringAndSize(s.data(), s.size());
+ }
+%}
+
+%typemap(in) string {
+ if (!_PyObjAs<string>($input, &$1)) return NULL;
+}
+
+%typemap(in) const string& (string temp) {
+ if (!_PyObjAs<string>($input, &temp)) return NULL;
+ $1 = &temp;
+}
+
+%typemap(out) string {
+ $result = PyString_FromStringAndSize($1.data(), $1.size());
+}
+
+%typemap(out) const string& {
+ $result = PyString_FromStringAndSize($1->data(), $1->size());
+}
+
+%typemap(in, numinputs = 0) string* OUTPUT (string temp) {
+ $1 = &temp;
+}
+
+%typemap(argout) string * OUTPUT {
+ PyObject *str = PyString_FromStringAndSize($1->data(), $1->length());
+ if (!str) SWIG_fail;
+ %append_output(str);
+}
+
+%typemap(argout) string* INOUT = string* OUTPUT;
+
+%typemap(varout) string {
+ $result = PyString_FromStringAndSize($1.data(), $1.size());
+}
+
+%define _LIST_OUTPUT_TYPEMAP(type, py_converter)
+ %typemap(in) std::vector<type>(std::vector<type> temp) {
+ if (!vector_input_helper($input, &temp, _PyObjAs<type>)) {
+ if (!PyErr_Occurred())
+ PyErr_SetString(PyExc_TypeError, "sequence(type) expected");
+ return NULL;
+ }
+ $1 = temp;
+}
+%typemap(in) const std::vector<type>& (std::vector<type> temp),
+ const std::vector<type>* (std::vector<type> temp) {
+ if (!vector_input_helper($input, &temp, _PyObjAs<type>)) {
+ if (!PyErr_Occurred())
+ PyErr_SetString(PyExc_TypeError, "sequence(type) expected");
+ return NULL;
+ }
+ $1 = &temp;
+}
+%typemap(in,numinputs=0)
+std::vector<type>* OUTPUT (std::vector<type> temp),
+ hash_set<type>* OUTPUT (hash_set<type> temp),
+ set<type>* OUTPUT (set<type> temp) {
+ $1 = &temp;
+}
+%typemap(argout) std::vector<type>* OUTPUT, set<type>* OUTPUT, hash_set<type>* OUTPUT {
+ %append_output(list_output_helper($1, &py_converter));
+}
+%typemap(out) std::vector<type> {
+ $result = vector_output_helper(&$1, &py_converter);
+}
+%typemap(out) std::vector<type>*, const std::vector<type>& {
+ $result = vector_output_helper($1, &py_converter);
+}
+%enddef
+
+_LIST_OUTPUT_TYPEMAP(string, _SwigString_FromString);
+_LIST_OUTPUT_TYPEMAP(unsigned long long, PyLong_FromUnsignedLongLong);
+
+%typemap(in) uint64 {
+ // TODO(gps): Check if another implementation
+ // from hosting/images/util/image-hosting-utils.swig is better. May be not.
+%#if PY_MAJOR_VERSION < 3
+ if (PyInt_Check($input)) {
+ $1 = static_cast<uint64>(PyInt_AsLong($input));
+ } else
+%#endif
+ if (PyLong_Check($input)) {
+ $1 = static_cast<uint64>(PyLong_AsUnsignedLongLong($input));
+ } else {
+ PyErr_SetString(PyExc_TypeError,
+ "int or long value expected for argument \"$1_name\"");
+ }
+ // TODO(mrovner): Make consistent use of SWIG_fail vs. return NULL.
+ if (PyErr_Occurred()) return NULL;
+}
+
+%define _COPY_TYPEMAPS(oldtype, newtype)
+ typedef oldtype newtype;
+%apply oldtype * OUTPUT { newtype * OUTPUT };
+%apply oldtype & OUTPUT { newtype & OUTPUT };
+%apply oldtype * INPUT { newtype * INPUT };
+%apply oldtype & INPUT { newtype & INPUT };
+%apply oldtype * INOUT { newtype * INOUT };
+%apply oldtype & INOUT { newtype & INOUT };
+%apply std::vector<oldtype> * OUTPUT { std::vector<newtype> * OUTPUT };
+%enddef
+
+_COPY_TYPEMAPS(unsigned long long, uint64);
+
+// SWIG macros for explicit API declaration.
+// Usage:
+//
+// %ignoreall
+// %unignore SomeName; // namespace / class / method
+// %include "somelib.h"
+// %unignoreall // mandatory closing "bracket"
+%define %ignoreall %ignore ""; %enddef
+%define %unignore %rename("%s") %enddef
+%define %unignoreall %rename("%s") ""; %enddef
diff --git a/tensorflow/python/platform/control_imports.py b/tensorflow/python/platform/control_imports.py
new file mode 100644
index 0000000000..713caf3f4f
--- /dev/null
+++ b/tensorflow/python/platform/control_imports.py
@@ -0,0 +1,13 @@
+"""Switch between Google or open source dependencies."""
+# Switch between Google and OSS dependencies
+USE_OSS = True
+
+# Per-dependency switches determining whether each dependency is ready
+# to be replaced by its OSS equivalence.
+# TODO(danmane,mrry,opensource): Flip these switches, then remove them
+OSS_APP = True
+OSS_FLAGS = True
+OSS_GFILE = True
+OSS_GOOGLETEST = True
+OSS_LOGGING = True
+OSS_PARAMETERIZED = True
diff --git a/tensorflow/python/platform/default/__init__.py b/tensorflow/python/platform/default/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/platform/default/__init__.py
diff --git a/tensorflow/python/platform/default/_app.py b/tensorflow/python/platform/default/_app.py
new file mode 100644
index 0000000000..5917d00ce3
--- /dev/null
+++ b/tensorflow/python/platform/default/_app.py
@@ -0,0 +1,11 @@
+"""Generic entry point script."""
+import sys
+
+from tensorflow.python.platform import flags
+
+
+def run():
+ f = flags.FLAGS
+ f._parse_flags()
+ main = sys.modules['__main__'].main
+ sys.exit(main(sys.argv))
diff --git a/tensorflow/python/platform/default/_flags.py b/tensorflow/python/platform/default/_flags.py
new file mode 100644
index 0000000000..ceccda6e5c
--- /dev/null
+++ b/tensorflow/python/platform/default/_flags.py
@@ -0,0 +1,92 @@
+"""Implementation of the flags interface."""
+import tensorflow.python.platform
+
+import argparse
+
+_global_parser = argparse.ArgumentParser()
+
+class _FlagValues(object):
+
+ def __init__(self):
+ """Global container and accessor for flags and their values."""
+ self.__dict__['__flags'] = {}
+ self.__dict__['__parsed'] = False
+
+ def _parse_flags(self):
+ result = _global_parser.parse_args()
+ for flag_name, val in vars(result).items():
+ self.__dict__['__flags'][flag_name] = val
+ self.__dict__['__parsed'] = True
+
+ def __getattr__(self, name):
+ """Retrieves the 'value' attribute of the flag --name."""
+ if not self.__dict__['__parsed']:
+ self._parse_flags()
+ if name not in self.__dict__['__flags']:
+ raise AttributeError(name)
+ return self.__dict__['__flags'][name]
+
+ def __setattr__(self, name, value):
+ """Sets the 'value' attribute of the flag --name."""
+ if not self.__dict__['__parsed']:
+ self._parse_flags()
+ self.__dict__['__flags'][name] = value
+
+
+def _define_helper(flag_name, default_value, docstring, flagtype):
+ """Registers 'flag_name' with 'default_value' and 'docstring'."""
+ _global_parser.add_argument("--" + flag_name,
+ default=default_value,
+ help=docstring,
+ type=flagtype)
+
+
+# Provides the global object that can be used to access flags.
+FLAGS = _FlagValues()
+
+
+def DEFINE_string(flag_name, default_value, docstring):
+ """Defines a flag of type 'string'.
+
+ Args:
+ flag_name: The name of the flag as a string.
+ default_value: The default value the flag should take as a string.
+ docstring: A helpful message explaining the use of the flag.
+ """
+ _define_helper(flag_name, default_value, docstring, str)
+
+
+def DEFINE_integer(flag_name, default_value, docstring):
+ """Defines a flag of type 'int'.
+
+ Args:
+ flag_name: The name of the flag as a string.
+ default_value: The default value the flag should take as an int.
+ docstring: A helpful message explaining the use of the flag.
+ """
+ _define_helper(flag_name, default_value, docstring, int)
+
+
+def DEFINE_boolean(flag_name, default_value, docstring):
+ """Defines a flag of type 'boolean'.
+
+ Args:
+ flag_name: The name of the flag as a string.
+ default_value: The default value the flag should take as a boolean.
+ docstring: A helpful message explaining the use of the flag.
+ """
+ _define_helper(flag_name, default_value, docstring, bool)
+ _global_parser.add_argument('--no' + flag_name,
+ action='store_false',
+ dest=flag_name)
+
+
+def DEFINE_float(flag_name, default_value, docstring):
+ """Defines a flag of type 'float'.
+
+ Args:
+ flag_name: The name of the flag as a string.
+ default_value: The default value the flag should take as a float.
+ docstring: A helpful message explaining the use of the flag.
+ """
+ _define_helper(flag_name, default_value, docstring, float)
diff --git a/tensorflow/python/platform/default/_gfile.py b/tensorflow/python/platform/default/_gfile.py
new file mode 100644
index 0000000000..cfd25bdf90
--- /dev/null
+++ b/tensorflow/python/platform/default/_gfile.py
@@ -0,0 +1,404 @@
+"""File processing utilities."""
+
+import errno
+import functools
+import glob as _glob
+import os
+import shutil
+import threading
+
+
+class FileError(IOError):
+ """An error occurred while reading or writing a file."""
+
+
+class GOSError(OSError):
+ """An error occurred while finding a file or in handling pathnames."""
+
+
+class _GFileBase(object):
+ """Base I/O wrapper class. Similar semantics to Python's file object."""
+
+ # pylint: disable=protected-access
+ def _error_wrapper(fn):
+ """Decorator wrapping GFileBase class method errors."""
+ @functools.wraps(fn) # Preserve methods' __doc__
+ def wrap(self, *args, **kwargs):
+ try:
+ return fn(self, *args, **kwargs)
+ except ValueError, e:
+ # Sometimes a ValueError is raised, e.g., a read() on a closed file.
+ raise FileError(errno.EIO, e.message, self._name)
+ except IOError, e:
+ e.filename = self._name
+ raise FileError(e)
+ except OSError, e:
+ raise GOSError(e)
+ return wrap
+
+ def _synchronized(fn):
+ """Synchronizes file I/O for methods in GFileBase."""
+ @functools.wraps(fn)
+ def sync(self, *args, **kwargs):
+ # Sometimes a GFileBase method is called before the instance
+ # has been properly initialized. Check that _locker is available.
+ if hasattr(self, '_locker'): self._locker.lock()
+ try:
+ return fn(self, *args, **kwargs)
+ finally:
+ if hasattr(self, '_locker'): self._locker.unlock()
+ return sync
+ # pylint: enable=protected-access
+
+ @_error_wrapper
+ def __init__(self, name, mode, locker):
+ """Create the GFileBase object with the given filename, mode, and locker.
+
+ Args:
+ name: string, the filename.
+ mode: string, the mode to open the file with (e.g. "r", "w", "a+").
+ locker: the thread locking object (e.g. _PythonLocker) for controlling
+ thread access to the I/O methods of this class.
+ """
+ self._name = name
+ self._mode = mode
+ self._locker = locker
+ self._fp = open(name, mode)
+
+ def __enter__(self):
+ """Make GFileBase usable with "with" statement."""
+ return self
+
+ def __exit__(self, unused_type, unused_value, unused_traceback):
+ """Make GFileBase usable with "with" statement."""
+ self.close()
+
+ @_error_wrapper
+ @_synchronized
+ def __del__(self):
+ # __del__ is sometimes called before initialization, in which
+ # case the object is not fully constructed. Check for this here
+ # before trying to close the file handle.
+ if hasattr(self, '_fp'): self._fp.close()
+
+ @_error_wrapper
+ @_synchronized
+ def flush(self):
+ """Flush the underlying file handle."""
+ return self._fp.flush()
+
+ @property
+ @_error_wrapper
+ @_synchronized
+ def closed(self):
+ """Returns "True" if the file handle is closed. Otherwise False."""
+ return self._fp.closed
+
+ @_error_wrapper
+ @_synchronized
+ def write(self, data):
+ """Write data to the underlying file handle.
+
+ Args:
+ data: The string to write to the file handle.
+ """
+ self._fp.write(data)
+
+ @_error_wrapper
+ @_synchronized
+ def writelines(self, seq):
+ """Write a sequence of strings to the underlying file handle."""
+ self._fp.writelines(seq)
+
+ @_error_wrapper
+ @_synchronized
+ def tell(self):
+ """Return the location from the underlying file handle.
+
+ Returns:
+ An integer location (which can be used in e.g., seek).
+ """
+ return self._fp.tell()
+
+ @_error_wrapper
+ @_synchronized
+ def seek(self, offset, whence=0):
+ """Seek to offset (conditioned on whence) in the underlying file handle.
+
+ Args:
+ offset: int, the offset within the file to seek to.
+ whence: 0, 1, or 2. See python's seek() documentation for details.
+ """
+ self._fp.seek(offset, whence)
+
+ @_error_wrapper
+ @_synchronized
+ def truncate(self, new_size=None):
+ """Truncate the underlying file handle to new_size.
+
+ Args:
+ new_size: Size after truncation. If None, the file handle is truncated
+ to 0 bytes.
+ """
+ self._fp.truncate(new_size)
+
+ @_error_wrapper
+ @_synchronized
+ def readline(self, max_length=-1):
+ """Read a single line (up to max_length) from the underlying file handle.
+
+ Args:
+ max_length: The maximum number of chsaracters to read.
+
+ Returns:
+ A string, including any newline at the end, or empty string if at EOF.
+ """
+ return self._fp.readline(max_length)
+
+ @_error_wrapper
+ @_synchronized
+ def readlines(self, sizehint=None):
+ """Read lines from the underlying file handle.
+
+ Args:
+ sizehint: See the python file.readlines() documentation.
+
+ Returns:
+ A list of strings from the underlying file handle.
+ """
+ if sizehint is not None:
+ return self._fp.readlines(sizehint)
+ else:
+ return self._fp.readlines()
+
+ def __iter__(self):
+ """Enable line iteration on the underlying handle (not synchronized)."""
+ return self
+
+ # Not synchronized
+ @_error_wrapper
+ def next(self):
+ """Enable line iteration on the underlying handle (not synchronized).
+
+ Returns:
+ An line iterator from the underlying handle.
+
+ Example:
+ # read a file's lines by consuming the iterator with a list
+ with open("filename", "r") as fp: lines = list(fp)
+ """
+ return self._fp.next()
+
+ @_error_wrapper
+ @_synchronized
+ def Size(self): # pylint: disable=invalid-name
+ """Get byte size of the file from the underlying file handle."""
+ cur = self.tell()
+ try:
+ self.seek(0, 2)
+ size = self.tell()
+ finally:
+ self.seek(cur)
+ return size
+
+ @_error_wrapper
+ @_synchronized
+ def read(self, n=-1):
+ """Read n bytes from the underlying file handle.
+
+ Args:
+ n: Number of bytes to read (if negative, read to end of file handle.)
+
+ Returns:
+ A string of the bytes read, up to the end of file.
+ """
+ return self._fp.read(n)
+
+ @_error_wrapper
+ @_synchronized
+ def close(self):
+ """Close the underlying file handle."""
+ self._fp.close()
+
+ # Declare wrappers as staticmethods at the end so that we can
+ # use them as decorators.
+ _error_wrapper = staticmethod(_error_wrapper)
+ _synchronized = staticmethod(_synchronized)
+
+
+class GFile(_GFileBase):
+ """File I/O wrappers with thread locking."""
+
+ def __init__(self, name, mode='r'):
+ super(GFile, self).__init__(name, mode, _Pythonlocker())
+
+
+class FastGFile(_GFileBase):
+ """File I/O wrappers without thread locking."""
+
+ def __init__(self, name, mode='r'):
+ super(FastGFile, self).__init__(name, mode, _Nulllocker())
+
+
+# locker classes. Note that locks must be reentrant, so that multiple
+# lock() calls by the owning thread will not block.
+class _Pythonlocker(object):
+ """A locking strategy that uses standard locks from the thread module."""
+
+ def __init__(self):
+ self._lock = threading.RLock()
+
+ def lock(self):
+ self._lock.acquire()
+
+ def unlock(self):
+ self._lock.release()
+
+
+class _Nulllocker(object):
+ """A locking strategy where lock() and unlock() methods are no-ops."""
+
+ def lock(self):
+ pass
+
+ def unlock(self):
+ pass
+
+
+def _func_error_wrapper(fn):
+ """Decorator wrapping function errors."""
+ @functools.wraps(fn) # Preserve methods' __doc__
+ def wrap(*args, **kwargs):
+ try:
+ return fn(*args, **kwargs)
+ except ValueError, e:
+ raise FileError(errno.EIO, e.message)
+ except IOError, e:
+ raise FileError(e)
+ except OSError, e:
+ raise GOSError(e)
+ return wrap
+
+
+@_func_error_wrapper
+def Exists(path): # pylint: disable=invalid-name
+ """Retruns True iff "path" exists (as a dir, file, non-broken symlink)."""
+ return os.path.exists(path)
+
+
+@_func_error_wrapper
+def IsDirectory(path): # pylint: disable=invalid-name
+ """Return True iff "path" exists and is a directory."""
+ return os.path.isdir(path)
+
+
+@_func_error_wrapper
+def Glob(glob): # pylint: disable=invalid-name
+ """Return a list of filenames matching the glob "glob"."""
+ return _glob.glob(glob)
+
+
+@_func_error_wrapper
+def MkDir(path, mode=0755): # pylint: disable=invalid-name
+ """Create the directory "path" with the given mode.
+
+ Args:
+ path: The directory path
+ mode: The file mode for the directory
+
+ Returns:
+ None
+
+ Raises:
+ GOSError: if the path already exists
+ """
+ os.mkdir(path, mode)
+
+
+@_func_error_wrapper
+def MakeDirs(path, mode=0755): # pylint: disable=invalid-name
+ """Recursively create the directory "path" with the given mode.
+
+ Args:
+ path: The directory path
+ mode: The file mode for the created directories
+
+ Returns:
+ None
+
+
+ Raises:
+ GOSError: if the path already exists
+ """
+ os.makedirs(path, mode)
+
+
+@_func_error_wrapper
+def RmDir(directory): # pylint: disable=invalid-name
+ """Removes the directory "directory" iff the directory is empty.
+
+ Args:
+ directory: The directory to remove.
+
+ Raises:
+ GOSError: If the directory does not exist or is not empty.
+ """
+ os.rmdir(directory)
+
+
+@_func_error_wrapper
+def Remove(path): # pylint: disable=invalid-name
+ """Delete the (non-directory) file "path".
+
+ Args:
+ path: The file to remove.
+
+ Raises:
+ GOSError: If "path" does not exist, is a directory, or cannot be deleted.
+ """
+ os.remove(path)
+
+
+@_func_error_wrapper
+def DeleteRecursively(path): # pylint: disable=invalid-name
+ """Delete the file or directory "path" recursively.
+
+ Args:
+ path: The path to remove (may be a non-empty directory).
+
+ Raises:
+ GOSError: If the path does not exist or cannot be deleted.
+ """
+ if IsDirectory(path):
+ shutil.rmtree(path)
+ else:
+ Remove(path)
+
+
+@_func_error_wrapper
+def ListDirectory(directory, return_dotfiles=False): # pylint: disable=invalid-name
+ """Returns a list of files in dir.
+
+ As with the standard os.listdir(), the filenames in the returned list will be
+ the basenames of the files in dir (not absolute paths). To get a list of
+ absolute paths of files in a directory, a client could do:
+ file_list = gfile.ListDir(my_dir)
+ file_list = [os.path.join(my_dir, f) for f in file_list]
+ (assuming that my_dir itself specified an absolute path to a directory).
+
+ Args:
+ directory: the directory to list
+ return_dotfiles: if True, dotfiles will be returned as well. Even if
+ this arg is True, '.' and '..' will not be returned.
+
+ Returns:
+ ['list', 'of', 'files']. The entries '.' and '..' are never returned.
+ Other entries starting with a dot will only be returned if return_dotfiles
+ is True.
+ Raises:
+ GOSError: if there is an error retrieving the directory listing.
+ """
+ files = os.listdir(directory)
+ if not return_dotfiles:
+ files = [f for f in files if not f.startswith('.')]
+ return files
diff --git a/tensorflow/python/platform/default/_googletest.py b/tensorflow/python/platform/default/_googletest.py
new file mode 100644
index 0000000000..d2686565a0
--- /dev/null
+++ b/tensorflow/python/platform/default/_googletest.py
@@ -0,0 +1,68 @@
+"""Imports unittest as a replacement for testing.pybase.googletest."""
+import inspect
+import itertools
+import os
+import tempfile
+
+# pylint: disable=wildcard-import
+from unittest import *
+
+
+unittest_main = main
+
+
+# pylint: disable=invalid-name
+# pylint: disable=undefined-variable
+def main(*args, **kwargs):
+ """Delegate to unittest.main after redefining testLoader."""
+ if 'TEST_SHARD_STATUS_FILE' in os.environ:
+ try:
+ f = None
+ try:
+ f = open(os.environ['TEST_SHARD_STATUS_FILE'], 'w')
+ f.write('')
+ except IOError:
+ sys.stderr.write('Error opening TEST_SHARD_STATUS_FILE (%s). Exiting.'
+ % os.environ['TEST_SHARD_STATUS_FILE'])
+ sys.exit(1)
+ finally:
+ if f is not None: f.close()
+
+ if ('TEST_TOTAL_SHARDS' not in os.environ or
+ 'TEST_SHARD_INDEX' not in os.environ):
+ return unittest_main(*args, **kwargs)
+
+ total_shards = int(os.environ['TEST_TOTAL_SHARDS'])
+ shard_index = int(os.environ['TEST_SHARD_INDEX'])
+ base_loader = TestLoader()
+
+ delegate_get_names = base_loader.getTestCaseNames
+ bucket_iterator = itertools.cycle(range(total_shards))
+
+ def getShardedTestCaseNames(testCaseClass):
+ filtered_names = []
+ for testcase in sorted(delegate_get_names(testCaseClass)):
+ bucket = bucket_iterator.next()
+ if bucket == shard_index:
+ filtered_names.append(testcase)
+ return filtered_names
+
+ # Override getTestCaseNames
+ base_loader.getTestCaseNames = getShardedTestCaseNames
+
+ kwargs['testLoader'] = base_loader
+ unittest_main(*args, **kwargs)
+
+
+def GetTempDir():
+ first_frame = inspect.stack()[-1][0]
+ temp_dir = os.path.join(
+ tempfile.gettempdir(), os.path.basename(inspect.getfile(first_frame)))
+ temp_dir = temp_dir.rstrip('.py')
+ if not os.path.isdir(temp_dir):
+ os.mkdir(temp_dir, 0755)
+ return temp_dir
+
+
+def StatefulSessionAvailable():
+ return False
diff --git a/tensorflow/python/platform/default/_init.py b/tensorflow/python/platform/default/_init.py
new file mode 100644
index 0000000000..916d598856
--- /dev/null
+++ b/tensorflow/python/platform/default/_init.py
@@ -0,0 +1 @@
+# Nothing to do for default platform
diff --git a/tensorflow/python/platform/default/_logging.py b/tensorflow/python/platform/default/_logging.py
new file mode 100644
index 0000000000..2e289b1abe
--- /dev/null
+++ b/tensorflow/python/platform/default/_logging.py
@@ -0,0 +1,182 @@
+"""Logging utilities."""
+# pylint: disable=unused-import
+# pylint: disable=g-bad-import-order
+# pylint: disable=invalid-name
+import os
+import sys
+import time
+import thread
+from logging import getLogger
+from logging import log
+from logging import debug
+from logging import error
+from logging import fatal
+from logging import info
+from logging import warn
+from logging import warning
+from logging import DEBUG
+from logging import ERROR
+from logging import FATAL
+from logging import INFO
+from logging import WARN
+
+# Controls which methods from pyglib.logging are available within the project
+# Do not add methods here without also adding to platform/default/_logging.py
+__all__ = ['log', 'debug', 'error', 'fatal', 'info', 'warn', 'warning',
+ 'DEBUG', 'ERROR', 'FATAL', 'INFO', 'WARN',
+ 'flush', 'log_every_n', 'log_first_n', 'vlog',
+ 'TaskLevelStatusMessage', 'get_verbosity', 'set_verbosity']
+
+warning = warn
+
+_level_names = {
+ FATAL: 'FATAL',
+ ERROR: 'ERROR',
+ WARN: 'WARN',
+ INFO: 'INFO',
+ DEBUG: 'DEBUG',
+}
+
+# Mask to convert integer thread ids to unsigned quantities for logging
+# purposes
+_THREAD_ID_MASK = 2 * sys.maxint + 1
+
+_log_prefix = None # later set to google2_log_prefix
+
+# Counter to keep track of number of log entries per token.
+_log_counter_per_token = {}
+
+
+def TaskLevelStatusMessage(msg):
+ error(msg)
+
+
+def flush():
+ raise NotImplementedError()
+
+
+# Code below is taken from pyglib/logging
+def vlog(level, msg, *args, **kwargs):
+ log(level, msg, *args, **kwargs)
+
+
+def _GetNextLogCountPerToken(token):
+ """Wrapper for _log_counter_per_token.
+
+ Args:
+ token: The token for which to look up the count.
+
+ Returns:
+ The number of times this function has been called with
+ *token* as an argument (starting at 0)
+ """
+ global _log_counter_per_token # pylint: disable=global-variable-not-assigned
+ _log_counter_per_token[token] = 1 + _log_counter_per_token.get(token, -1)
+ return _log_counter_per_token[token]
+
+
+def log_every_n(level, msg, n, *args):
+ """Log 'msg % args' at level 'level' once per 'n' times.
+
+ Logs the 1st call, (N+1)st call, (2N+1)st call, etc.
+ Not threadsafe.
+
+ Args:
+ level: The level at which to log.
+ msg: The message to be logged.
+ n: The number of times this should be called before it is logged.
+ *args: The args to be substituted into the msg.
+ """
+ count = _GetNextLogCountPerToken(_GetFileAndLine())
+ log_if(level, msg, not (count % n), *args)
+
+
+def log_first_n(level, msg, n, *args): # pylint: disable=g-bad-name
+ """Log 'msg % args' at level 'level' only first 'n' times.
+
+ Not threadsafe.
+
+ Args:
+ level: The level at which to log.
+ msg: The message to be logged.
+ n: The number of times this should be called before it is logged.
+ *args: The args to be substituted into the msg.
+ """
+ count = _GetNextLogCountPerToken(_GetFileAndLine())
+ log_if(level, msg, count < n, *args)
+
+
+def log_if(level, msg, condition, *args):
+ """Log 'msg % args' at level 'level' only if condition is fulfilled."""
+ if condition:
+ vlog(level, msg, *args)
+
+
+def _GetFileAndLine():
+ """Returns (filename, linenumber) for the stack frame."""
+ # Use sys._getframe(). This avoids creating a traceback object.
+ # pylint: disable=protected-access
+ f = sys._getframe()
+ # pylint: enable=protected-access
+ our_file = f.f_code.co_filename
+ f = f.f_back
+ while f:
+ code = f.f_code
+ if code.co_filename != our_file:
+ return (code.co_filename, f.f_lineno)
+ f = f.f_back
+ return ('<unknown>', 0)
+
+
+def google2_log_prefix(level, timestamp=None, file_and_line=None):
+ """Assemble a logline prefix using the google2 format."""
+ # pylint: disable=global-variable-not-assigned
+ global _level_names
+ global _logfile_map, _logfile_map_mutex
+ # pylint: enable=global-variable-not-assigned
+
+ # Record current time
+ now = timestamp or time.time()
+ now_tuple = time.localtime(now)
+ now_microsecond = int(1e6 * (now % 1.0))
+
+ (filename, line) = file_and_line or _GetFileAndLine()
+ basename = os.path.basename(filename)
+
+ # Severity string
+ severity = 'I'
+ if level in _level_names:
+ severity = _level_names[level][0]
+
+ s = '%c%02d%02d %02d:%02d:%02d.%06d %5d %s:%d] ' % (
+ severity,
+ now_tuple[1], # month
+ now_tuple[2], # day
+ now_tuple[3], # hour
+ now_tuple[4], # min
+ now_tuple[5], # sec
+ now_microsecond,
+ _get_thread_id(),
+ basename,
+ line)
+
+ return s
+
+
+def get_verbosity():
+ """Return how much logging output will be produced."""
+ return getLogger().getEffectiveLevel()
+
+
+def set_verbosity(verbosity):
+ """Sets the threshold for what messages will be logged."""
+ getLogger().setLevel(verbosity)
+
+
+def _get_thread_id():
+ """Get id of current thread, suitable for logging as an unsigned quantity."""
+ thread_id = thread.get_ident()
+ return thread_id & _THREAD_ID_MASK
+
+
+_log_prefix = google2_log_prefix
diff --git a/tensorflow/python/platform/default/_parameterized.py b/tensorflow/python/platform/default/_parameterized.py
new file mode 100644
index 0000000000..5d141568ed
--- /dev/null
+++ b/tensorflow/python/platform/default/_parameterized.py
@@ -0,0 +1,2 @@
+"""Extension to unittest to run parameterized tests."""
+raise ImportError("Not implemented yet.")
diff --git a/tensorflow/python/platform/default/_resource_loader.py b/tensorflow/python/platform/default/_resource_loader.py
new file mode 100644
index 0000000000..69f425072f
--- /dev/null
+++ b/tensorflow/python/platform/default/_resource_loader.py
@@ -0,0 +1,26 @@
+"""Read a file and return its contents."""
+
+import os.path
+
+from tensorflow.python.platform import logging
+
+
+def load_resource(path):
+ """Load the resource at given path, where path is relative to tensorflow/.
+
+ Args:
+ path: a string resource path relative to tensorflow/.
+
+ Returns:
+ The contents of that resource.
+
+ Raises:
+ IOError: If the path is not found, or the resource can't be opened.
+ """
+ path = os.path.join('tensorflow', path)
+ path = os.path.abspath(path)
+ try:
+ with open(path, 'rb') as f:
+ return f.read()
+ except IOError as e:
+ logging.warning('IOError %s on path %s' % (e, path))
diff --git a/tensorflow/python/platform/default/_status_bar.py b/tensorflow/python/platform/default/_status_bar.py
new file mode 100644
index 0000000000..2953908724
--- /dev/null
+++ b/tensorflow/python/platform/default/_status_bar.py
@@ -0,0 +1,5 @@
+"""A no-op implementation of status bar functions."""
+
+
+def SetupStatusBarInsideGoogle(unused_link_text, unused_port):
+ pass
diff --git a/tensorflow/python/platform/default/flags_test.py b/tensorflow/python/platform/default/flags_test.py
new file mode 100644
index 0000000000..1b15ca138a
--- /dev/null
+++ b/tensorflow/python/platform/default/flags_test.py
@@ -0,0 +1,53 @@
+"""Tests for our flags implementation."""
+import sys
+
+from tensorflow.python.platform.default import _googletest as googletest
+
+from tensorflow.python.platform.default import _flags as flags
+
+
+flags.DEFINE_string("string_foo", "default_val", "HelpString")
+flags.DEFINE_boolean("bool_foo", True, "HelpString")
+flags.DEFINE_integer("int_foo", 42, "HelpString")
+flags.DEFINE_float("float_foo", 42.0, "HelpString")
+
+FLAGS = flags.FLAGS
+
+class FlagsTest(googletest.TestCase):
+
+ def testString(self):
+ res = FLAGS.string_foo
+ self.assertEqual(res, "default_val")
+ FLAGS.string_foo = "bar"
+ self.assertEqual("bar", FLAGS.string_foo)
+
+ def testBool(self):
+ res = FLAGS.bool_foo
+ self.assertTrue(res)
+ FLAGS.bool_foo = False
+ self.assertFalse(FLAGS.bool_foo)
+
+ def testNoBool(self):
+ FLAGS.bool_foo = True
+ try:
+ sys.argv.append("--nobool_foo")
+ FLAGS._parse_flags()
+ self.assertFalse(FLAGS.bool_foo)
+ finally:
+ sys.argv.pop()
+
+ def testInt(self):
+ res = FLAGS.int_foo
+ self.assertEquals(res, 42)
+ FLAGS.int_foo = -1
+ self.assertEqual(-1, FLAGS.int_foo)
+
+ def testFloat(self):
+ res = FLAGS.float_foo
+ self.assertEquals(42.0, res)
+ FLAGS.float_foo = -1.0
+ self.assertEqual(-1.0, FLAGS.float_foo)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/platform/default/gfile_test.py b/tensorflow/python/platform/default/gfile_test.py
new file mode 100644
index 0000000000..9eec952e95
--- /dev/null
+++ b/tensorflow/python/platform/default/gfile_test.py
@@ -0,0 +1,147 @@
+import os
+import shutil
+
+from tensorflow.python.platform.default import _gfile as gfile
+from tensorflow.python.platform.default import _googletest as googletest
+from tensorflow.python.platform.default import _logging as logging
+
+
+class _BaseTest(object):
+
+ @property
+ def tmp(self):
+ return self._tmp_dir
+
+ def setUp(self):
+ self._orig_dir = os.getcwd()
+ self._tmp_dir = googletest.GetTempDir() + "/"
+ try:
+ os.makedirs(self._tmp_dir)
+ except OSError:
+ pass # Directory already exists
+
+ def tearDown(self):
+ try:
+ shutil.rmtree(self._tmp_dir)
+ except OSError:
+ logging.warn("[%s] Post-test directory cleanup failed: %s"
+ % (self, self._tmp_dir))
+
+
+class _GFileBaseTest(_BaseTest):
+
+ @property
+ def gfile(self):
+ raise NotImplementedError("Do not use _GFileBaseTest directly.")
+
+ def testWith(self):
+ with self.gfile(self.tmp + "test_with", "w") as fh:
+ fh.write("hi")
+ with self.gfile(self.tmp + "test_with", "r") as fh:
+ self.assertEquals(fh.read(), "hi")
+
+ def testSizeAndTellAndSeek(self):
+ with self.gfile(self.tmp + "test_tell", "w") as fh:
+ fh.write("".join(["0"] * 1000))
+ with self.gfile(self.tmp + "test_tell", "r") as fh:
+ self.assertEqual(1000, fh.Size())
+ self.assertEqual(0, fh.tell())
+ fh.seek(0, 2)
+ self.assertEqual(1000, fh.tell())
+ fh.seek(0)
+ self.assertEqual(0, fh.tell())
+
+ def testReadAndWritelines(self):
+ with self.gfile(self.tmp + "test_writelines", "w") as fh:
+ fh.writelines(["%d\n" % d for d in range(10)])
+ with self.gfile(self.tmp + "test_writelines", "r") as fh:
+ self.assertEqual(["%d\n" % x for x in range(10)], fh.readlines())
+
+ def testWriteAndTruncate(self):
+ with self.gfile(self.tmp + "test_truncate", "w") as fh:
+ fh.write("ababab")
+ with self.gfile(self.tmp + "test_truncate", "a+") as fh:
+ fh.seek(0, 2)
+ fh.write("hjhjhj")
+ with self.gfile(self.tmp + "test_truncate", "a+") as fh:
+ self.assertEqual(fh.Size(), 12)
+ fh.truncate(6)
+ with self.gfile(self.tmp + "test_truncate", "r") as fh:
+ self.assertEqual(fh.read(), "ababab")
+
+ def testErrors(self):
+ self.assertRaises(
+ gfile.FileError, lambda: self.gfile(self.tmp + "doesnt_exist", "r"))
+ with self.gfile(self.tmp + "test_error", "w") as fh:
+ self.assertRaises(gfile.FileError, lambda: fh.seek(-1))
+ # test_error now exists, we can read from it:
+ with self.gfile(self.tmp + "test_error", "r") as fh:
+ self.assertRaises(gfile.FileError, lambda: fh.write("ack"))
+ fh = self.gfile(self.tmp + "test_error", "w")
+ self.assertFalse(fh.closed)
+ fh.close()
+ self.assertTrue(fh.closed)
+ self.assertRaises(gfile.FileError, lambda: fh.write("ack"))
+
+ def testIteration(self):
+ with self.gfile(self.tmp + "test_iter", "w") as fh:
+ fh.writelines(["a\n", "b\n", "c\n"])
+ with self.gfile(self.tmp + "test_iter", "r") as fh:
+ lines = list(fh)
+ self.assertEqual(["a\n", "b\n", "c\n"], lines)
+
+
+class GFileTest(_GFileBaseTest, googletest.TestCase):
+
+ @property
+ def gfile(self):
+ return gfile.GFile
+
+
+class FastGFileTest(_GFileBaseTest, googletest.TestCase):
+
+ @property
+ def gfile(self):
+ return gfile.FastGFile
+
+
+class FunctionTests(_BaseTest, googletest.TestCase):
+
+ def testExists(self):
+ self.assertFalse(gfile.Exists(self.tmp + "test_exists"))
+ with gfile.GFile(self.tmp + "test_exists", "w"):
+ pass
+ self.assertTrue(gfile.Exists(self.tmp + "test_exists"))
+
+ def testMkDirsGlobAndRmDirs(self):
+ self.assertFalse(gfile.Exists(self.tmp + "test_dir"))
+ gfile.MkDir(self.tmp + "test_dir")
+ self.assertTrue(gfile.Exists(self.tmp + "test_dir"))
+ gfile.RmDir(self.tmp + "test_dir")
+ self.assertFalse(gfile.Exists(self.tmp + "test_dir"))
+ gfile.MakeDirs(self.tmp + "test_dir/blah0")
+ gfile.MakeDirs(self.tmp + "test_dir/blah1")
+ self.assertEqual([self.tmp + "test_dir/blah0", self.tmp + "test_dir/blah1"],
+ sorted(gfile.Glob(self.tmp + "test_dir/*")))
+ gfile.DeleteRecursively(self.tmp + "test_dir")
+ self.assertFalse(gfile.Exists(self.tmp + "test_dir"))
+
+ def testErrors(self):
+ self.assertRaises(
+ gfile.GOSError, lambda: gfile.RmDir(self.tmp + "dir_doesnt_exist"))
+ self.assertRaises(
+ gfile.GOSError, lambda: gfile.Remove(self.tmp + "file_doesnt_exist"))
+ gfile.MkDir(self.tmp + "error_dir")
+ with gfile.GFile(self.tmp + "error_dir/file", "w"):
+ pass # Create file
+ self.assertRaises(
+ gfile.GOSError, lambda: gfile.Remove(self.tmp + "error_dir"))
+ self.assertRaises(
+ gfile.GOSError, lambda: gfile.RmDir(self.tmp + "error_dir"))
+ self.assertTrue(gfile.Exists(self.tmp + "error_dir"))
+ gfile.DeleteRecursively(self.tmp + "error_dir")
+ self.assertFalse(gfile.Exists(self.tmp + "error_dir"))
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/platform/default/logging_test.py b/tensorflow/python/platform/default/logging_test.py
new file mode 100644
index 0000000000..fd492bc384
--- /dev/null
+++ b/tensorflow/python/platform/default/logging_test.py
@@ -0,0 +1,13 @@
+from tensorflow.python.platform.default import _googletest as googletest
+from tensorflow.python.platform.default import _logging as logging
+
+
+class EventLoaderTest(googletest.TestCase):
+
+ def test_log(self):
+ # Just check that logging works without raising an exception.
+ logging.error("test log message")
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/platform/flags.py b/tensorflow/python/platform/flags.py
new file mode 100644
index 0000000000..d5b12d26df
--- /dev/null
+++ b/tensorflow/python/platform/flags.py
@@ -0,0 +1,10 @@
+"""Switch between depending on pyglib.flags or open-source gflags."""
+# pylint: disable=unused-import
+# pylint: disable=g-import-not-at-top
+# pylint: disable=wildcard-import
+import tensorflow.python.platform
+import control_imports
+if control_imports.USE_OSS and control_imports.OSS_FLAGS:
+ from tensorflow.python.platform.default._flags import *
+else:
+ from tensorflow.python.platform.google._flags import *
diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py
new file mode 100644
index 0000000000..fc28811821
--- /dev/null
+++ b/tensorflow/python/platform/gfile.py
@@ -0,0 +1,10 @@
+"""Switch between depending on pyglib.gfile or an OSS replacement."""
+# pylint: disable=unused-import
+# pylint: disable=g-import-not-at-top
+# pylint: disable=wildcard-import
+import tensorflow.python.platform
+import control_imports
+if control_imports.USE_OSS and control_imports.OSS_GFILE:
+ from tensorflow.python.platform.default._gfile import *
+else:
+ from tensorflow.python.platform.google._gfile import *
diff --git a/tensorflow/python/platform/googletest.py b/tensorflow/python/platform/googletest.py
new file mode 100644
index 0000000000..ca22ec6e6b
--- /dev/null
+++ b/tensorflow/python/platform/googletest.py
@@ -0,0 +1,10 @@
+"""Switch between depending on googletest or unittest."""
+# pylint: disable=unused-import
+# pylint: disable=g-import-not-at-top
+# pylint: disable=wildcard-import
+import tensorflow.python.platform
+import control_imports
+if control_imports.USE_OSS and control_imports.OSS_GOOGLETEST:
+ from tensorflow.python.platform.default._googletest import *
+else:
+ from tensorflow.python.platform.google._googletest import *
diff --git a/tensorflow/python/platform/logging.py b/tensorflow/python/platform/logging.py
new file mode 100644
index 0000000000..b6d2e53dd4
--- /dev/null
+++ b/tensorflow/python/platform/logging.py
@@ -0,0 +1,10 @@
+"""Switch between depending on pyglib.logging or regular logging."""
+# pylint: disable=unused-import
+# pylint: disable=g-import-not-at-top
+# pylint: disable=wildcard-import
+import tensorflow.python.platform
+import control_imports
+if control_imports.USE_OSS and control_imports.OSS_LOGGING:
+ from tensorflow.python.platform.default._logging import *
+else:
+ from tensorflow.python.platform.google._logging import *
diff --git a/tensorflow/python/platform/numpy.i b/tensorflow/python/platform/numpy.i
new file mode 100644
index 0000000000..217acd5bff
--- /dev/null
+++ b/tensorflow/python/platform/numpy.i
@@ -0,0 +1,3085 @@
+/* -*- C -*- (not really, but good for syntax highlighting) */
+#ifdef SWIGPYTHON
+
+%{
+#ifndef SWIG_FILE_WITH_INIT
+#define NO_IMPORT_ARRAY
+#endif
+#include "stdio.h"
+#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
+#include <numpy/arrayobject.h>
+%}
+
+/**********************************************************************/
+
+%fragment("NumPy_Backward_Compatibility", "header")
+{
+%#if NPY_API_VERSION < 0x00000007
+%#define NPY_ARRAY_DEFAULT NPY_DEFAULT
+%#define NPY_ARRAY_FARRAY NPY_FARRAY
+%#define NPY_FORTRANORDER NPY_FORTRAN
+%#endif
+}
+
+/**********************************************************************/
+
+/* The following code originally appeared in
+ * enthought/kiva/agg/src/numeric.i written by Eric Jones. It was
+ * translated from C++ to C by John Hunter. Bill Spotz has modified
+ * it to fix some minor bugs, upgrade from Numeric to numpy (all
+ * versions), add some comments and functionality, and convert from
+ * direct code insertion to SWIG fragments.
+ */
+
+%fragment("NumPy_Macros", "header")
+{
+/* Macros to extract array attributes.
+ */
+%#if NPY_API_VERSION < 0x00000007
+%#define is_array(a) ((a) && PyArray_Check((PyArrayObject*)a))
+%#define array_type(a) (int)(PyArray_TYPE((PyArrayObject*)a))
+%#define array_numdims(a) (((PyArrayObject*)a)->nd)
+%#define array_dimensions(a) (((PyArrayObject*)a)->dimensions)
+%#define array_size(a,i) (((PyArrayObject*)a)->dimensions[i])
+%#define array_strides(a) (((PyArrayObject*)a)->strides)
+%#define array_stride(a,i) (((PyArrayObject*)a)->strides[i])
+%#define array_data(a) (((PyArrayObject*)a)->data)
+%#define array_descr(a) (((PyArrayObject*)a)->descr)
+%#define array_flags(a) (((PyArrayObject*)a)->flags)
+%#define array_enableflags(a,f) (((PyArrayObject*)a)->flags) = f
+%#else
+%#define is_array(a) ((a) && PyArray_Check(a))
+%#define array_type(a) PyArray_TYPE((PyArrayObject*)a)
+%#define array_numdims(a) PyArray_NDIM((PyArrayObject*)a)
+%#define array_dimensions(a) PyArray_DIMS((PyArrayObject*)a)
+%#define array_strides(a) PyArray_STRIDES((PyArrayObject*)a)
+%#define array_stride(a,i) PyArray_STRIDE((PyArrayObject*)a,i)
+%#define array_size(a,i) PyArray_DIM((PyArrayObject*)a,i)
+%#define array_data(a) PyArray_DATA((PyArrayObject*)a)
+%#define array_descr(a) PyArray_DESCR((PyArrayObject*)a)
+%#define array_flags(a) PyArray_FLAGS((PyArrayObject*)a)
+%#define array_enableflags(a,f) PyArray_ENABLEFLAGS((PyArrayObject*)a,f)
+%#endif
+%#define array_is_contiguous(a) (PyArray_ISCONTIGUOUS((PyArrayObject*)a))
+%#define array_is_native(a) (PyArray_ISNOTSWAPPED((PyArrayObject*)a))
+%#define array_is_fortran(a) (PyArray_ISFORTRAN((PyArrayObject*)a))
+}
+
+/**********************************************************************/
+
+%fragment("NumPy_Utilities",
+ "header")
+{
+ /* Given a PyObject, return a string describing its type.
+ */
+ const char* pytype_string(PyObject* py_obj)
+ {
+ if (py_obj == NULL ) return "C NULL value";
+ if (py_obj == Py_None ) return "Python None" ;
+ if (PyCallable_Check(py_obj)) return "callable" ;
+ if (PyString_Check( py_obj)) return "string" ;
+ if (PyInt_Check( py_obj)) return "int" ;
+ if (PyFloat_Check( py_obj)) return "float" ;
+ if (PyDict_Check( py_obj)) return "dict" ;
+ if (PyList_Check( py_obj)) return "list" ;
+ if (PyTuple_Check( py_obj)) return "tuple" ;
+%#if PY_MAJOR_VERSION < 3
+ if (PyFile_Check( py_obj)) return "file" ;
+ if (PyModule_Check( py_obj)) return "module" ;
+ if (PyInstance_Check(py_obj)) return "instance" ;
+%#endif
+
+ return "unkown type";
+ }
+
+ /* Given a NumPy typecode, return a string describing the type.
+ */
+ const char* typecode_string(int typecode)
+ {
+ static const char* type_names[25] = {"bool",
+ "byte",
+ "unsigned byte",
+ "short",
+ "unsigned short",
+ "int",
+ "unsigned int",
+ "long",
+ "unsigned long",
+ "long long",
+ "unsigned long long",
+ "float",
+ "double",
+ "long double",
+ "complex float",
+ "complex double",
+ "complex long double",
+ "object",
+ "string",
+ "unicode",
+ "void",
+ "ntypes",
+ "notype",
+ "char",
+ "unknown"};
+ return typecode < 24 ? type_names[typecode] : type_names[24];
+ }
+
+ /* Make sure input has correct numpy type. This now just calls
+ PyArray_EquivTypenums().
+ */
+ int type_match(int actual_type,
+ int desired_type)
+ {
+ return PyArray_EquivTypenums(actual_type, desired_type);
+ }
+
+%#ifdef SWIGPY_USE_CAPSULE
+ void free_cap(PyObject * cap)
+ {
+ void* array = (void*) PyCapsule_GetPointer(cap,SWIGPY_CAPSULE_NAME);
+ if (array != NULL) free(array);
+ }
+%#endif
+
+
+}
+
+/**********************************************************************/
+
+%fragment("NumPy_Object_to_Array",
+ "header",
+ fragment="NumPy_Backward_Compatibility",
+ fragment="NumPy_Macros",
+ fragment="NumPy_Utilities")
+{
+ /* Given a PyObject pointer, cast it to a PyArrayObject pointer if
+ * legal. If not, set the python error string appropriately and
+ * return NULL.
+ */
+ PyArrayObject* obj_to_array_no_conversion(PyObject* input,
+ int typecode)
+ {
+ PyArrayObject* ary = NULL;
+ if (is_array(input) && (typecode == NPY_NOTYPE ||
+ PyArray_EquivTypenums(array_type(input), typecode)))
+ {
+ ary = (PyArrayObject*) input;
+ }
+ else if is_array(input)
+ {
+ const char* desired_type = typecode_string(typecode);
+ const char* actual_type = typecode_string(array_type(input));
+ PyErr_Format(PyExc_TypeError,
+ "Array of type '%s' required. Array of type '%s' given",
+ desired_type, actual_type);
+ ary = NULL;
+ }
+ else
+ {
+ const char* desired_type = typecode_string(typecode);
+ const char* actual_type = pytype_string(input);
+ PyErr_Format(PyExc_TypeError,
+ "Array of type '%s' required. A '%s' was given",
+ desired_type,
+ actual_type);
+ ary = NULL;
+ }
+ return ary;
+ }
+
+ /* Convert the given PyObject to a NumPy array with the given
+ * typecode. On success, return a valid PyArrayObject* with the
+ * correct type. On failure, the python error string will be set and
+ * the routine returns NULL.
+ */
+ PyArrayObject* obj_to_array_allow_conversion(PyObject* input,
+ int typecode,
+ int* is_new_object)
+ {
+ PyArrayObject* ary = NULL;
+ PyObject* py_obj;
+ if (is_array(input) && (typecode == NPY_NOTYPE ||
+ PyArray_EquivTypenums(array_type(input),typecode)))
+ {
+ ary = (PyArrayObject*) input;
+ *is_new_object = 0;
+ }
+ else
+ {
+ py_obj = PyArray_FROMANY(input, typecode, 0, 0, NPY_ARRAY_DEFAULT);
+ /* If NULL, PyArray_FromObject will have set python error value.*/
+ ary = (PyArrayObject*) py_obj;
+ *is_new_object = 1;
+ }
+ return ary;
+ }
+
+ /* Given a PyArrayObject, check to see if it is contiguous. If so,
+ * return the input pointer and flag it as not a new object. If it is
+ * not contiguous, create a new PyArrayObject using the original data,
+ * flag it as a new object and return the pointer.
+ */
+ PyArrayObject* make_contiguous(PyArrayObject* ary,
+ int* is_new_object,
+ int min_dims,
+ int max_dims)
+ {
+ PyArrayObject* result;
+ if (array_is_contiguous(ary))
+ {
+ result = ary;
+ *is_new_object = 0;
+ }
+ else
+ {
+ result = (PyArrayObject*) PyArray_ContiguousFromObject((PyObject*)ary,
+ array_type(ary),
+ min_dims,
+ max_dims);
+ *is_new_object = 1;
+ }
+ return result;
+ }
+
+ /* Given a PyArrayObject, check to see if it is Fortran-contiguous.
+ * If so, return the input pointer, but do not flag it as not a new
+ * object. If it is not Fortran-contiguous, create a new
+ * PyArrayObject using the original data, flag it as a new object
+ * and return the pointer.
+ */
+ PyArrayObject* make_fortran(PyArrayObject* ary,
+ int* is_new_object)
+ {
+ PyArrayObject* result;
+ if (array_is_fortran(ary))
+ {
+ result = ary;
+ *is_new_object = 0;
+ }
+ else
+ {
+ Py_INCREF(array_descr(ary));
+ result = (PyArrayObject*) PyArray_FromArray(ary,
+ array_descr(ary),
+ NPY_FORTRANORDER);
+ *is_new_object = 1;
+ }
+ return result;
+ }
+
+ /* Convert a given PyObject to a contiguous PyArrayObject of the
+ * specified type. If the input object is not a contiguous
+ * PyArrayObject, a new one will be created and the new object flag
+ * will be set.
+ */
+ PyArrayObject* obj_to_array_contiguous_allow_conversion(PyObject* input,
+ int typecode,
+ int* is_new_object)
+ {
+ int is_new1 = 0;
+ int is_new2 = 0;
+ PyArrayObject* ary2;
+ PyArrayObject* ary1 = obj_to_array_allow_conversion(input,
+ typecode,
+ &is_new1);
+ if (ary1)
+ {
+ ary2 = make_contiguous(ary1, &is_new2, 0, 0);
+ if ( is_new1 && is_new2)
+ {
+ Py_DECREF(ary1);
+ }
+ ary1 = ary2;
+ }
+ *is_new_object = is_new1 || is_new2;
+ return ary1;
+ }
+
+ /* Convert a given PyObject to a Fortran-ordered PyArrayObject of the
+ * specified type. If the input object is not a Fortran-ordered
+ * PyArrayObject, a new one will be created and the new object flag
+ * will be set.
+ */
+ PyArrayObject* obj_to_array_fortran_allow_conversion(PyObject* input,
+ int typecode,
+ int* is_new_object)
+ {
+ int is_new1 = 0;
+ int is_new2 = 0;
+ PyArrayObject* ary2;
+ PyArrayObject* ary1 = obj_to_array_allow_conversion(input,
+ typecode,
+ &is_new1);
+ if (ary1)
+ {
+ ary2 = make_fortran(ary1, &is_new2);
+ if (is_new1 && is_new2)
+ {
+ Py_DECREF(ary1);
+ }
+ ary1 = ary2;
+ }
+ *is_new_object = is_new1 || is_new2;
+ return ary1;
+ }
+} /* end fragment */
+
+/**********************************************************************/
+
+%fragment("NumPy_Array_Requirements",
+ "header",
+ fragment="NumPy_Backward_Compatibility",
+ fragment="NumPy_Macros")
+{
+ /* Test whether a python object is contiguous. If array is
+ * contiguous, return 1. Otherwise, set the python error string and
+ * return 0.
+ */
+ int require_contiguous(PyArrayObject* ary)
+ {
+ int contiguous = 1;
+ if (!array_is_contiguous(ary))
+ {
+ PyErr_SetString(PyExc_TypeError,
+ "Array must be contiguous. A non-contiguous array was given");
+ contiguous = 0;
+ }
+ return contiguous;
+ }
+
+ /* Require that a numpy array is not byte-swapped. If the array is
+ * not byte-swapped, return 1. Otherwise, set the python error string
+ * and return 0.
+ */
+ int require_native(PyArrayObject* ary)
+ {
+ int native = 1;
+ if (!array_is_native(ary))
+ {
+ PyErr_SetString(PyExc_TypeError,
+ "Array must have native byteorder. "
+ "A byte-swapped array was given");
+ native = 0;
+ }
+ return native;
+ }
+
+ /* Require the given PyArrayObject to have a specified number of
+ * dimensions. If the array has the specified number of dimensions,
+ * return 1. Otherwise, set the python error string and return 0.
+ */
+ int require_dimensions(PyArrayObject* ary,
+ int exact_dimensions)
+ {
+ int success = 1;
+ if (array_numdims(ary) != exact_dimensions)
+ {
+ PyErr_Format(PyExc_TypeError,
+ "Array must have %d dimensions. Given array has %d dimensions",
+ exact_dimensions,
+ array_numdims(ary));
+ success = 0;
+ }
+ return success;
+ }
+
+ /* Require the given PyArrayObject to have one of a list of specified
+ * number of dimensions. If the array has one of the specified number
+ * of dimensions, return 1. Otherwise, set the python error string
+ * and return 0.
+ */
+ int require_dimensions_n(PyArrayObject* ary,
+ int* exact_dimensions,
+ int n)
+ {
+ int success = 0;
+ int i;
+ char dims_str[255] = "";
+ char s[255];
+ for (i = 0; i < n && !success; i++)
+ {
+ if (array_numdims(ary) == exact_dimensions[i])
+ {
+ success = 1;
+ }
+ }
+ if (!success)
+ {
+ for (i = 0; i < n-1; i++)
+ {
+ sprintf(s, "%d, ", exact_dimensions[i]);
+ strcat(dims_str,s);
+ }
+ sprintf(s, " or %d", exact_dimensions[n-1]);
+ strcat(dims_str,s);
+ PyErr_Format(PyExc_TypeError,
+ "Array must have %s dimensions. Given array has %d dimensions",
+ dims_str,
+ array_numdims(ary));
+ }
+ return success;
+ }
+
+ /* Require the given PyArrayObject to have a specified shape. If the
+ * array has the specified shape, return 1. Otherwise, set the python
+ * error string and return 0.
+ */
+ int require_size(PyArrayObject* ary,
+ npy_intp* size,
+ int n)
+ {
+ int i;
+ int success = 1;
+ int len;
+ char desired_dims[255] = "[";
+ char s[255];
+ char actual_dims[255] = "[";
+ for(i=0; i < n;i++)
+ {
+ if (size[i] != -1 && size[i] != array_size(ary,i))
+ {
+ success = 0;
+ }
+ }
+ if (!success)
+ {
+ for (i = 0; i < n; i++)
+ {
+ if (size[i] == -1)
+ {
+ sprintf(s, "*,");
+ }
+ else
+ {
+ sprintf(s, "%ld,", (long int)size[i]);
+ }
+ strcat(desired_dims,s);
+ }
+ len = strlen(desired_dims);
+ desired_dims[len-1] = ']';
+ for (i = 0; i < n; i++)
+ {
+ sprintf(s, "%ld,", (long int)array_size(ary,i));
+ strcat(actual_dims,s);
+ }
+ len = strlen(actual_dims);
+ actual_dims[len-1] = ']';
+ PyErr_Format(PyExc_TypeError,
+ "Array must have shape of %s. Given array has shape of %s",
+ desired_dims,
+ actual_dims);
+ }
+ return success;
+ }
+
+ /* Require the given PyArrayObject to to be Fortran ordered. If the
+ * the PyArrayObject is already Fortran ordered, do nothing. Else,
+ * set the Fortran ordering flag and recompute the strides.
+ */
+ int require_fortran(PyArrayObject* ary)
+ {
+ int success = 1;
+ int nd = array_numdims(ary);
+ int i;
+ npy_intp * strides = array_strides(ary);
+ if (array_is_fortran(ary)) return success;
+ /* Set the Fortran ordered flag */
+ array_enableflags(ary,NPY_ARRAY_FARRAY);
+ /* Recompute the strides */
+ strides[0] = strides[nd-1];
+ for (i=1; i < nd; ++i)
+ strides[i] = strides[i-1] * array_size(ary,i-1);
+ return success;
+ }
+}
+
+/* Combine all NumPy fragments into one for convenience */
+%fragment("NumPy_Fragments",
+ "header",
+ fragment="NumPy_Backward_Compatibility",
+ fragment="NumPy_Macros",
+ fragment="NumPy_Utilities",
+ fragment="NumPy_Object_to_Array",
+ fragment="NumPy_Array_Requirements")
+{
+}
+
+/* End John Hunter translation (with modifications by Bill Spotz)
+ */
+
+/* %numpy_typemaps() macro
+ *
+ * This macro defines a family of 74 typemaps that allow C arguments
+ * of the form
+ *
+ * 1. (DATA_TYPE IN_ARRAY1[ANY])
+ * 2. (DATA_TYPE* IN_ARRAY1, DIM_TYPE DIM1)
+ * 3. (DIM_TYPE DIM1, DATA_TYPE* IN_ARRAY1)
+ *
+ * 4. (DATA_TYPE IN_ARRAY2[ANY][ANY])
+ * 5. (DATA_TYPE* IN_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+ * 6. (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_ARRAY2)
+ * 7. (DATA_TYPE* IN_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+ * 8. (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_FARRAY2)
+ *
+ * 9. (DATA_TYPE IN_ARRAY3[ANY][ANY][ANY])
+ * 10. (DATA_TYPE* IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+ * 11. (DATA_TYPE** IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+ * 12. (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* IN_ARRAY3)
+ * 13. (DATA_TYPE* IN_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+ * 14. (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* IN_FARRAY3)
+ *
+ * 15. (DATA_TYPE IN_ARRAY4[ANY][ANY][ANY][ANY])
+ * 16. (DATA_TYPE* IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+ * 17. (DATA_TYPE** IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+ * 18. (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, , DIM_TYPE DIM4, DATA_TYPE* IN_ARRAY4)
+ * 19. (DATA_TYPE* IN_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+ * 20. (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* IN_FARRAY4)
+ *
+ * 21. (DATA_TYPE INPLACE_ARRAY1[ANY])
+ * 22. (DATA_TYPE* INPLACE_ARRAY1, DIM_TYPE DIM1)
+ * 23. (DIM_TYPE DIM1, DATA_TYPE* INPLACE_ARRAY1)
+ *
+ * 24. (DATA_TYPE INPLACE_ARRAY2[ANY][ANY])
+ * 25. (DATA_TYPE* INPLACE_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+ * 26. (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* INPLACE_ARRAY2)
+ * 27. (DATA_TYPE* INPLACE_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+ * 28. (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* INPLACE_FARRAY2)
+ *
+ * 29. (DATA_TYPE INPLACE_ARRAY3[ANY][ANY][ANY])
+ * 30. (DATA_TYPE* INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+ * 31. (DATA_TYPE** INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+ * 32. (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* INPLACE_ARRAY3)
+ * 33. (DATA_TYPE* INPLACE_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+ * 34. (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* INPLACE_FARRAY3)
+ *
+ * 35. (DATA_TYPE INPLACE_ARRAY4[ANY][ANY][ANY][ANY])
+ * 36. (DATA_TYPE* INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+ * 37. (DATA_TYPE** INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+ * 38. (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* INPLACE_ARRAY4)
+ * 39. (DATA_TYPE* INPLACE_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+ * 40. (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* INPLACE_FARRAY4)
+ *
+ * 41. (DATA_TYPE ARGOUT_ARRAY1[ANY])
+ * 42. (DATA_TYPE* ARGOUT_ARRAY1, DIM_TYPE DIM1)
+ * 43. (DIM_TYPE DIM1, DATA_TYPE* ARGOUT_ARRAY1)
+ *
+ * 44. (DATA_TYPE ARGOUT_ARRAY2[ANY][ANY])
+ *
+ * 45. (DATA_TYPE ARGOUT_ARRAY3[ANY][ANY][ANY])
+ *
+ * 46. (DATA_TYPE ARGOUT_ARRAY4[ANY][ANY][ANY][ANY])
+ *
+ * 47. (DATA_TYPE** ARGOUTVIEW_ARRAY1, DIM_TYPE* DIM1)
+ * 48. (DIM_TYPE* DIM1, DATA_TYPE** ARGOUTVIEW_ARRAY1)
+ *
+ * 49. (DATA_TYPE** ARGOUTVIEW_ARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2)
+ * 50. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEW_ARRAY2)
+ * 51. (DATA_TYPE** ARGOUTVIEW_FARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2)
+ * 52. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEW_FARRAY2)
+ *
+ * 53. (DATA_TYPE** ARGOUTVIEW_ARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3)
+ * 54. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEW_ARRAY3)
+ * 55. (DATA_TYPE** ARGOUTVIEW_FARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3)
+ * 56. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEW_FARRAY3)
+ *
+ * 57. (DATA_TYPE** ARGOUTVIEW_ARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4)
+ * 58. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEW_ARRAY4)
+ * 59. (DATA_TYPE** ARGOUTVIEW_FARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4)
+ * 60. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEW_FARRAY4)
+ *
+ * 61. (DATA_TYPE** ARGOUTVIEWM_ARRAY1, DIM_TYPE* DIM1)
+ * 62. (DIM_TYPE* DIM1, DATA_TYPE** ARGOUTVIEWM_ARRAY1)
+ *
+ * 63. (DATA_TYPE** ARGOUTVIEWM_ARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2)
+ * 64. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEWM_ARRAY2)
+ * 65. (DATA_TYPE** ARGOUTVIEWM_FARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2)
+ * 66. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEWM_FARRAY2)
+ *
+ * 67. (DATA_TYPE** ARGOUTVIEWM_ARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3)
+ * 68. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEWM_ARRAY3)
+ * 69. (DATA_TYPE** ARGOUTVIEWM_FARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3)
+ * 70. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEWM_FARRAY3)
+ *
+ * 71. (DATA_TYPE** ARGOUTVIEWM_ARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4)
+ * 72. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEWM_ARRAY4)
+ * 73. (DATA_TYPE** ARGOUTVIEWM_FARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4)
+ * 74. (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEWM_FARRAY4)
+ *
+ * where "DATA_TYPE" is any type supported by the NumPy module, and
+ * "DIM_TYPE" is any int-like type suitable for specifying dimensions.
+ * The difference between "ARRAY" typemaps and "FARRAY" typemaps is
+ * that the "FARRAY" typemaps expect Fortran ordering of
+ * multidimensional arrays. In python, the dimensions will not need
+ * to be specified (except for the "DATA_TYPE* ARGOUT_ARRAY1"
+ * typemaps). The IN_ARRAYs can be a numpy array or any sequence that
+ * can be converted to a numpy array of the specified type. The
+ * INPLACE_ARRAYs must be numpy arrays of the appropriate type. The
+ * ARGOUT_ARRAYs will be returned as new numpy arrays of the
+ * appropriate type.
+ *
+ * These typemaps can be applied to existing functions using the
+ * %apply directive. For example:
+ *
+ * %apply (double* IN_ARRAY1, int DIM1) {(double* series, int length)};
+ * double prod(double* series, int length);
+ *
+ * %apply (int DIM1, int DIM2, double* INPLACE_ARRAY2)
+ * {(int rows, int cols, double* matrix )};
+ * void floor(int rows, int cols, double* matrix, double f);
+ *
+ * %apply (double IN_ARRAY3[ANY][ANY][ANY])
+ * {(double tensor[2][2][2] )};
+ * %apply (double ARGOUT_ARRAY3[ANY][ANY][ANY])
+ * {(double low[2][2][2] )};
+ * %apply (double ARGOUT_ARRAY3[ANY][ANY][ANY])
+ * {(double upp[2][2][2] )};
+ * void luSplit(double tensor[2][2][2],
+ * double low[2][2][2],
+ * double upp[2][2][2] );
+ *
+ * or directly with
+ *
+ * double prod(double* IN_ARRAY1, int DIM1);
+ *
+ * void floor(int DIM1, int DIM2, double* INPLACE_ARRAY2, double f);
+ *
+ * void luSplit(double IN_ARRAY3[ANY][ANY][ANY],
+ * double ARGOUT_ARRAY3[ANY][ANY][ANY],
+ * double ARGOUT_ARRAY3[ANY][ANY][ANY]);
+ */
+
+%define %numpy_typemaps(DATA_TYPE, DATA_TYPECODE, DIM_TYPE)
+
+/************************/
+/* Input Array Typemaps */
+/************************/
+
+/* Typemap suite for (DATA_TYPE IN_ARRAY1[ANY])
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE IN_ARRAY1[ANY])
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE IN_ARRAY1[ANY])
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[1] = { $1_dim0 };
+ array = obj_to_array_contiguous_allow_conversion($input,
+ DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 1) ||
+ !require_size(array, size, 1)) SWIG_fail;
+ $1 = ($1_ltype) array_data(array);
+}
+%typemap(freearg)
+ (DATA_TYPE IN_ARRAY1[ANY])
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DATA_TYPE* IN_ARRAY1, DIM_TYPE DIM1)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE* IN_ARRAY1, DIM_TYPE DIM1)
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE* IN_ARRAY1, DIM_TYPE DIM1)
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[1] = { -1 };
+ array = obj_to_array_contiguous_allow_conversion($input,
+ DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 1) ||
+ !require_size(array, size, 1)) SWIG_fail;
+ $1 = (DATA_TYPE*) array_data(array);
+ $2 = (DIM_TYPE) array_size(array,0);
+}
+%typemap(freearg)
+ (DATA_TYPE* IN_ARRAY1, DIM_TYPE DIM1)
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DIM_TYPE DIM1, DATA_TYPE* IN_ARRAY1)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DIM_TYPE DIM1, DATA_TYPE* IN_ARRAY1)
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DIM_TYPE DIM1, DATA_TYPE* IN_ARRAY1)
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[1] = {-1};
+ array = obj_to_array_contiguous_allow_conversion($input,
+ DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 1) ||
+ !require_size(array, size, 1)) SWIG_fail;
+ $1 = (DIM_TYPE) array_size(array,0);
+ $2 = (DATA_TYPE*) array_data(array);
+}
+%typemap(freearg)
+ (DIM_TYPE DIM1, DATA_TYPE* IN_ARRAY1)
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DATA_TYPE IN_ARRAY2[ANY][ANY])
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE IN_ARRAY2[ANY][ANY])
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE IN_ARRAY2[ANY][ANY])
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[2] = { $1_dim0, $1_dim1 };
+ array = obj_to_array_contiguous_allow_conversion($input,
+ DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 2) ||
+ !require_size(array, size, 2)) SWIG_fail;
+ $1 = ($1_ltype) array_data(array);
+}
+%typemap(freearg)
+ (DATA_TYPE IN_ARRAY2[ANY][ANY])
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DATA_TYPE* IN_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE* IN_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE* IN_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[2] = { -1, -1 };
+ array = obj_to_array_contiguous_allow_conversion($input, DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 2) ||
+ !require_size(array, size, 2)) SWIG_fail;
+ $1 = (DATA_TYPE*) array_data(array);
+ $2 = (DIM_TYPE) array_size(array,0);
+ $3 = (DIM_TYPE) array_size(array,1);
+}
+%typemap(freearg)
+ (DATA_TYPE* IN_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_ARRAY2)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_ARRAY2)
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_ARRAY2)
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[2] = { -1, -1 };
+ array = obj_to_array_contiguous_allow_conversion($input,
+ DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 2) ||
+ !require_size(array, size, 2)) SWIG_fail;
+ $1 = (DIM_TYPE) array_size(array,0);
+ $2 = (DIM_TYPE) array_size(array,1);
+ $3 = (DATA_TYPE*) array_data(array);
+}
+%typemap(freearg)
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_ARRAY2)
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DATA_TYPE* IN_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE* IN_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE* IN_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[2] = { -1, -1 };
+ array = obj_to_array_fortran_allow_conversion($input,
+ DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 2) ||
+ !require_size(array, size, 2) || !require_fortran(array)) SWIG_fail;
+ $1 = (DATA_TYPE*) array_data(array);
+ $2 = (DIM_TYPE) array_size(array,0);
+ $3 = (DIM_TYPE) array_size(array,1);
+}
+%typemap(freearg)
+ (DATA_TYPE* IN_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_FARRAY2)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_FARRAY2)
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_FARRAY2)
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[2] = { -1, -1 };
+ array = obj_to_array_fortran_allow_conversion($input,
+ DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 2) ||
+ !require_size(array, size, 2) || !require_fortran(array)) SWIG_fail;
+ $1 = (DIM_TYPE) array_size(array,0);
+ $2 = (DIM_TYPE) array_size(array,1);
+ $3 = (DATA_TYPE*) array_data(array);
+}
+%typemap(freearg)
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* IN_FARRAY2)
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DATA_TYPE IN_ARRAY3[ANY][ANY][ANY])
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE IN_ARRAY3[ANY][ANY][ANY])
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE IN_ARRAY3[ANY][ANY][ANY])
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[3] = { $1_dim0, $1_dim1, $1_dim2 };
+ array = obj_to_array_contiguous_allow_conversion($input,
+ DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 3) ||
+ !require_size(array, size, 3)) SWIG_fail;
+ $1 = ($1_ltype) array_data(array);
+}
+%typemap(freearg)
+ (DATA_TYPE IN_ARRAY3[ANY][ANY][ANY])
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DATA_TYPE* IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2,
+ * DIM_TYPE DIM3)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE* IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE* IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[3] = { -1, -1, -1 };
+ array = obj_to_array_contiguous_allow_conversion($input, DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 3) ||
+ !require_size(array, size, 3)) SWIG_fail;
+ $1 = (DATA_TYPE*) array_data(array);
+ $2 = (DIM_TYPE) array_size(array,0);
+ $3 = (DIM_TYPE) array_size(array,1);
+ $4 = (DIM_TYPE) array_size(array,2);
+}
+%typemap(freearg)
+ (DATA_TYPE* IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DATA_TYPE** IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2,
+ * DIM_TYPE DIM3)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE** IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+{
+ /* for now, only concerned with lists */
+ $1 = PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE** IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+ (DATA_TYPE** array=NULL, PyArrayObject** object_array=NULL, int* is_new_object_array=NULL)
+{
+ npy_intp size[2] = { -1, -1 };
+ PyArrayObject* temp_array;
+ Py_ssize_t i;
+ int is_new_object;
+
+ /* length of the list */
+ $2 = PyList_Size($input);
+
+ /* the arrays */
+ array = (DATA_TYPE **)malloc($2*sizeof(DATA_TYPE *));
+ object_array = (PyArrayObject **)calloc($2,sizeof(PyArrayObject *));
+ is_new_object_array = (int *)calloc($2,sizeof(int));
+
+ if (array == NULL || object_array == NULL || is_new_object_array == NULL)
+ {
+ SWIG_fail;
+ }
+
+ for (i=0; i<$2; i++)
+ {
+ temp_array = obj_to_array_contiguous_allow_conversion(PySequence_GetItem($input,i), DATA_TYPECODE, &is_new_object);
+
+ /* the new array must be stored so that it can be destroyed in freearg */
+ object_array[i] = temp_array;
+ is_new_object_array[i] = is_new_object;
+
+ if (!temp_array || !require_dimensions(temp_array, 2)) SWIG_fail;
+
+ /* store the size of the first array in the list, then use that for comparison. */
+ if (i == 0)
+ {
+ size[0] = array_size(temp_array,0);
+ size[1] = array_size(temp_array,1);
+ }
+
+ if (!require_size(temp_array, size, 2)) SWIG_fail;
+
+ array[i] = (DATA_TYPE*) array_data(temp_array);
+ }
+
+ $1 = (DATA_TYPE**) array;
+ $3 = (DIM_TYPE) size[0];
+ $4 = (DIM_TYPE) size[1];
+}
+%typemap(freearg)
+ (DATA_TYPE** IN_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+{
+ Py_ssize_t i;
+
+ if (array$argnum!=NULL) free(array$argnum);
+
+ /*freeing the individual arrays if needed */
+ if (object_array$argnum!=NULL)
+ {
+ if (is_new_object_array$argnum!=NULL)
+ {
+ for (i=0; i<$2; i++)
+ {
+ if (object_array$argnum[i] != NULL && is_new_object_array$argnum[i])
+ { Py_DECREF(object_array$argnum[i]); }
+ }
+ free(is_new_object_array$argnum);
+ }
+ free(object_array$argnum);
+ }
+}
+
+/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3,
+ * DATA_TYPE* IN_ARRAY3)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* IN_ARRAY3)
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* IN_ARRAY3)
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[3] = { -1, -1, -1 };
+ array = obj_to_array_contiguous_allow_conversion($input, DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 3) ||
+ !require_size(array, size, 3)) SWIG_fail;
+ $1 = (DIM_TYPE) array_size(array,0);
+ $2 = (DIM_TYPE) array_size(array,1);
+ $3 = (DIM_TYPE) array_size(array,2);
+ $4 = (DATA_TYPE*) array_data(array);
+}
+%typemap(freearg)
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* IN_ARRAY3)
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DATA_TYPE* IN_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2,
+ * DIM_TYPE DIM3)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE* IN_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE* IN_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[3] = { -1, -1, -1 };
+ array = obj_to_array_fortran_allow_conversion($input, DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 3) ||
+ !require_size(array, size, 3) | !require_fortran(array)) SWIG_fail;
+ $1 = (DATA_TYPE*) array_data(array);
+ $2 = (DIM_TYPE) array_size(array,0);
+ $3 = (DIM_TYPE) array_size(array,1);
+ $4 = (DIM_TYPE) array_size(array,2);
+}
+%typemap(freearg)
+ (DATA_TYPE* IN_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3,
+ * DATA_TYPE* IN_FARRAY3)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* IN_FARRAY3)
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* IN_FARRAY3)
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[3] = { -1, -1, -1 };
+ array = obj_to_array_fortran_allow_conversion($input,
+ DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 3) ||
+ !require_size(array, size, 3) || !require_fortran(array)) SWIG_fail;
+ $1 = (DIM_TYPE) array_size(array,0);
+ $2 = (DIM_TYPE) array_size(array,1);
+ $3 = (DIM_TYPE) array_size(array,2);
+ $4 = (DATA_TYPE*) array_data(array);
+}
+%typemap(freearg)
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* IN_FARRAY3)
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DATA_TYPE IN_ARRAY4[ANY][ANY][ANY][ANY])
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE IN_ARRAY4[ANY][ANY][ANY][ANY])
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE IN_ARRAY4[ANY][ANY][ANY][ANY])
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[4] = { $1_dim0, $1_dim1, $1_dim2 , $1_dim3};
+ array = obj_to_array_contiguous_allow_conversion($input, DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 4) ||
+ !require_size(array, size, 4)) SWIG_fail;
+ $1 = ($1_ltype) array_data(array);
+}
+%typemap(freearg)
+ (DATA_TYPE IN_ARRAY4[ANY][ANY][ANY][ANY])
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DATA_TYPE* IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2,
+ * DIM_TYPE DIM3, DIM_TYPE DIM4)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE* IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE* IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[4] = { -1, -1, -1, -1 };
+ array = obj_to_array_contiguous_allow_conversion($input, DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 4) ||
+ !require_size(array, size, 4)) SWIG_fail;
+ $1 = (DATA_TYPE*) array_data(array);
+ $2 = (DIM_TYPE) array_size(array,0);
+ $3 = (DIM_TYPE) array_size(array,1);
+ $4 = (DIM_TYPE) array_size(array,2);
+ $5 = (DIM_TYPE) array_size(array,3);
+}
+%typemap(freearg)
+ (DATA_TYPE* IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DATA_TYPE** IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2,
+ * DIM_TYPE DIM3, DIM_TYPE DIM4)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE** IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+{
+ /* for now, only concerned with lists */
+ $1 = PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE** IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+ (DATA_TYPE** array=NULL, PyArrayObject** object_array=NULL, int* is_new_object_array=NULL)
+{
+ npy_intp size[3] = { -1, -1, -1 };
+ PyArrayObject* temp_array;
+ Py_ssize_t i;
+ int is_new_object;
+
+ /* length of the list */
+ $2 = PyList_Size($input);
+
+ /* the arrays */
+ array = (DATA_TYPE **)malloc($2*sizeof(DATA_TYPE *));
+ object_array = (PyArrayObject **)calloc($2,sizeof(PyArrayObject *));
+ is_new_object_array = (int *)calloc($2,sizeof(int));
+
+ if (array == NULL || object_array == NULL || is_new_object_array == NULL)
+ {
+ SWIG_fail;
+ }
+
+ for (i=0; i<$2; i++)
+ {
+ temp_array = obj_to_array_contiguous_allow_conversion(PySequence_GetItem($input,i), DATA_TYPECODE, &is_new_object);
+
+ /* the new array must be stored so that it can be destroyed in freearg */
+ object_array[i] = temp_array;
+ is_new_object_array[i] = is_new_object;
+
+ if (!temp_array || !require_dimensions(temp_array, 3)) SWIG_fail;
+
+ /* store the size of the first array in the list, then use that for comparison. */
+ if (i == 0)
+ {
+ size[0] = array_size(temp_array,0);
+ size[1] = array_size(temp_array,1);
+ size[2] = array_size(temp_array,2);
+ }
+
+ if (!require_size(temp_array, size, 3)) SWIG_fail;
+
+ array[i] = (DATA_TYPE*) array_data(temp_array);
+ }
+
+ $1 = (DATA_TYPE**) array;
+ $3 = (DIM_TYPE) size[0];
+ $4 = (DIM_TYPE) size[1];
+ $5 = (DIM_TYPE) size[2];
+}
+%typemap(freearg)
+ (DATA_TYPE** IN_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+{
+ Py_ssize_t i;
+
+ if (array$argnum!=NULL) free(array$argnum);
+
+ /*freeing the individual arrays if needed */
+ if (object_array$argnum!=NULL)
+ {
+ if (is_new_object_array$argnum!=NULL)
+ {
+ for (i=0; i<$2; i++)
+ {
+ if (object_array$argnum[i] != NULL && is_new_object_array$argnum[i])
+ { Py_DECREF(object_array$argnum[i]); }
+ }
+ free(is_new_object_array$argnum);
+ }
+ free(object_array$argnum);
+ }
+}
+
+/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4,
+ * DATA_TYPE* IN_ARRAY4)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* IN_ARRAY4)
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* IN_ARRAY4)
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[4] = { -1, -1, -1 , -1};
+ array = obj_to_array_contiguous_allow_conversion($input, DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 4) ||
+ !require_size(array, size, 4)) SWIG_fail;
+ $1 = (DIM_TYPE) array_size(array,0);
+ $2 = (DIM_TYPE) array_size(array,1);
+ $3 = (DIM_TYPE) array_size(array,2);
+ $4 = (DIM_TYPE) array_size(array,3);
+ $5 = (DATA_TYPE*) array_data(array);
+}
+%typemap(freearg)
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* IN_ARRAY4)
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DATA_TYPE* IN_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2,
+ * DIM_TYPE DIM3, DIM_TYPE DIM4)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE* IN_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE* IN_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[4] = { -1, -1, -1, -1 };
+ array = obj_to_array_fortran_allow_conversion($input, DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 4) ||
+ !require_size(array, size, 4) | !require_fortran(array)) SWIG_fail;
+ $1 = (DATA_TYPE*) array_data(array);
+ $2 = (DIM_TYPE) array_size(array,0);
+ $3 = (DIM_TYPE) array_size(array,1);
+ $4 = (DIM_TYPE) array_size(array,2);
+ $5 = (DIM_TYPE) array_size(array,3);
+}
+%typemap(freearg)
+ (DATA_TYPE* IN_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4,
+ * DATA_TYPE* IN_FARRAY4)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* IN_FARRAY4)
+{
+ $1 = is_array($input) || PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* IN_FARRAY4)
+ (PyArrayObject* array=NULL, int is_new_object=0)
+{
+ npy_intp size[4] = { -1, -1, -1 , -1 };
+ array = obj_to_array_fortran_allow_conversion($input, DATA_TYPECODE,
+ &is_new_object);
+ if (!array || !require_dimensions(array, 4) ||
+ !require_size(array, size, 4) || !require_fortran(array)) SWIG_fail;
+ $1 = (DIM_TYPE) array_size(array,0);
+ $2 = (DIM_TYPE) array_size(array,1);
+ $3 = (DIM_TYPE) array_size(array,2);
+ $4 = (DIM_TYPE) array_size(array,3);
+ $5 = (DATA_TYPE*) array_data(array);
+}
+%typemap(freearg)
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* IN_FARRAY4)
+{
+ if (is_new_object$argnum && array$argnum)
+ { Py_DECREF(array$argnum); }
+}
+
+/***************************/
+/* In-Place Array Typemaps */
+/***************************/
+
+/* Typemap suite for (DATA_TYPE INPLACE_ARRAY1[ANY])
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE INPLACE_ARRAY1[ANY])
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE INPLACE_ARRAY1[ANY])
+ (PyArrayObject* array=NULL)
+{
+ npy_intp size[1] = { $1_dim0 };
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,1) || !require_size(array, size, 1) ||
+ !require_contiguous(array) || !require_native(array)) SWIG_fail;
+ $1 = ($1_ltype) array_data(array);
+}
+
+/* Typemap suite for (DATA_TYPE* INPLACE_ARRAY1, DIM_TYPE DIM1)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE* INPLACE_ARRAY1, DIM_TYPE DIM1)
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE* INPLACE_ARRAY1, DIM_TYPE DIM1)
+ (PyArrayObject* array=NULL, int i=1)
+{
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,1) || !require_contiguous(array)
+ || !require_native(array)) SWIG_fail;
+ $1 = (DATA_TYPE*) array_data(array);
+ $2 = 1;
+ for (i=0; i < array_numdims(array); ++i) $2 *= array_size(array,i);
+}
+
+/* Typemap suite for (DIM_TYPE DIM1, DATA_TYPE* INPLACE_ARRAY1)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DIM_TYPE DIM1, DATA_TYPE* INPLACE_ARRAY1)
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DIM_TYPE DIM1, DATA_TYPE* INPLACE_ARRAY1)
+ (PyArrayObject* array=NULL, int i=0)
+{
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,1) || !require_contiguous(array)
+ || !require_native(array)) SWIG_fail;
+ $1 = 1;
+ for (i=0; i < array_numdims(array); ++i) $1 *= array_size(array,i);
+ $2 = (DATA_TYPE*) array_data(array);
+}
+
+/* Typemap suite for (DATA_TYPE INPLACE_ARRAY2[ANY][ANY])
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE INPLACE_ARRAY2[ANY][ANY])
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE INPLACE_ARRAY2[ANY][ANY])
+ (PyArrayObject* array=NULL)
+{
+ npy_intp size[2] = { $1_dim0, $1_dim1 };
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,2) || !require_size(array, size, 2) ||
+ !require_contiguous(array) || !require_native(array)) SWIG_fail;
+ $1 = ($1_ltype) array_data(array);
+}
+
+/* Typemap suite for (DATA_TYPE* INPLACE_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE* INPLACE_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE* INPLACE_ARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+ (PyArrayObject* array=NULL)
+{
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,2) || !require_contiguous(array)
+ || !require_native(array)) SWIG_fail;
+ $1 = (DATA_TYPE*) array_data(array);
+ $2 = (DIM_TYPE) array_size(array,0);
+ $3 = (DIM_TYPE) array_size(array,1);
+}
+
+/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* INPLACE_ARRAY2)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* INPLACE_ARRAY2)
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* INPLACE_ARRAY2)
+ (PyArrayObject* array=NULL)
+{
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,2) || !require_contiguous(array) ||
+ !require_native(array)) SWIG_fail;
+ $1 = (DIM_TYPE) array_size(array,0);
+ $2 = (DIM_TYPE) array_size(array,1);
+ $3 = (DATA_TYPE*) array_data(array);
+}
+
+/* Typemap suite for (DATA_TYPE* INPLACE_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE* INPLACE_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE* INPLACE_FARRAY2, DIM_TYPE DIM1, DIM_TYPE DIM2)
+ (PyArrayObject* array=NULL)
+{
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,2) || !require_contiguous(array)
+ || !require_native(array) || !require_fortran(array)) SWIG_fail;
+ $1 = (DATA_TYPE*) array_data(array);
+ $2 = (DIM_TYPE) array_size(array,0);
+ $3 = (DIM_TYPE) array_size(array,1);
+}
+
+/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* INPLACE_FARRAY2)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* INPLACE_FARRAY2)
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DATA_TYPE* INPLACE_FARRAY2)
+ (PyArrayObject* array=NULL)
+{
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,2) || !require_contiguous(array) ||
+ !require_native(array) || !require_fortran(array)) SWIG_fail;
+ $1 = (DIM_TYPE) array_size(array,0);
+ $2 = (DIM_TYPE) array_size(array,1);
+ $3 = (DATA_TYPE*) array_data(array);
+}
+
+/* Typemap suite for (DATA_TYPE INPLACE_ARRAY3[ANY][ANY][ANY])
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE INPLACE_ARRAY3[ANY][ANY][ANY])
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE INPLACE_ARRAY3[ANY][ANY][ANY])
+ (PyArrayObject* array=NULL)
+{
+ npy_intp size[3] = { $1_dim0, $1_dim1, $1_dim2 };
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,3) || !require_size(array, size, 3) ||
+ !require_contiguous(array) || !require_native(array)) SWIG_fail;
+ $1 = ($1_ltype) array_data(array);
+}
+
+/* Typemap suite for (DATA_TYPE* INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2,
+ * DIM_TYPE DIM3)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE* INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE* INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+ (PyArrayObject* array=NULL)
+{
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,3) || !require_contiguous(array) ||
+ !require_native(array)) SWIG_fail;
+ $1 = (DATA_TYPE*) array_data(array);
+ $2 = (DIM_TYPE) array_size(array,0);
+ $3 = (DIM_TYPE) array_size(array,1);
+ $4 = (DIM_TYPE) array_size(array,2);
+}
+
+/* Typemap suite for (DATA_TYPE** INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2,
+ * DIM_TYPE DIM3)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE** INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+{
+ $1 = PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE** INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+ (DATA_TYPE** array=NULL, PyArrayObject** object_array=NULL)
+{
+ npy_intp size[2] = { -1, -1 };
+ PyArrayObject* temp_array;
+ Py_ssize_t i;
+
+ /* length of the list */
+ $2 = PyList_Size($input);
+
+ /* the arrays */
+ array = (DATA_TYPE **)malloc($2*sizeof(DATA_TYPE *));
+ object_array = (PyArrayObject **)calloc($2,sizeof(PyArrayObject *));
+
+ if (array == NULL || object_array == NULL)
+ {
+ SWIG_fail;
+ }
+
+ for (i=0; i<$2; i++)
+ {
+ temp_array = obj_to_array_no_conversion(PySequence_GetItem($input,i), DATA_TYPECODE);
+
+ /* the new array must be stored so that it can be destroyed in freearg */
+ object_array[i] = temp_array;
+
+ if ( !temp_array || !require_dimensions(temp_array, 2) ||
+ !require_contiguous(temp_array) ||
+ !require_native(temp_array) ||
+ !PyArray_EquivTypenums(array_type(temp_array), DATA_TYPECODE)
+ ) SWIG_fail;
+
+ /* store the size of the first array in the list, then use that for comparison. */
+ if (i == 0)
+ {
+ size[0] = array_size(temp_array,0);
+ size[1] = array_size(temp_array,1);
+ }
+
+ if (!require_size(temp_array, size, 2)) SWIG_fail;
+
+ array[i] = (DATA_TYPE*) array_data(temp_array);
+ }
+
+ $1 = (DATA_TYPE**) array;
+ $3 = (DIM_TYPE) size[0];
+ $4 = (DIM_TYPE) size[1];
+}
+%typemap(freearg)
+ (DATA_TYPE** INPLACE_ARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+{
+ if (array$argnum!=NULL) free(array$argnum);
+ if (object_array$argnum!=NULL) free(object_array$argnum);
+}
+
+/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3,
+ * DATA_TYPE* INPLACE_ARRAY3)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* INPLACE_ARRAY3)
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* INPLACE_ARRAY3)
+ (PyArrayObject* array=NULL)
+{
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,3) || !require_contiguous(array)
+ || !require_native(array)) SWIG_fail;
+ $1 = (DIM_TYPE) array_size(array,0);
+ $2 = (DIM_TYPE) array_size(array,1);
+ $3 = (DIM_TYPE) array_size(array,2);
+ $4 = (DATA_TYPE*) array_data(array);
+}
+
+/* Typemap suite for (DATA_TYPE* INPLACE_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2,
+ * DIM_TYPE DIM3)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE* INPLACE_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE* INPLACE_FARRAY3, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3)
+ (PyArrayObject* array=NULL)
+{
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,3) || !require_contiguous(array) ||
+ !require_native(array) || !require_fortran(array)) SWIG_fail;
+ $1 = (DATA_TYPE*) array_data(array);
+ $2 = (DIM_TYPE) array_size(array,0);
+ $3 = (DIM_TYPE) array_size(array,1);
+ $4 = (DIM_TYPE) array_size(array,2);
+}
+
+/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3,
+ * DATA_TYPE* INPLACE_FARRAY3)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* INPLACE_FARRAY3)
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DATA_TYPE* INPLACE_FARRAY3)
+ (PyArrayObject* array=NULL)
+{
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,3) || !require_contiguous(array)
+ || !require_native(array) || !require_fortran(array)) SWIG_fail;
+ $1 = (DIM_TYPE) array_size(array,0);
+ $2 = (DIM_TYPE) array_size(array,1);
+ $3 = (DIM_TYPE) array_size(array,2);
+ $4 = (DATA_TYPE*) array_data(array);
+}
+
+/* Typemap suite for (DATA_TYPE INPLACE_ARRAY4[ANY][ANY][ANY][ANY])
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE INPLACE_ARRAY4[ANY][ANY][ANY][ANY])
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE INPLACE_ARRAY4[ANY][ANY][ANY][ANY])
+ (PyArrayObject* array=NULL)
+{
+ npy_intp size[4] = { $1_dim0, $1_dim1, $1_dim2 , $1_dim3 };
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,4) || !require_size(array, size, 4) ||
+ !require_contiguous(array) || !require_native(array)) SWIG_fail;
+ $1 = ($1_ltype) array_data(array);
+}
+
+/* Typemap suite for (DATA_TYPE* INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2,
+ * DIM_TYPE DIM3, DIM_TYPE DIM4)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE* INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE* INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+ (PyArrayObject* array=NULL)
+{
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,4) || !require_contiguous(array) ||
+ !require_native(array)) SWIG_fail;
+ $1 = (DATA_TYPE*) array_data(array);
+ $2 = (DIM_TYPE) array_size(array,0);
+ $3 = (DIM_TYPE) array_size(array,1);
+ $4 = (DIM_TYPE) array_size(array,2);
+ $5 = (DIM_TYPE) array_size(array,3);
+}
+
+/* Typemap suite for (DATA_TYPE** INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2,
+ * DIM_TYPE DIM3, DIM_TYPE DIM4)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE** INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+{
+ $1 = PySequence_Check($input);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE** INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+ (DATA_TYPE** array=NULL, PyArrayObject** object_array=NULL)
+{
+ npy_intp size[3] = { -1, -1, -1 };
+ PyArrayObject* temp_array;
+ Py_ssize_t i;
+
+ /* length of the list */
+ $2 = PyList_Size($input);
+
+ /* the arrays */
+ array = (DATA_TYPE **)malloc($2*sizeof(DATA_TYPE *));
+ object_array = (PyArrayObject **)calloc($2,sizeof(PyArrayObject *));
+
+ if (array == NULL || object_array == NULL)
+ {
+ SWIG_fail;
+ }
+
+ for (i=0; i<$2; i++)
+ {
+ temp_array = obj_to_array_no_conversion(PySequence_GetItem($input,i), DATA_TYPECODE);
+
+ /* the new array must be stored so that it can be destroyed in freearg */
+ object_array[i] = temp_array;
+
+ if ( !temp_array || !require_dimensions(temp_array, 3) ||
+ !require_contiguous(temp_array) ||
+ !require_native(temp_array) ||
+ !PyArray_EquivTypenums(array_type(temp_array), DATA_TYPECODE)
+ ) SWIG_fail;
+
+ /* store the size of the first array in the list, then use that for comparison. */
+ if (i == 0)
+ {
+ size[0] = array_size(temp_array,0);
+ size[1] = array_size(temp_array,1);
+ size[2] = array_size(temp_array,2);
+ }
+
+ if (!require_size(temp_array, size, 3)) SWIG_fail;
+
+ array[i] = (DATA_TYPE*) array_data(temp_array);
+ }
+
+ $1 = (DATA_TYPE**) array;
+ $3 = (DIM_TYPE) size[0];
+ $4 = (DIM_TYPE) size[1];
+ $5 = (DIM_TYPE) size[2];
+}
+%typemap(freearg)
+ (DATA_TYPE** INPLACE_ARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+{
+ if (array$argnum!=NULL) free(array$argnum);
+ if (object_array$argnum!=NULL) free(object_array$argnum);
+}
+
+/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4,
+ * DATA_TYPE* INPLACE_ARRAY4)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* INPLACE_ARRAY4)
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* INPLACE_ARRAY4)
+ (PyArrayObject* array=NULL)
+{
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,4) || !require_contiguous(array)
+ || !require_native(array)) SWIG_fail;
+ $1 = (DIM_TYPE) array_size(array,0);
+ $2 = (DIM_TYPE) array_size(array,1);
+ $3 = (DIM_TYPE) array_size(array,2);
+ $4 = (DIM_TYPE) array_size(array,3);
+ $5 = (DATA_TYPE*) array_data(array);
+}
+
+/* Typemap suite for (DATA_TYPE* INPLACE_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2,
+ * DIM_TYPE DIM3, DIM_TYPE DIM4)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DATA_TYPE* INPLACE_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE* INPLACE_FARRAY4, DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4)
+ (PyArrayObject* array=NULL)
+{
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,4) || !require_contiguous(array) ||
+ !require_native(array) || !require_fortran(array)) SWIG_fail;
+ $1 = (DATA_TYPE*) array_data(array);
+ $2 = (DIM_TYPE) array_size(array,0);
+ $3 = (DIM_TYPE) array_size(array,1);
+ $4 = (DIM_TYPE) array_size(array,2);
+ $5 = (DIM_TYPE) array_size(array,3);
+}
+
+/* Typemap suite for (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3,
+ * DATA_TYPE* INPLACE_FARRAY4)
+ */
+%typecheck(SWIG_TYPECHECK_DOUBLE_ARRAY,
+ fragment="NumPy_Macros")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* INPLACE_FARRAY4)
+{
+ $1 = is_array($input) && PyArray_EquivTypenums(array_type($input),
+ DATA_TYPECODE);
+}
+%typemap(in,
+ fragment="NumPy_Fragments")
+ (DIM_TYPE DIM1, DIM_TYPE DIM2, DIM_TYPE DIM3, DIM_TYPE DIM4, DATA_TYPE* INPLACE_FARRAY4)
+ (PyArrayObject* array=NULL)
+{
+ array = obj_to_array_no_conversion($input, DATA_TYPECODE);
+ if (!array || !require_dimensions(array,4) || !require_contiguous(array)
+ || !require_native(array) || !require_fortran(array)) SWIG_fail;
+ $1 = (DIM_TYPE) array_size(array,0);
+ $2 = (DIM_TYPE) array_size(array,1);
+ $3 = (DIM_TYPE) array_size(array,2);
+ $4 = (DIM_TYPE) array_size(array,3);
+ $5 = (DATA_TYPE*) array_data(array);
+}
+
+/*************************/
+/* Argout Array Typemaps */
+/*************************/
+
+/* Typemap suite for (DATA_TYPE ARGOUT_ARRAY1[ANY])
+ */
+%typemap(in,numinputs=0,
+ fragment="NumPy_Backward_Compatibility,NumPy_Macros")
+ (DATA_TYPE ARGOUT_ARRAY1[ANY])
+ (PyObject* array = NULL)
+{
+ npy_intp dims[1] = { $1_dim0 };
+ array = PyArray_SimpleNew(1, dims, DATA_TYPECODE);
+ if (!array) SWIG_fail;
+ $1 = ($1_ltype) array_data(array);
+}
+%typemap(argout)
+ (DATA_TYPE ARGOUT_ARRAY1[ANY])
+{
+ $result = SWIG_Python_AppendOutput($result,(PyObject*)array$argnum);
+}
+
+/* Typemap suite for (DATA_TYPE* ARGOUT_ARRAY1, DIM_TYPE DIM1)
+ */
+%typemap(in,numinputs=1,
+ fragment="NumPy_Fragments")
+ (DATA_TYPE* ARGOUT_ARRAY1, DIM_TYPE DIM1)
+ (PyObject* array = NULL)
+{
+ npy_intp dims[1];
+ if (!PyInt_Check($input))
+ {
+ const char* typestring = pytype_string($input);
+ PyErr_Format(PyExc_TypeError,
+ "Int dimension expected. '%s' given.",
+ typestring);
+ SWIG_fail;
+ }
+ $2 = (DIM_TYPE) PyInt_AsLong($input);
+ dims[0] = (npy_intp) $2;
+ array = PyArray_SimpleNew(1, dims, DATA_TYPECODE);
+ if (!array) SWIG_fail;
+ $1 = (DATA_TYPE*) array_data(array);
+}
+%typemap(argout)
+ (DATA_TYPE* ARGOUT_ARRAY1, DIM_TYPE DIM1)
+{
+ $result = SWIG_Python_AppendOutput($result,(PyObject*)array$argnum);
+}
+
+/* Typemap suite for (DIM_TYPE DIM1, DATA_TYPE* ARGOUT_ARRAY1)
+ */
+%typemap(in,numinputs=1,
+ fragment="NumPy_Fragments")
+ (DIM_TYPE DIM1, DATA_TYPE* ARGOUT_ARRAY1)
+ (PyObject* array = NULL)
+{
+ npy_intp dims[1];
+ if (!PyInt_Check($input))
+ {
+ const char* typestring = pytype_string($input);
+ PyErr_Format(PyExc_TypeError,
+ "Int dimension expected. '%s' given.",
+ typestring);
+ SWIG_fail;
+ }
+ $1 = (DIM_TYPE) PyInt_AsLong($input);
+ dims[0] = (npy_intp) $1;
+ array = PyArray_SimpleNew(1, dims, DATA_TYPECODE);
+ if (!array) SWIG_fail;
+ $2 = (DATA_TYPE*) array_data(array);
+}
+%typemap(argout)
+ (DIM_TYPE DIM1, DATA_TYPE* ARGOUT_ARRAY1)
+{
+ $result = SWIG_Python_AppendOutput($result,(PyObject*)array$argnum);
+}
+
+/* Typemap suite for (DATA_TYPE ARGOUT_ARRAY2[ANY][ANY])
+ */
+%typemap(in,numinputs=0,
+ fragment="NumPy_Backward_Compatibility,NumPy_Macros")
+ (DATA_TYPE ARGOUT_ARRAY2[ANY][ANY])
+ (PyObject* array = NULL)
+{
+ npy_intp dims[2] = { $1_dim0, $1_dim1 };
+ array = PyArray_SimpleNew(2, dims, DATA_TYPECODE);
+ if (!array) SWIG_fail;
+ $1 = ($1_ltype) array_data(array);
+}
+%typemap(argout)
+ (DATA_TYPE ARGOUT_ARRAY2[ANY][ANY])
+{
+ $result = SWIG_Python_AppendOutput($result,(PyObject*)array$argnum);
+}
+
+/* Typemap suite for (DATA_TYPE ARGOUT_ARRAY3[ANY][ANY][ANY])
+ */
+%typemap(in,numinputs=0,
+ fragment="NumPy_Backward_Compatibility,NumPy_Macros")
+ (DATA_TYPE ARGOUT_ARRAY3[ANY][ANY][ANY])
+ (PyObject* array = NULL)
+{
+ npy_intp dims[3] = { $1_dim0, $1_dim1, $1_dim2 };
+ array = PyArray_SimpleNew(3, dims, DATA_TYPECODE);
+ if (!array) SWIG_fail;
+ $1 = ($1_ltype) array_data(array);
+}
+%typemap(argout)
+ (DATA_TYPE ARGOUT_ARRAY3[ANY][ANY][ANY])
+{
+ $result = SWIG_Python_AppendOutput($result,(PyObject*)array$argnum);
+}
+
+/* Typemap suite for (DATA_TYPE ARGOUT_ARRAY4[ANY][ANY][ANY][ANY])
+ */
+%typemap(in,numinputs=0,
+ fragment="NumPy_Backward_Compatibility,NumPy_Macros")
+ (DATA_TYPE ARGOUT_ARRAY4[ANY][ANY][ANY][ANY])
+ (PyObject* array = NULL)
+{
+ npy_intp dims[4] = { $1_dim0, $1_dim1, $1_dim2, $1_dim3 };
+ array = PyArray_SimpleNew(4, dims, DATA_TYPECODE);
+ if (!array) SWIG_fail;
+ $1 = ($1_ltype) array_data(array);
+}
+%typemap(argout)
+ (DATA_TYPE ARGOUT_ARRAY4[ANY][ANY][ANY][ANY])
+{
+ $result = SWIG_Python_AppendOutput($result,(PyObject*)array$argnum);
+}
+
+/*****************************/
+/* Argoutview Array Typemaps */
+/*****************************/
+
+/* Typemap suite for (DATA_TYPE** ARGOUTVIEW_ARRAY1, DIM_TYPE* DIM1)
+ */
+%typemap(in,numinputs=0)
+ (DATA_TYPE** ARGOUTVIEW_ARRAY1, DIM_TYPE* DIM1 )
+ (DATA_TYPE* data_temp = NULL , DIM_TYPE dim_temp)
+{
+ $1 = &data_temp;
+ $2 = &dim_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility")
+ (DATA_TYPE** ARGOUTVIEW_ARRAY1, DIM_TYPE* DIM1)
+{
+ npy_intp dims[1] = { *$2 };
+ PyObject* obj = PyArray_SimpleNewFromData(1, dims, DATA_TYPECODE, (void*)(*$1));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DIM_TYPE* DIM1, DATA_TYPE** ARGOUTVIEW_ARRAY1)
+ */
+%typemap(in,numinputs=0)
+ (DIM_TYPE* DIM1 , DATA_TYPE** ARGOUTVIEW_ARRAY1)
+ (DIM_TYPE dim_temp, DATA_TYPE* data_temp = NULL )
+{
+ $1 = &dim_temp;
+ $2 = &data_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility")
+ (DIM_TYPE* DIM1, DATA_TYPE** ARGOUTVIEW_ARRAY1)
+{
+ npy_intp dims[1] = { *$1 };
+ PyObject* obj = PyArray_SimpleNewFromData(1, dims, DATA_TYPECODE, (void*)(*$2));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DATA_TYPE** ARGOUTVIEW_ARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2)
+ */
+%typemap(in,numinputs=0)
+ (DATA_TYPE** ARGOUTVIEW_ARRAY2, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 )
+ (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp)
+{
+ $1 = &data_temp;
+ $2 = &dim1_temp;
+ $3 = &dim2_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility")
+ (DATA_TYPE** ARGOUTVIEW_ARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2)
+{
+ npy_intp dims[2] = { *$2, *$3 };
+ PyObject* obj = PyArray_SimpleNewFromData(2, dims, DATA_TYPECODE, (void*)(*$1));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEW_ARRAY2)
+ */
+%typemap(in,numinputs=0)
+ (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DATA_TYPE** ARGOUTVIEW_ARRAY2)
+ (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DATA_TYPE* data_temp = NULL )
+{
+ $1 = &dim1_temp;
+ $2 = &dim2_temp;
+ $3 = &data_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility")
+ (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEW_ARRAY2)
+{
+ npy_intp dims[2] = { *$1, *$2 };
+ PyObject* obj = PyArray_SimpleNewFromData(2, dims, DATA_TYPECODE, (void*)(*$3));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DATA_TYPE** ARGOUTVIEW_FARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2)
+ */
+%typemap(in,numinputs=0)
+ (DATA_TYPE** ARGOUTVIEW_FARRAY2, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 )
+ (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp)
+{
+ $1 = &data_temp;
+ $2 = &dim1_temp;
+ $3 = &dim2_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements")
+ (DATA_TYPE** ARGOUTVIEW_FARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2)
+{
+ npy_intp dims[2] = { *$2, *$3 };
+ PyObject* obj = PyArray_SimpleNewFromData(2, dims, DATA_TYPECODE, (void*)(*$1));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array || !require_fortran(array)) SWIG_fail;
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEW_FARRAY2)
+ */
+%typemap(in,numinputs=0)
+ (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DATA_TYPE** ARGOUTVIEW_FARRAY2)
+ (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DATA_TYPE* data_temp = NULL )
+{
+ $1 = &dim1_temp;
+ $2 = &dim2_temp;
+ $3 = &data_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements")
+ (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEW_FARRAY2)
+{
+ npy_intp dims[2] = { *$1, *$2 };
+ PyObject* obj = PyArray_SimpleNewFromData(2, dims, DATA_TYPECODE, (void*)(*$3));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array || !require_fortran(array)) SWIG_fail;
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DATA_TYPE** ARGOUTVIEW_ARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2,
+ DIM_TYPE* DIM3)
+ */
+%typemap(in,numinputs=0)
+ (DATA_TYPE** ARGOUTVIEW_ARRAY3, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 )
+ (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp)
+{
+ $1 = &data_temp;
+ $2 = &dim1_temp;
+ $3 = &dim2_temp;
+ $4 = &dim3_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility")
+ (DATA_TYPE** ARGOUTVIEW_ARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3)
+{
+ npy_intp dims[3] = { *$2, *$3, *$4 };
+ PyObject* obj = PyArray_SimpleNewFromData(3, dims, DATA_TYPECODE, (void*)(*$1));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3,
+ DATA_TYPE** ARGOUTVIEW_ARRAY3)
+ */
+%typemap(in,numinputs=0)
+ (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEW_ARRAY3)
+ (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DATA_TYPE* data_temp = NULL)
+{
+ $1 = &dim1_temp;
+ $2 = &dim2_temp;
+ $3 = &dim3_temp;
+ $4 = &data_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility")
+ (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEW_ARRAY3)
+{
+ npy_intp dims[3] = { *$1, *$2, *$3 };
+ PyObject* obj = PyArray_SimpleNewFromData(3, dims, DATA_TYPECODE, (void*)(*$4));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DATA_TYPE** ARGOUTVIEW_FARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2,
+ DIM_TYPE* DIM3)
+ */
+%typemap(in,numinputs=0)
+ (DATA_TYPE** ARGOUTVIEW_FARRAY3, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 )
+ (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp)
+{
+ $1 = &data_temp;
+ $2 = &dim1_temp;
+ $3 = &dim2_temp;
+ $4 = &dim3_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements")
+ (DATA_TYPE** ARGOUTVIEW_FARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3)
+{
+ npy_intp dims[3] = { *$2, *$3, *$4 };
+ PyObject* obj = PyArray_SimpleNewFromData(3, dims, DATA_TYPECODE, (void*)(*$1));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array || !require_fortran(array)) SWIG_fail;
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3,
+ DATA_TYPE** ARGOUTVIEW_FARRAY3)
+ */
+%typemap(in,numinputs=0)
+ (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DATA_TYPE** ARGOUTVIEW_FARRAY3)
+ (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DATA_TYPE* data_temp = NULL )
+{
+ $1 = &dim1_temp;
+ $2 = &dim2_temp;
+ $3 = &dim3_temp;
+ $4 = &data_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements")
+ (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEW_FARRAY3)
+{
+ npy_intp dims[3] = { *$1, *$2, *$3 };
+ PyObject* obj = PyArray_SimpleNewFromData(3, dims, DATA_TYPECODE, (void*)(*$4));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array || !require_fortran(array)) SWIG_fail;
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DATA_TYPE** ARGOUTVIEW_ARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2,
+ DIM_TYPE* DIM3, DIM_TYPE* DIM4)
+ */
+%typemap(in,numinputs=0)
+ (DATA_TYPE** ARGOUTVIEW_ARRAY4, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 )
+ (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp)
+{
+ $1 = &data_temp;
+ $2 = &dim1_temp;
+ $3 = &dim2_temp;
+ $4 = &dim3_temp;
+ $5 = &dim4_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility")
+ (DATA_TYPE** ARGOUTVIEW_ARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4)
+{
+ npy_intp dims[4] = { *$2, *$3, *$4 , *$5 };
+ PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$1));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4,
+ DATA_TYPE** ARGOUTVIEW_ARRAY4)
+ */
+%typemap(in,numinputs=0)
+ (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 , DATA_TYPE** ARGOUTVIEW_ARRAY4)
+ (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp, DATA_TYPE* data_temp = NULL )
+{
+ $1 = &dim1_temp;
+ $2 = &dim2_temp;
+ $3 = &dim3_temp;
+ $4 = &dim4_temp;
+ $5 = &data_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility")
+ (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEW_ARRAY4)
+{
+ npy_intp dims[4] = { *$1, *$2, *$3 , *$4 };
+ PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$5));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DATA_TYPE** ARGOUTVIEW_FARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2,
+ DIM_TYPE* DIM3, DIM_TYPE* DIM4)
+ */
+%typemap(in,numinputs=0)
+ (DATA_TYPE** ARGOUTVIEW_FARRAY4, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 )
+ (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp)
+{
+ $1 = &data_temp;
+ $2 = &dim1_temp;
+ $3 = &dim2_temp;
+ $4 = &dim3_temp;
+ $5 = &dim4_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements")
+ (DATA_TYPE** ARGOUTVIEW_FARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4)
+{
+ npy_intp dims[4] = { *$2, *$3, *$4 , *$5 };
+ PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$1));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array || !require_fortran(array)) SWIG_fail;
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4,
+ DATA_TYPE** ARGOUTVIEW_FARRAY4)
+ */
+%typemap(in,numinputs=0)
+ (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 , DATA_TYPE** ARGOUTVIEW_FARRAY4)
+ (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp, DATA_TYPE* data_temp = NULL )
+{
+ $1 = &dim1_temp;
+ $2 = &dim2_temp;
+ $3 = &dim3_temp;
+ $4 = &dim4_temp;
+ $5 = &data_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements")
+ (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEW_FARRAY4)
+{
+ npy_intp dims[4] = { *$1, *$2, *$3 , *$4 };
+ PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$5));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array || !require_fortran(array)) SWIG_fail;
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/*************************************/
+/* Managed Argoutview Array Typemaps */
+/*************************************/
+
+/* Typemap suite for (DATA_TYPE** ARGOUTVIEWM_ARRAY1, DIM_TYPE* DIM1)
+ */
+%typemap(in,numinputs=0)
+ (DATA_TYPE** ARGOUTVIEWM_ARRAY1, DIM_TYPE* DIM1 )
+ (DATA_TYPE* data_temp = NULL , DIM_TYPE dim_temp)
+{
+ $1 = &data_temp;
+ $2 = &dim_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Utilities")
+ (DATA_TYPE** ARGOUTVIEWM_ARRAY1, DIM_TYPE* DIM1)
+{
+ npy_intp dims[1] = { *$2 };
+ PyObject* obj = PyArray_SimpleNewFromData(1, dims, DATA_TYPECODE, (void*)(*$1));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DIM_TYPE* DIM1, DATA_TYPE** ARGOUTVIEWM_ARRAY1)
+ */
+%typemap(in,numinputs=0)
+ (DIM_TYPE* DIM1 , DATA_TYPE** ARGOUTVIEWM_ARRAY1)
+ (DIM_TYPE dim_temp, DATA_TYPE* data_temp = NULL )
+{
+ $1 = &dim_temp;
+ $2 = &data_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Utilities")
+ (DIM_TYPE* DIM1, DATA_TYPE** ARGOUTVIEWM_ARRAY1)
+{
+ npy_intp dims[1] = { *$1 };
+ PyObject* obj = PyArray_SimpleNewFromData(1, dims, DATA_TYPECODE, (void*)(*$2));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DATA_TYPE** ARGOUTVIEWM_ARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2)
+ */
+%typemap(in,numinputs=0)
+ (DATA_TYPE** ARGOUTVIEWM_ARRAY2, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 )
+ (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp)
+{
+ $1 = &data_temp;
+ $2 = &dim1_temp;
+ $3 = &dim2_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Utilities")
+ (DATA_TYPE** ARGOUTVIEWM_ARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2)
+{
+ npy_intp dims[2] = { *$2, *$3 };
+ PyObject* obj = PyArray_SimpleNewFromData(2, dims, DATA_TYPECODE, (void*)(*$1));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEWM_ARRAY2)
+ */
+%typemap(in,numinputs=0)
+ (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DATA_TYPE** ARGOUTVIEWM_ARRAY2)
+ (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DATA_TYPE* data_temp = NULL )
+{
+ $1 = &dim1_temp;
+ $2 = &dim2_temp;
+ $3 = &data_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Utilities")
+ (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEWM_ARRAY2)
+{
+ npy_intp dims[2] = { *$1, *$2 };
+ PyObject* obj = PyArray_SimpleNewFromData(2, dims, DATA_TYPECODE, (void*)(*$3));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DATA_TYPE** ARGOUTVIEWM_FARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2)
+ */
+%typemap(in,numinputs=0)
+ (DATA_TYPE** ARGOUTVIEWM_FARRAY2, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 )
+ (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp)
+{
+ $1 = &data_temp;
+ $2 = &dim1_temp;
+ $3 = &dim2_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements,NumPy_Utilities")
+ (DATA_TYPE** ARGOUTVIEWM_FARRAY2, DIM_TYPE* DIM1, DIM_TYPE* DIM2)
+{
+ npy_intp dims[2] = { *$2, *$3 };
+ PyObject* obj = PyArray_SimpleNewFromData(2, dims, DATA_TYPECODE, (void*)(*$1));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array || !require_fortran(array)) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEWM_FARRAY2)
+ */
+%typemap(in,numinputs=0)
+ (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DATA_TYPE** ARGOUTVIEWM_FARRAY2)
+ (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DATA_TYPE* data_temp = NULL )
+{
+ $1 = &dim1_temp;
+ $2 = &dim2_temp;
+ $3 = &data_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements,NumPy_Utilities")
+ (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DATA_TYPE** ARGOUTVIEWM_FARRAY2)
+{
+ npy_intp dims[2] = { *$1, *$2 };
+ PyObject* obj = PyArray_SimpleNewFromData(2, dims, DATA_TYPECODE, (void*)(*$3));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array || !require_fortran(array)) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DATA_TYPE** ARGOUTVIEWM_ARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2,
+ DIM_TYPE* DIM3)
+ */
+%typemap(in,numinputs=0)
+ (DATA_TYPE** ARGOUTVIEWM_ARRAY3, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 )
+ (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp)
+{
+ $1 = &data_temp;
+ $2 = &dim1_temp;
+ $3 = &dim2_temp;
+ $4 = &dim3_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Utilities")
+ (DATA_TYPE** ARGOUTVIEWM_ARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3)
+{
+ npy_intp dims[3] = { *$2, *$3, *$4 };
+ PyObject* obj = PyArray_SimpleNewFromData(3, dims, DATA_TYPECODE, (void*)(*$1));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3,
+ DATA_TYPE** ARGOUTVIEWM_ARRAY3)
+ */
+%typemap(in,numinputs=0)
+ (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DATA_TYPE** ARGOUTVIEWM_ARRAY3)
+ (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DATA_TYPE* data_temp = NULL )
+{
+ $1 = &dim1_temp;
+ $2 = &dim2_temp;
+ $3 = &dim3_temp;
+ $4 = &data_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Utilities")
+ (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEWM_ARRAY3)
+{
+ npy_intp dims[3] = { *$1, *$2, *$3 };
+ PyObject* obj= PyArray_SimpleNewFromData(3, dims, DATA_TYPECODE, (void*)(*$4));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DATA_TYPE** ARGOUTVIEWM_FARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2,
+ DIM_TYPE* DIM3)
+ */
+%typemap(in,numinputs=0)
+ (DATA_TYPE** ARGOUTVIEWM_FARRAY3, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 )
+ (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp)
+{
+ $1 = &data_temp;
+ $2 = &dim1_temp;
+ $3 = &dim2_temp;
+ $4 = &dim3_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements,NumPy_Utilities")
+ (DATA_TYPE** ARGOUTVIEWM_FARRAY3, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3)
+{
+ npy_intp dims[3] = { *$2, *$3, *$4 };
+ PyObject* obj = PyArray_SimpleNewFromData(3, dims, DATA_TYPECODE, (void*)(*$1));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array || !require_fortran(array)) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3,
+ DATA_TYPE** ARGOUTVIEWM_FARRAY3)
+ */
+%typemap(in,numinputs=0)
+ (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DATA_TYPE** ARGOUTVIEWM_FARRAY3)
+ (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DATA_TYPE* data_temp = NULL )
+{
+ $1 = &dim1_temp;
+ $2 = &dim2_temp;
+ $3 = &dim3_temp;
+ $4 = &data_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements,NumPy_Utilities")
+ (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DATA_TYPE** ARGOUTVIEWM_FARRAY3)
+{
+ npy_intp dims[3] = { *$1, *$2, *$3 };
+ PyObject* obj = PyArray_SimpleNewFromData(3, dims, DATA_TYPECODE, (void*)(*$4));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array || !require_fortran(array)) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DATA_TYPE** ARGOUTVIEWM_ARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2,
+ DIM_TYPE* DIM3, DIM_TYPE* DIM4)
+ */
+%typemap(in,numinputs=0)
+ (DATA_TYPE** ARGOUTVIEWM_ARRAY4, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 )
+ (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp)
+{
+ $1 = &data_temp;
+ $2 = &dim1_temp;
+ $3 = &dim2_temp;
+ $4 = &dim3_temp;
+ $5 = &dim4_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Utilities")
+ (DATA_TYPE** ARGOUTVIEWM_ARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4)
+{
+ npy_intp dims[4] = { *$2, *$3, *$4 , *$5 };
+ PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$1));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4,
+ DATA_TYPE** ARGOUTVIEWM_ARRAY4)
+ */
+%typemap(in,numinputs=0)
+ (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 , DATA_TYPE** ARGOUTVIEWM_ARRAY4)
+ (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp, DATA_TYPE* data_temp = NULL )
+{
+ $1 = &dim1_temp;
+ $2 = &dim2_temp;
+ $3 = &dim3_temp;
+ $4 = &dim4_temp;
+ $5 = &data_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Utilities")
+ (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEWM_ARRAY4)
+{
+ npy_intp dims[4] = { *$1, *$2, *$3 , *$4 };
+ PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$5));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DATA_TYPE** ARGOUTVIEWM_FARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2,
+ DIM_TYPE* DIM3, DIM_TYPE* DIM4)
+ */
+%typemap(in,numinputs=0)
+ (DATA_TYPE** ARGOUTVIEWM_FARRAY4, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 )
+ (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp)
+{
+ $1 = &data_temp;
+ $2 = &dim1_temp;
+ $3 = &dim2_temp;
+ $4 = &dim3_temp;
+ $5 = &dim4_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements,NumPy_Utilities")
+ (DATA_TYPE** ARGOUTVIEWM_FARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3)
+{
+ npy_intp dims[4] = { *$2, *$3, *$4 , *$5 };
+ PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$1));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array || !require_fortran(array)) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4,
+ DATA_TYPE** ARGOUTVIEWM_FARRAY4)
+ */
+%typemap(in,numinputs=0)
+ (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 , DATA_TYPE** ARGOUTVIEWM_FARRAY4)
+ (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp, DATA_TYPE* data_temp = NULL )
+{
+ $1 = &dim1_temp;
+ $2 = &dim2_temp;
+ $3 = &dim3_temp;
+ $4 = &dim4_temp;
+ $5 = &data_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements,NumPy_Utilities")
+ (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEWM_FARRAY4)
+{
+ npy_intp dims[4] = { *$1, *$2, *$3 , *$4 };
+ PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$5));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array || !require_fortran(array)) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DATA_TYPE** ARGOUTVIEWM_ARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2,
+ DIM_TYPE* DIM3, DIM_TYPE* DIM4)
+ */
+%typemap(in,numinputs=0)
+ (DATA_TYPE** ARGOUTVIEWM_ARRAY4, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 )
+ (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp)
+{
+ $1 = &data_temp;
+ $2 = &dim1_temp;
+ $3 = &dim2_temp;
+ $4 = &dim3_temp;
+ $5 = &dim4_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Utilities")
+ (DATA_TYPE** ARGOUTVIEWM_ARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4)
+{
+ npy_intp dims[4] = { *$2, *$3, *$4 , *$5 };
+ PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$1));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4,
+ DATA_TYPE** ARGOUTVIEWM_ARRAY4)
+ */
+%typemap(in,numinputs=0)
+ (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 , DATA_TYPE** ARGOUTVIEWM_ARRAY4)
+ (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp, DATA_TYPE* data_temp = NULL )
+{
+ $1 = &dim1_temp;
+ $2 = &dim2_temp;
+ $3 = &dim3_temp;
+ $4 = &dim4_temp;
+ $5 = &data_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Utilities")
+ (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEWM_ARRAY4)
+{
+ npy_intp dims[4] = { *$1, *$2, *$3 , *$4 };
+ PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$5));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DATA_TYPE** ARGOUTVIEWM_FARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2,
+ DIM_TYPE* DIM3, DIM_TYPE* DIM4)
+ */
+%typemap(in,numinputs=0)
+ (DATA_TYPE** ARGOUTVIEWM_FARRAY4, DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 )
+ (DATA_TYPE* data_temp = NULL , DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp)
+{
+ $1 = &data_temp;
+ $2 = &dim1_temp;
+ $3 = &dim2_temp;
+ $4 = &dim3_temp;
+ $5 = &dim4_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements,NumPy_Utilities")
+ (DATA_TYPE** ARGOUTVIEWM_FARRAY4, DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4)
+{
+ npy_intp dims[4] = { *$2, *$3, *$4 , *$5 };
+ PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$1));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array || !require_fortran(array)) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+/* Typemap suite for (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4,
+ DATA_TYPE** ARGOUTVIEWM_FARRAY4)
+ */
+%typemap(in,numinputs=0)
+ (DIM_TYPE* DIM1 , DIM_TYPE* DIM2 , DIM_TYPE* DIM3 , DIM_TYPE* DIM4 , DATA_TYPE** ARGOUTVIEWM_FARRAY4)
+ (DIM_TYPE dim1_temp, DIM_TYPE dim2_temp, DIM_TYPE dim3_temp, DIM_TYPE dim4_temp, DATA_TYPE* data_temp = NULL )
+{
+ $1 = &dim1_temp;
+ $2 = &dim2_temp;
+ $3 = &dim3_temp;
+ $4 = &dim4_temp;
+ $5 = &data_temp;
+}
+%typemap(argout,
+ fragment="NumPy_Backward_Compatibility,NumPy_Array_Requirements,NumPy_Utilities")
+ (DIM_TYPE* DIM1, DIM_TYPE* DIM2, DIM_TYPE* DIM3, DIM_TYPE* DIM4, DATA_TYPE** ARGOUTVIEWM_FARRAY4)
+{
+ npy_intp dims[4] = { *$1, *$2, *$3 , *$4 };
+ PyObject* obj = PyArray_SimpleNewFromData(4, dims, DATA_TYPECODE, (void*)(*$5));
+ PyArrayObject* array = (PyArrayObject*) obj;
+
+ if (!array || !require_fortran(array)) SWIG_fail;
+
+%#ifdef SWIGPY_USE_CAPSULE
+ PyObject* cap = PyCapsule_New((void*)(*$1), SWIGPY_CAPSULE_NAME, free_cap);
+%#else
+ PyObject* cap = PyCObject_FromVoidPtr((void*)(*$1), free);
+%#endif
+
+%#if NPY_API_VERSION < 0x00000007
+ PyArray_BASE(array) = cap;
+%#else
+ PyArray_SetBaseObject(array,cap);
+%#endif
+
+ $result = SWIG_Python_AppendOutput($result,obj);
+}
+
+%enddef /* %numpy_typemaps() macro */
+/* *************************************************************** */
+
+/* Concrete instances of the %numpy_typemaps() macro: Each invocation
+ * below applies all of the typemaps above to the specified data type.
+ */
+%numpy_typemaps(signed char , NPY_BYTE , int)
+%numpy_typemaps(unsigned char , NPY_UBYTE , int)
+%numpy_typemaps(short , NPY_SHORT , int)
+%numpy_typemaps(unsigned short , NPY_USHORT , int)
+%numpy_typemaps(int , NPY_INT , int)
+%numpy_typemaps(unsigned int , NPY_UINT , int)
+%numpy_typemaps(long , NPY_LONG , int)
+%numpy_typemaps(unsigned long , NPY_ULONG , int)
+%numpy_typemaps(long long , NPY_LONGLONG , int)
+%numpy_typemaps(unsigned long long, NPY_ULONGLONG, int)
+%numpy_typemaps(float , NPY_FLOAT , int)
+%numpy_typemaps(double , NPY_DOUBLE , int)
+
+/* ***************************************************************
+ * The follow macro expansion does not work, because C++ bool is 4
+ * bytes and NPY_BOOL is 1 byte
+ *
+ * %numpy_typemaps(bool, NPY_BOOL, int)
+ */
+
+/* ***************************************************************
+ * On my Mac, I get the following warning for this macro expansion:
+ * 'swig/python detected a memory leak of type 'long double *', no destructor found.'
+ *
+ * %numpy_typemaps(long double, NPY_LONGDOUBLE, int)
+ */
+
+/* ***************************************************************
+ * Swig complains about a syntax error for the following macro
+ * expansions:
+ *
+ * %numpy_typemaps(complex float, NPY_CFLOAT , int)
+ *
+ * %numpy_typemaps(complex double, NPY_CDOUBLE, int)
+ *
+ * %numpy_typemaps(complex long double, NPY_CLONGDOUBLE, int)
+ */
+
+#endif /* SWIGPYTHON */
diff --git a/tensorflow/python/platform/parameterized.py b/tensorflow/python/platform/parameterized.py
new file mode 100644
index 0000000000..cf01512bc1
--- /dev/null
+++ b/tensorflow/python/platform/parameterized.py
@@ -0,0 +1,10 @@
+"""Switch between depending on pyglib.gfile or an OSS replacement."""
+# pylint: disable=unused-import
+# pylint: disable=g-import-not-at-top
+# pylint: disable=wildcard-import
+import tensorflow.python.platform
+import control_imports
+if control_imports.USE_OSS and control_imports.OSS_PARAMETERIZED:
+ from tensorflow.python.platform.default._parameterized import *
+else:
+ from tensorflow.python.platform.google._parameterized import *
diff --git a/tensorflow/python/platform/resource_loader.py b/tensorflow/python/platform/resource_loader.py
new file mode 100644
index 0000000000..a0e6546c28
--- /dev/null
+++ b/tensorflow/python/platform/resource_loader.py
@@ -0,0 +1,10 @@
+"""Load a file resource and return the contents."""
+# pylint: disable=unused-import
+# pylint: disable=g-import-not-at-top
+# pylint: disable=wildcard-import
+import control_imports
+import tensorflow.python.platform
+if control_imports.USE_OSS:
+ from tensorflow.python.platform.default._resource_loader import *
+else:
+ from tensorflow.python.platform.google._resource_loader import *
diff --git a/tensorflow/python/platform/status_bar.py b/tensorflow/python/platform/status_bar.py
new file mode 100644
index 0000000000..720b9d82c0
--- /dev/null
+++ b/tensorflow/python/platform/status_bar.py
@@ -0,0 +1,10 @@
+"""Switch between an internal status bar and a no-op version."""
+# pylint: disable=unused-import
+# pylint: disable=g-import-not-at-top
+# pylint: disable=wildcard-import
+import tensorflow.python.platform
+import control_imports
+if control_imports.USE_OSS:
+ from tensorflow.python.platform.default._status_bar import *
+else:
+ from tensorflow.python.platform.google._status_bar import *
diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py
new file mode 100644
index 0000000000..7d46f9cbc2
--- /dev/null
+++ b/tensorflow/python/platform/test.py
@@ -0,0 +1,6 @@
+from tensorflow.python.platform.googletest import GetTempDir
+from tensorflow.python.platform.googletest import main
+from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
+from tensorflow.python.framework.test_util import IsGoogleCudaEnabled as IsBuiltWithCuda
+
+get_temp_dir = GetTempDir
diff --git a/tensorflow/python/summary/README.md b/tensorflow/python/summary/README.md
new file mode 100644
index 0000000000..8a5fea0d9a
--- /dev/null
+++ b/tensorflow/python/summary/README.md
@@ -0,0 +1,15 @@
+# TensorFlow Event Processing
+
+This folder contains classes useful for analyzing and visualizing TensorFlow
+events files. The code is primarily being developed to support TensorBoard,
+but it can be used by anyone who wishes to analyze or visualize TensorFlow
+events files.
+
+If you wish to load TensorFlow events, you should use an EventAccumulator
+(to load from a single events file) or an EventMultiplexer (to load from
+multiple events files).
+
+The API around these tools has not solidified, and we may make backwards-
+incompatible changes without warning.
+
+If you have questions or requests, please contact danmane@google.com
diff --git a/tensorflow/python/summary/__init__.py b/tensorflow/python/summary/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/summary/__init__.py
diff --git a/tensorflow/python/summary/event_accumulator.py b/tensorflow/python/summary/event_accumulator.py
new file mode 100644
index 0000000000..ae067d94fe
--- /dev/null
+++ b/tensorflow/python/summary/event_accumulator.py
@@ -0,0 +1,433 @@
+"""Takes a generator of values, and accumulates them for a frontend."""
+
+import collections
+import threading
+
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import logging
+from tensorflow.python.summary.impl import directory_watcher
+from tensorflow.python.summary.impl import event_file_loader
+from tensorflow.python.summary.impl import reservoir
+
+namedtuple = collections.namedtuple
+ScalarEvent = namedtuple('ScalarEvent',
+ ['wall_time', 'step', 'value'])
+
+CompressedHistogramEvent = namedtuple('CompressedHistogramEvent',
+ ['wall_time', 'step',
+ 'compressed_histogram_values'])
+
+CompressedHistogramValue = namedtuple('CompressedHistogramValue',
+ ['basis_point', 'value'])
+
+HistogramEvent = namedtuple('HistogramEvent',
+ ['wall_time', 'step', 'histogram_value'])
+
+HistogramValue = namedtuple('HistogramValue',
+ ['min', 'max', 'num', 'sum', 'sum_squares',
+ 'bucket_limit', 'bucket'])
+
+ImageEvent = namedtuple('ImageEvent',
+ ['wall_time', 'step', 'encoded_image_string',
+ 'width', 'height'])
+
+## The tagTypes below are just arbitrary strings chosen to pass the type
+## information of the tag from the backend to the frontend
+COMPRESSED_HISTOGRAMS = 'compressedHistograms'
+HISTOGRAMS = 'histograms'
+IMAGES = 'images'
+SCALARS = 'scalars'
+GRAPH = 'graph'
+
+## normal CDF for std_devs: (-Inf, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, Inf)
+## naturally gives bands around median of width 1 std dev, 2 std dev, 3 std dev,
+## and then the long tail.
+NORMAL_HISTOGRAM_BPS = (0, 668, 1587, 3085, 5000, 6915, 8413, 9332, 10000)
+
+DEFAULT_SIZE_GUIDANCE = {
+ COMPRESSED_HISTOGRAMS: 500,
+ IMAGES: 4,
+ SCALARS: 10000,
+ HISTOGRAMS: 1,
+}
+
+STORE_EVERYTHING_SIZE_GUIDANCE = {
+ COMPRESSED_HISTOGRAMS: 0,
+ IMAGES: 0,
+ SCALARS: 0,
+ HISTOGRAMS: 0,
+}
+
+
+def IsTensorFlowEventsFile(path):
+ """Check the path name to see if it is probably a TF Events file."""
+ return 'tfevents' in path
+
+
+class EventAccumulator(object):
+ """An `EventAccumulator` takes an event generator, and accumulates the values.
+
+ The `EventAccumulator` is intended to provide a convenient Python interface
+ for loading Event data written during a TensorFlow run. TensorFlow writes out
+ `Event` protobuf objects, which have a timestamp and step number, and often
+ contain a `Summary`. Summaries can have different kinds of data like an image,
+ a scalar value, or a histogram. The Summaries also have a tag, which we use to
+ organize logically related data. The `EventAccumulator` supports retrieving
+ the `Event` and `Summary` data by its tag.
+
+ Calling `Tags()` gets a map from `tagType` (e.g. `'images'`,
+ `'compressedHistograms'`, `'scalars'`, etc) to the associated tags for those
+ data types. Then, various functional endpoints (eg
+ `Accumulator.Scalars(tag)`) allow for the retrieval of all data
+ associated with that tag.
+
+ Before usage, the `EventAccumulator` must be activated via `Reload()` or
+ `AutoUpdate(interval)`.
+
+ If activated via `Reload()`, it loads synchronously, so calls to `Values` or
+ `Tags` will block until all outstanding events are processed. Afterwards,
+ `Reload()` may be called again to load any new data.
+
+ If activated via `AutoUpdate(interval)`, it loads asynchronously, so calls to
+ `Values` or `Tags` will immediately return a valid subset of the outstanding
+ event data. It reloads new data every `interval` seconds.
+
+ Histograms and images are very large, so storing all of them is not
+ recommended.
+
+ @@Reload
+ @@AutoUpdate
+ @@Tags
+ @@Scalars
+ @@Graph
+ @@Histograms
+ @@CompressedHistograms
+ @@Images
+ """
+
+ def __init__(self, path, size_guidance=DEFAULT_SIZE_GUIDANCE,
+ compression_bps=NORMAL_HISTOGRAM_BPS):
+ """Construct the `EventAccumulator`.
+
+ Args:
+ path: A file path to a directory containing tf events files, or a single
+ tf events file. The accumulator will load events from this path.
+ size_guidance: Information on how much data the EventAccumulator should
+ store in memory. The DEFAULT_SIZE_GUIDANCE tries not to store too much
+ so as to avoid OOMing the client. The size_guidance should be a map
+ from a `tagType` string to an integer representing the number of
+ items to keep per tag for items of that `tagType`. If the size is 0,
+ all events are stored.
+ compression_bps: Information on how the `EventAccumulator` should compress
+ histogram data for the `CompressedHistograms` tag (for details see
+ `ProcessCompressedHistogram`).
+ """
+ sizes = {}
+ for key in DEFAULT_SIZE_GUIDANCE:
+ if key in size_guidance:
+ sizes[key] = size_guidance[key]
+ else:
+ sizes[key] = DEFAULT_SIZE_GUIDANCE[key]
+
+ self._scalars = reservoir.Reservoir(size=sizes[SCALARS])
+ self._graph = None
+ self._histograms = reservoir.Reservoir(size=sizes[HISTOGRAMS])
+ self._compressed_histograms = reservoir.Reservoir(
+ size=sizes[COMPRESSED_HISTOGRAMS])
+ self._images = reservoir.Reservoir(size=sizes[IMAGES])
+ self._generator_mutex = threading.Lock()
+ self._generator = _GeneratorFromPath(path)
+ self._is_autoupdating = False
+ self._activated = False
+ self._compression_bps = compression_bps
+
+ def Reload(self):
+ """Loads all events added since the last call to `Reload`.
+
+ If `Reload` was never called, loads all events in the file.
+ Calling `Reload` activates the `EventAccumulator`.
+
+ Returns:
+ The `EventAccumulator`.
+ """
+ self._activated = True
+ with self._generator_mutex:
+ for event in self._generator.Load():
+ if event.HasField('graph_def'):
+ if self._graph is not None:
+ logging.warn(('Found more than one graph event per run.'
+ 'Overwritting the graph with the newest event'))
+ self._graph = event.graph_def
+ elif event.HasField('summary'):
+ for value in event.summary.value:
+ if value.HasField('simple_value'):
+ self._ProcessScalar(value.tag, event.wall_time, event.step,
+ value.simple_value)
+ elif value.HasField('histo'):
+ self._ProcessHistogram(value.tag, event.wall_time, event.step,
+ value.histo)
+ self._ProcessCompressedHistogram(value.tag, event.wall_time,
+ event.step, value.histo)
+ elif value.HasField('image'):
+ self._ProcessImage(value.tag, event.wall_time, event.step,
+ value.image)
+ return self
+
+ def AutoUpdate(self, interval=60):
+ """Asynchronously load all events, and periodically reload.
+
+ Calling this function is not thread safe.
+ Calling this function activates the `EventAccumulator`.
+
+ Args:
+ interval: how many seconds after each successful reload to load new events
+ (default 60)
+
+ Returns:
+ The `EventAccumulator`.
+ """
+ if self._is_autoupdating:
+ return
+ self._is_autoupdating = True
+ self._activated = True
+ def Update():
+ self.Reload()
+ logging.info('EventAccumulator update triggered')
+ t = threading.Timer(interval, Update)
+ t.daemon = True
+ t.start()
+ # Asynchronously start the update process, so that the accumulator can
+ # immediately serve data, even if there is a very large event file to parse
+ t = threading.Timer(0, Update)
+ t.daemon = True
+ t.start()
+ return self
+
+ def Tags(self):
+ """Return all tags found in the value stream.
+
+ Raises:
+ RuntimeError: If the `EventAccumulator` has not been activated.
+
+ Returns:
+ A `{tagType: ['list', 'of', 'tags']}` dictionary.
+ """
+ self._VerifyActivated()
+ return {IMAGES: self._images.Keys(),
+ HISTOGRAMS: self._histograms.Keys(),
+ SCALARS: self._scalars.Keys(),
+ COMPRESSED_HISTOGRAMS: self._compressed_histograms.Keys(),
+ GRAPH: self._graph is not None}
+
+ def Scalars(self, tag):
+ """Given a summary tag, return all associated `ScalarEvent`s.
+
+ Args:
+ tag: A string tag associated with the events.
+
+ Raises:
+ KeyError: If the tag is not found.
+ RuntimeError: If the `EventAccumulator` has not been activated.
+
+ Returns:
+ An array of `ScalarEvent`s.
+ """
+ self._VerifyActivated()
+ return self._scalars.Items(tag)
+
+ def Graph(self):
+ """Return the graph definition, if there is one.
+
+ Raises:
+ ValueError: If there is no graph for this run.
+ RuntimeError: If the `EventAccumulator` has not been activated.
+
+ Returns:
+ The `graph_def` proto.
+ """
+ self._VerifyActivated()
+ if self._graph is None:
+ raise ValueError('There is no graph in this EventAccumulator')
+ return self._graph
+
+ def Histograms(self, tag):
+ """Given a summary tag, return all associated histograms.
+
+ Args:
+ tag: A string tag associated with the events.
+
+ Raises:
+ KeyError: If the tag is not found.
+ RuntimeError: If the `EventAccumulator` has not been activated.
+
+ Returns:
+ An array of `HistogramEvent`s.
+ """
+ self._VerifyActivated()
+ return self._histograms.Items(tag)
+
+ def CompressedHistograms(self, tag):
+ """Given a summary tag, return all associated compressed histograms.
+
+ Args:
+ tag: A string tag associated with the events.
+
+ Raises:
+ KeyError: If the tag is not found.
+ RuntimeError: If the `EventAccumulator` has not been activated.
+
+ Returns:
+ An array of `CompressedHistogramEvent`s.
+ """
+ self._VerifyActivated()
+ return self._compressed_histograms.Items(tag)
+
+ def Images(self, tag):
+ """Given a summary tag, return all associated images.
+
+ Args:
+ tag: A string tag associated with the events.
+
+ Raises:
+ KeyError: If the tag is not found.
+ RuntimeError: If the `EventAccumulator` has not been activated.
+
+ Returns:
+ An array of `ImageEvent`s.
+ """
+ self._VerifyActivated()
+ return self._images.Items(tag)
+
+ def _VerifyActivated(self):
+ if not self._activated:
+ raise RuntimeError('Accumulator must be activated before it may be used.')
+
+ def _ProcessScalar(self, tag, wall_time, step, scalar):
+ """Processes a simple value by adding it to accumulated state."""
+ sv = ScalarEvent(wall_time=wall_time, step=step, value=scalar)
+ self._scalars.AddItem(tag, sv)
+
+ def _ProcessHistogram(self, tag, wall_time, step, histo):
+ """Processes a histogram by adding it to accumulated state."""
+ histogram_value = HistogramValue(
+ min=histo.min,
+ max=histo.max,
+ num=histo.num,
+ sum=histo.sum,
+ sum_squares=histo.sum_squares,
+ # convert from proto repeated to list
+ bucket_limit=list(histo.bucket_limit),
+ bucket=list(histo.bucket),
+ )
+ histogram_event = HistogramEvent(
+ wall_time=wall_time,
+ step=step,
+ histogram_value=histogram_value,
+ )
+ self._histograms.AddItem(tag, histogram_event)
+
+ def _Remap(self, x, x0, x1, y0, y1):
+ """Linearly map from [x0, x1] unto [y0, y1]."""
+ return y0 + (x - x0) * float(y1 - y0)/(x1 - x0)
+
+ def _Percentile(self, compression_bps, bucket_limit, cumsum_weights,
+ histo_min, histo_max, histo_num):
+ """Linearly interpolates a histogram weight for a particular basis point.
+
+ Uses clamping methods on `histo_min` and `histo_max` to produce tight
+ linear estimates of the histogram weight at a particular basis point.
+
+ Args:
+ compression_bps: The desired basis point at which to estimate the weight
+ bucket_limit: An array of the RHS histogram bucket limits
+ cumsum_weights: A cumulative sum of the fraction of weights in each
+ histogram bucket, represented in basis points.
+ histo_min: The minimum weight observed in the weight histogram
+ histo_max: The maximum weight observed in the weight histogram
+ histo_num: The number of items in the weight histogram
+
+ Returns:
+ A linearly interpolated value of the histogram weight estimate.
+ """
+ if histo_num == 0: return 0
+
+ for i, cumsum in enumerate(cumsum_weights):
+ if cumsum >= compression_bps:
+ cumsum_prev = cumsum_weights[i-1] if i > 0 else 0
+ # Prevent cumsum = 0, cumsum_prev = 0, lerp divide by zero.
+ if cumsum == cumsum_prev: continue
+
+ # Calculate the lower bound of interpolation
+ lhs = bucket_limit[i-1] if (i > 0 and cumsum_prev > 0) else histo_min
+ lhs = max(lhs, histo_min)
+
+ # Calculate the upper bound of interpolation
+ rhs = bucket_limit[i]
+ rhs = min(rhs, histo_max)
+
+ weight = self._Remap(compression_bps, cumsum_prev, cumsum, lhs, rhs)
+ return weight
+
+ ## We have not exceeded cumsum, so return the max observed.
+ return histo_max
+
+ def _ProcessCompressedHistogram(self, tag, wall_time, step, histo):
+ """Processes a histogram by adding a compression to accumulated state.
+
+ Adds a compressed histogram by linearly interpolating histogram buckets to
+ represent the histogram weight at multiple compression points. Uses
+ self._compression_bps (passed to EventAccumulator constructor) as the
+ compression points (represented in basis points, 1/100ths of a precent).
+
+ Args:
+ tag: A string name of the tag for which histograms are retrieved.
+ wall_time: Time in seconds since epoch
+ step: Number of steps that have passed
+ histo: proto2 histogram Object
+ """
+ def _CumulativeSum(arr):
+ return [sum(arr[:i+1]) for i in range(len(arr))]
+
+ # Convert from proto repeated field into a Python list.
+ bucket = list(histo.bucket)
+ bucket_limit = list(histo.bucket_limit)
+
+ bucket_total = sum(bucket)
+ fraction_weights = [float(10000*x)/bucket_total for x in bucket]
+ cumsum_weights = _CumulativeSum(fraction_weights)
+
+ percentiles = [
+ self._Percentile(bps, bucket_limit, cumsum_weights, histo.min,
+ histo.max, histo.num) for bps in self._compression_bps
+ ]
+
+ compressed_histogram_values = [CompressedHistogramValue(
+ basis_point=bps,
+ value=value) for bps, value in zip(self._compression_bps, percentiles)]
+ histogram_event = CompressedHistogramEvent(
+ wall_time=wall_time,
+ step=step,
+ compressed_histogram_values=compressed_histogram_values)
+
+ self._compressed_histograms.AddItem(tag, histogram_event)
+
+ def _ProcessImage(self, tag, wall_time, step, image):
+ """Processes an image by adding it to accumulated state."""
+ event = ImageEvent(
+ wall_time=wall_time,
+ step=step,
+ encoded_image_string=image.encoded_image_string,
+ width=image.width,
+ height=image.height
+ )
+ self._images.AddItem(tag, event)
+
+
+def _GeneratorFromPath(path):
+ """Create an event generator for file or directory at given path string."""
+ loader_factory = event_file_loader.EventFileLoader
+ if gfile.IsDirectory(path):
+ return directory_watcher.DirectoryWatcher(path, loader_factory,
+ IsTensorFlowEventsFile)
+ else:
+ return loader_factory(path)
diff --git a/tensorflow/python/summary/event_accumulator_test.py b/tensorflow/python/summary/event_accumulator_test.py
new file mode 100644
index 0000000000..c8de80ccba
--- /dev/null
+++ b/tensorflow/python/summary/event_accumulator_test.py
@@ -0,0 +1,422 @@
+import os
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+from tensorflow.python.platform import gfile
+from tensorflow.python.summary import event_accumulator as ea
+
+
+class _EventGenerator(object):
+
+ def __init__(self):
+ self.items = []
+
+ def Load(self):
+ while self.items:
+ yield self.items.pop(0)
+
+ def AddScalar(self, tag, wall_time=0, step=0, value=0):
+ event = tf.Event(
+ wall_time=wall_time, step=step,
+ summary=tf.Summary(
+ value=[tf.Summary.Value(tag=tag, simple_value=value)]
+ )
+ )
+ self.AddEvent(event)
+
+ def AddHistogram(self, tag, wall_time=0, step=0, hmin=1, hmax=2, hnum=3,
+ hsum=4, hsum_squares=5, hbucket_limit=None, hbucket=None):
+ histo = tf.HistogramProto(min=hmin, max=hmax, num=hnum, sum=hsum,
+ sum_squares=hsum_squares,
+ bucket_limit=hbucket_limit,
+ bucket=hbucket)
+ event = tf.Event(
+ wall_time=wall_time,
+ step=step,
+ summary=tf.Summary(value=[tf.Summary.Value(tag=tag, histo=histo)]))
+ self.AddEvent(event)
+
+ def AddImage(self, tag, wall_time=0, step=0, encoded_image_string='imgstr',
+ width=150, height=100):
+ image = tf.Summary.Image(encoded_image_string=encoded_image_string,
+ width=width, height=height)
+ event = tf.Event(
+ wall_time=wall_time,
+ step=step,
+ summary=tf.Summary(
+ value=[tf.Summary.Value(tag=tag, image=image)]))
+ self.AddEvent(event)
+
+ def AddEvent(self, event):
+ self.items.append(event)
+
+
+class EventAccumulatorTest(tf.test.TestCase):
+
+ def assertTagsEqual(self, tags1, tags2):
+ # Make sure the two dictionaries have the same keys.
+ self.assertItemsEqual(tags1, tags2)
+ # Additionally, make sure each key in the dictionary maps to the same value.
+ for key in tags1:
+ if isinstance(tags1[key], list):
+ # We don't care about the order of the values in lists, thus asserting
+ # only if the items are equal.
+ self.assertItemsEqual(tags1[key], tags2[key])
+ else:
+ # Make sure the values are equal.
+ self.assertEqual(tags1[key], tags2[key])
+
+
+class MockingEventAccumulatorTest(EventAccumulatorTest):
+
+ def setUp(self):
+ super(MockingEventAccumulatorTest, self).setUp()
+ self.empty = {ea.IMAGES: [],
+ ea.SCALARS: [],
+ ea.HISTOGRAMS: [],
+ ea.COMPRESSED_HISTOGRAMS: [],
+ ea.GRAPH: False}
+ self._real_constructor = ea.EventAccumulator
+ self._real_generator = ea._GeneratorFromPath
+ def _FakeAccumulatorConstructor(generator, *args, **kwargs):
+ ea._GeneratorFromPath = lambda x: generator
+ return self._real_constructor(generator, *args, **kwargs)
+ ea.EventAccumulator = _FakeAccumulatorConstructor
+
+ def tearDown(self):
+ ea.EventAccumulator = self._real_constructor
+ ea._GeneratorFromPath = self._real_generator
+
+ def testEmptyAccumulator(self):
+ gen = _EventGenerator()
+ x = ea.EventAccumulator(gen)
+ x.Reload()
+ self.assertEqual(x.Tags(), self.empty)
+
+ def testTags(self):
+ gen = _EventGenerator()
+ gen.AddScalar('sv1')
+ gen.AddScalar('sv2')
+ gen.AddHistogram('hst1')
+ gen.AddHistogram('hst2')
+ gen.AddImage('im1')
+ gen.AddImage('im2')
+ acc = ea.EventAccumulator(gen)
+ acc.Reload()
+ self.assertTagsEqual(
+ acc.Tags(), {
+ ea.IMAGES: ['im1', 'im2'],
+ ea.SCALARS: ['sv1', 'sv2'],
+ ea.HISTOGRAMS: ['hst1', 'hst2'],
+ ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'],
+ ea.GRAPH: False})
+
+ def testReload(self):
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+ acc.Reload()
+ self.assertEqual(acc.Tags(), self.empty)
+ gen.AddScalar('sv1')
+ gen.AddScalar('sv2')
+ gen.AddHistogram('hst1')
+ gen.AddHistogram('hst2')
+ gen.AddImage('im1')
+ gen.AddImage('im2')
+ self.assertEqual(acc.Tags(), self.empty)
+ acc.Reload()
+ self.assertTagsEqual(acc.Tags(), {
+ ea.IMAGES: ['im1', 'im2'],
+ ea.SCALARS: ['sv1', 'sv2'],
+ ea.HISTOGRAMS: ['hst1', 'hst2'],
+ ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'],
+ ea.GRAPH: False})
+
+ def testScalars(self):
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+ sv1 = ea.ScalarEvent(wall_time=1, step=10, value=32)
+ sv2 = ea.ScalarEvent(wall_time=2, step=12, value=64)
+ gen.AddScalar('sv1', wall_time=1, step=10, value=32)
+ gen.AddScalar('sv2', wall_time=2, step=12, value=64)
+ acc.Reload()
+ self.assertEqual(acc.Scalars('sv1'), [sv1])
+ self.assertEqual(acc.Scalars('sv2'), [sv2])
+
+ def testHistograms(self):
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+
+ val1 = ea.HistogramValue(min=1, max=2, num=3, sum=4, sum_squares=5,
+ bucket_limit=[1, 2, 3], bucket=[0, 3, 0])
+ val2 = ea.HistogramValue(min=-2, max=3, num=4, sum=5, sum_squares=6,
+ bucket_limit=[2, 3, 4], bucket=[1, 3, 0])
+
+ hst1 = ea.HistogramEvent(wall_time=1, step=10, histogram_value=val1)
+ hst2 = ea.HistogramEvent(wall_time=2, step=12, histogram_value=val2)
+ gen.AddHistogram('hst1', wall_time=1, step=10, hmin=1, hmax=2, hnum=3,
+ hsum=4, hsum_squares=5, hbucket_limit=[1, 2, 3],
+ hbucket=[0, 3, 0])
+ gen.AddHistogram('hst2', wall_time=2, step=12, hmin=-2, hmax=3, hnum=4,
+ hsum=5, hsum_squares=6, hbucket_limit=[2, 3, 4],
+ hbucket=[1, 3, 0])
+ acc.Reload()
+ self.assertEqual(acc.Histograms('hst1'), [hst1])
+ self.assertEqual(acc.Histograms('hst2'), [hst2])
+
+ def testCompressedHistograms(self):
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen, compression_bps=(0, 2500, 5000, 7500, 10000))
+
+ gen.AddHistogram('hst1', wall_time=1, step=10, hmin=1, hmax=2, hnum=3,
+ hsum=4, hsum_squares=5, hbucket_limit=[1, 2, 3],
+ hbucket=[0, 3, 0])
+ gen.AddHistogram('hst2', wall_time=2, step=12, hmin=-2, hmax=3, hnum=4,
+ hsum=5, hsum_squares=6, hbucket_limit=[2, 3, 4],
+ hbucket=[1, 3, 0])
+ acc.Reload()
+
+ # Create the expected values after compressing hst1
+ expected_vals1 = [ea.CompressedHistogramValue(bp, val) for bp, val in [(
+ 0, 1.0), (2500, 1.25), (5000, 1.5), (7500, 1.75), (10000, 2.0)]]
+ expected_cmphst1 = ea.CompressedHistogramEvent(
+ wall_time=1,
+ step=10,
+ compressed_histogram_values=expected_vals1)
+ self.assertEqual(acc.CompressedHistograms('hst1'), [expected_cmphst1])
+
+ # Create the expected values after compressing hst2
+ expected_vals2 = [
+ ea.CompressedHistogramValue(bp, val)
+ for bp, val in [(0, -2), (2500, 2), (5000, 2 + float(1) / 3), (
+ 7500, 2 + float(2) / 3), (10000, 3)]
+ ]
+ expected_cmphst2 = ea.CompressedHistogramEvent(
+ wall_time=2,
+ step=12,
+ compressed_histogram_values=expected_vals2)
+ self.assertEqual(acc.CompressedHistograms('hst2'), [expected_cmphst2])
+
+ def testPercentile(self):
+
+ def AssertExpectedForBps(bps, expected):
+ output = acc._Percentile(
+ bps, bucket_limit, cumsum_weights, histo_min, histo_max, histo_num)
+ self.assertAlmostEqual(expected, output)
+
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+
+ bucket_limit = [1, 2, 3, 4]
+ histo_num = 100
+
+ ## All weights in the first bucket
+ cumsum_weights = [10000, 10000, 10000, 10000]
+ histo_min = -1
+ histo_max = .9
+ AssertExpectedForBps(0, histo_min)
+ AssertExpectedForBps(2500, acc._Remap(2500, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(5000, acc._Remap(5000, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(7500, acc._Remap(7500, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(10000, histo_max)
+
+ ## All weights in second bucket
+ cumsum_weights = [0, 10000, 10000, 10000]
+ histo_min = 1.1
+ histo_max = 1.8
+ AssertExpectedForBps(0, histo_min)
+ AssertExpectedForBps(2500, acc._Remap(2500, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(5000, acc._Remap(5000, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(7500, acc._Remap(7500, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(10000, histo_max)
+
+ ## All weights in the last bucket
+ cumsum_weights = [0, 0, 0, 10000]
+ histo_min = 3.1
+ histo_max = 3.6
+ AssertExpectedForBps(0, histo_min)
+ AssertExpectedForBps(2500, acc._Remap(2500, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(5000, acc._Remap(5000, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(7500, acc._Remap(7500, 0, 10000, histo_min, histo_max))
+ AssertExpectedForBps(10000, histo_max)
+
+ ## Weights distributed between two buckets
+ cumsum_weights = [0, 4000, 10000, 10000]
+ histo_min = 1.1
+ histo_max = 2.9
+ AssertExpectedForBps(0, histo_min)
+ AssertExpectedForBps(2500, acc._Remap(2500, 0, 4000, histo_min,
+ bucket_limit[1]))
+ AssertExpectedForBps(5000, acc._Remap(5000, 4000, 10000, bucket_limit[1],
+ histo_max))
+ AssertExpectedForBps(7500, acc._Remap(7500, 4000, 10000, bucket_limit[1],
+ histo_max))
+ AssertExpectedForBps(10000, histo_max)
+
+ ## Weights distributed between all buckets
+ cumsum_weights = [1000, 4000, 8000, 10000]
+ histo_min = -1
+ histo_max = 3.9
+ AssertExpectedForBps(0, histo_min)
+ AssertExpectedForBps(2500, acc._Remap(2500, 1000, 4000, bucket_limit[0],
+ bucket_limit[1]))
+ AssertExpectedForBps(5000, acc._Remap(5000, 4000, 8000, bucket_limit[1],
+ bucket_limit[2]))
+ AssertExpectedForBps(7500, acc._Remap(7500, 4000, 8000, bucket_limit[1],
+ bucket_limit[2]))
+ AssertExpectedForBps(9000, acc._Remap(9000, 8000, 10000, bucket_limit[2],
+ histo_max))
+ AssertExpectedForBps(10000, histo_max)
+
+ ## Most weight in first bucket
+ cumsum_weights = [9000, 10000, 10000, 10000]
+ histo_min = -1
+ histo_max = 1.1
+ AssertExpectedForBps(0, histo_min)
+ AssertExpectedForBps(2500, acc._Remap(2500, 0, 9000, histo_min,
+ bucket_limit[0]))
+ AssertExpectedForBps(5000, acc._Remap(5000, 0, 9000, histo_min,
+ bucket_limit[0]))
+ AssertExpectedForBps(7500, acc._Remap(7500, 0, 9000, histo_min,
+ bucket_limit[0]))
+ AssertExpectedForBps(9500, acc._Remap(9500, 9000, 10000, bucket_limit[0],
+ histo_max))
+ AssertExpectedForBps(10000, histo_max)
+
+ def testImages(self):
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+ im1 = ea.ImageEvent(wall_time=1, step=10, encoded_image_string='big',
+ width=400, height=300)
+ im2 = ea.ImageEvent(wall_time=2, step=12, encoded_image_string='small',
+ width=40, height=30)
+ gen.AddImage('im1', wall_time=1, step=10, encoded_image_string='big',
+ width=400, height=300)
+ gen.AddImage('im2', wall_time=2, step=12, encoded_image_string='small',
+ width=40, height=30)
+ acc.Reload()
+ self.assertEqual(acc.Images('im1'), [im1])
+ self.assertEqual(acc.Images('im2'), [im2])
+
+ def testActivation(self):
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+ self.assertFalse(acc._activated)
+ with self.assertRaises(RuntimeError):
+ acc.Tags()
+ with self.assertRaises(RuntimeError):
+ acc.Scalars('sv1')
+ acc.Reload()
+ self.assertTrue(acc._activated)
+ acc._activated = False
+
+ def testKeyError(self):
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+ acc.Reload()
+ with self.assertRaises(KeyError):
+ acc.Scalars('sv1')
+ with self.assertRaises(KeyError):
+ acc.Scalars('hst1')
+ with self.assertRaises(KeyError):
+ acc.Scalars('im1')
+ with self.assertRaises(KeyError):
+ acc.Histograms('sv1')
+ with self.assertRaises(KeyError):
+ acc.Histograms('im1')
+ with self.assertRaises(KeyError):
+ acc.Images('sv1')
+ with self.assertRaises(KeyError):
+ acc.Images('hst1')
+
+ def testNonValueEvents(self):
+ """Tests that non-value events in the generator don't cause early exits."""
+ gen = _EventGenerator()
+ acc = ea.EventAccumulator(gen)
+ gen.AddScalar('sv1', wall_time=1, step=10, value=20)
+ gen.AddEvent(tf.Event(
+ wall_time=2, step=20, file_version='notsv2'))
+ gen.AddScalar('sv3', wall_time=3, step=100, value=1)
+ gen.AddHistogram('hst1')
+ gen.AddImage('im1')
+
+ acc.Reload()
+ self.assertTagsEqual(acc.Tags(), {
+ ea.IMAGES: ['im1'],
+ ea.SCALARS: ['sv1', 'sv3'],
+ ea.HISTOGRAMS: ['hst1'],
+ ea.COMPRESSED_HISTOGRAMS: ['hst1'],
+ ea.GRAPH: False})
+
+
+class RealisticEventAccumulatorTest(EventAccumulatorTest):
+
+ def setUp(self):
+ super(RealisticEventAccumulatorTest, self).setUp()
+
+ def testScalarsRealistically(self):
+ """Test accumulator by writing values and then reading them."""
+ def FakeScalarSummary(tag, value):
+ value = tf.Summary.Value(tag=tag, simple_value=value)
+ summary = tf.Summary(value=[value])
+ return summary
+
+ directory = os.path.join(self.get_temp_dir(), 'values_dir')
+ if gfile.IsDirectory(directory):
+ gfile.DeleteRecursively(directory)
+ gfile.MkDir(directory)
+
+ writer = tf.train.SummaryWriter(directory, max_queue=100)
+ graph_def = tf.GraphDef(node=[tf.NodeDef(name='A', op='Mul')])
+ # Add a graph to the summary writer.
+ writer.add_graph(graph_def)
+
+ # Write a bunch of events using the writer
+ for i in xrange(30):
+ summ_id = FakeScalarSummary('id', i)
+ summ_sq = FakeScalarSummary('sq', i*i)
+ writer.add_summary(summ_id, i*5)
+ writer.add_summary(summ_sq, i*5)
+ writer.flush()
+
+ # Verify that we can load those events properly
+ acc = ea.EventAccumulator(directory)
+ acc.Reload()
+ self.assertTagsEqual(acc.Tags(), {
+ ea.IMAGES: [],
+ ea.SCALARS: ['id', 'sq'],
+ ea.HISTOGRAMS: [],
+ ea.COMPRESSED_HISTOGRAMS: [],
+ ea.GRAPH: True})
+ id_events = acc.Scalars('id')
+ sq_events = acc.Scalars('sq')
+ self.assertEqual(30, len(id_events))
+ self.assertEqual(30, len(sq_events))
+ for i in xrange(30):
+ self.assertEqual(i*5, id_events[i].step)
+ self.assertEqual(i*5, sq_events[i].step)
+ self.assertEqual(i, id_events[i].value)
+ self.assertEqual(i*i, sq_events[i].value)
+
+ # Write a few more events to test incremental reloading
+ for i in xrange(30, 40):
+ summ_id = FakeScalarSummary('id', i)
+ summ_sq = FakeScalarSummary('sq', i*i)
+ writer.add_summary(summ_id, i*5)
+ writer.add_summary(summ_sq, i*5)
+ writer.flush()
+
+ # Verify we can now see all of the data
+ acc.Reload()
+ self.assertEqual(40, len(id_events))
+ self.assertEqual(40, len(sq_events))
+ for i in xrange(40):
+ self.assertEqual(i*5, id_events[i].step)
+ self.assertEqual(i*5, sq_events[i].step)
+ self.assertEqual(i, id_events[i].value)
+ self.assertEqual(i*i, sq_events[i].value)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/python/summary/event_multiplexer.py b/tensorflow/python/summary/event_multiplexer.py
new file mode 100644
index 0000000000..9966d76b21
--- /dev/null
+++ b/tensorflow/python/summary/event_multiplexer.py
@@ -0,0 +1,346 @@
+"""Provides an interface for working with multiple event files."""
+
+import os
+import threading
+
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import logging
+from tensorflow.python.summary import event_accumulator
+
+
+class EventMultiplexer(object):
+ """An `EventMultiplexer` manages access to multiple `EventAccumulator`s.
+
+ Each `EventAccumulator` is associated with a `run`, which is a self-contained
+ TensorFlow execution. The `EventMultiplexer` provides methods for extracting
+ information about events from multiple `run`s.
+
+ Example usage for loading specific runs from files:
+
+ ```python
+ x = EventMultiplexer({'run1': 'path/to/run1', 'run2': 'path/to/run2'})
+ x.Reload()
+ ```
+
+ Example usage for loading a directory where each subdirectory is a run
+
+ ```python
+ (eg:) /parent/directory/path/
+ /parent/directory/path/run1/
+ /parent/directory/path/run1/events.out.tfevents.1001
+ /parent/directory/path/run1/events.out.tfevents.1002
+
+ /parent/directory/path/run2/
+ /parent/directory/path/run2/events.out.tfevents.9232
+
+ /parent/directory/path/run3/
+ /parent/directory/path/run3/events.out.tfevents.9232
+ x = EventMultiplexer().AddRunsFromDirectory('/parent/directory/path')
+ (which is equivalent to:)
+ x = EventMultiplexer({'run1': '/parent/directory/path/run1', 'run2':...}
+ ```
+
+ If you would like to watch `/parent/directory/path`, wait for it to be created
+ (if necessary) and then periodically pick up new runs, use
+ `AutoloadingMultiplexer`
+
+ @@__init__
+ @@AddRun
+ @@AddRunsFromDirectory
+ @@Reload
+ @@AutoUpdate
+ @@Runs
+ @@Scalars
+ @@Graph
+ @@Histograms
+ @@CompressedHistograms
+ @@Images
+ """
+
+ def __init__(self, run_path_map=None,
+ size_guidance=event_accumulator.DEFAULT_SIZE_GUIDANCE):
+ """Constructor for the `EventMultiplexer`.
+
+ Args:
+ run_path_map: Dict `{run: path}` which specifies the
+ name of a run, and the path to find the associated events. If it is
+ None, then the EventMultiplexer initializes without any runs.
+ size_guidance: A dictionary mapping from `tagType` to the number of items
+ to store for each tag of that type. See
+ `event_ccumulator.EventAccumulator` for details.
+ """
+ self._accumulators_mutex = threading.Lock()
+ self._accumulators = {}
+ self._paths = {}
+ self._reload_called = False
+ self._autoupdate_called = False
+ self._autoupdate_interval = None
+ self._size_guidance = size_guidance
+ if run_path_map is not None:
+ for (run, path) in run_path_map.iteritems():
+ self.AddRun(path, run)
+
+ def AddRun(self, path, name=None):
+ """Add a run to the multiplexer.
+
+ If the name is not specified, it is the same as the path.
+
+ If a run by that name exists, and we are already watching the right path,
+ do nothing. If we are watching a different path, replace the event
+ accumulator.
+
+ If `AutoUpdate` or `Reload` have been called, it will `AutoUpdate` or
+ `Reload` the newly created accumulators. This maintains the invariant that
+ once the Multiplexer was activated, all of its accumulators are active.
+
+ Args:
+ path: Path to the event files (or event directory) for given run.
+ name: Name of the run to add. If not provided, is set to path.
+
+ Returns:
+ The `EventMultiplexer`.
+ """
+ if name is None or name is '':
+ name = path
+ accumulator = None
+ with self._accumulators_mutex:
+ if name not in self._accumulators or self._paths[name] != path:
+ if name in self._paths and self._paths[name] != path:
+ # TODO(danmane) - Make it impossible to overwrite an old path with
+ # a new path (just give the new path a distinct name)
+ logging.warning('Conflict for name %s: old path %s, new path %s' %
+ (name, self._paths[name], path))
+ logging.info('Constructing EventAccumulator for %s', path)
+ accumulator = event_accumulator.EventAccumulator(path,
+ self._size_guidance)
+ self._accumulators[name] = accumulator
+ self._paths[name] = path
+ if accumulator:
+ if self._reload_called:
+ accumulator.Reload()
+ if self._autoupdate_called:
+ accumulator.AutoUpdate(self._autoupdate_interval)
+ return self
+
+ def AddRunsFromDirectory(self, path, name=None):
+ """Load runs from a directory, assuming each subdirectory is a run.
+
+ If path doesn't exist, no-op. This ensures that it is safe to call
+ `AddRunsFromDirectory` multiple times, even before the directory is made.
+
+ If the directory contains TensorFlow event files, it is itself treated as a
+ run.
+
+ If the `EventMultiplexer` is already loaded or autoupdating, this will cause
+ the newly created accumulators to also `Reload()` or `AutoUpdate()`.
+
+ Args:
+ path: A string path to a directory to load runs from.
+ name: Optionally, what name to apply to the runs. If name is provided
+ and the directory contains run subdirectories, the name of each subrun
+ is the concatenation of the parent name and the subdirectory name. If
+ name is provided and the directory contains event files, then a run
+ is added called "name" and with the events from the path.
+
+ Raises:
+ ValueError: If the path exists and isn't a directory.
+
+ Returns:
+ The `EventMultiplexer`.
+ """
+ if not gfile.Exists(path):
+ return # Maybe it hasn't been created yet, fail silently to retry later
+ if not gfile.IsDirectory(path):
+ raise ValueError('Path exists and is not a directory, %s' % path)
+ paths = gfile.ListDirectory(path)
+ is_directory = lambda x: gfile.IsDirectory(os.path.join(path, x))
+ subdirectories = filter(is_directory, paths)
+ for s in subdirectories:
+ if name:
+ subname = '/'.join([name, s])
+ else:
+ subname = s
+ self.AddRun(os.path.join(path, s), subname)
+
+ if filter(event_accumulator.IsTensorFlowEventsFile, paths):
+ directory_name = os.path.split(path)[1]
+ logging.info('Directory %s has event files; loading' % directory_name)
+ if name:
+ dname = name
+ else:
+ dname = directory_name
+ self.AddRun(path, dname)
+ return self
+
+ def Reload(self):
+ """Call `Reload` on every `EventAccumulator`."""
+ self._reload_called = True
+ with self._accumulators_mutex:
+ loaders = self._accumulators.values()
+
+ for l in loaders:
+ l.Reload()
+ return self
+
+ def AutoUpdate(self, interval=60):
+ """Call `AutoUpdate(interval)` on every `EventAccumulator`."""
+ self._autoupdate_interval = interval
+ self._autoupdate_called = True
+ with self._accumulators_mutex:
+ loaders = self._accumulators.values()
+ for l in loaders:
+ l.AutoUpdate(interval)
+ return self
+
+ def Scalars(self, run, tag):
+ """Retrieve the scalar events associated with a run and tag.
+
+ Args:
+ run: A string name of the run for which values are retrieved.
+ tag: A string name of the tag for which values are retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+ RuntimeError: If the run's `EventAccumulator` has not been activated.
+
+ Returns:
+ An array of `event_accumulator.ScalarEvents`.
+ """
+ accumulator = self._GetAccumulator(run)
+ return accumulator.Scalars(tag)
+
+ def Graph(self, run):
+ """Retrieve the graphs associated with the provided run.
+
+ Args:
+ run: A string name of a run to load the graph for.
+
+ Raises:
+ KeyError: If the run is not found.
+ ValueError: If the run does not have an associated graph.
+ RuntimeError: If the run's EventAccumulator has not been activated.
+
+ Returns:
+ The `graph_def` protobuf data structure.
+ """
+ accumulator = self._GetAccumulator(run)
+ return accumulator.Graph()
+
+ def Histograms(self, run, tag):
+ """Retrieve the histogram events associated with a run and tag.
+
+ Args:
+ run: A string name of the run for which values are retrieved.
+ tag: A string name of the tag for which values are retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+ RuntimeError: If the run's `EventAccumulator` has not been activated.
+
+ Returns:
+ An array of `event_accumulator.HistogramEvents`.
+ """
+ accumulator = self._GetAccumulator(run)
+ return accumulator.Histograms(tag)
+
+ def CompressedHistograms(self, run, tag):
+ """Retrieve the compressed histogram events associated with a run and tag.
+
+ Args:
+ run: A string name of the run for which values are retrieved.
+ tag: A string name of the tag for which values are retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+ RuntimeError: If the run's EventAccumulator has not been activated.
+
+ Returns:
+ An array of `event_accumulator.CompressedHistogramEvents`.
+ """
+ accumulator = self._GetAccumulator(run)
+ return accumulator.CompressedHistograms(tag)
+
+ def Images(self, run, tag):
+ """Retrieve the image events associated with a run and tag.
+
+ Args:
+ run: A string name of the run for which values are retrieved.
+ tag: A string name of the tag for which values are retrieved.
+
+ Raises:
+ KeyError: If the run is not found, or the tag is not available for
+ the given run.
+ RuntimeError: If the run's `EventAccumulator` has not been activated.
+
+ Returns:
+ An array of `event_accumulator.ImageEvents`.
+ """
+ accumulator = self._GetAccumulator(run)
+ return accumulator.Images(tag)
+
+ def Runs(self):
+ """Return all the run names in the `EventMultiplexer`.
+
+ Returns:
+ ```
+ {runName: { images: [tag1, tag2, tag3],
+ scalarValues: [tagA, tagB, tagC],
+ histograms: [tagX, tagY, tagZ],
+ compressedHistograms: [tagX, tagY, tagZ],
+ graph: true}}
+ ```
+ """
+ with self._accumulators_mutex:
+ # To avoid nested locks, we construct a copy of the run-accumulator map
+ items = list(self._accumulators.iteritems())
+ return {
+ run_name: accumulator.Tags()
+ for run_name, accumulator in items
+ }
+
+ def _GetAccumulator(self, run):
+ with self._accumulators_mutex:
+ return self._accumulators[run]
+
+
+def AutoloadingMultiplexer(path_to_run, interval_secs=60,
+ size_guidance=event_accumulator.DEFAULT_SIZE_GUIDANCE):
+ """Create an `EventMultiplexer` that automatically loads runs in directories.
+
+ Args:
+ path_to_run: Dict `{path: name}` which specifies the path to a directory,
+ and its name (or `None`). The path may contain tfevents files (in which
+ case they are loaded, with name as the name of the run) and subdirectories
+ containing tfevents files (in which case each subdirectory is added as a
+ run, named `'name/subdirectory'`).
+
+ interval_secs: How often to poll the directory for new runs.
+ size_guidance: How much data to store for each tag of various types - see
+ `event_accumulator.EventAccumulator`.
+
+ Returns:
+ The multiplexer which will automatically load from the directories.
+
+ Raises:
+ ValueError: if `path_to_run` is `None`
+ TypeError: if `path_to_run` is not a dict
+ """
+ multiplexer = EventMultiplexer(size_guidance=size_guidance)
+ if path_to_run is None:
+ raise ValueError('Cant construct an autoloading multiplexer without runs.')
+ if not isinstance(path_to_run, dict):
+ raise TypeError('path_to_run should be a dict, was %s', path_to_run)
+ def Load():
+ for (path, name) in path_to_run.iteritems():
+ logging.info('Checking for new runs in %s', path)
+ multiplexer.AddRunsFromDirectory(path, name)
+ t = threading.Timer(interval_secs, Load)
+ t.daemon = True
+ t.start()
+ t = threading.Timer(0, Load)
+ t.daemon = True
+ t.start()
+ return multiplexer
diff --git a/tensorflow/python/summary/event_multiplexer_test.py b/tensorflow/python/summary/event_multiplexer_test.py
new file mode 100644
index 0000000000..35a8aed266
--- /dev/null
+++ b/tensorflow/python/summary/event_multiplexer_test.py
@@ -0,0 +1,244 @@
+import os
+
+import tensorflow.python.platform
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import googletest
+from tensorflow.python.summary import event_accumulator
+from tensorflow.python.summary import event_multiplexer
+
+
+class _FakeAccumulator(object):
+
+ def __init__(self, path):
+ self._path = path
+ self.autoupdate_called = False
+ self.autoupdate_interval = None
+ self.reload_called = False
+
+ def Tags(self):
+ return {event_accumulator.IMAGES: ['im1', 'im2'],
+ event_accumulator.HISTOGRAMS: ['hst1', 'hst2'],
+ event_accumulator.COMPRESSED_HISTOGRAMS: ['cmphst1', 'cmphst2'],
+ event_accumulator.SCALARS: ['sv1', 'sv2']}
+
+ def Scalars(self, tag_name):
+ if tag_name not in self.Tags()[event_accumulator.SCALARS]:
+ raise KeyError
+ return ['%s/%s' % (self._path, tag_name)]
+
+ def Histograms(self, tag_name):
+ if tag_name not in self.Tags()[event_accumulator.HISTOGRAMS]:
+ raise KeyError
+ return ['%s/%s' % (self._path, tag_name)]
+
+ def CompressedHistograms(self, tag_name):
+ if tag_name not in self.Tags()[event_accumulator.COMPRESSED_HISTOGRAMS]:
+ raise KeyError
+ return ['%s/%s' % (self._path, tag_name)]
+
+ def Images(self, tag_name):
+ if tag_name not in self.Tags()[event_accumulator.IMAGES]:
+ raise KeyError
+ return ['%s/%s' % (self._path, tag_name)]
+
+ def AutoUpdate(self, interval):
+ self.autoupdate_called = True
+ self.autoupdate_interval = interval
+
+ def Reload(self):
+ self.reload_called = True
+
+
+def _GetFakeAccumulator(path, size_guidance): # pylint: disable=unused-argument
+ return _FakeAccumulator(path)
+
+
+class EventMultiplexerTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ super(EventMultiplexerTest, self).setUp()
+ event_accumulator.EventAccumulator = _GetFakeAccumulator
+
+ def testEmptyLoader(self):
+ x = event_multiplexer.EventMultiplexer()
+ self.assertEqual(x.Runs(), {})
+
+ def testRunNamesRespected(self):
+ x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
+ self.assertItemsEqual(x.Runs().keys(), ['run1', 'run2'])
+ self.assertEqual(x._GetAccumulator('run1')._path, 'path1')
+ self.assertEqual(x._GetAccumulator('run2')._path, 'path2')
+
+ def testReload(self):
+ x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
+ self.assertFalse(x._GetAccumulator('run1').reload_called)
+ self.assertFalse(x._GetAccumulator('run2').reload_called)
+ x.Reload()
+ self.assertTrue(x._GetAccumulator('run1').reload_called)
+ self.assertTrue(x._GetAccumulator('run2').reload_called)
+
+ def testAutoUpdate(self):
+ x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
+ x.AutoUpdate(5)
+ self.assertTrue(x._GetAccumulator('run1').autoupdate_called)
+ self.assertEqual(x._GetAccumulator('run1').autoupdate_interval, 5)
+ self.assertTrue(x._GetAccumulator('run2').autoupdate_called)
+ self.assertEqual(x._GetAccumulator('run2').autoupdate_interval, 5)
+
+ def testScalars(self):
+ x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
+
+ run1_actual = x.Scalars('run1', 'sv1')
+ run1_expected = ['path1/sv1']
+
+ self.assertEqual(run1_expected, run1_actual)
+
+ def testExceptions(self):
+ x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
+ with self.assertRaises(KeyError):
+ x.Scalars('sv1', 'xxx')
+
+ def testInitialization(self):
+ x = event_multiplexer.EventMultiplexer()
+ self.assertEqual(x.Runs(), {})
+ x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'})
+ self.assertItemsEqual(x.Runs(), ['run1', 'run2'])
+ self.assertEqual(x._GetAccumulator('run1')._path, 'path1')
+ self.assertEqual(x._GetAccumulator('run2')._path, 'path2')
+
+ def testAddRunsFromDirectory(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+ join = os.path.join
+ fakedir = join(tmpdir, 'fake_accumulator_directory')
+ realdir = join(tmpdir, 'real_accumulator_directory')
+ self.assertEqual(x.Runs(), {})
+ x.AddRunsFromDirectory(fakedir)
+ self.assertEqual(x.Runs(), {}, 'loading fakedir had no effect')
+
+ if gfile.IsDirectory(realdir):
+ gfile.DeleteRecursively(realdir)
+ gfile.MkDir(realdir)
+ x.AddRunsFromDirectory(realdir)
+ self.assertEqual(x.Runs(), {}, 'loading empty directory had no effect')
+
+ path1 = join(realdir, 'path1')
+ gfile.MkDir(path1)
+ x.AddRunsFromDirectory(realdir)
+ self.assertEqual(x.Runs().keys(), ['path1'], 'loaded run: path1')
+ loader1 = x._GetAccumulator('path1')
+ self.assertEqual(loader1._path, path1, 'has the correct path')
+
+ path2 = join(realdir, 'path2')
+ gfile.MkDir(path2)
+ x.AddRunsFromDirectory(realdir)
+ self.assertItemsEqual(x.Runs().keys(), ['path1', 'path2'])
+ self.assertEqual(x._GetAccumulator('path1'), loader1,
+ 'loader1 not regenerated')
+ loader2 = x._GetAccumulator('path2')
+
+ path2_2 = join(path2, 'path2')
+ gfile.MkDir(path2_2)
+ x.AddRunsFromDirectory(path2)
+ self.assertItemsEqual(x.Runs().keys(), ['path1', 'path2'])
+ self.assertNotEqual(loader2, x._GetAccumulator('path2'),
+ 'loader2 regenerated')
+ self.assertEqual(x._GetAccumulator('path2')._path, path2_2,
+ 'loader2 path correct')
+
+ def testAddRunsFromDirectoryThatContainsEvents(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+ join = os.path.join
+ realdir = join(tmpdir, 'event_containing_directory')
+
+ if gfile.IsDirectory(realdir):
+ gfile.DeleteRecursively(realdir)
+ gfile.MkDir(realdir)
+
+ self.assertEqual(x.Runs(), {})
+
+ with gfile.GFile(join(realdir, 'hypothetical.tfevents.out'), 'w'):
+ pass
+ x.AddRunsFromDirectory(realdir)
+ self.assertItemsEqual(x.Runs(), ['event_containing_directory'])
+
+ subdir = join(realdir, 'subdir')
+ gfile.MkDir(subdir)
+ x.AddRunsFromDirectory(realdir)
+ self.assertItemsEqual(x.Runs(), ['event_containing_directory', 'subdir'])
+
+ def testAddRunsFromDirectoryWithRunNames(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+ join = os.path.join
+ realdir = join(tmpdir, 'event_containing_directory')
+
+ if gfile.IsDirectory(realdir):
+ gfile.DeleteRecursively(realdir)
+ gfile.MkDir(realdir)
+
+ self.assertEqual(x.Runs(), {})
+
+ with gfile.GFile(join(realdir, 'hypothetical.tfevents.out'), 'w'):
+ pass
+ x.AddRunsFromDirectory(realdir, 'foo')
+ self.assertItemsEqual(x.Runs(), ['foo'])
+
+ subdir = join(realdir, 'subdir')
+ gfile.MkDir(subdir)
+ x.AddRunsFromDirectory(realdir, 'foo')
+ self.assertItemsEqual(x.Runs(), ['foo', 'foo/subdir'])
+
+ def testAddRunsFromDirectoryThrowsException(self):
+ x = event_multiplexer.EventMultiplexer()
+ tmpdir = self.get_temp_dir()
+
+ filepath = os.path.join(tmpdir, 'bad_file')
+ with gfile.GFile(filepath, 'w'):
+ pass
+
+ with self.assertRaises(ValueError):
+ x.AddRunsFromDirectory(filepath)
+
+ def testAddRun(self):
+ x = event_multiplexer.EventMultiplexer()
+ x.AddRun('run1_path', 'run1')
+ run1 = x._GetAccumulator('run1')
+ self.assertEqual(x.Runs().keys(), ['run1'])
+ self.assertEqual(run1._path, 'run1_path')
+
+ x.AddRun('run1_path', 'run1')
+ self.assertEqual(run1, x._GetAccumulator('run1'), 'loader not recreated')
+
+ x.AddRun('run2_path', 'run1')
+ new_run1 = x._GetAccumulator('run1')
+ self.assertEqual(new_run1._path, 'run2_path')
+ self.assertNotEqual(run1, new_run1)
+
+ x.AddRun('runName3')
+ self.assertItemsEqual(x.Runs().keys(), ['run1', 'runName3'])
+ self.assertEqual(x._GetAccumulator('runName3')._path, 'runName3')
+
+ def testAddRunMaintainsLoading(self):
+ x = event_multiplexer.EventMultiplexer()
+ x.Reload()
+ x.AddRun('run1')
+ x.AddRun('run2')
+ self.assertTrue(x._GetAccumulator('run1').reload_called)
+ self.assertTrue(x._GetAccumulator('run2').reload_called)
+
+ def testAddRunMaintainsAutoUpdate(self):
+ x = event_multiplexer.EventMultiplexer()
+ x.AutoUpdate(5)
+ x.AddRun('run1')
+ x.AddRun('run2')
+ self.assertTrue(x._GetAccumulator('run1').autoupdate_called)
+ self.assertTrue(x._GetAccumulator('run2').autoupdate_called)
+ self.assertEqual(x._GetAccumulator('run1').autoupdate_interval, 5)
+ self.assertEqual(x._GetAccumulator('run2').autoupdate_interval, 5)
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/summary/impl/__init__.py b/tensorflow/python/summary/impl/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/summary/impl/__init__.py
diff --git a/tensorflow/python/summary/impl/directory_watcher.py b/tensorflow/python/summary/impl/directory_watcher.py
new file mode 100644
index 0000000000..830e538cb6
--- /dev/null
+++ b/tensorflow/python/summary/impl/directory_watcher.py
@@ -0,0 +1,115 @@
+"""Contains the implementation for the DirectoryWatcher class."""
+import os
+
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import logging
+
+
+class DirectoryWatcher(object):
+ """A DirectoryWatcher wraps a loader to load from a directory.
+
+ A loader reads a file on disk and produces some kind of values as an
+ iterator. A DirectoryWatcher takes a directory with one file at a time being
+ written to and a factory for loaders and watches all the files at once.
+
+ This class is *only* valid under the assumption that files are never removed
+ and the only file ever changed is whichever one is lexicographically last.
+ """
+
+ def __init__(self, directory, loader_factory, path_filter=lambda x: True):
+ """Constructs a new DirectoryWatcher.
+
+ Args:
+ directory: The directory to watch. The directory doesn't have to exist.
+ loader_factory: A factory for creating loaders. The factory should take a
+ file path and return an object that has a Load method returning an
+ iterator that will yield all events that have not been yielded yet.
+ path_filter: Only files whose full path matches this predicate will be
+ loaded. If not specified, all files are loaded.
+
+ Raises:
+ ValueError: If directory or loader_factory is None.
+ """
+ if directory is None:
+ raise ValueError('A directory is required')
+ if loader_factory is None:
+ raise ValueError('A loader factory is required')
+ self._directory = directory
+ self._loader_factory = loader_factory
+ self._loader = None
+ self._path = None
+ self._path_filter = path_filter
+
+ def Load(self):
+ """Loads new values from disk.
+
+ The watcher will load from one file at a time; as soon as that file stops
+ yielding events, it will move on to the next file. We assume that old files
+ are never modified after a newer file has been written. As a result, Load()
+ can be called multiple times in a row without losing events that have not
+ been yielded yet. In other words, we guarantee that every event will be
+ yielded exactly once.
+
+ Yields:
+ All values that were written to disk that have not been yielded yet.
+ """
+
+ # If the loader exists, check it for a value.
+ if not self._loader:
+ self._InitializeLoader()
+
+ while True:
+ # Yield all the new events in the file we're currently loading from.
+ for event in self._loader.Load():
+ yield event
+
+ next_path = self._GetNextPath()
+ if not next_path:
+ logging.info('No more files in %s', self._directory)
+ # Current file is empty and there are no new files, so we're done.
+ return
+
+ # There's a new file, so check to make sure there weren't any events
+ # written between when we finished reading the current file and when we
+ # checked for the new one. The sequence of events might look something
+ # like this:
+ #
+ # 1. Event #1 written to file #1.
+ # 2. We check for events and yield event #1 from file #1
+ # 3. We check for events and see that there are no more events in file #1.
+ # 4. Event #2 is written to file #1.
+ # 5. Event #3 is written to file #2.
+ # 6. We check for a new file and see that file #2 exists.
+ #
+ # Without this loop, we would miss event #2. We're also guaranteed by the
+ # loader contract that no more events will be written to file #1 after
+ # events start being written to file #2, so we don't have to worry about
+ # that.
+ for event in self._loader.Load():
+ yield event
+
+ logging.info('Directory watcher for %s advancing to file %s',
+ self._directory, next_path)
+
+ # Advance to the next file and start over.
+ self._SetPath(next_path)
+
+ def _InitializeLoader(self):
+ path = self._GetNextPath()
+ if path:
+ self._SetPath(path)
+ else:
+ raise StopIteration
+
+ def _SetPath(self, path):
+ self._path = path
+ self._loader = self._loader_factory(path)
+
+ def _GetNextPath(self):
+ """Returns the path of the next file to use or None if no file exists."""
+ sorted_paths = [os.path.join(self._directory, path)
+ for path in sorted(gfile.ListDirectory(self._directory))]
+ # We filter here so the filter gets the full directory name.
+ filtered_paths = (path for path in sorted_paths
+ if self._path_filter(path) and path > self._path)
+ return next(filtered_paths, None)
diff --git a/tensorflow/python/summary/impl/directory_watcher_test.py b/tensorflow/python/summary/impl/directory_watcher_test.py
new file mode 100644
index 0000000000..a22e3f2922
--- /dev/null
+++ b/tensorflow/python/summary/impl/directory_watcher_test.py
@@ -0,0 +1,102 @@
+"""Tests for directory_watcher."""
+
+import os
+import shutil
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.summary.impl import directory_watcher
+
+
+class _ByteLoader(object):
+ """A loader that loads individual bytes from a file."""
+
+ def __init__(self, path):
+ self._f = open(path)
+
+ def Load(self):
+ while True:
+ byte = self._f.read(1)
+ if byte:
+ yield byte
+ else:
+ return
+
+
+class DirectoryWatcherTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ # Put everything in a directory so it's easier to delete.
+ self._directory = os.path.join(self.get_temp_dir(), 'monitor_dir')
+ os.mkdir(self._directory)
+ self._watcher = directory_watcher.DirectoryWatcher(
+ self._directory, _ByteLoader)
+
+ def tearDown(self):
+ shutil.rmtree(self._directory)
+
+ def _WriteToFile(self, filename, data):
+ path = os.path.join(self._directory, filename)
+ with open(path, 'a') as f:
+ f.write(data)
+
+ def assertWatcherYields(self, values):
+ self.assertEqual(list(self._watcher.Load()), values)
+
+ def testRaisesWithBadArguments(self):
+ with self.assertRaises(ValueError):
+ directory_watcher.DirectoryWatcher(None, lambda x: [])
+ with self.assertRaises(ValueError):
+ directory_watcher.DirectoryWatcher('asdf', None)
+
+ def testEmptyDirectory(self):
+ self.assertWatcherYields([])
+
+ def testSingleWrite(self):
+ self._WriteToFile('a', 'abc')
+ self.assertWatcherYields(['a', 'b', 'c'])
+
+ def testMultipleWrites(self):
+ self._WriteToFile('a', 'abc')
+ self.assertWatcherYields(['a', 'b', 'c'])
+ self._WriteToFile('a', 'xyz')
+ self.assertWatcherYields(['x', 'y', 'z'])
+
+ def testMultipleLoads(self):
+ self._WriteToFile('a', 'a')
+ self._watcher.Load()
+ self._watcher.Load()
+ self.assertWatcherYields(['a'])
+
+ def testMultipleFilesAtOnce(self):
+ self._WriteToFile('b', 'b')
+ self._WriteToFile('a', 'a')
+ self.assertWatcherYields(['a', 'b'])
+
+ def testFinishesLoadingFileWhenSwitchingToNewFile(self):
+ self._WriteToFile('a', 'a')
+ # Empty the iterator.
+ self.assertEquals(['a'], list(self._watcher.Load()))
+ self._WriteToFile('a', 'b')
+ self._WriteToFile('b', 'c')
+ # The watcher should finish its current file before starting a new one.
+ self.assertWatcherYields(['b', 'c'])
+
+ def testIntermediateEmptyFiles(self):
+ self._WriteToFile('a', 'a')
+ self._WriteToFile('b', '')
+ self._WriteToFile('c', 'c')
+ self.assertWatcherYields(['a', 'c'])
+
+ def testFileFilter(self):
+ self._watcher = directory_watcher.DirectoryWatcher(
+ self._directory, _ByteLoader,
+ path_filter=lambda path: 'do_not_watch_me' not in path)
+
+ self._WriteToFile('a', 'a')
+ self._WriteToFile('do_not_watch_me', 'b')
+ self._WriteToFile('c', 'c')
+ self.assertWatcherYields(['a', 'c'])
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/summary/impl/event_file_loader.py b/tensorflow/python/summary/impl/event_file_loader.py
new file mode 100644
index 0000000000..0571bc84cb
--- /dev/null
+++ b/tensorflow/python/summary/impl/event_file_loader.py
@@ -0,0 +1,49 @@
+"""Functionality for loading events from a record file."""
+
+from tensorflow.core.util import event_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.platform import app
+from tensorflow.python.platform import logging
+
+
+class EventFileLoader(object):
+ """An EventLoader is an iterator that yields Event protos."""
+
+ def __init__(self, file_path):
+ if file_path is None:
+ raise ValueError('A file path is required')
+ logging.debug('Opening a record reader pointing at %s', file_path)
+ self._reader = pywrap_tensorflow.PyRecordReader_New(file_path, 0)
+ # Store it for logging purposes.
+ self._file_path = file_path
+ if not self._reader:
+ raise IOError('Failed to open a record reader pointing to %s' % file_path)
+
+ def Load(self):
+ """Loads all new values from disk.
+
+ Calling Load multiple times in a row will not 'drop' events as long as the
+ return value is not iterated over.
+
+ Yields:
+ All values that were written to disk that have not been yielded yet.
+ """
+ while self._reader.GetNext():
+ logging.debug('Got an event from %s', self._file_path)
+ event = event_pb2.Event()
+ event.ParseFromString(self._reader.record())
+ yield event
+ logging.debug('No more events in %s', self._file_path)
+
+
+def main(argv):
+ if len(argv) != 2:
+ print 'Usage: event_file_loader <path-to-the-recordio-file>'
+ return 1
+ loader = EventFileLoader(argv[1])
+ for event in loader.Load():
+ print event
+
+
+if __name__ == '__main__':
+ app.run()
diff --git a/tensorflow/python/summary/impl/event_file_loader_test.py b/tensorflow/python/summary/impl/event_file_loader_test.py
new file mode 100644
index 0000000000..1dc29d85d5
--- /dev/null
+++ b/tensorflow/python/summary/impl/event_file_loader_test.py
@@ -0,0 +1,59 @@
+"""Tests for event_file_loader."""
+
+import os
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.summary.impl import event_file_loader
+
+
+class EventFileLoaderTest(test_util.TensorFlowTestCase):
+ # A record containing a simple event.
+ RECORD = ('\x18\x00\x00\x00\x00\x00\x00\x00\xa3\x7fK"\t\x00\x00\xc0%\xddu'
+ '\xd5A\x1a\rbrain.Event:1\xec\xf32\x8d')
+
+ def _WriteToFile(self, filename, data):
+ path = os.path.join(self.get_temp_dir(), filename)
+ with open(path, 'ab') as f:
+ f.write(data)
+
+ def _LoaderForTestFile(self, filename):
+ return event_file_loader.EventFileLoader(
+ os.path.join(self.get_temp_dir(), filename))
+
+ def testEmptyEventFile(self):
+ self._WriteToFile('empty_event_file', '')
+ loader = self._LoaderForTestFile('empty_event_file')
+ self.assertEquals(len(list(loader.Load())), 0)
+
+ def testSingleWrite(self):
+ self._WriteToFile('single_event_file', EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile('single_event_file')
+ events = list(loader.Load())
+ self.assertEquals(len(events), 1)
+ self.assertEquals(events[0].wall_time, 1440183447.0)
+ self.assertEquals(len(list(loader.Load())), 0)
+
+ def testMultipleWrites(self):
+ self._WriteToFile('staggered_event_file', EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile('staggered_event_file')
+ self.assertEquals(len(list(loader.Load())), 1)
+ self._WriteToFile('staggered_event_file', EventFileLoaderTest.RECORD)
+ self.assertEquals(len(list(loader.Load())), 1)
+
+ def testMultipleLoads(self):
+ self._WriteToFile('multiple_loads_event_file', EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile('multiple_loads_event_file')
+ loader.Load()
+ loader.Load()
+ self.assertEquals(len(list(loader.Load())), 1)
+
+ def testMultipleWritesAtOnce(self):
+ self._WriteToFile('multiple_event_file', EventFileLoaderTest.RECORD)
+ self._WriteToFile('multiple_event_file', EventFileLoaderTest.RECORD)
+ loader = self._LoaderForTestFile('staggered_event_file')
+ self.assertEquals(len(list(loader.Load())), 2)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/summary/impl/reservoir.py b/tensorflow/python/summary/impl/reservoir.py
new file mode 100644
index 0000000000..2c9b294841
--- /dev/null
+++ b/tensorflow/python/summary/impl/reservoir.py
@@ -0,0 +1,164 @@
+"""A key-value[] store that implements reservoir sampling on the values."""
+
+import collections
+import random
+import threading
+
+
+class Reservoir(object):
+ """A map-to-arrays container, with deterministic Reservoir Sampling.
+
+ Items are added with an associated key. Items may be retrieved by key, and
+ a list of keys can also be retrieved. If size is not zero, then it dictates
+ the maximum number of items that will be stored with each key. Once there are
+ more items for a given key, they are replaced via reservoir sampling, such
+ that each item has an equal probability of being included in the sample.
+
+ Deterministic means that for any given seed and bucket size, the sequence of
+ values that are kept for any given tag will always be the same, and that this
+ is independent of any insertions on other tags. That is:
+
+ >>> separate_reservoir = reservoir.Reservoir(10)
+ >>> interleaved_reservoir = reservoir.Reservoir(10)
+ >>> for i in xrange(100):
+ >>> separate_reservoir.AddItem('key1', i)
+ >>> for i in xrange(100):
+ >>> separate_reservoir.AddItem('key2', i)
+ >>> for i in xrange(100):
+ >>> interleaved_reservoir.AddItem('key1', i)
+ >>> interleaved_reservoir.AddItem('key2', i)
+
+ separate_reservoir and interleaved_reservoir will be in identical states.
+
+ See: https://en.wikipedia.org/wiki/Reservoir_sampling
+
+ Adding items has amortized O(1) runtime.
+
+ """
+
+ def __init__(self, size, seed=0):
+ """Creates a new reservoir.
+
+ Args:
+ size: The number of values to keep in the reservoir for each tag. If 0,
+ all values will be kept.
+ seed: The seed of the random number generator to use when sampling.
+ Different values for |seed| will produce different samples from the same
+ input items.
+
+ Raises:
+ ValueError: If size is negative or not an integer.
+ """
+ if size < 0 or size != round(size):
+ raise ValueError('size must be nonegative integer, was %s' % size)
+ self._buckets = collections.defaultdict(
+ lambda: _ReservoirBucket(size, random.Random(seed)))
+ # _mutex guards the keys - creating new keys, retreiving by key, etc
+ # the internal items are guarded by the ReservoirBuckets' internal mutexes
+ self._mutex = threading.Lock()
+
+ def Keys(self):
+ """Return all the keys in the reservoir.
+
+ Returns:
+ ['list', 'of', 'keys'] in the Reservoir.
+ """
+ with self._mutex:
+ return self._buckets.keys()
+
+ def Items(self, key):
+ """Return items associated with given key.
+
+ Args:
+ key: The key for which we are finding associated items.
+
+ Raises:
+ KeyError: If the key is not ofund in the reservoir.
+
+ Returns:
+ [list, of, items] associated with that key.
+ """
+ with self._mutex:
+ if key not in self._buckets:
+ raise KeyError('Key %s was not found in Reservoir' % key)
+ bucket = self._buckets[key]
+ return bucket.Items()
+
+ def AddItem(self, key, item):
+ """Add a new item to the Reservoir with the given tag.
+
+ The new item is guaranteed to be kept in the Reservoir. One other item might
+ be replaced.
+
+ Args:
+ key: The key to store the item under.
+ item: The item to add to the reservoir.
+ """
+ with self._mutex:
+ bucket = self._buckets[key]
+ bucket.AddItem(item)
+
+
+class _ReservoirBucket(object):
+ """A container for items from a stream, that implements reservoir sampling.
+
+ It always stores the most recent item as its final item.
+ """
+
+ def __init__(self, _max_size, _random=None):
+ """Create the _ReservoirBucket.
+
+ Args:
+ _max_size: The maximum size the reservoir bucket may grow to. If size is
+ zero, the bucket has unbounded size.
+ _random: The random number generator to use. If not specified, defaults to
+ random.Random(0).
+
+ Raises:
+ ValueError: if the size is not a nonnegative integer.
+ """
+ if _max_size < 0 or _max_size != round(_max_size):
+ raise ValueError('_max_size must be nonegative int, was %s' % _max_size)
+ self.items = []
+ # This mutex protects the internal items, ensuring that calls to Items and
+ # AddItem are thread-safe
+ self._mutex = threading.Lock()
+ self._max_size = _max_size
+ self._count = 0
+ if _random is not None:
+ self._random = _random
+ else:
+ self._random = random.Random(0)
+
+ def AddItem(self, item):
+ """Add an item to the ReservoirBucket, replacing an old item if necessary.
+
+ The new item is guaranteed to be added to the bucket, and to be the last
+ element in the bucket. If the bucket has reached capacity, then an old item
+ will be replaced. With probability (_max_size/_count) a random item in the
+ bucket will be popped out and the new item will be appended to the end. With
+ probability (1 - _max_size/_count) the last item in the bucket will be
+ replaced.
+
+ Since the O(n) replacements occur with O(1/_count) liklihood, the amortized
+ runtime is O(1).
+
+ Args:
+ item: The item to add to the bucket.
+ """
+ with self._mutex:
+ if len(self.items) < self._max_size or self._max_size == 0:
+ self.items.append(item)
+ else:
+ r = self._random.randint(0, self._count)
+ if r < self._max_size:
+ self.items.pop(r)
+ self.items.append(item)
+ else:
+ self.items[-1] = item
+ self._count += 1
+
+ def Items(self):
+ """Get all the items in the bucket."""
+ with self._mutex:
+ return self.items
diff --git a/tensorflow/python/summary/impl/reservoir_test.py b/tensorflow/python/summary/impl/reservoir_test.py
new file mode 100644
index 0000000000..46cbde5940
--- /dev/null
+++ b/tensorflow/python/summary/impl/reservoir_test.py
@@ -0,0 +1,178 @@
+import tensorflow.python.platform
+
+from tensorflow.python.platform import googletest
+from tensorflow.python.summary.impl import reservoir
+
+
+class ReservoirTest(googletest.TestCase):
+
+ def testEmptyReservoir(self):
+ r = reservoir.Reservoir(1)
+ self.assertFalse(r.Keys())
+
+ def testRespectsSize(self):
+ r = reservoir.Reservoir(42)
+ self.assertEqual(r._buckets['meaning of life']._max_size, 42)
+
+ def testItemsAndKeys(self):
+ r = reservoir.Reservoir(42)
+ r.AddItem('foo', 4)
+ r.AddItem('bar', 9)
+ r.AddItem('foo', 19)
+ self.assertItemsEqual(r.Keys(), ['foo', 'bar'])
+ self.assertEqual(r.Items('foo'), [4, 19])
+ self.assertEqual(r.Items('bar'), [9])
+
+ def testExceptions(self):
+ with self.assertRaises(ValueError):
+ reservoir.Reservoir(-1)
+ with self.assertRaises(ValueError):
+ reservoir.Reservoir(13.3)
+
+ r = reservoir.Reservoir(12)
+ with self.assertRaises(KeyError):
+ r.Items('missing key')
+
+ def testDeterminism(self):
+ """Tests that the reservoir is deterministic."""
+ key = 'key'
+ r1 = reservoir.Reservoir(10)
+ r2 = reservoir.Reservoir(10)
+ for i in xrange(100):
+ r1.AddItem('key', i)
+ r2.AddItem('key', i)
+
+ self.assertEqual(r1.Items(key), r2.Items(key))
+
+ def testBucketDeterminism(self):
+ """Tests that reservoirs are deterministic at a bucket level.
+
+ This means that only the order elements are added within a bucket matters.
+ """
+ separate_reservoir = reservoir.Reservoir(10)
+ interleaved_reservoir = reservoir.Reservoir(10)
+ for i in xrange(100):
+ separate_reservoir.AddItem('key1', i)
+ for i in xrange(100):
+ separate_reservoir.AddItem('key2', i)
+ for i in xrange(100):
+ interleaved_reservoir.AddItem('key1', i)
+ interleaved_reservoir.AddItem('key2', i)
+
+ for key in ['key1', 'key2']:
+ self.assertEqual(separate_reservoir.Items(key),
+ interleaved_reservoir.Items(key))
+
+ def testUsesSeed(self):
+ """Tests that reservoirs with different seeds keep different samples."""
+ key = 'key'
+ r1 = reservoir.Reservoir(10, seed=0)
+ r2 = reservoir.Reservoir(10, seed=1)
+ for i in xrange(100):
+ r1.AddItem('key', i)
+ r2.AddItem('key', i)
+ self.assertNotEqual(r1.Items(key), r2.Items(key))
+
+
+class ReservoirBucketTest(googletest.TestCase):
+
+ def testEmptyBucket(self):
+ b = reservoir._ReservoirBucket(1)
+ self.assertFalse(b.Items())
+
+ def testFillToSize(self):
+ b = reservoir._ReservoirBucket(100)
+ for i in xrange(100):
+ b.AddItem(i)
+ self.assertEqual(b.Items(), range(100))
+
+ def testDoesntOverfill(self):
+ b = reservoir._ReservoirBucket(10)
+ for i in xrange(1000):
+ b.AddItem(i)
+ self.assertEqual(len(b.Items()), 10)
+
+ def testMaintainsOrder(self):
+ b = reservoir._ReservoirBucket(100)
+ for i in xrange(10000):
+ b.AddItem(i)
+ items = b.Items()
+ prev = None
+ for item in items:
+ self.assertTrue(item > prev)
+ prev = item
+
+ def testKeepsLatestItem(self):
+ b = reservoir._ReservoirBucket(5)
+ for i in xrange(100):
+ b.AddItem(i)
+ last = b.Items()[-1]
+ self.assertEqual(last, i)
+
+ def testSizeOneBucket(self):
+ b = reservoir._ReservoirBucket(1)
+ for i in xrange(20):
+ b.AddItem(i)
+ self.assertEqual(b.Items(), [i])
+
+ def testSizeZeroBucket(self):
+ b = reservoir._ReservoirBucket(0)
+ for i in xrange(20):
+ b.AddItem(i)
+ self.assertEqual(b.Items(), range(i+1))
+
+ def testSizeRequirement(self):
+ with self.assertRaises(ValueError):
+ reservoir._ReservoirBucket(-1)
+ with self.assertRaises(ValueError):
+ reservoir._ReservoirBucket(10.3)
+
+
+class ReservoirBucketStatisticalDistributionTest(googletest.TestCase):
+
+ def setUp(self):
+ self.total = 1000000
+ self.samples = 10000
+ self.n_buckets = 100
+ self.total_per_bucket = self.total / self.n_buckets
+ self.assertEqual(self.total % self.n_buckets, 0, 'total must be evenly '
+ 'divisible by the number of buckets')
+ self.assertTrue(self.total > self.samples, 'need to have more items '
+ 'than samples')
+
+ def AssertBinomialQuantity(self, measured):
+ p = 1.0 * self.n_buckets / self.samples
+ mean = p * self.samples
+ variance = p * (1 - p) * self.samples
+ error = measured - mean
+ # Given that the buckets were actually binomially distributed, this
+ # fails with probability ~2E-9
+ passed = error * error <= 36.0 * variance
+ self.assertTrue(passed, 'found a bucket with measured %d '
+ 'too far from expected %d' % (measured, mean))
+
+ def testBucketReservoirSamplingViaStatisticalProperties(self):
+ # Not related to a 'ReservoirBucket', but instead number of buckets we put
+ # samples into for testing the shape of the distribution
+ b = reservoir._ReservoirBucket(_max_size=self.samples)
+ # add one extra item because we always keep the most recent item, which
+ # would skew the distribution; we can just slice it off the end instead.
+ for i in xrange(self.total + 1):
+ b.AddItem(i)
+
+ divbins = [0] * self.n_buckets
+ modbins = [0] * self.n_buckets
+ # Slice off the last item when we iterate.
+ for item in b.Items()[0:-1]:
+ divbins[item / self.total_per_bucket] += 1
+ modbins[item % self.n_buckets] += 1
+
+ for bucket_index in xrange(self.n_buckets):
+ divbin = divbins[bucket_index]
+ modbin = modbins[bucket_index]
+ self.AssertBinomialQuantity(divbin)
+ self.AssertBinomialQuantity(modbin)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
new file mode 100644
index 0000000000..d26f12a89c
--- /dev/null
+++ b/tensorflow/python/tensorflow.i
@@ -0,0 +1,14 @@
+/* SWIG wrapper for all of TensorFlow native functionality.
+ * The includes are intentionally not alphabetically sorted, as the order of
+ * includes follows dependency order */
+
+%include "tensorflow/python/util/port.i"
+
+%include "tensorflow/python/lib/core/status.i"
+%include "tensorflow/python/lib/core/status_helper.i"
+
+%include "tensorflow/python/lib/io/py_record_reader.i"
+%include "tensorflow/python/lib/io/py_record_writer.i"
+%include "tensorflow/python/client/events_writer.i"
+
+%include "tensorflow/python/client/tf_session.i"
diff --git a/tensorflow/python/training/__init__.py b/tensorflow/python/training/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/training/__init__.py
diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py
new file mode 100644
index 0000000000..41cf2e00f4
--- /dev/null
+++ b/tensorflow/python/training/adagrad.py
@@ -0,0 +1,58 @@
+"""Adagrad for TensorFlow."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_ops
+
+
+class AdagradOptimizer(optimizer.Optimizer):
+ """Optimizer that implements the Adagrad algorithm.
+
+ @@__init__
+ """
+
+ def __init__(self, learning_rate, initial_accumulator_value=0.1,
+ use_locking=False, name="Adagrad"):
+ """Construct a new Adagrad optimizer.
+
+ Args:
+ learning_rate: A `Tensor` or a floating point value. The learning rate.
+ initial_accumulator_value: A floating point value.
+ Starting value for the accumulators, must be positive.
+ use_locking: If `True` use locks for update operations.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "Adagrad".
+
+ Raises:
+ ValueError: If the initial_accumulator_value is invalid.
+ """
+ if initial_accumulator_value <= 0.0:
+ raise ValueError("initial_accumulator_value must be positive: %s" %
+ initial_accumulator_value)
+ super(AdagradOptimizer, self).__init__(use_locking, name)
+ self._learning_rate = learning_rate
+ self._initial_accumulator_value = initial_accumulator_value
+ # Created in Initialize.
+ self._learning_rate_tensor = None
+
+ def _create_slots(self, var_list):
+ for v in var_list:
+ val = constant_op.constant(self._initial_accumulator_value,
+ shape=v.get_shape())
+ self._get_or_make_slot(v, val, "accumulator", self._name)
+
+ def _prepare(self):
+ self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
+ name="learning_rate")
+
+ def _apply_dense(self, grad, var):
+ acc = self.get_slot(var, "accumulator")
+ return training_ops.apply_adagrad(
+ var, acc, self._learning_rate_tensor, grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var):
+ acc = self.get_slot(var, "accumulator")
+ return training_ops.sparse_apply_adagrad(
+ var, acc, self._learning_rate_tensor, grad.values, grad.indices,
+ use_locking=self._use_locking)
diff --git a/tensorflow/python/training/adagrad_test.py b/tensorflow/python/training/adagrad_test.py
new file mode 100644
index 0000000000..ee83791eb5
--- /dev/null
+++ b/tensorflow/python/training/adagrad_test.py
@@ -0,0 +1,144 @@
+"""Functional tests for aggregate operations."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class AdagradOptimizerTest(tf.test.TestCase):
+
+ def testBasic(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+ ada_opt = tf.train.AdagradOptimizer(3.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 3 steps of adagrad
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllClose(np.array([-1.6026098728179932, -0.6026098728179932]),
+ var0.eval())
+ self.assertAllClose(np.array([2.715679168701172, 3.715679168701172]),
+ var1.eval())
+
+ def testFloat64(self):
+ with self.test_session():
+ opt = tf.train.AdagradOptimizer(3.0, initial_accumulator_value=0.1)
+
+ # compute_gradients.
+ values = [1.0, 3.0]
+ good_vars = [tf.Variable([v]) for v in values]
+ bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
+ opt.compute_gradients, bad_loss, good_vars)
+ bad_vars = [
+ tf.Variable(np.array([v], np.float64), name="bad_var")
+ for v in values]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
+ bad_vars)
+ opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
+
+ # apply_gradients.
+ bad_grads = [
+ tf.constant([0.1], dtype=np.float64, name="bad_grad"),
+ tf.constant([0.01])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
+ opt.apply_gradients, zip(bad_grads, good_vars))
+ good_grads = [tf.constant([0.01]), tf.constant([0.02])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.apply_gradients, zip(good_grads, bad_vars))
+ opt.apply_gradients(zip(good_grads, good_vars))
+
+ def testSparseBasic(self):
+ with self.test_session():
+ var0 = tf.Variable([[1.0], [2.0]])
+ var1 = tf.Variable([[3.0], [4.0]])
+ grads0 = tf.IndexedSlices(tf.constant([0.1], shape=[1, 1]),
+ tf.constant([0]),
+ tf.constant([2, 1]))
+ grads1 = tf.IndexedSlices(tf.constant([0.01], shape=[1, 1]),
+ tf.constant([1]),
+ tf.constant([2, 1]))
+ ada_opt = tf.train.AdagradOptimizer(3.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([[1.0], [2.0]], var0.eval())
+ self.assertAllClose([[3.0], [4.0]], var1.eval())
+ # Run 3 step of sgd
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllClose([[-1.6026098728179932], [2.0]], var0.eval())
+ self.assertAllClose([[3.0], [3.715679168701172]], var1.eval())
+
+ def testSparseStability(self):
+ with self.test_session():
+ shape = [1, 6]
+ var0 = tf.Variable([[0.00872496, -0.106952, 0.110467, 0.226505,
+ -0.0147257, -0.0105945]])
+ grads0 = tf.IndexedSlices(
+ tf.constant(
+ [[-5.91278e-05, 5.31673e-05, -2.5779e-06, 4.29153e-05,
+ -8.4877e-05, -9.48906e-05]],
+ shape=shape),
+ tf.constant([0]),
+ tf.constant(shape))
+ ada_opt = tf.train.AdagradOptimizer(1.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(zip([grads0], [var0]))
+ self.assertEqual(["accumulator"], ada_opt.get_slot_names())
+ slot0 = ada_opt.get_slot(var0, "accumulator")
+ init = tf.initialize_all_variables()
+ for _ in range(100):
+ init.run()
+ ada_update.run()
+ self.assertAllClose([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], slot0.eval())
+ self.assertAllClose(
+ [[0.00891194, -0.10712013, 0.11047515, 0.22636929,
+ -0.0144573, -0.01029443]], var0.eval())
+
+ def testSharing(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+ ada_opt = tf.train.AdagradOptimizer(3.0)
+ # Apply the optimizer twice. Both applications will use the same accums.
+ ada_update1 = ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ ada_update2 = ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ self.assertEqual(["accumulator"], ada_opt.get_slot_names())
+ slot0 = ada_opt.get_slot(var0, "accumulator")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = ada_opt.get_slot(var1, "accumulator")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ tf.initialize_all_variables().run()
+
+ # Fetch params to validate initial values.
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Mix the first and the second adagrad for 3 steps.
+ ada_update1.run()
+ ada_update2.run()
+ ada_update1.run()
+ # Validate updated params (the same as with only 1 Adagrad).
+ self.assertAllClose(np.array([-1.6026098728179932, -0.6026098728179932]),
+ var0.eval())
+ self.assertAllClose(np.array([2.715679168701172, 3.715679168701172]),
+ var1.eval())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
new file mode 100644
index 0000000000..266430bb13
--- /dev/null
+++ b/tensorflow/python/training/adam.py
@@ -0,0 +1,142 @@
+"""Adam for TensorFlow."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_ops
+
+
+class AdamOptimizer(optimizer.Optimizer):
+ """Optimizer that implements the Adam algorithm.
+
+ @@__init__
+ """
+
+ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
+ use_locking=False, name="Adam"):
+ """Construct a new Adam optimizer.
+
+ Implementation is based on: http://arxiv.org/pdf/1412.6980v7.pdf
+
+ Initialization:
+
+ ```
+ m_0 <- 0 (Initialize initial 1st moment vector)
+ v_0 <- 0 (Initialize initial 2nd moment vector)
+ t <- 0 (Initialize timestep)
+ ```
+
+ The update rule for `variable` with gradient `g` uses an optimization
+ described at the end of section2 of the paper:
+
+ ```
+ t <- t + 1
+ lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
+
+ m_t <- beta1 * m_{t-1} + (1 - beta1) * g
+ v_t <- beta2 * v_{t-1} + (1 - beta2) * g * g
+ variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon)
+ ```
+
+ The default value of 1e-8 for epsilon might not be a good default in
+ general. For example, when training an Inception network on ImageNet a
+ current good choice is 1.0 or 0.1.
+
+ Args:
+ learning_rate: A Tensor or a floating point value. The learning rate.
+ beta1: A float value or a constant float tensor.
+ The exponential decay rate for the 1st moment estimates.
+ beta2: A float value or a constant float tensor.
+ The exponential decay rate for the 2st moment estimates.
+ epsilon: A small constant for numerical stability.
+ use_locking: If True use locks for update operation.s
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "Adam".
+ """
+ super(AdamOptimizer, self).__init__(use_locking, name)
+ self._lr = learning_rate
+ self._beta1 = beta1
+ self._beta2 = beta2
+ self._epsilon = epsilon
+
+ # Tensor versions of the constructor arguments, created in _prepare().
+ self._lr_t = None
+ self._beta1_t = None
+ self._beta2_t = None
+ self._epsilon_t = None
+
+ # Variables to accumulate the powers of the beta parameters.
+ # Created in _create_slots when we know the variables to optimize.
+ self._beta1_power = None
+ self._beta2_power = None
+
+ # Created in SparseApply if needed.
+ self._updated_lr = None
+
+ def _get_beta_accumulators(self):
+ return self._beta1_power, self._beta2_power
+
+ def _create_slots(self, var_list):
+ # Create the beta1 and beta2 accumulators on the same device as the first
+ # variable.
+ if self._beta1_power is None:
+ with ops.device(var_list[0].device):
+ self._beta1_power = variables.Variable(self._beta1, name="beta1_power")
+ self._beta2_power = variables.Variable(self._beta2, name="beta2_power")
+ # Create slots for the first and second moments.
+ for v in var_list:
+ self._zeros_slot(v, "m", self._name)
+ self._zeros_slot(v, "v", self._name)
+
+ def _prepare(self):
+ self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate")
+ self._beta1_t = ops.convert_to_tensor(self._beta1, name="beta1")
+ self._beta2_t = ops.convert_to_tensor(self._beta2, name="beta2")
+ self._epsilon_t = ops.convert_to_tensor(self._epsilon, name="epsilon")
+
+ def _apply_dense(self, grad, var):
+ m = self.get_slot(var, "m")
+ v = self.get_slot(var, "v")
+ return training_ops.apply_adam(
+ var, m, v, self._beta1_power, self._beta2_power,
+ self._lr_t, self._beta1_t, self._beta2_t,
+ self._epsilon_t, grad, use_locking=self._use_locking).op
+
+ def _apply_sparse(self, grad, var):
+ lr = (self._lr_t *
+ math_ops.sqrt(1 - self._beta2_power)
+ / (1 - self._beta1_power))
+ # m_t = beta1 * m + (1 - beta1) * g_t
+ m = self.get_slot(var, "m")
+ m_scaled_g_values = grad.values * (1 - self._beta1_t)
+ m_t = state_ops.assign(m, m * self._beta1_t,
+ use_locking=self._use_locking)
+ m_t = state_ops.scatter_add(m_t, grad.indices, m_scaled_g_values,
+ use_locking=self._use_locking)
+ # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
+ v = self.get_slot(var, "v")
+ v_scaled_g_values = (grad.values * grad.values) * (1 - self._beta2_t)
+ v_t = state_ops.assign(v, v * self._beta2_t, use_locking=self._use_locking)
+ v_t = state_ops.scatter_add(v_t, grad.indices, v_scaled_g_values,
+ use_locking=self._use_locking)
+ v_sqrt = math_ops.sqrt(v_t)
+ var_update = state_ops.assign_sub(var,
+ lr * m_t / (v_sqrt + self._epsilon_t),
+ use_locking=self._use_locking)
+ return control_flow_ops.group(*[var_update, m_t, v_t])
+
+ def _finish(self, update_ops, name_scope):
+ # Update the power accumulators.
+ with ops.control_dependencies(update_ops):
+ with ops.device(self._beta1_power.device):
+ update_beta1 = self._beta1_power.assign(
+ self._beta1_power * self._beta1_t,
+ use_locking=self._use_locking)
+ update_beta2 = self._beta2_power.assign(
+ self._beta2_power * self._beta2_t,
+ use_locking=self._use_locking)
+ return control_flow_ops.group(*update_ops + [update_beta1, update_beta2],
+ name=name_scope)
diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py
new file mode 100644
index 0000000000..f92728d0c7
--- /dev/null
+++ b/tensorflow/python/training/adam_test.py
@@ -0,0 +1,174 @@
+"""Tests for Adam."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+def adam_update_numpy(param, g_t, t, m, v, alpha=0.001, beta1=0.9, beta2=0.999,
+ epsilon=1e-8):
+ alpha_t = alpha * np.sqrt(1 - beta2 ** t) / (1 - beta1 ** t)
+
+ m_t = beta1 * m + (1 - beta1) * g_t
+ v_t = beta2 * v + (1 - beta2) * g_t * g_t
+
+ param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon)
+ return param_t, m_t, v_t
+
+
+class AdamOptimizerTest(tf.test.TestCase):
+
+ def testSparse(self):
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=np.float32)
+ grads0_np = np.array([0.1, 0.1], dtype=np.float32)
+ var1_np = np.array([3.0, 4.0], dtype=np.float32)
+ grads1_np = np.array([0.01, 0.01], dtype=np.float32)
+
+ var0 = tf.Variable(var0_np)
+ var1 = tf.Variable(var1_np)
+ grads0_np_indices = np.array([0, 1], dtype=np.int32)
+ grads0 = tf.IndexedSlices(tf.constant(grads0_np),
+ tf.constant(grads0_np_indices),
+ tf.constant([2]))
+ grads1_np_indices = np.array([0, 1], dtype=np.int32)
+ grads1 = tf.IndexedSlices(tf.constant(grads1_np),
+ tf.constant(grads1_np_indices),
+ tf.constant([2]))
+ opt = tf.train.AdamOptimizer()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ self.assertAllClose(0.9 ** t, beta1_power.eval())
+ self.assertAllClose(0.999 ** t, beta2_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllClose(var0_np, var0.eval())
+ self.assertAllClose(var1_np, var1.eval())
+
+ def testBasic(self):
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=np.float32)
+ grads0_np = np.array([0.1, 0.1], dtype=np.float32)
+ var1_np = np.array([3.0, 4.0], dtype=np.float32)
+ grads1_np = np.array([0.01, 0.01], dtype=np.float32)
+
+ var0 = tf.Variable(var0_np)
+ var1 = tf.Variable(var1_np)
+ grads0 = tf.constant(grads0_np)
+ grads1 = tf.constant(grads1_np)
+ opt = tf.train.AdamOptimizer()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ self.assertAllClose(0.9 ** t, beta1_power.eval())
+ self.assertAllClose(0.999 ** t, beta2_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllClose(var0_np, var0.eval())
+ self.assertAllClose(var1_np, var1.eval())
+
+ def testFloat64(self):
+ with self.test_session():
+ opt = tf.train.AdamOptimizer()
+
+ # compute_gradients.
+ values = [1.0, 3.0]
+ good_vars = [tf.Variable([v]) for v in values]
+ bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
+ opt.compute_gradients, bad_loss, good_vars)
+ bad_vars = [
+ tf.Variable(np.array([v], np.float64), name="bad_var")
+ for v in values]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
+ bad_vars)
+ opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
+
+ # apply_gradients.
+ bad_grads = [
+ tf.constant([0.1], dtype=np.float64, name="bad_grad"),
+ tf.constant([0.01])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
+ opt.apply_gradients, zip(bad_grads, good_vars))
+ good_grads = [tf.constant([0.01]), tf.constant([0.02])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.apply_gradients, zip(good_grads, bad_vars))
+ opt.apply_gradients(zip(good_grads, good_vars))
+
+ def testSharing(self):
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=np.float32)
+ grads0_np = np.array([0.1, 0.1], dtype=np.float32)
+ var1_np = np.array([3.0, 4.0], dtype=np.float32)
+ grads1_np = np.array([0.01, 0.01], dtype=np.float32)
+
+ var0 = tf.Variable(var0_np)
+ var1 = tf.Variable(var1_np)
+ grads0 = tf.constant(grads0_np)
+ grads1 = tf.constant(grads1_np)
+ opt = tf.train.AdamOptimizer()
+ update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 3 steps of intertwined Adam1 and Adam2.
+ for t in range(1, 4):
+ self.assertAllClose(0.9 ** t, beta1_power.eval())
+ self.assertAllClose(0.999 ** t, beta2_power.eval())
+ if t % 2 == 0:
+ update1.run()
+ else:
+ update2.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllClose(var0_np, var0.eval())
+ self.assertAllClose(var1_np, var1.eval())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/checkpoint_state.proto b/tensorflow/python/training/checkpoint_state.proto
new file mode 100644
index 0000000000..1f521341f1
--- /dev/null
+++ b/tensorflow/python/training/checkpoint_state.proto
@@ -0,0 +1,18 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+// Protocol buffer representing the checkpoint state.
+//
+// TODO(mdevin): Add other attributes as needed.
+message CheckpointState {
+ // Path to the most-recent model checkpoint.
+ string model_checkpoint_path = 1;
+
+ // Paths to all not-yet-deleted model checkpoints, sorted from oldest to
+ // newest.
+ // Note that the value of model_checkpoint_path should be the last item in
+ // this list.
+ repeated string all_model_checkpoint_paths = 2;
+}
diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py
new file mode 100644
index 0000000000..f090e6d222
--- /dev/null
+++ b/tensorflow/python/training/coordinator.py
@@ -0,0 +1,186 @@
+"""Coordinator to help multiple threads stop when requested."""
+import sys
+import threading
+import time
+
+from tensorflow.python.platform import logging
+
+
+class Coordinator(object):
+ """A coordinator for threads.
+
+ This class implements a simple mechanism to coordinate the termination of a
+ set of threads.
+
+ #### Usage:
+
+ ```python
+ # Create a coordinator.
+ coord = Coordinator()
+ # Start a number of threads, passing the coordinator to each of them.
+ ...start thread 1...(coord, ...)
+ ...start thread N...(coord, ...)
+ # Wait for all the threads to terminate.
+ coord.join(threads)
+ ```
+
+ Any of the threads can call `coord.request_stop()` to ask for all the threads
+ to stop. To cooperate with the requests, each thread must check for
+ `coord.should_stop()` on a regular basis. `coord.should_stop()` returns
+ `True` as soon as `coord.request_stop()` has been called.
+
+ A typical thread running with a Coordinator will do something like:
+
+ ```python
+ while not coord.should_stop():
+ ...do some work...
+ ```
+
+ #### Exception handling:
+
+ A thread can report an exception to the Coordinator as part of the
+ `should_stop()` call. The exception will be re-raised from the
+ `coord.join()` call.
+
+ Thread code:
+
+ ```python
+ try:
+ while not coord.should_stop():
+ ...do some work...
+ except Exception, e:
+ coord.request_stop(e)
+ ```
+
+ Main code:
+
+ ```python
+ try:
+ ...
+ coord = Coordinator()
+ # Start a number of threads, passing the coordinator to each of them.
+ ...start thread 1...(coord, ...)
+ ...start thread N...(coord, ...)
+ # Wait for all the threads to terminate.
+ coord.join(threads)
+ except Exception, e:
+ ...exception that was passed to coord.request_stop()
+ ```
+
+ #### Grace period for stopping:
+
+ After a thread has called `coord.request_stop()` the other threads have a
+ fixed time to stop, this is called the 'stop grace period' and defaults to 2
+ minutes. If any of the threads is still alive after the grace period expires
+ `coord.join()` raises a RuntimeException reporting the laggards.
+
+ ```
+ try:
+ ...
+ coord = Coordinator()
+ # Start a number of threads, passing the coordinator to each of them.
+ ...start thread 1...(coord, ...)
+ ...start thread N...(coord, ...)
+ # Wait for all the threads to terminate, give them 10s grace period
+ coord.join(threads, stop_grace_period_secs=10)
+ except RuntimeException:
+ ...one of the threads took more than 10s to stop after request_stop()
+ ...was called.
+ except Exception:
+ ...exception that was passed to coord.request_stop()
+ ```
+ """
+
+ def __init__(self):
+ """Create a new Coordinator."""
+ # Protects all attributes.
+ self._lock = threading.Lock()
+ # Event set when threads must stop.
+ self._stop_event = threading.Event()
+ # Python exc_info to report.
+ self._exc_info_to_raise = None
+
+ def request_stop(self, ex=None):
+ """Request that the threads stop.
+
+ After this is called, calls to should_stop() will return True.
+
+ Args:
+ ex: Optional Exception, or Python 'exc_info' tuple as returned by
+ sys.exc_info(). If this is the first call to request_stop() the
+ corresponding exception is recorded and re-raised from join().
+ """
+ with self._lock:
+ if not self._stop_event.is_set():
+ if ex and self._exc_info_to_raise is None:
+ if isinstance(ex, tuple):
+ logging.info("Error reported to Coordinator: %s", str(ex[1]))
+ self._exc_info_to_raise = ex
+ else:
+ logging.info("Error reported to Coordinator: %s", str(ex))
+ self._exc_info_to_raise = sys.exc_info()
+ self._stop_event.set()
+
+ def should_stop(self):
+ """Check if stop was requested.
+
+ Returns:
+ True if a stop was requested.
+ """
+ return self._stop_event.is_set()
+
+ def wait_for_stop(self, timeout=None):
+ """Wait till the Coordinator is told to stop.
+
+ Args:
+ timeout: float. Sleep for up to that many seconds waiting for
+ should_stop() to become True.
+
+ Returns:
+ True if the Coordinator is told stop, False if the timeout expired.
+ """
+ return self._stop_event.wait(timeout)
+
+ def join(self, threads, stop_grace_period_secs=120):
+ """Wait for threads to terminate.
+
+ Blocks until all 'threads' have terminated or request_stop() is called.
+
+ After the threads stop, if an 'exc_info' was passed to request_stop, that
+ exception is re-reaised.
+
+ Grace period handling: When request_stop() is called, threads are given
+ 'stop_grace_period_secs' seconds to terminate. If any of them is still
+ alive after that period expires, a RuntimeError is raised. Note that if
+ an 'exc_info' was passed to request_stop() then it is raised instead of
+ that RuntimeError.
+
+ Args:
+ threads: List threading.Threads. The started threads to join.
+ stop_grace_period_secs: Number of seconds given to threads to stop after
+ request_stop() has been called.
+
+ Raises:
+ RuntimeError: If any thread is still alive after request_stop()
+ is called and the grace period expires.
+ """
+ # Wait for all threads to stop or for request_stop() to be called.
+ while any(t.is_alive() for t in threads) and not self.wait_for_stop(1.0):
+ pass
+
+ # If any thread is still alive, wait for the grace period to expire.
+ while any(t.is_alive() for t in threads) and stop_grace_period_secs >= 0.0:
+ stop_grace_period_secs -= 1.0
+ time.sleep(1.0)
+
+ # List the threads still alive after the grace period.
+ stragglers = [t.name for t in threads if t.is_alive()]
+
+ # Terminate with an exception if appropriate.
+ with self._lock:
+ if self._exc_info_to_raise:
+ exc_info = self._exc_info_to_raise
+ raise exc_info[0], exc_info[1], exc_info[2]
+ elif stragglers:
+ raise RuntimeError("Coordinator stopped with threads still running: %s",
+ " ".join(stragglers))
diff --git a/tensorflow/python/training/coordinator_test.py b/tensorflow/python/training/coordinator_test.py
new file mode 100644
index 0000000000..ce9126caf4
--- /dev/null
+++ b/tensorflow/python/training/coordinator_test.py
@@ -0,0 +1,98 @@
+"""Tests for Coordinator."""
+import sys
+import threading
+import time
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+def StopInN(coord, n_secs):
+ time.sleep(n_secs)
+ coord.request_stop()
+
+
+def RaiseInN(coord, n_secs, ex, report_exception):
+ try:
+ time.sleep(n_secs)
+ raise ex
+ except RuntimeError, e:
+ if report_exception:
+ coord.request_stop(e)
+ else:
+ coord.request_stop(sys.exc_info())
+
+
+def SleepABit(n_secs):
+ time.sleep(n_secs)
+
+
+class CoordinatorTest(tf.test.TestCase):
+
+ def testStopAPI(self):
+ coord = tf.train.Coordinator()
+ self.assertFalse(coord.should_stop())
+ self.assertFalse(coord.wait_for_stop(0.01))
+ coord.request_stop()
+ self.assertTrue(coord.should_stop())
+ self.assertTrue(coord.wait_for_stop(0.01))
+
+ def testStopAsync(self):
+ coord = tf.train.Coordinator()
+ self.assertFalse(coord.should_stop())
+ self.assertFalse(coord.wait_for_stop(0.1))
+ threading.Thread(target=StopInN, args=(coord, 0.02)).start()
+ self.assertFalse(coord.should_stop())
+ self.assertFalse(coord.wait_for_stop(0.01))
+ self.assertTrue(coord.wait_for_stop(0.03))
+ self.assertTrue(coord.should_stop())
+
+ def testJoin(self):
+ coord = tf.train.Coordinator()
+ threads = [
+ threading.Thread(target=SleepABit, args=(0.01,)),
+ threading.Thread(target=SleepABit, args=(0.02,)),
+ threading.Thread(target=SleepABit, args=(0.01,))]
+ for t in threads:
+ t.start()
+ coord.join(threads)
+
+ def testJoinGraceExpires(self):
+ coord = tf.train.Coordinator()
+ threads = [
+ threading.Thread(target=StopInN, args=(coord, 0.01)),
+ threading.Thread(target=SleepABit, args=(10.0,))]
+ for t in threads:
+ t.daemon = True
+ t.start()
+ with self.assertRaisesRegexp(RuntimeError, "threads still running"):
+ coord.join(threads, stop_grace_period_secs=0.02)
+
+ def testJoinRaiseReportExcInfo(self):
+ coord = tf.train.Coordinator()
+ threads = [
+ threading.Thread(target=RaiseInN,
+ args=(coord, 0.01, RuntimeError("First"), False)),
+ threading.Thread(target=RaiseInN,
+ args=(coord, 0.02, RuntimeError("Too late"), False))]
+ for t in threads:
+ t.start()
+ with self.assertRaisesRegexp(RuntimeError, "First"):
+ coord.join(threads)
+
+ def testJoinRaiseReportException(self):
+ coord = tf.train.Coordinator()
+ threads = [
+ threading.Thread(target=RaiseInN,
+ args=(coord, 0.01, RuntimeError("First"), True)),
+ threading.Thread(target=RaiseInN,
+ args=(coord, 0.02, RuntimeError("Too late"), True))]
+ for t in threads:
+ t.start()
+ with self.assertRaisesRegexp(RuntimeError, "First"):
+ coord.join(threads)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/ftrl.py b/tensorflow/python/training/ftrl.py
new file mode 100644
index 0000000000..6b9471a5ed
--- /dev/null
+++ b/tensorflow/python/training/ftrl.py
@@ -0,0 +1,283 @@
+"""FTRL-Proximal for Tensor Flow."""
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.training import optimizer
+
+
+def _Solve(a, b, c):
+ """Return solution of a quadratic minimization.
+
+ The optimization equation is:
+ f(a, b, c) = argmin_w{1/2 * a * w^2 + b * w + c * |w|}
+ we get optimal solution w*:
+ w* = -(b - sign(b)*c)/a if |b| > c else w* = 0
+
+ REQUIRES: Dimensionality of a and b must be same
+
+ Args:
+ a: A Tensor
+ b: A Tensor
+ c: A Tensor with one element.
+
+ Returns:
+ A Tensor w, which is solution for the equation
+ """
+ with ops.name_scope("solve_" + b.op.name):
+ c = ops.convert_to_tensor(c)
+ k = array_ops.fill(array_ops.shape(b), c)
+ zero_t = array_ops.zeros(array_ops.shape(b), dtype=b.dtype)
+ w = (c * math_ops.sign(b) - b) / a
+ w = math_ops.select(math_ops.less(math_ops.abs(b), k), zero_t, w)
+ return w
+
+
+def _Compute(accum, linear, base_lr, lr_power, l1, l2):
+ """Compute "variable" given current "accum" and "linear".
+
+ REQUIRES: Dimensionality of accum and linear must be same.
+
+ Args:
+ accum: A Tensor which is accumulated gradient square.
+ linear: A Tensor with same size of accum.
+ base_lr: A Tensor which is base learning rate
+ lr_power: A Tensor which is learning rate power
+ l1: A Tensor which is l1_regularization strength
+ l2: A Tensor which is l2_regularization strength
+ Returns:
+ A Tensor which is "variable" after update
+ """
+ with ops.name_scope("compute_" + accum.op.name):
+ one_t = constant_op.constant(1.0, dtype=types.float32)
+ two_t = constant_op.constant(2.0, dtype=types.float32)
+ learning_rate = math_ops.pow(accum, lr_power) * base_lr
+ quadratic = one_t / learning_rate + two_t * l2
+ w = _Solve(quadratic, linear, l1)
+ return w
+
+
+def _Update(variable, gradients, accum, linear, base_lr, lr_power, l1, l2):
+ """Update "variable", "accum", "linear" based on "gradients".
+
+ Some notations here: "variable" as W, "accum" as N, "linear" as Z,
+ "gradients" as G, N(t) means "accum" at t-step.
+ Assuming lr_power = -0.5 which means using adagrad learning rate.
+ "accum" updates as: N = N + G^2
+ "linear" updates as: Z = Z + G - W * (sqrt(N(t)) - sqrt(N(t-1)))/base_lr
+ REQUIRES: Dimensionality of variable, gradients, accum and linear
+ must be same.
+
+ Args:
+ variable: A Variable.
+ gradients: A Tensor of same shape as 'variable'.
+ accum: A Variable containing the sum of the squares of gradients.
+ linear: A Variable containing approximation info.
+ base_lr: A constant represents base learning rate.
+ lr_power: A constant is used to adjust learning rate.
+ l1: A constant represents l1 regularization strength.
+ l2: A constant represents l2 regularization strength.
+
+ Returns:
+ A group op including three Assign ops:
+ 1. Assign for "accum"
+ 2. Assign for "linear"
+ 3. Assign for "variable"
+ """
+ dtype = variable.dtype.base_dtype
+ base_lr = ops.convert_to_tensor(base_lr, dtype=dtype)
+ lr_power = ops.convert_to_tensor(lr_power, dtype=dtype)
+ l1 = ops.convert_to_tensor(l1, dtype=dtype)
+ l2 = ops.convert_to_tensor(l2, dtype=dtype)
+ # Compute the new accumulator
+ sqr_grad = math_ops.square(gradients)
+ accum_updated = sqr_grad + accum
+ # Compute the new linear
+ neg_lr_power = math_ops.neg(lr_power)
+ sigma = math_ops.pow(accum_updated, neg_lr_power) - math_ops.pow(
+ accum, neg_lr_power)
+ sigma /= base_lr
+ proximal_adjust = sigma * variable
+ linear_updated = linear + gradients - proximal_adjust
+ # Compute the "variable"
+ variable_updated = _Compute(accum_updated, linear_updated, base_lr,
+ lr_power, l1, l2)
+
+ with ops.control_dependencies([sigma]):
+ accum_update_op = state_ops.assign(accum, accum_updated)
+ linear_update_op = state_ops.assign(linear, linear_updated)
+ variable_update_op = state_ops.assign(variable, variable_updated)
+ group_op = control_flow_ops.group(linear_update_op, accum_update_op,
+ variable_update_op)
+ return group_op
+
+
+# TODO(xbing): Refactor code to make _SparseUpdate and _Update share
+# common routines.
+def _SparseUpdate(variable, gradients, accum, linear, base_lr,
+ lr_power, l1, l2):
+ """Sparse Update "variable", "accum", "linear" based on sparse "gradients".
+
+ See the description in _Update.
+
+ Args:
+ variable: A Variable.
+ gradients: A Sparse Tensor
+ accum: A Variable containing the sum of the squares of gradients.
+ linear: A Variable containing approximation info.
+ base_lr: A constant represents base learning rate.
+ lr_power: A constant is used to adjust learning rate.
+ l1: A constant represents l1 regularization strength.
+ l2: A constant represents l2 regularization strength.
+
+ Returns:
+ A group op including three ScatterUpdate ops:
+ 1. ScatterUpdate for "accum"
+ 2. ScatterUpdate for "linear"
+ 3. ScatterUpdate for "variable"
+ """
+ assert isinstance(gradients, ops.IndexedSlices)
+ with ops.name_scope("sparse_update_" + variable.op.name) as scope:
+ dtype = variable.dtype.base_dtype
+ base_lr = ops.convert_to_tensor(base_lr, dtype=dtype)
+ lr_power = ops.convert_to_tensor(lr_power, dtype=dtype)
+ l1 = ops.convert_to_tensor(l1, dtype=dtype)
+ l2 = ops.convert_to_tensor(l2, dtype=dtype)
+
+ # Compute the new value for the accumulator
+ previous_accum = array_ops.gather(accum, gradients.indices)
+ sqr_grad = gradients.values * gradients.values
+ accum_updated = sqr_grad + previous_accum
+
+ # Compute the new linear
+ neg_lr_power = math_ops.neg(lr_power)
+ sigma = math_ops.pow(accum_updated, neg_lr_power) - math_ops.pow(
+ previous_accum, neg_lr_power)
+ sigma /= base_lr
+ variable_slice = array_ops.gather(variable, gradients.indices)
+ proximal_adjust = sigma * variable_slice
+ linear_slice = array_ops.gather(linear, gradients.indices)
+ linear_updated = linear_slice + gradients.values - proximal_adjust
+
+ # Compute the new "variable"
+ variable_updated = _Compute(accum_updated, linear_updated, base_lr,
+ lr_power, l1, l2)
+
+ with ops.control_dependencies([sigma]):
+ accum_update_op = state_ops.scatter_update(accum, gradients.indices,
+ accum_updated)
+ linear_update_op = state_ops.scatter_update(linear, gradients.indices,
+ linear_updated)
+ variable_update_op = state_ops.scatter_update(variable, gradients.indices,
+ variable_updated)
+ group_op = control_flow_ops.group(linear_update_op, accum_update_op,
+ variable_update_op, name=scope)
+ return group_op
+
+
+class FtrlOptimizer(optimizer.Optimizer):
+ """Optimizer that implements the FTRL algorithm.
+
+ @@__init__
+ """
+
+ def __init__(self, learning_rate,
+ learning_rate_power=-0.5,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0,
+ use_locking=False, name="Ftrl"):
+ """Construct a new FTRL optimizer.
+
+ The Ftrl-proximal algorithm, abbreviated for Follow-the-regularized-leader,
+ is described in the paper [Ad Click Prediction: a View from the Trenches](
+ https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
+
+ It can give a good performance vs. sparsity tradeoff.
+
+ Ftrl-proximal uses its own global base learning rate and can behave like
+ Adagrad with `learning_rate_power=-0.5`, or like gradient descent with
+ `learning_rate_power=0.0`.
+
+ The effective learning rate is adjusted per parameter, relative to this
+ base learning rate as:
+
+ ```
+ effective_learning_rate_i = (learning_rate /
+ pow(k + summed_squared_gradients_for_i, learning_rate_power));
+ ```
+
+ where k is the small constant `initial_accumulator_value`.
+
+ Note that the real regularization coefficient of `|w|^2` for objective
+ function is `1 / lambda_2` if specifying `l2 = lambda_2` as argument when
+ using this function.
+
+ Args:
+ learning_rate: A float value or a constant float `Tensor`.
+ learning_rate_power: A float value, must be less or equal to zero.
+ initial_accumulator_value: The starting value for accumulators.
+ Only positive values are allowed.
+ l1_regularization_strength: A float value, must be greater than or
+ equal to zero.
+ l2_regularization_strength: A float value, must be greater than or
+ equal to zero.
+ use_locking: If `True` use locks for update operations.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "Ftrl".
+
+ Raises:
+ ValueError: if one of the arguments is invalid.
+ """
+ super(FtrlOptimizer, self).__init__(use_locking, name)
+
+ if initial_accumulator_value <= 0.0:
+ raise ValueError("initial_accumulator_value %f needs to be positive" %
+ initial_accumulator_value)
+ if learning_rate_power > 0.0:
+ raise ValueError("learning_rate_power %f needs to be negative or zero" %
+ learning_rate_power)
+ if l1_regularization_strength < 0.0:
+ raise ValueError(
+ "l1_regularization_strength %f needs to be positive or zero" %
+ l1_regularization_strength)
+ if l2_regularization_strength < 0.0:
+ raise ValueError(
+ "l2_regularization_strength %f needs to be positive or zero" %
+ l2_regularization_strength)
+
+ self._learning_rate = learning_rate
+ self._learning_rate_power = learning_rate_power
+ self._initial_accumulator_value = initial_accumulator_value
+ self._l1_regularization_strength = l1_regularization_strength
+ self._l2_regularization_strength = l2_regularization_strength
+
+ def _create_slots(self, var_list):
+ # Create the "accum" and "linear" slots.
+ for v in var_list:
+ self._get_or_make_slot(
+ v,
+ constant_op.constant(self._initial_accumulator_value,
+ dtype=v.dtype, shape=v.get_shape()),
+ "accum",
+ self._name)
+ self._zeros_slot(v, "linear", self._name)
+
+ def _apply_dense(self, grad, var):
+ accum = self.get_slot(var, "accum")
+ linear = self.get_slot(var, "linear")
+ return _Update(var, grad, accum, linear,
+ self._learning_rate, self._learning_rate_power,
+ self._l1_regularization_strength,
+ self._l2_regularization_strength)
+
+ def _apply_sparse(self, grad, var):
+ accum = self.get_slot(var, "accum")
+ linear = self.get_slot(var, "linear")
+ return _SparseUpdate(var, grad, accum, linear,
+ self._learning_rate, self._learning_rate_power,
+ self._l1_regularization_strength,
+ self._l2_regularization_strength)
diff --git a/tensorflow/python/training/ftrl_test.py b/tensorflow/python/training/ftrl_test.py
new file mode 100644
index 0000000000..eb581048f1
--- /dev/null
+++ b/tensorflow/python/training/ftrl_test.py
@@ -0,0 +1,234 @@
+"""Functional tests for Ftrl operations."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class FtrlOptimizerTest(tf.test.TestCase):
+
+ def testFtrlwithoutRegularization(self):
+ with self.test_session() as sess:
+ var0 = tf.Variable([0.0, 0.0])
+ var1 = tf.Variable([0.0, 0.0])
+ grads0 = tf.constant([0.1, 0.2])
+ grads1 = tf.constant([0.01, 0.02])
+ opt = tf.train.FtrlOptimizer(3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllClose([0.0, 0.0], v0_val)
+ self.assertAllClose([0.0, 0.0], v1_val)
+
+ # Run 3 steps FTRL
+ for _ in range(3):
+ update.run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllClose(np.array([-2.60260963, -4.29698515]),
+ v0_val)
+ self.assertAllClose(np.array([-0.28432083, -0.56694895]),
+ v1_val)
+
+ def testFtrlwithoutRegularization2(self):
+ with self.test_session() as sess:
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([4.0, 3.0])
+ grads0 = tf.constant([0.1, 0.2])
+ grads1 = tf.constant([0.01, 0.02])
+
+ opt = tf.train.FtrlOptimizer(3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllClose([1.0, 2.0], v0_val)
+ self.assertAllClose([4.0, 3.0], v1_val)
+
+ # Run 3 steps FTRL
+ for _ in range(3):
+ update.run()
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllClose(np.array([-2.55607247, -3.98729396]),
+ v0_val)
+ self.assertAllClose(np.array([-0.28232238, -0.56096673]),
+ v1_val)
+
+ def testFtrlWithL1(self):
+ with self.test_session() as sess:
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([4.0, 3.0])
+ grads0 = tf.constant([0.1, 0.2])
+ grads1 = tf.constant([0.01, 0.02])
+
+ opt = tf.train.FtrlOptimizer(3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllClose([1.0, 2.0], v0_val)
+ self.assertAllClose([4.0, 3.0], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update.run()
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllClose(np.array([-7.66718769, -10.91273689]),
+ v0_val)
+ self.assertAllClose(np.array([-0.93460727, -1.86147261]),
+ v1_val)
+
+ def testFtrlWithL1_L2(self):
+ with self.test_session() as sess:
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([4.0, 3.0])
+ grads0 = tf.constant([0.1, 0.2])
+ grads1 = tf.constant([0.01, 0.02])
+
+ opt = tf.train.FtrlOptimizer(3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllClose([1.0, 2.0], v0_val)
+ self.assertAllClose([4.0, 3.0], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update.run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllClose(np.array([-0.24059935, -0.46829352]),
+ v0_val)
+ self.assertAllClose(np.array([-0.02406147, -0.04830509]),
+ v1_val)
+
+ def applyOptimizer(self, opt, steps=5, is_sparse=False):
+ if is_sparse:
+ var0 = tf.Variable([[0.0], [0.0]])
+ var1 = tf.Variable([[0.0], [0.0]])
+ grads0 = tf.IndexedSlices(tf.constant([0.1], shape=[1, 1]),
+ tf.constant([0]),
+ tf.constant([2, 1]))
+ grads1 = tf.IndexedSlices(tf.constant([0.02], shape=[1, 1]),
+ tf.constant([1]),
+ tf.constant([2, 1]))
+ else:
+ var0 = tf.Variable([0.0, 0.0])
+ var1 = tf.Variable([0.0, 0.0])
+ grads0 = tf.constant([0.1, 0.2])
+ grads1 = tf.constant([0.01, 0.02])
+
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ sess = tf.get_default_session()
+ v0_val, v1_val = sess.run([var0, var1])
+ if is_sparse:
+ self.assertAllClose([[0.0], [0.0]], v0_val)
+ self.assertAllClose([[0.0], [0.0]], v1_val)
+ else:
+ self.assertAllClose([0.0, 0.0], v0_val)
+ self.assertAllClose([0.0, 0.0], v1_val)
+
+ # Run Ftrl for a few steps
+ for _ in range(steps):
+ update.run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ return v0_val, v1_val
+
+ # When variables are intialized with Zero, FTRL-Proximal has two properties:
+ # 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical
+ # with GradientDescent.
+ # 2. Without L1&L2 but with adaptive learning rate, FTRL-Proximal is identical
+ # with Adagrad.
+ # So, basing on these two properties, we test if our implementation of
+ # FTRL-Proximal performs same updates as Adagrad or GradientDescent.
+ def testEquivAdagradwithoutRegularization(self):
+ with self.test_session():
+ val0, val1 = self.applyOptimizer(
+ tf.train.FtrlOptimizer(3.0,
+ # Adagrad learning rate
+ learning_rate_power=-0.5,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0))
+
+ with self.test_session():
+ val2, val3 = self.applyOptimizer(
+ tf.train.AdagradOptimizer(3.0, initial_accumulator_value=0.1))
+
+ self.assertAllClose(val0, val2)
+ self.assertAllClose(val1, val3)
+
+ def testEquivSparseAdagradwithoutRegularization(self):
+ with self.test_session():
+ val0, val1 = self.applyOptimizer(
+ tf.train.FtrlOptimizer(3.0,
+ # Adagrad learning rate
+ learning_rate_power=-0.5,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0),
+ is_sparse=True)
+
+ with self.test_session():
+ val2, val3 = self.applyOptimizer(
+ tf.train.AdagradOptimizer(3.0, initial_accumulator_value=0.1),
+ is_sparse=True)
+
+ self.assertAllClose(val0, val2)
+ self.assertAllClose(val1, val3)
+
+ def testEquivSparseGradientDescentwithoutRegularizaion(self):
+ with self.test_session():
+ val0, val1 = self.applyOptimizer(
+ tf.train.FtrlOptimizer(3.0,
+ # Fixed learning rate
+ learning_rate_power=-0.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0),
+ is_sparse=True)
+
+ with self.test_session():
+ val2, val3 = self.applyOptimizer(
+ tf.train.GradientDescentOptimizer(3.0), is_sparse=True)
+
+ self.assertAllClose(val0, val2)
+ self.assertAllClose(val1, val3)
+
+ def testEquivGradientDescentwithoutRegularizaion(self):
+ with self.test_session():
+ val0, val1 = self.applyOptimizer(
+ tf.train.FtrlOptimizer(3.0,
+ # Fixed learning rate
+ learning_rate_power=-0.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0))
+
+ with self.test_session():
+ val2, val3 = self.applyOptimizer(
+ tf.train.GradientDescentOptimizer(3.0))
+
+ self.assertAllClose(val0, val2)
+ self.assertAllClose(val1, val3)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/gradient_descent.py b/tensorflow/python/training/gradient_descent.py
new file mode 100644
index 0000000000..21247aacf3
--- /dev/null
+++ b/tensorflow/python/training/gradient_descent.py
@@ -0,0 +1,44 @@
+"""GradientDescent for TensorFlow."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import constant_op
+# pylint: disable=unused-import
+from tensorflow.python.ops import math_ops
+# pylint: enable=unused-import
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_ops
+
+
+class GradientDescentOptimizer(optimizer.Optimizer):
+ """Optimizer that implements the gradient descent algorithm.
+
+ @@__init__
+ """
+
+ def __init__(self, learning_rate, use_locking=False, name="GradientDescent"):
+ """Construct a new gradient descent optimizer.
+
+ Args:
+ learning_rate: A Tensor or a floating point value. The learning
+ rate to use.
+ use_locking: If True use locks for update operation.s
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "GradientDescent".
+ """
+ super(GradientDescentOptimizer, self).__init__(use_locking, name)
+ self._learning_rate = learning_rate
+
+ def _apply_dense(self, grad, var):
+ return training_ops.apply_gradient_descent(
+ var,
+ self._learning_rate_tensor,
+ grad,
+ use_locking=self._use_locking).op
+
+ def _apply_sparse(self, grad, var):
+ delta = ops.IndexedSlices(grad.values * self._learning_rate_tensor,
+ grad.indices, grad.dense_shape)
+ return var.scatter_sub(delta, use_locking=self._use_locking)
+
+ def _prepare(self):
+ self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
+ name="learning_rate")
diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py
new file mode 100644
index 0000000000..d5b0cae401
--- /dev/null
+++ b/tensorflow/python/training/gradient_descent_test.py
@@ -0,0 +1,105 @@
+"""Functional test for GradientDescent."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class GradientDescentOptimizerTest(tf.test.TestCase):
+
+ def testBasic(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+ sgd_op = tf.train.GradientDescentOptimizer(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllClose([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], var0.eval())
+ self.assertAllClose([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval())
+
+ def testFloat64(self):
+ with self.test_session():
+ opt = tf.train.GradientDescentOptimizer(3.0)
+
+ # compute_gradients.
+ values = [1.0, 3.0]
+ good_vars = [tf.Variable([v]) for v in values]
+ bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
+ opt.compute_gradients, bad_loss, good_vars)
+ bad_vars = [
+ tf.Variable(np.array([v], np.float64), name="bad_var")
+ for v in values]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
+ bad_vars)
+ opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
+
+ # apply_gradients.
+ bad_grads = [
+ tf.constant([0.1], dtype=np.float64, name="bad_grad"),
+ tf.constant([0.01])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
+ opt.apply_gradients, zip(bad_grads, good_vars))
+ good_grads = [tf.constant([0.01]), tf.constant([0.02])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.apply_gradients, zip(good_grads, bad_vars))
+ opt.apply_gradients(zip(good_grads, good_vars))
+
+ def testWithGlobalStep(self):
+ with self.test_session():
+ global_step = tf.Variable(0, trainable=False)
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+ sgd_op = tf.train.GradientDescentOptimizer(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ tf.initialize_all_variables().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params and global_step
+ self.assertAllClose([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], var0.eval())
+ self.assertAllClose([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval())
+ self.assertAllClose(1, global_step.eval())
+
+ def testSparseBasic(self):
+ with self.test_session():
+ var0 = tf.Variable([[1.0], [2.0]])
+ var1 = tf.Variable([[3.0], [4.0]])
+ grads0 = tf.IndexedSlices(tf.constant([0.1], shape=[1, 1]),
+ tf.constant([0]),
+ tf.constant([2, 1]))
+ grads1 = tf.IndexedSlices(tf.constant([0.01], shape=[1, 1]),
+ tf.constant([1]),
+ tf.constant([2, 1]))
+ sgd_op = tf.train.GradientDescentOptimizer(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([[1.0], [2.0]], var0.eval())
+ self.assertAllClose([[3.0], [4.0]], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllClose([[1.0 - 3.0 * 0.1], [2.0]], var0.eval())
+ self.assertAllClose([[3.0], [4.0 - 3.0 * 0.01]], var1.eval())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
new file mode 100644
index 0000000000..413fc044f7
--- /dev/null
+++ b/tensorflow/python/training/input.py
@@ -0,0 +1,501 @@
+"""## Input pipeline
+
+TensorFlow functions for setting up an input-prefetching pipeline.
+Please see the [reading data how-to](../../how_tos/reading_data.md)
+for context.
+
+### Beginning of an input pipeline
+
+The "producer" functions add a queue to the graph and a corresponding
+QueueRunner for running the subgraph that fills that queue.
+
+@@match_filenames_once
+@@limit_epochs
+@@range_input_producer
+@@slice_input_producer
+@@string_input_producer
+
+### Batching at the end of an input pipeline
+
+These functions add a queue to the graph to assemble a batch of
+examples, with possible shuffling. They also add a QueueRunner for
+running the subgraph that fills that queue.
+
+Use [batch](#batch) or [batch_join](#batch_join) for batching examples that have
+already been well shuffled. Use [shuffle_batch](#shuffle_batch) or
+[shuffle_batch_join](#shuffle_batch_join) for examples that
+would benefit from additional shuffling.
+
+Use [batch](#batch) or [shuffle_batch](#shuffle_batch) if you want a
+single thread producing examples to batch, or if you have a
+single subgraph producing examples but you want to run it in N threads
+(where you increase N until it can keep the queue full). Use
+[batch_join](#batch_join) or [shuffle_batch_join](#shuffle_batch_join)
+if you have N different subgraphs producing examples to batch and you
+want them run by N threads.
+
+@@batch
+@@batch_join
+@@shuffle_batch
+@@shuffle_batch_join
+
+"""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import summary_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.training import queue_runner
+
+
+def match_filenames_once(pattern, name=None):
+ """Save the list of files matching pattern, so it is only computed once.
+
+ Args:
+ pattern: A file pattern (glob).
+ name: A name for the operations (optional).
+
+ Returns:
+ A variable that is initialized to the list of files matching pattern.
+ """
+ with ops.op_scope([pattern], name, "matching_filenames") as name:
+ return variables.Variable(io_ops.matching_files(pattern), trainable=False,
+ name=name, validate_shape=False)
+
+
+def limit_epochs(tensor, num_epochs=None, name=None):
+ """Returns tensor num_epochs times and then raises an OutOfRange error.
+
+ Args:
+ tensor: Any Tensor.
+ num_epochs: An integer (optional). If specified, limits the number
+ of steps the output tensor may be evaluated.
+ name: A name for the operations (optional).
+
+ Returns:
+ tensor or OutOfRange.
+ """
+ if num_epochs is None:
+ return tensor
+ if num_epochs <= 0:
+ raise ValueError("num_epochs must be > 0 not %d." % num_epochs)
+ with ops.op_scope([tensor], name, "limit_epochs") as name:
+ zero64 = constant_op.constant(0, dtype=types.int64)
+ epochs = variables.Variable(zero64, name="epochs")
+ counter = epochs.count_up_to(num_epochs)
+ with ops.control_dependencies([counter]):
+ return array_ops.identity(tensor, name=name)
+
+
+def _input_producer(input_tensor, dtype, num_epochs, shuffle, seed, capacity,
+ name, summary_name):
+ if shuffle:
+ input_tensor = random_ops.random_shuffle(input_tensor, seed=seed)
+ input_tensor = limit_epochs(input_tensor, num_epochs)
+
+ q = data_flow_ops.FIFOQueue(capacity=capacity, dtypes=[dtype], shapes=[[]],
+ name=name)
+ enq = q.enqueue_many([input_tensor])
+ queue_runner.add_queue_runner(queue_runner.QueueRunner(q, [enq]))
+ summary_ops.scalar_summary("queue/%s/%s" % (q.name, summary_name),
+ math_ops.cast(q.size(), types.float32) *
+ (1. / capacity))
+ return q
+
+
+def string_input_producer(string_tensor, num_epochs=None, shuffle=True,
+ seed=None, capacity=32, name=None):
+ """Output strings (e.g. filenames) to a queue for an input pipeline.
+
+ Args:
+ string_tensor: A 1-D string tensor with the strings to produce.
+ num_epochs: An integer (optional). If specified, `string_input_producer`
+ produces each string from `string_tensor` `num_epochs` times before
+ generating an OutOfRange error. If not specified, `string_input_producer`
+ can cycle through the strings in `string_tensor` an unlimited number of
+ times.
+ shuffle: Boolean. If true, the strings are randomly shuffled within each
+ epoch.
+ seed: An integer (optional). Seed used if shuffle == True.
+ capacity: An integer. Sets the queue capacity.
+ name: A name for the operations (optional).
+
+ Returns:
+ A queue with the output strings. A QueueRunner for the Queue
+ is added to the current Graph's QUEUE_RUNNER collection.
+ """
+ with ops.op_scope([string_tensor], name, "input_producer") as name:
+ return _input_producer(
+ string_tensor, types.string, num_epochs, shuffle, seed, capacity, name,
+ "fraction_of_%d_full" % capacity)
+
+
+def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None,
+ capacity=32, name=None):
+ """Produces the integers from 0 to limit-1 in a queue.
+
+ Args:
+ limit: An int32 scalar tensor.
+ num_epochs: An integer (optional). If specified, `range_input_producer`
+ produces each integer `num_epochs` times before generating an
+ OutOfRange error. If not specified, `range_input_producer` can cycle
+ through the integers an unlimited number of times.
+ shuffle: Boolean. If true, the integers are randomly shuffled within each
+ epoch.
+ seed: An integer (optional). Seed used if shuffle == True.
+ capacity: An integer. Sets the queue capacity.
+ name: A name for the operations (optional).
+
+ Returns:
+ A Queue with the output integers. A QueueRunner for the Queue
+ is added to the current Graph's QUEUE_RUNNER collection.
+ """
+ with ops.op_scope([limit], name, "input_producer") as name:
+ range_tensor = math_ops.range(0, limit)
+ return _input_producer(
+ range_tensor, types.int32, num_epochs, shuffle, seed, capacity, name,
+ "fraction_of_%d_full" % capacity)
+
+
+def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None,
+ capacity=32, name=None):
+ """Produces a slice of each Tensor in tensor_list.
+
+ Implemented using a Queue -- a QueueRunner for the Queue
+ is added to the current Graph's QUEUE_RUNNER collection.
+
+ Args:
+ tensor_list: A list of Tensors. Every Tensor in tensor_list must
+ have the same size in the first dimension.
+ num_epochs: An integer (optional). If specified, `slice_input_producer`
+ produces each slice `num_epochs` times before generating
+ an OutOfRange error. If not specified, `slice_input_producer` can cycle
+ through the slices an unlimited number of times.
+ seed: An integer (optional). Seed used if shuffle == True.
+ capacity: An integer. Sets the queue capacity.
+ name: A name for the operations (optional).
+
+ Returns:
+ A list of tensors, one for each element of tensor_list. If the tensor
+ in tensor_list has shape [N, a, b, .., z], then the corresponding output
+ tensor will have shape [a, b, ..., z].
+ """
+ with ops.op_scope(tensor_list, name, "input_producer"):
+ tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensor_list)
+ if not tensor_list:
+ raise ValueError(
+ "Expected at least one tensor in slice_input_producer().")
+ range_size = array_ops.shape(tensor_list[0])[0]
+ # TODO(josh11b): Add an assertion that the first dimension of
+ # everything in TensorList matches. Maybe just check the inferred shapes?
+ queue = range_input_producer(range_size, num_epochs=num_epochs,
+ shuffle=shuffle, seed=seed, capacity=capacity)
+ index = queue.dequeue()
+ output = [array_ops.gather(t, index) for t in tensor_list]
+ return output
+
+
+# Helpers for the batching functions ------------------------------------------
+
+def _flatten(tensor_list_list):
+ return [tensor for tensor_list in tensor_list_list for tensor in tensor_list]
+
+
+def _validate(tensor_list):
+ tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensor_list)
+ if not tensor_list:
+ raise ValueError("Expected at least one tensor in batch().")
+ return tensor_list
+
+
+def _validate_join(tensor_list_list):
+ tensor_list_list = [ops.convert_n_to_tensor_or_indexed_slices(tl)
+ for tl in tensor_list_list]
+ if not tensor_list_list:
+ raise ValueError("Expected at least one input in batch_join().")
+ return tensor_list_list
+
+
+def _dtypes(tensor_list_list):
+ all_dtypes = [[t.dtype for t in tl] for tl in tensor_list_list]
+ dtypes = all_dtypes[0]
+ for other_dtypes in all_dtypes[1:]:
+ if other_dtypes != dtypes:
+ raise TypeError("Expected types to be consistent: %s vs. %s." %
+ ", ".join(x.name for x in dtypes),
+ ", ".join(x.name for x in other_dtypes))
+ return dtypes
+
+
+def _merge_shapes(shape_list, enqueue_many):
+ shape_list = [tensor_shape.as_shape(s) for s in shape_list]
+ if enqueue_many:
+ # We want the shapes without the leading batch dimension.
+ shape_list = [s.WithRankAtLeast(1)[1:] for s in shape_list]
+ merged_shape = shape_list[0]
+ for s in shape_list[1:]:
+ merged_shape.merge_with(s)
+ return merged_shape.as_list()
+
+
+def _shapes(tensor_list_list, shapes, enqueue_many):
+ if shapes is None:
+ l = len(tensor_list_list[0])
+ shapes = [_merge_shapes([tl[i].get_shape().as_list()
+ for tl in tensor_list_list],
+ enqueue_many) for i in range(l)]
+ return shapes
+
+
+def _enqueue_join(queue, tensor_list_list, enqueue_many):
+ if enqueue_many:
+ enqueue_ops = [queue.enqueue_many(tl) for tl in tensor_list_list]
+ else:
+ enqueue_ops = [queue.enqueue(tl) for tl in tensor_list_list]
+ queue_runner.add_queue_runner(queue_runner.QueueRunner(queue, enqueue_ops))
+
+
+def _enqueue(queue, tensor_list, threads, enqueue_many):
+ if enqueue_many:
+ enqueue_ops = [queue.enqueue_many(tensor_list)] * threads
+ else:
+ enqueue_ops = [queue.enqueue(tensor_list)] * threads
+ queue_runner.add_queue_runner(queue_runner.QueueRunner(queue, enqueue_ops))
+
+
+# Batching functions ----------------------------------------------------------
+
+def batch(tensor_list, batch_size, num_threads=1, capacity=32,
+ enqueue_many=False, shapes=None, name=None):
+ """Run tensor_list to fill a queue to create batches.
+
+ Implemented using a queue -- a QueueRunner for the queue
+ is added to the current Graph's QUEUE_RUNNER collection.
+
+ Args:
+ tensor_list: The list of tensors to enqueue.
+ batch_size: The new batch size pulled from the queue.
+ num_threads: The number of threads enqueuing tensor_list.
+ capacity: Maximum number of elements in the queue, controls the
+ how far ahead the prefetching allowed is allowed to get and
+ memory usage.
+ enqueue_many: If False, tensor_list is assumed to represent a
+ single example. If True, tensor_list is assumed to represent
+ a batch of examples, where the first dimension is indexed by
+ example, and all members of tensor_list should have the same
+ size in the first dimension.
+ shapes: Optional. The shapes for each example. Defaults to the
+ inferred shapes for tensor_list (leaving off the first dimension
+ if enqueue_many is True).
+ name: A name for the operations (optional).
+
+ Returns:
+ A list of tensors with the same number and types as tensor_list.
+ If enqueue_many is false, then an input tensor with shape
+ `[x, y, z]` will be output as a tensor with shape
+ `[batch_size, x, y, z]`. If enqueue_many is True, and an
+ input tensor has shape `[*, x, y, z]`, the the output will have
+ shape `[batch_size, x, y, z]`.
+ """
+ with ops.op_scope(tensor_list, name, "batch") as name:
+ tensor_list = _validate(tensor_list)
+ dtypes = _dtypes([tensor_list])
+ shapes = _shapes([tensor_list], shapes, enqueue_many)
+ # TODO(josh11b,mrry): Switch to BatchQueue once it is written.
+ queue = data_flow_ops.FIFOQueue(
+ capacity=capacity, dtypes=dtypes, shapes=shapes)
+ _enqueue(queue, tensor_list, num_threads, enqueue_many)
+ summary_ops.scalar_summary(
+ "queue/%s/fraction_of_%d_full" % (queue.name, capacity),
+ math_ops.cast(queue.size(), types.float32) * (1. / capacity))
+ return queue.dequeue_many(batch_size, name=name)
+
+
+# TODO(josh11b): Add a thread_multiplier or num_threads (that has to be
+# a multiple of len(tensor_list_list)?) parameter, to address the use
+# case where you want more parallelism than you can support different
+# readers (either because you don't have that many files or can't
+# read that many files in parallel due to the number of seeks required).
+# Once this is done, batch() can be written as a call to batch_join().
+def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False,
+ shapes=None, name=None):
+ """Run a list of tensors to fill a queue to create batches of examples.
+
+ This version enqueues a different list of tensors in different threads.
+ Implemented using a queue -- a QueueRunner for the queue
+ is added to the current Graph's QUEUE_RUNNER collection.
+
+ Args:
+ tensor_list_list: A list of tuples of tensors to enqueue.
+ len(tensor_list_list) threads will be started, with the i-th
+ thread enqueuing the tensors from tensor_list[i].
+ tensor_list[i1][j] must match tensor_list[i2][j] in type and
+ shape (except in the first dimension if enqueue_many is true).
+ batch_size: The new batch size pulled from the queue.
+ capacity: Maximum number of elements in the queue, controls the
+ how far ahead the prefetching allowed is allowed to get and
+ memory usage.
+ enqueue_many: If False, each tensor_list_list[i] is assumed to
+ represent a single example. If True, tensor_list_list[i] is
+ assumed to represent a batch of examples, where the first
+ dimension is indexed by example, and all members of
+ tensor_list_list[i] should have the same size in the first
+ dimension.
+ shapes: Optional. The shapes for each example. Defaults to the
+ inferred shapes for tensor_list_list[i] (which must match, after
+ leaving off the first dimension if enqueue_many is True).
+ name: A name for the operations (optional).
+
+ Returns:
+ A list of tensors with the same number and types as
+ tensor_list_list[i]. If enqueue_many is false, then an input
+ tensor with shape `[x, y, z]` will be output as a tensor with
+ shape `[batch_size, x, y, z]`. If enqueue_many is True, and an
+ input tensor has shape `[*, x, y, z]`, the the output will have
+ shape `[batch_size, x, y, z]`.
+ """
+ with ops.op_scope(_flatten(tensor_list_list), name, "batch_join") as name:
+ tensor_list_list = _validate_join(tensor_list_list)
+ dtypes = _dtypes(tensor_list_list)
+ shapes = _shapes(tensor_list_list, shapes, enqueue_many)
+ # TODO(josh11b,mrry): Switch to BatchQueue once it is written.
+ queue = data_flow_ops.FIFOQueue(
+ capacity=capacity, dtypes=dtypes, shapes=shapes)
+ _enqueue_join(queue, tensor_list_list, enqueue_many)
+ summary_ops.scalar_summary(
+ "queue/%s/fraction_of_%d_full" % (queue.name, capacity),
+ math_ops.cast(queue.size(), types.float32) * (1. / capacity))
+ return queue.dequeue_many(batch_size, name=name)
+
+
+def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue,
+ num_threads=1, seed=None, enqueue_many=False, shapes=None,
+ name=None):
+ """Create batches by randomly shuffling tensors.
+
+ This adds:
+
+ * a shuffling queue into which tensors from tensor_list are enqueued.
+ * a dequeue many operation to create batches from the queue,
+ * and a QueueRunner is added to the current Graph's QUEUE_RUNNER collection,
+ to enqueue the tensors from tensor_list.
+
+ Args:
+ tensor_list: The list of tensors to enqueue.
+ batch_size: The new batch size pulled from the queue.
+ capacity: Maximum number of elements in the queue, controls the
+ how far ahead the prefetching allowed is allowed to get and
+ memory usage.
+ min_after_dequeue: Minimum number elements in the queue after a
+ dequeue, used to ensure a level of mixing of elements.
+ num_threads: The number of threads enqueuing tensor_list.
+ seed: Seed for the random shuffling within the queue.
+ enqueue_many: If False, tensor_list is assumed to represent a
+ single example. If True, tensor_list is assumed to represent
+ a batch of examples, where the first dimension is indexed by
+ example, and all members of tensor_list should have the same
+ size in the first dimension.
+ shapes: Optional. The shapes for each example. Defaults to the
+ inferred shapes for tensor_list (leaving off the first dimension
+ if enqueue_many is True).
+ name: A name for the operations (optional).
+
+ Returns:
+ A list of tensors with the same number and types as tensor_list.
+ If enqueue_many is false, then an input tensor with shape
+ `[x, y, z]` will be output as a tensor with shape
+ `[batch_size, x, y, z]`. If enqueue_many is True, and an
+ input tensor has shape `[*, x, y, z]`, the the output will have
+ shape `[batch_size, x, y, z]`.
+ """
+ with ops.op_scope(tensor_list, name, "shuffle_batch") as name:
+ tensor_list = _validate(tensor_list)
+ dtypes = _dtypes([tensor_list])
+ shapes = _shapes([tensor_list], shapes, enqueue_many)
+ queue = data_flow_ops.RandomShuffleQueue(
+ capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed,
+ dtypes=dtypes, shapes=shapes)
+ _enqueue(queue, tensor_list, num_threads, enqueue_many)
+ full = (math_ops.cast(queue.size() - min_after_dequeue, types.float32) *
+ (1. / (capacity - min_after_dequeue)))
+ # Note that name contains a '/' at the end so we intentionally do not place
+ # a '/' after %s below.
+ summary_name = (
+ "queue/%sfraction_over_%d_of_%d_full" %
+ (name, min_after_dequeue, capacity - min_after_dequeue))
+ summary_ops.scalar_summary(summary_name, full)
+
+ return queue.dequeue_many(batch_size, name=name)
+
+
+def shuffle_batch_join(tensor_list_list, batch_size, capacity,
+ min_after_dequeue, seed=None, enqueue_many=False,
+ shapes=None, name=None):
+ """Create batches by randomly shuffling tensors.
+
+ This version enqueues a different list of tensors in different threads.
+ It adds:
+
+ * a shuffling queue into which tensors from tensor_list_list are enqueued.
+ * a dequeue many operation to create batches from the queue,
+ * and a QueueRunner is added to the current Graph's QUEUE_RUNNER collection,
+ to enqueue the tensors from tensor_list_list.
+
+ Args:
+ tensor_list_list: A list of tuples of tensors to enqueue.
+ len(tensor_list_list) threads will be started, with the i-th
+ thread enqueuing the tensors from tensor_list[i].
+ tensor_list[i1][j] must match tensor_list[i2][j] in type and
+ shape (except in the first dimension if enqueue_many is true).
+ batch_size: The new batch size pulled from the queue.
+ capacity: Maximum number of elements in the queue, controls the
+ how far ahead the prefetching allowed is allowed to get and
+ memory usage.
+ min_after_dequeue: Minimum number elements in the queue after a
+ dequeue, used to ensure a level of mixing of elements.
+ seed: Seed for the random shuffling within the queue.
+ enqueue_many: If False, each tensor_list_list[i] is assumed to
+ represent a single example. If True, tensor_list_list[i] is
+ assumed to represent a batch of examples, where the first
+ dimension is indexed by example, and all members of
+ tensor_list_list[i] should have the same size in the first
+ dimension.
+ shapes: Optional. The shapes for each example. Defaults to the
+ inferred shapes for tensor_list_list[i] (which must match, after
+ leaving off the first dimension if enqueue_many is True).
+ name: A name for the operations (optional).
+
+ Returns:
+ A list of tensors with the same number and types as
+ tensor_list_list[i]. If enqueue_many is false, then an input
+ tensor with shape `[x, y, z]` will be output as a tensor with
+ shape `[batch_size, x, y, z]`. If enqueue_many is True, and an
+ input tensor has shape `[*, x, y, z]`, the the output will have
+ shape `[batch_size, x, y, z]`.
+ """
+ with ops.op_scope(
+ _flatten(tensor_list_list), name, "shuffle_batch_join") as name:
+ tensor_list_list = _validate_join(tensor_list_list)
+ dtypes = _dtypes(tensor_list_list)
+ shapes = _shapes(tensor_list_list, shapes, enqueue_many)
+ queue = data_flow_ops.RandomShuffleQueue(
+ capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed,
+ dtypes=dtypes, shapes=shapes)
+ _enqueue_join(queue, tensor_list_list, enqueue_many)
+ full = (math_ops.cast(queue.size() - min_after_dequeue, types.float32) *
+ (1. / (capacity - min_after_dequeue)))
+ # Note that name contains a '/' at the end so we intentionally do not place
+ # a '/' after %s below.
+ summary_name = (
+ "queue/%sfraction_over_%d_of_%d_full" %
+ (name, min_after_dequeue, capacity - min_after_dequeue))
+ summary_ops.scalar_summary(summary_name, full)
+ return queue.dequeue_many(batch_size, name=name)
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py
new file mode 100644
index 0000000000..fe8c195e77
--- /dev/null
+++ b/tensorflow/python/training/input_test.py
@@ -0,0 +1,477 @@
+"""Tests for training.input."""
+
+import os
+import itertools
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class MatchFilenamesOnceTest(tf.test.TestCase):
+
+ def test(self):
+ temp_dir = self.get_temp_dir()
+ filenames = [os.path.join(temp_dir, n) for n in os.listdir(temp_dir)]
+ additional = [os.path.join(self.get_temp_dir(), "match_filenames.%d" % i)
+ for i in range(3)]
+ for name in additional:
+ open(name, "w").write("Some contents")
+ filenames += additional
+ with self.test_session():
+ star = tf.train.match_filenames_once(
+ os.path.join(self.get_temp_dir(), "*"))
+ question = tf.train.match_filenames_once(
+ os.path.join(self.get_temp_dir(), "match_filenames.?"))
+ one = tf.train.match_filenames_once(additional[1])
+ tf.initialize_all_variables().run()
+ self.assertItemsEqual(filenames, star.eval())
+ self.assertItemsEqual(additional, question.eval())
+ self.assertItemsEqual([additional[1]], one.eval())
+
+
+class LimitEpochsTest(tf.test.TestCase):
+
+ def testNoLimit(self):
+ with self.test_session():
+ seven = tf.constant(7)
+ seven_forever = tf.train.limit_epochs(seven)
+ tf.initialize_all_variables().run()
+ for i in range(100):
+ self.assertEqual(7, seven_forever.eval())
+
+ def testLimit(self):
+ with self.test_session():
+ love_me = tf.constant("Love Me")
+ love_me_two_times = tf.train.limit_epochs(love_me, num_epochs=2)
+ tf.initialize_all_variables().run()
+ self.assertEqual("Love Me", love_me_two_times.eval())
+ self.assertEqual("Love Me", love_me_two_times.eval())
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ love_me_two_times.eval()
+
+
+class StringInputProducerTest(tf.test.TestCase):
+
+ def testNoShuffle(self):
+ with self.test_session():
+ strings = ["to", "be", "or", "not", "to", "be"]
+ num_epochs = 3
+ queue = tf.train.string_input_producer(
+ strings, num_epochs=num_epochs, shuffle=False)
+ dequeue_many = queue.dequeue_many(len(strings) * num_epochs)
+ dequeue = queue.dequeue()
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # No randomness, so just see repeated copies of the input.
+ output = dequeue_many.eval()
+ self.assertAllEqual(strings * num_epochs, output)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ dequeue.eval()
+ for thread in threads:
+ thread.join()
+
+ def testShuffle(self):
+ with self.test_session():
+ strings = ["a", "b", "c"]
+ num_epochs = 600
+ queue = tf.train.string_input_producer(
+ strings, num_epochs=num_epochs, shuffle=True, seed=271828)
+ dequeue_many = queue.dequeue_many(len(strings))
+ dequeue = queue.dequeue()
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # Validate that we only shuffle the strings within an epoch and
+ # count how often each possible order appears.
+ expected = ["abc", "acb", "bac", "bca", "cab", "cba"]
+ frequency = {}
+ for e in expected:
+ frequency[e] = 0
+ for _ in range(num_epochs):
+ output = dequeue_many.eval()
+ key = "".join(output)
+ self.assertIn(key, expected)
+ frequency[key] += 1
+
+ # Expect an approximately even distribution over all possible orders.
+ expected_frequency = num_epochs / len(expected)
+ margin = expected_frequency * 0.4
+ tf.logging.info("Observed counts: %s", frequency)
+ for key in expected:
+ value = frequency[key]
+ self.assertGreater(value, expected_frequency - margin)
+ self.assertLess(value, expected_frequency + margin)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ dequeue.eval()
+ for thread in threads:
+ thread.join()
+
+
+class RangeInputProducerTest(tf.test.TestCase):
+
+ def testNoShuffle(self):
+ with self.test_session():
+ num_epochs = 3
+ range_size = 5
+ queue = tf.train.range_input_producer(
+ range_size, num_epochs=num_epochs, shuffle=False)
+ dequeue_many = queue.dequeue_many(range_size * num_epochs)
+ dequeue = queue.dequeue()
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # No randomness, so just see repeated copies of the input.
+ output = dequeue_many.eval()
+ self.assertAllEqual(range(range_size) * num_epochs, output)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ dequeue.eval()
+ for thread in threads:
+ thread.join()
+
+ def testShuffle(self):
+ with self.test_session():
+ num_epochs = 200
+ range_size = 2
+ queue = tf.train.range_input_producer(
+ range_size, num_epochs=num_epochs, shuffle=True, seed=314159)
+ dequeue_many = queue.dequeue_many(range_size)
+ dequeue = queue.dequeue()
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # Validate that we only shuffle the integers within an epoch and
+ # count how often each possible order appears.
+ expected = [12, 21]
+ frequency = {}
+ for e in expected:
+ frequency[e] = 0
+ for _ in range(num_epochs):
+ output = dequeue_many.eval()
+ key = 10 * (output[0] + 1) + (output[1] + 1)
+ self.assertIn(key, expected)
+ frequency[key] += 1
+
+ # Expect an approximately even distribution over all possible orders.
+ expected_frequency = num_epochs / len(expected)
+ margin = expected_frequency * 0.4
+ tf.logging.info("Observed counts: %s", frequency)
+ for key in expected:
+ value = frequency[key]
+ self.assertGreater(value, expected_frequency - margin)
+ self.assertLess(value, expected_frequency + margin)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ dequeue.eval()
+ for thread in threads:
+ thread.join()
+
+
+class SliceInputProducerTest(tf.test.TestCase):
+
+ def testNoShuffle(self):
+ with self.test_session() as sess:
+ num_epochs = 3
+ source_strings = ["Alpha", "Beta", "Delta", "Gamma"]
+ source_ints = [2, 3, 5, 7]
+ slices = tf.train.slice_input_producer(
+ [source_strings, source_ints], num_epochs=num_epochs, shuffle=False)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # No randomness, so just see repeated copies of the input.
+ num_items = len(source_strings) * num_epochs
+ output = [sess.run(slices) for _ in range(num_items)]
+ out_strings, out_ints = zip(*output)
+ self.assertAllEqual(source_strings * num_epochs, out_strings)
+ self.assertAllEqual(source_ints * num_epochs, out_ints)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(slices)
+ for thread in threads:
+ thread.join()
+
+ def testShuffle(self):
+ with self.test_session() as sess:
+ num_epochs = 1200
+ source_strings = ["A", "B", "D", "G"]
+ source_ints = [7, 3, 5, 2]
+ slices = tf.train.slice_input_producer(
+ [source_strings, source_ints], num_epochs=num_epochs, shuffle=True,
+ seed=161803)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # Validate that we only shuffle the integers within an epoch and
+ # count how often each possible order appears.
+ expected = [",".join(x) for x in
+ itertools.permutations(["A7", "B3", "D5", "G2"])]
+ frequency = {}
+ for e in expected:
+ frequency[e] = 0
+ for _ in range(num_epochs):
+ output = [sess.run(slices) for _ in range(len(source_strings))]
+ key = ",".join([s + str(i) for s, i in output])
+ self.assertIn(key, expected)
+ frequency[key] += 1
+
+ # Expect an approximately even distribution over all possible orders.
+ expected_frequency = num_epochs / len(expected)
+ margin = expected_frequency * 0.4
+ tf.logging.info("Observed counts: %s", frequency)
+ for key in expected:
+ value = frequency[key]
+ self.assertGreater(value, expected_frequency - margin)
+ self.assertLess(value, expected_frequency + margin)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(slices)
+ for thread in threads:
+ thread.join()
+
+
+class BatchTest(tf.test.TestCase):
+
+ def testOneThread(self):
+ with self.test_session() as sess:
+ batch_size = 10
+ num_batches = 3
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_batches * batch_size)
+ batched = tf.train.batch([counter, "string"], batch_size=batch_size)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ for i in range(num_batches):
+ results = sess.run(batched)
+ self.assertAllEqual(results[0],
+ range(i * batch_size, (i + 1) * batch_size))
+ self.assertAllEqual(results[1], ["string"] * batch_size)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+ def testManyThreads(self):
+ with self.test_session() as sess:
+ batch_size = 10
+ num_batches = 3
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_batches * batch_size)
+ batched = tf.train.batch([counter, "string"], batch_size=batch_size,
+ num_threads=4)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ all_counts = []
+ for i in range(num_batches):
+ results = sess.run(batched)
+ tf.logging.info("Batch %d: %s", i, results[0])
+ self.assertEqual(len(results[0]), batch_size)
+ all_counts.extend(results[0])
+ self.assertAllEqual(results[1], ["string"] * batch_size)
+ self.assertItemsEqual(all_counts, range(num_batches * batch_size))
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+
+class BatchJoinTest(tf.test.TestCase):
+
+ def testTwoThreads(self):
+ with self.test_session() as sess:
+ # Two threads, the first generates (0..34, "a").
+ num_a = 35
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_a)
+
+ # The second generates (99, "b") 45 times and then stops.
+ num_b = 45
+ ninety_nine = tf.train.limit_epochs(
+ tf.constant(99, dtype=tf.int64), num_b)
+
+ # These get joined together and grouped into batches of 5.
+ batch_size = 5
+ batched = tf.train.batch_join([[counter, "a"], [ninety_nine, "b"]],
+ batch_size=batch_size)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # Should see the "a" and "b" threads mixed together.
+ all_a = []
+ seen_b = 0
+ saw_both = 0
+ num_batches = (num_a + num_b) / batch_size
+ for i in range(num_batches):
+ results = sess.run(batched)
+ tf.logging.info("Batch %d: %s", i, results[0])
+ self.assertEqual(len(results[0]), batch_size)
+ self.assertEqual(len(results[1]), batch_size)
+ which_a = [i for i, s in enumerate(results[1]) if s == "a"]
+ which_b = [i for i, s in enumerate(results[1]) if s == "b"]
+ self.assertEqual(len(which_a) + len(which_b), batch_size)
+ if len(which_a) > 0 and len(which_b) > 0: saw_both += 1
+ all_a.extend([results[0][i] for i in which_a])
+ seen_b += len(which_b)
+ self.assertAllEqual([99] * len(which_b),
+ [results[0][i] for i in which_b])
+
+ # Some minimum level of mixing of the results of both threads.
+ self.assertGreater(saw_both, 1)
+
+ # Verify the order of results from "a" were preserved.
+ self.assertAllEqual(all_a, range(num_a))
+ self.assertEqual(seen_b, num_b)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+
+class ShuffleBatchTest(tf.test.TestCase):
+
+ def testOneThread(self):
+ with self.test_session() as sess:
+ batch_size = 10
+ num_batches = 3
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_batches * batch_size)
+ batched = tf.train.shuffle_batch(
+ [counter, "string"], batch_size=batch_size, capacity=32,
+ min_after_dequeue=16, seed=141421)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ all_counts = []
+ for i in range(num_batches):
+ results = sess.run(batched)
+ self.assertEqual(len(results[0]), batch_size)
+ all_counts.extend(results[0])
+ self.assertAllEqual(results[1], ["string"] * batch_size)
+ # Results scrambled, but include all the expected numbers.
+ deltas = [all_counts[i + 1] - all_counts[i]
+ for i in range(len(all_counts) - 1)]
+ self.assertFalse(all(d == deltas[0] for d in deltas))
+ self.assertItemsEqual(all_counts, range(num_batches * batch_size))
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+ def testManyThreads(self):
+ with self.test_session() as sess:
+ batch_size = 10
+ num_batches = 3
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_batches * batch_size)
+ batched = tf.train.shuffle_batch(
+ [counter, "string"], batch_size=batch_size, capacity=32,
+ min_after_dequeue=16, seed=173205, num_threads=4)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ all_counts = []
+ for i in range(num_batches):
+ results = sess.run(batched)
+ tf.logging.info("Batch %d: %s", i, results[0])
+ self.assertEqual(len(results[0]), batch_size)
+ all_counts.extend(results[0])
+ self.assertAllEqual(results[1], ["string"] * batch_size)
+ # Results scrambled, but include all the expected numbers.
+ deltas = [all_counts[i + 1] - all_counts[i]
+ for i in range(len(all_counts) - 1)]
+ self.assertFalse(all(d == deltas[0] for d in deltas))
+ self.assertItemsEqual(all_counts, range(num_batches * batch_size))
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+
+class ShuffleBatchJoinTest(tf.test.TestCase):
+
+ def testTwoThreads(self):
+ with self.test_session() as sess:
+ # Two threads, the first generates (0..24, "a").
+ num_a = 25
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_a)
+
+ # The second generates (99, "b") 35 times and then stops.
+ num_b = 35
+ ninety_nine = tf.train.limit_epochs(
+ tf.constant(99, dtype=tf.int64), num_b)
+
+ # These get joined together and grouped into batches of 5.
+ batch_size = 5
+ batched = tf.train.shuffle_batch_join(
+ [[counter, "a"], [ninety_nine, "b"]], batch_size=batch_size,
+ capacity=32, min_after_dequeue=16, seed=223607)
+
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # Should see the "a" and "b" threads mixed together.
+ all_a = []
+ seen_b = 0
+ saw_both = 0
+ num_batches = (num_a + num_b) / batch_size
+ for i in range(num_batches):
+ results = sess.run(batched)
+ tf.logging.info("Batch %d: %s", i, results[0])
+ self.assertEqual(len(results[0]), batch_size)
+ self.assertEqual(len(results[1]), batch_size)
+ which_a = [i for i, s in enumerate(results[1]) if s == "a"]
+ which_b = [i for i, s in enumerate(results[1]) if s == "b"]
+ self.assertEqual(len(which_a) + len(which_b), batch_size)
+ if len(which_a) > 0 and len(which_b) > 0: saw_both += 1
+ all_a.extend([results[0][i] for i in which_a])
+ seen_b += len(which_b)
+ self.assertAllEqual([99] * len(which_b),
+ [results[0][i] for i in which_b])
+
+ # Some minimum level of mixing of the results of both threads.
+ self.assertGreater(saw_both, 1)
+
+ # Saw all the items from "a", but scrambled.
+ self.assertItemsEqual(all_a, range(num_a))
+ deltas = [all_a[i + 1] - all_a[i]
+ for i in range(len(all_a) - 1)]
+ self.assertFalse(all(d == deltas[0] for d in deltas))
+ self.assertEqual(seen_b, num_b)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py
new file mode 100644
index 0000000000..cafcb26d01
--- /dev/null
+++ b/tensorflow/python/training/learning_rate_decay.py
@@ -0,0 +1,65 @@
+"""Various learning rate decay functions."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+
+
+def exponential_decay(learning_rate, global_step, decay_steps, decay_rate,
+ staircase=False, name=None):
+ """Applies exponential decay to the learning rate.
+
+ When training a model, it is often recommended to lower the learning rate as
+ the training progresses. This function applies an exponential decay function
+ to a provided initial learning rate. It requires a `global_step` value to
+ compute the decayed learning rate. You can just pass a TensorFlow variable
+ that you increment at each training step.
+
+ The function returns the decayed learning rate. It is computed as:
+
+ ```python
+ decayed_learning_rate = learning_rate *
+ decay_rate ^ (global_step / decay_steps)
+ ```
+
+ If the argument `staircase` is `True`, then `global_step /decay_steps` is an
+ integer division and the decayed learning rate follows a staircase function.
+
+ Example: decay every 100000 steps with a base of 0.96:
+
+ ```python
+ ...
+ global_step = tf.Variable(0, trainable=False)
+ starter_learning_rate = 0.1
+ learning_rate = tf.exponential_decay(starter_learning_rate, global_step,
+ 100000, 0.96, staircase=True)
+ optimizer = tf.GradientDescent(learning_rate)
+ # Passing global_step to minimize() will increment it at each step.
+ optimizer.minimize(...my loss..., global_step=global_step)
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The initial learning rate.
+ global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Global step to use for the decay computation. Must not be negative.
+ decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Must be positive. See the decay computation above.
+ decay_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The decay rate.
+ staircase: Boolean. It `True` decay the learning rate at discrete intervals.
+ name: string. Optional name of the operation. Defaults to 'ExponentialDecay'
+
+ Returns:
+ A scalar `Tensor` of the same type as `learning_rate`. The decayed
+ learning rate.
+ """
+ with ops.op_scope([learning_rate, global_step, decay_steps, decay_rate],
+ name, "ExponentialDecay") as name:
+ learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+ dtype = learning_rate.dtype
+ global_step = math_ops.cast(global_step, dtype)
+ decay_steps = math_ops.cast(decay_steps, dtype)
+ decay_rate = math_ops.cast(decay_rate, dtype)
+ p = global_step / decay_steps
+ if staircase:
+ p = math_ops.floor(p)
+ return math_ops.mul(learning_rate, math_ops.pow(decay_rate, p), name=name)
diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py
new file mode 100644
index 0000000000..b85d58cae7
--- /dev/null
+++ b/tensorflow/python/training/learning_rate_decay_test.py
@@ -0,0 +1,60 @@
+"""Functional test for learning rate decay."""
+import tensorflow.python.platform
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import learning_rate_decay
+
+
+class LRDecayTest(test_util.TensorFlowTestCase):
+
+ def testContinuous(self):
+ with self.test_session():
+ step = 5
+ decayed_lr = learning_rate_decay.exponential_decay(0.05, step, 10, 0.96)
+ expected = .05 * 0.96 ** (5.0 / 10.0)
+ self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+
+ def testStaircase(self):
+ with self.test_session():
+ step = state_ops.variable_op([], types.int32)
+ assign_100 = state_ops.assign(step, 100)
+ assign_1 = state_ops.assign(step, 1)
+ assign_2 = state_ops.assign(step, 2)
+ decayed_lr = learning_rate_decay.exponential_decay(.1, step, 3, 0.96,
+ staircase=True)
+ # No change to learning rate
+ assign_1.op.run()
+ self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
+ assign_2.op.run()
+ self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
+ # Decayed learning rate
+ assign_100.op.run()
+ expected = .1 * 0.96 ** (100 / 3)
+ self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+
+ def testVariables(self):
+ with self.test_session():
+ step = variables.Variable(1)
+ assign_1 = step.assign(1)
+ assign_2 = step.assign(2)
+ assign_100 = step.assign(100)
+ decayed_lr = learning_rate_decay.exponential_decay(.1, step, 3, 0.96,
+ staircase=True)
+ variables.initialize_all_variables().run()
+ # No change to learning rate
+ assign_1.op.run()
+ self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
+ assign_2.op.run()
+ self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
+ # Decayed learning rate
+ assign_100.op.run()
+ expected = .1 * 0.96 ** (100 / 3)
+ self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/training/momentum.py b/tensorflow/python/training/momentum.py
new file mode 100644
index 0000000000..fdd434359f
--- /dev/null
+++ b/tensorflow/python/training/momentum.py
@@ -0,0 +1,51 @@
+"""Momentum for TensorFlow."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_ops
+
+
+class MomentumOptimizer(optimizer.Optimizer):
+ """Optimizer that implements the Momentum algorithm.
+
+ @@__init__
+ """
+
+ def __init__(self, learning_rate, momentum,
+ use_locking=False, name="Momentum"):
+ """Construct a new Momentum optimizer.
+
+ Args:
+ learning_rate: A `Tensor` or a floating point value. The learning rate.
+ momentum: A `Tensor` or a floating point value. The momentum.
+ use_locking: If `True` use locks for update operations.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "Momentum".
+ """
+ super(MomentumOptimizer, self).__init__(use_locking, name)
+ self._learning_rate = learning_rate
+ self._momentum = momentum
+
+ def _create_slots(self, var_list):
+ for v in var_list:
+ self._zeros_slot(v, "momentum", self._name)
+
+ def _prepare(self):
+ self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
+ name="learning_rate")
+ self._momentum_tensor = ops.convert_to_tensor(self._momentum,
+ name="momentum")
+
+ def _apply_dense(self, grad, var):
+ mom = self.get_slot(var, "momentum")
+ return training_ops.apply_momentum(
+ var, mom,
+ self._learning_rate_tensor, grad, self._momentum_tensor,
+ use_locking=self._use_locking).op
+
+ def _apply_sparse(self, grad, var):
+ mom = self.get_slot(var, "momentum")
+ return training_ops.sparse_apply_momentum(
+ var, mom,
+ self._learning_rate_tensor, grad.values, grad.indices,
+ self._momentum_tensor, use_locking=self._use_locking).op
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py
new file mode 100644
index 0000000000..2cf86d97c9
--- /dev/null
+++ b/tensorflow/python/training/momentum_test.py
@@ -0,0 +1,258 @@
+"""Tests for Momentum."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class MomentumOptimizerTest(tf.test.TestCase):
+
+ def testBasic(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+ mom_opt = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9)
+ mom_update = mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ self.assertFalse(slot0 in tf.trainable_variables())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ self.assertFalse(slot1 in tf.trainable_variables())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllClose(np.array([0.1, 0.1]), slot0.eval())
+ self.assertAllClose(np.array([0.01, 0.01]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllClose(np.array([1.0 - (0.1 * 2.0),
+ 2.0 - (0.1 * 2.0)]),
+ var0.eval())
+ self.assertAllClose(np.array([3.0 - (0.01 * 2.0),
+ 4.0 - (0.01 * 2.0)]),
+ var1.eval())
+ # Step 2: the momentum accumulators contain the previous update.
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllClose(np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
+ slot0.eval())
+ self.assertAllClose(np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
+ slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllClose(
+ np.array([1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)]),
+ var0.eval())
+ self.assertAllClose(np.array([2.98 - ((0.9 * 0.01 + 0.01) * 2.0),
+ 3.98 - ((0.9 * 0.01 + 0.01) * 2.0)]),
+ var1.eval())
+
+ def testFloat64(self):
+ with self.test_session():
+ opt = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9)
+
+ # compute_gradients.
+ values = [1.0, 3.0]
+ good_vars = [tf.Variable([v]) for v in values]
+ bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
+ opt.compute_gradients, bad_loss, good_vars)
+ bad_vars = [
+ tf.Variable(np.array([v], np.float64), name="bad_var")
+ for v in values]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
+ bad_vars)
+ opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
+
+ # apply_gradients.
+ bad_grads = [
+ tf.constant([0.1], dtype=np.float64, name="bad_grad"),
+ tf.constant([0.01])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
+ opt.apply_gradients, zip(bad_grads, good_vars))
+ good_grads = [tf.constant([0.01]), tf.constant([0.02])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.apply_gradients, zip(good_grads, bad_vars))
+ opt.apply_gradients(zip(good_grads, good_vars))
+
+ def _dbParamsMom01(self):
+ """Return dist-belief momentum values.
+
+ Return values been generated from the dist-belief momentum unittest,
+ running with a learning rate of 0.1 and a momemntum of 0.1.
+
+ These values record how a parameter vector of size 10, initialized with 0.0,
+ gets updated with 10 consecutive momentum steps. It uses random gradients.
+
+ Returns:
+ db_grad: The gradients to apply
+ db_out: The parameters after the momentum update.
+ """
+ db_grad = [[]] * 10
+ db_out = [[]] * 10
+ # pylint: disable=line-too-long
+ db_grad[0] = [0.00096264342, 0.17914793, 0.93945462, 0.41396621, 0.53037018, 0.93197989, 0.78648776, 0.50036013, 0.55345792, 0.96722615]
+ db_out[0] = [-9.6264346e-05, -0.017914793, -0.093945466, -0.041396622, -0.053037018, -0.093197994, -0.078648776, -0.050036013, -0.055345792, -0.096722618]
+ db_grad[1] = [0.17075552, 0.88821375, 0.20873757, 0.25236958, 0.57578111, 0.15312378, 0.5513742, 0.94687688, 0.16012503, 0.22159521]
+ db_out[1] = [-0.017181443, -0.10852765, -0.12421377, -0.070773244, -0.11591884, -0.11783017, -0.14165108, -0.14972731, -0.076892875, -0.1285544]
+ db_grad[2] = [0.35077485, 0.47304362, 0.44412705, 0.44368884, 0.078527533, 0.81223965, 0.31168157, 0.43203235, 0.16792089, 0.24644311]
+ db_out[2] = [-0.053967446, -0.1648933, -0.1716533, -0.1180798, -0.13005978, -0.20151734, -0.17911947, -0.20289968, -0.095839672, -0.15638189]
+ db_grad[3] = [0.9694621, 0.75035888, 0.28171822, 0.83813518, 0.53807181, 0.3728098, 0.81454384, 0.03848977, 0.89759839, 0.93665648]
+ db_out[3] = [-0.15459226, -0.24556576, -0.20456907, -0.20662397, -0.18528105, -0.24716705, -0.2643207, -0.21206589, -0.18749419, -0.2528303]
+ db_grad[4] = [0.38578293, 0.8536852, 0.88722926, 0.66276771, 0.13678469, 0.94036359, 0.69107032, 0.81897682, 0.5433259, 0.67860287]
+ db_out[4] = [-0.20323303, -0.33900154, -0.29658359, -0.28175515, -0.20448165, -0.34576839, -0.34194785, -0.29488021, -0.25099224, -0.33033544]
+ db_grad[5] = [0.27885768, 0.76100707, 0.24625534, 0.81354135, 0.18959245, 0.48038563, 0.84163809, 0.41172323, 0.83259648, 0.44941229]
+ db_out[5] = [-0.23598288, -0.42444581, -0.33041057, -0.3706224, -0.22536094, -0.40366709, -0.43387437, -0.34433398, -0.34060168, -0.38302717]
+ db_grad[6] = [0.27233034, 0.056316052, 0.5039115, 0.24105175, 0.35697976, 0.75913221, 0.73577434, 0.16014607, 0.57500273, 0.071136251]
+ db_out[6] = [-0.26649091, -0.43862185, -0.38418442, -0.40361428, -0.26314685, -0.48537019, -0.51664448, -0.36529395, -0.40706289, -0.39540997]
+ db_grad[7] = [0.58697265, 0.2494842, 0.08106143, 0.39954534, 0.15892942, 0.12683646, 0.74053431, 0.16033, 0.66625422, 0.73515922]
+ db_out[7] = [-0.32823896, -0.46498787, -0.39766794, -0.446868, -0.28281838, -0.50622416, -0.59897494, -0.38342294, -0.48033443, -0.47016418]
+ db_grad[8] = [0.8215279, 0.41994119, 0.95172721, 0.68000203, 0.79439718, 0.43384039, 0.55561525, 0.22567581, 0.93331909, 0.29438227]
+ db_out[8] = [-0.41656655, -0.50961858, -0.49418902, -0.51919359, -0.36422527, -0.55169362, -0.6627695, -0.40780342, -0.58099347, -0.50707781]
+ db_grad[9] = [0.68297005, 0.67758518, 0.1748755, 0.13266537, 0.70697063, 0.055731893, 0.68593478, 0.50580865, 0.12602448, 0.093537711]
+ db_out[9] = [-0.49369633, -0.58184016, -0.52132869, -0.5396927, -0.44306302, -0.56181377, -0.73774242, -0.46082234, -0.60366184, -0.52012295]
+ # pylint: enable=line-too-long
+ return db_grad, db_out
+
+ def testLikeDistBeliefMom01(self):
+ with self.test_session():
+ db_grad, db_out = self._dbParamsMom01()
+ num_samples = len(db_grad)
+ var0 = tf.Variable([0.0] * num_samples)
+ grads0 = tf.constant([0.0] * num_samples)
+ mom_opt = tf.train.MomentumOptimizer(learning_rate=0.1, momentum=0.1)
+ mom_update = mom_opt.apply_gradients(zip([grads0], [var0]))
+ tf.initialize_all_variables().run()
+ for i in xrange(num_samples):
+ mom_update.run(feed_dict={grads0: db_grad[i]})
+ self.assertAllClose(np.array(db_out[i]), var0.eval())
+
+ def testSparse(self):
+ with self.test_session():
+ var0 = tf.Variable(tf.zeros([4, 2]))
+ var1 = tf.Variable(
+ tf.constant(1.0, tf.float32, [4, 2]))
+ grads0 = tf.IndexedSlices(tf.constant([[.1, .1]]),
+ tf.constant([1]),
+ tf.constant([4, 2]))
+ grads1 = tf.IndexedSlices(tf.constant([[.01, .01], [.01, .01]]),
+ tf.constant([2, 3]),
+ tf.constant([4, 2]))
+ mom_opt = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9)
+ mom_update = mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([0, 0], var0.eval()[0])
+ self.assertAllClose([0, 0], var0.eval()[1])
+ self.assertAllClose([1, 1], var1.eval()[2])
+
+ # Step 1: the momentum accumulators are 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllClose(np.array([0, 0]), slot0.eval()[0])
+ self.assertAllClose(np.array([.1, .1]), slot0.eval()[1])
+ self.assertAllClose(np.array([.01, .01]), slot1.eval()[2])
+ # Check that the parameters have been updated.
+ self.assertAllClose(np.array([0, 0]), var0.eval()[0])
+ self.assertAllClose(np.array([- (0.1 * 2.0),
+ - (0.1 * 2.0)]),
+ var0.eval()[1])
+ self.assertAllClose(np.array([1.0 - (0.01 * 2.0),
+ 1.0 - (0.01 * 2.0)]),
+ var1.eval()[2])
+ # Step 2: the momentum accumulators contain the previous update.
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllClose(np.array([0, 0]), slot0.eval()[0])
+ self.assertAllClose(np.array([(0.9 * 0.1 + 0.1),
+ (0.9 * 0.1 + 0.1)]),
+ slot0.eval()[1])
+ self.assertAllClose(np.array([(0.9 * 0.01 + 0.01),
+ (0.9 * 0.01 + 0.01)]),
+ slot1.eval()[2])
+ # Check that the parameters have been updated.
+ self.assertAllClose(np.array([0, 0]), var0.eval()[0])
+ self.assertAllClose(
+ np.array([- (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)]),
+ var0.eval()[1])
+ self.assertAllClose(np.array([0.98 - ((0.9 * 0.01 + 0.01) * 2.0),
+ 0.98 - ((0.9 * 0.01 + 0.01) * 2.0)]),
+ var1.eval()[2])
+
+ def testSharing(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+ mom_opt = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9)
+ mom_update1 = mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ mom_update2 = mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update1.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllClose(np.array([0.1, 0.1]), slot0.eval())
+ self.assertAllClose(np.array([0.01, 0.01]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllClose(np.array([1.0 - (0.1 * 2.0),
+ 2.0 - (0.1 * 2.0)]),
+ var0.eval())
+ self.assertAllClose(np.array([3.0 - (0.01 * 2.0),
+ 4.0 - (0.01 * 2.0)]),
+ var1.eval())
+ # Step 2: the second momentum accumulators contain the previous update.
+ mom_update2.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllClose(np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
+ slot0.eval())
+ self.assertAllClose(np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
+ slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllClose(
+ np.array([1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)]),
+ var0.eval())
+ self.assertAllClose(np.array([2.98 - ((0.9 * 0.01 + 0.01) * 2.0),
+ 3.98 - ((0.9 * 0.01 + 0.01) * 2.0)]),
+ var1.eval())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
new file mode 100644
index 0000000000..becc71dfa2
--- /dev/null
+++ b/tensorflow/python/training/moving_averages.py
@@ -0,0 +1,247 @@
+"""Maintain moving averages of parameters."""
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+
+
+# TODO(mdevin): switch to variables.Variable.
+def assign_moving_average(variable, value, decay, name=None):
+ """Compute the moving average of a variable.
+
+ The moving average of 'variable' updated with 'value' is:
+ variable * decay + value * (1 - decay)
+
+ The returned Operation sets 'variable' to the newly computed moving average.
+
+ The new value of 'variable' can be set with the 'AssignSub' op as:
+ variable -= (1 - decay) * (variable - value)
+
+ Args:
+ variable: A Variable.
+ value: A tensor with the same shape as 'variable'
+ decay: A float Tensor or float value. The moving average decay.
+ name: Optional name of the returned operation.
+
+ Returns:
+ An Operation that updates 'variable' with the newly computed
+ moving average.
+ """
+ with ops.op_scope([variable, value, decay], name, "AssignMovingAvg") as name:
+ with ops.device(variable.device):
+ decay = ops.convert_to_tensor(1.0 - decay, name="decay")
+ if decay.dtype != variable.dtype.base_dtype:
+ decay = math_ops.cast(decay, variable.dtype.base_dtype)
+ return state_ops.assign_sub(variable, (variable - value) * decay,
+ name=name)
+
+
+class ExponentialMovingAverage(object):
+ """Maintains moving averages of variables by employing and exponential decay.
+
+ When training a model, it is often beneficial to maintain moving averages of
+ the trained parameters. Evaluations that use averaged parameters sometimes
+ produce significantly better results than the final trained values.
+
+ The `apply()` method adds shadow copies of trained variables and add ops that
+ maintain a moving average of the trained variables in their shadow copies.
+ It is used when building the training model. The ops that maintain moving
+ averages are typically run after each training step.
+ The `average()` and `average_name()` methods give access to the shadow
+ variables and their names. They are useful when building an evaluation
+ model, or when restoring a model from a checkpoint file. They help use the
+ moving averages in place of the last trained values for evaluations.
+
+ The moving averages are computed using exponential decay. You specify the
+ decay value when creating the `ExponentialMovingAverage` object. The shadow
+ variables are initialized with the same initial values as the trained
+ variables. When you run the ops to maintain the moving averages, each
+ shadow variable is updated with the formula:
+
+ `shadow_variable -= (1 - decay) * (shadow_variable - variable)`
+
+ This is mathematically equivalent to the classic formula below, but the use
+ of an `assign_sub` op (the `"-="` in the formula) allows concurrent lockless
+ updates to the variables:
+
+ `shadow_variable = decay * shadow_variable + (1 - decay) * variable`
+
+ Reasonable values for `decay` are close to 1.0, typically in the
+ multiple-nines range: 0.999, 0.9999, etc.
+
+ Example usage when creating a training model:
+
+ ```python
+ # Create variables.
+ var0 = tf.Variable(...)
+ var1 = tf.Variable(...)
+ # ... use the variables to build a training model...
+ ...
+ # Create an op that applies the optimizer. This is what we usually
+ # would use as a training op.
+ opt_op = opt.minimize(my_loss, [var0, var1])
+
+ # Create an ExponentialMovingAverage object
+ ema = tf.train.ExponentialMovingAverage(decay=0.9999)
+
+ # Create the shadow variables, and add ops to maintain moving averages
+ # of var0 and var1.
+ maintain_averages_op = ema.apply([var0, var1])
+
+ # Create an op that will update the moving averages after each training
+ # step. This is what we will use in place of the usuall trainig op.
+ with tf.control_dependencies([opt_op]):
+ training_op = tf.group(maintain_averages_op)
+
+ ...train the model by running training_op...
+ ```
+
+ There are two ways to use the moving averages for evaluations:
+
+ * Build a model that uses the shadow variables instead of the variables.
+ For this, use the `average()` method which returns the shadow variable
+ for a given variable.
+ * Build a model normally but load the checkpoint files to evaluate by using
+ the shadow variable names. For this use the `average_name()` method. See
+ the [Saver class](train.md#Saver) for more information on restoring saved
+ variables.
+
+ Example of restoring the shadow variable values:
+
+ ```python
+ # Create a Saver that loads variables from their saved shadow values.
+ shadow_var0_name = ema.average_name(var0)
+ shadow_var1_name = ema.average_name(var1)
+ saver = tf.train.Saver({shadow_var0_name: var0, shadow_var1_name: var1})
+ saver.restore(...checkpoint filename...)
+ # var0 and var1 now hold the moving average values
+ ```
+
+ @@__init__
+ @@apply
+ @@average_name
+ @@average
+ """
+
+ def __init__(self, decay, num_updates=None,
+ name="ExponentialMovingAverage"):
+ """Creates a new ExponentialMovingAverage object.
+
+ The `Apply()` method has to be called to create shadow variables and add
+ ops to maintain moving averages.
+
+ The optional `num_updates` parameter allows one to tweak the decay rate
+ dynamically. . It is typical to pass the count of training steps, usually
+ kept in a variable that is incremented at each step, in which case the
+ decay rate is lower at the start of training. This makes moving averages
+ move faster. If passed, the actual decay rate used is:
+
+ `min(decay, (1 + num_updates) / (10 + num_updates))`
+
+ Args:
+ decay: Float. The decay to use.
+ num_updates: Optional count of number of updates applied to variables.
+ name: String. Optional prefix name to use for the name of ops added in
+ `Apply()`.
+ """
+ self._decay = decay
+ self._num_updates = num_updates
+ self._name = name
+ self._averages = {}
+
+ def apply(self, var_list=None):
+ """Maintains moving averages of variables.
+
+ `var_list` must be a list of `Variable` or `Tensor` objects. This method
+ creates shadow variables for all elements of `var_list`. Shadow variables
+ for `Variable` objects are initialized to the variable's initial value.
+ For `Tensor` objects, the shadow variables are initialized to 0.
+
+ shadow variables are created with `trainable=False` and added to the
+ `GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to
+ `tf.all_variables()`.
+
+ Returns an op that updates all shadow variables as described above.
+
+ Note that `apply()` can be called multiple times with different lists of
+ variables.
+
+ Args:
+ var_list: A list of Variable or Tensor objects. The variables
+ and Tensors must be of types float32 or float64.
+
+ Returns:
+ An Operation that updates the moving averages.
+
+ Raises:
+ TypeError: If the arguments are not all float32 or float64.
+ ValueError: If the moving average of one of the variables is already
+ being computed.
+ """
+ # TODO(mdevin): op_scope
+ if var_list is None:
+ var_list = variables.trainable_variables()
+ for var in var_list:
+ if var.dtype.base_dtype not in [types.float32, types.float64]:
+ raise TypeError("The variables must be float or double: %s" % var)
+ if var in self._averages:
+ raise ValueError("Moving average already computed for: %s" % var)
+ with ops.name_scope(var.op.name + "/" + self._name) as scope:
+ with ops.device(var.device):
+ if isinstance(var, variables.Variable):
+ initial_value = var.initialized_value()
+ else:
+ initial_value = array_ops.zeros(var.get_shape().as_list())
+ avg = variables.Variable(initial_value, name=scope, trainable=False)
+ self._averages[var] = avg
+ with ops.name_scope(self._name) as scope:
+ decay = ops.convert_to_tensor(self._decay, name="decay")
+ if self._num_updates is not None:
+ num_updates = math_ops.cast(self._num_updates, types.float32,
+ name="num_updates")
+ decay = math_ops.minimum(decay,
+ (1.0 + num_updates) / (10.0 + num_updates))
+ updates = []
+ for var in var_list:
+ updates.append(assign_moving_average(self._averages[var], var, decay))
+ return control_flow_ops.group(*updates, name=scope)
+
+ def average(self, var):
+ """Returns the `Variable` holding the average of `var`.
+
+ Args:
+ var: A `Variable` object.
+
+ Returns:
+ A `Variable` object or `None` if the moving average of `var`
+ is not maintained..
+ """
+ return self._averages.get(var, None)
+
+ def average_name(self, var):
+ """Returns the name of the `Variable` holding the average for `var`.
+
+ The typical scenario for `ExponentialMovingAverage` is to compute moving
+ averages of variables during training, and restore the variables from the
+ computed moving averages during evaluations.
+
+ To restore variables, you have to know the name of the shadow variables.
+ That name and the original variable can then be passed to a `Saver()` object
+ to restore the variable from the moving average value with:
+ `saver = tf.train.Saver({ema.average_name(var): var})`
+
+ `average_name()` can be called whether or not `apply()` has been called.
+
+ Args:
+ var: A `Variable` object.
+
+ Returns:
+ A string: the name of the variable that will be used or was used
+ by the `ExponentialMovingAverage class` to hold the moving average of
+ `var`.
+ """
+ return var.op.name + "/" + self._name
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
new file mode 100644
index 0000000000..73ee94b400
--- /dev/null
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -0,0 +1,130 @@
+"""Functional test for moving_averages.py."""
+import tensorflow.python.platform
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import moving_averages
+
+
+class MovingAveragesTest(test_util.TensorFlowTestCase):
+
+ def testAssignMovingAverage(self):
+ with self.test_session():
+ var = variables.Variable([10.0, 11.0])
+ val = constant_op.constant([1.0, 2.0], types.float32)
+ decay = 0.25
+ assign = moving_averages.assign_moving_average(var, val, decay)
+ variables.initialize_all_variables().run()
+ self.assertAllClose([10.0, 11.0], var.eval())
+ assign.op.run()
+ self.assertAllClose([10.0 * 0.25 + 1.0 * (1.0 - 0.25),
+ 11.0 * 0.25 + 2.0 * (1.0 - 0.25)],
+ var.eval())
+
+def _Repeat(value, dim):
+ if dim == 1:
+ return value
+ return [value for _ in xrange(dim)]
+
+class ExponentialMovingAverageTest(test_util.TensorFlowTestCase):
+
+ def _CheckDecay(self, ema, actual_decay, dim):
+ tens = _Repeat(10.0, dim)
+ thirties = _Repeat(30.0, dim)
+ var0 = variables.Variable(tens, name="v0")
+ var1 = variables.Variable(thirties, name="v1")
+ variables.initialize_all_variables().run()
+ # Note that tensor2 is not a Variable but just a plain Tensor resulting
+ # from the sum operation.
+ tensor2 = var0 + var1
+ update = ema.apply([var0, var1, tensor2])
+ avg0 = ema.average(var0)
+ avg1 = ema.average(var1)
+ avg2 = ema.average(tensor2)
+
+ self.assertFalse(avg0 in variables.trainable_variables())
+ self.assertFalse(avg1 in variables.trainable_variables())
+ self.assertFalse(avg2 in variables.trainable_variables())
+ variables.initialize_all_variables().run()
+
+ self.assertEqual("v0/ExponentialMovingAverage:0", avg0.name)
+ self.assertEqual("v1/ExponentialMovingAverage:0", avg1.name)
+ self.assertEqual("add/ExponentialMovingAverage:0", avg2.name)
+
+ # Check initial values.
+ self.assertAllClose(tens, var0.eval())
+ self.assertAllClose(thirties, var1.eval())
+ self.assertAllClose(_Repeat(10.0 + 30.0, dim), tensor2.eval())
+
+ # Check that averages are initialized correctly.
+ self.assertAllClose(tens, avg0.eval())
+ self.assertAllClose(thirties, avg1.eval())
+ # Note that averages of Tensor's initialize to zeros_like since no value
+ # of the Tensor is known because the Op has not been run (yet).
+ self.assertAllClose(_Repeat(0.0, dim), avg2.eval())
+
+ # Update the averages and check.
+ update.run()
+ dk = actual_decay
+
+ expected = _Repeat(10.0 * dk + 10.0 * (1 - dk), dim)
+ self.assertAllClose(expected, avg0.eval())
+ expected = _Repeat(30.0 * dk + 30.0 * (1 - dk), dim)
+ self.assertAllClose(expected, avg1.eval())
+ expected = _Repeat(0.0 * dk + (10.0 + 30.0) * (1 - dk), dim)
+ self.assertAllClose(expected, avg2.eval())
+
+ # Again, update the averages and check.
+ update.run()
+ expected = _Repeat((10.0 * dk + 10.0 * (1 - dk)) * dk + 10.0 * (1 - dk),
+ dim)
+ self.assertAllClose(expected, avg0.eval())
+ expected = _Repeat((30.0 * dk + 30.0 * (1 - dk)) * dk + 30.0 * (1 - dk),
+ dim)
+ self.assertAllClose(expected, avg1.eval())
+ expected = _Repeat(((0.0 * dk + (10.0 + 30.0) * (1 - dk)) * dk +
+ (10.0 + 30.0) * (1 - dk)),
+ dim)
+ self.assertAllClose(expected, avg2.eval())
+
+ def testAverageVariablesNoNumUpdates_Scalar(self):
+ with self.test_session():
+ ema = moving_averages.ExponentialMovingAverage(0.25)
+ self._CheckDecay(ema, actual_decay=0.25, dim=1)
+
+ def testAverageVariablesNoNumUpdates_Vector(self):
+ with self.test_session():
+ ema = moving_averages.ExponentialMovingAverage(0.25)
+ self._CheckDecay(ema, actual_decay=0.25, dim=5)
+
+ def testAverageVariablesNumUpdates_Scalar(self):
+ with self.test_session():
+ # With num_updates 1, the decay applied is 0.1818
+ ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1)
+ self._CheckDecay(ema, actual_decay=0.181818, dim=1)
+
+ def testAverageVariablesNumUpdates_Vector(self):
+ with self.test_session():
+ # With num_updates 1, the decay applied is 0.1818
+ ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1)
+ self._CheckDecay(ema, actual_decay=0.181818, dim=5)
+
+ def testAverageVariablesNames(self):
+ v0 = variables.Variable(10.0, name="v0")
+ v1 = variables.Variable(30.0, name="v1")
+ tensor2 = v0 + v1
+ ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg")
+ self.assertEqual("v0/foo_avg", ema.average_name(v0))
+ self.assertEqual("v1/foo_avg", ema.average_name(v1))
+ self.assertEqual("add/foo_avg", ema.average_name(tensor2))
+ ema.apply([v0, v1, tensor2])
+ self.assertEqual(ema.average_name(v0), ema.average(v0).op.name)
+ self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
+ self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
new file mode 100644
index 0000000000..1186117169
--- /dev/null
+++ b/tensorflow/python/training/optimizer.py
@@ -0,0 +1,426 @@
+"""Base class for optimizers."""
+# pylint: disable=g-bad-name
+import types
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types as tf_types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+
+
+class Optimizer(object):
+ """Base class for optimizers.
+
+ This class defines the API to add Ops to train a model. You never use this
+ class directly, but instead instantiate one of its subclasses such as
+ `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
+
+ ### Usage
+
+ ```
+ # Create an optimizer with the desired parameters.
+ opt = GradientDescentOptimizer(learning_rate=0.1)
+ # Add Ops to the graph to minimize a cost by updating a list of variables.
+ # "cost" is a Tensor, and the list of variables contains variables.Variable
+ # objects.
+ opt_op = opt.minimize(cost, <list of variables>)
+ ```
+
+ In the training program you will just have to run the returned Op.
+
+ ```
+ # Execute opt_op to do one step of training:
+ opt_op.run()
+ ```
+
+ ### Processing gradients before applying them.
+
+ Calling `minimize()` takes care of both computing the gradients and
+ applying them to the variables. If you want to process the gradients
+ before applying them you can instead use the optimizer in three steps:
+
+ 1. Compute the gradients with `compute_gradients()`.
+ 2. Process the gradients as you wish.
+ 3. Apply the processed gradients with `apply_gradients()`.
+
+ Example:
+
+ ```
+ # Create an optimizer.
+ opt = GradientDescentOptimizer(learning_rate=0.1)
+
+ # Compute the gradients for a list of variables.
+ grads_and_vars = opt.compute_gradients(loss, <list of variables>)
+
+ # grads_and_vars is a list of tuples (gradient, variable). Do whatever you
+ # need to the 'gradient' part, for example cap them, etc.
+ capped_grads_and_vars = [(MyCapper(gv[0]), gv[1])) for gv in grads_and_vars]
+
+ # Ask the optimizer to apply the capped gradients.
+ opt.apply_gradients(capped_grads_and_vars)
+ ```
+
+ @@__init__
+
+ @@minimize
+ @@compute_gradients
+ @@apply_gradients
+
+ ### Gating Gradients
+
+ Both `minimize()` and `compute_gradients()` accept a `gate_gradient` argument
+ that controls the degree of parallelism during the application of the
+ gradients.
+
+ The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
+
+ <b>GATE_NONE</b>: Compute and apply gradients in parallel. This provides the
+ maximum parallelism in execution, at the cost of some non-reproducibility in
+ the results. For example the two gradients of MatMul depend on the input
+ values: With `GATE_NONE` one of the gradients could be applied to one of the
+ inputs _before_ the other gradient is computed resulting in non-reproducible
+ results.
+
+ <b>GATE_OP</b>: For each Op, make sure all gradients are computed before they
+ are used. This prevents race conditions for Ops that generate gradients for
+ multiple inputs where the gradients depend on the inputs.
+
+ <b>GATE_GRAPH</b>: Make sure all gradients for all variables are computed
+ before any one of them is used. This provides the least parallelism but can
+ be useful if you want to process all gradients before applying any of them.
+
+ ### Slots
+
+ Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
+ allocate and manage additional variables associated with the variables to
+ train. These are called <i>Slots</i>. Slots have names and you can ask the
+ optimizer for the names of the slots that it uses. Once you have a slot name
+ you can ask the optimizer for the variable it created to hold the slot value.
+
+ This can be useful if you want to log debug a training algorithm, report stats
+ about the slots, etc.
+
+ @@get_slot_names
+ @@get_slot
+ """
+
+ # Values for gate_gradients.
+ GATE_NONE = 0
+ GATE_OP = 1
+ GATE_GRAPH = 2
+
+ def __init__(self, use_locking, name):
+ """Create a new Optimizer.
+
+ This must be called by the constructors of subclasses.
+
+ Args:
+ use_locking: Bool. If True apply use locks to prevent concurrent updates
+ to variables.
+ name: A non-empty string. The name to use for accumulators created
+ for the optimizer.
+
+ Raises:
+ ValueError: if name is malformed.
+ """
+ if not name:
+ raise ValueError("Must specify the optimizer name")
+ self._use_locking = use_locking
+ self._name = name
+ # Dictionary of slots.
+ # {slot_name : { variable_to_train: slot_for_the_variable, ...}, ... }
+ self._slots = {}
+
+ def minimize(self, loss, global_step=None, var_list=None,
+ gate_gradients=GATE_OP, name=None):
+ """Add operations to minimize 'loss' by updating 'var_list'.
+
+ This method simply combines calls compute_gradients() and
+ apply_gradients(). If you want to process the gradient before applying them
+ call compute_gradients() and apply_gradients() explicitly instead of using
+ this function.
+
+ Args:
+ loss: A Tensor containing the value to minimize.
+ global_step: Optional Variable to increment by one after the
+ variables have been updated.
+ var_list: Optional list of variables.Variable to update to minimize
+ 'loss'. Defaults to the list of variables collected in the graph
+ under the key GraphKeys.TRAINABLE_VARIABLES.
+ gate_gradients: How to gate the computation of gradients. Can be
+ GATE_NONE, GATE_OP, or GATE_GRAPH.
+ name: Optional name for the returned operation.
+
+ Returns:
+ An Operation that updates the variables in 'var_list'. If 'global_step'
+ was not None, that operation also increments global_step.
+
+ Raises:
+ ValueError: if some of the variables are not variables.Variable objects.
+ """
+ grads_and_vars = self.compute_gradients(loss, var_list=var_list,
+ gate_gradients=gate_gradients)
+ return self.apply_gradients(grads_and_vars, global_step=global_step,
+ name=name)
+
+ def compute_gradients(self, loss, var_list=None, gate_gradients=GATE_OP):
+ """Compute gradients of "loss" for the variables in "var_list".
+
+ This is the first part of minimize(). It returns a list
+ of (gradient, variable) pairs where "gradient" is the gradient
+ for "variable". Note that "gradient" can be a Tensor, a
+ IndexedSlices, or None if there is no gradient for the
+ given variable.
+
+ Args:
+ loss: A Tensor containing the value to minimize.
+ var_list: Optional list of variables.Variable to update to minimize
+ "loss". Defaults to the list of variables collected in the graph
+ under the key GraphKey.TRAINABLE_VARIABLES.
+ gate_gradients: How to gate the computation of gradients. Can be
+ GATE_NONE, GATE_OP, or GATE_GRAPH.
+
+ Returns:
+ A list of (gradient, variable) pairs.
+
+ Raises:
+ TypeError: If var_list contains anything else than variables.Variable.
+ ValueError: If some arguments are invalid.
+ """
+ if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
+ Optimizer.GATE_GRAPH]:
+ raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
+ "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" %
+ gate_gradients)
+ self._assert_valid_dtypes([loss])
+ if var_list is None:
+ var_list = variables.trainable_variables()
+ for var in var_list:
+ if not isinstance(var, variables.Variable):
+ raise TypeError("Argument is not a variables.Variable: %s" % var)
+ grads = gradients.gradients(
+ loss, var_list, gate_gradients=(gate_gradients == Optimizer.GATE_OP))
+ if gate_gradients == Optimizer.GATE_GRAPH:
+ grads = control_flow_ops.tuple(grads)
+ grads_and_vars = zip(grads, var_list)
+ self._assert_valid_dtypes([v for g, v in grads_and_vars if g is not None])
+ return grads_and_vars
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ """Apply gradients to variables.
+
+ This is the second part of minimize(). It returns an Operation that
+ applies gradients.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ compute_gradients().
+ global_step: Optional Variable to increment by one after the
+ variables have been updated.
+ name: Optional name for the returned operation. Default to the
+ name passed to the Optimizer constructor.
+
+ Returns:
+ An Operation that applies the specified gradients. If 'global_step'
+ was not None, that operation also increments global_step.
+
+ Raises:
+ TypeError: if grads_and_vars is malformed.
+ """
+ # This is a default implementation of apply_gradients() that can be shared
+ # by most optimizers. It relies on the subclass implementing the following
+ # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
+ for g, v in grads_and_vars:
+ if not isinstance(g, (ops.Tensor, ops.IndexedSlices, types.NoneType)):
+ raise TypeError(
+ "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
+ if not isinstance(v, variables.Variable):
+ raise TypeError(
+ "Variable must be a variables.Variable: %s" % v)
+ if g is not None:
+ self._assert_valid_dtypes([g, v])
+ self._create_slots([v for g, v in grads_and_vars if g is not None])
+ update_ops = []
+ with ops.op_scope([], name, self._name) as name:
+ self._prepare()
+ for grad, var in grads_and_vars:
+ if not grad:
+ continue
+ with ops.name_scope("update_" + var.op.name), ops.device(var.device):
+ if isinstance(grad, ops.Tensor):
+ update_ops.append(self._apply_dense(grad, var))
+ else:
+ update_ops.append(self._apply_sparse(grad, var))
+ if global_step is None:
+ return self._finish(update_ops, name)
+ else:
+ with ops.control_dependencies([self._finish(update_ops, "update")]):
+ with ops.device(global_step.device):
+ return state_ops.assign_add(global_step, 1, name=name).op
+
+ def get_slot(self, var, name):
+ """Return a slot named "name" created for "var" by the Optimizer.
+
+ Some Optimizer subclasses use additional variables. For example
+ Momentum and Adagrad use variables to accumulate updates. This method
+ gives access to these Variables if for some reason you need them.
+
+ Use get_slot_names() to get the list of slot names created by the Optimizer.
+
+ Args:
+ var: A variable passed to minimize() or apply_gradients().
+ name: A string.
+
+ Returns:
+ The Variable for the slot if it was created, None otherwise.
+ """
+ named_slots = self._slots.get(name, None)
+ if not named_slots:
+ return None
+ return named_slots.get(var, None)
+
+ def get_slot_names(self):
+ """Return a list of the names of slots created by the Optimizer.
+
+ See get_slot().
+
+ Returns:
+ A list of strings.
+ """
+ return sorted(self._slots.keys())
+
+ def _assert_valid_dtypes(self, tensors):
+ """Asserts tensors are all valid types (see _valid_dtypes).
+
+ Args:
+ tensors: tensors to check.
+ Raises:
+ ValueError: if any tensor is not a valid type.
+ """
+ valid_dtypes = self._valid_dtypes()
+ for t in tensors:
+ dtype = t.dtype.base_dtype
+ if dtype not in valid_dtypes:
+ raise ValueError(
+ "Invalid type %s for %s, expected: %s." % (
+ dtype, t.name, [v for v in valid_dtypes]))
+
+ # --------------
+ # Methods to be implemented by subclasses if they want to use the
+ # inherited implementation of apply_gradients() or compute_gradients().
+ # --------------
+ def _valid_dtypes(self):
+ """Valid types for loss, variables and gradients.
+
+ Defaults to float32. Subclasses should override to allow other types.
+
+ Returns:
+ Valid types for loss, variables and gradients.
+ """
+ return set([tf_types.float32])
+
+ def _create_slots(self, var_list):
+ """Create all slots needed by the variables.
+
+ Args:
+ var_list: A list of variables.Variable.
+ """
+ # No slots needed by default
+ pass
+
+ def _prepare(self):
+ """Create all needed tensors before applying gradients.
+
+ This is called with the name_scope using the "name" that
+ users have chosen for the application of gradients.
+ """
+ pass
+
+ def _apply_dense(self, grad, var):
+ """Add ops to apply dense gradients to "var".
+
+ Args:
+ grad: A Tensor.
+ var: A variables.Variable.
+
+ Return:
+ An Operation.
+ """
+ raise NotImplementedError()
+
+ def _apply_sparse(self, grad, var):
+ """Add ops to apply sparse gradients to "var".
+
+ Args:
+ grad: IndexedSlices.
+ var: A variables.Variable.
+
+ Return:
+ An Operation.
+ """
+ raise NotImplementedError()
+
+ def _finish(self, update_ops, name_scope):
+ """Do what is needed to finish the update.
+
+ This is called with the name_scope using the "name" that
+ users have chosen for the application of gradients.
+
+ Args:
+ update_ops: List of Operations to update variables. This list contains
+ the values returned by the _apply_dense() and _apply_sparse() calls.
+ name_scope: string. Name to use for the returned operation.
+
+ Returns:
+ The operation to apply updates.
+ """
+ return control_flow_ops.group(*update_ops, name=name_scope)
+
+ # --------------
+ # Utility methods for subclasses.
+ # --------------
+
+ def _get_or_make_slot(self, var, val, slot_name, op_name):
+ """Find or create a slot for a variable.
+
+ Args:
+ var: A variables.Variable.
+ val: A Tensor. The initial value of the slot.
+ slot_name: Name for the slot.
+ op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+
+ Returns:
+ A variables.Variable.
+ """
+ named_slots = self._slots.get(slot_name, None)
+ if named_slots is None:
+ named_slots = {}
+ self._slots[slot_name] = named_slots
+ slot = named_slots.get(var, None)
+ if slot is None:
+ # Scope the slot name in the namespace of the Variable and
+ # create the slot on the same device as the variable.
+ with ops.name_scope(var.op.name + "/" + op_name) as scope:
+ with ops.device(var.device):
+ slot = variables.Variable(val, name=scope, trainable=False)
+ named_slots[var] = slot
+ return slot
+
+ def _zeros_slot(self, var, slot_name, op_name):
+ """Find or create a slot initialized with 0.0.
+
+ Args:
+ var: A variables.Variable.
+ slot_name: Name for the slot.
+ op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+
+ Returns:
+ A variables.Variable.
+ """
+ val = array_ops.zeros(var.get_shape().as_list(), dtype=var.dtype)
+ return self._get_or_make_slot(var, val, slot_name, op_name)
diff --git a/tensorflow/python/training/queue_runner.py b/tensorflow/python/training/queue_runner.py
new file mode 100644
index 0000000000..fcf9927c79
--- /dev/null
+++ b/tensorflow/python/training/queue_runner.py
@@ -0,0 +1,233 @@
+"""Create threads to run multiple enqueue ops."""
+import threading
+
+import tensorflow.python.platform
+
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import logging
+
+
+class QueueRunner(object):
+ """Holds a list of enqueue operations for a queue, each to be run in a thread.
+
+ Queues are a convenient TensorFlow mechanism to compute tensors
+ asynchronously using multiple threads. For example in the canonical 'Input
+ Reader' setup one set of threads generates filenames in a queue; a second set
+ of threads read records from the files, processes them, and enqueues tensors
+ on a second queue; a third set of threads dequeues these input records to
+ construct batches and runs them through training operations.
+
+ There are several delicate issues when running multiple threads that way:
+ closing the queues in sequence as the input is exhausted, correctly catching
+ and reporting exceptions, etc.
+
+ The `QueueRunner`, combined with the `Coordinator`, helps handle these issues.
+ """
+
+ def __init__(self, queue, enqueue_ops):
+ """Create a QueueRunner.
+
+ On construction the `QueueRunner` adds an op to close the queue. That op
+ will be run if the enqueue ops raise exceptions.
+
+ When you later call the `create_threads()` method, the `QueueRunner` will
+ create one thread for each op in `enqueue_ops`. Each thread will run its
+ enqueue op in parallel with the other threads. The enqueue ops do not have
+ to all be the same op, but it is expected that they all enqueue tensors in
+ `queue`.
+
+ Args:
+ queue: A `Queue`.
+ enqueue_ops: List of enqueue ops to run in threads later.
+ """
+ self._queue = queue
+ self._enqueue_ops = enqueue_ops
+ # Close when no more will be produced, but pending enqueues should be
+ # preserved.
+ self._close_op = self._queue.close()
+ # Close and cancel pending enqueues since there was an error and we want
+ # to unblock everything so we can cleanly exit.
+ self._cancel_op = self._queue.close(cancel_pending_enqueues=True)
+ # Protect the count of runs to wait for.
+ self._lock = threading.Lock()
+ self._runs = 0
+ # List of exceptions raised by the running threads.
+ self._exceptions_raised = []
+
+ @property
+ def exceptions_raised(self):
+ """Exceptions raised but not handled by the `QueueRunner` threads.
+
+ Exceptions raised in queue runner threads are handled in one of two ways
+ depending on whether or not a `Coordinator` was passed to
+ `create_threads()`:
+
+ * With a `Coordinator`, exceptions are reported to the coordinator and
+ forgotten by the `QueueRunner`.
+ * Without a `Coordinator`, exceptions are captured by the `QueueRunner` and
+ made available in this `exceptions_raised` property.
+
+ Returns:
+ A list of Python `Exception` objects. The list is empty if no exception
+ was captured. (No exceptions are captured when using a Coordinator.)
+ """
+ return self._exceptions_raised
+
+ # pylint: disable=broad-except
+ def _run(self, sess, enqueue_op, coord=None):
+ """Execute the enqueue op in a loop, close the queue in case of error.
+
+ Args:
+ sess: A Session.
+ enqueue_op: The Operation to run.
+ coord: Optional Coordinator object for reporting errors and checking
+ for stop conditions.
+ """
+ decremented = False
+ try:
+ while True:
+ if coord and coord.should_stop():
+ break
+ try:
+ sess.run(enqueue_op)
+ except errors.OutOfRangeError:
+ # This exception indicates that a queue was closed.
+ with self._lock:
+ self._runs -= 1
+ decremented = True
+ if self._runs == 0:
+ try:
+ sess.run(self._close_op)
+ except Exception, e:
+ # Intentionally ignore errors from close_op.
+ logging.vlog(1, "Ignored exception: %s", str(e))
+ return
+ except Exception, e:
+ # This catches all other exceptions.
+ if coord:
+ coord.request_stop(e)
+ else:
+ logging.error("Exception in QueueRunner: %s", str(e))
+ with self._lock:
+ self._exceptions_raised.append(e)
+ raise
+ finally:
+ # Make sure we account for all terminations: normal or errors.
+ if not decremented:
+ with self._lock:
+ self._runs -= 1
+
+ def _close_on_stop(self, sess, cancel_op, coord):
+ """Close the queue when the Coordinator requests stop.
+
+ Args:
+ sess: A Session.
+ cancel_op: The Operation to run.
+ coord: Coordinator.
+ """
+ coord.wait_for_stop()
+ try:
+ sess.run(cancel_op)
+ except Exception, e:
+ # Intentionally ignore errors from cancel_op.
+ logging.vlog(1, "Ignored exception: %s", str(e))
+ # pylint: enable=broad-except
+
+ def create_threads(self, sess, coord=None, daemon=False, start=False):
+ """Create threads to run the enqueue ops.
+
+ This method requires a session in which the graph was launched. It creates
+ a list of threads, optionally starting them. There is one thread for each
+ op passed in `enqueue_ops`.
+
+ The `coord` argument is an optional coordinator, that the threads will use
+ to terminate together and report exceptions. If a coordinator is given,
+ this method starts an additional thread to close the queue when the
+ coordinator requests a stop.
+
+ This method may be called again as long as all threads from a previous call
+ have stopped.
+
+ Args:
+ sess: A `Session`.
+ coord: Optional `Coordinator` object for reporting errors and checking
+ stop conditions.
+ daemon: Boolean. If `True` make the threads daemon threads.
+ start: Boolean. If `True` starts the threads. If `False` the
+ caller must call the `start()` method of the returned threads.
+
+ Returns:
+ A list of threads.
+
+ Raises:
+ RuntimeError: If threads from a previous call to `create_threads()` are
+ still running.
+ """
+ with self._lock:
+ if self._runs > 0:
+ raise RuntimeError(
+ "Threads are already running from a previous call to Threads() "
+ "for this queue runner.")
+ self._runs = len(self._enqueue_ops)
+ self._exceptions_raised = []
+
+ ret_threads = [threading.Thread(target=self._run, args=(sess, op, coord))
+ for op in self._enqueue_ops]
+ if coord:
+ ret_threads.append(threading.Thread(target=self._close_on_stop,
+ args=(sess, self._cancel_op, coord)))
+ for t in ret_threads:
+ if daemon:
+ t.daemon = True
+ if start:
+ t.start()
+ return ret_threads
+
+
+def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS):
+ """Adds a `QueueRunner` to a collection in the graph.
+
+ When building a complex model that uses many queues it is often difficult to
+ gather all the queue runners that need to be run. This convenience function
+ allows you to add a queue runner to a well known collection in the graph.
+
+ The companion method `start_queue_runners()` can be used to start threads for
+ all the collected queue runners.
+
+ Args:
+ qr: A `QueueRunner`.
+ collection: A `GraphKey` specifying the graph collection to add
+ the queue runner to. Defaults to `GraphKeys.QUEUE_RUNNERS`.
+ """
+ ops.add_to_collection(collection, qr)
+
+
+def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
+ collection=ops.GraphKeys.QUEUE_RUNNERS):
+ """Starts all queue runners collected in the graph.
+
+ This is a companion method to `add_queue_runner()`. It just starts
+ threads for all queue runners collected in the graph. It returns
+ the list of all threads.
+
+ Args:
+ sess: `Session` used to run the queue ops. Defaults to the
+ default session.
+ coord: Optional `Coordinator` for coordinating the started threads.
+ daemon: Whether the threads should be marked as `daemons`, meaning
+ they don't block program exit.
+ start: Set to `False` to only create the threads, not start them.
+ collection: A `GraphKey` specifying the graph collection to
+ get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`.
+
+ Returns:
+ A list of threads.
+ """
+ if sess is None:
+ sess = ops.get_default_session()
+ threads = []
+ for qr in ops.get_collection(collection):
+ threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon,
+ start=start))
+ return threads
diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py
new file mode 100644
index 0000000000..c94c02da66
--- /dev/null
+++ b/tensorflow/python/training/queue_runner_test.py
@@ -0,0 +1,186 @@
+"""Tests for QueueRunner."""
+import time
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class QueueRunnerTest(tf.test.TestCase):
+
+ def testBasic(self):
+ with self.test_session() as sess:
+ # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
+ zero64 = tf.constant(0, dtype=tf.int64)
+ var = tf.Variable(zero64)
+ count_up_to = var.count_up_to(3)
+ queue = tf.FIFOQueue(10, tf.float32)
+ tf.initialize_all_variables().run()
+ qr = tf.train.QueueRunner(queue, [count_up_to])
+ threads = qr.create_threads(sess)
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+ self.assertEqual(0, len(qr.exceptions_raised))
+ # The variable should be 3.
+ self.assertEqual(3, var.eval())
+
+ def testTwoOps(self):
+ with self.test_session() as sess:
+ # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
+ zero64 = tf.constant(0, dtype=tf.int64)
+ var0 = tf.Variable(zero64)
+ count_up_to_3 = var0.count_up_to(3)
+ var1 = tf.Variable(zero64)
+ count_up_to_30 = var1.count_up_to(30)
+ queue = tf.FIFOQueue(10, tf.float32)
+ qr = tf.train.QueueRunner(queue, [count_up_to_3, count_up_to_30])
+ threads = qr.create_threads(sess)
+ tf.initialize_all_variables().run()
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+ self.assertEqual(0, len(qr.exceptions_raised))
+ self.assertEqual(3, var0.eval())
+ self.assertEqual(30, var1.eval())
+
+ def testExceptionsCaptured(self):
+ with self.test_session() as sess:
+ queue = tf.FIFOQueue(10, tf.float32)
+ qr = tf.train.QueueRunner(queue, ["i fail", "so fail"])
+ threads = qr.create_threads(sess)
+ tf.initialize_all_variables().run()
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+ exceptions = qr.exceptions_raised
+ self.assertEqual(2, len(exceptions))
+ self.assertTrue("Operation not in the graph" in str(exceptions[0]))
+ self.assertTrue("Operation not in the graph" in str(exceptions[1]))
+
+ def testRealDequeueEnqueue(self):
+ with self.test_session() as sess:
+ q0 = tf.FIFOQueue(3, tf.float32)
+ enqueue0 = q0.enqueue((10.0,))
+ close0 = q0.close()
+ q1 = tf.FIFOQueue(30, tf.float32)
+ enqueue1 = q1.enqueue((q0.dequeue(),))
+ dequeue1 = q1.dequeue()
+ qr = tf.train.QueueRunner(q1, [enqueue1])
+ threads = qr.create_threads(sess)
+ for t in threads:
+ t.start()
+ # Enqueue 2 values, then close queue0.
+ enqueue0.run()
+ enqueue0.run()
+ close0.run()
+ # Wait for the queue runner to terminate.
+ for t in threads:
+ t.join()
+ # It should have terminated cleanly.
+ self.assertEqual(0, len(qr.exceptions_raised))
+ # The 2 values should be in queue1.
+ self.assertEqual(10.0, dequeue1.eval())
+ self.assertEqual(10.0, dequeue1.eval())
+ # And queue1 should now be closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError, "is closed"):
+ dequeue1.eval()
+
+ def testRespectCoordShouldStop(self):
+ with self.test_session() as sess:
+ # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
+ zero64 = tf.constant(0, dtype=tf.int64)
+ var = tf.Variable(zero64)
+ count_up_to = var.count_up_to(3)
+ queue = tf.FIFOQueue(10, tf.float32)
+ tf.initialize_all_variables().run()
+ qr = tf.train.QueueRunner(queue, [count_up_to])
+ # As the coordinator to stop. The queue runner should
+ # finish immediately.
+ coord = tf.train.Coordinator()
+ coord.request_stop()
+ threads = qr.create_threads(sess, coord)
+ for t in threads:
+ t.start()
+ coord.join(threads)
+ self.assertEqual(0, len(qr.exceptions_raised))
+ # The variable should be 0.
+ self.assertEqual(0, var.eval())
+
+ def testRequestStopOnException(self):
+ with self.test_session() as sess:
+ queue = tf.FIFOQueue(10, tf.float32)
+ qr = tf.train.QueueRunner(queue, ["not an op"])
+ coord = tf.train.Coordinator()
+ threads = qr.create_threads(sess, coord)
+ for t in threads:
+ t.start()
+ # The exception should be re-raised when joining.
+ with self.assertRaisesRegexp(ValueError, "Operation not in the graph"):
+ coord.join(threads)
+
+ def testGracePeriod(self):
+ with self.test_session() as sess:
+ # The enqueue will quickly block.
+ queue = tf.FIFOQueue(2, tf.float32)
+ enqueue = queue.enqueue((10.0,))
+ dequeue = queue.dequeue()
+ qr = tf.train.QueueRunner(queue, [enqueue])
+ coord = tf.train.Coordinator()
+ threads = qr.create_threads(sess, coord, start=True)
+ # Dequeue one element and then request stop.
+ dequeue.op.run()
+ time.sleep(0.02)
+ coord.request_stop()
+ # We should be able to join because the RequestStop() will cause
+ # the queue to be closed and the enqueue to terminate.
+ coord.join(threads, stop_grace_period_secs=0.05)
+
+ def testNoMultiThreads(self):
+ with self.test_session() as sess:
+ # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
+ zero64 = tf.constant(0, dtype=tf.int64)
+ var = tf.Variable(zero64)
+ count_up_to = var.count_up_to(3)
+ queue = tf.FIFOQueue(10, tf.float32)
+ tf.initialize_all_variables().run()
+ coord = tf.train.Coordinator()
+ qr = tf.train.QueueRunner(queue, [count_up_to])
+ threads = []
+ threads.extend(qr.create_threads(sess, coord=coord))
+ with self.assertRaisesRegexp(
+ RuntimeError,
+ "Threads are already running"):
+ threads.extend(qr.create_threads(sess, coord=coord))
+ coord.request_stop()
+ coord.join(threads, stop_grace_period_secs=0.5)
+
+ def testThreads(self):
+ with self.test_session() as sess:
+ # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
+ zero64 = tf.constant(0, dtype=tf.int64)
+ var = tf.Variable(zero64)
+ count_up_to = var.count_up_to(3)
+ queue = tf.FIFOQueue(10, tf.float32)
+ tf.initialize_all_variables().run()
+ qr = tf.train.QueueRunner(queue, [count_up_to, "bad op"])
+ threads = qr.create_threads(sess, start=True)
+ for t in threads:
+ t.join()
+ exceptions = qr.exceptions_raised
+ self.assertEqual(1, len(exceptions))
+ self.assertTrue("Operation not in the graph" in str(exceptions[0]))
+
+ threads = qr.create_threads(sess, start=True)
+ for t in threads:
+ t.join()
+ exceptions = qr.exceptions_raised
+ self.assertEqual(1, len(exceptions))
+ self.assertTrue("Operation not in the graph" in str(exceptions[0]))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/rmsprop.py b/tensorflow/python/training/rmsprop.py
new file mode 100644
index 0000000000..6dc0ce11ea
--- /dev/null
+++ b/tensorflow/python/training/rmsprop.py
@@ -0,0 +1,81 @@
+"""One-line documentation for rmsprop module.
+
+rmsprop algorithm [tieleman2012rmsprop]
+
+A detailed description of rmsprop.
+
+- maintain a moving (discounted) average of the square of gradients
+- divide gradient by the root of this average
+
+mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2
+mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square + epsilon)
+delta = - mom
+
+"""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_ops
+
+
+class RMSPropOptimizer(optimizer.Optimizer):
+ """Optimizer that implements the RMSProp algorithm.
+
+ @@__init__
+ """
+
+ def __init__(self, learning_rate, decay, momentum=0.0, epsilon=1e-10,
+ use_locking=False, name="RMSProp"):
+ """Construct a new RMSProp optimizer.
+
+ Args:
+ learning_rate: A Tensor or a floating point value. The learning rate.
+ decay: discounting factor for the history/coming gradient
+ momentum: a scalar tensor.
+ epsilon: small value to avoid zero denominator.
+ use_locking: If True use locks for update operation.
+ name: Optional name prefic for the operations created when applying
+ gradients. Defaults to "RMSProp".
+ """
+ super(RMSPropOptimizer, self).__init__(use_locking, name)
+ self._learning_rate = learning_rate
+ self._decay = decay
+ self._momentum = momentum
+ self._epsilon = epsilon
+
+ # Tensors for learning rate and momentum. Created in _prepare.
+ self._learning_rate_tensor = None
+ self._decay_tensor = None
+ self._momentum_tensor = None
+ self._epsilon_tensor = None
+
+ def _create_slots(self, var_list):
+ for v in var_list:
+ self._get_or_make_slot(
+ v, constant_op.constant(1.0, dtype=v.dtype, shape=v.get_shape()),
+ "rms", self._name)
+ self._zeros_slot(v, "momentum", self._name)
+
+ def _prepare(self):
+ self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
+ name="learning_rate")
+ self._decay_tensor = ops.convert_to_tensor(self._decay, name="decay")
+ self._momentum_tensor = ops.convert_to_tensor(self._momentum,
+ name="momentum")
+ self._epsilon_tensor = ops.convert_to_tensor(self._epsilon,
+ name="epsilon")
+
+ def _apply_dense(self, grad, var):
+ rms = self.get_slot(var, "rms")
+ mom = self.get_slot(var, "momentum")
+ return training_ops.apply_rms_prop(
+ var, rms, mom,
+ self._learning_rate_tensor,
+ self._decay_tensor,
+ self._momentum_tensor,
+ self._epsilon_tensor,
+ grad, use_locking=self._use_locking).op
+
+ def _apply_sparse(self, grad, var):
+ raise NotImplementedError()
diff --git a/tensorflow/python/training/rmsprop_test.py b/tensorflow/python/training/rmsprop_test.py
new file mode 100644
index 0000000000..520df73ca8
--- /dev/null
+++ b/tensorflow/python/training/rmsprop_test.py
@@ -0,0 +1,158 @@
+"""Tests for rmsprop."""
+import math
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class RMSPropOptimizerTest(tf.test.TestCase):
+
+ def testWithoutMomentum(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+ opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9,
+ momentum=0.0, epsilon=1.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertTrue(rms0 is not None)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertTrue(rms1 is not None)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertTrue(mom0 is not None)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertTrue(mom1 is not None)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the rms accumulators where 1. So we should see a normal
+ # update: v -= grad * learning_rate
+ update.run()
+ # Check the root mean square accumulators.
+ self.assertAllClose(np.array([0.901, 0.901]), rms0.eval())
+ self.assertAllClose(np.array([0.90001, 0.90001]), rms1.eval())
+ # Check the parameters.
+ self.assertAllClose(np.array([1.0 - (0.1 * 2.0 / math.sqrt(0.901+1.0)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901+1.0))]),
+ var0.eval())
+ self.assertAllClose(np.array([3.0 - (0.01 * 2.0
+ / math.sqrt(0.90001+1.0)),
+ 4.0 - (0.01 * 2.0
+ / math.sqrt(0.90001+1.0))]),
+ var1.eval())
+ # Step 2: the root mean square accumulators contain the previous update.
+ update.run()
+ # Check the rms accumulators.
+ self.assertAllClose(np.array([0.901*0.9+0.001, 0.901*0.9+0.001]),
+ rms0.eval())
+ self.assertAllClose(np.array([0.90001*0.9+1e-5, 0.90001*0.9+1e-5]),
+ rms1.eval())
+ # Check the parameters.
+ self.assertAllClose(
+ np.array([1.0 - (0.1 * 2.0 / math.sqrt(0.901+1.0))
+ - (0.1 * 2.0 / math.sqrt(0.901*0.9+0.001+1.0)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901+1.0))
+ - (0.1 * 2.0 / math.sqrt(0.901*0.9+0.001+1.0))]),
+ var0.eval())
+ self.assertAllClose(np.array([3.0 - (0.01 * 2.0 / math.sqrt(0.90001+1.0))
+ - (0.01 * 2.0 /
+ math.sqrt(0.90001*0.9+1e-5+1.0)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001+1.0))
+ - (0.01 * 2.0 /
+ math.sqrt(0.90001*0.9+1e-5+1.0))]),
+ var1.eval())
+
+ def testWithMomentum(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+
+ opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9,
+ momentum=0.5, epsilon=1e-5)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertTrue(rms0 is not None)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertTrue(rms1 is not None)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertTrue(mom0 is not None)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertTrue(mom1 is not None)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: rms = 1, mom = 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ update.run()
+ # Check the root mean square accumulators.
+ self.assertAllClose(np.array([0.901, 0.901]), rms0.eval())
+ self.assertAllClose(np.array([0.90001, 0.90001]), rms1.eval())
+ # Check the momentum accumulators
+ self.assertAllClose(np.array([(0.1 * 2.0 / math.sqrt(0.901+1e-5)),
+ (0.1 * 2.0 / math.sqrt(0.901+1e-5))]),
+ mom0.eval())
+ self.assertAllClose(np.array([(0.01 * 2.0/ math.sqrt(0.90001+1e-5)),
+ (0.01 * 2.0/ math.sqrt(0.90001+1e-5))]),
+ mom1.eval())
+
+ # Check that the parameters.
+ self.assertAllClose(np.array([1.0 - (0.1 * 2.0 / math.sqrt(0.901+1e-5)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901+1e-5))]),
+ var0.eval())
+ self.assertAllClose(np.array([3.0 - (0.01 * 2.0/ math.sqrt(0.90001+1e-5)),
+ 4.0 - (0.01 * 2.0/ math.sqrt(0.90001+1e-5))]
+ ),
+ var1.eval())
+
+ # Step 2: the root mean square accumulators contain the previous update.
+ update.run()
+ # Check the rms accumulators.
+ self.assertAllClose(np.array([0.901*0.9+0.001, 0.901*0.9+0.001]),
+ rms0.eval())
+ self.assertAllClose(np.array([0.90001*0.9+1e-5, 0.90001*0.9+1e-5]),
+ rms1.eval())
+ self.assertAllClose(np.array([0.5 * (0.1 * 2.0 / math.sqrt(0.901+1e-5)) +
+ (0.1*2.0/math.sqrt(0.901*0.9+0.001+1e-5)),
+ 0.5 * (0.1 * 2.0 / math.sqrt(0.901+1e-5)) +
+ (0.1*2.0/math.sqrt(0.901*0.9+0.001+1e-5))
+ ]), mom0.eval())
+ self.assertAllClose(np.array([0.5 *(0.01 * 2.0/ math.sqrt(0.90001+1e-5))+
+ (0.01 * 2.0 /math.sqrt(0.90001*0.9+2e-5)),
+ 0.5 *(0.01 * 2.0/ math.sqrt(0.90001+1e-5))+
+ (0.01 * 2.0 / math.sqrt(0.90001*0.9+2e-5))
+ ]), mom1.eval())
+
+ # Check the parameters.
+ self.assertAllClose(
+ np.array([1.0 - (0.1 * 2.0 / math.sqrt(0.901+1e-5)) - (0.5 * (
+ 0.1 * 2.0 / math.sqrt(0.901+1e-5)) +(
+ 0.1 * 2.0 / math.sqrt(0.901*0.9+0.001+1e-5))),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901+1e-5)) - (0.5 * (
+ 0.1 * 2.0 / math.sqrt(0.901+1e-5)) +(
+ 0.1 * 2.0 / math.sqrt(0.901*0.9+0.001+1e-5)))
+ ]), var0.eval())
+
+ self.assertAllClose(
+ np.array([3.0 - (0.01 * 2.0 / math.sqrt(0.90001+1e-5))
+ - (0.5 *(0.01 * 2.0/ math.sqrt(0.90001+1e-5)) +
+ (0.01 * 2.0 /math.sqrt(0.90001*0.9+2e-5))),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001+1e-5))
+ - (0.5 *(0.01 * 2.0/ math.sqrt(0.90001+1e-5)) +
+ (0.01 * 2.0 / math.sqrt(0.90001*0.9+2e-5)))]),
+ var1.eval())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/saver.proto b/tensorflow/python/training/saver.proto
new file mode 100644
index 0000000000..b9ba9f7e3c
--- /dev/null
+++ b/tensorflow/python/training/saver.proto
@@ -0,0 +1,30 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+// Protocol buffer representing the configuration of a SaveRestoreHelper.
+message SaverDef {
+ // The name of the tensor in which to specify the filename when saving or
+ // restoring a model checkpoint.
+ string filename_tensor_name = 1;
+
+ // The operation to run when saving a model checkpoint.
+ string save_tensor_name = 2;
+
+ // The operation to run when restoring a model checkpoint.
+ string restore_op_name = 3;
+
+ // Maximum number of checkpoints to keep. If 0, no checkpoints are deleted.
+ int32 max_to_keep = 4;
+
+ // Shard the save files, one per device that has Parameters nodes.
+ bool sharded = 5;
+
+ // How often to keep an additional checkpoint. If not specified, only the last
+ // "max_to_keep" checkpoints are kept; if specified, in addition to keeping
+ // the
+ // last "max_to_keep" checkpoints, an additional checkpoint will be kept for
+ // every n hours of training.
+ float keep_checkpoint_every_n_hours = 6;
+}
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
new file mode 100644
index 0000000000..505bbad4c6
--- /dev/null
+++ b/tensorflow/python/training/saver.py
@@ -0,0 +1,887 @@
+# pylint: disable=invalid-name
+"""Save and restore variables."""
+import collections
+import numbers
+import os.path
+import time
+
+from google.protobuf import text_format
+
+from tensorflow.python.client import graph_util
+from tensorflow.python.client import session
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_io_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import logging
+from tensorflow.python.training import saver_pb2
+from tensorflow.python.training import training_util
+from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
+
+
+class BaseSaverBuilder(object):
+ """Base class for Savers.
+
+ Can be extended to create different Ops.
+ """
+
+ class VarToSave(object):
+ """Class used to describe variable slices that need to be saved."""
+
+ def __init__(self, var, slice_spec, name):
+ self.var = var
+ self.slice_spec = slice_spec
+ self.name = name
+
+ def __init__(self):
+ pass
+
+ def save_op(self, filename_tensor, vars_to_save):
+ """Create an Op to save 'vars_to_save'.
+
+ This is intended to be overridden by subclasses that want to generate
+ different Ops.
+
+ Args:
+ filename_tensor: String Tensor.
+ vars_to_save: a list of BaseSaverBuilder.VarToSave objects.
+
+ Returns:
+ An Operation that save the variables.
+ """
+ return io_ops._save(
+ filename=filename_tensor,
+ tensor_names=[vs.name for vs in vars_to_save],
+ tensors=[vs.var for vs in vars_to_save],
+ tensor_slices=[vs.slice_spec for vs in vars_to_save])
+
+ def restore_op(self, filename_tensor, var_to_save, preferred_shard):
+ """Create an Op to read the variable 'var_to_save'.
+
+ This is intended to be overridden by subclasses that want to generate
+ different Ops.
+
+ Args:
+ filename_tensor: String Tensor.
+ var_to_save: a BaseSaverBuilder.VarToSave object.
+ preferred_shard: Int. Shard to open first when loading a sharded file.
+
+ Returns:
+ A Tensor resulting from reading 'var_to_save' from 'filename'.
+ """
+ return io_ops._restore_slice(
+ filename_tensor,
+ var_to_save.name,
+ var_to_save.slice_spec,
+ var_to_save.var.dtype,
+ preferred_shard=preferred_shard)
+
+ def sharded_filename(self, filename_tensor, shard, num_shards):
+ """Append sharding information to a filename.
+
+ Args:
+ filename_tensor: a string tensor.
+ shard: integer. The shard for the filename.
+ num_shards: an int Tensor for the number of shards.
+
+ Returns:
+ A string tensor.
+ """
+ return gen_io_ops._sharded_filename(filename_tensor, shard, num_shards)
+
+ def _AddSaveOps(self, filename_tensor, vars_to_save):
+ """Add ops to save variables that are on the same shard.
+
+ Args:
+ filename_tensor: String Tensor.
+ vars_to_save: a list of _VarToSave objects.
+
+ Returns:
+ A tensor with the filename used to save.
+ """
+ save = self.save_op(filename_tensor, vars_to_save)
+ return control_flow_ops.with_dependencies([save], filename_tensor)
+
+ def _AddShardedSaveOps(self, filename_tensor, per_device):
+ """Add ops to save the params per shard.
+
+ Args:
+ filename_tensor: String Tensor.
+ per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
+ returned by _GroupByDevices().
+
+ Returns:
+ An op to save the variables.
+ """
+ num_shards = len(per_device)
+ sharded_saves = []
+ num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
+ for shard, (device, vars_to_save) in enumerate(per_device):
+ with ops.device(device):
+ sharded_filename = self.sharded_filename(
+ filename_tensor, shard, num_shards_tensor)
+ sharded_saves.append(self._AddSaveOps(sharded_filename, vars_to_save))
+ # Return the sharded name for the save path.
+ with ops.control_dependencies([x.op for x in sharded_saves]):
+ return gen_io_ops._sharded_filespec(filename_tensor, num_shards_tensor)
+
+ def _AddRestoreOps(self,
+ filename_tensor,
+ vars_to_save,
+ restore_sequentially,
+ reshape,
+ preferred_shard=-1,
+ name="restore_all"):
+ """Add operations to restore vars_to_save.
+
+ Args:
+ filename_tensor: Tensor for the path of the file to load.
+ vars_to_save: a list of _VarToSave objects.
+ restore_sequentially: True if we want to restore variables sequentially
+ within a shard.
+ reshape: True if we want to reshape loaded tensors to the shape of
+ the corresponding variable.
+ preferred_shard: Shard to open first when loading a sharded file.
+ name: Name for the returned op.
+
+ Returns:
+ An Operation that restores the variables.
+ """
+ assign_ops = []
+ for vs in vars_to_save:
+ v = vs.var
+ restore_control_inputs = assign_ops[-1:] if restore_sequentially else []
+ # Load and optionally reshape on the CPU, as string tensors are not
+ # available on the GPU.
+ # TODO(mdevin): Re-enable restore on GPU when we can support annotating
+ # string tensors as "HostMemory" inputs.
+ with ops.device(graph_util.set_cpu0(v.device) if v.device else None):
+ with ops.control_dependencies(restore_control_inputs):
+ values = self.restore_op(filename_tensor, vs, preferred_shard)
+ if reshape:
+ shape = v.get_shape()
+ if not shape.is_fully_defined():
+ shape = array_ops.shape(v)
+ values = array_ops.reshape(values, shape)
+
+ # Assign on the same device as the variable.
+ with ops.device(v.device):
+ assign_ops.append(state_ops.assign(v,
+ values,
+ validate_shape=not reshape))
+
+ # Create a Noop that has control dependencies from all the updates.
+ return control_flow_ops.group(*assign_ops, name=name)
+
+ def _AddShardedRestoreOps(self, filename_tensor, per_device,
+ restore_sequentially, reshape):
+ """Add Ops to save variables from multiple devices.
+
+ Args:
+ filename_tensor: Tensor for the path of the file to load.
+ per_device: A list of (device, _VarToSave) pairs, as
+ returned by _GroupByDevices().
+ restore_sequentially: True if we want to restore variables sequentially
+ within a shard.
+ reshape: True if we want to reshape loaded tensors to the shape of
+ the corresponding variable.
+
+ Returns:
+ An Operation that restores the variables.
+ """
+ sharded_restores = []
+ for shard, (device, vars_to_save) in enumerate(per_device):
+ with ops.device(device):
+ sharded_restores.append(self._AddRestoreOps(
+ filename_tensor,
+ vars_to_save,
+ restore_sequentially,
+ reshape,
+ preferred_shard=shard,
+ name="restore_shard"))
+ return control_flow_ops.group(*sharded_restores, name="restore_all")
+
+ def _IsVariable(self, v):
+ return isinstance(v, ops.Tensor) and (
+ v.op.type == "Variable" or v.op.type == "AutoReloadVariable")
+
+ def _GroupByDevices(self, vars_to_save):
+ """Group Variable tensor slices per device.
+
+ TODO(mdevin): Make sure that all the devices found are on different
+ job/replica/task/cpu|gpu. It would be bad if 2 were on the same device.
+ It can happen if the devices as unspecified.
+
+ Args:
+ vars_to_save: a list of BaseSaverBuilder.VarToSave objects.
+
+ Returns:
+ A list of tuples: (device_name, BaseSaverBuilder.VarToSave) tuples.
+ The list is sorted by ascending device_name.
+ """
+ per_device = collections.defaultdict(lambda: [])
+ for var_to_save in vars_to_save:
+ per_device[var_to_save.var.device].append(var_to_save)
+ return sorted([(dev, tup) for dev, tup in per_device.iteritems()],
+ key=lambda t: t[0])
+
+ def _VarListToDict(self, var_list):
+ """Create a dictionary of names to variable lists.
+
+ Args:
+ var_list: A list, tuple, or set of Variables.
+
+ Returns:
+ A dictionary of variable names to the variables that must be saved under
+ that name. Variables with save_slice_info are grouped together under the
+ same key in no particular order.
+
+ Raises:
+ TypeError: If the type of var_list or its elements is not supported.
+ ValueError: If at least two variables share the same name.
+ """
+ if not isinstance(var_list, (list, tuple, set)):
+ raise TypeError("Variables to save should be passed in a dict or a "
+ "list: %s" % var_list)
+ var_list = set(var_list)
+ names_to_variables = {}
+ for var in var_list:
+ # pylint: disable=protected-access
+ if isinstance(var, variables.Variable) and var._save_slice_info:
+ name = var._save_slice_info.name
+ if name in names_to_variables:
+ if not isinstance(names_to_variables[name], list):
+ raise ValueError("Mixing slices and non-slices with the same name: "
+ "%s" % name)
+ names_to_variables[name].append(var)
+ else:
+ names_to_variables[name] = [var]
+ else:
+ var = ops.convert_to_tensor(var)
+ if not self._IsVariable(var):
+ raise TypeError("Variable to save is not a Variable: %s" % var)
+ name = var.op.name
+ if name in names_to_variables:
+ raise ValueError("At least two variables have the same name: %s" %
+ name)
+ names_to_variables[name] = var
+ # pylint: enable=protected-access
+ return names_to_variables
+
+ def _ValidateAndSliceInputs(self, names_to_variables):
+ """Returns the variables and names that will be used for a Saver.
+
+ Args:
+ names_to_variables: A dict (k, v) where k is the name of a variable and v
+ is a Variable to save or a BaseSaverBuilder.Saver.
+
+ Returns:
+ A list of BaseSaverBuilder.VarToSave objects.
+
+ Raises:
+ TypeError: if any of the keys are not strings or any of the
+ values are not one of Tensor or Variable.
+ ValueError: if the same variable is given in more than one value
+ (this also applies to slices of SlicedVariables).
+ """
+ if not isinstance(names_to_variables, dict):
+ names_to_variables = self._VarListToDict(names_to_variables)
+
+ vars_to_save = []
+ seen_variables = set()
+ for name in sorted(names_to_variables.iterkeys()):
+ if not isinstance(name, basestring):
+ raise TypeError("names_to_variables must be a dict mapping string "
+ "names to variable Tensors. Name is not a string: %s" %
+ name)
+ v = names_to_variables[name]
+ if isinstance(v, (list, tuple)):
+ # A set of slices.
+ slice_name = None
+ # pylint: disable=protected-access
+ for variable in v:
+ if not isinstance(variable, variables.Variable):
+ raise ValueError("Slices must all be Variables: %s" % variable)
+ if not variable._save_slice_info:
+ raise ValueError("Slices must all be slices: %s" % variable)
+ if slice_name is None:
+ slice_name = variable._save_slice_info.name
+ elif slice_name != variable._save_slice_info.name:
+ raise variable("Slices must all be from the same tensor: %s != %s"
+ % (slice_name, variable._save_slice_info.name))
+ self._AddVarToSave(vars_to_save, seen_variables,
+ variable, variable._save_slice_info.spec, name)
+ # pylint: enable=protected-access
+ else:
+ # A variable or tensor.
+ variable = ops.convert_to_tensor(v)
+ if not self._IsVariable(variable):
+ raise TypeError("names_to_variables must be a dict mapping string "
+ "names to Tensors/Variables. Not a variable: %s" %
+ variable)
+ self._AddVarToSave(vars_to_save, seen_variables, variable, "", name)
+ return vars_to_save
+
+ def _AddVarToSave(self, vars_to_save, seen_variables, variable, slice_spec,
+ name):
+ """Create a VarToSave and add it to the vars_to_save list.
+
+ Args:
+ vars_to_save: List to append the new VarToSave to.
+ seen_variables: Set of variables already processed. Used to check
+ that each variable is only saved once.
+ variable: Variable to save.
+ slice_spec: String. Slice spec for the variable.
+ name: Name to use to save the variable.
+
+ Raises:
+ ValueError: If the variable has already been processed.
+ """
+ if variable in seen_variables:
+ raise ValueError("The same variable will be restored with two names: %s",
+ variable)
+ vars_to_save.append(BaseSaverBuilder.VarToSave(variable, slice_spec, name))
+ seen_variables.add(variable)
+
+ def build(self,
+ names_to_variables,
+ reshape=False,
+ sharded=False,
+ max_to_keep=5,
+ keep_checkpoint_every_n_hours=10000.0,
+ name=None,
+ restore_sequentially=False):
+ """Adds save/restore nodes to the graph and creates a SaverDef proto.
+
+ Args:
+ names_to_variables: A dictionary mapping name to a Variable.
+ Each name will be associated with the
+ corresponding variable in the checkpoint.
+ reshape: If True, allow restoring parameters from a checkpoint
+ that where the parameters have a different shape. This is
+ only needed when you try to restore from a Dist-Belief checkpoint,
+ and only some times.
+ sharded: If True, shard the checkpoints, one per device that has
+ Parameters nodes.
+ max_to_keep: maximum number of checkpoints to keep. As new checkpoints
+ are created, old ones are deleted. If None or 0, no checkpoints are
+ deleted. Presently the number is only roughly enforced. For example
+ in case of restarts more than max_to_keep checkpoints may be kept.
+ keep_checkpoint_every_n_hours: How often checkpoints should be kept.
+ Defaults to 10,000 hours.
+ name: string. Optional name to use as a prefix when adding operations.
+ restore_sequentially: A Bool, which if true, causes restore of different
+ variables to happen sequentially within each device.
+
+ Returns:
+ A SaverDef proto.
+
+ Raises:
+ TypeError: If 'names_to_variables' is not a dictionary mapping string
+ keys to variable Tensors.
+ ValueError: If any of the keys or values in 'names_to_variables' is not
+ unique.
+ """
+ vars_to_save = self._ValidateAndSliceInputs(names_to_variables)
+ if max_to_keep is None:
+ max_to_keep = 0
+
+ with ops.op_scope([vs.var for vs in vars_to_save], name, "save") as name:
+ # Add the Constant string tensor for the filename.
+ filename_tensor = constant_op.constant("model")
+
+ # Add the save ops.
+ if sharded:
+ per_device = self._GroupByDevices(vars_to_save)
+ save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
+ restore_op = self._AddShardedRestoreOps(
+ filename_tensor, per_device, restore_sequentially, reshape)
+ else:
+ save_tensor = self._AddSaveOps(filename_tensor, vars_to_save)
+ restore_op = self._AddRestoreOps(
+ filename_tensor, vars_to_save, restore_sequentially, reshape)
+
+ assert restore_op.name.endswith("restore_all"), restore_op.name
+
+ return saver_pb2.SaverDef(
+ filename_tensor_name=filename_tensor.name,
+ save_tensor_name=save_tensor.name,
+ restore_op_name=restore_op.name,
+ max_to_keep=max_to_keep,
+ keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
+ sharded=sharded)
+
+def _GetCheckpointFilename(save_dir, latest_filename):
+ """Returns a filename for storing the CheckpointState.
+
+ Args:
+ save_dir: The directory for saving and restoring checkpoints.
+ latest_filename: Name of the file in 'save_dir' that is used
+ to store the CheckpointState.
+
+ Returns:
+ The path of the file that contains the CheckpointState proto.
+ """
+ if latest_filename is None:
+ latest_filename = "checkpoint"
+ return os.path.join(save_dir, latest_filename)
+
+
+def update_checkpoint_state(save_dir,
+ model_checkpoint_path,
+ all_model_checkpoint_paths=None,
+ latest_filename=None):
+ """Updates the content of the 'checkpoint' file.
+
+ This updates the checkpoint file containing a CheckpointState
+ proto.
+
+ Args:
+ save_dir: Directory where the model was saved.
+ model_checkpoint_path: The checkpoint file.
+ all_model_checkpoint_paths: list of strings. Paths to all not-yet-deleted
+ checkpoints, sorted from oldest to newest. If this is a non-empty list,
+ the last element must be equal to model_checkpoint_path. These paths
+ are also saved in the CheckpointState proto.
+ latest_filename: Optional name of the checkpoint file. Default to
+ 'checkpoint'.
+
+ Raises:
+ RuntimeError: If the save paths conflict.
+ """
+ if all_model_checkpoint_paths is None:
+ all_model_checkpoint_paths = []
+ elif all_model_checkpoint_paths[-1] != model_checkpoint_path:
+ logging.warning(
+ "%s is not in all_model_checkpoint_paths! Manually adding it.",
+ model_checkpoint_path)
+ all_model_checkpoint_paths.append(model_checkpoint_path)
+ # Writes the "checkpoint" file for the coordinator for later restoration.
+ coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
+ if coord_checkpoint_filename == model_checkpoint_path:
+ raise RuntimeError("Save path '%s' conflicts with path used for "
+ "checkpoint state. Please use a different save path." %
+ model_checkpoint_path)
+ coord_checkpoint_proto = CheckpointState(
+ model_checkpoint_path=model_checkpoint_path,
+ all_model_checkpoint_paths=all_model_checkpoint_paths)
+ f = gfile.FastGFile(coord_checkpoint_filename, mode="w")
+ f.write(text_format.MessageToString(coord_checkpoint_proto))
+ f.close()
+
+
+def get_checkpoint_state(checkpoint_dir, latest_filename=None):
+ """Returns CheckpointState proto from the "checkpoint" file.
+
+ If the "checkpoint" file contains a valid CheckpointState
+ proto, returns it.
+
+ Args:
+ checkpoint_dir: The directory of checkpoints.
+ latest_filename: Optional name of the checkpoint file. Default to
+ 'checkpoint'.
+
+ Returns:
+ A CheckpointState if the state was available, None
+ otherwise.
+ """
+ ckpt = None
+ coord_checkpoint_filename = _GetCheckpointFilename(
+ checkpoint_dir, latest_filename)
+ f = None
+ try:
+ # Check that the file exists before opeining it to avoid
+ # many lines of errors from colossus in the logs.
+ if gfile.Exists(coord_checkpoint_filename):
+ f = gfile.FastGFile(coord_checkpoint_filename, mode="r")
+ ckpt = CheckpointState()
+ text_format.Merge(f.read(), ckpt)
+ except gfile.FileError:
+ # It's ok if the file cannot be read
+ return None
+ except text_format.ParseError, e:
+ logging.warning(str(e))
+ logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
+ return None
+ finally:
+ if f:
+ f.close()
+ return ckpt
+
+
+class Saver(object):
+ """Saves and restores variables.
+
+ See [Variables](../../how_tos/variables/index.md)
+ for an overview of variables, saving and restoring.
+
+ The `Saver` class adds ops to save and restore variables to and from
+ *checkpoints*. It also provides convenience methods to run these ops.
+
+ Checkpoints are binary files in a proprietary format which map variable names
+ to tensor values. The best way to examine the contents of a checkpoint is to
+ load it using a `Saver`.
+
+ Savers can automatically number checkpoint filenames with a provided counter.
+ This lets you keep multiple checkpoints at different steps while training a
+ model. For example you can number the checkpoint filenames with the training
+ step number. To avoid filling up disks, savers manage checkpoint files
+ automatically. For example, they can keep only the N most recent files, or
+ one checkpoint for every N hours of training.
+
+ You number checkpoint filenames by passing a value to the optional
+ `global_step` argument to `save()`:
+
+ ```python
+ saver.save('my-model', global_step=0) ==> filename: 'my-model-0'
+ ...
+ saver.save('my-model', global_step=1000) ==> filename: 'my-model-1000'
+ ```
+
+ Additionally, optional arguments to the `Saver()` constructor let you control
+ the proliferation of checkpoint files on disk:
+
+ * `max_to_keep` indicates the maximum number of recent checkpoint files to
+ keep. As new files are created, older files are deleted. If None or 0,
+ all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent
+ checkpoint files are kept.)
+
+ * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent
+ `max_to_keep` checkpoint files, you might want to keep one checkpoint file
+ for every N hours of training. This can be useful if you want to later
+ analyze how a model progressed during a long training session. For
+ example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep
+ one checkpoint file for every 2 hours of training. The default value of
+ 10,000 hours effectively disables the feature.
+
+ Note that you still have to call the `save()` method to save the model.
+ Passing these arguments to the constructor will not save variables
+ automatically for you.
+
+ A training program that saves regularly looks like:
+
+ ```python
+ ...
+ # Create a saver.
+ saver = tf.train.Saver(...variables...)
+ # Launch the graph and train, saving the model every 1,000 steps.
+ sess = tf.Session()
+ for step in xrange(1000000):
+ sess.run(..training_op..)
+ if step % 1000 == 0:
+ # Append the step number to the checkpoint name:
+ saver.save(sess, 'my-model', global_step=step)
+ ```
+
+ In addition to checkpoint files, savers keep a protocol buffer on disk with
+ the list of recent checkpoints. This is used to manage numbered checkpoint
+ files and by `latest_checkpoint()`, which makes it easy to discover the path
+ to the most recent checkpoint. That protocol buffer is stored in a file named
+ 'checkpoint' next to the checkpoint files.
+
+ If you create several savers, you can specify a different filename for the
+ protocol buffer file in the call to `save()`.
+
+ @@__init__
+ @@save
+ @@restore
+
+ Other utility methods.
+
+ @@last_checkpoints
+ @@set_last_checkpoints
+ @@as_saver_def
+ """
+
+ def __init__(self,
+ var_list=None,
+ reshape=False,
+ sharded=False,
+ max_to_keep=5,
+ keep_checkpoint_every_n_hours=10000.0,
+ name=None,
+ restore_sequentially=False,
+ saver_def=None,
+ builder=None):
+ """Creates a `Saver`.
+
+ The constructor adds ops to save and restore variables.
+
+ `var_list` specifies the variables that will be saved and restored. It can
+ be passed as a `dict` or a list:
+
+ * A `dict` of names to variables: The keys are the names that will be
+ used to save or restore the variables in the checkpoint files.
+ * A list of variables: The variables will be keyed with their op name in
+ the checkpoint files.
+
+ For example:
+
+ ```python
+ v1 = tf.Variable(..., name='v1')
+ v2 = tf.Variable(..., name='v2')
+
+ # Pass the variables as a dict:
+ saver = tf.train.Saver({'v1': v1, 'v2': v2})
+
+ # Or pass them as a list.
+ saver = tf.train.Saver([v1, v2])
+ # Passing a list is equivalent to passing a dict with the variable op names
+ # as keys:
+ saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
+ ```
+
+ The optional `reshape` argument, if True, allows restoring a variable from
+ a save file where the variable had a different shape, but the same number
+ of elements and type. This is useful if you have reshaped a variable and
+ want to reload it from an older checkpoint.
+
+ The optional `sharded` argument, if True, instructs the saver to shard
+ checkpoints per device.
+
+ Args:
+ var_list: A list of Variables or a dictionary mapping names to
+ Variables. If None, defaults to the list of all variables.
+ reshape: If True, allows restoring parameters from a checkpoint
+ where the variables have a different shape.
+ sharded: If True, shard the checkpoints, one per device.
+ max_to_keep: maximum number of recent checkpoints to keep.
+ Defaults to 10,000 hours.
+ keep_checkpoint_every_n_hours: How often to keep checkpoints.
+ Defaults to 10,000 hours.
+ name: string. Optional name to use as a prefix when adding operations.
+ restore_sequentially: A Bool, which if true, causes restore of different
+ variables to happen sequentially within each device. This can lower
+ memory usage when restoring very large models.
+ saver_def: Optional SaverDef proto to use instead of running the builder.
+ This is only useful for specialty code that wants to recreate a Saver
+ object for a previously built Graph that had a Saver. The saver_def
+ proto should be the one returned by the as_saver_def() call of the
+ Saver that was created for that Graph.
+ builder: Optional SaverBuilder to use if a saver_def was not provided.
+ Defaults to BaseSaverBuilder().
+
+ Raises:
+ TypeError: If `var_list` is invalid.
+ ValueError: If any of the keys or values in `var_list` is not unique.
+ """
+ if saver_def is None:
+ if builder is None:
+ builder = BaseSaverBuilder()
+ if var_list is None:
+ var_list = variables.all_variables()
+ if not var_list:
+ raise ValueError("No variables to save")
+ saver_def = builder.build(
+ var_list,
+ reshape=reshape,
+ sharded=sharded,
+ max_to_keep=max_to_keep,
+ keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
+ name=name,
+ restore_sequentially=restore_sequentially)
+ if not isinstance(saver_def, saver_pb2.SaverDef):
+ raise ValueError("saver_def must if a saver_pb2.SaverDef: %s" % saver_def)
+ if not saver_def.save_tensor_name:
+ raise ValueError("saver_def must specify the save_tensor_name: %s"
+ % str(saver_def))
+ if not saver_def.restore_op_name:
+ raise ValueError("saver_def must specify the restore_op_name: %s"
+ % str(saver_def))
+ self._filename_tensor_name = saver_def.filename_tensor_name
+ self._save_tensor_name = saver_def.save_tensor_name
+ self._restore_op_name = saver_def.restore_op_name
+ self._max_to_keep = saver_def.max_to_keep
+ # If keep_checkpoint_every_n_hours is not set, set it to 10000 hours.
+ self._keep_checkpoint_every_n_hours = (
+ saver_def.keep_checkpoint_every_n_hours if
+ saver_def.keep_checkpoint_every_n_hours else 10000)
+ self._next_checkpoint_time = (
+ time.time() + self._keep_checkpoint_every_n_hours * 3600)
+ self._sharded = saver_def.sharded
+ self._last_checkpoints = []
+
+ def _CheckpointFilename(self, p):
+ """Returns the checkpoint file name.
+
+ If p is (filename, time) pair, return p[0]; else return p.
+
+ Args:
+ p: (filename, time) pair or just checkpoint filename.
+
+ Returns:
+ Checkpoint file name.
+ """
+ return p[0] if isinstance(p, tuple) else p
+
+ def _MaybeDeleteOldCheckpoints(self, latest_save_path):
+ """Deletes old checkpoints if necessary.
+
+ Always keep the last max_to_keep checkpoints. If
+ keep_checkpoint_every_n_hours was specified, keep an additional checkpoint
+ every N hours. For example, if N is 0.5, an additional checkpoint is kept
+ for every 0.5 hours of training; if N is 10, an additional checkpoint is
+ kept for every 10 hours of training.
+
+ Args:
+ latest_save_path: Name including path of checkpoint file to save.
+ """
+ if not self._max_to_keep:
+ return
+ # Remove first from list if the same name was used before.
+ for p in self._last_checkpoints:
+ if latest_save_path == self._CheckpointFilename(p):
+ self._last_checkpoints.remove(p)
+ # Append new path to list
+ self._last_checkpoints.append((latest_save_path, time.time()))
+ # If more than max_to_keep, remove oldest.
+ if len(self._last_checkpoints) > self._max_to_keep:
+ p = self._last_checkpoints.pop(0)
+ # Do not delete the file if we keep_checkpoint_every_n_hours is set and we
+ # have reached N hours of training.
+ should_keep = p[1] > self._next_checkpoint_time
+ if should_keep:
+ self._next_checkpoint_time += (
+ self._keep_checkpoint_every_n_hours * 3600)
+ return
+ # Otherwise delete the files.
+ for f in gfile.Glob(self._CheckpointFilename(p)):
+ try:
+ gfile.Remove(f)
+ except gfile.GOSError, e:
+ logging.warning("Ignoring: %s", str(e))
+
+ def as_saver_def(self):
+ """Generates a `SaverDef` representation of this saver.
+
+ Returns:
+ A `SaverDef` proto.
+ """
+ return saver_pb2.SaverDef(
+ filename_tensor_name=self._filename_tensor_name,
+ save_tensor_name=self._save_tensor_name,
+ restore_op_name=self._restore_op_name,
+ max_to_keep=self._max_to_keep,
+ keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours,
+ sharded=self._sharded)
+
+ @property
+ def last_checkpoints(self):
+ """List of not-yet-deleted checkpoint filenames.
+
+ You can pass any of the returned values to `restore()`.
+
+ Returns:
+ A list of checkpoint filenames, sorted from oldest to newest.
+ """
+ return list(self._CheckpointFilename(p) for p in self._last_checkpoints)
+
+ def set_last_checkpoints(self, last_checkpoints):
+ """Sets the list of not-yet-deleted checkpoint filenames.
+
+ Args:
+ last_checkpoints: a list of checkpoint filenames.
+
+ Raises:
+ AssertionError: if the list of checkpoint filenames has already been set.
+ """
+ assert not self._last_checkpoints
+ assert isinstance(last_checkpoints, list)
+ self._last_checkpoints = list(last_checkpoints)
+
+ def save(self, sess, save_path, global_step=None, latest_filename=None):
+ """Saves variables.
+
+ This method runs the ops added by the constructor for saving variables.
+ It requires a session in which the graph was launched. The variables to
+ save must also have been initialized.
+
+ The method returns the path of the newly created checkpoint file. This
+ path can be passed directly to a call to `restore()`.
+
+ Args:
+ sess: A Session to use to save the variables..
+ save_path: string. Path to the checkpoint filename. If the saver is
+ `sharded`, this is the prefix of the sharded checkpoint filename.
+ global_step: If provided the global step number is appended to
+ `save_path` to create the checkpoint filename. The optional argument
+ can be a Tensor, a Tensor name or an integer.
+ latest_filename: Optional name for the protocol buffer file that will
+ contains the list of most recent checkpoint filenames. That file,
+ kept in the same directory as the checkpoint files, is automatically
+ managed by the saver to keep track of recent checkpoints. Defaults to
+ 'checkpoint'.
+
+ Returns:
+ A string: path at which the variables were saved. If the saver is
+ sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
+ is the number of shards created.
+
+ Raises:
+ TypeError: If `sess` is not a Session.
+ """
+ if latest_filename is None:
+ latest_filename = "checkpoint"
+ if global_step is not None:
+ if not isinstance(global_step, numbers.Number):
+ global_step = training_util.global_step(sess, global_step)
+ checkpoint_file = "%s-%d" % (save_path, global_step)
+ else:
+ checkpoint_file = save_path
+ save_path = os.path.dirname(save_path)
+ if not isinstance(sess, session.SessionInterface):
+ raise TypeError("'sess' must be a Session; %s" % sess)
+
+ model_checkpoint_path = sess.run(
+ self._save_tensor_name, {self._filename_tensor_name: checkpoint_file})
+ model_checkpoint_path = str(model_checkpoint_path)
+ self._MaybeDeleteOldCheckpoints(model_checkpoint_path)
+ update_checkpoint_state(save_path, model_checkpoint_path,
+ self.last_checkpoints, latest_filename)
+ return model_checkpoint_path
+
+ def restore(self, sess, save_path):
+ """Restores previously saved variables.
+
+ This method runs the ops added by the constructor for restoring variables.
+ It requires a session in which the graph was launched. The variables to
+ restore do not have to have been initialized, as restoring is itself a way
+ to initialize variables.
+
+ The `save_path` argument is typically a value previously returned from a
+ `save()` call, or a call to `latest_checkpoint()`.
+
+ Args:
+ sess: A Session to use to restore the parameters.
+ save_path: Path where parameters were previously saved.
+ """
+ sess.run([self._restore_op_name], {self._filename_tensor_name: save_path})
+
+
+def latest_checkpoint(checkpoint_dir, latest_filename=None):
+ """Finds the filename of latest saved checkpoint file.
+
+ Args:
+ checkpoint_dir: Directory where the variables were saved.
+ latest_filename: Optional name for the protocol buffer file that
+ contains the list of most recent checkpoint filenames.
+ See the corresponding argument to `Saver.save()`.
+
+ Returns:
+ The full path to the latest checkpoint or None if no checkpoint was found.
+ """
+ # Pick the latest checkpoint based on checkpoint state.
+ ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
+ if ckpt and ckpt.model_checkpoint_path:
+ checkpoint_full_path = os.path.join(
+ checkpoint_dir, ckpt.model_checkpoint_path)
+ if gfile.Exists(checkpoint_full_path):
+ return checkpoint_full_path
+
+ return None
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
new file mode 100644
index 0000000000..db378e9637
--- /dev/null
+++ b/tensorflow/python/training/saver_test.py
@@ -0,0 +1,563 @@
+"""Tests for tensorflow.ops.io_ops."""
+import os.path
+import time
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+import numpy as np
+
+from tensorflow.python.platform import gfile
+
+
+class SaverTest(tf.test.TestCase):
+
+ def testBasics(self):
+ save_path = os.path.join(self.get_temp_dir(), "basics")
+
+ with self.test_session() as sess:
+ # Build a graph with 2 parameter nodes, and Save and
+ # Restore nodes for them.
+ v0 = tf.Variable(10.0, name="v0")
+ v1 = tf.Variable(20.0, name="v1")
+ save = tf.train.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
+ tf.initialize_all_variables().run()
+
+ # Check that the parameter nodes have been initialized.
+ self.assertEqual(10.0, v0.eval())
+ self.assertEqual(20.0, v1.eval())
+
+ # Save the initialized values in the file at "save_path"
+ val = save.save(sess, save_path)
+ self.assertTrue(isinstance(val, basestring))
+ self.assertEqual(save_path, val)
+
+ # Start a second session. In that session the parameter nodes
+ # have not been initialized either.
+ with self.test_session() as sess:
+ v0 = tf.Variable(-1.0, name="v0")
+ v1 = tf.Variable(-1.0, name="v1")
+ save = tf.train.Saver({"v0": v0, "v1": v1})
+
+ with self.assertRaisesWithPredicateMatch(
+ tf.OpError, lambda e: "uninitialized value v0" in e.message):
+ sess.run(v0)
+ with self.assertRaisesWithPredicateMatch(
+ tf.OpError, lambda e: "uninitialized value v1" in e.message):
+ sess.run(v1)
+
+ # Restore the saved values in the parameter nodes.
+ save.restore(sess, save_path)
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(10.0, v0.eval())
+ self.assertEqual(20.0, v1.eval())
+
+ # Build another graph with 2 nodes, initialized
+ # differently, and a Restore node for them.
+ with self.test_session() as sess:
+ v0_2 = tf.Variable(1000.0, name="v0")
+ v1_2 = tf.Variable(2000.0, name="v1")
+ save2 = tf.train.Saver({"v0": v0_2, "v1": v1_2})
+ tf.initialize_all_variables().run()
+
+ # Check that the parameter nodes have been initialized.
+ self.assertEqual(1000.0, v0_2.eval())
+ self.assertEqual(2000.0, v1_2.eval())
+ # Restore the values saved earlier in the parameter nodes.
+ save2.restore(sess, save_path)
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(10.0, v0_2.eval())
+ self.assertEqual(20.0, v1_2.eval())
+
+ def testInt64(self):
+ save_path = os.path.join(self.get_temp_dir(), "int64")
+
+ with self.test_session() as sess:
+ # Build a graph with 1 node, and save and restore for them.
+ v = tf.Variable(np.int64(15), name="v")
+ save = tf.train.Saver({"v": v}, restore_sequentially=True)
+ tf.initialize_all_variables().run()
+
+ # Save the initialized values in the file at "save_path"
+ val = save.save(sess, save_path)
+ self.assertTrue(isinstance(val, basestring))
+ self.assertEqual(save_path, val)
+
+ with self.test_session() as sess:
+ v = tf.Variable(np.int64(-1), name="v")
+ save = tf.train.Saver({"v": v})
+
+ with self.assertRaisesWithPredicateMatch(
+ tf.OpError, lambda e: "uninitialized value v" in e.message):
+ sess.run(v)
+
+ # Restore the saved values in the parameter nodes.
+ save.restore(sess, save_path)
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(np.int64(15), v.eval())
+
+ def testSomeErrors(self):
+ with tf.Graph().as_default():
+ v0 = tf.Variable([10.0], name="v0")
+ v1 = tf.Variable([20.0], name="v1")
+ v2 = tf.Variable([20.0], name="v2")
+ v2._set_save_slice_info(tf.Variable.SaveSliceInfo("v1", ""))
+
+ # By default the name used for "v2" will be "v1" and raise an error.
+ with self.assertRaisesRegexp(ValueError, "same name: v1"):
+ tf.train.Saver([v0, v1, v2])
+
+ # The names are different and will work.
+ tf.train.Saver({"vee1": v1, "other": [v2]})
+
+ def testBasicsWithListOfVariables(self):
+ save_path = os.path.join(self.get_temp_dir(), "basics_with_list")
+
+ with self.test_session(graph=tf.Graph()) as sess:
+ # Build a graph with 2 parameter nodes, and Save and
+ # Restore nodes for them.
+ v0 = tf.Variable(10.0, name="v0")
+ v1 = tf.Variable(20.0, name="v1")
+ save = tf.train.Saver([v0, v1])
+ tf.initialize_all_variables().run()
+
+ # Check that the parameter nodes have been initialized.
+ self.assertEqual(10.0, v0.eval())
+ self.assertEqual(20.0, v1.eval())
+
+ # Save the initialized values in the file at "save_path"
+ val = save.save(sess, save_path)
+ self.assertTrue(isinstance(val, basestring))
+ self.assertEqual(save_path, val)
+
+ # Start a second session. In that session the variables
+ # have not been initialized either.
+ with self.test_session(graph=tf.Graph()) as sess:
+ v0 = tf.Variable(-1.0, name="v0")
+ v1 = tf.Variable(-1.0, name="v1")
+ save = tf.train.Saver([v0, v1])
+
+ with self.assertRaisesWithPredicateMatch(
+ tf.OpError, lambda e: "uninitialized value v0" in e.message):
+ sess.run(v0)
+ with self.assertRaisesWithPredicateMatch(
+ tf.OpError, lambda e: "uninitialized value v1" in e.message):
+ sess.run(v1)
+
+ # Restore the saved values in the parameter nodes.
+ save.restore(sess, save_path)
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(10.0, v0.eval())
+ self.assertEqual(20.0, v1.eval())
+
+ # Build another graph with 2 nodes, initialized
+ # differently, and a Restore node for them.
+ with self.test_session(graph=tf.Graph()) as sess:
+ v0_2 = tf.Variable(1000.0, name="v0")
+ v1_2 = tf.Variable(2000.0, name="v1")
+ save2 = tf.train.Saver([v0_2, v1_2])
+ tf.initialize_all_variables().run()
+
+ # Check that the parameter nodes have been initialized.
+ self.assertEqual(1000.0, v0_2.eval())
+ self.assertEqual(2000.0, v1_2.eval())
+ # Restore the values saved earlier in the parameter nodes.
+ save2.restore(sess, save_path)
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(10.0, v0_2.eval())
+ self.assertEqual(20.0, v1_2.eval())
+
+ def _SaveAndLoad(self, var_name, var_value, other_value, save_path):
+ with self.test_session() as sess:
+ var = tf.Variable(var_value, name=var_name)
+ save = tf.train.Saver({var_name: var})
+ var.initializer.run()
+ val = save.save(sess, save_path)
+ self.assertEqual(save_path, val)
+ with self.test_session() as sess:
+ var = tf.Variable(other_value, name=var_name)
+ save = tf.train.Saver({var_name: var})
+ save.restore(sess, save_path)
+ self.assertAllClose(var_value, var.eval())
+
+ def testCacheRereadsFile(self):
+ save_path = os.path.join(self.get_temp_dir(), "cache_rereads")
+ # Save and reload one Variable named "var0".
+ self._SaveAndLoad("var0", 0.0, 1.0, save_path)
+ # Save and reload one Variable named "var1" in the same file.
+ # The cached readers should know to re-read the file.
+ self._SaveAndLoad("var1", 1.1, 2.2, save_path)
+
+ def testGPU(self):
+ if not tf.test.IsBuiltWithCuda():
+ return
+ save_path = os.path.join(self.get_temp_dir(), "gpu")
+ with tf.Session("", graph=tf.Graph()) as sess:
+ with sess.graph.device("/gpu:0"):
+ v0_1 = tf.Variable(123.45)
+ save = tf.train.Saver({"v0": v0_1})
+ tf.initialize_all_variables().run()
+ save.save(sess, save_path)
+
+ with tf.Session("", graph=tf.Graph()) as sess:
+ with sess.graph.device("/gpu:0"):
+ v0_2 = tf.Variable(543.21)
+ save = tf.train.Saver({"v0": v0_2})
+ tf.initialize_all_variables().run()
+ self.assertAllClose(543.21, v0_2.eval())
+ save.restore(sess, save_path)
+ self.assertAllClose(123.45, v0_2.eval())
+
+ def testVariables(self):
+ save_path = os.path.join(self.get_temp_dir(), "variables")
+ with tf.Session("", graph=tf.Graph()) as sess:
+ one = tf.Variable(1.0)
+ twos = tf.Variable([2.0, 2.0, 2.0])
+ init = tf.initialize_all_variables()
+ save = tf.train.Saver(tf.all_variables())
+ init.run()
+ save.save(sess, save_path)
+
+ with tf.Session("", graph=tf.Graph()) as sess:
+ one = tf.Variable(0.0)
+ twos = tf.Variable([0.0, 0.0, 0.0])
+ # Saver with no arg, defaults to 'all variables'.
+ save = tf.train.Saver()
+ save.restore(sess, save_path)
+ self.assertAllClose(1.0, one.eval())
+ self.assertAllClose([2.0, 2.0, 2.0], twos.eval())
+
+ def testSaveWithGlobalStep(self):
+ save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step")
+ global_step_int = 5
+ # Save and reload one Variable named "var0".
+ self._SaveAndLoad("var0", 0.0, 1.0, save_path)
+ for use_tensor in [True, False]:
+ with self.test_session() as sess:
+ var = tf.Variable(1.0, name="var0")
+ save = tf.train.Saver({var.op.name: var})
+ var.initializer.run()
+ if use_tensor:
+ global_step = tf.constant(global_step_int)
+ val = save.save(sess, save_path, global_step=global_step)
+ else:
+ val = save.save(sess, save_path, global_step=global_step_int)
+ expected_save_path = "%s-%d" % (save_path, global_step_int)
+ self.assertEqual(expected_save_path, val)
+
+
+class SaveRestoreShardedTest(tf.test.TestCase):
+
+ def testBasics(self):
+ save_path = os.path.join(self.get_temp_dir(), "sharded")
+
+ # Build a graph with 2 parameter nodes on different devices.
+ with tf.Session(
+ target="",
+ config=tf.ConfigProto(device_count={"CPU": 2})) as sess:
+ with sess.graph.device("/cpu:0"):
+ v0 = tf.Variable(10, name="v0")
+ with sess.graph.device("/cpu:1"):
+ v1 = tf.Variable(20, name="v1")
+ save = tf.train.Saver({"v0": v0, "v1": v1}, sharded=True)
+ tf.initialize_all_variables().run()
+ val = save.save(sess, save_path)
+ self.assertEqual(save_path + "-?????-of-00002", val)
+
+ # Restore a different "v0" from shard 0 of the saved files.
+ with tf.Session(
+ target="",
+ config=tf.ConfigProto(device_count={"CPU": 2})) as sess:
+ with sess.graph.device("/cpu:0"):
+ v0 = tf.Variable(111, name="v0")
+ save = tf.train.Saver({"v0": v0}, sharded=True)
+ tf.initialize_all_variables().run()
+ self.assertEqual(111, v0.eval())
+ save.restore(sess, save_path + "-00000-of-00002")
+ self.assertEqual(10, v0.eval())
+
+ # Restore a different "v1" from shard 1 of the saved files.
+ with tf.Session(
+ target="",
+ config=tf.ConfigProto(device_count={"CPU": 2})) as sess:
+ with sess.graph.device("/cpu:0"):
+ v1 = tf.Variable(222)
+ save = tf.train.Saver({"v1": v1}, sharded=True)
+ tf.initialize_all_variables().run()
+ self.assertEqual(222, v1.eval())
+ save.restore(sess, save_path + "-00001-of-00002")
+ self.assertEqual(20, v1.eval())
+
+ # Now try a restore with the sharded filename.
+ with tf.Session(
+ target="",
+ config=tf.ConfigProto(device_count={"CPU": 2})) as sess:
+ with sess.graph.device("/cpu:0"):
+ v0 = tf.Variable(111, name="v0")
+ with sess.graph.device("/cpu:1"):
+ v1 = tf.Variable(222, name="v1")
+ save = tf.train.Saver({"v0": v0, "v1": v1}, sharded=True)
+ tf.initialize_all_variables().run()
+ self.assertEqual(111, v0.eval())
+ self.assertEqual(222, v1.eval())
+ save_path = os.path.join(self.get_temp_dir(), "sharded")
+ save.restore(sess, save_path + "-?????-of-?????")
+ self.assertEqual(10, v0.eval())
+ self.assertEqual(20, v1.eval())
+
+ def testSaverDef(self):
+ with self.test_session():
+ v0 = tf.Variable(123, name="v0")
+ save = tf.train.Saver({"v0": v0}, sharded=True)
+ sd = save.as_saver_def()
+ self.assertTrue(sd.sharded)
+
+
+class MaxToKeepTest(tf.test.TestCase):
+
+ def testNonSharded(self):
+ save_dir = os.path.join(self.get_temp_dir(), "max_to_keep_non_sharded")
+ try:
+ gfile.DeleteRecursively(save_dir)
+ except gfile.GOSError, _:
+ pass # Ignore
+ gfile.MakeDirs(save_dir)
+
+ with self.test_session() as sess:
+ v = tf.Variable(10.0, name="v")
+ save = tf.train.Saver({"v": v}, max_to_keep=2)
+ tf.initialize_all_variables().run()
+ self.assertEqual([], save.last_checkpoints)
+
+ s1 = save.save(sess, os.path.join(save_dir, "s1"))
+ self.assertEqual([s1], save.last_checkpoints)
+ self.assertTrue(gfile.Exists(s1))
+
+ s2 = save.save(sess, os.path.join(save_dir, "s2"))
+ self.assertEqual([s1, s2], save.last_checkpoints)
+ self.assertTrue(gfile.Exists(s1))
+ self.assertTrue(gfile.Exists(s2))
+
+ s3 = save.save(sess, os.path.join(save_dir, "s3"))
+ self.assertEqual([s2, s3], save.last_checkpoints)
+ self.assertFalse(gfile.Exists(s1))
+ self.assertTrue(gfile.Exists(s2))
+ self.assertTrue(gfile.Exists(s3))
+
+ # Create a second helper, identical to the first.
+ save2 = tf.train.Saver(saver_def=save.as_saver_def())
+ save2.set_last_checkpoints(save.last_checkpoints)
+
+ # Create a third helper, with the same configuration but no knowledge of
+ # previous checkpoints.
+ save3 = tf.train.Saver(saver_def=save.as_saver_def())
+
+ # Exercise the first helper.
+
+ # Adding s2 again (old s2 is removed first, then new s2 appended)
+ s2 = save.save(sess, os.path.join(save_dir, "s2"))
+ self.assertEqual([s3, s2], save.last_checkpoints)
+ self.assertFalse(gfile.Exists(s1))
+ self.assertTrue(gfile.Exists(s3))
+ self.assertTrue(gfile.Exists(s2))
+
+ # Adding s1 (s3 should now be deleted as oldest in list)
+ s1 = save.save(sess, os.path.join(save_dir, "s1"))
+ self.assertEqual([s2, s1], save.last_checkpoints)
+ self.assertFalse(gfile.Exists(s3))
+ self.assertTrue(gfile.Exists(s2))
+ self.assertTrue(gfile.Exists(s1))
+
+ # Exercise the second helper.
+
+ # Adding s2 again (old s2 is removed first, then new s2 appended)
+ s2 = save2.save(sess, os.path.join(save_dir, "s2"))
+ self.assertEqual([s3, s2], save2.last_checkpoints)
+ # Created by the first helper.
+ self.assertTrue(gfile.Exists(s1))
+ # Deleted by the first helper.
+ self.assertFalse(gfile.Exists(s3))
+ self.assertTrue(gfile.Exists(s2))
+
+ # Adding s1 (s3 should now be deleted as oldest in list)
+ s1 = save2.save(sess, os.path.join(save_dir, "s1"))
+ self.assertEqual([s2, s1], save2.last_checkpoints)
+ self.assertFalse(gfile.Exists(s3))
+ self.assertTrue(gfile.Exists(s2))
+ self.assertTrue(gfile.Exists(s1))
+
+ # Exercise the third helper.
+
+ # Adding s2 again (but helper is unaware of previous s2)
+ s2 = save3.save(sess, os.path.join(save_dir, "s2"))
+ self.assertEqual([s2], save3.last_checkpoints)
+ # Created by the first helper.
+ self.assertTrue(gfile.Exists(s1))
+ # Deleted by the first helper.
+ self.assertFalse(gfile.Exists(s3))
+ self.assertTrue(gfile.Exists(s2))
+
+ # Adding s1 (s3 should not be deleted because helper is unaware of it)
+ s1 = save3.save(sess, os.path.join(save_dir, "s1"))
+ self.assertEqual([s2, s1], save3.last_checkpoints)
+ self.assertFalse(gfile.Exists(s3))
+ self.assertTrue(gfile.Exists(s2))
+ self.assertTrue(gfile.Exists(s1))
+
+ def testSharded(self):
+ save_dir = os.path.join(self.get_temp_dir(), "max_to_keep_sharded")
+ try:
+ gfile.DeleteRecursively(save_dir)
+ except gfile.GOSError, _:
+ pass # Ignore
+ gfile.MakeDirs(save_dir)
+
+ with tf.Session(
+ target="",
+ config=tf.ConfigProto(device_count={"CPU": 2})) as sess:
+ with sess.graph.device("/cpu:0"):
+ v0 = tf.Variable(111, name="v0")
+ with sess.graph.device("/cpu:1"):
+ v1 = tf.Variable(222, name="v1")
+ save = tf.train.Saver({"v0": v0, "v1": v1}, sharded=True, max_to_keep=2)
+ tf.initialize_all_variables().run()
+ self.assertEqual([], save.last_checkpoints)
+
+ s1 = save.save(sess, os.path.join(save_dir, "s1"))
+ self.assertEqual([s1], save.last_checkpoints)
+ self.assertEquals(2, len(gfile.Glob(s1)))
+
+ s2 = save.save(sess, os.path.join(save_dir, "s2"))
+ self.assertEqual([s1, s2], save.last_checkpoints)
+ self.assertEquals(2, len(gfile.Glob(s1)))
+ self.assertEquals(2, len(gfile.Glob(s2)))
+
+ s3 = save.save(sess, os.path.join(save_dir, "s3"))
+ self.assertEqual([s2, s3], save.last_checkpoints)
+ self.assertEquals(0, len(gfile.Glob(s1)))
+ self.assertEquals(2, len(gfile.Glob(s2)))
+ self.assertEquals(2, len(gfile.Glob(s3)))
+
+
+class KeepCheckpointEveryNHoursTest(tf.test.TestCase):
+
+ def testNonSharded(self):
+ save_dir = os.path.join(self.get_temp_dir(),
+ "keep_checkpoint_every_n_hours")
+ try:
+ gfile.DeleteRecursively(save_dir)
+ except gfile.GOSError, _:
+ pass # Ignore
+ gfile.MakeDirs(save_dir)
+
+ with self.test_session() as sess:
+ v = tf.Variable([10.0], name="v")
+ # Run the initializer NOW to avoid the 0.5s overhead of the first Run()
+ # call, which throws the test timing off in fastbuild mode.
+ tf.initialize_all_variables().run()
+ # Create a saver that will keep the last 2 checkpoints plus one every 0.7
+ # seconds.
+ start_time = time.time()
+ save = tf.train.Saver({"v": v}, max_to_keep=2,
+ keep_checkpoint_every_n_hours=0.7 / 3600)
+ self.assertEqual([], save.last_checkpoints)
+
+ # Wait till 0.7 second have elapsed so s1 will be old enough to keep.
+ time.sleep((time.time() + 0.7) - start_time)
+ s1 = save.save(sess, os.path.join(save_dir, "s1"))
+ self.assertEqual([s1], save.last_checkpoints)
+
+ s2 = save.save(sess, os.path.join(save_dir, "s2"))
+ self.assertEqual([s1, s2], save.last_checkpoints)
+
+ # We now have 2 'last_checkpoints': [s1, s2]. The next call to Save(),
+ # would normally delete s1, because max_to_keep is 2. However, s1 is
+ # older than 0.7s so we must keep it.
+ s3 = save.save(sess, os.path.join(save_dir, "s3"))
+ self.assertEqual([s2, s3], save.last_checkpoints)
+
+ # s1 should still be here, we are Not checking now to reduce time
+ # variance in the test.
+
+ # We now have 2 'last_checkpoints': [s2, s3], and s1 on disk. The next
+ # call to Save(), will delete s2, because max_to_keep is 2, and because
+ # we already kept the old s1. s2 is very close in time to s1 so it gets
+ # deleted.
+ s4 = save.save(sess, os.path.join(save_dir, "s4"))
+ self.assertEqual([s3, s4], save.last_checkpoints)
+
+ # Check that s1 is still here, but s2 is gone.
+ self.assertTrue(gfile.Exists(s1))
+ self.assertFalse(gfile.Exists(s2))
+ self.assertTrue(gfile.Exists(s3))
+ self.assertTrue(gfile.Exists(s4))
+
+
+class SaveRestoreWithVariableNameMap(tf.test.TestCase):
+
+ def testNonReshape(self):
+ save_path = os.path.join(self.get_temp_dir(), "basics")
+
+ with self.test_session() as sess:
+ # Build a graph with 2 parameter nodes, and Save and
+ # Restore nodes for them.
+ v0 = tf.Variable(10.0, name="v0")
+ v1 = tf.Variable(20.0, name="v1")
+ save = tf.train.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
+ tf.initialize_all_variables().run()
+
+ # Check that the parameter nodes have been initialized.
+ self.assertEqual(10.0, v0.eval())
+ self.assertEqual(20.0, v1.eval())
+
+ # Save the initialized values in the file at "save_path"
+ # Use a variable name map to set the saved tensor names
+ val = save.save(sess, save_path)
+ self.assertTrue(isinstance(val, basestring))
+ self.assertEqual(save_path, val)
+
+ # Verify that the original names are not in the Saved file
+ save = tf.train.Saver({"v0": v0, "v1": v1})
+ with self.assertRaisesOpError("not found in checkpoint"):
+ save.restore(sess, save_path)
+
+ # Verify that the mapped names are present in the Saved file and can be
+ # Restored using remapped names.
+ with self.test_session() as sess:
+ v0 = tf.Variable(-1.0, name="v0")
+ v1 = tf.Variable(-1.0, name="v1")
+
+ with self.assertRaisesOpError("uninitialized value v0"):
+ sess.run(v0)
+ with self.assertRaisesOpError("uninitialized value v1"):
+ sess.run(v1)
+
+ save = tf.train.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
+ save.restore(sess, save_path)
+
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(10.0, v0.eval())
+ self.assertEqual(20.0, v1.eval())
+
+ # Add a prefix to the node names in the current graph and Restore using
+ # remapped names.
+ with self.test_session() as sess:
+ v0 = tf.Variable(-1.0, name="restore_prefix/v0")
+ v1 = tf.Variable(-1.0, name="restore_prefix/v1")
+
+ with self.assertRaisesOpError("uninitialized value restore_prefix/v0"):
+ sess.run(v0)
+ with self.assertRaisesOpError("uninitialized value restore_prefix/v1"):
+ sess.run(v1)
+
+ # Restore the saved values in the parameter nodes.
+ save = tf.train.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
+ save.restore(sess, save_path)
+
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(10.0, v0.eval())
+ self.assertEqual(20.0, v1.eval())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/summary_io.py b/tensorflow/python/training/summary_io.py
new file mode 100644
index 0000000000..dd994c5311
--- /dev/null
+++ b/tensorflow/python/training/summary_io.py
@@ -0,0 +1,226 @@
+"""Reads Summaries from and writes Summaries to event files."""
+
+import os.path
+import Queue
+import threading
+import time
+
+from tensorflow.core.framework import summary_pb2
+from tensorflow.core.util import event_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.lib.io import tf_record
+from tensorflow.python.platform import gfile
+
+
+class SummaryWriter(object):
+ """Writes `Summary` protocol buffers to event files.
+
+ The `SummaryWriter` class provides a mechanism to create an event file in a
+ given directory and add summaries and events to it. The class updates the
+ file contents asynchronously. This allows a training program to call methods
+ to add data to the file directly from the training loop, without slowing down
+ training.
+
+ @@__init__
+
+ @@add_summary
+ @@add_event
+ @@add_graph
+
+ @@flush
+ @@close
+ """
+
+ def __init__(self, logdir, graph_def=None, max_queue=10, flush_secs=120):
+ """Creates a `SummaryWriter` and an event file.
+
+ On construction the summary writer creates a new event file in `logdir`.
+ This event file will contain `Event` protocol buffers constructed when you
+ call one of the following functions: `add_summary()`, `add_event()`, or
+ `add_graph()`.
+
+ If you pass a `graph_def` protocol buffer to the constructor it is added to
+ the event file. (This is equivalent to calling `add_graph()` later).
+
+ TensorBoard will pick the graph from the file and display it graphically so
+ you can interactively explore the graph you built. You will usually pass
+ the graph from the session in which you launched it:
+
+ ```python
+ ...create a graph...
+ # Launch the graph in a session.
+ sess = tf.Session()
+ # Create a summary writer, add the 'graph_def' to the event file.
+ writer = tf.train.SummaryWriter(<some-directory>, sess.graph_def)
+ ```
+
+ The other arguments to the constructor control the asynchronous writes to
+ the event file:
+
+ * `flush_secs`: How often, in seconds, to flush the added summaries
+ and events to disk.
+ * `max_queue`: Maximum number of summaries or events pending to be
+ written to disk before one of the 'add' calls block.
+
+ Args:
+ logdir: A string. Directory where event file will be written.
+ graph_def: A `GraphDef` protocol buffer.
+ max_queue: Integer. Size of the queue for pending events and summaries.
+ flush_secs: Number. How often, in seconds, to flush the
+ pending events and summaries to disk.
+ """
+ self._logdir = logdir
+ if not gfile.IsDirectory(self._logdir):
+ gfile.MakeDirs(self._logdir)
+ self._event_queue = Queue.Queue(max_queue)
+ self._ev_writer = pywrap_tensorflow.EventsWriter(
+ os.path.join(self._logdir, "events"))
+ self._worker = _EventLoggerThread(self._event_queue, self._ev_writer,
+ flush_secs)
+ self._worker.start()
+ if graph_def is not None:
+ self.add_graph(graph_def)
+
+ def add_summary(self, summary, global_step=None):
+ """Adds a `Summary` protocol buffer to the event file.
+
+ This method wraps the provided summary in an `Event` procotol buffer
+ and adds it to the event file.
+
+ You can pass the output of any summary op, as-is, to this function. You
+ can also pass a `Summary` procotol buffer that you manufacture with your
+ own data. This is commonly done to report evaluation results in event
+ files.
+
+ Args:
+ summary: A `Summary` protocol buffer, optionally serialized as a string.
+ global_step: Number. Optional global step value to record with the
+ summary.
+ """
+ if isinstance(summary, basestring):
+ summ = summary_pb2.Summary()
+ summ.ParseFromString(summary)
+ summary = summ
+ event = event_pb2.Event(wall_time=time.time(), summary=summary)
+ if global_step is not None:
+ event.step = long(global_step)
+ self.add_event(event)
+
+ def add_event(self, event):
+ """Adds an event to the event file.
+
+ Args:
+ event: An `Event` protocol buffer.
+ """
+ self._event_queue.put(event)
+
+ def add_graph(self, graph_def, global_step=None):
+ """Adds a `GraphDef` protocol buffer to the event file.
+
+ The graph described by the protocol buffer will be displayed by
+ TensorBoard. Most users pass a graph in the constructor instead.
+
+ Args:
+ graph_def: A `GraphDef` protocol buffer.
+ global_step: Number. Optional global step counter to record with the
+ graph.
+ """
+ event = event_pb2.Event(wall_time=time.time(), graph_def=graph_def)
+ if global_step is not None:
+ event.step = long(global_step)
+ self._event_queue.put(event)
+
+ def flush(self):
+ """Flushes the event file to disk.
+
+ Call this method to make sure that all pending events have been written to
+ disk.
+ """
+ self._event_queue.join()
+ self._ev_writer.Flush()
+
+ def close(self):
+ """Flushes the event file to disk and close the file.
+
+ Call this method when you do not need the summary writer anymore.
+ """
+ self.flush()
+ self._ev_writer.Close()
+
+
+class _EventLoggerThread(threading.Thread):
+ """Thread that logs events."""
+
+ def __init__(self, queue, ev_writer, flush_secs):
+ """Creates an _EventLoggerThread.
+
+ Args:
+ queue: a Queue from which to dequeue events.
+ ev_writer: an event writer. Used to log brain events for
+ the visualizer.
+ flush_secs: How often, in seconds, to flush the
+ pending file to disk.
+ """
+ threading.Thread.__init__(self)
+ self.daemon = True
+ self._queue = queue
+ self._ev_writer = ev_writer
+ self._flush_secs = flush_secs
+ # The first event will be flushed immediately.
+ self._next_event_flush_time = 0
+
+ def run(self):
+ while True:
+ event = self._queue.get()
+ try:
+ self._ev_writer.WriteEvent(event)
+ # Flush the event writer every so often.
+ now = time.time()
+ if now > self._next_event_flush_time:
+ self._ev_writer.Flush()
+ # Do it again in two minutes.
+ self._next_event_flush_time = now + self._flush_secs
+ finally:
+ self._queue.task_done()
+
+
+def summary_iterator(path):
+ """An iterator for reading `Event` protocol buffers from an event file.
+
+ You can use this function to read events written to an event file. It returns
+ a Python iterator that yields `Event` protocol buffers.
+
+ Example: Print the contents of an events file.
+
+ ```python
+ for e in tf.summary_iterator(path to events file):
+ print e
+ ```
+
+ Example: Print selected summary values.
+
+ ```python
+ # This example supposes that the events file contains summaries with a
+ # summary value tag 'loss'. These could have been added by calling
+ # `add_summary()`, passing the output of a scalar summary op created with
+ # with: `tf.scalar_summary(['loss'], loss_tensor)`.
+ for e in tf.summary_iterator(path to events file):
+ for v in e.summary.value:
+ if v.tag == 'loss':
+ print v.simple_value
+ ```
+
+ See the protocol buffer definitions of
+ [Event](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/util/event.proto)
+ and
+ [Summary](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/summary.proto)
+ for more information about their attributes.
+
+ Args:
+ path: The path to an event file created by a `SummaryWriter`.
+
+ Yields:
+ `Event` protocol buffers.
+ """
+ for r in tf_record.tf_record_iterator(path):
+ yield event_pb2.Event.FromString(r)
diff --git a/tensorflow/python/training/summary_writer_test.py b/tensorflow/python/training/summary_writer_test.py
new file mode 100644
index 0000000000..2ec416f68f
--- /dev/null
+++ b/tensorflow/python/training/summary_writer_test.py
@@ -0,0 +1,151 @@
+"""Tests for training_coordinator.py."""
+import glob
+import os.path
+import shutil
+import time
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class SummaryWriterTestCase(tf.test.TestCase):
+
+ def _TestDir(self, test_name):
+ test_dir = os.path.join(self.get_temp_dir(), test_name)
+ return test_dir
+
+ def _CleanTestDir(self, test_name):
+ test_dir = self._TestDir(test_name)
+ if os.path.exists(test_dir):
+ shutil.rmtree(test_dir)
+ return test_dir
+
+ def _EventsReader(self, test_dir):
+ event_paths = glob.glob(os.path.join(test_dir, "event*"))
+ # If the tests runs multiple time in the same directory we can have
+ # more than one matching event file. We only want to read the last one.
+ self.assertTrue(event_paths)
+ return tf.train.summary_iterator(event_paths[-1])
+
+ def _assertRecent(self, t):
+ self.assertTrue(abs(t - time.time()) < 5)
+
+ def testBasics(self):
+ test_dir = self._CleanTestDir("basics")
+ sw = tf.train.SummaryWriter(test_dir)
+ sw.add_summary(tf.Summary(value=[tf.Summary.Value(tag="mee",
+ simple_value=10.0)]),
+ 10)
+ sw.add_summary(tf.Summary(value=[tf.Summary.Value(tag="boo",
+ simple_value=20.0)]),
+ 20)
+ with tf.Graph().as_default() as g:
+ tf.constant([0], name="zero")
+ gd = g.as_graph_def()
+ sw.add_graph(gd, global_step=30)
+ sw.close()
+ rr = self._EventsReader(test_dir)
+
+ # The first event should list the file_version.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals("brain.Event:1", ev.file_version)
+
+ # The next event should have the value 'mee=10.0'.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(10, ev.step)
+ self.assertProtoEquals("""
+ value { tag: 'mee' simple_value: 10.0 }
+ """, ev.summary)
+
+ # The next event should have the value 'boo=20.0'.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(20, ev.step)
+ self.assertProtoEquals("""
+ value { tag: 'boo' simple_value: 20.0 }
+ """, ev.summary)
+
+ # The next event should have the graph_def.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(30, ev.step)
+ self.assertProtoEquals(gd, ev.graph_def)
+
+ # We should be done.
+ self.assertRaises(StopIteration, lambda: next(rr))
+
+ def testConstructWithGraph(self):
+ test_dir = self._CleanTestDir("basics_with_graph")
+ with tf.Graph().as_default() as g:
+ tf.constant([12], name="douze")
+ gd = g.as_graph_def()
+ sw = tf.train.SummaryWriter(test_dir, graph_def=gd)
+ sw.close()
+ rr = self._EventsReader(test_dir)
+
+ # The first event should list the file_version.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals("brain.Event:1", ev.file_version)
+
+ # The next event should have the graph.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(0, ev.step)
+ self.assertProtoEquals(gd, ev.graph_def)
+
+ # We should be done.
+ self.assertRaises(StopIteration, lambda: next(rr))
+
+ # Checks that values returned from session Run() calls are added correctly to
+ # summaries. These are numpy types so we need to check they fit in the
+ # protocol buffers correctly.
+ def testSummariesAndStopFromSessionRunCalls(self):
+ test_dir = self._CleanTestDir("global_step")
+ sw = tf.train.SummaryWriter(test_dir)
+ with self.test_session():
+ i = tf.constant(1, dtype=tf.int32, shape=[])
+ l = tf.constant(2, dtype=tf.int64, shape=[])
+ # Test the summary can be passed serialized.
+ summ = tf.Summary(value=[tf.Summary.Value(tag="i", simple_value=1.0)])
+ sw.add_summary(summ.SerializeToString(), i.eval())
+ sw.add_summary(tf.Summary(value=[tf.Summary.Value(tag="l",
+ simple_value=2.0)]),
+ l.eval())
+ sw.close()
+
+ rr = self._EventsReader(test_dir)
+
+ # File_version.
+ ev = next(rr)
+ self.assertTrue(ev)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals("brain.Event:1", ev.file_version)
+
+ # Summary passed serialized.
+ ev = next(rr)
+ self.assertTrue(ev)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(1, ev.step)
+ self.assertProtoEquals("""
+ value { tag: 'i' simple_value: 1.0 }
+ """, ev.summary)
+
+ # Summary passed as SummaryObject.
+ ev = next(rr)
+ self.assertTrue(ev)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(2, ev.step)
+ self.assertProtoEquals("""
+ value { tag: 'l' simple_value: 2.0 }
+ """, ev.summary)
+
+ # We should be done.
+ self.assertRaises(StopIteration, lambda: next(rr))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py
new file mode 100644
index 0000000000..a400e9fa7d
--- /dev/null
+++ b/tensorflow/python/training/training.py
@@ -0,0 +1,138 @@
+# pylint: disable=wildcard-import,unused-import,g-bad-import-order,line-too-long
+"""This library provides a set of classes and functions that helps train models.
+
+## Optimizers.
+
+The Optimizer base class provides methods to compute gradients for a loss and
+apply gradients to variables. A collection of subclasses implement classic
+optimization algorithms such as GradientDescent and Adagrad.
+
+You never instantiate the Optimizer class itself, but instead instantiate one
+of the subclasses.
+
+@@Optimizer
+
+@@GradientDescentOptimizer
+@@AdagradOptimizer
+@@MomentumOptimizer
+@@AdamOptimizer
+@@FtrlOptimizer
+@@RMSPropOptimizer
+
+## Gradient Computation.
+
+TensorFlow provides functions to compute the derivatives for a given
+TensorFlow computation graph, adding operations to the graph. The
+optimizer classes automatically compute derivatives on your graph, but
+creators of new Optimizers or expert users can call the lower-level
+functions below.
+
+@@gradients
+@@AggregationMethod
+
+@@stop_gradient
+
+
+## Gradient Clipping
+
+TensorFlow provides several operations that you can use to add clipping
+functions to your graph. You can use these functions to perform general data
+clipping, but they're particularly useful for handling exploding or vanishing
+gradients.
+
+@@clip_by_value
+@@clip_by_norm
+@@clip_by_average_norm
+@@clip_by_global_norm
+@@global_norm
+
+## Decaying the learning rate.
+@@exponential_decay
+
+## Moving Averages.
+
+Some training algorithms, such as GradientDescent and Momentum often benefit
+from maintaining a moving average of variables during optimization. Using the
+moving averages for evaluations often improve results significantly.
+
+@@ExponentialMovingAverage
+
+## Coordinator and QueueRunner.
+
+See [Threading and Queues](../../how_tos/threading_and_queues/index.md)
+for how to use threads and queues. For documentation on the Queue API,
+see [Queues](../../api_docs/python/io_ops.md#queues).
+
+@@Coordinator
+@@QueueRunner
+@@add_queue_runner
+@@start_queue_runners
+
+## Summary Operations.
+
+The following ops output
+[`Summary`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/summary.proto)
+protocol buffers as serialized string tensors.
+
+You can fetch the output of a summary op in a session, and pass it to a
+[SummaryWriter](train.md#SummaryWriter) to append it to an event file. You can
+then use TensorBoard to visualize the contents of the event files. See
+[TensorBoard and Summaries](../../how_tos/summaries_and_tensorboard/index.md)
+for more details.
+
+@@scalar_summary
+@@image_summary
+@@histogram_summary
+@@zero_fraction
+
+@@merge_summary
+@@merge_all_summaries
+
+## Adding Summaries to Event Files.
+
+See [Summaries and
+TensorBoard](../../how_tos/summaries_and_tensorboard/index.md) for an
+overview of summaries, event files, and visualization in TensorBoard.
+
+@@SummaryWriter
+@@summary_iterator
+
+## Training utilities.
+
+@@global_step
+@@write_graph
+"""
+
+# Optimizers.
+from tensorflow.python.training.adagrad import AdagradOptimizer
+from tensorflow.python.training.adam import AdamOptimizer
+from tensorflow.python.training.ftrl import FtrlOptimizer
+from tensorflow.python.training.momentum import MomentumOptimizer
+from tensorflow.python.training.moving_averages import ExponentialMovingAverage
+from tensorflow.python.training.optimizer import Optimizer
+from tensorflow.python.training.rmsprop import RMSPropOptimizer
+from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
+
+# Utility classes for training.
+from tensorflow.python.training.coordinator import Coordinator
+from tensorflow.python.training.queue_runner import *
+
+# For the module level doc.
+from tensorflow.python.training import input as _input
+from tensorflow.python.training.input import *
+
+from tensorflow.python.training.saver import get_checkpoint_state
+from tensorflow.python.training.saver import latest_checkpoint
+from tensorflow.python.training.saver import Saver
+from tensorflow.python.training.saver import update_checkpoint_state
+from tensorflow.python.training.summary_io import summary_iterator
+from tensorflow.python.training.summary_io import SummaryWriter
+from tensorflow.python.training.training_util import write_graph
+from tensorflow.python.training.training_util import global_step
+
+# Training data protos.
+from tensorflow.core.example.example_pb2 import *
+from tensorflow.core.example.feature_pb2 import *
+
+# Utility op. Open Source. TODO(mdevin): move to nn?
+from tensorflow.python.training.learning_rate_decay import exponential_decay
diff --git a/tensorflow/python/training/training_ops.py b/tensorflow/python/training/training_ops.py
new file mode 100644
index 0000000000..410b23e04d
--- /dev/null
+++ b/tensorflow/python/training/training_ops.py
@@ -0,0 +1,115 @@
+"""Python wrappers for training ops."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.training import gen_training_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.training.gen_training_ops import *
+# pylint: enable=wildcard-import
+
+
+# Shape functions for fused training ops
+# --------------------------------------
+#
+# The fused training ops all have the same basic structure: they take
+# one or more variables with the same shape, and emit a reference to
+# the original variable (which has the same shape as the first
+# input). In addition, they take one or more scalar tensors containing
+# hyperparameters.
+#
+# The sparse ops take the gradients as a Python IndexedSlices, which
+# means that the indices are a vector of length N, and the gradient
+# values are a tensor whose size is the same as the original variable,
+# except for the 0th dimension, which has size N.
+
+
+def _AssertInputIsScalar(op, index):
+ """Raises ValueError if `op.inputs[index]` is not scalar."""
+ op.inputs[index].get_shape().assert_is_compatible_with(tensor_shape.scalar())
+
+
+@ops.RegisterShape("ApplyAdagrad")
+def _ApplyAdagradShape(op):
+ """Shape function for the ApplyAdagrad op."""
+ var_shape = op.inputs[0].get_shape()
+ accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ _AssertInputIsScalar(op, 2) # lr
+ grad_shape = op.inputs[3].get_shape().merge_with(accum_shape)
+ return [grad_shape]
+
+
+@ops.RegisterShape("ApplyAdam")
+def _ApplyAdamShape(op):
+ """Shape function for the ApplyAdam op."""
+ var_shape = op.inputs[0].get_shape()
+ m_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ v_shape = op.inputs[2].get_shape().merge_with(m_shape)
+ _AssertInputIsScalar(op, 3) # beta1_power
+ _AssertInputIsScalar(op, 4) # beta2_power
+ _AssertInputIsScalar(op, 5) # lr
+ _AssertInputIsScalar(op, 6) # beta1
+ _AssertInputIsScalar(op, 7) # beta2
+ _AssertInputIsScalar(op, 8) # epsilon
+ grad_shape = op.inputs[9].get_shape().merge_with(v_shape)
+ return [grad_shape]
+
+
+@ops.RegisterShape("ApplyMomentum")
+def _ApplyMomentumShape(op):
+ """Shape function for the ApplyMomentum op."""
+ var_shape = op.inputs[0].get_shape()
+ accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ _AssertInputIsScalar(op, 2) # lr
+ grad_shape = op.inputs[3].get_shape().merge_with(accum_shape)
+ _AssertInputIsScalar(op, 4) # momentum
+ return [grad_shape]
+
+
+@ops.RegisterShape("ApplyRMSProp")
+def _ApplyRMSPropShape(op):
+ """Shape function for the ApplyRMSProp op."""
+ var_shape = op.inputs[0].get_shape()
+ ms_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ mom_shape = op.inputs[2].get_shape().merge_with(ms_shape)
+ _AssertInputIsScalar(op, 3) # lr
+ _AssertInputIsScalar(op, 4) # rho
+ _AssertInputIsScalar(op, 5) # momentum
+ _AssertInputIsScalar(op, 6) # epsilon
+ grad_shape = op.inputs[7].get_shape().merge_with(mom_shape)
+ return [grad_shape]
+
+
+@ops.RegisterShape("ApplyGradientDescent")
+def _ApplyGradientDescentShape(op):
+ """Shape function for the ApplyGradientDescent op."""
+ var_shape = op.inputs[0].get_shape()
+ _AssertInputIsScalar(op, 1) # alpha
+ delta_shape = op.inputs[2].get_shape().merge_with(var_shape)
+ return [delta_shape]
+
+
+@ops.RegisterShape("SparseApplyAdagrad")
+def _SparseApplyAdagradShape(op):
+ """Shape function for the SparseApplyAdagrad op."""
+ var_shape = op.inputs[0].get_shape()
+ accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ _AssertInputIsScalar(op, 2) # lr
+ grad_shape = op.inputs[3].get_shape().merge_with(
+ tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
+ unused_indices_shape = op.inputs[4].get_shape().merge_with(
+ tensor_shape.vector(grad_shape[0]))
+ return [accum_shape]
+
+
+@ops.RegisterShape("SparseApplyMomentum")
+def _SparseApplyMomentumShape(op):
+ """Shape function for the SparseApplyMomentum op."""
+ var_shape = op.inputs[0].get_shape()
+ accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ _AssertInputIsScalar(op, 2) # lr
+ grad_shape = op.inputs[3].get_shape().merge_with(
+ tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
+ unused_indices_shape = op.inputs[4].get_shape().merge_with(
+ tensor_shape.vector(grad_shape[0]))
+ _AssertInputIsScalar(op, 5) # momentum
+ return [accum_shape]
diff --git a/tensorflow/python/training/training_ops_test.py b/tensorflow/python/training/training_ops_test.py
new file mode 100644
index 0000000000..902b9b0d78
--- /dev/null
+++ b/tensorflow/python/training/training_ops_test.py
@@ -0,0 +1,159 @@
+"""Tests for tensorflow.learning.training_ops."""
+
+import itertools
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import types
+from tensorflow.python.framework.test_util import TensorFlowTestCase
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import training_ops
+
+
+class TrainingOpsTest(TensorFlowTestCase):
+
+ def _toType(self, dtype):
+ if dtype == np.float32:
+ return types.float32
+ elif dtype == np.float64:
+ return types.float64
+ elif dtype == np.int32:
+ return types.int32
+ elif dtype == np.int64:
+ return types.int64
+ else:
+ assert False, (dtype)
+
+ def _testTypes(self, x, alpha, delta, use_gpu=None):
+ self.setUp()
+ with self.test_session(use_gpu=use_gpu):
+ var = variables.Variable(x)
+ variables.initialize_all_variables().run()
+ self.assertAllEqual(x, var.eval())
+ apply_sgd = training_ops.apply_gradient_descent(var, alpha, delta)
+ out = apply_sgd.eval()
+ self.assertShapeEqual(out, apply_sgd)
+ self.assertAllEqual(x - alpha * delta, out)
+
+ def testApplyGradientDescent(self):
+ for (dtype, use_gpu) in itertools.product(
+ [np.float32, np.float64], [False, True]):
+ x = np.arange(100).astype(dtype)
+ alpha = np.array(2.0).astype(dtype)
+ delta = np.arange(100).astype(dtype)
+ self._testTypes(x, alpha, delta, use_gpu)
+
+ def _testTypesForAdagrad(self, x, y, lr, grad, use_gpu=None):
+ self.setUp()
+ with self.test_session(use_gpu=use_gpu):
+ var = variables.Variable(x)
+ accum = variables.Variable(y)
+ variables.initialize_all_variables().run()
+
+ self.assertAllEqual(x, var.eval())
+ apply_adagrad = training_ops.apply_adagrad(var, accum, lr, grad)
+ out = apply_adagrad.eval()
+ self.assertShapeEqual(out, apply_adagrad)
+ self.assertAllClose(
+ x - lr * grad * (y + grad * grad) ** (-0.5), out)
+ self.assertAllEqual(y + grad * grad, accum.eval())
+
+ def testApplyAdagrad(self):
+ for (dtype, use_gpu) in itertools.product(
+ [np.float32, np.float64], [False, True]):
+ x = np.arange(100).astype(dtype)
+ y = np.arange(1, 101).astype(dtype)
+ lr = np.array(2.0).astype(dtype)
+ grad = np.arange(100).astype(dtype)
+ self._testTypesForAdagrad(x, y, lr, grad, use_gpu)
+
+ def _testTypesForSparseAdagrad(self, x, y, lr, grad, indices):
+ self.setUp()
+ with self.test_session(use_gpu=False):
+ var = variables.Variable(x)
+ accum = variables.Variable(y)
+ variables.initialize_all_variables().run()
+
+ self.assertAllEqual(x, var.eval())
+ sparse_apply_adagrad = training_ops.sparse_apply_adagrad(
+ var, accum, lr, grad,
+ constant_op.constant(indices, self._toType(indices.dtype)))
+ out = sparse_apply_adagrad.eval()
+ self.assertShapeEqual(out, sparse_apply_adagrad)
+
+ for (i, index) in enumerate(indices):
+ self.assertAllClose(
+ x[index] - lr * grad[i] * (y[index] + grad[i] * grad[i]) ** (-0.5),
+ var.eval()[index])
+ self.assertAllEqual(y[index] + grad[i] * grad[i], accum.eval()[index])
+
+ def testSparseApplyAdagrad(self):
+ for (dtype, index_type) in itertools.product(
+ [np.float32, np.float64], [np.int32, np.int64]):
+ x_val = [range(10), range(10, 20), range(20, 30)]
+ y_val = [range(1, 11), range(11, 21), range(21, 31)]
+ x = np.array(x_val).astype(dtype)
+ y = np.array(y_val).astype(dtype)
+ lr = np.array(2.0).astype(dtype)
+ grad_val = [range(10), range(10)]
+ grad = np.array(grad_val).astype(dtype)
+ indices = np.array([0, 2]).astype(index_type)
+ self._testTypesForSparseAdagrad(x, y, lr, grad, indices)
+
+ def testApplyAdam(self):
+ for dtype, use_gpu in itertools.product(
+ [np.float32, np.float64], [False, True]):
+ var = np.arange(100).astype(dtype)
+ m = np.arange(1, 101).astype(dtype)
+ v = np.arange(101, 201).astype(dtype)
+ grad = np.arange(100).astype(dtype)
+ self._testTypesForAdam(var, m, v, grad, use_gpu)
+
+ def _testTypesForAdam(self, var, m, v, grad, use_gpu):
+ self.setUp()
+ with self.test_session(use_gpu=use_gpu):
+ var_t = variables.Variable(var)
+ m_t = variables.Variable(m)
+ v_t = variables.Variable(v)
+
+ t = 1
+ beta1 = np.array(0.9, dtype=var.dtype)
+ beta2 = np.array(0.999, dtype=var.dtype)
+ beta1_power = beta1**t
+ beta2_power = beta2**t
+ lr = np.array(0.001, dtype=var.dtype)
+ epsilon = np.array(1e-8, dtype=var.dtype)
+ beta1_t = constant_op.constant(beta1, self._toType(var.dtype), [])
+ beta2_t = constant_op.constant(beta2, self._toType(var.dtype), [])
+ beta1_power_t = variables.Variable(beta1_power)
+ beta2_power_t = variables.Variable(beta2_power)
+ lr_t = constant_op.constant(lr, self._toType(var.dtype), [])
+ epsilon_t = constant_op.constant(epsilon, self._toType(var.dtype), [])
+ variables.initialize_all_variables().run()
+
+ self.assertAllEqual(var, var_t.eval())
+ new_var, _, _ = self._adamUpdateNumpy(var, grad, t, m, v,
+ lr, beta1, beta2, epsilon)
+ apply_adam = training_ops.apply_adam(var_t, m_t, v_t, beta1_power_t,
+ beta2_power_t, lr_t,
+ beta1_t, beta2_t, epsilon_t, grad)
+ out = apply_adam.eval()
+ self.assertShapeEqual(out, apply_adam)
+ self.assertAllClose(new_var, out)
+
+ def _adamUpdateNumpy(self, param, g_t, t, m, v, alpha, beta1,
+ beta2, epsilon):
+ alpha_t = alpha * np.sqrt(1 - beta2 ** t) / (1 - beta1 ** t)
+
+ m_t = beta1 * m + (1 - beta1) * g_t
+ v_t = beta2 * v + (1 - beta2) * g_t * g_t
+
+ param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon)
+ return param_t, m_t, v_t
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py
new file mode 100644
index 0000000000..14166e25c6
--- /dev/null
+++ b/tensorflow/python/training/training_util.py
@@ -0,0 +1,57 @@
+"""Utility functions for training."""
+import os.path
+
+from tensorflow.python.platform import gfile
+
+
+def global_step(sess, global_step_tensor):
+ """Small helper to get the global step.
+
+ ```python
+ # Creates a variable to hold the global_step.
+ global_step_tensor = tf.Variable(10, trainable=False, name='global_step')
+ # Creates a session.
+ sess = tf.Session()
+ # Initializes the variable.
+ sess.run(global_step_tensor.initializer)
+ print 'global_step:', tf.train.global_step(sess, global_step_tensor)
+
+ global_step: 10
+ ```
+
+ Args:
+ sess: A brain `Session` object.
+ global_step_tensor: `Tensor` or the `name` of the operation that contains
+ the global step.
+
+ Returns:
+ The global step value.
+ """
+ return int(sess.run(global_step_tensor))
+
+
+def write_graph(graph_def, logdir, name, as_text=True):
+ """Writes a graph proto on disk.
+
+ The graph is written as a binary proto unless as_text is `True`.
+
+ ```python
+ v = tf.Variable(0, name='my_variable')
+ sess = tf.Session()
+ tf.train.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt')
+ ```
+
+ Args:
+ graph_def: A `GraphDef` protocol buffer.
+ logdir: Directory where to write the graph.
+ name: Filename for the graph.
+ as_text: If `True`, writes the graph as an ASCII proto.
+ """
+ path = os.path.join(logdir, name)
+ gfile.MakeDirs(os.path.dirname(path))
+ f = gfile.FastGFile(path, "w")
+ if as_text:
+ f.write(str(graph_def))
+ else:
+ f.write(graph_def.SerializeToString())
+ f.close()
diff --git a/tensorflow/python/user_ops/__init__.py b/tensorflow/python/user_ops/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/user_ops/__init__.py
diff --git a/tensorflow/python/user_ops/user_ops.py b/tensorflow/python/user_ops/user_ops.py
new file mode 100644
index 0000000000..20e2604e05
--- /dev/null
+++ b/tensorflow/python/user_ops/user_ops.py
@@ -0,0 +1,10 @@
+"""All user ops."""
+
+import tensorflow.python.platform
+from tensorflow.python.ops import gen_user_ops
+from tensorflow.python.ops.gen_user_ops import *
+
+
+def my_fact():
+ """Example of overriding the generated code for an Op."""
+ return gen_user_ops._fact()
diff --git a/tensorflow/python/util/__init__.py b/tensorflow/python/util/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/util/__init__.py
diff --git a/tensorflow/python/util/port.i b/tensorflow/python/util/port.i
new file mode 100644
index 0000000000..fdb217dcc7
--- /dev/null
+++ b/tensorflow/python/util/port.i
@@ -0,0 +1,11 @@
+%include "tensorflow/python/platform/base.i"
+
+%{
+#include "tensorflow/core/util/port.h"
+%}
+
+%ignoreall
+%unignore tensorflow;
+%unignore tensorflow::IsGoogleCudaEnabled;
+%include "tensorflow/core/util/port.h"
+%unignoreall
diff --git a/tensorflow/python/util/protobuf/__init__.py b/tensorflow/python/util/protobuf/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/util/protobuf/__init__.py
diff --git a/tensorflow/python/util/protobuf/compare.py b/tensorflow/python/util/protobuf/compare.py
new file mode 100644
index 0000000000..19f7128f4e
--- /dev/null
+++ b/tensorflow/python/util/protobuf/compare.py
@@ -0,0 +1,384 @@
+#!/usr/bin/python2.4
+
+"""Utility functions for comparing proto2 messages in Python.
+
+Proto2Cmp() is a cmp-style comparison function. It can be passed to sort(), etc.
+See its docstring for details.
+
+ClearDefaultValuedFields() recursively clears the fields that are set to their
+default values. This is useful for comparing protocol buffers where the
+semantics of unset fields and default valued fields are the same.
+
+NormalizeRepeatedFields() sorts and optionally de-dupes repeated fields. This
+is useful for treating repeated fields as sets instead of lists.
+
+assertProto2Equal() and assertProto2SameElements() are useful for unit tests.
+They produce much more helpful output than assertEqual() and friends for proto2
+messages, e.g. this:
+
+ outer {
+ inner {
+- strings: "x"
+? ^
++ strings: "y"
+? ^
+ }
+ }
+
+...compared to the default output from assertEqual() that looks like this:
+
+AssertionError: <my.Msg object at 0x9fb353c> != <my.Msg object at 0x9fb35cc>
+
+Call them inside your unit test's googletest.TestCase subclasses like this:
+
+ from tensorflow.python.util.protobuf import compare
+
+ class MyTest(googletest.TestCase):
+ ...
+ def testXXX(self):
+ ...
+ compare.assertProto2Equal(self, a, b)
+ compare.assertProto2SameElements(self, a, c)
+
+Alternatively:
+
+ from tensorflow.python.util.protobuf import compare
+
+ class MyTest(compare.Proto2Assertions, googletest.TestCase):
+ ...
+ def testXXX(self):
+ ...
+ self.assertProto2Equal(a, b)
+ self.assertProto2SameElements(a, c)
+"""
+
+import copy
+
+from google.protobuf import descriptor
+from google.protobuf import message
+from google.protobuf import text_format
+
+
+def assertProto2Equal(self, a, b, check_initialized=True,
+ normalize_numbers=False, msg=None):
+ """Fails with a useful error if a and b aren't equal.
+
+ Comparison of repeated fields matches the semantics of
+ unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter.
+
+ Args:
+ self: googletest.TestCase
+ a: proto2 PB instance, or text string representing one
+ b: proto2 PB instance -- message.Message or subclass thereof
+ check_initialized: boolean, whether to fail if either a or b isn't
+ initialized
+ normalize_numbers: boolean, whether to normalize types and precision of
+ numbers before comparison.
+ msg: if specified, is used as the error message on failure
+ """
+ if isinstance(a, basestring):
+ a = text_format.Merge(a, b.__class__())
+
+ for pb in a, b:
+ if check_initialized:
+ errors = pb.FindInitializationErrors()
+ if errors:
+ self.fail('Initialization errors: %s\n%s' % (errors, pb))
+ if normalize_numbers:
+ NormalizeNumberFields(pb)
+
+ self.assertMultiLineEqual(text_format.MessageToString(a),
+ text_format.MessageToString(b),
+ msg=msg)
+
+
+def assertProto2SameElements(self, a, b, number_matters=False,
+ check_initialized=True, normalize_numbers=False,
+ msg=None):
+ """Fails with a useful error if a and b aren't equivalent.
+
+ When comparing repeated fields, order doesn't matter and the number of times
+ each element appears (ie duplicates) only matters if number_matters is True.
+
+ By default, comparison of repeated fields follows set semantics and matches
+ googletest.TestCase.assertSameElements(): neither order nor number of a given
+ element matters.
+
+ Args:
+ self: googletest.TestCase
+ a: proto2 PB instance, or text string representing one
+ b: proto2 PB instance -- message.Message or subclass thereof
+ number_matters: boolean, whether number of each elements must match
+ check_initialized: boolean, whether to fail if either a or b isn't
+ initialized
+ normalize_numbers: boolean, whether to normalize types and precision of
+ numbers before comparison.
+ msg: if specified, is used as the error message on failure
+ """
+ if isinstance(a, basestring):
+ a = text_format.Merge(a, b.__class__())
+ else:
+ a = copy.deepcopy(a)
+ b = copy.deepcopy(b)
+ for pb in a, b:
+ NormalizeRepeatedFields(pb, dedupe=not number_matters)
+ assertProto2Equal(
+ self, a, b, check_initialized=check_initialized,
+ normalize_numbers=normalize_numbers, msg=msg)
+
+
+def assertProto2Contains(self, a, b, # pylint: disable=invalid-name
+ number_matters=False, check_initialized=True,
+ msg=None):
+ """Fails with a useful error if fields in a are not in b.
+
+ Useful to test if expected fields are in b, allows tests to define
+ expected fields in string format.
+
+ Example:
+ compare.assertProto2Contains('group { field: "value" }', test_pb2)
+
+ Args:
+ self: googletest.TestCase
+ a: proto2 PB instance, or text string representing one
+ b: proto2 PB instance
+ number_matters: boolean, whether number of each field must match
+ check_initialized: boolean, whether to fail if b isn't initialized
+ msg: if specified, is used as the error message on failure
+ """
+ if isinstance(a, basestring):
+ a = text_format.Merge(a, b.__class__())
+ else:
+ a = copy.deepcopy(a)
+ completed_a = copy.deepcopy(b)
+ completed_a.MergeFrom(a)
+ assertProto2SameElements(self, completed_a, b, number_matters=number_matters,
+ check_initialized=check_initialized, msg=msg)
+
+
+def ClearDefaultValuedFields(pb):
+ """Clears all fields in a proto2 message that are set to their default values.
+
+ The result has more compact text / json / binary representation. It's also
+ easier to compare to other protos if the choice whether fields are not set or
+ set to their default values doesn't change the proto buffer's semantics.
+
+ Args:
+ pb: A proto2 message.
+ """
+ for field, value in pb.ListFields():
+ if field.type == field.TYPE_MESSAGE:
+ if field.label == field.LABEL_REPEATED:
+ for item in value:
+ ClearDefaultValuedFields(item)
+ else:
+ ClearDefaultValuedFields(value)
+ if field.label == field.LABEL_OPTIONAL and not value.ListFields():
+ pb.ClearField(field.name)
+ elif field.label == field.LABEL_OPTIONAL and value == field.default_value:
+ pb.ClearField(field.name)
+
+
+def NormalizeRepeatedFields(pb, dedupe=True):
+ """Sorts all repeated fields and optionally removes duplicates.
+
+ Modifies pb in place. Recurses into nested objects. Uses Proto2Cmp for
+ sorting.
+
+ Args:
+ pb: proto2 message
+ dedupe: boolean, whether to remove duplicates
+
+ Returns: the given pb, modified in place
+ """
+ for desc, values in pb.ListFields():
+ if desc.label is not descriptor.FieldDescriptor.LABEL_REPEATED:
+ values = [values]
+
+ if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
+ desc.message_type.has_options and
+ desc.message_type.GetOptions().map_entry):
+ # This is a map, only recurse if the values have a message type.
+ if (desc.message_type.fields_by_number[2].type ==
+ descriptor.FieldDescriptor.TYPE_MESSAGE):
+ for v in values.itervalues():
+ NormalizeRepeatedFields(v, dedupe=dedupe)
+ else:
+ if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
+ desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
+ for v in values:
+ # recursive step
+ NormalizeRepeatedFields(v, dedupe=dedupe)
+
+ values.sort(Proto2Cmp)
+
+ if dedupe:
+ # De-dupe in place. Can't use set, etc. because messages aren't
+ # hashable. This is a heavily discussed toy problem. the code below is
+ # a simplified version of http://code.activestate.com/recipes/52560/
+ # and it requires that values is sorted.
+ for i in xrange(len(values) - 1, 0, -1):
+ if values[i] == values[i - 1]:
+ del values[i]
+
+ return pb
+
+
+def NormalizeNumberFields(pb):
+ """Normalizes types and precisions of number fields in a protocol buffer.
+
+ Due to subtleties in the python protocol buffer implementation, it is possible
+ for values to have different types and precision depending on whether they
+ were set and retrieved directly or deserialized from a protobuf. This function
+ normalizes integer values to ints and longs based on width, 32-bit floats to
+ five digits of precision to account for python always storing them as 64-bit,
+ and ensures doubles are floating point for when they're set to integers.
+
+ Modifies pb in place. Recurses into nested objects.
+
+ Args:
+ pb: proto2 message
+
+ Returns:
+ the given pb, modified in place
+ """
+ for desc, values in pb.ListFields():
+ is_repeated = True
+ if desc.label is not descriptor.FieldDescriptor.LABEL_REPEATED:
+ is_repeated = False
+ values = [values]
+
+ normalized_values = None
+
+ # We force 32-bit values to int and 64-bit values to long to make
+ # alternate implementations where the distinction is more significant
+ # (e.g. the C++ implementation) simpler.
+ if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
+ descriptor.FieldDescriptor.TYPE_UINT64,
+ descriptor.FieldDescriptor.TYPE_SINT64):
+ normalized_values = [long(x) for x in values]
+ elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
+ descriptor.FieldDescriptor.TYPE_UINT32,
+ descriptor.FieldDescriptor.TYPE_SINT32,
+ descriptor.FieldDescriptor.TYPE_ENUM):
+ normalized_values = [int(x) for x in values]
+ elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
+ normalized_values = [round(x, 6) for x in values]
+ elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
+ normalized_values = [round(float(x), 7) for x in values]
+
+ if normalized_values is not None:
+ if is_repeated:
+ pb.ClearField(desc.name)
+ getattr(pb, desc.name).extend(normalized_values)
+ else:
+ setattr(pb, desc.name, normalized_values[0])
+
+ if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
+ desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
+ if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
+ desc.message_type.has_options and
+ desc.message_type.GetOptions().map_entry):
+ # This is a map, only recurse if the values have a message type.
+ if (desc.message_type.fields_by_number[2].type ==
+ descriptor.FieldDescriptor.TYPE_MESSAGE):
+ for v in values.itervalues():
+ NormalizeNumberFields(v)
+ else:
+ for v in values:
+ # recursive step
+ NormalizeNumberFields(v)
+
+ return pb
+
+
+def _IsRepeatedContainer(value):
+ if isinstance(value, basestring):
+ return False
+ try:
+ iter(value)
+ return True
+ except TypeError:
+ return False
+
+
+def Proto2Cmp(a, b):
+ """Compares two proto2 objects field by field, in ascending tag order.
+
+ Recurses into nested messages. Uses list (not set) semantics for comparing
+ repeated fields, ie duplicates and order matter. If one field is a prefix of
+ the other, the longer field is greater.
+
+ This function is intended to be used as a python cmp function, e.g. in sort.
+
+ Ordering fields by tag number has precedent in other google code, but it's
+ still somewhat arbitrary. The main value is to provide *some* stable ordering
+ for proto2 messages.
+
+ This would be easier as a__cmp__ method or set of __le__, __gt__, etc methods
+ in the proto2 Message class itself. That would take a little more care,
+ though, and probably some significant debate over whether they should exist at
+ all, so this was easier.
+
+ Args:
+ a, b: proto2 messages or primitives
+
+ Returns: integer > 0 if a > b, < 0 if a < b, 0 if a == b
+ """
+ def Format(pb):
+ """Returns a dictionary that maps tag number (for messages) or element index
+ (for repeated fields) to value, or just pb unchanged if it's neither."""
+ if isinstance(pb, message.Message):
+ return dict((desc.number, value) for desc, value in pb.ListFields())
+ elif _IsRepeatedContainer(pb):
+ return dict(enumerate(list(pb)))
+ else:
+ return pb
+
+ a, b = Format(a), Format(b)
+
+ # base case
+ if not isinstance(a, dict) or not isinstance(b, dict):
+ return cmp(a, b)
+
+ # this list performs double duty: it compares two messages by tag value *or*
+ # two repeated fields by element, in order. the magic is in the format()
+ # function, which converts them both to the same easily comparable format.
+ for tag in sorted(set(a.keys() + b.keys())):
+ if tag not in a:
+ return -1 # b is greater
+ elif tag not in b:
+ return 1 # a is greater
+ else:
+ # recursive step
+ cmped = Proto2Cmp(a[tag], b[tag])
+ if cmped != 0:
+ return cmped
+
+ # didn't find any values that differed, so they're equal!
+ return 0
+
+
+class Proto2Assertions(object):
+ """Mix this into a googletest.TestCase class to get proto2 assertions.
+
+ Usage:
+
+ class SomeTestCase(compare.Proto2Assertions, googletest.TestCase):
+ ...
+ def testSomething(self):
+ ...
+ self.assertProto2Equal(a, b)
+
+ See module-level definitions for method documentation.
+ """
+
+ # pylint: disable=invalid-name
+ def assertProto2Equal(self, *args, **kwargs):
+ return assertProto2Equal(self, *args, **kwargs)
+
+ def assertProto2SameElements(self, *args, **kwargs):
+ return assertProto2SameElements(self, *args, **kwargs)
+
+ def assertProto2Contains(self, *args, **kwargs):
+ return assertProto2Contains(self, *args, **kwargs)
diff --git a/tensorflow/python/util/protobuf/compare_test.proto b/tensorflow/python/util/protobuf/compare_test.proto
new file mode 100644
index 0000000000..fa0b5de9f0
--- /dev/null
+++ b/tensorflow/python/util/protobuf/compare_test.proto
@@ -0,0 +1,49 @@
+// Test messages used in compare_test.py.
+syntax = "proto2";
+
+package compare_test;
+// option cc_enable_arenas = true;
+
+enum Enum {
+ A = 0;
+ B = 1;
+ C = 2;
+}
+
+message Small {
+ repeated string strings = 1;
+};
+
+message Medium {
+ repeated int32 int32s = 1;
+ repeated Small smalls = 2;
+ repeated group GroupA = 3 {
+ repeated group GroupB = 4 {
+ required string strings = 5;
+ }
+ }
+ repeated float floats = 6;
+};
+
+message Large {
+ optional string string_ = 1;
+ optional int64 int64_ = 2;
+ optional float float_ = 3;
+ optional bool bool_ = 4;
+ optional Enum enum_ = 5;
+ repeated int64 int64s = 6;
+ optional Medium medium = 7;
+ optional Small small = 8;
+ optional double double_ = 9;
+ optional WithMap with_map = 10;
+};
+
+message Labeled {
+ required int32 required = 1;
+ optional int32 optional = 2;
+}
+
+message WithMap {
+ map<int32, Small> value_message = 1;
+ map<string, string> value_string = 2;
+}
diff --git a/tensorflow/python/util/protobuf/compare_test.py b/tensorflow/python/util/protobuf/compare_test.py
new file mode 100644
index 0000000000..9a03d123ae
--- /dev/null
+++ b/tensorflow/python/util/protobuf/compare_test.py
@@ -0,0 +1,652 @@
+#!/usr/bin/python2.4
+
+"""Tests for python.util.protobuf.compare."""
+
+import copy
+import re
+import textwrap
+
+from tensorflow.python.platform import googletest
+from tensorflow.python.util.protobuf import compare
+from tensorflow.python.util.protobuf import compare_test_pb2
+
+from google.protobuf import text_format
+
+
+def LargePbs(*args):
+ """Converts ASCII string Large PBs to messages."""
+ pbs = []
+ for arg in args:
+ pb = compare_test_pb2.Large()
+ text_format.Merge(arg, pb)
+ pbs.append(pb)
+
+ return pbs
+
+
+class Proto2CmpTest(googletest.TestCase):
+
+ def assertGreater(self, a, b):
+ """Asserts that Proto2Cmp says a > b."""
+ a, b = LargePbs(a, b)
+ googletest.TestCase.assertGreater(self, compare.Proto2Cmp(a, b), 0)
+ googletest.TestCase.assertLess(self, compare.Proto2Cmp(b, a), 0)
+
+ def assertEquals(self, a, b):
+ """Asserts that Proto2Cmp says a == b."""
+ a, b = LargePbs(a, b)
+ googletest.TestCase.assertEquals(self, compare.Proto2Cmp(a, b), 0)
+
+ def testPrimitives(self):
+ googletest.TestCase.assertEqual(self, 0, compare.Proto2Cmp('a', 'a'))
+ googletest.TestCase.assertLess(self, 0, compare.Proto2Cmp('b', 'a'))
+
+ pb = compare_test_pb2.Large()
+ googletest.TestCase.assertEquals(self, cmp('a', pb), compare.Proto2Cmp('a', pb))
+ googletest.TestCase.assertEqual(self, cmp(pb, 'a'), compare.Proto2Cmp(pb, 'a'))
+
+ def testEmpty(self):
+ self.assertEquals('', '')
+
+ def testPrimitiveFields(self):
+ self.assertGreater('string_: "a"', '')
+ self.assertEquals('string_: "a"', 'string_: "a"')
+ self.assertGreater('string_: "b"', 'string_: "a"')
+ self.assertGreater('string_: "ab"', 'string_: "aa"')
+
+ self.assertGreater('int64_: 0', '')
+ self.assertEquals('int64_: 0', 'int64_: 0')
+ self.assertGreater('int64_: -1', '')
+ self.assertGreater('int64_: 1', 'int64_: 0')
+ self.assertGreater('int64_: 0', 'int64_: -1')
+
+ self.assertGreater('float_: 0.0', '')
+ self.assertEquals('float_: 0.0', 'float_: 0.0')
+ self.assertGreater('float_: -0.1', '')
+ self.assertGreater('float_: 3.14', 'float_: 0')
+ self.assertGreater('float_: 0', 'float_: -0.1')
+ self.assertEquals('float_: -0.1', 'float_: -0.1')
+
+ self.assertGreater('bool_: true', '')
+ self.assertGreater('bool_: false', '')
+ self.assertGreater('bool_: true', 'bool_: false')
+ self.assertEquals('bool_: false', 'bool_: false')
+ self.assertEquals('bool_: true', 'bool_: true')
+
+ self.assertGreater('enum_: A', '')
+ self.assertGreater('enum_: B', 'enum_: A')
+ self.assertGreater('enum_: C', 'enum_: B')
+ self.assertEquals('enum_: C', 'enum_: C')
+
+ def testRepeatedPrimitives(self):
+ self.assertGreater('int64s: 0', '')
+ self.assertEquals('int64s: 0', 'int64s: 0')
+ self.assertGreater('int64s: 1', 'int64s: 0')
+ self.assertGreater('int64s: 0 int64s: 0', '')
+ self.assertGreater('int64s: 0 int64s: 0', 'int64s: 0')
+ self.assertGreater('int64s: 1 int64s: 0', 'int64s: 0')
+ self.assertGreater('int64s: 0 int64s: 1', 'int64s: 0')
+ self.assertGreater('int64s: 1', 'int64s: 0 int64s: 2')
+ self.assertGreater('int64s: 2 int64s: 0', 'int64s: 1')
+ self.assertEquals('int64s: 0 int64s: 0', 'int64s: 0 int64s: 0')
+ self.assertEquals('int64s: 0 int64s: 1', 'int64s: 0 int64s: 1')
+ self.assertGreater('int64s: 1 int64s: 0', 'int64s: 0 int64s: 0')
+ self.assertGreater('int64s: 1 int64s: 0', 'int64s: 0 int64s: 1')
+ self.assertGreater('int64s: 1 int64s: 0', 'int64s: 0 int64s: 2')
+ self.assertGreater('int64s: 1 int64s: 1', 'int64s: 1 int64s: 0')
+ self.assertGreater('int64s: 1 int64s: 1', 'int64s: 1 int64s: 0 int64s: 2')
+
+ def testMessage(self):
+ self.assertGreater('small <>', '')
+ self.assertEquals('small <>', 'small <>')
+ self.assertGreater('small < strings: "a" >', '')
+ self.assertGreater('small < strings: "a" >', 'small <>')
+ self.assertEquals('small < strings: "a" >', 'small < strings: "a" >')
+ self.assertGreater('small < strings: "b" >', 'small < strings: "a" >')
+ self.assertGreater('small < strings: "a" strings: "b" >',
+ 'small < strings: "a" >')
+
+ self.assertGreater('string_: "a"', 'small <>')
+ self.assertGreater('string_: "a"', 'small < strings: "b" >')
+ self.assertGreater('string_: "a"', 'small < strings: "b" strings: "c" >')
+ self.assertGreater('string_: "a" small <>', 'small <>')
+ self.assertGreater('string_: "a" small <>', 'small < strings: "b" >')
+ self.assertEquals('string_: "a" small <>', 'string_: "a" small <>')
+ self.assertGreater('string_: "a" small < strings: "a" >',
+ 'string_: "a" small <>')
+ self.assertEquals('string_: "a" small < strings: "a" >',
+ 'string_: "a" small < strings: "a" >')
+ self.assertGreater('string_: "a" small < strings: "a" >',
+ 'int64_: 1 small < strings: "a" >')
+ self.assertGreater('string_: "a" small < strings: "a" >', 'int64_: 1')
+ self.assertGreater('string_: "a"', 'int64_: 1 small < strings: "a" >')
+ self.assertGreater('string_: "a" int64_: 0 small < strings: "a" >',
+ 'int64_: 1 small < strings: "a" >')
+ self.assertGreater('string_: "a" int64_: 1 small < strings: "a" >',
+ 'string_: "a" int64_: 0 small < strings: "a" >')
+ self.assertEquals('string_: "a" int64_: 0 small < strings: "a" >',
+ 'string_: "a" int64_: 0 small < strings: "a" >')
+
+ def testNestedMessage(self):
+ self.assertGreater('medium <>', '')
+ self.assertEquals('medium <>', 'medium <>')
+ self.assertGreater('medium < smalls <> >', 'medium <>')
+ self.assertEquals('medium < smalls <> >', 'medium < smalls <> >')
+ self.assertGreater('medium < smalls <> smalls <> >', 'medium < smalls <> >')
+ self.assertEquals('medium < smalls <> smalls <> >',
+ 'medium < smalls <> smalls <> >')
+
+ self.assertGreater('medium < int32s: 0 >', 'medium < smalls <> >')
+
+ self.assertGreater('medium < smalls < strings: "a"> >',
+ 'medium < smalls <> >')
+
+ def testTagOrder(self):
+ """Tests that different fields are ordered by tag number.
+
+ For reference, here are the relevant tag numbers from compare_test.proto:
+ optional string string_ = 1;
+ optional int64 int64_ = 2;
+ optional float float_ = 3;
+ optional Small small = 8;
+ optional Medium medium = 7;
+ optional Small small = 8;
+ """
+ self.assertGreater('string_: "a" ',
+ ' int64_: 1 ')
+ self.assertGreater('string_: "a" int64_: 2 ',
+ ' int64_: 1 ')
+ self.assertGreater('string_: "b" int64_: 1 ',
+ 'string_: "a" int64_: 2 ')
+ self.assertEquals( 'string_: "a" int64_: 1 ',
+ 'string_: "a" int64_: 1 ')
+ self.assertGreater('string_: "a" int64_: 1 float_: 0.0',
+ 'string_: "a" int64_: 1 ')
+ self.assertEquals( 'string_: "a" int64_: 1 float_: 0.0',
+ 'string_: "a" int64_: 1 float_: 0.0')
+ self.assertGreater('string_: "a" int64_: 1 float_: 0.1',
+ 'string_: "a" int64_: 1 float_: 0.0')
+ self.assertGreater('string_: "a" int64_: 2 float_: 0.0',
+ 'string_: "a" int64_: 1 float_: 0.1')
+ self.assertGreater('string_: "a" ',
+ ' int64_: 1 float_: 0.1')
+ self.assertGreater('string_: "a" float_: 0.0',
+ ' int64_: 1 ')
+ self.assertGreater('string_: "b" float_: 0.0',
+ 'string_: "a" int64_: 1 ')
+
+ self.assertGreater('string_: "a"',
+ 'small < strings: "a" >')
+ self.assertGreater('string_: "a" small < strings: "a" >',
+ 'small < strings: "b" >')
+ self.assertGreater('string_: "a" small < strings: "b" >',
+ 'string_: "a" small < strings: "a" >')
+ self.assertEquals('string_: "a" small < strings: "a" >',
+ 'string_: "a" small < strings: "a" >')
+
+ self.assertGreater('string_: "a" medium <>',
+ 'string_: "a" small < strings: "a" >')
+ self.assertGreater('string_: "a" medium < smalls <> >',
+ 'string_: "a" small < strings: "a" >')
+ self.assertGreater('medium <>', 'small < strings: "a" >')
+ self.assertGreater('medium <> small <>', 'small < strings: "a" >')
+ self.assertGreater('medium < smalls <> >', 'small < strings: "a" >')
+ self.assertGreater('medium < smalls < strings: "a" > >',
+ 'small < strings: "b" >')
+
+
+class NormalizeRepeatedFieldsTest(googletest.TestCase):
+
+ def assertNormalizes(self, orig, expected_no_dedupe, expected_dedupe):
+ """Checks NormalizeRepeatedFields(orig) against the two expected results."""
+ orig, expected_no_dedupe, expected_dedupe = LargePbs(
+ orig, expected_no_dedupe, expected_dedupe)
+
+ actual = compare.NormalizeRepeatedFields(copy.deepcopy(orig), dedupe=False)
+ self.assertEqual(expected_no_dedupe, actual)
+
+ actual = compare.NormalizeRepeatedFields(copy.deepcopy(orig), dedupe=True)
+ self.assertEqual(expected_dedupe, actual)
+
+ def testIgnoreNonRepeatedFields(self):
+ orig = """string_: "a" int64_: 1 float_: 0.1 bool_: true enum_: A
+ medium: {} small: {}"""
+ self.assertNormalizes(orig, orig, orig)
+
+ def testRepeatedPrimitive(self):
+ self.assertNormalizes('int64s: 3 int64s: -1 int64s: 2 int64s: -1 int64s: 3',
+ 'int64s: -1 int64s: -1 int64s: 2 int64s: 3 int64s: 3',
+ 'int64s: -1 int64s: 2 int64s: 3')
+
+ def testRepeatedMessage(self):
+ self.assertNormalizes("""medium: { smalls: { strings: "c" }
+ smalls: { strings: "a" }
+ smalls: { strings: "b" }
+ smalls: { strings: "a" }
+ smalls: { strings: "c" } }
+ """,
+ """medium: { smalls: { strings: "a" }
+ smalls: { strings: "a" }
+ smalls: { strings: "b" }
+ smalls: { strings: "c" }
+ smalls: { strings: "c" } }
+ """,
+ """medium: { smalls: { strings: "a" }
+ smalls: { strings: "b" }
+ smalls: { strings: "c" } }
+ """)
+
+ def testNestedRepeatedGroup(self):
+ self.assertNormalizes("""medium { GroupA { GroupB { strings: "c" }
+ GroupB { strings: "a" }
+ GroupB { strings: "b" }
+ GroupB { strings: "a" }
+ GroupB { strings: "c" } } }
+ """,
+ """medium { GroupA { GroupB { strings: "a" }
+ GroupB { strings: "a" }
+ GroupB { strings: "b" }
+ GroupB { strings: "c" }
+ GroupB { strings: "c" } } }
+ """,
+ """medium { GroupA { GroupB { strings: "a" }
+ GroupB { strings: "b" }
+ GroupB { strings: "c" } } }
+ """)
+
+ def testMapNormalizes(self):
+ self.assertNormalizes(
+ """with_map: { value_message: { key: 2, value: { strings: "k2v1",
+ strings: "k2v2",
+ strings: "k2v1" } },
+ value_message: { key: 1, value: { strings: "k1v2",
+ strings: "k1v1" } } }
+ """,
+ """with_map: { value_message: { key: 1, value: { strings: "k1v1",
+ strings: "k1v2" } },
+ value_message: { key: 2, value: { strings: "k2v1",
+ strings: "k2v1",
+ strings: "k2v2" } } }
+ """,
+ """with_map: { value_message: { key: 1, value: { strings: "k1v1",
+ strings: "k1v2" } },
+ value_message: { key: 2, value: { strings: "k2v1",
+ strings: "k2v2" } } }
+ """)
+
+
+class NormalizeNumbersTest(googletest.TestCase):
+ """Tests for NormalizeNumberFields()."""
+
+ def testNormalizesInts(self):
+ pb = compare_test_pb2.Large()
+ pb.int64_ = 4
+ compare.NormalizeNumberFields(pb)
+ self.assertTrue(isinstance(pb.int64_, long))
+
+ pb.int64_ = 4L
+ compare.NormalizeNumberFields(pb)
+ self.assertTrue(isinstance(pb.int64_, long))
+
+ pb.int64_ = 9999999999999999L
+ compare.NormalizeNumberFields(pb)
+ self.assertTrue(isinstance(pb.int64_, long))
+
+ def testNormalizesRepeatedInts(self):
+ pb = compare_test_pb2.Large()
+ pb.int64s.extend([1L, 400, 999999999999999L])
+ compare.NormalizeNumberFields(pb)
+ self.assertTrue(isinstance(pb.int64s[0], long))
+ self.assertTrue(isinstance(pb.int64s[1], long))
+ self.assertTrue(isinstance(pb.int64s[2], long))
+
+ def testNormalizesFloats(self):
+ pb1 = compare_test_pb2.Large()
+ pb1.float_ = 1.2314352351231
+ pb2 = compare_test_pb2.Large()
+ pb2.float_ = 1.231435
+ self.assertNotEqual(pb1.float_, pb2.float_)
+ compare.NormalizeNumberFields(pb1)
+ compare.NormalizeNumberFields(pb2)
+ self.assertEqual(pb1.float_, pb2.float_)
+
+ def testNormalizesRepeatedFloats(self):
+ pb = compare_test_pb2.Large()
+ pb.medium.floats.extend([0.111111111, 0.111111])
+ compare.NormalizeNumberFields(pb)
+ for value in pb.medium.floats:
+ self.assertAlmostEqual(0.111111, value)
+
+ def testNormalizesDoubles(self):
+ pb1 = compare_test_pb2.Large()
+ pb1.double_ = 1.2314352351231
+ pb2 = compare_test_pb2.Large()
+ pb2.double_ = 1.2314352
+ self.assertNotEqual(pb1.double_, pb2.double_)
+ compare.NormalizeNumberFields(pb1)
+ compare.NormalizeNumberFields(pb2)
+ self.assertEqual(pb1.double_, pb2.double_)
+
+ def testNormalizesMaps(self):
+ pb = compare_test_pb2.WithMap()
+ pb.value_message[4].strings.extend(['a', 'b', 'c'])
+ pb.value_string['d'] = 'e'
+ compare.NormalizeNumberFields(pb)
+
+
+class AssertTest(googletest.TestCase):
+ """Tests both assertProto2Equal() and assertProto2SameElements()."""
+ def assertProto2Equal(self, a, b, **kwargs):
+ if isinstance(a, basestring) and isinstance(b, basestring):
+ a, b = LargePbs(a, b)
+ compare.assertProto2Equal(self, a, b, **kwargs)
+
+ def assertProto2SameElements(self, a, b, **kwargs):
+ if isinstance(a, basestring) and isinstance(b, basestring):
+ a, b = LargePbs(a, b)
+ compare.assertProto2SameElements(self, a, b, **kwargs)
+
+ def assertAll(self, a, **kwargs):
+ """Checks that all possible asserts pass."""
+ self.assertProto2Equal(a, a, **kwargs)
+ self.assertProto2SameElements(a, a, number_matters=False, **kwargs)
+ self.assertProto2SameElements(a, a, number_matters=True, **kwargs)
+
+ def assertSameNotEqual(self, a, b):
+ """Checks that assertProto2SameElements() passes with number_matters=False
+ and number_matters=True but not assertProto2Equal().
+ """
+ self.assertProto2SameElements(a, b, number_matters=False)
+ self.assertProto2SameElements(a, b, number_matters=True)
+ self.assertRaises(AssertionError, self.assertProto2Equal, a, b)
+
+ def assertSameExceptNumber(self, a, b):
+ """Checks that assertProto2SameElements() passes with number_matters=False
+ but not number_matters=True or assertProto2Equal().
+ """
+ self.assertProto2SameElements(a, b, number_matters=False)
+ self.assertRaises(AssertionError, self.assertProto2SameElements, a, b,
+ number_matters=True)
+ self.assertRaises(AssertionError, self.assertProto2Equal, a, b)
+
+ def assertNone(self, a, b, message, **kwargs):
+ """Checks that all possible asserts fail with the given message."""
+ message = re.escape(textwrap.dedent(message))
+ self.assertRaisesRegexp(AssertionError, message,
+ self.assertProto2SameElements, a, b,
+ number_matters=False, **kwargs)
+ self.assertRaisesRegexp(AssertionError, message,
+ self.assertProto2SameElements, a, b,
+ number_matters=True, **kwargs)
+ self.assertRaisesRegexp(AssertionError, message,
+ self.assertProto2Equal, a, b, **kwargs)
+
+ def testCheckInitialized(self):
+ # neither is initialized
+ a = compare_test_pb2.Labeled()
+ a.optional = 1
+ self.assertNone(a, a, 'Initialization errors: ', check_initialized=True)
+ self.assertAll(a, check_initialized=False)
+
+ # a is initialized, b isn't
+ b = copy.deepcopy(a)
+ a.required = 2
+ self.assertNone(a, b, 'Initialization errors: ', check_initialized=True)
+ self.assertNone(a, b,
+ """
+ - required: 2
+ optional: 1
+ """,
+ check_initialized=False)
+
+ # both are initialized
+ a = compare_test_pb2.Labeled()
+ a.required = 2
+ self.assertAll(a, check_initialized=True)
+ self.assertAll(a, check_initialized=False)
+
+ b = copy.deepcopy(a)
+ b.required = 3
+ message = """
+ - required: 2
+ ? ^
+ + required: 3
+ ? ^
+ """
+ self.assertNone(a, b, message, check_initialized=True)
+ self.assertNone(a, b, message, check_initialized=False)
+
+ def testAssertEqualWithStringArg(self):
+ pb = compare_test_pb2.Large()
+ pb.string_ = 'abc'
+ pb.float_ = 1.234
+ compare.assertProto2Equal(
+ self,
+ """
+ string_: 'abc'
+ float_: 1.234
+ """,
+ pb)
+
+ def testAssertSameElementsWithStringArg(self):
+ pb = compare_test_pb2.Large()
+ pb.string_ = 'abc'
+ pb.float_ = 1.234
+ pb.int64s.extend([7, 3, 5])
+ compare.assertProto2SameElements(
+ self,
+ """
+ string_: 'abc'
+ float_: 1.234
+ int64s: 3
+ int64s: 7
+ int64s: 5
+ """,
+ pb)
+
+ def testProto2ContainsString(self):
+ pb = compare_test_pb2.Large()
+ pb.string_ = 'abc'
+ pb.float_ = 1.234
+ pb.small.strings.append('xyz')
+ compare.assertProto2Contains(
+ self,
+ """
+ small {
+ strings: "xyz"
+ }
+ """,
+ pb)
+
+ def testProto2ContainsProto(self):
+ pb = compare_test_pb2.Large()
+ pb.string_ = 'abc'
+ pb.float_ = 1.234
+ pb.small.strings.append('xyz')
+ pb2 = compare_test_pb2.Large()
+ pb2.small.strings.append('xyz')
+ compare.assertProto2Contains(
+ self, pb2, pb)
+
+ def testNormalizesNumbers(self):
+ pb1 = compare_test_pb2.Large()
+ pb1.int64_ = 4
+ pb2 = compare_test_pb2.Large()
+ pb2.int64_ = 4L
+ compare.assertProto2Equal(self, pb1, pb2)
+
+ def testNormalizesFloat(self):
+ pb1 = compare_test_pb2.Large()
+ pb1.double_ = 4.0
+ pb2 = compare_test_pb2.Large()
+ pb2.double_ = 4L
+ compare.assertProto2Equal(self, pb1, pb2, normalize_numbers=True)
+
+ pb1 = compare_test_pb2.Medium()
+ pb1.floats.extend([4.0, 6.0])
+ pb2 = compare_test_pb2.Medium()
+ pb2.floats.extend([6L, 4L])
+ compare.assertProto2SameElements(self, pb1, pb2, normalize_numbers=True)
+
+ def testPrimitives(self):
+ self.assertAll('string_: "x"')
+ self.assertNone('string_: "x"',
+ 'string_: "y"',
+ """
+ - string_: "x"
+ ? ^
+ + string_: "y"
+ ? ^
+ """)
+
+ def testRepeatedPrimitives(self):
+ self.assertAll('int64s: 0 int64s: 1')
+
+ self.assertSameNotEqual('int64s: 0 int64s: 1', 'int64s: 1 int64s: 0')
+ self.assertSameNotEqual('int64s: 0 int64s: 1 int64s: 2',
+ 'int64s: 2 int64s: 1 int64s: 0')
+
+ self.assertSameExceptNumber('int64s: 0', 'int64s: 0 int64s: 0')
+ self.assertSameExceptNumber('int64s: 0 int64s: 1',
+ 'int64s: 1 int64s: 0 int64s: 1')
+
+ self.assertNone('int64s: 0',
+ 'int64s: 0 int64s: 2',
+ """
+ int64s: 0
+ + int64s: 2
+ """)
+ self.assertNone('int64s: 0 int64s: 1',
+ 'int64s: 0 int64s: 2',
+ """
+ int64s: 0
+ - int64s: 1
+ ? ^
+ + int64s: 2
+ ? ^
+ """)
+
+ def testMessage(self):
+ self.assertAll('medium: {}')
+ self.assertAll('medium: { smalls: {} }')
+ self.assertAll('medium: { int32s: 1 smalls: {} }')
+ self.assertAll('medium: { smalls: { strings: "x" } }')
+ self.assertAll('medium: { smalls: { strings: "x" } } small: { strings: "y" }')
+
+ self.assertSameNotEqual(
+ 'medium: { smalls: { strings: "x" strings: "y" } }',
+ 'medium: { smalls: { strings: "y" strings: "x" } }')
+ self.assertSameNotEqual(
+ 'medium: { smalls: { strings: "x" } smalls: { strings: "y" } }',
+ 'medium: { smalls: { strings: "y" } smalls: { strings: "x" } }')
+
+ self.assertSameExceptNumber(
+ 'medium: { smalls: { strings: "x" strings: "y" strings: "x" } }',
+ 'medium: { smalls: { strings: "y" strings: "x" } }')
+ self.assertSameExceptNumber(
+ 'medium: { smalls: { strings: "x" } int32s: 0 }',
+ 'medium: { int32s: 0 smalls: { strings: "x" } int32s: 0 }')
+
+ self.assertNone('medium: {}',
+ 'medium: { smalls: { strings: "x" } }',
+ """
+ medium {
+ + smalls {
+ + strings: "x"
+ + }
+ }
+ """)
+ self.assertNone('medium: { smalls: { strings: "x" } }',
+ 'medium: { smalls: {} }',
+ """
+ medium {
+ smalls {
+ - strings: "x"
+ }
+ }
+ """)
+ self.assertNone('medium: { int32s: 0 }',
+ 'medium: { int32s: 1 }',
+ """
+ medium {
+ - int32s: 0
+ ? ^
+ + int32s: 1
+ ? ^
+ }
+ """)
+
+ def testMsgPassdown(self):
+ self.assertRaisesRegexp(AssertionError, 'test message passed down',
+ self.assertProto2Equal,
+ 'medium: {}',
+ 'medium: { smalls: { strings: "x" } }',
+ msg='test message passed down')
+
+ def testRepeatedMessage(self):
+ self.assertAll('medium: { smalls: {} smalls: {} }')
+ self.assertAll('medium: { smalls: { strings: "x" } } medium: {}')
+ self.assertAll('medium: { smalls: { strings: "x" } } medium: { int32s: 0 }')
+ self.assertAll('medium: { smalls: {} smalls: { strings: "x" } } small: {}')
+
+ self.assertSameNotEqual('medium: { smalls: { strings: "x" } smalls: {} }',
+ 'medium: { smalls: {} smalls: { strings: "x" } }')
+
+ self.assertSameExceptNumber('medium: { smalls: {} }',
+ 'medium: { smalls: {} smalls: {} }')
+ self.assertSameExceptNumber('medium: { smalls: {} smalls: {} } medium: {}',
+ 'medium: {} medium: {} medium: { smalls: {} }')
+ self.assertSameExceptNumber(
+ 'medium: { smalls: { strings: "x" } smalls: {} }',
+ 'medium: { smalls: {} smalls: { strings: "x" } smalls: {} }')
+
+ self.assertNone('medium: {}',
+ 'medium: {} medium { smalls: {} }',
+ """
+ medium {
+ + smalls {
+ + }
+ }
+ """)
+ self.assertNone('medium: { smalls: {} smalls: { strings: "x" } }',
+ 'medium: { smalls: {} smalls: { strings: "y" } }',
+ """
+ medium {
+ smalls {
+ }
+ smalls {
+ - strings: "x"
+ ? ^
+ + strings: "y"
+ ? ^
+ }
+ }
+ """)
+
+
+class MixinTests(compare.Proto2Assertions, googletest.TestCase):
+
+ def testAssertEqualWithStringArg(self):
+ pb = compare_test_pb2.Large()
+ pb.string_ = 'abc'
+ pb.float_ = 1.234
+ self.assertProto2Equal(
+ """
+ string_: 'abc'
+ float_: 1.234
+ """,
+ pb)
+
+ def testAssertSameElements(self):
+ a = compare_test_pb2.Large()
+ a.string_ = 'abc'
+ a.float_ = 1.234
+ a.int64s[:] = [4, 3, 2]
+ b = compare_test_pb2.Large()
+ b.CopyFrom(a)
+ b.int64s[:] = [2, 4, 3]
+ self.assertProto2SameElements(a, b)
+
+
+if __name__ == '__main__':
+ googletest.main()