aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--RELEASE.md8
-rw-r--r--WORKSPACE7
-rw-r--r--eigen.BUILD2
-rw-r--r--farmhash.BUILD8
-rw-r--r--gif.BUILD26
-rw-r--r--gmock.BUILD2
-rw-r--r--grpc.BUILD2
-rw-r--r--jsoncpp.BUILD2
-rw-r--r--nanopb.BUILD2
-rw-r--r--png.BUILD2
-rw-r--r--six.BUILD2
-rw-r--r--tensorflow/BUILD3
-rw-r--r--tensorflow/cc/saved_model/BUILD10
-rw-r--r--tensorflow/cc/saved_model/loader_test.cc5
-rw-r--r--tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/assets/foo.txt1
-rw-r--r--tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/saved_model.pbtxt1951
-rw-r--r--tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.data-00000-of-00001bin8 -> 0 bytes
-rw-r--r--tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.indexbin134 -> 0 bytes
-rw-r--r--tensorflow/cc/saved_model/testdata/half_plus_two_sharded/assets/foo.txt1
-rw-r--r--tensorflow/cc/saved_model/testdata/half_plus_two_sharded/saved_model.pbbin7331 -> 0 bytes
-rw-r--r--tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.data-00000-of-00001bin8 -> 0 bytes
-rw-r--r--tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.indexbin134 -> 0 bytes
-rw-r--r--tensorflow/cc/training/queue_runner.cc15
-rw-r--r--tensorflow/cc/training/queue_runner.h18
-rw-r--r--tensorflow/cc/training/queue_runner_test.cc16
-rw-r--r--tensorflow/contrib/android/cmake/CMakeLists.txt61
-rw-r--r--tensorflow/contrib/android/cmake/README.md44
-rw-r--r--tensorflow/contrib/android/cmake/build.gradle97
-rw-r--r--tensorflow/contrib/android/cmake/src/main/AndroidManifest.xml9
-rw-r--r--tensorflow/contrib/android/cmake/src/main/res/values/strings.xml3
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/special_math.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/beta.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijector.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/dirichlet.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/distribution.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/distribution_util.py10
-rw-r--r--tensorflow/contrib/distributions/python/ops/gamma.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/gumbel.py205
-rw-r--r--tensorflow/contrib/distributions/python/ops/inverse_gamma.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/logistic.py210
-rw-r--r--tensorflow/contrib/distributions/python/ops/onehot_categorical.py262
-rw-r--r--tensorflow/contrib/distributions/python/ops/quantized_distribution.py26
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py213
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py420
-rw-r--r--tensorflow/contrib/distributions/python/ops/shape.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/student_t.py6
-rw-r--r--tensorflow/contrib/distributions/python/ops/transformed_distribution.py6
-rw-r--r--tensorflow/contrib/distributions/python/ops/uniform.py8
-rw-r--r--tensorflow/contrib/distributions/python/ops/wishart.py2
-rw-r--r--tensorflow/contrib/factorization/python/ops/kmeans.py2
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util.py107
-rw-r--r--tensorflow/contrib/framework/python/ops/variables.py42
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/ops.py8
-rw-r--r--tensorflow/contrib/layers/__init__.py8
-rw-r--r--tensorflow/contrib/layers/python/layers/embedding_ops.py2
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column.py20
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_ops_test.py19
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py114
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py28
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers.py2
-rw-r--r--tensorflow/contrib/learn/BUILD30
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/constants.py (renamed from tensorflow/python/util/net_lib.py)12
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn.py69
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py62
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py90
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_test.py119
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py181
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py136
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py114
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py117
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py716
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head_test.py232
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear.py33
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear_test.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/model_fn.py35
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/random_forest.py17
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/svm.py18
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/array_ops.py21
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/gc.py205
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/gc_test.py120
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py97
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py248
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py228
-rw-r--r--tensorflow/contrib/linalg/BUILD2
-rw-r--r--tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py103
-rw-r--r--tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py46
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator.py96
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_diag.py117
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py226
-rw-r--r--tensorflow/contrib/linear_optimizer/BUILD44
-rw-r--r--tensorflow/contrib/linear_optimizer/__init__.py2
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py93
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py233
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py167
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py97
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py114
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py51
-rw-r--r--tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py7
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops.py68
-rw-r--r--tensorflow/contrib/makefile/Makefile2
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt2
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py10
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py3
-rw-r--r--tensorflow/contrib/session_bundle/BUILD9
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim.cc23
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim_test.cc45
-rw-r--r--tensorflow/contrib/session_bundle/session_bundle.cc11
-rw-r--r--tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/saved_model.pbbin6491 -> 0 bytes
-rw-r--r--tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/variables/variables.data-00000-of-00001bin8 -> 0 bytes
-rw-r--r--tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/variables/variables.indexbin134 -> 0 bytes
-rw-r--r--tensorflow/contrib/slim/python/slim/evaluation.py285
-rw-r--r--tensorflow/contrib/slim/python/slim/evaluation_test.py144
-rw-r--r--tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc27
-rw-r--r--tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc25
-rw-r--r--tensorflow/contrib/tensor_forest/core/ops/tree_utils.h5
-rw-r--r--tensorflow/contrib/tensorboard/BUILD30
-rw-r--r--tensorflow/contrib/tensorboard/plugins/__init__.py1
-rw-r--r--tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py2
-rw-r--r--tensorflow/contrib/tensorboard/plugins/trace/__init__.py (renamed from tensorflow/python/util/net_lib_test.py)25
-rw-r--r--tensorflow/contrib/tensorboard/plugins/trace/trace.py162
-rw-r--r--tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto60
-rw-r--r--tensorflow/contrib/tensorboard/plugins/trace/trace_test.py91
-rw-r--r--tensorflow/contrib/testing/python/framework/fake_summary_writer.py22
-rw-r--r--tensorflow/contrib/training/python/training/evaluation.py64
-rw-r--r--tensorflow/contrib/training/python/training/resample.py21
-rw-r--r--tensorflow/contrib/training/python/training/resample_test.py6
-rw-r--r--tensorflow/contrib/training/python/training/sampling_ops.py2
-rw-r--r--tensorflow/contrib/training/python/training/sampling_ops_test.py3
-rw-r--r--tensorflow/core/BUILD5
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc28
-rw-r--r--tensorflow/core/common_runtime/direct_session.h5
-rw-r--r--tensorflow/core/common_runtime/executor.cc8
-rw-r--r--tensorflow/core/common_runtime/executor.h2
-rw-r--r--tensorflow/core/common_runtime/function.cc4
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc21
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.h2
-rw-r--r--tensorflow/core/framework/allocator.h5
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc13
-rw-r--r--tensorflow/core/framework/function.h5
-rw-r--r--tensorflow/core/framework/op_kernel.cc2
-rw-r--r--tensorflow/core/framework/op_kernel.h11
-rw-r--r--tensorflow/core/framework/resource_mgr.h27
-rw-r--r--tensorflow/core/graph/graph_partition.cc38
-rw-r--r--tensorflow/core/graph/graph_partition_test.cc32
-rw-r--r--tensorflow/core/kernels/BUILD1581
-rw-r--r--tensorflow/core/kernels/conv_ops.cc3
-rw-r--r--tensorflow/core/kernels/eigen_pooling.h6
-rw-r--r--tensorflow/core/kernels/hexagon/BUILD4
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.cc205
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.h69
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer_test.cc25
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc43
-rw-r--r--tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.cc2
-rw-r--r--tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h2
-rw-r--r--tensorflow/core/kernels/quantization_utils.h7
-rw-r--r--tensorflow/core/kernels/quantization_utils_test.cc59
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc16
-rw-r--r--tensorflow/core/kernels/scatter_nd_op_test.cc10
-rw-r--r--tensorflow/core/kernels/sparse_dense_binary_op_shared.cc4
-rw-r--r--tensorflow/core/kernels/sparse_matmul_op.h106
-rw-r--r--tensorflow/core/kernels/sparse_matmul_op_test.cc12
-rw-r--r--tensorflow/core/kernels/stack_ops.cc19
-rw-r--r--tensorflow/core/kernels/string_split_op.cc6
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc29
-rw-r--r--tensorflow/core/kernels/variable_ops.h11
-rw-r--r--tensorflow/core/lib/strings/str_util.h24
-rw-r--r--tensorflow/core/lib/strings/str_util_test.cc6
-rw-r--r--tensorflow/core/ops/array_ops_test.cc40
-rw-r--r--tensorflow/core/ops/ops.pbtxt4
-rw-r--r--tensorflow/core/ops/string_ops.cc9
-rw-r--r--tensorflow/core/platform/default/build_config/BUILD13
-rw-r--r--tensorflow/core/platform/default/logging.h2
-rw-r--r--tensorflow/core/platform/default/stacktrace.h4
-rw-r--r--tensorflow/core/platform/env_test.cc17
-rw-r--r--tensorflow/core/protobuf/config.proto13
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc5
-rw-r--r--tensorflow/examples/android/AndroidManifest.xml9
-rw-r--r--tensorflow/examples/android/BUILD25
-rw-r--r--tensorflow/examples/android/README.md39
-rw-r--r--tensorflow/examples/android/jni/box_coder_jni.cc92
-rw-r--r--tensorflow/examples/android/jni/object_tracking/config.h300
-rw-r--r--tensorflow/examples/android/jni/object_tracking/flow_cache.h306
-rw-r--r--tensorflow/examples/android/jni/object_tracking/frame_pair.cc308
-rw-r--r--tensorflow/examples/android/jni/object_tracking/frame_pair.h103
-rw-r--r--tensorflow/examples/android/jni/object_tracking/geom.h319
-rwxr-xr-xtensorflow/examples/android/jni/object_tracking/gl_utils.h55
-rw-r--r--tensorflow/examples/android/jni/object_tracking/image-inl.h642
-rw-r--r--tensorflow/examples/android/jni/object_tracking/image.h346
-rw-r--r--tensorflow/examples/android/jni/object_tracking/image_data.h270
-rw-r--r--tensorflow/examples/android/jni/object_tracking/image_neon.cc270
-rw-r--r--tensorflow/examples/android/jni/object_tracking/image_utils.h301
-rwxr-xr-xtensorflow/examples/android/jni/object_tracking/integral_image.h187
-rw-r--r--tensorflow/examples/android/jni/object_tracking/jni_utils.h62
-rw-r--r--tensorflow/examples/android/jni/object_tracking/keypoint.h48
-rw-r--r--tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc549
-rw-r--r--tensorflow/examples/android/jni/object_tracking/keypoint_detector.h133
-rw-r--r--tensorflow/examples/android/jni/object_tracking/log_streaming.h37
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_detector.cc27
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_detector.h232
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_model.h101
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_tracker.cc690
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_tracker.h271
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc463
-rw-r--r--tensorflow/examples/android/jni/object_tracking/optical_flow.cc490
-rw-r--r--tensorflow/examples/android/jni/object_tracking/optical_flow.h111
-rwxr-xr-xtensorflow/examples/android/jni/object_tracking/sprite.h205
-rw-r--r--tensorflow/examples/android/jni/object_tracking/time_log.cc29
-rw-r--r--tensorflow/examples/android/jni/object_tracking/time_log.h138
-rw-r--r--tensorflow/examples/android/jni/object_tracking/tracked_object.cc163
-rw-r--r--tensorflow/examples/android/jni/object_tracking/tracked_object.h191
-rw-r--r--tensorflow/examples/android/jni/object_tracking/utils.h386
-rwxr-xr-xtensorflow/examples/android/jni/object_tracking/utils_neon.cc151
-rw-r--r--tensorflow/examples/android/proto/box_coder.proto42
-rw-r--r--tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml30
-rw-r--r--tensorflow/examples/android/res/values/base-strings.xml5
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java11
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java317
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java218
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java381
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java649
-rw-r--r--tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py2
-rw-r--r--tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py2
-rw-r--r--tensorflow/examples/image_retraining/retrain.py132
-rw-r--r--tensorflow/examples/learn/wide_n_deep_tutorial.py4
-rw-r--r--tensorflow/examples/tutorials/mnist/fully_connected_feed.py2
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist_with_summaries.py5
-rw-r--r--tensorflow/g3doc/api_docs/python/client.md4
-rw-r--r--tensorflow/g3doc/api_docs/python/contrib.distributions.md46
-rw-r--r--tensorflow/g3doc/api_docs/python/contrib.framework.md28
-rw-r--r--tensorflow/g3doc/api_docs/python/contrib.layers.md58
-rw-r--r--tensorflow/g3doc/api_docs/python/contrib.learn.md98
-rw-r--r--tensorflow/g3doc/api_docs/python/contrib.learn.monitors.md6
-rw-r--r--tensorflow/g3doc/api_docs/python/contrib.linalg.md145
-rw-r--r--tensorflow/g3doc/api_docs/python/contrib.losses.md1
-rw-r--r--tensorflow/g3doc/api_docs/python/contrib.training.md5
-rw-r--r--tensorflow/g3doc/api_docs/python/control_flow_ops.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/framework.md78
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.framework.deprecated_args.md13
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.LinearRegressor.md9
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.linalg.LinearOperatorDiag.md78
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.training.weighted_resample.md5
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.summary.FileWriterCache.clear.md4
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.ExponentialMovingAverage.md4
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md76
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.distributions.TransformedDistribution.md46
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.layers.embedding_column.md4
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.learn.LinearClassifier.md9
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.losses.compute_weighted_loss.md1
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.framework.assert_scalar_int.md5
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.layers.shared_embedding_columns.md4
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.learn.monitors.ExportMonitor.md6
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.string_split.md13
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.learn.Estimator.md33
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.DNNClassifier.md14
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.select.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.framework.variable.md5
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.div.md19
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.make_template.md5
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.summary.FileWriterCache.get.md13
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.SummaryWriterCache.get.md4
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.summary.FileWriter.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.summary.FileWriterCache.md26
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriter.md184
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriterCache.md8
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.truediv.md15
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.GraphKeys.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Session.md4
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Variable.md82
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.framework.model_variable.md5
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.learn.LogisticRegressor.md33
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.linalg.LinearOperator.md67
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.nn.zero_fraction.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.summary_iterator.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/index.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/math_ops.md34
-rw-r--r--tensorflow/g3doc/api_docs/python/state_ops.md87
-rw-r--r--tensorflow/g3doc/api_docs/python/string_ops.md13
-rw-r--r--tensorflow/g3doc/api_docs/python/summary.md33
-rw-r--r--tensorflow/g3doc/api_docs/python/train.md200
-rw-r--r--tensorflow/g3doc/get_started/os_setup.md2
-rw-r--r--tensorflow/g3doc/how_tos/adding_an_op/index.md10
-rw-r--r--tensorflow/g3doc/how_tos/graph_viz/index.md8
-rw-r--r--tensorflow/g3doc/how_tos/hadoop/index.md8
-rw-r--r--tensorflow/g3doc/how_tos/image_retraining/index.md37
-rw-r--r--tensorflow/g3doc/how_tos/threading_and_queues/index.md4
-rw-r--r--tensorflow/g3doc/how_tos/variable_scope/index.md10
-rw-r--r--tensorflow/g3doc/tutorials/seq2seq/index.md10
-rw-r--r--tensorflow/g3doc/tutorials/wide/index.md4
-rw-r--r--tensorflow/g3doc/tutorials/wide_and_deep/index.md4
-rw-r--r--tensorflow/g3doc/tutorials/word2vec/index.md4
-rw-r--r--tensorflow/go/README.md3
-rw-r--r--tensorflow/java/BUILD61
-rw-r--r--tensorflow/java/README.md72
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java28
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/examples/.gitignore3
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/examples/BUILD25
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/examples/Example.java29
-rw-r--r--tensorflow/java/src/main/native/BUILD66
-rw-r--r--tensorflow/java/src/main/native/tensorflow.cc (renamed from tensorflow/python/client/net_lib.i)20
-rw-r--r--tensorflow/java/src/main/native/tensorflow.h36
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java31
-rw-r--r--tensorflow/models/embedding/word2vec.py6
-rw-r--r--tensorflow/models/image/cifar10/cifar10.py17
-rw-r--r--tensorflow/models/image/cifar10/cifar10_eval.py4
-rw-r--r--tensorflow/models/image/cifar10/cifar10_input.py2
-rw-r--r--tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py14
-rw-r--r--tensorflow/models/image/cifar10/cifar10_train.py1
-rw-r--r--tensorflow/models/rnn/ptb/ptb_word_lm.py6
-rw-r--r--tensorflow/python/BUILD28
-rw-r--r--tensorflow/python/__init__.py2
-rw-r--r--tensorflow/python/client/session.py7
-rw-r--r--tensorflow/python/client/session_test.py12
-rw-r--r--tensorflow/python/client/timeline.py19
-rw-r--r--tensorflow/python/client/timeline_test.py17
-rw-r--r--tensorflow/python/framework/graph_util.py2
-rw-r--r--tensorflow/python/framework/graph_util_impl.py18
-rw-r--r--tensorflow/python/framework/ops.py2
-rw-r--r--tensorflow/python/kernel_tests/concat_op_test.py20
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py25
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py28
-rw-r--r--tensorflow/python/kernel_tests/parsing_ops_test.py143
-rw-r--r--tensorflow/python/kernel_tests/pooling_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py15
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py32
-rw-r--r--tensorflow/python/kernel_tests/rnn_cell_test.py7
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py18
-rw-r--r--tensorflow/python/kernel_tests/string_split_op_test.py30
-rw-r--r--tensorflow/python/kernel_tests/summary_ops_test.py12
-rw-r--r--tensorflow/python/kernel_tests/template_test.py29
-rw-r--r--tensorflow/python/kernel_tests/tensor_array_ops_test.py4
-rw-r--r--tensorflow/python/layers/__init__.py32
-rw-r--r--tensorflow/python/layers/base.py58
-rw-r--r--tensorflow/python/layers/base_test.py35
-rw-r--r--tensorflow/python/layers/core.py159
-rw-r--r--tensorflow/python/layers/core_test.py108
-rw-r--r--tensorflow/python/layers/layers.py39
-rw-r--r--tensorflow/python/lib/io/file_io.py4
-rw-r--r--tensorflow/python/lib/io/py_record_reader.cc10
-rw-r--r--tensorflow/python/lib/io/py_record_reader.h10
-rw-r--r--tensorflow/python/lib/io/tf_record.py7
-rw-r--r--tensorflow/python/ops/array_ops.py2
-rw-r--r--tensorflow/python/ops/control_flow_ops.py6
-rw-r--r--tensorflow/python/ops/hidden_ops.txt6
-rw-r--r--tensorflow/python/ops/io_ops.py1
-rw-r--r--tensorflow/python/ops/math_grad.py14
-rw-r--r--tensorflow/python/ops/math_ops.py148
-rw-r--r--tensorflow/python/ops/math_ops_test.py6
-rw-r--r--tensorflow/python/ops/nn_grad.py9
-rw-r--r--tensorflow/python/ops/nn_impl.py8
-rw-r--r--tensorflow/python/ops/nn_ops.py4
-rw-r--r--tensorflow/python/ops/parsing_ops.py176
-rw-r--r--tensorflow/python/ops/rnn.py7
-rw-r--r--tensorflow/python/ops/rnn_cell.py859
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py872
-rw-r--r--tensorflow/python/ops/seq2seq.py3
-rw-r--r--tensorflow/python/ops/string_ops.py14
-rw-r--r--tensorflow/python/ops/template.py18
-rw-r--r--tensorflow/python/ops/variables.py6
-rw-r--r--tensorflow/python/platform/app.py16
-rw-r--r--tensorflow/python/saved_model/BUILD36
-rw-r--r--tensorflow/python/saved_model/example/BUILD16
-rw-r--r--tensorflow/python/saved_model/example/saved_model_half_plus_two.py39
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py11
-rw-r--r--tensorflow/python/saved_model/signature_def_utils.py158
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_test.py156
-rw-r--r--tensorflow/python/saved_model/utils.py28
-rw-r--r--tensorflow/python/saved_model/utils_test.py30
-rw-r--r--tensorflow/python/summary/event_accumulator_test.py8
-rw-r--r--tensorflow/python/summary/impl/event_file_loader.py10
-rw-r--r--tensorflow/python/summary/impl/event_file_loader_test.py9
-rw-r--r--tensorflow/python/summary/summary.py4
-rw-r--r--tensorflow/python/summary/summary_iterator.py4
-rw-r--r--tensorflow/python/summary/writer/writer.py4
-rw-r--r--tensorflow/python/summary/writer/writer_cache.py28
-rw-r--r--tensorflow/python/summary/writer/writer_test.py30
-rw-r--r--tensorflow/python/tensorflow.i1
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py2
-rw-r--r--tensorflow/python/training/basic_session_run_hooks_test.py30
-rw-r--r--tensorflow/python/training/moving_averages.py9
-rw-r--r--tensorflow/python/training/moving_averages_test.py96
-rw-r--r--tensorflow/python/training/saver.py21
-rw-r--r--tensorflow/python/training/summary_io.py3
-rw-r--r--tensorflow/python/training/supervisor_test.py2
-rw-r--r--tensorflow/python/training/tensorboard_logging_test.py2
-rw-r--r--tensorflow/python/util/deprecation.py80
-rw-r--r--tensorflow/python/util/deprecation_test.py33
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.cc49
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.h5
-rw-r--r--tensorflow/stream_executor/kernel.h269
-rw-r--r--tensorflow/stream_executor/stream_executor.h28
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.h5
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc7
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h10
-rw-r--r--tensorflow/stream_executor/trace_listener.h2
-rw-r--r--tensorflow/tensorboard/backend/server_test.py2
-rw-r--r--tensorflow/tensorboard/components/vz_projector/data.ts59
-rw-r--r--tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts197
-rw-r--r--tensorflow/tensorboard/components/vz_projector/renderContext.ts8
-rw-r--r--tensorflow/tensorboard/components/vz_projector/scatterPlot.ts58
-rw-r--r--tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer.ts4
-rw-r--r--tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer3DLabels.ts76
-rw-r--r--tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts24
-rw-r--r--tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts199
-rw-r--r--tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerTraces.ts42
-rw-r--r--tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.ts42
-rw-r--r--tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts28
-rw-r--r--tensorflow/tensorboard/components/vz_projector/vz-projector.ts112
-rw-r--r--tensorflow/tensorboard/scripts/generate_testdata.py8
-rw-r--r--tensorflow/tensorflow.bzl27
-rw-r--r--tensorflow/tools/benchmark/benchmark_model.cc5
-rwxr-xr-xtensorflow/tools/ci_build/ci_sanity.sh69
-rw-r--r--tensorflow/tools/dist_test/python/census_widendeep.py4
-rw-r--r--tensorflow/tools/dist_test/server/Dockerfile.test4
-rw-r--r--tensorflow/tools/pip_package/BUILD24
-rw-r--r--tensorflow/workspace.bzl16
-rw-r--r--third_party/eigen3/BUILD2
-rw-r--r--third_party/hadoop/BUILD2
-rw-r--r--third_party/hadoop/LICENSE.txt284
-rw-r--r--third_party/llvm/llvm.BUILD25
-rw-r--r--third_party/pcre.BUILD2
-rwxr-xr-xthird_party/sycl/sycl/BUILD.tpl2
-rw-r--r--third_party/sycl/sycl/LICENSE.text.tpl268
-rw-r--r--third_party/sycl/sycl_configure.bzl1
424 files changed, 22987 insertions, 7028 deletions
diff --git a/RELEASE.md b/RELEASE.md
index 0bb2a92ba6..ead29f0c54 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -59,6 +59,13 @@
* Removes RegisterShape from public API. Use C++ shape function registration instead.
indexing now starts from 1 instead of 0, and `bus_id==0` is used where
previously `BUS_ANY` was used.
+* Most RNN cells and RNN functions now use different variable scopes to be
+ consistent with layers (`tf.contrib.layers`). This means old checkpoints
+ written using this code will not load after this change without providing
+ `Saver` a list of variable renames. Examples of variable scope changes
+ include `RNN` -> `rnn` in `tf.nn.rnn`, `tf.nn.dynamic_rnn` and moving from
+ `Linear/Matrix` -> `weights` and `Linear/Bias` -> `biases` in most RNN cells.
+* Deprecated tf.select op. tf.where should be used instead.
* `Env::FileExists` and `FileSystem::FileExists` now return a
`tensorflow::Status` intead of a bool. Any callers to this function can be
converted to a bool by adding `.ok()` to the call.
@@ -118,7 +125,6 @@ Yuming Wang, Zafar Takhirov, @zhongyuk, Ziming Dong, @guotong1988
We are also grateful to all who filed issues or helped resolve them, asked and
answered questions, and were part of inspiring discussions.
->>>>>>> r0.12
# Release 0.11.0
diff --git a/WORKSPACE b/WORKSPACE
index 30aba396b8..20c0285084 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -29,6 +29,13 @@ new_http_archive(
sha256 = "d13569f6a98159de37e92e9c8ec4dae8f674fbf475f69fe6199b514f756d4364"
)
+new_http_archive(
+ name = "mobile_multibox",
+ build_file = "models.BUILD",
+ url = "https://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1.zip",
+ sha256 = "b4c178fd6236dcf0a20d25d07c45eebe85281263978c6a6f1dfc49d75befc45f"
+)
+
# TENSORBOARD_BOWER_AUTOGENERATED_BELOW_THIS_LINE_DO_NOT_EDIT
new_http_archive(
diff --git a/eigen.BUILD b/eigen.BUILD
index 8ce28ac076..3fd710dfd4 100644
--- a/eigen.BUILD
+++ b/eigen.BUILD
@@ -9,6 +9,8 @@ licenses([
"notice", # Portions BSD
])
+exports_files(["COPYING.MPL2",])
+
# License-restricted (i.e. not reciprocal or notice) files inside Eigen/...
EIGEN_RESTRICTED_FILES = [
"Eigen/src/OrderingMethods/Amd.h",
diff --git a/farmhash.BUILD b/farmhash.BUILD
index b41c799f8f..d054797a56 100644
--- a/farmhash.BUILD
+++ b/farmhash.BUILD
@@ -1,5 +1,7 @@
licenses(["notice"]) # MIT
+exports_files(["COPYING"])
+
config_setting(
name = "windows",
values = {
@@ -10,13 +12,13 @@ config_setting(
cc_library(
name = "farmhash",
- srcs = ["farmhash.cc"],
- hdrs = ["farmhash.h"],
+ srcs = ["src/farmhash.cc"],
+ hdrs = ["src/farmhash.h"],
# Disable __builtin_expect support on Windows
copts = select({
":windows" : ["/DFARMHASH_OPTIONAL_BUILTIN_EXPECT"],
"//conditions:default" : [],
}),
- includes = ["."],
+ includes = ["src/."],
visibility = ["//visibility:public"],
)
diff --git a/gif.BUILD b/gif.BUILD
index 22ccda52e4..fec7449130 100644
--- a/gif.BUILD
+++ b/gif.BUILD
@@ -3,22 +3,24 @@
licenses(["notice"]) # MIT
+exports_files(["COPYING"])
+
cc_library(
name = "gif",
srcs = [
- "dgif_lib.c",
- "egif_lib.c",
- "gif_err.c",
- "gif_font.c",
- "gif_hash.c",
- "gif_hash.h",
- "gif_lib_private.h",
- "gifalloc.c",
- "openbsd-reallocarray.c",
- "quantize.c",
+ "lib/dgif_lib.c",
+ "lib/egif_lib.c",
+ "lib/gif_err.c",
+ "lib/gif_font.c",
+ "lib/gif_hash.c",
+ "lib/gif_hash.h",
+ "lib/gif_lib_private.h",
+ "lib/gifalloc.c",
+ "lib/openbsd-reallocarray.c",
+ "lib/quantize.c",
],
- hdrs = ["gif_lib.h"],
- includes = ["."],
+ hdrs = ["lib/gif_lib.h"],
+ includes = ["lib/."],
visibility = ["//visibility:public"],
deps = select({
":windows": [":windows_polyfill"],
diff --git a/gmock.BUILD b/gmock.BUILD
index 66ed60750d..501e322529 100644
--- a/gmock.BUILD
+++ b/gmock.BUILD
@@ -4,6 +4,8 @@
licenses(["notice"]) # 3-clause BSD
+exports_files(["LICENSE"])
+
cc_library(
name = "gtest",
srcs = [
diff --git a/grpc.BUILD b/grpc.BUILD
index e74da683e3..e501db57e5 100644
--- a/grpc.BUILD
+++ b/grpc.BUILD
@@ -45,6 +45,8 @@ licenses(["notice"]) # 3-clause BSD
package(default_visibility = ["//visibility:public"])
+exports_files(["LICENSE"])
+
genrule(
name = "pb_h",
outs = ["third_party/nanopb/pb.h"],
diff --git a/jsoncpp.BUILD b/jsoncpp.BUILD
index 765bf15129..ce672a72ec 100644
--- a/jsoncpp.BUILD
+++ b/jsoncpp.BUILD
@@ -1,5 +1,7 @@
licenses(["unencumbered"]) # Public Domain or MIT
+exports_files(["LICENSE"])
+
cc_library(
name = "jsoncpp",
srcs = [
diff --git a/nanopb.BUILD b/nanopb.BUILD
index 8b428689e1..d21866911b 100644
--- a/nanopb.BUILD
+++ b/nanopb.BUILD
@@ -3,6 +3,8 @@
licenses(["notice"]) # zlib license
+exports_files(["LICENSE.txt"])
+
cc_library(
name = "nanopb",
srcs = [
diff --git a/png.BUILD b/png.BUILD
index 9ff982bc90..6a7ad719aa 100644
--- a/png.BUILD
+++ b/png.BUILD
@@ -3,6 +3,8 @@
licenses(["notice"]) # BSD/MIT-like license
+exports_files(["LICENSE"])
+
cc_library(
name = "png",
srcs = [
diff --git a/six.BUILD b/six.BUILD
index fd3d0cc16f..a1b2f7b20c 100644
--- a/six.BUILD
+++ b/six.BUILD
@@ -4,6 +4,8 @@
licenses(["notice"]) # MIT
+exports_files(["LICENSE"])
+
py_library(
name = "six",
srcs = ["six.py"],
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 1e6c5d5947..2d7c28feb7 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -160,6 +160,9 @@ filegroup(
"//tensorflow/g3doc/how_tos/adding_an_op:all_files",
"//tensorflow/g3doc/tutorials:all_files",
"//tensorflow/go:all_files",
+ "//tensorflow/java:all_files",
+ "//tensorflow/java/src/main/java/org/tensorflow/examples:all_files",
+ "//tensorflow/java/src/main/native:all_files",
"//tensorflow/models/embedding:all_files",
"//tensorflow/models/image/alexnet:all_files",
"//tensorflow/models/image/cifar10:all_files",
diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD
index 90c87210b1..bfa1386d96 100644
--- a/tensorflow/cc/saved_model/BUILD
+++ b/tensorflow/cc/saved_model/BUILD
@@ -45,7 +45,7 @@ tf_cc_test(
name = "loader_test",
srcs = ["loader_test.cc"],
data = [
- ":saved_model_half_plus_two",
+ "//tensorflow/python/saved_model/example:saved_model_half_plus_two_data",
],
linkstatic = 1,
deps = [
@@ -61,14 +61,6 @@ tf_cc_test(
],
)
-filegroup(
- name = "saved_model_half_plus_two",
- srcs = glob([
- "testdata/half_plus_two_pbtxt/**",
- "testdata/half_plus_two_sharded/**",
- ]),
-)
-
# -----------------------------------------------------------------------------
# Google-internal targets.
diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc
index 82f30c23f6..dbbcc79802 100644
--- a/tensorflow/cc/saved_model/loader_test.cc
+++ b/tensorflow/cc/saved_model/loader_test.cc
@@ -29,9 +29,10 @@ limitations under the License.
namespace tensorflow {
namespace {
-constexpr char kTestDataPbTxt[] = "cc/saved_model/testdata/half_plus_two_pbtxt";
+constexpr char kTestDataPbTxt[] =
+ "python/saved_model/example/saved_model_half_plus_two_pbtxt/00000123";
constexpr char kTestDataSharded[] =
- "cc/saved_model/testdata/half_plus_two_sharded";
+ "python/saved_model/example/saved_model_half_plus_two/00000123";
class LoaderTest : public ::testing::Test {
protected:
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/assets/foo.txt b/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/assets/foo.txt
deleted file mode 100644
index f9ff036688..0000000000
--- a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/assets/foo.txt
+++ /dev/null
@@ -1 +0,0 @@
-asset-file-contents \ No newline at end of file
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/saved_model.pbtxt b/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/saved_model.pbtxt
deleted file mode 100644
index 693262eb4d..0000000000
--- a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/saved_model.pbtxt
+++ /dev/null
@@ -1,1951 +0,0 @@
-saved_model_schema_version: 1
-meta_graphs {
- meta_info_def {
- stripped_op_list {
- op {
- name: "Add"
- input_arg {
- name: "x"
- type_attr: "T"
- }
- input_arg {
- name: "y"
- type_attr: "T"
- }
- output_arg {
- name: "z"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_HALF
- type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_UINT8
- type: DT_INT8
- type: DT_INT16
- type: DT_INT32
- type: DT_INT64
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_STRING
- }
- }
- }
- }
- op {
- name: "Assign"
- input_arg {
- name: "ref"
- type_attr: "T"
- is_ref: true
- }
- input_arg {
- name: "value"
- type_attr: "T"
- }
- output_arg {
- name: "output_ref"
- type_attr: "T"
- is_ref: true
- }
- attr {
- name: "T"
- type: "type"
- }
- attr {
- name: "validate_shape"
- type: "bool"
- default_value {
- b: true
- }
- }
- attr {
- name: "use_locking"
- type: "bool"
- default_value {
- b: true
- }
- }
- allows_uninitialized_input: true
- }
- op {
- name: "Const"
- output_arg {
- name: "output"
- type_attr: "dtype"
- }
- attr {
- name: "value"
- type: "tensor"
- }
- attr {
- name: "dtype"
- type: "type"
- }
- }
- op {
- name: "Identity"
- input_arg {
- name: "input"
- type_attr: "T"
- }
- output_arg {
- name: "output"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- }
- }
- op {
- name: "MergeV2Checkpoints"
- input_arg {
- name: "checkpoint_prefixes"
- type: DT_STRING
- }
- input_arg {
- name: "destination_prefix"
- type: DT_STRING
- }
- attr {
- name: "delete_old_dirs"
- type: "bool"
- default_value {
- b: true
- }
- }
- }
- op {
- name: "Mul"
- input_arg {
- name: "x"
- type_attr: "T"
- }
- input_arg {
- name: "y"
- type_attr: "T"
- }
- output_arg {
- name: "z"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_HALF
- type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_UINT8
- type: DT_INT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT32
- type: DT_INT64
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- }
- }
- }
- is_commutative: true
- }
- op {
- name: "NoOp"
- }
- op {
- name: "Pack"
- input_arg {
- name: "values"
- type_attr: "T"
- number_attr: "N"
- }
- output_arg {
- name: "output"
- type_attr: "T"
- }
- attr {
- name: "N"
- type: "int"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "T"
- type: "type"
- }
- attr {
- name: "axis"
- type: "int"
- default_value {
- i: 0
- }
- }
- }
- op {
- name: "ParseExample"
- input_arg {
- name: "serialized"
- type: DT_STRING
- }
- input_arg {
- name: "names"
- type: DT_STRING
- }
- input_arg {
- name: "sparse_keys"
- type: DT_STRING
- number_attr: "Nsparse"
- }
- input_arg {
- name: "dense_keys"
- type: DT_STRING
- number_attr: "Ndense"
- }
- input_arg {
- name: "dense_defaults"
- type_list_attr: "Tdense"
- }
- output_arg {
- name: "sparse_indices"
- type: DT_INT64
- number_attr: "Nsparse"
- }
- output_arg {
- name: "sparse_values"
- type_list_attr: "sparse_types"
- }
- output_arg {
- name: "sparse_shapes"
- type: DT_INT64
- number_attr: "Nsparse"
- }
- output_arg {
- name: "dense_values"
- type_list_attr: "Tdense"
- }
- attr {
- name: "Nsparse"
- type: "int"
- has_minimum: true
- }
- attr {
- name: "Ndense"
- type: "int"
- has_minimum: true
- }
- attr {
- name: "sparse_types"
- type: "list(type)"
- has_minimum: true
- allowed_values {
- list {
- type: DT_FLOAT
- type: DT_INT64
- type: DT_STRING
- }
- }
- }
- attr {
- name: "Tdense"
- type: "list(type)"
- has_minimum: true
- allowed_values {
- list {
- type: DT_FLOAT
- type: DT_INT64
- type: DT_STRING
- }
- }
- }
- attr {
- name: "dense_shapes"
- type: "list(shape)"
- has_minimum: true
- }
- }
- op {
- name: "Placeholder"
- output_arg {
- name: "output"
- type_attr: "dtype"
- }
- attr {
- name: "dtype"
- type: "type"
- }
- attr {
- name: "shape"
- type: "shape"
- default_value {
- shape {
- }
- }
- }
- }
- op {
- name: "RestoreV2"
- input_arg {
- name: "prefix"
- type: DT_STRING
- }
- input_arg {
- name: "tensor_names"
- type: DT_STRING
- }
- input_arg {
- name: "shape_and_slices"
- type: DT_STRING
- }
- output_arg {
- name: "tensors"
- type_list_attr: "dtypes"
- }
- attr {
- name: "dtypes"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- }
- op {
- name: "SaveV2"
- input_arg {
- name: "prefix"
- type: DT_STRING
- }
- input_arg {
- name: "tensor_names"
- type: DT_STRING
- }
- input_arg {
- name: "shape_and_slices"
- type: DT_STRING
- }
- input_arg {
- name: "tensors"
- type_list_attr: "dtypes"
- }
- attr {
- name: "dtypes"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- }
- op {
- name: "ShardedFilename"
- input_arg {
- name: "basename"
- type: DT_STRING
- }
- input_arg {
- name: "shard"
- type: DT_INT32
- }
- input_arg {
- name: "num_shards"
- type: DT_INT32
- }
- output_arg {
- name: "filename"
- type: DT_STRING
- }
- }
- op {
- name: "StringJoin"
- input_arg {
- name: "inputs"
- type: DT_STRING
- number_attr: "N"
- }
- output_arg {
- name: "output"
- type: DT_STRING
- }
- attr {
- name: "N"
- type: "int"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "separator"
- type: "string"
- default_value {
- s: ""
- }
- }
- }
- op {
- name: "Variable"
- output_arg {
- name: "ref"
- type_attr: "dtype"
- is_ref: true
- }
- attr {
- name: "shape"
- type: "shape"
- }
- attr {
- name: "dtype"
- type: "type"
- }
- attr {
- name: "container"
- type: "string"
- default_value {
- s: ""
- }
- }
- attr {
- name: "shared_name"
- type: "string"
- default_value {
- s: ""
- }
- }
- is_stateful: true
- }
- }
- tags: "serve"
- }
- graph_def {
- node {
- name: "a/initial_value"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_FLOAT
- tensor_shape {
- }
- float_val: 0.5
- }
- }
- }
- }
- node {
- name: "a"
- op: "Variable"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "container"
- value {
- s: ""
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "shape"
- value {
- shape {
- }
- }
- }
- attr {
- key: "shared_name"
- value {
- s: ""
- }
- }
- }
- node {
- name: "a/Assign"
- op: "Assign"
- input: "a"
- input: "a/initial_value"
- attr {
- key: "T"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "_class"
- value {
- list {
- s: "loc:@a"
- }
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "use_locking"
- value {
- b: true
- }
- }
- attr {
- key: "validate_shape"
- value {
- b: true
- }
- }
- }
- node {
- name: "a/read"
- op: "Identity"
- input: "a"
- attr {
- key: "T"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "_class"
- value {
- list {
- s: "loc:@a"
- }
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- }
- node {
- name: "b/initial_value"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_FLOAT
- tensor_shape {
- }
- float_val: 2.0
- }
- }
- }
- }
- node {
- name: "b"
- op: "Variable"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "container"
- value {
- s: ""
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "shape"
- value {
- shape {
- }
- }
- }
- attr {
- key: "shared_name"
- value {
- s: ""
- }
- }
- }
- node {
- name: "b/Assign"
- op: "Assign"
- input: "b"
- input: "b/initial_value"
- attr {
- key: "T"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "_class"
- value {
- list {
- s: "loc:@b"
- }
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "use_locking"
- value {
- b: true
- }
- }
- attr {
- key: "validate_shape"
- value {
- b: true
- }
- }
- }
- node {
- name: "b/read"
- op: "Identity"
- input: "b"
- attr {
- key: "T"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "_class"
- value {
- list {
- s: "loc:@b"
- }
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- }
- node {
- name: "tf_example"
- op: "Placeholder"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- unknown_rank: true
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "shape"
- value {
- shape {
- }
- }
- }
- }
- node {
- name: "ParseExample/Const"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- }
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_FLOAT
- tensor_shape {
- dim {
- }
- }
- }
- }
- }
- }
- node {
- name: "ParseExample/ParseExample/names"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- }
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_STRING
- tensor_shape {
- dim {
- }
- }
- }
- }
- }
- }
- node {
- name: "ParseExample/ParseExample/dense_keys_0"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_STRING
- tensor_shape {
- }
- string_val: "x"
- }
- }
- }
- }
- node {
- name: "ParseExample/ParseExample"
- op: "ParseExample"
- input: "tf_example"
- input: "ParseExample/ParseExample/names"
- input: "ParseExample/ParseExample/dense_keys_0"
- input: "ParseExample/Const"
- attr {
- key: "Ndense"
- value {
- i: 1
- }
- }
- attr {
- key: "Nsparse"
- value {
- i: 0
- }
- }
- attr {
- key: "Tdense"
- value {
- list {
- type: DT_FLOAT
- }
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- size: -1
- }
- dim {
- size: 1
- }
- }
- }
- }
- }
- attr {
- key: "dense_shapes"
- value {
- list {
- shape {
- dim {
- size: 1
- }
- }
- }
- }
- }
- attr {
- key: "sparse_types"
- value {
- list {
- }
- }
- }
- }
- node {
- name: "x"
- op: "Identity"
- input: "ParseExample/ParseExample"
- attr {
- key: "T"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- size: -1
- }
- dim {
- size: 1
- }
- }
- }
- }
- }
- }
- node {
- name: "Mul"
- op: "Mul"
- input: "a/read"
- input: "x"
- attr {
- key: "T"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- size: -1
- }
- dim {
- size: 1
- }
- }
- }
- }
- }
- }
- node {
- name: "y"
- op: "Add"
- input: "Mul"
- input: "b/read"
- attr {
- key: "T"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- size: -1
- }
- dim {
- size: 1
- }
- }
- }
- }
- }
- }
- node {
- name: "Const"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_STRING
- tensor_shape {
- }
- string_val: "/tmp/original/export/assets/foo.txt"
- }
- }
- }
- }
- node {
- name: "filename_tensor/initial_value"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_STRING
- tensor_shape {
- }
- string_val: "foo.txt"
- }
- }
- }
- }
- node {
- name: "filename_tensor"
- op: "Variable"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "container"
- value {
- s: ""
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "shape"
- value {
- shape {
- }
- }
- }
- attr {
- key: "shared_name"
- value {
- s: ""
- }
- }
- }
- node {
- name: "filename_tensor/Assign"
- op: "Assign"
- input: "filename_tensor"
- input: "filename_tensor/initial_value"
- attr {
- key: "T"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "_class"
- value {
- list {
- s: "loc:@filename_tensor"
- }
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "use_locking"
- value {
- b: true
- }
- }
- attr {
- key: "validate_shape"
- value {
- b: true
- }
- }
- }
- node {
- name: "filename_tensor/read"
- op: "Identity"
- input: "filename_tensor"
- attr {
- key: "T"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "_class"
- value {
- list {
- s: "loc:@filename_tensor"
- }
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- }
- node {
- name: "Assign/value"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_STRING
- tensor_shape {
- }
- string_val: "foo.txt"
- }
- }
- }
- }
- node {
- name: "Assign"
- op: "Assign"
- input: "filename_tensor"
- input: "Assign/value"
- attr {
- key: "T"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "_class"
- value {
- list {
- s: "loc:@filename_tensor"
- }
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "use_locking"
- value {
- b: false
- }
- }
- attr {
- key: "validate_shape"
- value {
- b: true
- }
- }
- }
- node {
- name: "Identity"
- op: "Identity"
- input: "y"
- attr {
- key: "T"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- size: -1
- }
- dim {
- size: 1
- }
- }
- }
- }
- }
- }
- node {
- name: "init"
- op: "NoOp"
- input: "^a/Assign"
- input: "^b/Assign"
- }
- node {
- name: "group_deps"
- op: "NoOp"
- input: "^Assign"
- }
- node {
- name: "save/Const"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_STRING
- tensor_shape {
- }
- string_val: "model"
- }
- }
- }
- }
- node {
- name: "save/StringJoin/inputs_1"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_STRING
- tensor_shape {
- }
- string_val: "_temp_ff2bd25218b646ea9ed224eecdce5e79/part"
- }
- }
- }
- }
- node {
- name: "save/StringJoin"
- op: "StringJoin"
- input: "save/Const"
- input: "save/StringJoin/inputs_1"
- attr {
- key: "N"
- value {
- i: 2
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "separator"
- value {
- s: ""
- }
- }
- }
- node {
- name: "save/num_shards"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_INT32
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_INT32
- tensor_shape {
- }
- int_val: 1
- }
- }
- }
- }
- node {
- name: "save/ShardedFilename/shard"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_INT32
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_INT32
- tensor_shape {
- }
- int_val: 0
- }
- }
- }
- }
- node {
- name: "save/ShardedFilename"
- op: "ShardedFilename"
- input: "save/StringJoin"
- input: "save/ShardedFilename/shard"
- input: "save/num_shards"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- }
- node {
- name: "save/SaveV2/tensor_names"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- size: 2
- }
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_STRING
- tensor_shape {
- dim {
- size: 2
- }
- }
- string_val: "a"
- string_val: "b"
- }
- }
- }
- }
- node {
- name: "save/SaveV2/shape_and_slices"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- size: 2
- }
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_STRING
- tensor_shape {
- dim {
- size: 2
- }
- }
- string_val: ""
- string_val: ""
- }
- }
- }
- }
- node {
- name: "save/SaveV2"
- op: "SaveV2"
- input: "save/ShardedFilename"
- input: "save/SaveV2/tensor_names"
- input: "save/SaveV2/shape_and_slices"
- input: "a"
- input: "b"
- attr {
- key: "dtypes"
- value {
- list {
- type: DT_FLOAT
- type: DT_FLOAT
- }
- }
- }
- }
- node {
- name: "save/control_dependency"
- op: "Identity"
- input: "save/ShardedFilename"
- input: "^save/SaveV2"
- attr {
- key: "T"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "_class"
- value {
- list {
- s: "loc:@save/ShardedFilename"
- }
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- }
- node {
- name: "save/MergeV2Checkpoints/checkpoint_prefixes"
- op: "Pack"
- input: "save/ShardedFilename"
- input: "^save/control_dependency"
- attr {
- key: "N"
- value {
- i: 1
- }
- }
- attr {
- key: "T"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- size: 1
- }
- }
- }
- }
- }
- attr {
- key: "axis"
- value {
- i: 0
- }
- }
- }
- node {
- name: "save/MergeV2Checkpoints"
- op: "MergeV2Checkpoints"
- input: "save/MergeV2Checkpoints/checkpoint_prefixes"
- input: "save/Const"
- attr {
- key: "delete_old_dirs"
- value {
- b: true
- }
- }
- }
- node {
- name: "save/Identity"
- op: "Identity"
- input: "save/Const"
- input: "^save/control_dependency"
- input: "^save/MergeV2Checkpoints"
- attr {
- key: "T"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- }
- node {
- name: "save/RestoreV2/tensor_names"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- size: 1
- }
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_STRING
- tensor_shape {
- dim {
- size: 1
- }
- }
- string_val: "a"
- }
- }
- }
- }
- node {
- name: "save/RestoreV2/shape_and_slices"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- size: 1
- }
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_STRING
- tensor_shape {
- dim {
- size: 1
- }
- }
- string_val: ""
- }
- }
- }
- }
- node {
- name: "save/RestoreV2"
- op: "RestoreV2"
- input: "save/Const"
- input: "save/RestoreV2/tensor_names"
- input: "save/RestoreV2/shape_and_slices"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- unknown_rank: true
- }
- }
- }
- }
- attr {
- key: "dtypes"
- value {
- list {
- type: DT_FLOAT
- }
- }
- }
- }
- node {
- name: "save/Assign"
- op: "Assign"
- input: "a"
- input: "save/RestoreV2"
- attr {
- key: "T"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "_class"
- value {
- list {
- s: "loc:@a"
- }
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "use_locking"
- value {
- b: true
- }
- }
- attr {
- key: "validate_shape"
- value {
- b: true
- }
- }
- }
- node {
- name: "save/RestoreV2_1/tensor_names"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- size: 1
- }
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_STRING
- tensor_shape {
- dim {
- size: 1
- }
- }
- string_val: "b"
- }
- }
- }
- }
- node {
- name: "save/RestoreV2_1/shape_and_slices"
- op: "Const"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- dim {
- size: 1
- }
- }
- }
- }
- }
- attr {
- key: "dtype"
- value {
- type: DT_STRING
- }
- }
- attr {
- key: "value"
- value {
- tensor {
- dtype: DT_STRING
- tensor_shape {
- dim {
- size: 1
- }
- }
- string_val: ""
- }
- }
- }
- }
- node {
- name: "save/RestoreV2_1"
- op: "RestoreV2"
- input: "save/Const"
- input: "save/RestoreV2_1/tensor_names"
- input: "save/RestoreV2_1/shape_and_slices"
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- unknown_rank: true
- }
- }
- }
- }
- attr {
- key: "dtypes"
- value {
- list {
- type: DT_FLOAT
- }
- }
- }
- }
- node {
- name: "save/Assign_1"
- op: "Assign"
- input: "b"
- input: "save/RestoreV2_1"
- attr {
- key: "T"
- value {
- type: DT_FLOAT
- }
- }
- attr {
- key: "_class"
- value {
- list {
- s: "loc:@b"
- }
- }
- }
- attr {
- key: "_output_shapes"
- value {
- list {
- shape {
- }
- }
- }
- }
- attr {
- key: "use_locking"
- value {
- b: true
- }
- }
- attr {
- key: "validate_shape"
- value {
- b: true
- }
- }
- }
- node {
- name: "save/restore_shard"
- op: "NoOp"
- input: "^save/Assign"
- input: "^save/Assign_1"
- }
- node {
- name: "save/restore_all"
- op: "NoOp"
- input: "^save/restore_shard"
- }
- versions {
- producer: 15
- }
- }
- saver_def {
- filename_tensor_name: "save/Const:0"
- save_tensor_name: "save/Identity:0"
- restore_op_name: "save/restore_all"
- max_to_keep: 5
- sharded: true
- keep_checkpoint_every_n_hours: 10000.0
- version: V2
- }
- collection_def {
- key: "asset_filepaths"
- value {
- node_list {
- value: "Const:0"
- }
- }
- }
- collection_def {
- key: "legacy_init_op"
- value {
- node_list {
- value: "group_deps"
- }
- }
- }
- collection_def {
- key: "saved_model_assets"
- value {
- any_list {
- value {
- type_url: "type.googleapis.com/tensorflow.AssetFileDef"
- value: "\n\t\n\007Const:0\022\007foo.txt"
- }
- }
- }
- }
- collection_def {
- key: "trainable_variables"
- value {
- bytes_list {
- value: "\n\003a:0\022\010a/Assign\032\010a/read:0"
- value: "\n\003b:0\022\010b/Assign\032\010b/read:0"
- }
- }
- }
- collection_def {
- key: "variables"
- value {
- bytes_list {
- value: "\n\003a:0\022\010a/Assign\032\010a/read:0"
- value: "\n\003b:0\022\010b/Assign\032\010b/read:0"
- }
- }
- }
- signature_def {
- key: "tensorflow/serving/regress"
- value {
- inputs {
- key: "inputs"
- value {
- name: "tf_example:0"
- }
- }
- outputs {
- key: "outputs"
- value {
- name: "Identity:0"
- }
- }
- method_name: "tensorflow/serving/regress"
- }
- }
-}
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.data-00000-of-00001
deleted file mode 100644
index 20bc7d454d..0000000000
--- a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.data-00000-of-00001
+++ /dev/null
Binary files differ
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.index b/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.index
deleted file mode 100644
index e7df518f5b..0000000000
--- a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.index
+++ /dev/null
Binary files differ
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/assets/foo.txt b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/assets/foo.txt
deleted file mode 100644
index f9ff036688..0000000000
--- a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/assets/foo.txt
+++ /dev/null
@@ -1 +0,0 @@
-asset-file-contents \ No newline at end of file
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/saved_model.pb b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/saved_model.pb
deleted file mode 100644
index 0df49f2168..0000000000
--- a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/saved_model.pb
+++ /dev/null
Binary files differ
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.data-00000-of-00001
deleted file mode 100644
index 20bc7d454d..0000000000
--- a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.data-00000-of-00001
+++ /dev/null
Binary files differ
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.index b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.index
deleted file mode 100644
index e7df518f5b..0000000000
--- a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.index
+++ /dev/null
Binary files differ
diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc
index 10ff80c9cd..5d6710ea5c 100644
--- a/tensorflow/cc/training/queue_runner.cc
+++ b/tensorflow/cc/training/queue_runner.cc
@@ -33,6 +33,16 @@ Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
return (*result)->Init(queue_runner_def);
}
+void QueueRunner::AddErrorCallback(const std::function<void(Status)>& cb) {
+ mutex_lock l(cb_mu_);
+ callbacks_.push_back(cb);
+}
+
+void QueueRunner::ClearErrorCallbacks() {
+ mutex_lock l(cb_mu_);
+ callbacks_.clear();
+}
+
Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
queue_name_ = queue_runner_def.queue_name();
enqueue_op_names_.clear();
@@ -100,7 +110,6 @@ Status QueueRunner::Start(Session* sess, int wait_for) {
}
void QueueRunner::Stop(Session* sess) {
- DCHECK(coord_ != nullptr);
if (cancel_op_name_.empty()) {
return;
}
@@ -127,6 +136,10 @@ void QueueRunner::UpdateStatus(const Status& status) {
if (coord_) {
coord_->ReportStatus(status);
}
+ mutex_lock l(cb_mu_);
+ for (auto& cb : callbacks_) {
+ cb(status);
+ }
}
void QueueRunner::Run(Session* sess, const string& enqueue_op) {
diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h
index 273eb39671..fd9f97a958 100644
--- a/tensorflow/cc/training/queue_runner.h
+++ b/tensorflow/cc/training/queue_runner.h
@@ -46,6 +46,12 @@ class QueueRunner : public RunnerInterface {
static Status New(const QueueRunnerDef& queue_runner_def, Coordinator* coord,
std::unique_ptr<QueueRunner>* result);
+ // Adds a callback that the queue runner will call when it detects an error.
+ void AddErrorCallback(const std::function<void(Status)>& cb);
+
+ // Delete the previously registered callbacks.
+ void ClearErrorCallbacks();
+
// The destructor would join all the threads.
~QueueRunner();
@@ -56,6 +62,11 @@ class QueueRunner : public RunnerInterface {
// specified time (in milliseconds) for the queues to start to fill up.
Status Start(Session* sess, int wait_for_ms);
+ // Requests to stop and runs the cancel op. It would be called in a separate
+ // thread when coordinator is set. If there is no coordinator it should be
+ // called before calling Join.
+ void Stop(Session* sess);
+
// Joins all the threads. Returns okay if all threads run successfully;
// otherwise returns the first captured failure status.
Status Join() final;
@@ -72,10 +83,6 @@ class QueueRunner : public RunnerInterface {
// The Run function for each thread.
void Run(Session* sess, const string& enqueue_op);
- // Requests to stop and runs the cancel op. It would be called in a separate
- // thread when coordinator is set.
- void Stop(Session* sess);
-
// Updates the internal status; it only keeps OK or the first unexpected error
// status.
void UpdateStatus(const Status& status);
@@ -100,6 +107,9 @@ class QueueRunner : public RunnerInterface {
std::unique_ptr<BlockingCounter> counter_;
Coordinator* coord_;
+
+ mutex cb_mu_;
+ std::vector<std::function<void(Status)>> callbacks_;
};
} // namespace tensorflow
diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc
index 0e7e94b40f..1661c5c91b 100644
--- a/tensorflow/cc/training/queue_runner_test.cc
+++ b/tensorflow/cc/training/queue_runner_test.cc
@@ -328,5 +328,21 @@ TEST(QueueRunnerTest, TestCoordinatorStop) {
TF_EXPECT_OK(coord.Join());
}
+TEST(QueueRunnerTest, CallbackCalledOnError) {
+ GraphDef graph_def = BuildSimpleGraph();
+ auto session = BuildSessionAndInitVariable(graph_def);
+
+ QueueRunnerDef queue_runner_def = BuildQueueRunnerDef(
+ kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {});
+
+ std::unique_ptr<QueueRunner> qr;
+ TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr));
+ bool error_caught = false;
+ qr->AddErrorCallback([&error_caught](const Status&) { error_caught = true; });
+ TF_EXPECT_OK(qr->Start(session.get()));
+ qr->Join();
+ EXPECT_TRUE(error_caught);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/contrib/android/cmake/CMakeLists.txt b/tensorflow/contrib/android/cmake/CMakeLists.txt
new file mode 100644
index 0000000000..a7e0d581aa
--- /dev/null
+++ b/tensorflow/contrib/android/cmake/CMakeLists.txt
@@ -0,0 +1,61 @@
+#
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+cmake_minimum_required(VERSION 3.4.1)
+include(ExternalProject)
+
+# TENSORFLOW_ROOT_DIR:
+# root directory of tensorflow repo
+# used for shared source files and pre-built libs
+get_filename_component(TENSORFLOW_ROOT_DIR ../../../.. ABSOLUTE)
+set(PREBUILT_DIR ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen)
+
+add_library(lib_proto STATIC IMPORTED )
+set_target_properties(lib_proto PROPERTIES IMPORTED_LOCATION
+ ${PREBUILT_DIR}/protobuf/lib/libprotobuf.a)
+
+add_library(lib_tf STATIC IMPORTED )
+set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION
+ ${PREBUILT_DIR}/lib/libtensorflow-core.a)
+# Change to compile flags should be replicated into bazel build file
+# LINT.IfChange
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -fno-rtti -fno-exceptions \
+ -fpic -O2 -mfpu=neon -DTF_LEAN_BINARY \
+ -DGOOGLE_PROTOBUF_NO_RTTI \
+ -DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER")
+# LINT.ThenChange(//tensorflow/tensorflow.bzl)
+
+set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} \
+ -Wl,--allow-multiple-definition \
+ -Wl,--whole-archive")
+
+file(GLOB tensorflow_inference_sources
+ ${CMAKE_CURRENT_SOURCE_DIR}/../jni/*.cc)
+add_library(tensorflow_inference SHARED ${tensorflow_inference_sources})
+
+# Include libraries needed for hello-jni lib
+target_link_libraries(tensorflow_inference
+ android
+ log
+ m
+ z
+ lib_tf
+ lib_proto)
+include_directories(
+ ${PREBUILT_DIR}/proto
+ ${PREBUILT_DIR}/protobuf/include
+ ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/downloads/eigen
+ ${TENSORFLOW_ROOT_DIR}
+ ${CMAKE_CURRENT_SOURCE_DIR}/..)
diff --git a/tensorflow/contrib/android/cmake/README.md b/tensorflow/contrib/android/cmake/README.md
new file mode 100644
index 0000000000..ad9e1720c7
--- /dev/null
+++ b/tensorflow/contrib/android/cmake/README.md
@@ -0,0 +1,44 @@
+TensorFlow-Android-Inference
+============================
+Android Java interface to the TensorFlow native APIs
+
+Usage
+-----
+Add TensorFlow-Android-Inference as a dependency of your Android application
+
+* settings.gradle
+
+```
+include ':TensorFlow-Android-Inference'
+findProject(":TensorFlow-Android-Inference").projectDir =
+ new File("${/path/to/tensorflow_repo}/contrib/android/cmake")
+```
+
+* application's build.gradle (adding dependency):
+
+```
+debugCompile project(path: ':tensorflow_inference', configuration: 'debug')
+releaseCompile project(path: ':tensorflow_inference', configuration: 'release')
+```
+Note: this makes native code in the lib traceable from your app.
+
+Dependencies
+------------
+TensorFlow-Android-Inference depends on the TensorFlow static libs already built in your
+local TensorFlow repo directory. For Linux/Mac OS, build_all_android.sh is used
+in build.gradle to build it. It DOES take time to build the core libs;
+so, by default, it is commented out to avoid confusion (otherwise
+Android Studio would appear to hang during opening the project).
+To enable it, refer to the comment in
+
+* build.gradle
+
+Output
+------
+- TensorFlow-Inference-debug.aar
+- TensorFlow-Inference-release.aar
+
+File libtensorflow_inference.so should be packed under jni/${ANDROID_ABI}/
+in the above aar, and it is transparent to the app as it will acccess them via
+equivalent java APIs.
+
diff --git a/tensorflow/contrib/android/cmake/build.gradle b/tensorflow/contrib/android/cmake/build.gradle
new file mode 100644
index 0000000000..8791fac18a
--- /dev/null
+++ b/tensorflow/contrib/android/cmake/build.gradle
@@ -0,0 +1,97 @@
+apply plugin: 'com.android.library'
+
+android {
+ compileSdkVersion 24
+ buildToolsVersion "24.0.2"
+
+ // for debugging native code purpose
+ publishNonDefault true
+
+ defaultConfig {
+ archivesBaseName = "Tensorflow-Android-Inference"
+ minSdkVersion 21
+ targetSdkVersion 21
+ versionCode 1
+ versionName "1.0"
+ ndk {
+ abiFilters 'armeabi-v7a'
+ }
+ externalNativeBuild {
+ cmake {
+ arguments '-DANDROID_TOOLCHAIN=gcc',
+ '-DANDROID_STL=gnustl_static'
+ }
+ }
+ }
+ sourceSets {
+ main {
+ java.srcDirs = ["../java"]
+ }
+ }
+
+ externalNativeBuild {
+ cmake {
+ path 'CMakeLists.txt'
+ }
+ }
+ buildTypes {
+ release {
+ minifyEnabled false
+ proguardFiles getDefaultProguardFile('proguard-android.txt'),
+ 'proguard-rules.pro'
+ }
+ }
+}
+
+// Build libtensorflow-core.a if necessary
+// Note: the environment needs to be set up already
+// [ such as installing autoconfig, make, etc ]
+// How to use:
+// 1) install all of the necessary tools to build libtensorflow-core.a
+// 2) inside Android Studio IDE, uncomment buildTensorFlow in
+// whenTaskAdded{...}
+// 3) re-sync and re-build. It could take a long time if NOT building
+// with multiple processes.
+import org.apache.tools.ant.taskdefs.condition.Os
+
+Properties properties = new Properties()
+properties.load(project.rootProject.file('local.properties')
+ .newDataInputStream())
+def ndkDir = properties.getProperty('ndk.dir')
+if (ndkDir == null || ndkDir == "") {
+ ndkDir = System.getenv('ANDROID_NDK_HOME')
+}
+
+if(! Os.isFamily(Os.FAMILY_WINDOWS)) {
+ // This script is for non-Windows OS. For Windows OS, MANUALLY build
+ // (or copy the built) libs/headers to the
+ // ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen
+ // refer to CMakeLists.txt about lib and header directories for details
+ task buildTensorflow(type: Exec) {
+ group 'buildTensorflowLib'
+ workingDir getProjectDir().toString() + '/../../../../'
+ environment PATH: '/opt/local/bin:/opt/local/sbin:' +
+ System.getenv('PATH')
+ environment NDK_ROOT: ndkDir
+ commandLine 'tensorflow/contrib/makefile/build_all_android.sh'
+ }
+
+ tasks.whenTaskAdded { task ->
+ group 'buildTensorflowLib'
+ if (task.name.toLowerCase().contains('sources')) {
+ def tensorflowTarget = new File(getProjectDir().toString() +
+ '/../../makefile/gen/lib/libtensorflow-core.a')
+ if (!tensorflowTarget.exists()) {
+ // Note:
+ // just uncomment this line to use it:
+ // it can take long time to build by default
+ // it is disabled to avoid false first impression
+ // task.dependsOn buildTensorflow
+ }
+ }
+ }
+}
+
+dependencies {
+ compile fileTree(dir: 'libs', include: ['*.jar'])
+}
diff --git a/tensorflow/contrib/android/cmake/src/main/AndroidManifest.xml b/tensorflow/contrib/android/cmake/src/main/AndroidManifest.xml
new file mode 100644
index 0000000000..bced47e046
--- /dev/null
+++ b/tensorflow/contrib/android/cmake/src/main/AndroidManifest.xml
@@ -0,0 +1,9 @@
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="org.tensorflow.contrib.android">
+
+ <application android:allowBackup="true" android:label="@string/app_name"
+ android:supportsRtl="true">
+
+ </application>
+
+</manifest>
diff --git a/tensorflow/contrib/android/cmake/src/main/res/values/strings.xml b/tensorflow/contrib/android/cmake/src/main/res/values/strings.xml
new file mode 100644
index 0000000000..92dc3a1baf
--- /dev/null
+++ b/tensorflow/contrib/android/cmake/src/main/res/values/strings.xml
@@ -0,0 +1,3 @@
+<resources>
+ <string name="app_name">TensorFlowInference</string>
+</resources>
diff --git a/tensorflow/contrib/bayesflow/python/ops/special_math.py b/tensorflow/contrib/bayesflow/python/ops/special_math.py
index 77e7c0e093..5e5cde5c1f 100644
--- a/tensorflow/contrib/bayesflow/python/ops/special_math.py
+++ b/tensorflow/contrib/bayesflow/python/ops/special_math.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
__all__ = [
@@ -90,9 +91,9 @@ def _ndtr(x):
0.5 * math.sqrt(2.), dtype=x.dtype, name="half_sqrt_2")
w = x * half_sqrt_2
z = math_ops.abs(w)
- y = math_ops.select(math_ops.less(z, half_sqrt_2),
+ y = array_ops.where(math_ops.less(z, half_sqrt_2),
1. + math_ops.erf(w),
- math_ops.select(math_ops.greater(w, 0.),
+ array_ops.where(math_ops.greater(w, 0.),
2. - math_ops.erfc(z),
math_ops.erfc(z)))
return 0.5 * y
@@ -180,10 +181,10 @@ def log_ndtr(x, series_order=3, name="log_ndtr"):
# the gradient of a select involves the calculation 1*dy+0*(-inf)=nan
# regardless of whether dy is finite. Note that the minimum is a NOP if
# the branch is chosen.
- return math_ops.select(
+ return array_ops.where(
math_ops.greater(x, upper_segment),
-_ndtr(-x), # log(1-x) ~= -x, x << 1
- math_ops.select(math_ops.greater(x, lower_segment),
+ array_ops.where(math_ops.greater(x, lower_segment),
math_ops.log(_ndtr(math_ops.maximum(x, lower_segment))),
_log_ndtr_lower(math_ops.minimum(x, lower_segment),
series_order)))
diff --git a/tensorflow/contrib/distributions/python/ops/beta.py b/tensorflow/contrib/distributions/python/ops/beta.py
index 1abe77e295..84839f40c4 100644
--- a/tensorflow/contrib/distributions/python/ops/beta.py
+++ b/tensorflow/contrib/distributions/python/ops/beta.py
@@ -252,7 +252,7 @@ class Beta(distribution.Distribution):
mode = (self.a - 1.)/ (self.a_b_sum - 2.)
if self.allow_nan_stats:
nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
- return math_ops.select(
+ return array_ops.where(
math_ops.logical_and(
math_ops.greater(self.a, 1.),
math_ops.greater(self.b, 1.)),
diff --git a/tensorflow/contrib/distributions/python/ops/bijector.py b/tensorflow/contrib/distributions/python/ops/bijector.py
index c89108b3f6..b29c272405 100644
--- a/tensorflow/contrib/distributions/python/ops/bijector.py
+++ b/tensorflow/contrib/distributions/python/ops/bijector.py
@@ -1468,7 +1468,7 @@ class ScaleAndShift(Bijector):
batch_ndims: `Tensor` (0D, `int32`). The ndims of the `batch` portion.
"""
ndims = array_ops.rank(scale)
- left = math_ops.select(
+ left = array_ops.where(
math_ops.reduce_any([
math_ops.reduce_all([
math_ops.equal(ndims, 0),
@@ -1478,7 +1478,7 @@ class ScaleAndShift(Bijector):
math_ops.equal(ndims, 2),
math_ops.equal(event_ndims, 1)
])]), 1, 0)
- right = math_ops.select(math_ops.equal(event_ndims, 0), 2, 0)
+ right = array_ops.where(math_ops.equal(event_ndims, 0), 2, 0)
pad = array_ops.concat(0, (
array_ops.ones([left], dtype=dtypes.int32),
array_ops.shape(scale),
diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet.py b/tensorflow/contrib/distributions/python/ops/dirichlet.py
index 2a2ea4ec26..c485038fb2 100644
--- a/tensorflow/contrib/distributions/python/ops/dirichlet.py
+++ b/tensorflow/contrib/distributions/python/ops/dirichlet.py
@@ -239,7 +239,7 @@ class Dirichlet(distribution.Distribution):
if self.allow_nan_stats:
nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
shape = array_ops.concat(0, (self.batch_shape(), self.event_shape()))
- return math_ops.select(
+ return array_ops.where(
math_ops.greater(self.alpha, 1.),
mode,
array_ops.fill(shape, nan, name="nan"))
diff --git a/tensorflow/contrib/distributions/python/ops/distribution.py b/tensorflow/contrib/distributions/python/ops/distribution.py
index cd193f4d6d..9f52e1f0dd 100644
--- a/tensorflow/contrib/distributions/python/ops/distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/distribution.py
@@ -130,7 +130,10 @@ class _DistributionMeta(abc.ABCMeta):
if not baseclasses: # Nothing to be done for Distribution
raise TypeError("Expected non-empty baseclass. Does Distribution "
"not subclass _BaseDistribution?")
- base = baseclasses[0]
+ which_base = [
+ base for base in baseclasses
+ if base == _BaseDistribution or issubclass(base, Distribution)]
+ base = which_base[0]
if base == _BaseDistribution: # Nothing to be done for Distribution
return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
if not issubclass(base, Distribution):
diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py
index 094854b672..1da931c08e 100644
--- a/tensorflow/contrib/distributions/python/ops/distribution_util.py
+++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py
@@ -335,7 +335,7 @@ def rotate_transpose(x, shift, name="rotate_transpose"):
# Finally, we transform shift by modulo length so it can be specified
# independently from the array upon which it operates (like python).
ndims = array_ops.rank(x)
- shift = math_ops.select(math_ops.less(shift, 0),
+ shift = array_ops.where(math_ops.less(shift, 0),
math_ops.mod(-shift, ndims),
ndims - math_ops.mod(shift, ndims))
first = math_ops.range(0, shift)
@@ -396,8 +396,8 @@ def pick_vector(cond,
false_vector.name, false_vector.dtype))
n = array_ops.shape(true_vector)[0]
return array_ops.slice(array_ops.concat(0, (true_vector, false_vector)),
- [math_ops.select(cond, 0, n)],
- [math_ops.select(cond, n, -1)])
+ [array_ops.where(cond, 0, n)],
+ [array_ops.where(cond, n, -1)])
def gen_new_seed(seed, salt):
@@ -578,8 +578,8 @@ class AppendDocstring(object):
if "\n" in value:
raise ValueError(
"Parameter description for \"%s\" contains newlines." % key)
- bullets.append("* <b>`%s`</b>: %s" % (key, value))
- self._additional_note += ("\n\n##### <b>`condition_kwargs`</b>:\n\n" +
+ bullets.append("* `%s`: %s" % (key, value))
+ self._additional_note += ("\n\n##### `condition_kwargs`:\n\n" +
"\n".join(bullets))
def __call__(self, fn):
diff --git a/tensorflow/contrib/distributions/python/ops/gamma.py b/tensorflow/contrib/distributions/python/ops/gamma.py
index 6d2a1ee953..fcc5281c55 100644
--- a/tensorflow/contrib/distributions/python/ops/gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/gamma.py
@@ -208,7 +208,7 @@ class Gamma(distribution.Distribution):
mode = (self.alpha - 1.) / self.beta
if self.allow_nan_stats:
nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
- return math_ops.select(
+ return array_ops.where(
self.alpha >= 1.,
mode,
array_ops.fill(self.batch_shape(), nan, name="nan"))
diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py
new file mode 100644
index 0000000000..d0f3ce4933
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/gumbel.py
@@ -0,0 +1,205 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Gumbel distribution class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import numpy as np
+from tensorflow.contrib.distributions.python.ops import distribution
+from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
+from tensorflow.python.framework import common_shapes
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+
+
+class _Gumbel(distribution.Distribution):
+ """The scalar Gumbel distribution with location and scale parameters.
+
+ #### Mathematical details
+
+ The PDF of this distribution is:
+
+ ```pdf(x) = exp(-(x - loc)/scale - exp(-(x - loc)/scale))```
+
+ with support on (-inf, inf). The CDF of this distribution is:
+
+ ```cdf(x) = exp(-exp(-(x - loc)/scale))```
+
+ #### Examples
+
+ Examples of initialization of one or a batch of distributions.
+
+ ```python
+ # Define a single scalar Gumbel distribution.
+ dist = tf.contrib.distributions.Gumbel(loc=0., scale=3.)
+
+ # Evaluate the cdf at 1, returning a scalar.
+ dist.cdf(1.)
+
+ # Define a batch of two scalar valued Gumbels.
+ # The first has mean 1 and scale 11, the second 2 and 22.
+ dist = tf.contrib.distributions.Gumbel(loc=[1, 2.], scale=[11, 22.])
+
+ # Evaluate the pdf of the first distribution on 0, and the second on 1.5,
+ # returning a length two tensor.
+ dist.pdf([0, 1.5])
+
+ # Get 3 samples, returning a 3 x 2 tensor.
+ dist.sample([3])
+ ```
+
+ Arguments are broadcast when possible.
+
+ ```python
+ # Define a batch of two scalar valued Logistics.
+ # Both have mean 1, but different scales.
+ dist = tf.contrib.distributions.Gumbel(loc=1., scale=[11, 22.])
+
+ # Evaluate the pdf of both distributions on the same point, 3.0,
+ # returning a length 2 tensor.
+ dist.pdf(3.0)
+ ```
+
+ """
+
+ def __init__(self,
+ loc,
+ scale,
+ validate_args=False,
+ allow_nan_stats=True,
+ name="Gumbel"):
+ """Construct Gumbel distributions with location and scale `loc` and `scale`.
+
+ The parameters `loc` and `scale` must be shaped in a way that supports
+ broadcasting (e.g. `loc + scale` is a valid operation).
+
+ Args:
+ loc: Floating point tensor, the means of the distribution(s).
+ scale: Floating point tensor, the scales of the distribution(s).
+ scale must contain only positive values.
+ validate_args: `Boolean`, default `False`. Whether to assert that
+ `scale > 0`. If `validate_args` is `False`, correct output is not
+ guaranteed when input is invalid.
+ allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
+ exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+ batch member. If `True`, batch members with valid parameters leading to
+ undefined statistics will return NaN for this statistic.
+ name: The name to give Ops created by the initializer.
+
+ Raises:
+ TypeError: if loc and scale are different dtypes.
+ """
+ parameters = locals()
+ parameters.pop("self")
+ with ops.name_scope(name, values=[loc, scale]) as ns:
+ with ops.control_dependencies([check_ops.assert_positive(scale)] if
+ validate_args else []):
+ self._loc = array_ops.identity(loc, name="loc")
+ self._scale = array_ops.identity(scale, name="scale")
+ contrib_tensor_util.assert_same_float_dtype((self._loc, self._scale))
+ super(_Gumbel, self).__init__(
+ dtype=self._scale.dtype,
+ is_continuous=True,
+ is_reparameterized=True,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ parameters=parameters,
+ graph_parents=[self._loc, self._scale],
+ name=ns)
+
+ @staticmethod
+ def _param_shapes(sample_shape):
+ return dict(
+ zip(("loc", "scale"), ([ops.convert_to_tensor(
+ sample_shape, dtype=dtypes.int32)] * 2)))
+
+ @property
+ def loc(self):
+ """Distribution parameter for the location."""
+ return self._loc
+
+ @property
+ def scale(self):
+ """Distribution parameter for scale."""
+ return self._scale
+
+ def _batch_shape(self):
+ return array_ops.shape(self.loc + self.scale)
+
+ def _get_batch_shape(self):
+ return common_shapes.broadcast_shape(self.loc.get_shape(),
+ self.scale.get_shape())
+
+ def _event_shape(self):
+ return constant_op.constant([], dtype=dtypes.int32)
+
+ def _get_event_shape(self):
+ return tensor_shape.scalar()
+
+ def _sample_n(self, n, seed=None):
+ shape = array_ops.concat(0, ([n], array_ops.shape(self.mean())))
+ np_dtype = self.dtype.as_numpy_dtype()
+ minval = np.nextafter(np_dtype(0), np_dtype(1))
+ uniform = random_ops.random_uniform(shape=shape,
+ minval=minval,
+ maxval=1,
+ dtype=self.dtype,
+ seed=seed)
+ sampled = -math_ops.log(-math_ops.log(uniform))
+ return sampled * self.scale + self.loc
+
+ def _log_prob(self, x):
+ z = self._z(x)
+ return - z - math_ops.log(self.scale) - math_ops.exp(-z)
+
+ def _prob(self, x):
+ return math_ops.exp(self._log_prob(x))
+
+ def _log_cdf(self, x):
+ return -math_ops.exp(-self._z(x))
+
+ def _cdf(self, x):
+ return math_ops.exp(-math_ops.exp(-self._z(x)))
+
+ def _entropy(self):
+ # Use broadcasting rules to calculate the full broadcast sigma.
+ scale = self.scale * array_ops.ones_like(self.loc)
+ return 1 + math_ops.log(scale) + np.euler_gamma
+
+ def _mean(self):
+ return self.loc + self.scale * np.euler_gamma
+
+ def _variance(self):
+ return math_ops.square(self.std())
+
+ def _std(self):
+ return self.scale * array_ops.ones_like(self.loc) * math.pi / math.sqrt(6)
+
+ def _mode(self):
+ return self.loc * array_ops.ones_like(self.scale)
+
+ def _z(self, x):
+ """Standardize input `x` to a unit logistic."""
+ with ops.name_scope("standardize", values=[x]):
+ return (x - self.loc) / self.scale
diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
index ffe72f09a9..feb0bf2f90 100644
--- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
@@ -185,7 +185,7 @@ class InverseGamma(distribution.Distribution):
mean = self.beta / (self.alpha - 1.)
if self.allow_nan_stats:
nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
- return math_ops.select(
+ return array_ops.where(
self.alpha > 1., mean,
array_ops.fill(self.batch_shape(), nan, name="nan"))
else:
@@ -204,7 +204,7 @@ class InverseGamma(distribution.Distribution):
(math_ops.square(self.alpha - 1.) * (self.alpha - 2.)))
if self.allow_nan_stats:
nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
- return math_ops.select(
+ return array_ops.where(
self.alpha > 2., var,
array_ops.fill(self.batch_shape(), nan, name="nan"))
else:
diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py
new file mode 100644
index 0000000000..9a20c653ae
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/logistic.py
@@ -0,0 +1,210 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Logistic distribution class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import numpy as np
+from tensorflow.contrib.distributions.python.ops import distribution
+from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
+from tensorflow.python.framework import common_shapes
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+
+
+class _Logistic(distribution.Distribution):
+ """The scalar Logistic distribution with location and scale parameters.
+
+ #### Mathematical details
+
+ The CDF of this distribution is:
+
+ ```cdf(x) = 1/(1+exp(-(x - loc) / scale))```
+
+ with support on (-inf, inf).
+
+ #### Examples
+
+ Examples of initialization of one or a batch of distributions.
+
+ ```python
+ # Define a single scalar Logistic distribution.
+ dist = tf.contrib.distributions.Logistic(loc=0., scale=3.)
+
+ # Evaluate the cdf at 1, returning a scalar.
+ dist.cdf(1.)
+
+ # Define a batch of two scalar valued Logistics.
+ # The first has mean 1 and scale 11, the second 2 and 22.
+ dist = tf.contrib.distributions.Logistic(loc=[1, 2.], scale=[11, 22.])
+
+ # Evaluate the pdf of the first distribution on 0, and the second on 1.5,
+ # returning a length two tensor.
+ dist.pdf([0, 1.5])
+
+ # Get 3 samples, returning a 3 x 2 tensor.
+ dist.sample([3])
+ ```
+
+ Arguments are broadcast when possible.
+
+ ```python
+ # Define a batch of two scalar valued Logistics.
+ # Both have mean 1, but different scales.
+ dist = tf.contrib.distributions.Logistic(loc=1., scale=[11, 22.])
+
+ # Evaluate the pdf of both distributions on the same point, 3.0,
+ # returning a length 2 tensor.
+ dist.pdf(3.0)
+ ```
+
+ """
+
+ def __init__(self,
+ loc,
+ scale,
+ validate_args=False,
+ allow_nan_stats=True,
+ name="Logistic"):
+ """Construct Logistic distributions with mean and scale `loc` and `scale`.
+
+ The parameters `loc` and `scale` must be shaped in a way that supports
+ broadcasting (e.g. `loc + scale` is a valid operation).
+
+ Args:
+ loc: Floating point tensor, the means of the distribution(s).
+ scale: Floating point tensor, the scales of the distribution(s).
+ scale must contain only positive values.
+ validate_args: `Boolean`, default `False`. Whether to assert that
+ `scale > 0`. If `validate_args` is `False`, correct output is not
+ guaranteed when input is invalid.
+ allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
+ exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+ batch member. If `True`, batch members with valid parameters leading to
+ undefined statistics will return NaN for this statistic.
+ name: The name to give Ops created by the initializer.
+
+ Raises:
+ TypeError: if loc and scale are different dtypes.
+ """
+ parameters = locals()
+ parameters.pop("self")
+ with ops.name_scope(name, values=[loc, scale]) as ns:
+ with ops.control_dependencies([check_ops.assert_positive(scale)] if
+ validate_args else []):
+ self._loc = array_ops.identity(loc, name="loc")
+ self._scale = array_ops.identity(scale, name="scale")
+ contrib_tensor_util.assert_same_float_dtype((self._loc, self._scale))
+ super(_Logistic, self).__init__(
+ dtype=self._scale.dtype,
+ is_continuous=True,
+ is_reparameterized=True,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ parameters=parameters,
+ graph_parents=[self._loc, self._scale],
+ name=ns)
+
+ @staticmethod
+ def _param_shapes(sample_shape):
+ return dict(
+ zip(("loc", "scale"), ([ops.convert_to_tensor(
+ sample_shape, dtype=dtypes.int32)] * 2)))
+
+ @property
+ def loc(self):
+ """Distribution parameter for the location."""
+ return self._loc
+
+ @property
+ def scale(self):
+ """Distribution parameter for scale."""
+ return self._scale
+
+ def _batch_shape(self):
+ return array_ops.shape(self.loc + self.scale)
+
+ def _get_batch_shape(self):
+ return common_shapes.broadcast_shape(self.loc.get_shape(),
+ self.scale.get_shape())
+
+ def _event_shape(self):
+ return constant_op.constant([], dtype=dtypes.int32)
+
+ def _get_event_shape(self):
+ return tensor_shape.scalar()
+
+ def _sample_n(self, n, seed=None):
+ shape = array_ops.concat(0, ([n], array_ops.shape(self.mean())))
+ np_dtype = self.dtype.as_numpy_dtype()
+ minval = np.nextafter(np_dtype(0), np_dtype(1))
+ uniform = random_ops.random_uniform(shape=shape,
+ minval=minval,
+ maxval=1,
+ dtype=self.dtype,
+ seed=seed)
+ sampled = math_ops.log(uniform) - math_ops.log(1-uniform)
+ return sampled * self.scale + self.loc
+
+ def _log_prob(self, x):
+ z = self._z(x)
+ return - z - math_ops.log(self.scale) - 2*nn_ops.softplus(-z)
+
+ def _prob(self, x):
+ return math_ops.exp(self._log_prob(x))
+
+ def _log_cdf(self, x):
+ return nn_ops.softplus(-self._z(x))
+
+ def _cdf(self, x):
+ return math_ops.sigmoid(self._z(x))
+
+ def _log_survival_function(self, x):
+ return nn_ops.softplus(self._z(x))
+
+ def _survival_function(self, x):
+ return math_ops.sigmoid(-self._z(x))
+
+ def _entropy(self):
+ # Use broadcasting rules to calculate the full broadcast sigma.
+ scale = self.scale * array_ops.ones_like(self.loc)
+ return 2 + math_ops.log(scale)
+
+ def _mean(self):
+ return self.loc * array_ops.ones_like(self.scale)
+
+ def _variance(self):
+ return math_ops.square(self.std())
+
+ def _std(self):
+ return self.scale * array_ops.ones_like(self.loc) * math.pi / math.sqrt(3)
+
+ def _mode(self):
+ return self._mean()
+
+ def _z(self, x):
+ """Standardize input `x` to a unit logistic."""
+ with ops.name_scope("standardize", values=[x]):
+ return (x - self.loc) / self.scale
diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
new file mode 100644
index 0000000000..bb05c90a12
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
@@ -0,0 +1,262 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The OneHotCategorical distribution class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distributions.python.ops import distribution
+from tensorflow.contrib.distributions.python.ops import distribution_util
+from tensorflow.contrib.distributions.python.ops import kullback_leibler
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+
+
+class _OneHotCategorical(distribution.Distribution):
+ """OneHotCategorical distribution.
+
+ The categorical distribution is parameterized by the log-probabilities
+ of a set of classes. The difference between OneHotCategorical and Categorical
+ distributions is that OneHotCategorical is a discrete distribution over
+ one-hot bit vectors whereas Categorical is a discrete distribution over
+ positive integers.
+
+ This class provides methods to create indexed batches of OneHotCategorical
+ distributions. If the provided `logits` or `p` is rank 2 or higher, for
+ every fixed set of leading dimensions, the last dimension represents one
+ single OneHotCategorical distribution. When calling distribution
+ functions (e.g. `dist.prob(x)`), `logits` and `x` are broadcast to the
+ same shape (if possible). In all cases, the last dimension of `logits/x`
+ represents single OneHotCategorical distributions.
+
+ #### Examples
+
+ Creates a 3-class distiribution, with the 2nd class, the most likely to be
+ drawn from.
+
+ ```python
+ p = [0.1, 0.5, 0.4]
+ dist = OneHotCategorical(p=p)
+ ```
+
+ Creates a 3-class distiribution, with the 2nd class the most likely to be
+ drawn from, using logits.
+
+ ```python
+ logits = [-2, 2, 0]
+ dist = OneHotCategorical(logits=logits)
+ ```
+
+ Creates a 3-class distribution, with the 3rd class is most likely to be drawn.
+
+ ```python
+ # counts is a scalar.
+ p = [0.1, 0.4, 0.5]
+ dist = OneHotCategorical(p=p)
+ dist.pmf([0,1,0]) # Shape []
+
+ # p will be broadcast to [[0.1, 0.4, 0.5], [0.1, 0.4, 0.5]] to match.
+ samples = [[0,1,0], [1,0,0]]
+ dist.pmf(samples) # Shape [2]
+ ```
+
+ """
+
+ def __init__(
+ self,
+ logits=None,
+ p=None,
+ dtype=dtypes.int32,
+ validate_args=False,
+ allow_nan_stats=True,
+ name="OneHotCategorical"):
+ """Initialize OneHotCategorical distributions using class log-probabilities.
+
+ Args:
+ logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities
+ of a set of Categorical distributions. The first `N - 1` dimensions
+ index into a batch of independent distributions and the last dimension
+ represents a vector of logits for each class. Only one of `logits` or
+ `p` should be passed in.
+ p: An N-D `Tensor`, `N >= 1`, representing the probabilities
+ of a set of Categorical distributions. The first `N - 1` dimensions
+ index into a batch of independent distributions and the last dimension
+ represents a vector of probabilities for each class. Only one of
+ `logits` or `p` should be passed in.
+ dtype: The type of the event samples (default: int32).
+ validate_args: Unused in this distribution.
+ allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
+ exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+ batch member. If `True`, batch members with valid parameters leading to
+ undefined statistics will return NaN for this statistic.
+ name: A name for this distribution (optional).
+ """
+ parameters = locals()
+ parameters.pop("self")
+ with ops.name_scope(name, values=[logits]) as ns:
+ self._logits, self._p = distribution_util.get_logits_and_prob(
+ name=name, logits=logits, p=p, validate_args=validate_args,
+ multidimensional=True)
+
+ logits_shape_static = self._logits.get_shape().with_rank_at_least(1)
+ if logits_shape_static.ndims is not None:
+ self._batch_rank = ops.convert_to_tensor(
+ logits_shape_static.ndims - 1,
+ dtype=dtypes.int32,
+ name="batch_rank")
+ else:
+ with ops.name_scope(name="batch_rank"):
+ self._batch_rank = array_ops.rank(self._logits) - 1
+
+ logits_shape = array_ops.shape(self._logits, name="logits_shape")
+ if logits_shape_static[-1].value is not None:
+ self._num_classes = ops.convert_to_tensor(
+ logits_shape_static[-1].value,
+ dtype=dtypes.int32,
+ name="num_classes")
+ else:
+ self._num_classes = array_ops.gather(logits_shape,
+ self._batch_rank,
+ name="num_classes")
+
+ if logits_shape_static[:-1].is_fully_defined():
+ self._batch_shape_val = constant_op.constant(
+ logits_shape_static[:-1].as_list(),
+ dtype=dtypes.int32,
+ name="batch_shape")
+ else:
+ with ops.name_scope(name="batch_shape"):
+ self._batch_shape_val = logits_shape[:-1]
+ super(_OneHotCategorical, self).__init__(
+ dtype=dtype,
+ is_continuous=False,
+ is_reparameterized=False,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ parameters=parameters,
+ graph_parents=[self._logits, self._num_classes],
+ name=ns)
+
+ @property
+ def num_classes(self):
+ """Scalar `int32` tensor: the number of classes."""
+ return self._num_classes
+
+ @property
+ def logits(self):
+ """Vector of coordinatewise logits."""
+ return self._logits
+
+ @property
+ def p(self):
+ """Vector of probabilities summing to one.
+
+ Each element is the probability of drawing that coordinate."""
+ return self._p
+
+ def _batch_shape(self):
+ # Use identity to inherit callers "name".
+ return array_ops.identity(self._batch_shape_val)
+
+ def _get_batch_shape(self):
+ return self.logits.get_shape()[:-1]
+
+ def _event_shape(self):
+ return array_ops.shape(self.logits)[-1]
+
+ def _get_event_shape(self):
+ return self.logits.get_shape().with_rank_at_least(1)[-1:]
+
+ def _sample_n(self, n, seed=None):
+ sample_shape = array_ops.concat(0, ([n], array_ops.shape(self.logits)))
+ logits = self.logits
+ if logits.get_shape().ndims == 2:
+ logits_2d = logits
+ else:
+ logits_2d = array_ops.reshape(logits, [-1, self.num_classes])
+ samples = random_ops.multinomial(logits_2d, n, seed=seed)
+ samples = array_ops.transpose(samples)
+ samples = array_ops.one_hot(samples, self.num_classes, dtype=self.dtype)
+ ret = array_ops.reshape(samples, sample_shape)
+ return ret
+
+ def _log_prob(self, x):
+ x = ops.convert_to_tensor(x, name="x")
+ # broadcast logits or x if need be.
+ logits = self.logits
+ if (not x.get_shape().is_fully_defined() or
+ not logits.get_shape().is_fully_defined() or
+ x.get_shape() != logits.get_shape()):
+ logits = array_ops.ones_like(x, dtype=logits.dtype) * logits
+ x = array_ops.ones_like(logits, dtype=x.dtype) * x
+
+ logits_shape = array_ops.shape(logits)
+ if logits.get_shape().ndims == 2:
+ logits_2d = logits
+ x_2d = x
+ else:
+ logits_2d = array_ops.reshape(logits, [-1, self.num_classes])
+ x_2d = array_ops.reshape(x, [-1, self.num_classes])
+ ret = -nn_ops.softmax_cross_entropy_with_logits(logits_2d, x_2d)
+ ret = array_ops.reshape(ret, logits_shape)
+ return ret
+
+ def _prob(self, x):
+ return math_ops.exp(self._log_prob(x))
+
+ def _entropy(self):
+ if self.logits.get_shape().ndims == 2:
+ logits_2d = self.logits
+ else:
+ logits_2d = array_ops.reshape(self.logits, [-1, self.num_classes])
+ histogram_2d = nn_ops.softmax(logits_2d)
+ ret = array_ops.reshape(
+ nn_ops.softmax_cross_entropy_with_logits(logits_2d, histogram_2d),
+ self.batch_shape())
+ ret.set_shape(self.get_batch_shape())
+ return ret
+
+ def _mode(self):
+ ret = math_ops.argmax(self.logits, axis=self._batch_rank)
+ ret = array_ops.one_hot(ret, self.num_classes, dtype=self.dtype)
+ ret.set_shape(self.logits.get_shape())
+ return ret
+
+
+@kullback_leibler.RegisterKL(_OneHotCategorical, _OneHotCategorical)
+def _kl_categorical_categorical(a, b, name=None):
+ """Calculate the batched KL divergence KL(a || b) with a, b OneHotCategorical.
+
+ Args:
+ a: instance of a OneHotCategorical distribution object.
+ b: instance of a OneHotCategorical distribution object.
+ name: (optional) Name to use for created operations.
+ default is "kl_categorical_categorical".
+
+ Returns:
+ Batchwise KL(a || b)
+ """
+ with ops.name_scope(
+ name, "kl_categorical_categorical", [a.logits, b.logits]):
+ # sum(p*ln(p/q))
+ return math_ops.reduce_sum(
+ nn_ops.softmax(a.logits)*(nn_ops.log_softmax(a.logits)
+ - nn_ops.log_softmax(b.logits)), reduction_indices=[-1])
diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
index bca3f99604..fd3ec553c0 100644
--- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
@@ -284,11 +284,11 @@ class QuantizedDistribution(distributions.Distribution):
result_so_far = math_ops.ceil(x_samps)
if lower_cutoff is not None:
- result_so_far = math_ops.select(result_so_far < lower_cutoff,
+ result_so_far = array_ops.where(result_so_far < lower_cutoff,
lower_cutoff * ones, result_so_far)
if upper_cutoff is not None:
- result_so_far = math_ops.select(result_so_far > upper_cutoff,
+ result_so_far = array_ops.where(result_so_far > upper_cutoff,
upper_cutoff * ones, result_so_far)
return result_so_far
@@ -327,8 +327,8 @@ class QuantizedDistribution(distributions.Distribution):
# In either case, we are doing Log[ exp{big} - exp{small} ]
# We want to use the sf items precisely when we are on the right side of the
# median, which occurs when logsf_y < logcdf_y.
- big = math_ops.select(logsf_y < logcdf_y, logsf_y_minus_1, logcdf_y)
- small = math_ops.select(logsf_y < logcdf_y, logsf_y, logcdf_y_minus_1)
+ big = array_ops.where(logsf_y < logcdf_y, logsf_y_minus_1, logcdf_y)
+ small = array_ops.where(logsf_y < logcdf_y, logsf_y, logcdf_y_minus_1)
return _logsum_expbig_minus_expsmall(big, small)
@@ -357,7 +357,7 @@ class QuantizedDistribution(distributions.Distribution):
cdf_y_minus_1 = self.cdf(y - 1)
# sf_prob has greater precision iff we're on the right side of the median.
- return math_ops.select(
+ return array_ops.where(
sf_y < cdf_y, # True iff we're on the right side of the median.
sf_y_minus_1 - sf_y,
cdf_y - cdf_y_minus_1)
@@ -386,9 +386,9 @@ class QuantizedDistribution(distributions.Distribution):
# Re-define values at the cutoffs.
if lower_cutoff is not None:
neg_inf = -np.inf * array_ops.ones_like(result_so_far)
- result_so_far = math_ops.select(j < lower_cutoff, neg_inf, result_so_far)
+ result_so_far = array_ops.where(j < lower_cutoff, neg_inf, result_so_far)
if upper_cutoff is not None:
- result_so_far = math_ops.select(j >= upper_cutoff,
+ result_so_far = array_ops.where(j >= upper_cutoff,
array_ops.zeros_like(result_so_far),
result_so_far)
@@ -418,11 +418,11 @@ class QuantizedDistribution(distributions.Distribution):
# Re-define values at the cutoffs.
if lower_cutoff is not None:
- result_so_far = math_ops.select(j < lower_cutoff,
+ result_so_far = array_ops.where(j < lower_cutoff,
array_ops.zeros_like(result_so_far),
result_so_far)
if upper_cutoff is not None:
- result_so_far = math_ops.select(j >= upper_cutoff,
+ result_so_far = array_ops.where(j >= upper_cutoff,
array_ops.ones_like(result_so_far),
result_so_far)
@@ -452,12 +452,12 @@ class QuantizedDistribution(distributions.Distribution):
# Re-define values at the cutoffs.
if lower_cutoff is not None:
- result_so_far = math_ops.select(j < lower_cutoff,
+ result_so_far = array_ops.where(j < lower_cutoff,
array_ops.zeros_like(result_so_far),
result_so_far)
if upper_cutoff is not None:
neg_inf = -np.inf * array_ops.ones_like(result_so_far)
- result_so_far = math_ops.select(j >= upper_cutoff, neg_inf, result_so_far)
+ result_so_far = array_ops.where(j >= upper_cutoff, neg_inf, result_so_far)
return result_so_far
@@ -485,11 +485,11 @@ class QuantizedDistribution(distributions.Distribution):
# Re-define values at the cutoffs.
if lower_cutoff is not None:
- result_so_far = math_ops.select(j < lower_cutoff,
+ result_so_far = array_ops.where(j < lower_cutoff,
array_ops.ones_like(result_so_far),
result_so_far)
if upper_cutoff is not None:
- result_so_far = math_ops.select(j >= upper_cutoff,
+ result_so_far = array_ops.where(j >= upper_cutoff,
array_ops.zeros_like(result_so_far),
result_so_far)
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
new file mode 100644
index 0000000000..7994c5d433
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
@@ -0,0 +1,213 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The RelaxedBernoulli distribution class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distributions.python.ops import bijector
+from tensorflow.contrib.distributions.python.ops import distribution_util
+from tensorflow.contrib.distributions.python.ops import logistic
+from tensorflow.contrib.distributions.python.ops import transformed_distribution
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+
+
+class _RelaxedBernoulli(transformed_distribution.TransformedDistribution):
+ """RelaxedBernoulli distribution with temperature and logits parameters.
+
+ The RelaxedBernoulli is a distribution over the unit interval (0,1), which
+ continuously approximates a Bernoulli. The degree of approximation is
+ controlled by a temperature: as the temperaturegoes to 0 the RelaxedBernoulli
+ becomes discrete with a distribution described by the `logits` or `p`
+ parameters, as the temperature goes to infinity the RelaxedBernoulli
+ becomes the constant distribution that is identically 0.5.
+
+ The RelaxedBernoulli distribution is a reparameterized continuous
+ distribution that is the binary special case of the RelaxedOneHotCategorical
+ distribution (Maddison et al., 2016; Jang et al., 2016). For details on the
+ binary special case see the appendix of Maddison et al. (2016) where it is
+ referred to as BinConcrete. If you use this distribution, please cite both
+ papers.
+
+ Some care needs to be taken for loss functions that depend on the
+ log-probability of RelaxedBernoullis, because computing log-probabilities of
+ the RelaxedBernoulli can suffer from underflow issues. In many case loss
+ functions such as these are invariant under invertible transformations of
+ the random variables. The KL divergence, found in the variational autoencoder
+ loss, is an example. Because RelaxedBernoullis are sampled by by a Logistic
+ random variable followed by a `tf.sigmoid` op, one solution is to treat
+ the Logistic as the random variable and `tf.sigmoid` as downstream. The
+ KL divergences of two Logistics, which are always followed by a `tf.sigmoid`
+ op, is equivalent to evaluating KL divergences of RelaxedBernoulli samples.
+ See Maddison et al., 2016 for more details where this distribution is called
+ the BinConcrete.
+
+ #### Examples
+
+ Creates three continuous distributions, which approximate 3 Bernoullis with
+ probabilities (0.1, 0.5, 0.4). Samples from these distributions will be in
+ the unit interval (0,1).
+
+ ```python
+ temperature = 0.5
+ p = [0.1, 0.5, 0.4]
+ dist = RelaxedBernoulli(temperature, p=p)
+ ```
+
+ Creates three continuous distributions, which approximate 3 Bernoullis with
+ logits (-2, 2, 0). Samples from these distributions will be in
+ the unit interval (0,1).
+
+ ```python
+ temperature = 0.5
+ logits = [-2, 2, 0]
+ dist = RelaxedBernoulli(temperature, logits=logits)
+ ```
+
+ Creates three continuous distributions, whose sigmoid approximate 3 Bernoullis
+ with logits (-2, 2, 0).
+
+ ```python
+ temperature = 0.5
+ logits = [-2, 2, 0]
+ dist = Logistic(logits/temperature, 1./temperature)
+ samples = dist.sample()
+ sigmoid_samples = tf.sigmoid(samples)
+ # sigmoid_samples has the same distribution as samples from
+ # RelaxedBernoulli(temperature, logits=logits)
+ ```
+
+ Creates three continuous distributions, which approximate 3 Bernoullis with
+ logits (-2, 2, 0). Samples from these distributions will be in
+ the unit interval (0,1). Because the temperature is very low, samples from
+ these distributions are almost discrete, usually taking values very close to 0
+ or 1.
+
+ ```python
+ temperature = 1e-5
+ logits = [-2, 2, 0]
+ dist = RelaxedBernoulli(temperature, logits=logits)
+ ```
+
+ Creates three continuous distributions, which approximate 3 Bernoullis with
+ logits (-2, 2, 0). Samples from these distributions will be in
+ the unit interval (0,1). Because the temperature is very high, samples from
+ these distributions are usually close to the (0.5, 0.5, 0.5) vector.
+
+ ```python
+ temperature = 100
+ logits = [-2, 2, 0]
+ dist = RelaxedBernoulli(temperature, logits=logits)
+ ```
+
+ Chris J. Maddison, Andriy Mnih, and Yee Whye Teh. The Concrete Distribution:
+ A Continuous Relaxation of Discrete Random Variables. 2016.
+
+ Eric Jang, Shixiang Gu, and Ben Poole. Categorical Reparameterization with
+ Gumbel-Softmax. 2016.
+ """
+
+ def __init__(self,
+ temperature,
+ logits=None,
+ p=None,
+ validate_args=False,
+ allow_nan_stats=True,
+ name="RelaxedBernoulli"):
+ """Construct RelaxedBernoulli distributions.
+
+ Args:
+ temperature: An 0-D `Tensor`, representing the temperature
+ of a set of RelaxedBernoulli distributions. The temperature should be
+ positive.
+ logits: An N-D `Tensor` representing the log-odds
+ of a positive event. Each entry in the `Tensor` parametrizes
+ an independent RelaxedBernoulli distribution where the probability of an
+ event is sigmoid(logits). Only one of `logits` or `p` should be passed
+ in.
+ p: An N-D `Tensor` representing the probability of a positive
+ event. Each entry in the `Tensor` parameterizes an independent
+ Bernoulli distribution. Only one of `logits` or `p` should be passed
+ in.
+ validate_args: `Boolean`, default `False`. Whether to validate that
+ `0 <= p <= 1`. If `validate_args` is `False`, and the inputs are
+ invalid, methods like `log_pmf` may return `NaN` values.
+ allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
+ exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+ batch member. If `True`, batch members with valid parameters leading to
+ undefined statistics will return NaN for this statistic.
+ name: A name for this distribution.
+
+ Raises:
+ ValueError: If p and logits are passed, or if neither are passed.
+ """
+ parameters = locals()
+ parameters.pop("self")
+ with ops.name_scope(name, values=[logits, p, temperature]) as ns:
+ with ops.control_dependencies([check_ops.assert_positive(temperature)]
+ if validate_args else []):
+ self._temperature = array_ops.identity(temperature, name="temperature")
+
+ self._logits, self._p = distribution_util.get_logits_and_prob(
+ logits=logits, p=p, validate_args=validate_args)
+ with ops.name_scope("q"):
+ self._q = 1. - self._p
+ dist = logistic._Logistic(self._logits / self._temperature,
+ 1./self._temperature,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ name=ns)
+
+ def inverse_log_det_jacobian_fn(y):
+ return -math_ops.reduce_sum(math_ops.log(y) + math_ops.log(1-y),
+ reduction_indices=-1)
+
+ sigmoidbijector = bijector.Inline(
+ forward_fn=math_ops.sigmoid,
+ inverse_fn=(lambda y: math_ops.log(y) - math_ops.log(1-y)),
+ inverse_log_det_jacobian_fn=inverse_log_det_jacobian_fn,
+ name="sigmoid")
+ super(_RelaxedBernoulli, self).__init__(dist,
+ sigmoidbijector,
+ name=name)
+
+ @staticmethod
+ def _param_shapes(sample_shape):
+ return {"logits": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)}
+
+ @property
+ def temperature(self):
+ """Distribution parameter for the location."""
+ return self._temperature
+
+ @property
+ def logits(self):
+ """Log-odds of success."""
+ return self._logits
+
+ @property
+ def p(self):
+ """Probability of success."""
+ return self._p
+
+ @property
+ def q(self):
+ """Probability of failure."""
+ return self._q
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
new file mode 100644
index 0000000000..1b60b32ff6
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
@@ -0,0 +1,420 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Relaxed OneHotCategorical distribution classes."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from tensorflow.contrib.distributions.python.ops import bijector
+from tensorflow.contrib.distributions.python.ops import distribution
+from tensorflow.contrib.distributions.python.ops import distribution_util
+from tensorflow.contrib.distributions.python.ops import transformed_distribution
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+
+
+class _ExpRelaxedOneHotCategorical(distribution.Distribution):
+ """ExpRelaxedOneHotCategorical distribution with temperature and logits.
+
+ An ExpRelaxedOneHotCategorical distribution is a log-transformed
+ RelaxedOneHotCategorical distribution. The RelaxedOneHotCategorical is a
+ distribution over random probability vectors, vectors of positive real
+ values that sum to one, which continuously approximates a OneHotCategorical.
+ The degree of approximation is controlled by a temperature: as the temperature
+ goes to 0 the RelaxedOneHotCategorical becomes discrete with a distribution
+ described by the logits, as the temperature goes to infinity the
+ RelaxedOneHotCategorical becomes the constant distribution that is identically
+ the constant vector of (1/num_classes, ..., 1/num_classes).
+
+ Because computing log-probabilities of the RelaxedOneHotCategorical can
+ suffer from underflow issues, this class is one solution for loss
+ functions that depend on log-probabilities, such as the KL Divergence found
+ in the variational autoencoder loss. The KL divergence between two
+ distributions is invariant under invertible transformations, so evaluating
+ KL divergences of ExpRelaxedOneHotCategorical samples, which are always
+ followed by a `tf.exp` op, is equivalent to evaluating KL divergences of
+ RelaxedOneHotCategorical samples. See the appendix of Maddison et al., 2016
+ for more mathematical details, where this distribution is called the
+ ExpConcrete.
+
+ #### Examples
+
+ Creates a continuous distribution, whoe exp approximates a 3-class one-hot
+ categorical distiribution. The 2nd class is the most likely to be the
+ largest component in samples drawn from this distribution. If those samples
+ are followed by a `tf.exp` op, then they are distributed as a relaxed onehot
+ categorical.
+
+ ```python
+ temperature = 0.5
+ p = [0.1, 0.5, 0.4]
+ dist = ExpRelaxedOneHotCategorical(temperature, p=p)
+ samples = dist.sample()
+ exp_samples = tf.exp(samples)
+ # exp_samples has the same distribution as samples from
+ # RelaxedOneHotCategorical(temperature, p=p)
+ ```
+
+ Creates a continuous distribution, whose exp approximates a 3-class one-hot
+ categorical distiribution. The 2nd class is the most likely to be the
+ largest component in samples drawn from this distribution.
+
+ ```python
+ temperature = 0.5
+ logits = [-2, 2, 0]
+ dist = ExpRelaxedOneHotCategorical(temperature, logits=logits)
+ samples = dist.sample()
+ exp_samples = tf.exp(samples)
+ # exp_samples has the same distribution as samples from
+ # RelaxedOneHotCategorical(temperature, p=p)
+ ```
+
+ Creates a continuous distribution, whose exp approximates a 3-class one-hot
+ categorical distiribution. Because the temperature is very low, samples from
+ this distribution are almost discrete, with one component almost 0 and the
+ others very negative. The 2nd class is the most likely to be the largest
+ component in samples drawn from this distribution.
+
+ ```python
+ temperature = 1e-5
+ logits = [-2, 2, 0]
+ dist = ExpRelaxedOneHotCategorical(temperature, logits=logits)
+ samples = dist.sample()
+ exp_samples = tf.exp(samples)
+ # exp_samples has the same distribution as samples from
+ # RelaxedOneHotCategorical(temperature, p=p)
+ ```
+
+ Creates a continuous distribution, whose exp approximates a 3-class one-hot
+ categorical distiribution. Because the temperature is very high, samples from
+ this distribution are usually close to the (-log(3), -log(3), -log(3)) vector.
+ The 2nd class is still the most likely to be the largest component
+ in samples drawn from this distribution.
+
+ ```python
+ temperature = 10
+ logits = [-2, 2, 0]
+ dist = ExpRelaxedOneHotCategorical(temperature, logits=logits)
+ samples = dist.sample()
+ exp_samples = tf.exp(samples)
+ # exp_samples has the same distribution as samples from
+ # RelaxedOneHotCategorical(temperature, p=p)
+ ```
+
+ Chris J. Maddison, Andriy Mnih, and Yee Whye Teh. The Concrete Distribution:
+ A Continuous Relaxation of Discrete Random Variables. 2016.
+ """
+
+ def __init__(
+ self,
+ temperature,
+ logits=None,
+ p=None,
+ dtype=dtypes.float32,
+ validate_args=False,
+ allow_nan_stats=True,
+ name="ExpRelaxedOneHotCategorical"):
+ """Initialize ExpRelaxedOneHotCategorical using class log-probabilities.
+
+ Args:
+ temperature: An 0-D `Tensor`, representing the temperature
+ of a set of ExpRelaxedCategorical distributions. The temperature should
+ be positive.
+ logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities
+ of a set of ExpRelaxedCategorical distributions. The first
+ `N - 1` dimensions index into a batch of independent distributions and
+ the last dimension represents a vector of logits for each class. Only
+ one of `logits` or `p` should be passed in.
+ p: An N-D `Tensor`, `N >= 1`, representing the probabilities
+ of a set of ExpRelaxedCategorical distributions. The first
+ `N - 1` dimensions index into a batch of independent distributions and
+ the last dimension represents a vector of probabilities for each
+ class. Only one of `logits` or `p` should be passed in.
+ dtype: The type of the event samples (default: int32).
+ validate_args: Unused in this distribution.
+ allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
+ exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+ batch member. If `True`, batch members with valid parameters leading to
+ undefined statistics will return NaN for this statistic.
+ name: A name for this distribution (optional).
+ """
+ parameters = locals()
+ parameters.pop("self")
+ with ops.name_scope(name, values=[logits, p, temperature]) as ns:
+ with ops.control_dependencies([check_ops.assert_positive(temperature)]
+ if validate_args else []):
+ self._temperature = array_ops.identity(temperature, name="temperature")
+ self._logits, self._p = distribution_util.get_logits_and_prob(
+ name=name, logits=logits, p=p, validate_args=validate_args,
+ multidimensional=True)
+
+ logits_shape_static = self._logits.get_shape().with_rank_at_least(1)
+ if logits_shape_static.ndims is not None:
+ self._batch_rank = ops.convert_to_tensor(
+ logits_shape_static.ndims - 1,
+ dtype=dtypes.int32,
+ name="batch_rank")
+ else:
+ with ops.name_scope(name="batch_rank"):
+ self._batch_rank = array_ops.rank(self._logits) - 1
+
+ logits_shape = array_ops.shape(self._logits, name="logits_shape")
+ if logits_shape_static[-1].value is not None:
+ self._num_classes = ops.convert_to_tensor(
+ logits_shape_static[-1].value,
+ dtype=dtypes.int32,
+ name="num_classes")
+ else:
+ self._num_classes = array_ops.gather(logits_shape,
+ self._batch_rank,
+ name="num_classes")
+
+ if logits_shape_static[:-1].is_fully_defined():
+ self._batch_shape_val = constant_op.constant(
+ logits_shape_static[:-1].as_list(),
+ dtype=dtypes.int32,
+ name="batch_shape")
+ else:
+ with ops.name_scope(name="batch_shape"):
+ self._batch_shape_val = logits_shape[:-1]
+ super(_ExpRelaxedOneHotCategorical, self).__init__(
+ dtype=dtype,
+ is_continuous=True,
+ is_reparameterized=True,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ parameters=parameters,
+ graph_parents=[self._logits, self._temperature, self._num_classes],
+ name=ns)
+
+ @property
+ def num_classes(self):
+ """Scalar `int32` tensor: the number of classes."""
+ return self._num_classes
+
+ @property
+ def temperature(self):
+ """A scalar representing the temperature."""
+ return self._temperature
+
+ @property
+ def logits(self):
+ """Vector of coordinatewise logits."""
+ return self._logits
+
+ @property
+ def p(self):
+ """Vector of probabilities summing to one."""
+ return self._p
+
+ def _batch_shape(self):
+ # Use identity to inherit callers "name".
+ return array_ops.identity(self._batch_shape_val)
+
+ def _get_batch_shape(self):
+ return self.logits.get_shape()[:-1]
+
+ def _event_shape(self):
+ return array_ops.shape(self.logits)[-1]
+
+ def _get_event_shape(self):
+ return self.logits.get_shape().with_rank_at_least(1)[-1:]
+
+ def _sample_n(self, n, seed=None):
+ sample_shape = array_ops.concat(0, ([n], array_ops.shape(self.logits)))
+ logits = self.logits * array_ops.ones(sample_shape)
+ if logits.get_shape().ndims == 2:
+ logits_2d = logits
+ else:
+ logits_2d = array_ops.reshape(logits, [-1, self.num_classes])
+ np_dtype = self.dtype.as_numpy_dtype()
+ minval = np.nextafter(np_dtype(0), np_dtype(1))
+ uniform = random_ops.random_uniform(shape=array_ops.shape(logits_2d),
+ minval=minval,
+ maxval=1,
+ dtype=self.dtype,
+ seed=seed)
+ gumbel = - math_ops.log(- math_ops.log(uniform))
+ noisy_logits = math_ops.div(gumbel + logits_2d, self.temperature)
+ samples = nn_ops.log_softmax(noisy_logits)
+ ret = array_ops.reshape(samples, sample_shape)
+ return ret
+
+ def _log_prob(self, x):
+ x = ops.convert_to_tensor(x, name="x")
+ x = self._assert_valid_sample(x)
+ # broadcast logits or x if need be.
+ logits = self.logits
+ if (not x.get_shape().is_fully_defined() or
+ not logits.get_shape().is_fully_defined() or
+ x.get_shape() != logits.get_shape()):
+ logits = array_ops.ones_like(x, dtype=logits.dtype) * logits
+ x = array_ops.ones_like(logits, dtype=x.dtype) * x
+
+ logits_shape = array_ops.shape(logits)
+ if logits.get_shape().ndims == 2:
+ logits_2d = logits
+ x_2d = x
+ else:
+ logits_2d = array_ops.reshape(logits, [-1, self.num_classes])
+ x_2d = array_ops.reshape(x, [-1, self.num_classes])
+ # compute the normalization constant
+ log_norm_const = (math_ops.lgamma(self.num_classes)
+ + (self.num_classes - 1)
+ * math_ops.log(self.temperature))
+ # compute the unnormalized density
+ log_softmax = nn_ops.log_softmax(logits_2d - x_2d * self.temperature)
+ log_unnorm_prob = math_ops.reduce_sum(log_softmax, [-1], keep_dims=False)
+ # combine unnormalized density with normalization constant
+ log_prob = log_norm_const + log_unnorm_prob
+ ret = array_ops.reshape(log_prob, logits_shape)
+ return ret
+
+ def _prob(self, x):
+ return math_ops.exp(self._log_prob(x))
+
+ def _assert_valid_sample(self, x):
+ if not self.validate_args: return x
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_non_positive(x),
+ distribution_util.assert_close(
+ array_ops.zeros((), dtype=self.dtype),
+ math_ops.reduce_logsumexp(x, reduction_indices=[-1])),
+ ], x)
+
+
+class _RelaxedOneHotCategorical(
+ transformed_distribution.TransformedDistribution):
+ """RelaxedOneHotCategorical distribution with temperature and logits.
+
+ The RelaxedOneHotCategorical is a distribution over random probability
+ vectors, vectors of positive real values that sum to one, which continuously
+ approximates a OneHotCategorical. The degree of approximation is controlled by
+ a temperature: as the temperaturegoes to 0 the RelaxedOneHotCategorical
+ becomes discrete with a distribution described by the `logits` or `p`
+ parameters, as the temperature goes to infinity the RelaxedOneHotCategorical
+ becomes the constant distribution that is identically the constant vector of
+ (1/num_classes, ..., 1/num_classes).
+
+ The RelaxedOneHotCategorical distribution was concurrently introduced as the
+ Gumbel-Softmax (Jang et al., 2016) and Concrete (Maddison et al., 2016)
+ distributions for use as a reparameterized continuous approximation to the
+ `Categorical` one-hot distribution. If you use this distribution, please cite
+ both papers.
+
+ #### Examples
+
+ Creates a continuous distribution, which approximates a 3-class one-hot
+ categorical distiribution. The 2nd class is the most likely to be the
+ largest component in samples drawn from this distribution.
+
+ ```python
+ temperature = 0.5
+ p = [0.1, 0.5, 0.4]
+ dist = RelaxedOneHotCategorical(temperature, p=p)
+ ```
+
+ Creates a continuous distribution, which approximates a 3-class one-hot
+ categorical distiribution. The 2nd class is the most likely to be the
+ largest component in samples drawn from this distribution.
+
+ ```python
+ temperature = 0.5
+ logits = [-2, 2, 0]
+ dist = RelaxedOneHotCategorical(temperature, logits=logits)
+ ```
+
+ Creates a continuous distribution, which approximates a 3-class one-hot
+ categorical distiribution. Because the temperature is very low, samples from
+ this distribution are almost discrete, with one component almost 1 and the
+ others nearly 0. The 2nd class is the most likely to be the largest component
+ in samples drawn from this distribution.
+
+ ```python
+ temperature = 1e-5
+ logits = [-2, 2, 0]
+ dist = RelaxedOneHotCategorical(temperature, logits=logits)
+ ```
+
+ Creates a continuous distribution, which approximates a 3-class one-hot
+ categorical distiribution. Because the temperature is very high, samples from
+ this distribution are usually close to the (1/3, 1/3, 1/3) vector. The 2nd
+ class is still the most likely to be the largest component
+ in samples drawn from this distribution.
+
+ ```python
+ temperature = 10
+ logits = [-2, 2, 0]
+ dist = RelaxedOneHotCategorical(temperature, logits=logits)
+ ```
+
+ Eric Jang, Shixiang Gu, and Ben Poole. Categorical Reparameterization with
+ Gumbel-Softmax. 2016.
+
+ Chris J. Maddison, Andriy Mnih, and Yee Whye Teh. The Concrete Distribution:
+ A Continuous Relaxation of Discrete Random Variables. 2016.
+ """
+
+ def __init__(
+ self,
+ temperature,
+ logits=None,
+ p=None,
+ dtype=dtypes.float32,
+ validate_args=False,
+ allow_nan_stats=True,
+ name="RelaxedOneHotCategorical"):
+ """Initialize RelaxedOneHotCategorical using class log-probabilities.
+
+ Args:
+ temperature: An 0-D `Tensor`, representing the temperature
+ of a set of RelaxedOneHotCategorical distributions. The temperature
+ should be positive.
+ logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities
+ of a set of RelaxedOneHotCategorical distributions. The first
+ `N - 1` dimensions index into a batch of independent distributions and
+ the last dimension represents a vector of logits for each class. Only
+ one of `logits` or `p` should be passed in.
+ p: An N-D `Tensor`, `N >= 1`, representing the probabilities
+ of a set of RelaxedOneHotCategorical distributions. The first
+ `N - 1` dimensions index into a batch of independent distributions and
+ the last dimension represents a vector of probabilities for each
+ class. Only one of `logits` or `p` should be passed in.
+ dtype: The type of the event samples (default: int32).
+ validate_args: Unused in this distribution.
+ allow_nan_stats: `Boolean`, default `True`. If `False`, raise an
+ exception if a statistic (e.g. mean/mode/etc...) is undefined for any
+ batch member. If `True`, batch members with valid parameters leading to
+ undefined statistics will return NaN for this statistic.
+ name: A name for this distribution (optional).
+ """
+ dist = _ExpRelaxedOneHotCategorical(temperature,
+ logits=logits,
+ p=p,
+ dtype=dtype,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats)
+ super(_RelaxedOneHotCategorical, self).__init__(dist,
+ bijector.Exp(),
+ name=name)
diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py
index f5b2d94a8a..fa1596d555 100644
--- a/tensorflow/contrib/distributions/python/ops/shape.py
+++ b/tensorflow/contrib/distributions/python/ops/shape.py
@@ -422,7 +422,7 @@ class _DistributionShape(object):
batch_shape = array_ops.slice(s, (1,), (self.batch_ndims,))
# Since sample_dims=1 and is left-most, we add 1 to the number of
# batch_ndims to get the event start dim.
- event_start = math_ops.select(
+ event_start = array_ops.where(
self._batch_ndims_is_0, 2, 1 + self.batch_ndims)
event_shape = array_ops.slice(s, (event_start,), (self.event_ndims,))
new_shape = array_ops.concat(0, (sample_shape, batch_shape, event_shape))
diff --git a/tensorflow/contrib/distributions/python/ops/student_t.py b/tensorflow/contrib/distributions/python/ops/student_t.py
index cf21aedc37..9c086d126c 100644
--- a/tensorflow/contrib/distributions/python/ops/student_t.py
+++ b/tensorflow/contrib/distributions/python/ops/student_t.py
@@ -230,7 +230,7 @@ class StudentT(distribution.Distribution):
mean = self.mu * self._ones()
if self.allow_nan_stats:
nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
- return math_ops.select(
+ return array_ops.where(
math_ops.greater(self.df, self._ones()), mean,
array_ops.fill(self.batch_shape(), nan, name="nan"))
else:
@@ -255,14 +255,14 @@ class StudentT(distribution.Distribution):
math_ops.square(self.sigma) * self.df / (self.df - 2))
# When 1 < df <= 2, variance is infinite.
inf = np.array(np.inf, dtype=self.dtype.as_numpy_dtype())
- result_where_defined = math_ops.select(
+ result_where_defined = array_ops.where(
math_ops.greater(self.df, array_ops.fill(self.batch_shape(), 2.)),
var,
array_ops.fill(self.batch_shape(), inf, name="inf"))
if self.allow_nan_stats:
nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
- return math_ops.select(
+ return array_ops.where(
math_ops.greater(self.df, self._ones()),
result_where_defined,
array_ops.fill(self.batch_shape(), nan, name="nan"))
diff --git a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py
index 074c51b8d5..74115fe542 100644
--- a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py
@@ -120,7 +120,7 @@ class TransformedDistribution(distributions.Distribution):
forward_fn=tf.exp,
inverse_fn=tf.log,
inverse_log_det_jacobian_fn=(
- lambda y: -tf.reduce_sum(tf.log(x), reduction_indices=-1)),
+ lambda y: -tf.reduce_sum(tf.log(y), reduction_indices=-1)),
name="LogNormalTransformedDistribution")
```
@@ -144,7 +144,7 @@ class TransformedDistribution(distributions.Distribution):
"""Construct a Transformed Distribution.
Args:
- distribution: The base distribution class to transform. Typically an
+ distribution: The base distribution instance to transform. Typically an
instance of `Distribution`.
bijector: The object responsible for calculating the transformation.
Typically an instance of `Bijector`.
@@ -244,7 +244,7 @@ class TransformedDistribution(distributions.Distribution):
bijector_kwargs = bijector_kwargs or {}
distribution_kwargs = distribution_kwargs or {}
x = self.bijector.inverse(y, **bijector_kwargs)
- return self.distribution.log_cdf(x, distribution_kwargs)
+ return self.distribution.log_cdf(x, **distribution_kwargs)
@distribution_util.AppendDocstring(
condition_kwargs_dict=_condition_kwargs_dict)
diff --git a/tensorflow/contrib/distributions/python/ops/uniform.py b/tensorflow/contrib/distributions/python/ops/uniform.py
index 6b64ca4669..9c9fdf42bd 100644
--- a/tensorflow/contrib/distributions/python/ops/uniform.py
+++ b/tensorflow/contrib/distributions/python/ops/uniform.py
@@ -148,10 +148,10 @@ class Uniform(distribution.Distribution):
def _prob(self, x):
broadcasted_x = x * array_ops.ones(self.batch_shape())
- return math_ops.select(
+ return array_ops.where(
math_ops.is_nan(broadcasted_x),
broadcasted_x,
- math_ops.select(
+ array_ops.where(
math_ops.logical_or(broadcasted_x < self.a,
broadcasted_x > self.b),
array_ops.zeros_like(broadcasted_x),
@@ -164,9 +164,9 @@ class Uniform(distribution.Distribution):
broadcasted_x = x * array_ops.ones(self.batch_shape())
zeros = array_ops.zeros_like(x + self.a + self.b, dtype=self.dtype)
ones = array_ops.ones_like(x + self.a + self.b, dtype=self.dtype)
- result_if_not_big = math_ops.select(
+ result_if_not_big = array_ops.where(
x < self.a, zeros, (broadcasted_x - self.a) / self.range())
- return math_ops.select(x >= self.b, ones, result_if_not_big)
+ return array_ops.where(x >= self.b, ones, result_if_not_big)
def _entropy(self):
return math_ops.log(self.range())
diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py
index ae106628cc..b478a12d36 100644
--- a/tensorflow/contrib/distributions/python/ops/wishart.py
+++ b/tensorflow/contrib/distributions/python/ops/wishart.py
@@ -367,7 +367,7 @@ class _WishartOperatorPD(distribution.Distribution):
def _mode(self):
s = self.df - self.dimension - 1.
- s = math_ops.select(
+ s = array_ops.where(
math_ops.less(s, 0.),
constant_op.constant(float("NaN"), dtype=self.dtype, name="nan"),
s)
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py
index 3228c1f3df..7784c6dbda 100644
--- a/tensorflow/contrib/factorization/python/ops/kmeans.py
+++ b/tensorflow/contrib/factorization/python/ops/kmeans.py
@@ -243,7 +243,7 @@ class KMeansClustering(estimator.Estimator,
).training_graph()
incr_step = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
self._loss = tf.reduce_sum(losses)
- tf.scalar_summary('loss/raw', self._loss)
+ tf.contrib.deprecated.scalar_summary('loss/raw', self._loss)
training_op = with_dependencies([training_op, incr_step], self._loss)
return training_op, self._loss
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py
index 34420fc87f..c149d14849 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util.py
@@ -98,29 +98,32 @@ def assert_same_float_dtype(tensors=None, dtype=None):
return dtype
-def assert_scalar_int(tensor):
+def assert_scalar_int(tensor, name=None):
"""Assert `tensor` is 0-D, of type `tf.int32` or `tf.int64`.
Args:
- tensor: Tensor to test.
+ tensor: `Tensor` to test.
+ name: Name of the op and of the new `Tensor` if one is created.
Returns:
`tensor`, for chaining.
Raises:
ValueError: if `tensor` is not 0-D, of type `tf.int32` or `tf.int64`.
"""
- tensor = ops.convert_to_tensor(tensor)
- data_type = tensor.dtype
- if data_type.base_dtype not in [dtypes.int32, dtypes.int64]:
- raise ValueError('Unexpected type %s for %s.' % (data_type, tensor.name))
- assert_scalar(tensor)
-
-
-def assert_scalar(tensor):
- tensor = ops.convert_to_tensor(tensor)
- shape = tensor.get_shape()
- if shape.ndims != 0:
- raise ValueError('Unexpected shape %s for %s.' % (shape, tensor.name))
- return tensor
+ with ops.name_scope(name, 'assert_scalar_int', [tensor]) as name_scope:
+ tensor = ops.convert_to_tensor(tensor)
+ data_type = tensor.dtype
+ if data_type.base_dtype not in [dtypes.int32, dtypes.int64]:
+ raise ValueError('Unexpected type %s for %s.' % (data_type, tensor.name))
+ return assert_scalar(tensor, name=name_scope)
+
+
+def assert_scalar(tensor, name=None):
+ with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope:
+ tensor = ops.convert_to_tensor(tensor, name=name_scope)
+ shape = tensor.get_shape()
+ if shape.ndims != 0:
+ raise ValueError('Unexpected shape %s for %s.' % (shape, tensor.name))
+ return tensor
def reduce_sum_n(tensors, name=None):
@@ -141,14 +144,15 @@ def reduce_sum_n(tensors, name=None):
"""
if not tensors:
raise ValueError('No tensors provided.')
- tensors = [math_ops.reduce_sum(t, name='%s/sum' % t.op.name) for t in tensors]
- if len(tensors) == 1:
- return tensors[0]
- with ops.name_scope(name, 'reduce_sum_n', tensors) as scope:
- return math_ops.add_n(tensors, name=scope)
+ with ops.name_scope(name, 'reduce_sum_n', tensors) as name_scope:
+ tensors = [
+ math_ops.reduce_sum(t, name='%s/sum' % t.op.name) for t in tensors]
+ if len(tensors) == 1:
+ return tensors[0]
+ return math_ops.add_n(tensors, name=name_scope)
-def remove_squeezable_dimensions(predictions, labels):
+def remove_squeezable_dimensions(predictions, labels, name=None):
"""Squeeze last dim if ranks of `predictions` and `labels` differ by 1.
This will use static shape if available. Otherwise, it will add graph
@@ -157,41 +161,44 @@ def remove_squeezable_dimensions(predictions, labels):
Args:
predictions: Predicted values, a `Tensor` of arbitrary dimensions.
labels: Label values, a `Tensor` whose dimensions match `predictions`.
+ name: Name of the op.
Returns:
Tuple of `predictions` and `labels`, possibly with last dim squeezed.
"""
- predictions = ops.convert_to_tensor(predictions)
- labels = ops.convert_to_tensor(labels)
- predictions_shape = predictions.get_shape()
- predictions_rank = predictions_shape.ndims
- labels_shape = labels.get_shape()
- labels_rank = labels_shape.ndims
- if (labels_rank is not None) and (predictions_rank is not None):
- # Use static rank.
- rank_diff = predictions_rank - labels_rank
- if rank_diff == -1:
- labels = array_ops.squeeze(labels, [-1])
- elif rank_diff == 1:
- predictions = array_ops.squeeze(predictions, [-1])
+ with ops.name_scope(name, 'remove_squeezable_dimensions',
+ [predictions, labels]):
+ predictions = ops.convert_to_tensor(predictions)
+ labels = ops.convert_to_tensor(labels)
+ predictions_shape = predictions.get_shape()
+ predictions_rank = predictions_shape.ndims
+ labels_shape = labels.get_shape()
+ labels_rank = labels_shape.ndims
+ if (labels_rank is not None) and (predictions_rank is not None):
+ # Use static rank.
+ rank_diff = predictions_rank - labels_rank
+ if rank_diff == -1:
+ labels = array_ops.squeeze(labels, [-1])
+ elif rank_diff == 1:
+ predictions = array_ops.squeeze(predictions, [-1])
+ return predictions, labels
+
+ # Use dynamic rank.
+ rank_diff = array_ops.rank(predictions) - array_ops.rank(labels)
+ if (predictions_rank is None) or (
+ predictions_shape.dims[-1].is_compatible_with(1)):
+ predictions = control_flow_ops.cond(
+ math_ops.equal(1, rank_diff),
+ lambda: array_ops.squeeze(predictions, [-1]),
+ lambda: predictions)
+ if (labels_rank is None) or (
+ labels_shape.dims[-1].is_compatible_with(1)):
+ labels = control_flow_ops.cond(
+ math_ops.equal(-1, rank_diff),
+ lambda: array_ops.squeeze(labels, [-1]),
+ lambda: labels)
return predictions, labels
- # Use dynamic rank.
- rank_diff = array_ops.rank(predictions) - array_ops.rank(labels)
- if (predictions_rank is None) or (
- predictions_shape.dims[-1].is_compatible_with(1)):
- predictions = control_flow_ops.cond(
- math_ops.equal(1, rank_diff),
- lambda: array_ops.squeeze(predictions, [-1]),
- lambda: predictions)
- if (labels_rank is None) or (
- labels_shape.dims[-1].is_compatible_with(1)):
- labels = control_flow_ops.cond(
- math_ops.equal(-1, rank_diff),
- lambda: array_ops.squeeze(labels, [-1]),
- lambda: labels)
- return predictions, labels
-
def _all_equal(tensor0, tensor1):
with ops.name_scope('all_equal', values=[tensor0, tensor1]) as scope:
diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py
index 3ec9de0af2..2db91cd889 100644
--- a/tensorflow/contrib/framework/python/ops/variables.py
+++ b/tensorflow/contrib/framework/python/ops/variables.py
@@ -171,7 +171,8 @@ def local_variable(initial_value, validate_shape=True, name=None):
@contrib_add_arg_scope
def variable(name, shape=None, dtype=None, initializer=None,
regularizer=None, trainable=True, collections=None,
- caching_device=None, device=None):
+ caching_device=None, device=None,
+ partitioner=None, custom_getter=None):
"""Gets an existing variable with these parameters or creates a new one.
Args:
@@ -191,6 +192,11 @@ def variable(name, shape=None, dtype=None, initializer=None,
device.
device: Optional device to place the variable. It can be an string or a
function that is called to get the device for the variable.
+ partitioner: Optional callable that accepts a fully defined `TensorShape`
+ and dtype of the `Variable` to be created, and returns a list of
+ partitions for each axis (currently only one axis can be partitioned).
+ custom_getter: Callable that allows overwriting the internal
+ get_variable method and has to have the same signature.
Returns:
The created or existing variable.
@@ -199,19 +205,24 @@ def variable(name, shape=None, dtype=None, initializer=None,
# Remove duplicates
collections = set(collections)
+ getter = variable_scope.get_variable
+ if custom_getter is not None:
+ getter = custom_getter
with ops.device(device or ''):
- return variable_scope.get_variable(name, shape=shape, dtype=dtype,
- initializer=initializer,
- regularizer=regularizer,
- trainable=trainable,
- collections=collections,
- caching_device=caching_device)
+ return getter(name, shape=shape, dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner)
@contrib_add_arg_scope
def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None,
regularizer=None, trainable=True, collections=None,
- caching_device=None, device=None):
+ caching_device=None, device=None, partitioner=None,
+ custom_getter=None):
"""Gets an existing model variable with these parameters or creates a new one.
Args:
@@ -232,16 +243,23 @@ def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None,
device.
device: Optional device to place the variable. It can be an string or a
function that is called to get the device for the variable.
+ partitioner: Optional callable that accepts a fully defined `TensorShape`
+ and dtype of the `Variable` to be created, and returns a list of
+ partitions for each axis (currently only one axis can be partitioned).
+ custom_getter: Callable that allows overwriting the internal
+ get_variable method and has to have the same signature.
Returns:
The created or existing variable.
"""
collections = list(collections or [])
collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES]
- return variable(name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer,
- trainable=trainable, collections=collections,
- caching_device=caching_device, device=device)
+ var = variable(name, shape=shape, dtype=dtype,
+ initializer=initializer, regularizer=regularizer,
+ trainable=trainable, collections=collections,
+ caching_device=caching_device, device=device,
+ partitioner=partitioner, custom_getter=custom_getter)
+ return var
def add_model_variable(var):
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py
index 590590bf7b..d846b013fe 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py
@@ -727,10 +727,10 @@ def matmul(a, b, name=None):
b = core.convert_to_labeled_tensor(b)
if len(a.axes) > 2 or len(b.axes) > 2:
- # We could use tf.batch_matmul to make this work, but we would also need
- # to use tf.tile and/or tf.transpose. These are more expensive than doing
- # reshapes, so it's not clear if it's a good idea to do this
- # automatically.
+ # We could pass batched inputs to tf.matmul to make this work, but we
+ # would also need to use tf.tile and/or tf.transpose. These are more
+ # expensive than doing reshapes, so it's not clear if it's a good idea to
+ # do this automatically.
raise NotImplementedError(
'matmul currently requires inputs with rank 2 or less, but '
'inputs have ranks %r and %r' % (len(a.axes), len(b.axes)))
diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py
index 57debbc148..b7832be73f 100644
--- a/tensorflow/contrib/layers/__init__.py
+++ b/tensorflow/contrib/layers/__init__.py
@@ -33,12 +33,14 @@ common machine learning algorithms.
@@repeat
@@safe_embedding_lookup_sparse
@@separable_convolution2d
-@@stack
@@unit_norm
Aliases for fully_connected which set a default activation function are
available: `relu`, `relu6` and `linear`.
+`stack` operation is also available. It builds a stack of layers by applying
+a layer repeatedly.
+
## Regularizers
Regularization can help prevent overfitting. These have the signature
@@ -118,4 +120,8 @@ from tensorflow.contrib.layers.python.ops import sparse_ops
from tensorflow.python.util.all_util import make_all
# pylint: enable=unused-import,wildcard-import
+
+# Note: `stack` operation is available, just excluded from the document above
+# due to collision with tf.stack.
+
__all__ = make_all(__name__)
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py
index 4713d0b5c7..25a871cd15 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py
@@ -150,7 +150,7 @@ def safe_embedding_lookup_sparse(embedding_weights,
array_ops.reshape(is_row_empty, [-1, 1]),
array_ops.pack([1, array_ops.shape(result)[1]]))
- result = math_ops.select(is_row_empty,
+ result = array_ops.where(is_row_empty,
array_ops.zeros_like(result),
result,
name=scope)
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index d8e5485373..e3ef7328a4 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -850,6 +850,8 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
shared_embedding_name: (Optional). The common name for shared embedding.
shared_vocab_size: (Optional). The common vocab_size used for shared
embedding space.
+ max_norm: (Optional). If not None, embedding values are l2-normalized to
+ the value of max_norm.
Raises:
ValueError: if `initializer` is specified and is not callable. Also,
@@ -959,7 +961,8 @@ def embedding_column(sparse_id_column,
combiner=None,
initializer=None,
ckpt_to_load_from=None,
- tensor_name_in_ckpt=None):
+ tensor_name_in_ckpt=None,
+ max_norm=None):
"""Creates an `_EmbeddingColumn` for feeding sparse data into a DNN.
Args:
@@ -984,6 +987,8 @@ def embedding_column(sparse_id_column,
tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided
checkpoint from which to restore the column weights. Required if
`ckpt_to_load_from` is not None.
+ max_norm: (Optional). If not None, embedding values are l2-normalized to
+ the value of max_norm.
Returns:
An `_EmbeddingColumn`.
@@ -993,7 +998,8 @@ def embedding_column(sparse_id_column,
"to \"sqrtn\" after 2016/11/01.")
combiner = "mean"
return _EmbeddingColumn(sparse_id_column, dimension, combiner, initializer,
- ckpt_to_load_from, tensor_name_in_ckpt)
+ ckpt_to_load_from, tensor_name_in_ckpt,
+ max_norm=max_norm)
def shared_embedding_columns(sparse_id_columns,
@@ -1002,7 +1008,8 @@ def shared_embedding_columns(sparse_id_columns,
shared_embedding_name=None,
initializer=None,
ckpt_to_load_from=None,
- tensor_name_in_ckpt=None):
+ tensor_name_in_ckpt=None,
+ max_norm=None):
"""Creates a list of `_EmbeddingColumn` sharing the same embedding.
Args:
@@ -1030,6 +1037,8 @@ def shared_embedding_columns(sparse_id_columns,
tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided
checkpoint from which to restore the column weights. Required if
`ckpt_to_load_from` is not None.
+ max_norm: (Optional). If not None, embedding values are l2-normalized to
+ the value of max_norm.
Returns:
A tuple of `_EmbeddingColumn` with shared embedding space.
@@ -1061,7 +1070,7 @@ def shared_embedding_columns(sparse_id_columns,
return [
_EmbeddingColumn(sparse_id_columns[0], dimension, combiner, initializer,
ckpt_to_load_from, tensor_name_in_ckpt,
- shared_embedding_name)]
+ shared_embedding_name, max_norm=max_norm)]
else:
# check compatibility of sparse_id_columns
compatible = True
@@ -1090,7 +1099,8 @@ def shared_embedding_columns(sparse_id_columns,
embedded_columns.append(
_EmbeddingColumn(column, dimension, combiner, initializer,
ckpt_to_load_from, tensor_name_in_ckpt,
- shared_embedding_name, shared_vocab_size))
+ shared_embedding_name, shared_vocab_size,
+ max_norm=max_norm))
return tuple(embedded_columns)
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
index af6cfa9418..8a49e14c08 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
@@ -778,6 +778,25 @@ class CreateInputLayersForDNNsTest(tf.test.TestCase):
# score: (number of values)
self.assertAllEqual(output.eval(), [[1.], [2.], [0.]])
+ def testEmbeddingColumnWithMaxNormForDNN(self):
+ hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
+ wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo"],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ shape=[3, 2])
+ features = {"wire": wire_tensor}
+ embedded_sparse = tf.contrib.layers.embedding_column(
+ hashed_sparse,
+ 1,
+ combiner="sum",
+ initializer=init_ops.ones_initializer(),
+ max_norm=0.5)
+ output = tf.contrib.layers.input_from_feature_columns(features,
+ [embedded_sparse])
+ with self.test_session():
+ tf.global_variables_initializer().run()
+ # score: (number of values * 0.5)
+ self.assertAllClose(output.eval(), [[0.5], [1.], [0.]])
+
def testEmbeddingColumnWithWeightedSparseColumnForDNN(self):
ids = tf.contrib.layers.sparse_column_with_keys(
"ids", ["marlo", "omar", "stringer"])
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 30f6690c68..5c6559b826 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -39,6 +39,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import standard_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training import moving_averages
# TODO(b/28426988): Replace legacy_* fns migrated from slim.
@@ -1153,12 +1154,16 @@ def dropout(inputs,
Returns:
a tensor representing the output of the operation.
"""
- with ops.name_scope(scope, 'Dropout', [inputs]) as sc:
+ with variable_scope.variable_scope(
+ scope, 'Dropout', [inputs], custom_getter=_model_variable_getter) as sc:
inputs = ops.convert_to_tensor(inputs)
- dropout_fn = lambda: nn.dropout(inputs, keep_prob, noise_shape)
- id_fn = lambda: array_ops.identity(inputs)
- outputs = utils.smart_cond(is_training, dropout_fn, id_fn)
- return utils.collect_named_outputs(outputs_collections, sc, outputs)
+ layer = core_layers.Dropout(rate=1 - keep_prob,
+ noise_shape=noise_shape,
+ name=sc.name,
+ _scope=sc)
+ outputs = layer.apply(inputs, training=is_training)
+ return utils.collect_named_outputs(
+ outputs_collections, sc.original_name_scope, outputs)
@add_arg_scope
@@ -1264,6 +1269,31 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None):
return utils.collect_named_outputs(output_collections, sc, flattened)
+def _model_variable_getter(getter, name, shape=None, dtype=None,
+ initializer=None, regularizer=None, trainable=True,
+ collections=None, caching_device=None,
+ partitioner=None, **_):
+ """Getter that uses model_variable for compatibility with core layers."""
+ return variables.model_variable(
+ name, shape=shape, dtype=dtype, initializer=initializer,
+ regularizer=regularizer, collections=collections, trainable=trainable,
+ caching_device=caching_device, partitioner=partitioner,
+ custom_getter=getter)
+
+
+def _add_variable_to_collections(variable, collections_set, collections_name):
+ """Adds variable (or all its parts) to all collections with that name."""
+ collections = utils.get_variable_collections(
+ collections_set, collections_name) or []
+ variables_list = [variable]
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ variables_list = [v for v in variable]
+ for collection in collections:
+ for var in variables_list:
+ if var not in ops.get_collection(collection):
+ ops.add_to_collection(collection, var)
+
+
@add_arg_scope
def fully_connected(inputs,
num_outputs,
@@ -1325,41 +1355,17 @@ def fully_connected(inputs,
if not (isinstance(num_outputs, six.integer_types)):
raise ValueError('num_outputs should be int or long, got %s.', num_outputs)
- # Currently, the layers in this module do not create variables via
- # `tf.get_variable`, rather they use their own variable management system
- # which wraps `tf.get_variable` (the `model_variable()` interface from Slim).
- # This interface is globally-configured via an argscope. This global
- # configuration mechanism is used for instance by Slim-deploy to globally
- # configure the target device of the variables of a model.
- #
- # We have the following the constraints:
- # - Argscopes are not currently moving into core, thus core layers cannot
- # rely on the Slim variable wrapper, and should instead
- # use `tf.get_variable`.
- # - Contrib layers require to use the argscope-enabled Slim variable wrapper
- # rather than raw TF variables.
- # - We want to be able to reuse at least the logic across core layers
- # and contrib layers.
- #
- # We use the following strategy:
- # - We instantiate variables in the contrib layer via the Slim interface.
- # - We instantiate a core layer and set its variables to be the Slim ones.
- # - We call the core layer.
- #
- # This enables us to reuse the `call` method across both implementations.
-
- with variable_scope.variable_scope(scope, 'fully_connected', [inputs],
- reuse=reuse) as sc:
+ with variable_scope.variable_scope(
+ scope, 'fully_connected', [inputs],
+ reuse=reuse, custom_getter=_model_variable_getter) as sc:
inputs = ops.convert_to_tensor(inputs)
-
- # Instantiate the FullyConnected layer.
layer = core_layers.FullyConnected(
- num_outputs,
+ units=num_outputs,
activation=None,
use_bias=not normalizer_fn and biases_initializer,
- w_initializer=weights_initializer,
+ weights_initializer=weights_initializer,
bias_initializer=biases_initializer,
- w_regularizer=weights_regularizer,
+ weights_regularizer=weights_regularizer,
bias_regularizer=biases_regularizer,
activity_regularizer=None,
trainable=trainable,
@@ -1367,39 +1373,12 @@ def fully_connected(inputs,
dtype=inputs.dtype.base_dtype,
_scope=sc,
_reuse_weights=reuse)
+ outputs = layer.apply(inputs)
- dtype = inputs.dtype.base_dtype
- inputs_shape = inputs.get_shape()
- num_input_units = utils.last_dimension(inputs_shape, min_rank=2)
-
- static_shape = inputs_shape.as_list()
- static_shape[-1] = num_outputs
-
- weights_shape = [num_input_units, num_outputs]
- weights_collections = utils.get_variable_collections(
- variables_collections, 'weights')
- weights = variables.model_variable('weights',
- shape=weights_shape,
- dtype=dtype,
- initializer=weights_initializer,
- regularizer=weights_regularizer,
- collections=weights_collections,
- trainable=trainable)
- layer.w = weights
-
- if layer.use_bias:
- biases_collections = utils.get_variable_collections(
- variables_collections, 'biases')
- biases = variables.model_variable('biases',
- shape=[num_outputs,],
- dtype=dtype,
- initializer=biases_initializer,
- regularizer=biases_regularizer,
- collections=biases_collections,
- trainable=trainable)
- layer.bias = biases
-
- outputs = layer.call(inputs)
+ # Add variables to collections.
+ _add_variable_to_collections(layer.w, variables_collections, 'weights')
+ if layer.bias:
+ _add_variable_to_collections(layer.bias, variables_collections, 'biases')
# Apply normalizer function / layer.
if normalizer_fn is not None:
@@ -2099,4 +2078,3 @@ conv2d = convolution2d
conv2d_transpose = convolution2d_transpose
conv2d_in_plane = convolution2d_in_plane
separable_conv2d = separable_convolution2d
-
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index a9c8997885..b28e3363dc 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -628,7 +628,7 @@ class Convolution2dTransposeTests(tf.test.TestCase):
self.assertListEqual(list(output.eval().shape), expected_size)
def testOutputSizeWithStrideOneValidPaddingNCHW(self):
- if tf.test.is_gpu_available():
+ if tf.test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True) as sess:
num_filters = 32
input_size = [5, 3, 10, 12]
@@ -644,7 +644,7 @@ class Convolution2dTransposeTests(tf.test.TestCase):
self.assertListEqual(list(output.eval().shape), expected_size)
def testOutputSizeWithStrideTwoValidPaddingNCHW(self):
- if tf.test.is_gpu_available():
+ if tf.test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True) as sess:
num_filters = 32
input_size = [5, 3, 9, 11]
@@ -661,7 +661,7 @@ class Convolution2dTransposeTests(tf.test.TestCase):
self.assertListEqual(list(output.eval().shape), expected_size)
def testOutputSizeWith1x1StrideTwoSamePaddingNCHW(self):
- if tf.test.is_gpu_available():
+ if tf.test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True) as sess:
num_filters = 1
input_size = [1, 1, 1, 1]
@@ -678,7 +678,7 @@ class Convolution2dTransposeTests(tf.test.TestCase):
self.assertListEqual(list(output.eval().shape), expected_size)
def testOutputSizeWith1x1StrideTwoValidPaddingNCHW(self):
- if tf.test.is_gpu_available():
+ if tf.test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True) as sess:
num_filters = 1
input_size = [1, 1, 1, 1]
@@ -693,7 +693,7 @@ class Convolution2dTransposeTests(tf.test.TestCase):
self.assertListEqual(list(output.eval().shape), expected_size)
def testOutputSizeWith2x2StrideTwoSamePaddingNCHW(self):
- if tf.test.is_gpu_available():
+ if tf.test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True) as sess:
num_filters = 1
input_size = [1, 1, 2, 2]
@@ -708,7 +708,7 @@ class Convolution2dTransposeTests(tf.test.TestCase):
self.assertListEqual(list(output.eval().shape), expected_size)
def testOutputSizeWith2x2StrideTwoValidPaddingNCHW(self):
- if tf.test.is_gpu_available():
+ if tf.test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True) as sess:
num_filters = 1
input_size = [1, 1, 2, 2]
@@ -723,7 +723,7 @@ class Convolution2dTransposeTests(tf.test.TestCase):
self.assertListEqual(list(output.eval().shape), expected_size)
def testOutputSizeWithStride2x1NCHW(self):
- if tf.test.is_gpu_available():
+ if tf.test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True) as sess:
num_filters = 1
input_size = [1, 1, 3, 2]
@@ -738,7 +738,7 @@ class Convolution2dTransposeTests(tf.test.TestCase):
self.assertListEqual(list(output.eval().shape), expected_size)
def testOutputSizeWithStride2x4NCHW(self):
- if tf.test.is_gpu_available():
+ if tf.test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True) as sess:
num_filters = 1
input_size = [1, 1, 3, 2]
@@ -753,7 +753,7 @@ class Convolution2dTransposeTests(tf.test.TestCase):
self.assertListEqual(list(output.eval().shape), expected_size)
def testOutputSizeWithStride2x5NCHW(self):
- if tf.test.is_gpu_available():
+ if tf.test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True) as sess:
num_filters = 1
input_size = [1, 1, 3, 2]
@@ -1181,7 +1181,6 @@ class DropoutTest(tf.test.TestCase):
is_training = tf.constant(True)
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.dropout(images, is_training=is_training)
- self.assertEqual(output.op.name, 'Dropout/dropout/mul')
output.get_shape().assert_is_compatible_with(images.get_shape())
def testCreateDropoutWithConstantFalse(self):
@@ -1190,7 +1189,6 @@ class DropoutTest(tf.test.TestCase):
is_training = tf.constant(False)
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.dropout(images, is_training=is_training)
- self.assertEqual(output.op.name, 'Dropout/Identity')
output.get_shape().assert_is_compatible_with(images.get_shape())
def testCreateDropoutWithPlaceholder(self):
@@ -1220,8 +1218,8 @@ class DropoutTest(tf.test.TestCase):
num_elem = tf.reduce_mean(tf.to_float(output > 0))
sess.run(tf.global_variables_initializer())
num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial])
- self.assertLess(num_elem, num_elem_initial/2 + 0.1)
- self.assertGreater(num_elem, num_elem_initial/2 - 0.1)
+ self.assertLess(num_elem, num_elem_initial / 2 + 0.1)
+ self.assertGreater(num_elem, num_elem_initial / 2 - 0.1)
def testCreateDropoutNoTraining(self):
height, width = 3, 3
@@ -1246,8 +1244,8 @@ class DropoutTest(tf.test.TestCase):
num_elem = tf.reduce_mean(tf.to_float(output > 0))
sess.run(tf.global_variables_initializer())
num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial])
- self.assertLess(num_elem, num_elem_initial/2 + 0.1)
- self.assertGreater(num_elem, num_elem_initial/2 - 0.1)
+ self.assertLess(num_elem, num_elem_initial / 2 + 0.1)
+ self.assertGreater(num_elem, num_elem_initial / 2 - 0.1)
def testCreateFCWithDropout(self):
height, width = 3, 3
diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py
index 2908096c6c..ada679d83b 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers.py
@@ -360,7 +360,7 @@ def adaptive_clipping_fn(std_factor=2.,
summary.scalar("global_norm/adaptive_max_gradient_norm", max_norm)
# factor will be 1. if norm is smaller than max_norm
- factor = math_ops.select(norm < max_norm,
+ factor = array_ops.where(norm < max_norm,
array_ops.ones_like(norm),
math_ops.exp(log_mean) / norm)
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 5b7a2d76d8..764971935f 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -21,6 +21,10 @@ py_library(
"//tensorflow/contrib/session_bundle:exporter",
"//tensorflow/contrib/tensor_forest:client_lib",
"//tensorflow/python:framework",
+ "//tensorflow/python/saved_model:builder",
+ "//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:signature_def_utils",
+ "//tensorflow/python/saved_model:tag_constants",
],
)
@@ -663,6 +667,32 @@ py_test(
)
py_test(
+ name = "gc_test",
+ size = "small",
+ srcs = ["python/learn/utils/gc_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":learn",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
+py_test(
+ name = "saved_model_export_utils_test",
+ size = "small",
+ srcs = ["python/learn/utils/saved_model_export_utils_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":learn",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
+py_test(
name = "stability_test",
size = "small",
srcs = ["python/learn/estimators/stability_test.py"],
diff --git a/tensorflow/python/util/net_lib.py b/tensorflow/contrib/learn/python/learn/estimators/constants.py
index d8566eb7c7..aee4541627 100644
--- a/tensorflow/python/util/net_lib.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/constants.py
@@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""A Python interface for creating TensorFlow tests."""
+"""Constants regarding Estimators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python import pywrap_tensorflow
-
-def pick_unused_port_or_die():
- """Find an unused port on localhost."""
- return pywrap_tensorflow.PickUnusedPortOrDie()
+class ProblemType(object):
+ UNSPECIFIED = 0
+ CLASSIFICATION = 1
+ LINEAR_REGRESSION = 2
+ LOGISTIC_REGRESSION = 3
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
index db0a24e508..98947cc6d4 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
@@ -19,9 +19,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import six
+
from tensorflow.contrib import layers
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
+from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn import evaluable
@@ -89,6 +92,9 @@ def _dnn_model_fn(features, labels, mode, params):
* gradient_clip_norm: A float > 0. If provided, gradients are
clipped to their global norm with this clipping ratio.
* num_ps_replicas: The number of parameter server replicas.
+ * embedding_lr_multipliers: Optional. A dictionary from
+ `EmbeddingColumn` to a `float` multiplier. Multiplier will be used to
+ multiply with learning rate for the embedding variables.
Returns:
predictions: A dict of `Tensor` objects.
@@ -103,6 +109,7 @@ def _dnn_model_fn(features, labels, mode, params):
dropout = params.get("dropout")
gradient_clip_norm = params.get("gradient_clip_norm")
num_ps_replicas = params.get("num_ps_replicas", 0)
+ embedding_lr_multipliers = params.get("embedding_lr_multipliers", {})
features = _get_feature_dict(features)
parent_scope = "dnn"
@@ -111,9 +118,10 @@ def _dnn_model_fn(features, labels, mode, params):
partitioned_variables.min_max_variable_partitioner(
max_partitions=num_ps_replicas,
min_slice_size=64 << 20))
+ input_layer_scope = parent_scope + "/input_from_feature_columns"
with variable_scope.variable_scope(
- parent_scope + "/input_from_feature_columns",
- values=features.values(),
+ input_layer_scope,
+ values=list(six.itervalues(features)),
partitioner=input_layer_partitioner) as scope:
net = layers.input_from_feature_columns(
columns_to_tensors=features,
@@ -160,6 +168,9 @@ def _dnn_model_fn(features, labels, mode, params):
global_step=contrib_variables.get_global_step(),
learning_rate=_LEARNING_RATE,
optimizer=_get_optimizer(optimizer),
+ gradient_multipliers=(
+ dnn_linear_combined._extract_embedding_lr_multipliers( # pylint: disable=protected-access
+ embedding_lr_multipliers, parent_scope, input_layer_scope)),
clip_gradients=gradient_clip_norm,
name=parent_scope,
# Empty summaries to prevent optimizers from logging the training_loss.
@@ -234,7 +245,8 @@ class DNNClassifier(evaluable.Evaluable, trainable.Trainable):
gradient_clip_norm=None,
enable_centered_bias=False,
config=None,
- feature_engineering_fn=None):
+ feature_engineering_fn=None,
+ embedding_lr_multipliers=None):
"""Initializes a DNNClassifier instance.
Args:
@@ -271,6 +283,9 @@ class DNNClassifier(evaluable.Evaluable, trainable.Trainable):
labels which are the output of `input_fn` and
returns features and labels which will be fed
into the model.
+ embedding_lr_multipliers: Optional. A dictionary from `EmbeddingColumn` to
+ a `float` multiplier. Multiplier will be used to multiply with
+ learning rate for the embedding variables.
Returns:
A `DNNClassifier` estimator.
@@ -287,17 +302,27 @@ class DNNClassifier(evaluable.Evaluable, trainable.Trainable):
model_dir=model_dir,
config=config,
params={
- "head": head_lib._multi_class_head( # pylint: disable=protected-access
- n_classes,
- weight_column_name=weight_column_name,
- enable_centered_bias=enable_centered_bias),
- "hidden_units": hidden_units,
- "feature_columns": feature_columns,
- "optimizer": optimizer,
- "activation_fn": activation_fn,
- "dropout": dropout,
- "gradient_clip_norm": gradient_clip_norm,
- "num_ps_replicas": config.num_ps_replicas if config else 0,
+ "head":
+ head_lib._multi_class_head( # pylint: disable=protected-access
+ n_classes,
+ weight_column_name=weight_column_name,
+ enable_centered_bias=enable_centered_bias),
+ "hidden_units":
+ hidden_units,
+ "feature_columns":
+ feature_columns,
+ "optimizer":
+ optimizer,
+ "activation_fn":
+ activation_fn,
+ "dropout":
+ dropout,
+ "gradient_clip_norm":
+ gradient_clip_norm,
+ "num_ps_replicas":
+ config.num_ps_replicas if config else 0,
+ "embedding_lr_multipliers":
+ embedding_lr_multipliers,
},
feature_engineering_fn=feature_engineering_fn)
@@ -428,6 +453,22 @@ class DNNClassifier(evaluable.Evaluable, trainable.Trainable):
default_batch_size=default_batch_size,
exports_to_keep=exports_to_keep)
+ @experimental
+ def export_savedmodel(self,
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=None,
+ assets_extra=None,
+ as_text=False,
+ exports_to_keep=None):
+ return self._estimator.export_savedmodel(
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=default_output_alternative_key,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ exports_to_keep=exports_to_keep)
+
@property
def model_dir(self):
return self._estimator.model_dir
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
index 10b288f1fb..256e074079 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
@@ -26,7 +26,9 @@ import six
from tensorflow.contrib import layers
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
+from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
+from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib
from tensorflow.contrib.layers.python.layers import feature_column_ops
from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn import evaluable
@@ -365,6 +367,31 @@ def _add_hidden_layer_summary(value, tag):
logging_ops.histogram_summary("%s:activation" % tag, value)
+def _get_embedding_variable(column, collection_key, input_layer_scope):
+ return ops.get_collection(collection_key,
+ input_layer_scope + "/" + column.name)
+
+
+def _extract_embedding_lr_multipliers(embedding_lr_multipliers, collection_key,
+ input_layer_scope):
+ """Converts embedding lr multipliers to variable based gradient multiplier."""
+ if not embedding_lr_multipliers:
+ return None
+ gradient_multipliers = {}
+ for column, lr_mult in embedding_lr_multipliers.items():
+ if not isinstance(column, feature_column_lib._EmbeddingColumn): # pylint: disable=protected-access
+ raise ValueError(
+ "learning rate multipler can only be defined for embedding columns. "
+ "It is defined for {}".format(column))
+ embedding = _get_embedding_variable(
+ column, collection_key, input_layer_scope)
+ if not embedding:
+ raise ValueError("Couldn't find a variable for column {}".format(column))
+ for v in embedding:
+ gradient_multipliers[v] = lr_mult
+ return gradient_multipliers
+
+
def _dnn_linear_combined_model_fn(features, labels, mode, params):
"""Deep Neural Net and Linear combined model_fn.
@@ -396,6 +423,9 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params):
* gradient_clip_norm: A float > 0. If provided, gradients are
clipped to their global norm with this clipping ratio.
* num_ps_replicas: The number of parameter server replicas.
+ * embedding_lr_multipliers: Optional. A dictionary from
+ `EmbeddingColumn` to a `float` multiplier. Multiplier will be used to
+ multiply with learning rate for the embedding variables.
Returns:
`ModelFnOps`
@@ -414,7 +444,8 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params):
dnn_activation_fn = params.get("dnn_activation_fn")
dnn_dropout = params.get("dnn_dropout")
gradient_clip_norm = params.get("gradient_clip_norm")
- num_ps_replicas = params["num_ps_replicas"]
+ num_ps_replicas = params.get("num_ps_replicas", 0)
+ embedding_lr_multipliers = params.get("embedding_lr_multipliers", {})
if not linear_feature_columns and not dnn_feature_columns:
raise ValueError(
@@ -432,8 +463,9 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params):
partitioned_variables.min_max_variable_partitioner(
max_partitions=num_ps_replicas,
min_slice_size=64 << 20))
+ input_layer_scope = dnn_parent_scope + "/input_from_feature_columns"
with variable_scope.variable_scope(
- dnn_parent_scope + "/input_from_feature_columns",
+ input_layer_scope,
values=features.values(),
partitioner=input_layer_partitioner) as scope:
net = layers.input_from_feature_columns(
@@ -521,6 +553,9 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params):
global_step=contrib_variables.get_global_step(),
learning_rate=_DNN_LEARNING_RATE,
optimizer=_get_optimizer(dnn_optimizer),
+ gradient_multipliers=_extract_embedding_lr_multipliers( # pylint: disable=protected-access
+ embedding_lr_multipliers, dnn_parent_scope,
+ input_layer_scope),
clip_gradients=gradient_clip_norm,
variables=ops.get_collection(dnn_parent_scope),
name=dnn_parent_scope,
@@ -612,7 +647,8 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
gradient_clip_norm=None,
enable_centered_bias=False,
config=None,
- feature_engineering_fn=None):
+ feature_engineering_fn=None,
+ embedding_lr_multipliers=None):
"""Constructs a DNNLinearCombinedClassifier instance.
Args:
@@ -656,6 +692,9 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
labels which are the output of `input_fn` and
returns features and labels which will be fed
into the model.
+ embedding_lr_multipliers: Optional. A dictionary from `EmbeddingColumn` to
+ a `float` multiplier. Multiplier will be used to multiply with
+ learning rate for the embedding variables.
Raises:
ValueError: If `n_classes` < 2.
@@ -695,6 +734,7 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
"dnn_dropout": dnn_dropout,
"gradient_clip_norm": gradient_clip_norm,
"num_ps_replicas": config.num_ps_replicas if config else 0,
+ "embedding_lr_multipliers": embedding_lr_multipliers,
},
feature_engineering_fn=feature_engineering_fn)
@@ -829,6 +869,22 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
default_batch_size=default_batch_size,
exports_to_keep=exports_to_keep)
+ @experimental
+ def export_savedmodel(self,
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=None,
+ assets_extra=None,
+ as_text=False,
+ exports_to_keep=None):
+ return self._estimator.export_savedmodel(
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=default_output_alternative_key,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ exports_to_keep=exports_to_keep)
+
@property
def model_dir(self):
return self._estimator.model_dir
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
index 26c38d0789..33d0d2eb4f 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
@@ -27,7 +27,9 @@ import numpy as np
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
+from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined
from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils
+from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.contrib.learn.python.learn.estimators import test_data
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
@@ -39,6 +41,82 @@ def _assert_metrics_in_range(keys, metrics):
0.0 - epsilon, 1.0 + epsilon, key, metrics)
+class EmbeddingMultiplierTest(tf.test.TestCase):
+ """dnn_model_fn tests."""
+
+ def testRaisesNonEmbeddingColumn(self):
+ one_hot_language = tf.contrib.layers.one_hot_column(
+ tf.contrib.layers.sparse_column_with_hash_bucket('language', 10))
+
+ params = {
+ 'dnn_feature_columns': [one_hot_language],
+ 'head': head_lib._multi_class_head(2),
+ 'dnn_hidden_units': [1],
+ # Set lr mult to 0. to keep embeddings constant.
+ 'embedding_lr_multipliers': {
+ one_hot_language: 0.0
+ },
+ 'dnn_optimizer': 'Adagrad',
+ }
+ features = {
+ 'language':
+ tf.SparseTensor(
+ values=['en', 'fr', 'zh'],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ shape=[3, 1]),
+ }
+ labels = tf.constant([[0], [0], [0]], dtype=tf.int32)
+ with self.assertRaisesRegexp(
+ ValueError, 'can only be defined for embedding columns'):
+ dnn_linear_combined._dnn_linear_combined_model_fn(
+ features, labels, tf.contrib.learn.ModeKeys.TRAIN, params)
+
+ def testMultipliesGradient(self):
+ embedding_language = tf.contrib.layers.embedding_column(
+ tf.contrib.layers.sparse_column_with_hash_bucket('language', 10),
+ dimension=1, initializer=tf.constant_initializer(0.1))
+ embedding_wire = tf.contrib.layers.embedding_column(
+ tf.contrib.layers.sparse_column_with_hash_bucket('wire', 10),
+ dimension=1, initializer=tf.constant_initializer(0.1))
+
+ params = {
+ 'dnn_feature_columns': [embedding_language, embedding_wire],
+ 'head': head_lib._multi_class_head(2),
+ 'dnn_hidden_units': [1],
+ # Set lr mult to 0. to keep embeddings constant.
+ 'embedding_lr_multipliers': {
+ embedding_language: 0.0
+ },
+ 'dnn_optimizer': 'Adagrad',
+ }
+ features = {
+ 'language':
+ tf.SparseTensor(
+ values=['en', 'fr', 'zh'],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ shape=[3, 1]),
+ 'wire':
+ tf.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ shape=[3, 1]),
+ }
+ labels = tf.constant([[0], [0], [0]], dtype=tf.int32)
+ model_ops = dnn_linear_combined._dnn_linear_combined_model_fn(
+ features, labels, tf.contrib.learn.ModeKeys.TRAIN, params)
+ with tf.train.MonitoredSession() as sess:
+ language_var = dnn_linear_combined._get_embedding_variable(
+ embedding_language, 'dnn', 'dnn/input_from_feature_columns')
+ wire_var = dnn_linear_combined._get_embedding_variable(
+ embedding_wire, 'dnn', 'dnn/input_from_feature_columns')
+ for _ in range(2):
+ _, language_value, wire_value = sess.run(
+ [model_ops.train_op, language_var, wire_var])
+ initial_value = np.full_like(language_value, 0.1)
+ self.assertTrue(np.all(np.isclose(language_value, initial_value)))
+ self.assertFalse(np.all(np.isclose(wire_value, initial_value)))
+
+
class DNNLinearCombinedClassifierTest(tf.test.TestCase):
def testEstimatorContract(self):
@@ -54,6 +132,18 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
dnn_feature_columns=None,
dnn_hidden_units=[3, 3])
+ def testEmbeddingMultiplier(self):
+ embedding_language = tf.contrib.layers.embedding_column(
+ tf.contrib.layers.sparse_column_with_hash_bucket('language', 10),
+ dimension=1, initializer=tf.constant_initializer(0.1))
+ classifier = tf.contrib.learn.DNNLinearCombinedClassifier(
+ dnn_feature_columns=[embedding_language],
+ dnn_hidden_units=[3, 3],
+ embedding_lr_multipliers={embedding_language: 0.8})
+ self.assertEqual(
+ {embedding_language: 0.8},
+ classifier._estimator.params['embedding_lr_multipliers'])
+
def testLogisticRegression_MatrixData(self):
"""Tests binary classification using matrix data as input."""
iris = test_data.prepare_iris_data_for_logistic_regression()
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py
index a55ca08b1b..9196d78d22 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py
@@ -27,12 +27,89 @@ import numpy as np
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
+from tensorflow.contrib.learn.python.learn.estimators import dnn
+from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined
from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils
+from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.contrib.learn.python.learn.estimators import test_data
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
from tensorflow.python.ops import math_ops
+class EmbeddingMultiplierTest(tf.test.TestCase):
+ """dnn_model_fn tests."""
+
+ def testRaisesNonEmbeddingColumn(self):
+ one_hot_language = tf.contrib.layers.one_hot_column(
+ tf.contrib.layers.sparse_column_with_hash_bucket('language', 10))
+
+ params = {
+ 'feature_columns': [one_hot_language],
+ 'head': head_lib._multi_class_head(2),
+ 'hidden_units': [1],
+ # Set lr mult to 0. to keep embeddings constant.
+ 'embedding_lr_multipliers': {
+ one_hot_language: 0.0
+ },
+ }
+ features = {
+ 'language':
+ tf.SparseTensor(
+ values=['en', 'fr', 'zh'],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ shape=[3, 1]),
+ }
+ labels = tf.constant([[0], [0], [0]], dtype=tf.int32)
+ with self.assertRaisesRegexp(
+ ValueError, 'can only be defined for embedding columns'):
+ dnn._dnn_model_fn(features, labels,
+ tf.contrib.learn.ModeKeys.TRAIN, params)
+
+ def testMultipliesGradient(self):
+ embedding_language = tf.contrib.layers.embedding_column(
+ tf.contrib.layers.sparse_column_with_hash_bucket('language', 10),
+ dimension=1, initializer=tf.constant_initializer(0.1))
+ embedding_wire = tf.contrib.layers.embedding_column(
+ tf.contrib.layers.sparse_column_with_hash_bucket('wire', 10),
+ dimension=1, initializer=tf.constant_initializer(0.1))
+
+ params = {
+ 'feature_columns': [embedding_language, embedding_wire],
+ 'head': head_lib._multi_class_head(2),
+ 'hidden_units': [1],
+ # Set lr mult to 0. to keep embeddings constant.
+ 'embedding_lr_multipliers': {
+ embedding_language: 0.0
+ },
+ }
+ features = {
+ 'language':
+ tf.SparseTensor(
+ values=['en', 'fr', 'zh'],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ shape=[3, 1]),
+ 'wire':
+ tf.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ shape=[3, 1]),
+ }
+ labels = tf.constant([[0], [0], [0]], dtype=tf.int32)
+ model_ops = dnn._dnn_model_fn(features, labels,
+ tf.contrib.learn.ModeKeys.TRAIN, params)
+ with tf.train.MonitoredSession() as sess:
+ language_var = dnn_linear_combined._get_embedding_variable(
+ embedding_language, 'dnn', 'dnn/input_from_feature_columns')
+ wire_var = dnn_linear_combined._get_embedding_variable(
+ embedding_wire, 'dnn', 'dnn/input_from_feature_columns')
+ for _ in range(2):
+ _, language_value, wire_value = sess.run(
+ [model_ops.train_op, language_var, wire_var])
+ initial_value = np.full_like(language_value, 0.1)
+ self.assertTrue(np.all(np.isclose(language_value, initial_value)))
+ self.assertFalse(np.all(np.isclose(wire_value, initial_value)))
+
+
class DNNClassifierTest(tf.test.TestCase):
def _assertInRange(self, expected_min, expected_max, actual):
@@ -43,6 +120,18 @@ class DNNClassifierTest(tf.test.TestCase):
estimator_test_utils.assert_estimator_contract(
self, tf.contrib.learn.DNNClassifier)
+ def testEmbeddingMultiplier(self):
+ embedding_language = tf.contrib.layers.embedding_column(
+ tf.contrib.layers.sparse_column_with_hash_bucket('language', 10),
+ dimension=1, initializer=tf.constant_initializer(0.1))
+ classifier = tf.contrib.learn.DNNClassifier(
+ feature_columns=[embedding_language],
+ hidden_units=[3, 3],
+ embedding_lr_multipliers={embedding_language: 0.8})
+ self.assertEqual(
+ {embedding_language: 0.8},
+ classifier._estimator.params['embedding_lr_multipliers'])
+
def testLogisticRegression_MatrixData(self):
"""Tests binary classification using matrix data as input."""
cont_features = [
@@ -118,10 +207,10 @@ class DNNClassifierTest(tf.test.TestCase):
classifier = tf.contrib.learn.DNNClassifier(
n_classes=2,
feature_columns=feature_columns,
- hidden_units=[3, 3],
+ hidden_units=[10, 10],
config=tf.contrib.learn.RunConfig(tf_random_seed=1))
- classifier.fit(input_fn=_input_fn, steps=5)
+ classifier.fit(input_fn=_input_fn, steps=50)
scores = classifier.evaluate(input_fn=_input_fn, steps=1)
self._assertInRange(0.0, 1.0, scores['accuracy'])
@@ -222,7 +311,7 @@ class DNNClassifierTest(tf.test.TestCase):
n_classes=3,
feature_columns=feature_columns,
hidden_units=[3, 3],
- config=tf.contrib.learn.RunConfig(tf_random_seed=3))
+ config=tf.contrib.learn.RunConfig(tf_random_seed=1))
classifier.fit(x=train_x, y=train_y, steps=200)
scores = classifier.evaluate(x=train_x, y=train_y, steps=1)
@@ -310,7 +399,7 @@ class DNNClassifierTest(tf.test.TestCase):
weight_column_name='w',
feature_columns=[tf.contrib.layers.real_valued_column('x')],
hidden_units=[3, 3],
- config=tf.contrib.learn.RunConfig(tf_random_seed=3))
+ config=tf.contrib.learn.RunConfig(tf_random_seed=1))
classifier.fit(input_fn=_input_fn_train, steps=5)
scores = classifier.evaluate(input_fn=_input_fn_eval, steps=1)
@@ -339,8 +428,8 @@ class DNNClassifierTest(tf.test.TestCase):
classifier = tf.contrib.learn.DNNClassifier(
n_classes=3,
feature_columns=feature_columns,
- hidden_units=[3, 3],
- config=tf.contrib.learn.RunConfig(tf_random_seed=3))
+ hidden_units=[10, 10],
+ config=tf.contrib.learn.RunConfig(tf_random_seed=1))
classifier.fit(input_fn=_input_fn, steps=100)
@@ -524,7 +613,7 @@ class DNNClassifierTest(tf.test.TestCase):
}
with tf.test.mock.patch.dict('os.environ',
{'TF_CONFIG': json.dumps(tf_config)}):
- config = tf.contrib.learn.RunConfig(tf_random_seed=5)
+ config = tf.contrib.learn.RunConfig(tf_random_seed=1)
# Because we did not start a distributed cluster, we need to pass an
# empty ClusterSpec, otherwise the device_setter will look for
# distributed jobs, such as "/job:ps" which are not present.
@@ -707,7 +796,7 @@ class DNNRegressorTest(tf.test.TestCase):
regressor = tf.contrib.learn.DNNRegressor(
feature_columns=[tf.contrib.layers.real_valued_column('x')],
hidden_units=[3, 3],
- config=tf.contrib.learn.RunConfig(tf_random_seed=3))
+ config=tf.contrib.learn.RunConfig(tf_random_seed=1))
regressor.fit(input_fn=_input_fn_train, steps=5)
scores = regressor.evaluate(input_fn=_input_fn_train, steps=1)
@@ -772,7 +861,7 @@ class DNNRegressorTest(tf.test.TestCase):
weight_column_name='w',
feature_columns=[tf.contrib.layers.real_valued_column('x')],
hidden_units=[3, 3],
- config=tf.contrib.learn.RunConfig(tf_random_seed=3))
+ config=tf.contrib.learn.RunConfig(tf_random_seed=1))
regressor.fit(input_fn=_input_fn_train, steps=5)
scores = regressor.evaluate(input_fn=_input_fn_eval, steps=1)
@@ -803,7 +892,7 @@ class DNNRegressorTest(tf.test.TestCase):
regressor = tf.contrib.learn.DNNRegressor(
feature_columns=feature_columns,
hidden_units=[3, 3],
- config=tf.contrib.learn.RunConfig(tf_random_seed=3))
+ config=tf.contrib.learn.RunConfig(tf_random_seed=1))
regressor.fit(input_fn=_input_fn, steps=200)
@@ -837,7 +926,7 @@ class DNNRegressorTest(tf.test.TestCase):
regressor = tf.contrib.learn.DNNRegressor(
feature_columns=feature_columns,
hidden_units=[3, 3],
- config=tf.contrib.learn.RunConfig(tf_random_seed=3))
+ config=tf.contrib.learn.RunConfig(tf_random_seed=1))
regressor.fit(input_fn=_input_fn, steps=200)
@@ -918,7 +1007,7 @@ class DNNRegressorTest(tf.test.TestCase):
model_dir=model_dir,
feature_columns=feature_columns,
hidden_units=[3, 3],
- config=tf.contrib.learn.RunConfig(tf_random_seed=3))
+ config=tf.contrib.learn.RunConfig(tf_random_seed=1))
regressor.fit(input_fn=_input_fn, steps=5)
predict_input_fn = functools.partial(_input_fn, num_epochs=1)
@@ -929,7 +1018,7 @@ class DNNRegressorTest(tf.test.TestCase):
model_dir=model_dir,
feature_columns=feature_columns,
hidden_units=[3, 3],
- config=tf.contrib.learn.RunConfig(tf_random_seed=3))
+ config=tf.contrib.learn.RunConfig(tf_random_seed=1))
predictions2 = list(regressor2.predict(input_fn=predict_input_fn))
self.assertAllClose(predictions, predictions2)
@@ -1004,7 +1093,7 @@ class DNNRegressorTest(tf.test.TestCase):
feature_columns=feature_columns,
hidden_units=[3, 3],
enable_centered_bias=True,
- config=tf.contrib.learn.RunConfig(tf_random_seed=3))
+ config=tf.contrib.learn.RunConfig(tf_random_seed=1))
regressor.fit(input_fn=_input_fn, steps=5)
self.assertIn('centered_bias_weight', regressor.get_variable_names())
@@ -1037,7 +1126,7 @@ class DNNRegressorTest(tf.test.TestCase):
feature_columns=feature_columns,
hidden_units=[3, 3],
enable_centered_bias=False,
- config=tf.contrib.learn.RunConfig(tf_random_seed=3))
+ config=tf.contrib.learn.RunConfig(tf_random_seed=1))
regressor.fit(input_fn=_input_fn, steps=5)
self.assertNotIn('centered_bias_weight', regressor.get_variable_names())
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
index cdd8300e2e..be28f07a7f 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
@@ -59,25 +59,6 @@ _CELL_TYPES = {'basic_rnn': rnn_cell.BasicRNNCell,
'gru': rnn_cell.GRUCell,}
-# TODO(jamieas): move `padding_mask` to array_ops.
-def padding_mask(sequence_lengths, padded_length):
- """Creates a mask used for calculating losses with padded input.
-
- Args:
- sequence_lengths: A `Tensor` of shape `[batch_size]` containing the unpadded
- length of each sequence.
- padded_length: A scalar `Tensor` indicating the length of the sequences
- after padding
- Returns:
- A boolean `Tensor` M of shape `[batch_size, padded_length]` where
- `M[i, j] == True` when `lengths[i] > j`.
-
- """
- range_tensor = math_ops.range(padded_length)
- return math_ops.less(array_ops.expand_dims(range_tensor, 0),
- array_ops.expand_dims(sequence_lengths, 1))
-
-
def mask_activations_and_labels(activations, labels, sequence_lengths):
"""Remove entries outside `sequence_lengths` and returned flattened results.
@@ -89,7 +70,7 @@ def mask_activations_and_labels(activations, labels, sequence_lengths):
Returns:
activations_masked: `logit` values with those beyond `sequence_lengths`
- removed for each batch. Batches are then concatenated. Shape
+ removed for each batch. Batches are then concatenated. Shape
`[tf.sum(sequence_lengths), k]` if `sequence_lengths` is not `None` and
shape `[batch_size * padded_length, k]` otherwise.
labels_masked: Label values after removing unneeded entries. Shape
@@ -107,7 +88,7 @@ def mask_activations_and_labels(activations, labels, sequence_lengths):
[flattened_dimension, -1])
labels_masked = array_ops.reshape(labels, [flattened_dimension])
else:
- mask = padding_mask(sequence_lengths, padded_length)
+ mask = array_ops.sequence_mask(sequence_lengths, padded_length)
activations_masked = array_ops.boolean_mask(activations, mask)
labels_masked = array_ops.boolean_mask(labels, mask)
return activations_masked, labels_masked
@@ -236,7 +217,7 @@ def construct_rnn(initial_state,
num_label_columns,
dtype=dtypes.float32,
parallel_iterations=32,
- swap_memory=False):
+ swap_memory=True):
"""Build an RNN and apply a fully connected layer to get the desired output.
Args:
@@ -273,6 +254,9 @@ def construct_rnn(initial_state,
num_outputs=num_label_columns,
activation_fn=None,
trainable=True)
+ # Use `identitiy` to rename `final_state`.
+ final_state = array_ops.identity(
+ final_state, name=RNNKeys.FINAL_STATE_KEY)
return activations, final_state
@@ -371,13 +355,15 @@ def _multi_value_predictions(
probability_shape = array_ops.concat(0, [activations_shape[:2], [2]])
else:
probability_shape = activations_shape
- probabilities = array_ops.reshape(flat_probabilities, probability_shape)
+ probabilities = array_ops.reshape(
+ flat_probabilities, probability_shape, name=RNNKeys.PROBABILITIES_KEY)
prediction_dict[RNNKeys.PROBABILITIES_KEY] = probabilities
else:
flat_predictions = target_column.logits_to_predictions(
flattened_activations, proba=False)
predictions = array_ops.reshape(
- flat_predictions, [activations_shape[0], activations_shape[1]])
+ flat_predictions, [activations_shape[0], activations_shape[1]],
+ name=RNNKeys.PREDICTIONS_KEY)
prediction_dict[RNNKeys.PREDICTIONS_KEY] = predictions
return prediction_dict
@@ -474,7 +460,7 @@ def apply_dropout(
cell: An `RNNCell`.
input_keep_probability: Probability to keep inputs to `cell`. If `None`,
no dropout is applied.
- output_keep_probability: Probability to keep outputs to `cell`. If `None`,
+ output_keep_probability: Probability to keep outputs of `cell`. If `None`,
no dropout is applied.
random_seed: Seed for random dropout.
@@ -509,13 +495,12 @@ def _get_dynamic_rnn_model_fn(cell,
initial_state_key=RNNKeys.INITIAL_STATE_KEY,
dtype=dtypes.float32,
parallel_iterations=None,
- swap_memory=False,
+ swap_memory=True,
name='DynamicRNNModel'):
"""Creates an RNN model function for an `Estimator`.
Args:
cell: An initialized `RNNCell` to be used in the RNN.
- 'basic_rnn,' 'lstm' or 'gru'.
target_column: An initialized `TargetColumn`, used to calculate prediction
and loss.
problem_type: `ProblemType.CLASSIFICATION` or`ProblemType.REGRESSION`.
@@ -527,23 +512,23 @@ def _get_dynamic_rnn_model_fn(cell,
describing sequence features. All items in the set should be instances
of classes derived from `FeatureColumn`.
context_feature_columns: An iterable containing all the feature columns
- describing context features i.e. features that apply accross all time
+ describing context features, i.e., features that apply accross all time
steps. All items in the set should be instances of classes derived from
`FeatureColumn`.
predict_probabilities: A boolean indicating whether to predict probabilities
- for all classes. Should only be used with `ProblemType.CLASSIFICATION`.
+ for all classes. Must only be used with `ProblemType.CLASSIFICATION`.
learning_rate: Learning rate used for optimization. This argument has no
effect if `optimizer` is an instance of an `Optimizer`.
gradient_clipping_norm: A float. Gradients will be clipped to this value.
input_keep_probability: Probability to keep inputs to `cell`. If `None`,
no dropout is applied.
- output_keep_probability: Probability to keep outputs to `cell`. If `None`,
+ output_keep_probability: Probability to keep outputs of `cell`. If `None`,
no dropout is applied.
sequence_length_key: The key that will be used to look up sequence length in
the `features` dict.
initial_state_key: The key that will be used to look up initial_state in
the `features` dict.
- dtype: The dtype of the state and output for the given `cell_num`
+ dtype: The dtype of the state and output of the given `cell`.
parallel_iterations: Number of iterations to run in parallel. Values >> 1
use more memory but take less time, while smaller values use less memory
but computations take longer.
@@ -601,30 +586,41 @@ def _get_dynamic_rnn_model_fn(cell,
dtype=dtype,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory)
+
+ loss = None # Created below for modes TRAIN and EVAL.
if prediction_type == PredictionType.MULTIPLE_VALUE:
prediction_dict = _multi_value_predictions(
rnn_activations, target_column, predict_probabilities)
- loss = _multi_value_loss(
- rnn_activations, labels, sequence_length, target_column, features)
+ if mode != model_fn.ModeKeys.INFER:
+ loss = _multi_value_loss(
+ rnn_activations, labels, sequence_length, target_column, features)
elif prediction_type == PredictionType.SINGLE_VALUE:
prediction_dict = _single_value_predictions(
rnn_activations, sequence_length, target_column,
predict_probabilities)
- loss = _single_value_loss(
- rnn_activations, labels, sequence_length, target_column, features)
- # TODO(roumposg): Return eval_metric_ops here, instead of default_metrics.
- default_metrics = _get_default_metrics(
- problem_type, prediction_type, sequence_length)
+ if mode != model_fn.ModeKeys.INFER:
+ loss = _single_value_loss(
+ rnn_activations, labels, sequence_length, target_column, features)
prediction_dict[RNNKeys.FINAL_STATE_KEY] = final_state
- eval_metric_ops = estimator._make_metrics_ops( # pylint: disable=protected-access
- default_metrics, features, labels, prediction_dict)
- train_op = optimizers.optimize_loss(
- loss=loss,
- global_step=None,
- learning_rate=learning_rate,
- optimizer=optimizer,
- clip_gradients=gradient_clipping_norm,
- summaries=optimizers.OPTIMIZER_SUMMARIES)
+
+ eval_metric_ops = None
+ if mode != model_fn.ModeKeys.INFER:
+ # TODO(roumposg): Return eval_metric_ops instead of default_metrics.
+ default_metrics = _get_default_metrics(
+ problem_type, prediction_type, sequence_length)
+ eval_metric_ops = estimator._make_metrics_ops( # pylint: disable=protected-access
+ default_metrics, features, labels, prediction_dict)
+
+ train_op = None
+ if mode == model_fn.ModeKeys.TRAIN:
+ train_op = optimizers.optimize_loss(
+ loss=loss,
+ global_step=None, # Get it internally.
+ learning_rate=learning_rate,
+ optimizer=optimizer,
+ clip_gradients=gradient_clipping_norm,
+ summaries=optimizers.OPTIMIZER_SUMMARIES)
+
return model_fn.ModelFnOps(mode=mode,
predictions=prediction_dict,
loss=loss,
@@ -674,43 +670,43 @@ def multi_value_rnn_regressor(num_units,
optimizer_type='SGD',
learning_rate=0.1,
momentum=None,
- gradient_clipping_norm=10.0,
+ gradient_clipping_norm=5.0,
input_keep_probability=None,
output_keep_probability=None,
model_dir=None,
config=None,
- params=None,
feature_engineering_fn=None):
-
"""Creates a RNN `Estimator` that predicts sequences of values.
Args:
- num_units: The size of the RNN cells.
+ num_units: The size of the RNN cells. This argument has no effect
+ if `cell_type` is an instance of `RNNCell`.
sequence_feature_columns: An iterable containing all the feature columns
describing sequence features. All items in the set should be instances
of classes derived from `FeatureColumn`.
context_feature_columns: An iterable containing all the feature columns
- describing context features i.e. features that apply accross all time
+ describing context features, i.e., features that apply accross all time
steps. All items in the set should be instances of classes derived from
`FeatureColumn`.
- cell_type: A subclass of `RNNCell`, an instance of an `RNNCell or one of
+ cell_type: A subclass of `RNNCell`, an instance of an `RNNCell` or one of
'basic_rnn,' 'lstm' or 'gru'.
- num_rnn_layers: Number of RNN layers.
+ num_rnn_layers: Number of RNN layers. Leave this at its default value 1
+ if passing a `cell_type` that is already a MultiRNNCell.
optimizer_type: The type of optimizer to use. Either a subclass of
`Optimizer`, an instance of an `Optimizer` or a string. Strings must be
one of 'Adagrad', 'Momentum' or 'SGD'.
- learning_rate: Learning rate.
+ learning_rate: Learning rate. This argument has no effect if `optimizer`
+ is an instance of an `Optimizer`.
momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'.
gradient_clipping_norm: Parameter used for gradient clipping. If `None`,
then no clipping is performed.
input_keep_probability: Probability to keep inputs to `cell`. If `None`,
no dropout is applied.
- output_keep_probability: Probability to keep outputs to `cell`. If `None`,
+ output_keep_probability: Probability to keep outputs of `cell`. If `None`,
no dropout is applied.
- model_dir: Directory to use for The directory in which to save and restore
- the model graph, parameters, etc.
+ model_dir: The directory in which to save and restore the model graph,
+ parameters, etc.
config: A `RunConfig` instance.
- params: `dict` of hyperparameters. Passed through to `Estimator`.
feature_engineering_fn: Takes features and labels which are the output of
`input_fn` and returns features and labels which will be fed into
`model_fn`. Please check `model_fn` for a definition of features and
@@ -739,7 +735,6 @@ def multi_value_rnn_regressor(num_units,
return estimator.Estimator(model_fn=dynamic_rnn_model_fn,
model_dir=model_dir,
config=config,
- params=params,
feature_engineering_fn=feature_engineering_fn)
@@ -754,32 +749,34 @@ def multi_value_rnn_classifier(num_classes,
learning_rate=0.1,
predict_probabilities=False,
momentum=None,
- gradient_clipping_norm=10.0,
+ gradient_clipping_norm=5.0,
input_keep_probability=None,
output_keep_probability=None,
model_dir=None,
config=None,
- params=None,
feature_engineering_fn=None):
"""Creates a RNN `Estimator` that predicts sequences of labels.
Args:
num_classes: The number of classes for categorization.
- num_units: The size of the RNN cells.
+ num_units: The size of the RNN cells. This argument has no effect
+ if `cell_type` is an instance of `RNNCell`.
sequence_feature_columns: An iterable containing all the feature columns
describing sequence features. All items in the set should be instances
of classes derived from `FeatureColumn`.
context_feature_columns: An iterable containing all the feature columns
- describing context features i.e. features that apply accross all time
+ describing context features, i.e., features that apply accross all time
steps. All items in the set should be instances of classes derived from
`FeatureColumn`.
cell_type: A subclass of `RNNCell`, an instance of an `RNNCell or one of
'basic_rnn,' 'lstm' or 'gru'.
- num_rnn_layers: Number of RNN layers.
+ num_rnn_layers: Number of RNN layers. Leave this at its default value 1
+ if passing a `cell_type` that is already a MultiRNNCell.
optimizer_type: The type of optimizer to use. Either a subclass of
`Optimizer`, an instance of an `Optimizer` or a string. Strings must be
one of 'Adagrad', 'Momentum' or 'SGD'.
- learning_rate: Learning rate.
+ learning_rate: Learning rate. This argument has no effect if `optimizer`
+ is an instance of an `Optimizer`.
predict_probabilities: A boolean indicating whether to predict probabilities
for all classes.
momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'.
@@ -787,12 +784,11 @@ def multi_value_rnn_classifier(num_classes,
then no clipping is performed.
input_keep_probability: Probability to keep inputs to `cell`. If `None`,
no dropout is applied.
- output_keep_probability: Probability to keep outputs to `cell`. If `None`,
+ output_keep_probability: Probability to keep outputs of `cell`. If `None`,
no dropout is applied.
- model_dir: Directory to use for The directory in which to save and restore
- the model graph, parameters, etc.
+ model_dir: The directory in which to save and restore the model graph,
+ parameters, etc.
config: A `RunConfig` instance.
- params: `dict` of hyperparameters. Passed through to `Estimator`.
feature_engineering_fn: Takes features and labels which are the output of
`input_fn` and returns features and labels which will be fed into
`model_fn`. Please check `model_fn` for a definition of features and
@@ -822,7 +818,6 @@ def multi_value_rnn_classifier(num_classes,
return estimator.Estimator(model_fn=dynamic_rnn_model_fn,
model_dir=model_dir,
config=config,
- params=params,
feature_engineering_fn=feature_engineering_fn)
@@ -835,42 +830,43 @@ def single_value_rnn_regressor(num_units,
optimizer_type='SGD',
learning_rate=0.1,
momentum=None,
- gradient_clipping_norm=10.0,
+ gradient_clipping_norm=5.0,
input_keep_probability=None,
output_keep_probability=None,
model_dir=None,
config=None,
- params=None,
feature_engineering_fn=None):
"""Create a RNN `Estimator` that predicts single values.
Args:
- num_units: The size of the RNN cells.
+ num_units: The size of the RNN cells. This argument has no effect
+ if `cell_type` is an instance of `RNNCell`.
sequence_feature_columns: An iterable containing all the feature columns
describing sequence features. All items in the set should be instances
of classes derived from `FeatureColumn`.
context_feature_columns: An iterable containing all the feature columns
- describing context features i.e. features that apply accross all time
+ describing context features, i.e., features that apply accross all time
steps. All items in the set should be instances of classes derived from
`FeatureColumn`.
cell_type: A subclass of `RNNCell`, an instance of an `RNNCell or one of
'basic_rnn,' 'lstm' or 'gru'.
- num_rnn_layers: Number of RNN layers.
+ num_rnn_layers: Number of RNN layers. Leave this at its default value 1
+ if passing a `cell_type` that is already a MultiRNNCell.
optimizer_type: The type of optimizer to use. Either a subclass of
`Optimizer`, an instance of an `Optimizer` or a string. Strings must be
one of 'Adagrad', 'Momentum' or 'SGD'.
- learning_rate: Learning rate.
+ learning_rate: Learning rate. This argument has no effect if `optimizer`
+ is an instance of an `Optimizer`.
momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'.
gradient_clipping_norm: Parameter used for gradient clipping. If `None`,
then no clipping is performed.
input_keep_probability: Probability to keep inputs to `cell`. If `None`,
no dropout is applied.
- output_keep_probability: Probability to keep outputs to `cell`. If `None`,
+ output_keep_probability: Probability to keep outputs of `cell`. If `None`,
no dropout is applied.
- model_dir: Directory to use for The directory in which to save and restore
- the model graph, parameters, etc.
+ model_dir: The directory in which to save and restore the model graph,
+ parameters, etc.
config: A `RunConfig` instance.
- params: `dict` of hyperparameters. Passed through to `Estimator`.
feature_engineering_fn: Takes features and labels which are the output of
`input_fn` and returns features and labels which will be fed into
`model_fn`. Please check `model_fn` for a definition of features and
@@ -899,7 +895,6 @@ def single_value_rnn_regressor(num_units,
return estimator.Estimator(model_fn=dynamic_rnn_model_fn,
model_dir=model_dir,
config=config,
- params=params,
feature_engineering_fn=feature_engineering_fn)
@@ -914,32 +909,34 @@ def single_value_rnn_classifier(num_classes,
learning_rate=0.1,
predict_probabilities=False,
momentum=None,
- gradient_clipping_norm=10.0,
+ gradient_clipping_norm=5.0,
input_keep_probability=None,
output_keep_probability=None,
model_dir=None,
config=None,
- params=None,
feature_engineering_fn=None):
"""Creates a RNN `Estimator` that predicts single labels.
Args:
num_classes: The number of classes for categorization.
- num_units: The size of the RNN cells.
+ num_units: The size of the RNN cells. This argument has no effect
+ if `cell_type` is an instance of `RNNCell`.
sequence_feature_columns: An iterable containing all the feature columns
describing sequence features. All items in the set should be instances
of classes derived from `FeatureColumn`.
context_feature_columns: An iterable containing all the feature columns
- describing context features i.e. features that apply accross all time
+ describing context features, i.e., features that apply accross all time
steps. All items in the set should be instances of classes derived from
`FeatureColumn`.
cell_type: A subclass of `RNNCell`, an instance of an `RNNCell or one of
'basic_rnn,' 'lstm' or 'gru'.
- num_rnn_layers: Number of RNN layers.
+ num_rnn_layers: Number of RNN layers. Leave this at its default value 1
+ if passing a `cell_type` that is already a MultiRNNCell.
optimizer_type: The type of optimizer to use. Either a subclass of
`Optimizer`, an instance of an `Optimizer` or a string. Strings must be
one of 'Adagrad', 'Momentum' or 'SGD'.
- learning_rate: Learning rate.
+ learning_rate: Learning rate. This argument has no effect if `optimizer`
+ is an instance of an `Optimizer`.
predict_probabilities: A boolean indicating whether to predict probabilities
for all classes.
momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'.
@@ -947,12 +944,11 @@ def single_value_rnn_classifier(num_classes,
then no clipping is performed.
input_keep_probability: Probability to keep inputs to `cell`. If `None`,
no dropout is applied.
- output_keep_probability: Probability to keep outputs to `cell`. If `None`,
+ output_keep_probability: Probability to keep outputs of `cell`. If `None`,
no dropout is applied.
- model_dir: Directory to use for The directory in which to save and restore
- the model graph, parameters, etc.
+ model_dir: The directory in which to save and restore the model graph,
+ parameters, etc.
config: A `RunConfig` instance.
- params: `dict` of hyperparameters. Passed through to `Estimator`.
feature_engineering_fn: Takes features and labels which are the output of
`input_fn` and returns features and labels which will be fed into
`model_fn`. Please check `model_fn` for a definition of features and
@@ -982,5 +978,4 @@ def single_value_rnn_classifier(num_classes,
return estimator.Estimator(model_fn=dynamic_rnn_model_fn,
model_dir=model_dir,
config=config,
- params=params,
feature_engineering_fn=feature_engineering_fn)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
index a2df6de6fd..f534789270 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import tempfile
+
import numpy as np
import tensorflow as tf
@@ -69,17 +71,6 @@ class MockTargetColumn(object):
self._num_label_columns = n
-class MockOptimizer(object):
-
- def compute_gradients(self, loss, var_list):
- raise NotImplementedError(
- 'MockOptimizer.compute_gradients called unexpectedly.')
-
- def apply_gradients(self, processed_gradients, global_step):
- raise NotImplementedError(
- 'MockOptimizer.apply_gradients called unexpectedly.')
-
-
def sequence_length_mask(values, lengths):
masked = values
for i, length in enumerate(lengths):
@@ -95,6 +86,7 @@ class DynamicRnnEstimatorTest(tf.test.TestCase):
'inputs', dimension=NUM_LABEL_COLUMNS)
def setUp(self):
+ super(DynamicRnnEstimatorTest, self).setUp()
self.rnn_cell = rnn_cell.BasicRNNCell(self.NUM_RNN_CELL_UNITS)
self.mock_target_column = MockTargetColumn(
num_label_columns=self.NUM_LABEL_COLUMNS)
@@ -112,7 +104,9 @@ class DynamicRnnEstimatorTest(tf.test.TestCase):
'measurements', dimension=2)
self.sequence_feature_columns = [measurements, wire_cast_embedded]
- self.columns_to_tensors = {
+ def GetColumnsToTensors(self):
+ """Get columns_to_tensors matching setUp(), in the current default graph."""
+ return {
'location': tf.SparseTensor(
indices=[[0, 0], [1, 0], [2, 0]],
values=['west_side', 'west_side', 'nyc'],
@@ -125,11 +119,16 @@ class DynamicRnnEstimatorTest(tf.test.TestCase):
b'omar', b'stringer', b'marlo',
b'marlo'],
shape=[3, 2, 2]),
- 'measurements': tf.random_uniform([3, 2, 2])}
+ 'measurements': tf.random_uniform([3, 2, 2], seed=4711)}
+
+ def GetClassificationTargetsOrNone(self, mode):
+ """Get targets matching setUp() and mode, in the current default graph."""
+ return (tf.random_uniform([3, 2, 1], 0, 2, dtype=tf.int64, seed=1412)
+ if mode != tf.contrib.learn.ModeKeys.INFER else None)
def testBuildSequenceInputInput(self):
sequence_input = dynamic_rnn_estimator.build_sequence_input(
- self.columns_to_tensors,
+ self.GetColumnsToTensors(),
self.sequence_feature_columns,
self.context_feature_columns)
with self.test_session() as sess:
@@ -146,7 +145,7 @@ class DynamicRnnEstimatorTest(tf.test.TestCase):
def testConstructRNN(self):
initial_state = None
sequence_input = dynamic_rnn_estimator.build_sequence_input(
- self.columns_to_tensors,
+ self.GetColumnsToTensors(),
self.sequence_feature_columns,
self.context_feature_columns)
activations_t, final_state_t = dynamic_rnn_estimator.construct_rnn(
@@ -166,30 +165,6 @@ class DynamicRnnEstimatorTest(tf.test.TestCase):
expected_state_shape = np.array([3, self.NUM_RNN_CELL_UNITS])
self.assertAllEqual(expected_state_shape, final_state.shape)
- def testPaddingMask(self):
- """Test `padding_mask`."""
- batch_size = 16
- padded_length = 32
- np.random.seed(1234)
- sequence_lengths = np.random.randint(0, padded_length + 1, batch_size)
-
- padding_mask_t = dynamic_rnn_estimator.padding_mask(
- tf.constant(sequence_lengths, dtype=tf.int32),
- tf.constant(padded_length, dtype=tf.int32))
-
- with tf.Session() as sess:
- padding_mask = sess.run(padding_mask_t)
-
- for i in range(batch_size):
- actual_mask = padding_mask[i]
- expected_mask = np.concatenate(
- [np.ones(sequence_lengths[i]),
- np.zeros(padded_length - sequence_lengths[i])],
- axis=0)
- np.testing.assert_equal(actual_mask, expected_mask,
- 'Mismatch on row {}. Got {}; expected {}.'.format(
- i, actual_mask, expected_mask))
-
def testMaskActivationsAndLabels(self):
"""Test `mask_activations_and_labels`."""
batch_size = 4
@@ -275,9 +250,90 @@ class DynamicRnnEstimatorTest(tf.test.TestCase):
' Expected {}; got {}.'.format(i, expected_activations,
actual_activations))
+ # testGetDynamicRnnModelFn{Train,Eval,Infer}() test which fields
+ # of ModelFnOps are set depending on mode.
+ def testGetDynamicRnnModelFnTrain(self):
+ model_fn_ops = self._GetModelFnOpsForMode(tf.contrib.learn.ModeKeys.TRAIN)
+ self.assertIsNotNone(model_fn_ops.predictions)
+ self.assertIsNotNone(model_fn_ops.loss)
+ self.assertIsNotNone(model_fn_ops.train_op)
+ # None may get normalized to {}; we accept neither.
+ self.assertNotEqual(len(model_fn_ops.eval_metric_ops), 0)
+
+ def testGetDynamicRnnModelFnEval(self):
+ model_fn_ops = self._GetModelFnOpsForMode(tf.contrib.learn.ModeKeys.EVAL)
+ self.assertIsNotNone(model_fn_ops.predictions)
+ self.assertIsNotNone(model_fn_ops.loss)
+ self.assertIsNone(model_fn_ops.train_op)
+ # None may get normalized to {}; we accept neither.
+ self.assertNotEqual(len(model_fn_ops.eval_metric_ops), 0)
+
+ def testGetDynamicRnnModelFnInfer(self):
+ model_fn_ops = self._GetModelFnOpsForMode(tf.contrib.learn.ModeKeys.INFER)
+ self.assertIsNotNone(model_fn_ops.predictions)
+ self.assertIsNone(model_fn_ops.loss)
+ self.assertIsNone(model_fn_ops.train_op)
+ # None may get normalized to {}; we accept both.
+ self.assertFalse(model_fn_ops.eval_metric_ops)
+
+ def _GetModelFnOpsForMode(self, mode):
+ """Helper for testGetDynamicRnnModelFn{Train,Eval,Infer}()."""
+ model_fn = dynamic_rnn_estimator._get_dynamic_rnn_model_fn(
+ self.rnn_cell,
+ target_column=tf.contrib.layers.multi_class_target(n_classes=2),
+ # Only CLASSIFICATION yields eval metrics to test for.
+ problem_type=dynamic_rnn_estimator.ProblemType.CLASSIFICATION,
+ prediction_type=dynamic_rnn_estimator.PredictionType.MULTIPLE_VALUE,
+ optimizer='SGD',
+ sequence_feature_columns=self.sequence_feature_columns,
+ context_feature_columns=self.context_feature_columns,
+ learning_rate=0.1)
+ labels = self.GetClassificationTargetsOrNone(mode)
+ model_fn_ops = model_fn(features=self.GetColumnsToTensors(),
+ labels=labels, mode=mode)
+ return model_fn_ops
+
+ def testExport(self):
+ input_feature_key = 'magic_input_feature_key'
+ def get_input_fn(mode):
+ def input_fn():
+ features = self.GetColumnsToTensors()
+ if mode == tf.contrib.learn.ModeKeys.INFER:
+ input_examples = tf.placeholder(tf.string)
+ features[input_feature_key] = input_examples
+ # Real code would now parse features out of input_examples,
+ # but this test can just stick to the constants above.
+ return features, self.GetClassificationTargetsOrNone(mode)
+ return input_fn
+
+ model_dir = tempfile.mkdtemp()
+ def estimator_fn():
+ return dynamic_rnn_estimator.multi_value_rnn_classifier(
+ num_classes=2,
+ num_units=self.NUM_RNN_CELL_UNITS,
+ sequence_feature_columns=self.sequence_feature_columns,
+ context_feature_columns=self.context_feature_columns,
+ predict_probabilities=True,
+ model_dir=model_dir)
+
+ # Train a bit to create an exportable checkpoint.
+ estimator_fn().fit(
+ input_fn=get_input_fn(tf.contrib.learn.ModeKeys.TRAIN), steps=100)
+ # Now export, but from a fresh estimator instance, like you would
+ # in an export binary. That means .export() has to work without
+ # .fit() being called on the same object.
+ export_dir = tempfile.mkdtemp()
+ print('Exporting to', export_dir)
+ estimator_fn().export(
+ export_dir,
+ input_fn=get_input_fn(tf.contrib.learn.ModeKeys.INFER),
+ use_deprecated_input_fn=False,
+ input_feature_key=input_feature_key)
+
+
# TODO(jamieas): move all tests below to a benchmark test.
class DynamicRNNEstimatorLearningTest(tf.test.TestCase):
- """Learning tests for dymanic RNN Estimators."""
+ """Learning tests for dynamic RNN Estimators."""
def testLearnSineFunction(self):
"""Tests learning a sine function."""
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 5402c20297..ade126d6d1 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -38,6 +38,8 @@ from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.framework import deprecated_args
from tensorflow.contrib.framework import list_variables
from tensorflow.contrib.framework import load_variable
+from tensorflow.contrib.framework.python.framework import experimental
+from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import graph_actions
from tensorflow.contrib.learn.python.learn import metric_spec
@@ -51,14 +53,21 @@ from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
from tensorflow.contrib.learn.python.learn.utils import export
-
+from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
+from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import builder as saved_model_builder
+from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import device_setter
from tensorflow.python.training import saver
+from tensorflow.python.util import compat
AS_ITERABLE_DATE = '2016-09-15'
@@ -553,13 +562,12 @@ class BaseEstimator(
use_deprecated_input_fn=use_deprecated_input_fn,
default_batch_size=default_batch_size,
exports_to_keep=exports_to_keep)
- # pylint: enable=protected-access
@abc.abstractproperty
def _get_train_ops(self, features, labels):
"""Method that builds model graph and returns trainer ops.
- Expected to be overriden by sub-classes that require custom support.
+ Expected to be overridden by sub-classes that require custom support.
Args:
features: `Tensor` or `dict` of `Tensor` objects.
@@ -1126,6 +1134,106 @@ class Estimator(BaseEstimator):
self._labels_info[model_fn_lib.ModeKeys.INFER])
return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.INFER)
+ @experimental
+ def export_savedmodel(
+ self, export_dir_base, input_fn,
+ default_output_alternative_key=None,
+ assets_extra=None,
+ as_text=False,
+ exports_to_keep=None):
+ """Exports inference graph as a SavedModel into given dir.
+
+ Args:
+ export_dir_base: A string containing a directory to write the exported
+ graph and checkpoints.
+ input_fn: A function that takes no argument and
+ returns an `InputFnOps`.
+ default_output_alternative_key: the name of the head to serve when none is
+ specified.
+ assets_extra: A dict specifying how to populate the assets.extra directory
+ within the exported SavedModel. Each key should give the destination
+ path (including the filename) relative to the assets.extra directory.
+ The corresponding value gives the full path of the source file to be
+ copied. For example, the simple case of copying a single file without
+ renaming it is specified as
+ `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
+ as_text: whether to write the SavedModel proto in text format.
+ exports_to_keep: Number of exports to keep.
+
+ Returns:
+ The string path to the exported directory.
+
+ Raises:
+ ValueError: if an unrecognized export_type is requested.
+ """
+ if input_fn is None:
+ raise ValueError('input_fn must be defined.')
+
+ with ops.Graph().as_default() as g:
+ contrib_variables.create_global_step(g)
+
+ # Call the input_fn and collect the input alternatives.
+ input_ops = input_fn()
+ input_alternatives, features = (
+ saved_model_export_utils.get_input_alternatives(input_ops))
+
+ # Call the model_fn and collect the output alternatives.
+ model_fn_ops = self._call_model_fn(features, None,
+ model_fn_lib.ModeKeys.INFER)
+ output_alternatives, actual_default_output_alternative_key = (
+ saved_model_export_utils.get_output_alternatives(
+ model_fn_ops, default_output_alternative_key))
+
+ # Build the SignatureDefs from all pairs of input and output signatures
+ signature_def_map = saved_model_export_utils.build_all_signature_defs(
+ input_alternatives, output_alternatives,
+ actual_default_output_alternative_key)
+
+ # Locate the latest checkpoint
+ # TODO(soergel): does it help that we know we have one from this step?
+ checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ if not checkpoint_path:
+ raise NotFittedError("Couldn't find trained model at %s."
+ % self._model_dir)
+
+ export_dir = saved_model_export_utils.get_timestamped_export_dir(
+ export_dir_base)
+
+ with tf_session.Session('') as session:
+ variables.initialize_local_variables()
+ data_flow_ops.initialize_all_tables()
+ saver_for_restore = saver.Saver(
+ variables.global_variables(),
+ sharded=True)
+ saver_for_restore.restore(session, checkpoint_path)
+
+ init_op = control_flow_ops.group(
+ variables.local_variables_initializer(),
+ data_flow_ops.initialize_all_tables())
+
+ # Perform the export
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+ builder.add_meta_graph_and_variables(
+ session, [tag_constants.SERVING],
+ signature_def_map=signature_def_map,
+ assets_collection=ops.get_collection(
+ ops.GraphKeys.ASSET_FILEPATHS),
+ legacy_init_op=init_op)
+ builder.save(as_text)
+
+ # Add the extra assets
+ if assets_extra:
+ assets_extra_path = os.path.join(compat.as_bytes(export_dir),
+ compat.as_bytes('assets.extra'))
+ for dest_relative, source in assets_extra.items():
+ dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
+ compat.as_bytes(dest_relative))
+ dest_path = os.path.dirname(dest_absolute)
+ gfile.MakeDirs(dest_path)
+ gfile.Copy(source, dest_absolute)
+
+ return export_dir
+
# For time of deprecation x,y from Estimator allow direct access
# pylint: disable=protected-access
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index ae38e7a79e..a43b960a96 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -22,6 +22,7 @@ from __future__ import print_function
import functools
import itertools
import json
+import os
import tempfile
import numpy as np
@@ -33,6 +34,11 @@ from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import model_fn
+from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
+from tensorflow.python.framework import ops
+from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.util import compat
_BOSTON_INPUT_DIM = 13
@@ -105,6 +111,8 @@ def linear_model_fn(features, labels, mode):
tf.contrib.learn.ModeKeys.TRAIN,
tf.contrib.learn.ModeKeys.EVAL,
tf.contrib.learn.ModeKeys.INFER)
+ if isinstance(features, dict):
+ (_, features), = features.items()
prediction, loss = (
tf.contrib.learn.models.linear_regression_zero_init(features, labels)
)
@@ -144,6 +152,45 @@ def logistic_model_no_mode_fn(features, labels):
learning_rate=0.1)
return {'class': tf.argmax(prediction, 1), 'prob': prediction}, loss, train_op
+VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n'
+EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n'
+
+
+def _build_estimator_for_export_tests(tmpdir):
+ def _input_fn():
+ iris = tf.contrib.learn.datasets.load_iris()
+ return {
+ 'feature': tf.constant(iris.data, dtype=tf.float32)
+ }, tf.constant(iris.target, shape=[150], dtype=tf.int32)
+
+ feature_columns = [tf.contrib.layers.real_valued_column('feature',
+ dimension=4)]
+
+ est = tf.contrib.learn.LinearRegressor(feature_columns)
+ est.fit(input_fn=_input_fn, steps=20)
+
+ feature_spec = tf.contrib.layers.create_feature_spec_for_parsing(
+ feature_columns)
+ export_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec)
+
+ # hack in an op that uses an asset, in order to test asset export.
+ # this is not actually valid, of course.
+ def export_input_fn_with_asset():
+ features, labels, inputs = export_input_fn()
+
+ vocab_file_name = os.path.join(tmpdir, 'my_vocab_file')
+ vocab_file = tf.gfile.GFile(vocab_file_name, mode='w')
+ vocab_file.write(VOCAB_FILE_CONTENT)
+ vocab_file.close()
+ hashtable = tf.contrib.lookup.HashTable(
+ tf.contrib.lookup.TextFileStringTableInitializer(vocab_file_name), 'x')
+ features['bogus_lookup'] = hashtable.lookup(
+ tf.to_int64(features['feature']))
+
+ return input_fn_utils.InputFnOps(features, labels, inputs)
+
+ return est, export_input_fn_with_asset
+
class CheckCallsMonitor(tf.contrib.learn.monitors.BaseMonitor):
@@ -585,6 +632,76 @@ class EstimatorTest(tf.test.TestCase):
self.assertEquals(expected, actual)
+ def test_export_savedmodel(self):
+ tmpdir = tempfile.mkdtemp()
+ est, export_input_fn = _build_estimator_for_export_tests(tmpdir)
+
+ extra_file_name = os.path.join(compat.as_bytes(tmpdir),
+ compat.as_bytes('my_extra_file'))
+ extra_file = tf.gfile.GFile(extra_file_name, mode='w')
+ extra_file.write(EXTRA_FILE_CONTENT)
+ extra_file.close()
+ assets_extra = {'some/sub/directory/my_extra_file': extra_file_name}
+
+ export_dir_base = os.path.join(compat.as_bytes(tmpdir),
+ compat.as_bytes('export'))
+ export_dir = est.export_savedmodel(export_dir_base, export_input_fn,
+ assets_extra=assets_extra)
+
+ self.assertTrue(tf.gfile.Exists(export_dir_base))
+ self.assertTrue(tf.gfile.Exists(export_dir))
+ self.assertTrue(tf.gfile.Exists(
+ os.path.join(compat.as_bytes(export_dir),
+ compat.as_bytes('saved_model.pb'))))
+ self.assertTrue(tf.gfile.Exists(
+ os.path.join(compat.as_bytes(export_dir),
+ compat.as_bytes('variables'))))
+ self.assertTrue(tf.gfile.Exists(
+ os.path.join(compat.as_bytes(export_dir),
+ compat.as_bytes('variables/variables.index'))))
+ self.assertTrue(tf.gfile.Exists(os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('variables/variables.data-00000-of-00001'))))
+
+ self.assertTrue(tf.gfile.Exists(
+ os.path.join(compat.as_bytes(export_dir), compat.as_bytes('assets'))))
+ self.assertTrue(tf.gfile.Exists(
+ os.path.join(compat.as_bytes(export_dir),
+ compat.as_bytes('assets/my_vocab_file'))))
+ self.assertEqual(
+ compat.as_bytes(VOCAB_FILE_CONTENT),
+ compat.as_bytes(tf.gfile.GFile(
+ os.path.join(compat.as_bytes(export_dir),
+ compat.as_bytes('assets/my_vocab_file'))).read()))
+
+ expected_extra_path = os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('assets.extra/some/sub/directory/my_extra_file'))
+ self.assertTrue(tf.gfile.Exists(
+ os.path.join(compat.as_bytes(export_dir),
+ compat.as_bytes('assets.extra'))))
+ self.assertTrue(tf.gfile.Exists(expected_extra_path))
+ self.assertEqual(
+ compat.as_bytes(EXTRA_FILE_CONTENT),
+ compat.as_bytes(tf.gfile.GFile(expected_extra_path).read()))
+
+ expected_vocab_file = os.path.join(compat.as_bytes(tmpdir),
+ compat.as_bytes('my_vocab_file'))
+ # Restore, to validate that the export was well-formed.
+ with tf.Graph().as_default() as graph:
+ with tf.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.SERVING], export_dir)
+ assets = [x.eval()
+ for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)]
+ self.assertItemsEqual([expected_vocab_file], assets)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('input_example_tensor' in graph_ops)
+ self.assertTrue('ParseExample/ParseExample' in graph_ops)
+ self.assertTrue('linear/linear/feature/matmul' in graph_ops)
+
+ # cleanup
+ tf.gfile.DeleteRecursively(tmpdir)
+
class InferRealValuedColumnsTest(tf.test.TestCase):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index 389b5b2b62..ad7a3f5a46 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -19,10 +19,12 @@ from __future__ import division
from __future__ import print_function
import abc
+import six
from tensorflow.contrib import losses
from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.learn.python.learn import metric_spec
+from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import metric_key
from tensorflow.contrib.learn.python.learn.estimators import model_fn
@@ -33,9 +35,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
-from tensorflow.python.ops import variables
+from tensorflow.python.ops import variable_scope
from tensorflow.python.training import training
@@ -64,8 +67,7 @@ def _regression_head(label_name=None,
Returns:
An instance of _Head
"""
- return _RegressionHead(loss_fn=_mean_squared_loss,
- label_name=label_name,
+ return _RegressionHead(label_name=label_name,
weight_column_name=weight_column_name,
label_dimension=label_dimension,
enable_centered_bias=enable_centered_bias,
@@ -198,6 +200,9 @@ class _Head(object):
"""
__metaclass__ = abc.ABCMeta
+ def __init__(self, head_name):
+ self._head_name = head_name
+
@abc.abstractproperty
def logits_dimension(self):
raise NotImplementedError("Calling an abstract method.")
@@ -215,8 +220,7 @@ class _Head(object):
optimize with the loss.
logits: logits to be used for the head.
logits_input: tensor to build logits from.
- scope: Optional scope for variable_scope. Only used by heads which create
- variables.
+ scope: Optional scope for variable_scope.
Returns:
`ModelFnOps`.
@@ -226,16 +230,48 @@ class _Head(object):
"""
raise NotImplementedError("Calling an abstract method.")
+ def _create_output_alternatives(self, predictions):
+ """Creates output alternative for the Head.
+
+ Args:
+ predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a
+ symbolic name for an output Tensor possibly but not necessarily taken
+ from `PredictionKey`, and 'Tensor' is the corresponding output Tensor
+ itself.
+
+ Returns:
+ `dict` of {submodel_name: (problem_type, {tensor_name: Tensor})}, where
+ 'submodel_name' is a submodel identifier that should be consistent across
+ the pipeline (here likely taken from the head_name),
+ 'problem_type' is a `ProblemType`,
+ 'tensor_name' is a symbolic name for an output Tensor possibly but not
+ necessarily taken from `PredictionKey`, and
+ 'Tensor' is the corresponding output Tensor itself.
+ """
+ return {self._head_name: (self._problem_type, predictions)}
+
+
+# TODO(zakaria): use contrib losses.
+def _mean_squared_loss(logits, labels):
+ with ops.name_scope(None, "mean_squared_loss", (logits, labels)) as name:
+ # To prevent broadcasting inside "-".
+ if len(labels.get_shape()) == 1:
+ labels = array_ops.expand_dims(labels, dim=(1,))
+ # TODO(zakaria): make sure it does not recreate the broadcast bug.
+ if len(logits.get_shape()) == 1:
+ logits = array_ops.expand_dims(logits, dim=(1,))
+ logits.get_shape().assert_is_compatible_with(labels.get_shape())
+ return math_ops.square(logits - math_ops.to_float(labels), name=name)
+
class _RegressionHead(_Head):
"""_Head for regression."""
- def __init__(self, loss_fn, label_name, weight_column_name, label_dimension,
- enable_centered_bias, head_name):
+ def __init__(self, label_name, weight_column_name, label_dimension,
+ enable_centered_bias, head_name, loss_fn=_mean_squared_loss):
"""Base type for all single heads.
Args:
- loss_fn: Loss function.
label_name: String, name of the key in label dict. Can be null if label
is a tensor (single headed models).
weight_column_name: A string defining feature column name representing
@@ -247,15 +283,16 @@ class _RegressionHead(_Head):
residual after centered bias.
head_name: name of the head. If provided, predictions, summary and metrics
keys will be prefixed by the head_name and an underscore.
+ loss_fn: Loss function.
"""
+ super(_RegressionHead, self).__init__(head_name=head_name)
+
self._loss_fn = loss_fn
self._logits_dimension = label_dimension
self._label_name = label_name
self._weight_column_name = weight_column_name
- self._head_name = head_name
self._enable_centered_bias = enable_centered_bias
- self._centered_bias_weight_collection = _head_prefixed(head_name,
- "centered_bias")
+ self._problem_type = constants.ProblemType.LINEAR_REGRESSION
@property
def logits_dimension(self):
@@ -266,16 +303,29 @@ class _RegressionHead(_Head):
"""See `_Head`."""
_check_mode_valid(mode)
_check_logits_input_not_supported(logits, logits_input)
- predictions = self._predictions(logits)
- if (mode == model_fn.ModeKeys.INFER) or (labels is None):
- loss = None
- train_op = None
- eval_metric_ops = None
- else:
- loss = self._training_loss(features, labels, logits)
- train_op = (None if train_op_fn is None or mode == model_fn.ModeKeys.EVAL
- else self._train_op(loss, labels, train_op_fn))
- eval_metric_ops = self._eval_metric_ops(features, labels, predictions)
+
+ centered_bias = None
+ if self._enable_centered_bias:
+ centered_bias = _centered_bias(self._logits_dimension)
+ logits = nn.bias_add(logits, centered_bias)
+
+ predictions = self._logits_to_predictions(logits)
+ loss = None
+ train_op = None
+ eval_metric_ops = None
+ if (mode != model_fn.ModeKeys.INFER) and (labels is not None):
+ labels = _check_labels(labels, self._label_name)
+ loss = _training_loss(
+ features, labels, logits,
+ loss_fn=self._loss_fn,
+ weight_column_name=self._weight_column_name,
+ head_name=self._head_name)
+ if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None):
+ train_op = _train_op(
+ loss, labels, train_op_fn, centered_bias, self.logits_dimension,
+ self._loss_fn)
+ eval_metric_ops = _eval_metric_ops(
+ self._default_metrics(), features, labels, predictions)
return model_fn.ModelFnOps(
mode=mode,
@@ -283,79 +333,8 @@ class _RegressionHead(_Head):
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
- signature_fn=self._signature_fn())
-
- def _training_loss(self, features, labels, logits, name="training_loss"):
- """Returns training loss tensor for this head.
-
- Training loss is different from the loss reported on the tensorboard as we
- should respect the example weights when computing the gradient.
-
- L = sum_{i} w_{i} * l_{i} / B
-
- where B is the number of examples in the batch, l_{i}, w_{i} are individual
- losses, and example weight.
-
- Args:
- features: features dict.
- labels: either a tensor for labels or in multihead case, a dict of string
- to labels tensor.
- logits: logits, a float tensor.
- name: Op name.
-
- Returns:
- A loss `Tensor`.
- """
- labels = _check_labels(labels, self._label_name)
-
- if self._enable_centered_bias:
- logits = nn.bias_add(logits, _centered_bias(
- self.logits_dimension,
- self._centered_bias_weight_collection))
-
- loss_unweighted = self._loss_fn(logits, labels)
- loss, weighted_average_loss = _loss(
- loss_unweighted,
- _weight_tensor(features, self._weight_column_name),
- name=name)
- summary.scalar(
- _head_prefixed(self._head_name, "loss"), weighted_average_loss)
- return loss
-
- def _train_op(self, loss, labels, train_op_fn):
- """Returns op for the training step."""
- train_op = train_op_fn(loss)
-
- if self._enable_centered_bias:
- centered_bias_step = [_centered_bias_step(
- self.logits_dimension,
- self._centered_bias_weight_collection,
- labels,
- self._loss_fn)]
- train_op = control_flow_ops.group(train_op, *centered_bias_step)
-
- return train_op
-
- def _eval_metric_ops(self, features, labels, predictions):
- """Returns a dict of metric ops keyed by name."""
- labels = _check_labels(labels, self._label_name)
- return estimator._make_metrics_ops( # pylint: disable=protected-access
- self._default_metrics(), features, labels, predictions)
-
- def _predictions(self, logits):
- """Returns a dict of predictions.
-
- Args:
- logits: logits `Tensor` before applying possible centered bias.
-
- Returns:
- Dict of prediction `Tensor` keyed by `PredictionKey`.
- """
- if self._enable_centered_bias:
- logits = nn.bias_add(logits, _centered_bias(
- self.logits_dimension,
- self._centered_bias_weight_collection))
- return self._logits_to_predictions(logits)
+ signature_fn=self._signature_fn(),
+ output_alternatives=self._create_output_alternatives(predictions))
def _logits_to_predictions(self, logits):
"""Returns a dict of predictions.
@@ -366,13 +345,11 @@ class _RegressionHead(_Head):
Returns:
Dict of prediction `Tensor` keyed by `PredictionKey`.
"""
- predictions = {}
- if self.logits_dimension == 1:
- predictions[prediction_key.PredictionKey.SCORES] = array_ops.squeeze(
- logits, squeeze_dims=[1], name=prediction_key.PredictionKey.SCORES)
- else:
- predictions[prediction_key.PredictionKey.SCORES] = logits
- return predictions
+ key = prediction_key.PredictionKey.SCORES
+ with ops.name_scope(None, "predictions", (logits,)):
+ if self.logits_dimension == 1:
+ logits = array_ops.squeeze(logits, squeeze_dims=(1,), name=key)
+ return {key: logits}
def _signature_fn(self):
"""Returns the signature_fn to be used in exporting."""
@@ -399,11 +376,17 @@ class _RegressionHead(_Head):
def _log_loss_with_two_classes(logits, labels):
- # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels.
- if len(labels.get_shape()) == 1:
- labels = array_ops.expand_dims(labels, dim=[1])
- return nn.sigmoid_cross_entropy_with_logits(
- logits, math_ops.to_float(labels))
+ with ops.name_scope(
+ None, "log_loss_with_two_classes", (logits, labels)) as name:
+ # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels.
+ if len(labels.get_shape()) == 1:
+ labels = array_ops.expand_dims(labels, dim=(1,))
+ return nn.sigmoid_cross_entropy_with_logits(
+ logits, math_ops.to_float(labels), name=name)
+
+
+def _one_class_to_two_class_logits(logits):
+ return array_ops.concat(1, (array_ops.zeros_like(logits), logits))
class _BinaryLogisticHead(_Head):
@@ -430,34 +413,45 @@ class _BinaryLogisticHead(_Head):
Raises:
ValueError: if n_classes is invalid.
"""
- self._thresholds = thresholds if thresholds else [.5]
+ super(_BinaryLogisticHead, self).__init__(head_name=head_name)
+ self._thresholds = thresholds if thresholds else (.5,)
self._label_name = label_name
self._weight_column_name = weight_column_name
- self._head_name = head_name
self._loss_fn = loss_fn
self._enable_centered_bias = enable_centered_bias
- self._centered_bias_weight_collection = _head_prefixed(head_name,
- "centered_bias")
@property
def logits_dimension(self):
return 1
def head_ops(self, features, labels, mode, train_op_fn, logits=None,
- logits_input=None):
+ logits_input=None, scope=None):
"""See `_Head`."""
_check_mode_valid(mode)
_check_logits_input_not_supported(logits, logits_input)
- predictions = self._predictions(logits)
- if (mode == model_fn.ModeKeys.INFER) or (labels is None):
- loss = None
- train_op = None
- eval_metric_ops = None
- else:
- loss = self._training_loss(features, labels, logits)
- train_op = (None if train_op_fn is None
- else self._train_op(loss, labels, train_op_fn))
- eval_metric_ops = self._eval_metric_ops(features, labels, predictions)
+
+ centered_bias = None
+ if self._enable_centered_bias:
+ centered_bias = _centered_bias(1)
+ logits = nn.bias_add(logits, centered_bias)
+
+ predictions = self._logits_to_predictions(logits)
+ loss = None
+ train_op = None
+ eval_metric_ops = None
+ if (mode != model_fn.ModeKeys.INFER) and (labels is not None):
+ labels = _check_labels(labels, self._label_name)
+ loss = _training_loss(
+ features, labels, logits,
+ loss_fn=self._loss_fn,
+ weight_column_name=self._weight_column_name,
+ head_name=self._head_name)
+ if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None):
+ train_op = _train_op(
+ loss, labels, train_op_fn, centered_bias, self.logits_dimension,
+ self._loss_fn)
+ eval_metric_ops = _eval_metric_ops(
+ self._default_metrics(), features, labels, predictions)
return model_fn.ModelFnOps(
mode=mode,
@@ -467,78 +461,6 @@ class _BinaryLogisticHead(_Head):
eval_metric_ops=eval_metric_ops,
signature_fn=self._signature_fn())
- def _training_loss(self, features, labels, logits=None, name="training_loss"):
- """Returns training loss tensor for this head.
-
- Training loss is different from the loss reported on the tensorboard as we
- should respect the example weights when computing the gradient.
-
- L = sum_{i} w_{i} * l_{i} / B
-
- where B is the number of examples in the batch, l_{i}, w_{i} are individual
- losses, and example weight.
-
- Args:
- features: features dict.
- labels: either a tensor for labels or in multihead case, a dict of string
- to labels tensor.
- logits: logits, a float tensor.
- name: Op name.
-
- Returns:
- A loss `Output`.
- """
- labels = _check_labels(labels, self._label_name)
-
- if self._enable_centered_bias:
- logits = nn.bias_add(logits, _centered_bias(
- self.logits_dimension,
- self._centered_bias_weight_collection))
-
- loss_unweighted = self._loss_fn(logits, labels)
- loss, weighted_average_loss = _loss(
- loss_unweighted,
- _weight_tensor(features, self._weight_column_name),
- name=name)
- summary.scalar(
- _head_prefixed(self._head_name, "loss"), weighted_average_loss)
- return loss
-
- def _train_op(self, loss, labels, train_op_fn):
- """Returns op for the training step."""
- train_op = train_op_fn(loss)
-
- if self._enable_centered_bias:
- centered_bias_step = [_centered_bias_step(
- self.logits_dimension,
- self._centered_bias_weight_collection,
- labels,
- self._loss_fn)]
- train_op = control_flow_ops.group(train_op, *centered_bias_step)
-
- return train_op
-
- def _eval_metric_ops(self, features, labels, predictions):
- """Returns a dict of metric ops keyed by name."""
- labels = _check_labels(labels, self._label_name)
- return estimator._make_metrics_ops( # pylint: disable=protected-access
- self._default_metrics(), features, labels, predictions)
-
- def _predictions(self, logits):
- """Returns a dict of predictions.
-
- Args:
- logits: logits `Output` before applying possible centered bias.
-
- Returns:
- Dict of prediction `Output` keyed by `PredictionKey`.
- """
- if self._enable_centered_bias:
- logits = nn.bias_add(logits, _centered_bias(
- self.logits_dimension,
- self._centered_bias_weight_collection))
- return self._logits_to_predictions(logits)
-
def _logits_to_predictions(self, logits):
"""Returns a dict of predictions.
@@ -548,15 +470,18 @@ class _BinaryLogisticHead(_Head):
Returns:
Dict of prediction `Output` keyed by `PredictionKey`.
"""
- predictions = {prediction_key.PredictionKey.LOGITS: logits}
- predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
- logits)
- logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
- predictions[prediction_key.PredictionKey.PROBABILITIES] = nn.softmax(
- logits)
- predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
- logits, 1)
- return predictions
+ with ops.name_scope(None, "predictions", (logits,)):
+ two_class_logits = _one_class_to_two_class_logits(logits)
+ return {
+ prediction_key.PredictionKey.LOGITS: logits,
+ prediction_key.PredictionKey.LOGISTIC: math_ops.sigmoid(
+ logits, name=prediction_key.PredictionKey.LOGISTIC),
+ prediction_key.PredictionKey.PROBABILITIES: nn.softmax(
+ two_class_logits,
+ name=prediction_key.PredictionKey.PROBABILITIES),
+ prediction_key.PredictionKey.CLASSES: math_ops.argmax(
+ two_class_logits, 1, name=prediction_key.PredictionKey.CLASSES)
+ }
def _signature_fn(self):
"""Returns the signature_fn to be used in exporting."""
@@ -628,14 +553,17 @@ class _BinaryLogisticHead(_Head):
def _softmax_cross_entropy_loss(logits, labels):
- # Check that we got integer for classification.
- if not labels.dtype.is_integer:
- raise ValueError("Labels dtype should be integer "
- "Instead got %s." % labels.dtype)
- # sparse_softmax_cross_entropy_with_logits requires [batch_size] labels.
- if len(labels.get_shape()) == 2:
- labels = array_ops.squeeze(labels, squeeze_dims=[1])
- return nn.sparse_softmax_cross_entropy_with_logits(logits, labels)
+ with ops.name_scope(
+ None, "softmax_cross_entropy_loss", (logits, labels,)) as name:
+ # Check that we got integer for classification.
+ if not labels.dtype.is_integer:
+ raise ValueError("Labels dtype should be integer "
+ "Instead got %s." % labels.dtype)
+ # sparse_softmax_cross_entropy_with_logits requires [batch_size] labels.
+ if len(labels.get_shape()) == 2:
+ labels = array_ops.squeeze(labels, squeeze_dims=(1,))
+ return nn.sparse_softmax_cross_entropy_with_logits(
+ logits, labels, name=name)
class _MultiClassHead(_Head):
@@ -665,18 +593,17 @@ class _MultiClassHead(_Head):
Raises:
ValueError: if n_classes is invalid.
"""
+ super(_MultiClassHead, self).__init__(head_name=head_name)
+
if (n_classes is None) or (n_classes <= 2):
raise ValueError("n_classes must be > 2: %s." % n_classes)
- self._thresholds = thresholds if thresholds else [.5]
-
+ self._thresholds = thresholds if thresholds else (.5,)
self._logits_dimension = n_classes
self._label_name = label_name
self._weight_column_name = weight_column_name
- self._head_name = head_name
self._loss_fn = loss_fn
self._enable_centered_bias = enable_centered_bias
- self._centered_bias_weight_collection = _head_prefixed(head_name,
- "centered_bias")
+ self._problem_type = constants.ProblemType.CLASSIFICATION
@property
def logits_dimension(self):
@@ -687,16 +614,29 @@ class _MultiClassHead(_Head):
"""See `_Head`."""
_check_mode_valid(mode)
_check_logits_input_not_supported(logits, logits_input)
- predictions = self._predictions(logits)
- if (mode == model_fn.ModeKeys.INFER) or (labels is None):
- loss = None
- train_op = None
- eval_metric_ops = None
- else:
- loss = self._training_loss(features, labels, logits)
- train_op = (None if train_op_fn is None or mode == model_fn.ModeKeys.EVAL
- else self._train_op(loss, labels, train_op_fn))
- eval_metric_ops = self._eval_metric_ops(features, labels, predictions)
+
+ centered_bias = None
+ if self._enable_centered_bias:
+ centered_bias = _centered_bias(self._logits_dimension)
+ logits = nn.bias_add(logits, centered_bias)
+
+ predictions = self._logits_to_predictions(logits)
+ loss = None
+ train_op = None
+ eval_metric_ops = None
+ if (mode != model_fn.ModeKeys.INFER) and (labels is not None):
+ labels = _check_labels(labels, self._label_name)
+ loss = _training_loss(
+ features, labels, logits,
+ loss_fn=self._loss_fn,
+ weight_column_name=self._weight_column_name,
+ head_name=self._head_name)
+ if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None):
+ train_op = _train_op(
+ loss, labels, train_op_fn, centered_bias, self._logits_dimension,
+ self._loss_fn)
+ eval_metric_ops = _eval_metric_ops(
+ self._default_metrics(), features, labels, predictions)
return model_fn.ModelFnOps(
mode=mode,
@@ -704,79 +644,8 @@ class _MultiClassHead(_Head):
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
- signature_fn=self._signature_fn())
-
- def _training_loss(self, features, labels, logits=None, name="training_loss"):
- """Returns training loss tensor for this head.
-
- Training loss is different from the loss reported on the tensorboard as we
- should respect the example weights when computing the gradient.
-
- L = sum_{i} w_{i} * l_{i} / B
-
- where B is the number of examples in the batch, l_{i}, w_{i} are individual
- losses, and example weight.
-
- Args:
- features: features dict.
- labels: either a tensor for labels or in multihead case, a dict of string
- to labels tensor.
- logits: logits, a float tensor.
- name: Op name.
-
- Returns:
- A loss `Tensor`.
- """
- labels = _check_labels(labels, self._label_name)
-
- if self._enable_centered_bias:
- logits = nn.bias_add(logits, _centered_bias(
- self.logits_dimension,
- self._centered_bias_weight_collection))
-
- loss_unweighted = self._loss_fn(logits, labels)
- loss, weighted_average_loss = _loss(
- loss_unweighted,
- _weight_tensor(features, self._weight_column_name),
- name=name)
- summary.scalar(
- _head_prefixed(self._head_name, "loss"), weighted_average_loss)
- return loss
-
- def _train_op(self, loss, labels, train_op_fn):
- """Returns op for the training step."""
- train_op = train_op_fn(loss)
-
- if self._enable_centered_bias:
- centered_bias_step = [_centered_bias_step(
- self.logits_dimension,
- self._centered_bias_weight_collection,
- labels,
- self._loss_fn)]
- train_op = control_flow_ops.group(train_op, *centered_bias_step)
-
- return train_op
-
- def _eval_metric_ops(self, features, labels, predictions):
- """Returns a dict of metric ops keyed by name."""
- labels = _check_labels(labels, self._label_name)
- return estimator._make_metrics_ops( # pylint: disable=protected-access
- self._default_metrics(), features, labels, predictions)
-
- def _predictions(self, logits):
- """Returns a dict of predictions.
-
- Args:
- logits: logits `Tensor` before applying possible centered bias.
-
- Returns:
- Dict of prediction `Tensor` keyed by `PredictionKey`.
- """
- if self._enable_centered_bias:
- logits = nn.bias_add(logits, _centered_bias(
- self.logits_dimension,
- self._centered_bias_weight_collection))
- return self._logits_to_predictions(logits)
+ signature_fn=self._signature_fn(),
+ output_alternatives=self._create_output_alternatives(predictions))
def _logits_to_predictions(self, logits):
"""Returns a dict of predictions.
@@ -787,13 +656,14 @@ class _MultiClassHead(_Head):
Returns:
Dict of prediction `Tensor` keyed by `PredictionKey`.
"""
- predictions = {prediction_key.PredictionKey.LOGITS: logits}
- predictions[prediction_key.PredictionKey.PROBABILITIES] = nn.softmax(
- logits)
- predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
- logits, 1)
-
- return predictions
+ with ops.name_scope(None, "predictions", (logits,)):
+ return {
+ prediction_key.PredictionKey.LOGITS: logits,
+ prediction_key.PredictionKey.PROBABILITIES: nn.softmax(
+ logits, name=prediction_key.PredictionKey.PROBABILITIES),
+ prediction_key.PredictionKey.CLASSES: math_ops.argmax(
+ logits, 1, name=prediction_key.PredictionKey.CLASSES)
+ }
def _signature_fn(self):
"""Returns the signature_fn to be used in exporting."""
@@ -849,31 +719,32 @@ class _BinarySvmHead(_BinaryLogisticHead):
def __init__(self, label_name, weight_column_name, enable_centered_bias,
head_name, thresholds):
def _loss_fn(logits, labels):
- check_shape_op = control_flow_ops.Assert(
- math_ops.less_equal(array_ops.rank(labels), 2),
- ["labels shape should be either [batch_size, 1] or [batch_size]"])
- with ops.control_dependencies([check_shape_op]):
- labels = array_ops.reshape(
- labels, shape=[array_ops.shape(labels)[0], 1])
- return losses.hinge_loss(logits, labels)
+ with ops.name_scope(None, "hinge_loss", (logits, labels)) as name:
+ check_shape_op = control_flow_ops.Assert(
+ math_ops.less_equal(array_ops.rank(labels), 2),
+ ("labels shape should be either [batch_size, 1] or [batch_size]",))
+ with ops.control_dependencies((check_shape_op,)):
+ labels = array_ops.reshape(
+ labels, shape=(array_ops.shape(labels)[0], 1))
+ return losses.hinge_loss(logits, labels, scope=name)
super(_BinarySvmHead, self).__init__(
- loss_fn=_loss_fn,
label_name=label_name,
weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias,
head_name=head_name,
+ loss_fn=_loss_fn,
thresholds=thresholds)
def _logits_to_predictions(self, logits):
"""See `_MultiClassHead`."""
- predictions = {}
- predictions[prediction_key.PredictionKey.LOGITS] = logits
- logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
- predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
- logits, 1, name=prediction_key.PredictionKey.CLASSES)
-
- return predictions
+ with ops.name_scope(None, "predictions", (logits,)):
+ return {
+ prediction_key.PredictionKey.LOGITS: logits,
+ prediction_key.PredictionKey.CLASSES: math_ops.argmax(
+ _one_class_to_two_class_logits(logits), 1,
+ name=prediction_key.PredictionKey.CLASSES)
+ }
def _default_metrics(self):
"""See `_MultiClassHead`."""
@@ -901,60 +772,62 @@ class _MultiLabelHead(_MultiClassHead):
thresholds):
super(_MultiLabelHead, self).__init__(
- loss_fn=_sigmoid_cross_entropy_loss,
n_classes=n_classes,
label_name=label_name,
weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias,
head_name=head_name,
+ loss_fn=_sigmoid_cross_entropy_loss,
thresholds=thresholds)
def _logits_to_predictions(self, logits):
"""See `_MultiClassHead`."""
- predictions = {prediction_key.PredictionKey.LOGITS: logits}
- if self.logits_dimension == 1:
- predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
- logits, name=prediction_key.PredictionKey.LOGISTIC)
- logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
- predictions[
- prediction_key.PredictionKey.PROBABILITIES] = math_ops.sigmoid(
- logits, name=prediction_key.PredictionKey.PROBABILITIES)
- predictions[prediction_key.PredictionKey.CLASSES] = math_ops.to_int64(
- math_ops.greater(logits, 0),
- name=prediction_key.PredictionKey.CLASSES)
- return predictions
+ with ops.name_scope(None, "predictions", (logits,)):
+ return {
+ prediction_key.PredictionKey.LOGITS: logits,
+ prediction_key.PredictionKey.PROBABILITIES: math_ops.sigmoid(
+ logits, name=prediction_key.PredictionKey.PROBABILITIES),
+ prediction_key.PredictionKey.CLASSES: math_ops.to_int64(
+ math_ops.greater(logits, 0),
+ name=prediction_key.PredictionKey.CLASSES)
+ }
def _weighted_loss(loss, weight):
"""Returns cumulative weighted loss."""
- unweighted_loss = array_ops.reshape(loss, shape=(-1,))
- weighted_loss = math_ops.mul(unweighted_loss,
- array_ops.reshape(
- weight, shape=(-1,)))
- return weighted_loss
+ with ops.name_scope(None, "weighted_loss", (loss, weight)) as name:
+ unweighted_loss = array_ops.reshape(loss, shape=(-1,))
+ weighted_loss = math_ops.mul(unweighted_loss,
+ array_ops.reshape(
+ weight, shape=(-1,)),
+ name=name)
+ return weighted_loss
def _weight_tensor(features, weight_column_name):
if not weight_column_name:
return None
- else:
+ with ops.name_scope(
+ None, "weight_tensor", tuple(six.itervalues(features))) as name:
return array_ops.reshape(
math_ops.to_float(features[weight_column_name]),
- shape=(-1,))
+ shape=(-1,),
+ name=name)
def _loss(loss_unweighted, weight, name):
- """Returns loss."""
- if weight is None:
- loss = math_ops.reduce_mean(loss_unweighted, name=name)
- return loss, loss
- loss_weighted = _weighted_loss(loss_unweighted, weight)
- weighted_average_loss = math_ops.div(
- math_ops.reduce_sum(loss_weighted),
- math_ops.to_float(math_ops.reduce_sum(weight)),
- name="weighted_average_loss")
- loss = math_ops.reduce_mean(loss_weighted, name=name)
- return loss, weighted_average_loss
+ """Returns a tuple of (loss, weighted_average_loss)."""
+ with ops.name_scope(name, values=(loss_unweighted, weight)) as name_scope:
+ if weight is None:
+ loss = math_ops.reduce_mean(loss_unweighted, name=name_scope)
+ return loss, loss
+ loss_weighted = _weighted_loss(loss_unweighted, weight)
+ weighted_average_loss = math_ops.div(
+ math_ops.reduce_sum(loss_weighted),
+ math_ops.to_float(math_ops.reduce_sum(weight)),
+ name="weighted_average_loss")
+ loss = math_ops.reduce_mean(loss_weighted, name=name_scope)
+ return loss, weighted_average_loss
def _check_logits_input_not_supported(logits, logits_input):
@@ -971,63 +844,128 @@ def _check_mode_valid(mode):
raise ValueError("mode=%s unrecognized." % str(mode))
-def _centered_bias(logits_dimension, weight_collection):
- """Creates and returns centered bias."""
- centered_bias = variables.Variable(
- array_ops.zeros([logits_dimension]),
- collections=[weight_collection, ops.GraphKeys.GLOBAL_VARIABLES],
- name="centered_bias_weight")
+def _centered_bias(logits_dimension):
+ """Returns `logits`, optionally with centered bias applied.
+
+ Args:
+ logits_dimension: Last dimension of `logits`. Must be >= 1.
+
+ Returns:
+ Centered bias `Variable`.
- biases = array_ops.reshape(centered_bias, [-1])
- for cb in range(logits_dimension):
- summary.scalar("centered_bias_%d" % cb, biases[cb])
+ Raises:
+ ValueError: if `logits_dimension` is invalid.
+ """
+ if (logits_dimension is None) or (logits_dimension < 1):
+ raise ValueError("Invalid logits_dimension %s." % logits_dimension)
+ centered_bias = variable_scope.get_variable(
+ name="centered_bias_weight",
+ shape=(logits_dimension,),
+ initializer=init_ops.zeros_initializer,
+ trainable=True)
+ for dim in range(logits_dimension):
+ summary.scalar("centered_bias_%d" % dim, centered_bias[dim])
return centered_bias
-def _centered_bias_step(logits_dimension, weight_collection, labels, loss_fn):
+def _centered_bias_step(centered_bias, logits_dimension, labels, loss_fn):
"""Creates and returns training op for centered bias."""
- centered_bias = ops.get_collection(weight_collection)
- batch_size = array_ops.shape(labels)[0]
- logits = array_ops.reshape(
- array_ops.tile(centered_bias[0], [batch_size]),
- [batch_size, logits_dimension])
- with ops.name_scope(None, "centered_bias", (labels, logits)):
- centered_bias_loss = math_ops.reduce_mean(
- loss_fn(logits, labels), name="training_loss")
- # Learn central bias by an optimizer. 0.1 is a convervative lr for a
- # single variable.
- return training.AdagradOptimizer(0.1).minimize(
- centered_bias_loss, var_list=centered_bias)
+ if (logits_dimension is None) or (logits_dimension < 1):
+ raise ValueError("Invalid logits_dimension %s." % logits_dimension)
+ with ops.name_scope(None, "centered_bias_step", (labels,)) as name:
+ batch_size = array_ops.shape(labels)[0]
+ logits = array_ops.reshape(
+ array_ops.tile(centered_bias, (batch_size,)),
+ (batch_size, logits_dimension))
+ with ops.name_scope(None, "centered_bias", (labels, logits)):
+ centered_bias_loss = math_ops.reduce_mean(
+ loss_fn(logits, labels), name="training_loss")
+ # Learn central bias by an optimizer. 0.1 is a convervative lr for a
+ # single variable.
+ return training.AdagradOptimizer(0.1).minimize(
+ centered_bias_loss, var_list=(centered_bias,), name=name)
def _head_prefixed(head_name, val):
return "%s_%s" % (head_name, val) if head_name else val
-# TODO(zakaria): use contrib losses.
-def _mean_squared_loss(logits, labels):
- # To prevent broadcasting inside "-".
- if len(labels.get_shape()) == 1:
- labels = array_ops.expand_dims(labels, dim=[1])
- # TODO(zakaria): make sure it does not recreate the broadcast bug.
- if len(logits.get_shape()) == 1:
- logits = array_ops.expand_dims(logits, dim=[1])
- logits.get_shape().assert_is_compatible_with(labels.get_shape())
- return math_ops.square(logits - math_ops.to_float(labels))
+def _training_loss(
+ features, labels, logits, loss_fn, weight_column_name=None, head_name=None):
+ """Returns training loss tensor.
+
+ Training loss is different from the loss reported on the tensorboard as we
+ should respect the example weights when computing the gradient.
+
+ L = sum_{i} w_{i} * l_{i} / B
+
+ where B is the number of examples in the batch, l_{i}, w_{i} are individual
+ losses, and example weight.
+
+ Args:
+ features: Features `dict`.
+ labels: Either a `Tensor` for labels or in multihead case, a `dict` of
+ string to `Tensor`.
+ logits: logits, a float `Tensor`. Shape is `(batch_size, logits_dimension)`.
+ loss_fn: Function taking `logits` and `labels`, and returning the raw
+ unweighted loss.
+ weight_column_name: Key for weights `Tensor` in `features`, if applicable.
+ head_name: Head name, used for summary.
+
+ Returns:
+ A loss `Output`.
+ """
+ with ops.name_scope(
+ None, "training_loss",
+ tuple(six.itervalues(features)) + (labels, logits)) as name:
+ loss, weighted_average_loss = _loss(
+ loss_fn(logits, labels),
+ _weight_tensor(features, weight_column_name),
+ name=name)
+ summary.scalar(_head_prefixed(head_name, "loss"), weighted_average_loss)
+ return loss
+
+
+def _train_op(
+ loss, labels, train_op_fn, centered_bias=None, logits_dimension=None,
+ loss_fn=None):
+ """Returns op for the training step."""
+ with ops.name_scope(None, "train_op", (loss, labels)):
+ train_op = train_op_fn(loss)
+ if centered_bias is not None:
+ centered_bias_step = _centered_bias_step(
+ centered_bias, logits_dimension, labels, loss_fn)
+ train_op = control_flow_ops.group(train_op, centered_bias_step)
+ return train_op
+
+
+def _eval_metric_ops(metrics, features, labels, predictions):
+ with ops.name_scope(
+ None, "metrics",
+ (tuple(six.itervalues(features)) +
+ (labels,) +
+ tuple(six.itervalues(predictions)))):
+ # pylint: disable=protected-access
+ return estimator._make_metrics_ops(metrics, features, labels, predictions)
+ # pylint: enable=protected-access
def _sigmoid_cross_entropy_loss(logits, labels):
- # sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels.
- return nn.sigmoid_cross_entropy_with_logits(logits, math_ops.to_float(labels))
+ with ops.name_scope(
+ None, "sigmoid_cross_entropy_loss", (logits, labels)) as name:
+ # sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels.
+ return nn.sigmoid_cross_entropy_with_logits(
+ logits, math_ops.to_float(labels), name=name)
def _float_weights_or_none(weights):
if weights is None:
return None
- return math_ops.to_float(weights)
+ with ops.name_scope(None, "float_weights", (weights,)) as name:
+ return math_ops.to_float(weights, name=name)
-def _weighted_average_loss_metric_spec(loss_fn, predictoin_key,
+def _weighted_average_loss_metric_spec(loss_fn, pred_key,
label_key, weight_key):
def _streaming_weighted_average_loss(predictions, labels, weights=None):
loss_unweighted = loss_fn(predictions, labels)
@@ -1038,7 +976,7 @@ def _weighted_average_loss_metric_spec(loss_fn, predictoin_key,
name="eval_loss")
return metrics_lib.streaming_mean(weighted_average_loss)
return metric_spec.MetricSpec(_streaming_weighted_average_loss,
- predictoin_key, label_key, weight_key)
+ pred_key, label_key, weight_key)
def _labels_streaming_mean(unused_predictions, labels, weights=None):
@@ -1070,7 +1008,7 @@ def _streaming_at_threshold(streaming_metrics_fn, threshold):
def _streaming_metrics(predictions, labels, weights=None):
precision_tensor, update_op = streaming_metrics_fn(
- predictions, labels=labels, thresholds=[threshold],
+ predictions, labels=labels, thresholds=(threshold,),
weights=_float_weights_or_none(weights))
return array_ops.squeeze(precision_tensor), update_op
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
index 440615371d..673fbaefbb 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
@@ -18,13 +18,38 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+import six
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
+def _assert_variables(
+ test_case, expected_global=None, expected_model=None,
+ expected_trainable=None):
+ test_case.assertItemsEqual(
+ [] if expected_global is None else expected_global,
+ [k.name for k in tf.global_variables()])
+ test_case.assertItemsEqual(
+ [] if expected_model is None else expected_model,
+ [k.name for k in tf.model_variables()])
+ test_case.assertItemsEqual(
+ [] if expected_trainable is None else expected_trainable,
+ [k.name for k in tf.trainable_variables()])
+
+
+def _assert_no_variables(test_case):
+ _assert_variables(test_case, set([]), set([]), set([]))
+
+
class RegressionModelHeadTest(tf.test.TestCase):
+ def _assert_metrics(self, model_fn_ops):
+ self.assertItemsEqual((
+ "loss",
+ ), six.iterkeys(model_fn_ops.eval_metric_ops))
+
# TODO(zakaria): test multilabel regresssion.
def testRegression(self):
head = head_lib._regression_head()
@@ -34,8 +59,15 @@ class RegressionModelHeadTest(tf.test.TestCase):
model_fn_ops = head.head_ops({}, labels,
tf.contrib.learn.ModeKeys.TRAIN,
_noop_train_op, logits=prediction)
+ self._assert_metrics(model_fn_ops)
+ _assert_no_variables(self)
self.assertAlmostEqual(5. / 3, sess.run(model_fn_ops.loss))
+ model_fn_ops = head.head_ops({}, labels,
+ tf.contrib.learn.ModeKeys.EVAL,
+ _noop_train_op, logits=prediction)
+ self.assertIsNone(model_fn_ops.train_op)
+
def testRegressionWithWeights(self):
head = head_lib._regression_head(
weight_column_name="label_weight")
@@ -46,6 +78,28 @@ class RegressionModelHeadTest(tf.test.TestCase):
model_fn_ops = head.head_ops(features, labels,
tf.contrib.learn.ModeKeys.TRAIN,
_noop_train_op, logits=prediction)
+ self._assert_metrics(model_fn_ops)
+ _assert_no_variables(self)
+ self.assertAlmostEqual(2. / 3, sess.run(model_fn_ops.loss), places=3)
+
+ def testRegressionWithCenteredBias(self):
+ head = head_lib._regression_head(
+ weight_column_name="label_weight", enable_centered_bias=True)
+ with tf.Graph().as_default(), tf.Session() as sess:
+ features = {"label_weight": tf.constant([[2.], [5.], [0.]])}
+ prediction = tf.constant([[1.], [1.], [3.]])
+ labels = tf.constant([[0.], [1.], [1.]])
+ model_fn_ops = head.head_ops(features, labels,
+ tf.contrib.learn.ModeKeys.TRAIN,
+ _noop_train_op, logits=prediction)
+ self._assert_metrics(model_fn_ops)
+ _assert_variables(self, expected_global=(
+ "centered_bias_weight:0",
+ "train_op/centered_bias_step/centered_bias_weight/Adagrad:0",
+ ), expected_trainable=(
+ "centered_bias_weight:0",
+ ))
+ tf.global_variables_initializer().run()
self.assertAlmostEqual(2. / 3, sess.run(model_fn_ops.loss), places=3)
def testErrorInSparseTensorLabels(self):
@@ -64,6 +118,12 @@ class RegressionModelHeadTest(tf.test.TestCase):
class MultiLabelModelHeadTest(tf.test.TestCase):
+ def _assert_metrics(self, model_fn_ops):
+ self.assertItemsEqual((
+ "accuracy",
+ "loss",
+ ), six.iterkeys(model_fn_ops.eval_metric_ops))
+
def testMultiLabel(self):
head = head_lib._multi_label_head(n_classes=3)
with tf.Graph().as_default(), tf.Session() as sess:
@@ -72,8 +132,15 @@ class MultiLabelModelHeadTest(tf.test.TestCase):
model_fn_ops = head.head_ops({}, labels,
tf.contrib.learn.ModeKeys.TRAIN,
_noop_train_op, logits=logits)
+ self._assert_metrics(model_fn_ops)
+ _assert_no_variables(self)
self.assertAlmostEqual(0.89985204, sess.run(model_fn_ops.loss))
+ model_fn_ops = head.head_ops({}, labels,
+ tf.contrib.learn.ModeKeys.EVAL,
+ _noop_train_op, logits=logits)
+ self.assertIsNone(model_fn_ops.train_op)
+
def testMultiLabelWithWeight(self):
head = head_lib._multi_label_head(
n_classes=3, weight_column_name="label_weight")
@@ -84,11 +151,44 @@ class MultiLabelModelHeadTest(tf.test.TestCase):
model_fn_ops = head.head_ops(features, labels,
tf.contrib.learn.ModeKeys.TRAIN,
_noop_train_op, logits=logits)
+ self._assert_metrics(model_fn_ops)
+ _assert_no_variables(self)
self.assertAlmostEqual(0.089985214, sess.run(model_fn_ops.loss))
+ def testMultiLabelWithCenteredBias(self):
+ head = head_lib._multi_label_head(n_classes=3, enable_centered_bias=True)
+ with tf.Graph().as_default(), tf.Session() as sess:
+ logits = tf.constant([[1., 0., 0.]])
+ labels = tf.constant([[0, 0, 1]])
+ model_fn_ops = head.head_ops({}, labels,
+ tf.contrib.learn.ModeKeys.TRAIN,
+ _noop_train_op, logits=logits)
+ self._assert_metrics(model_fn_ops)
+ _assert_variables(self, expected_global=(
+ "centered_bias_weight:0",
+ "train_op/centered_bias_step/centered_bias_weight/Adagrad:0",
+ ), expected_trainable=(
+ "centered_bias_weight:0",
+ ))
+ tf.global_variables_initializer().run()
+ self.assertAlmostEqual(0.89985204, sess.run(model_fn_ops.loss))
+
class MultiClassModelHeadTest(tf.test.TestCase):
+ def _assert_binary_metrics(self, model_fn_ops):
+ self.assertItemsEqual((
+ "accuracy",
+ "accuracy/baseline_label_mean",
+ "accuracy/threshold_0.500000_mean",
+ "auc",
+ "labels/actual_label_mean",
+ "labels/prediction_mean",
+ "loss",
+ "precision/positive_threshold_0.500000_mean",
+ "recall/positive_threshold_0.500000_mean",
+ ), six.iterkeys(model_fn_ops.eval_metric_ops))
+
def testBinaryClassification(self):
head = head_lib._multi_class_head(n_classes=2)
with tf.Graph().as_default(), tf.Session() as sess:
@@ -99,8 +199,14 @@ class MultiClassModelHeadTest(tf.test.TestCase):
model_fn_ops = head.head_ops({}, labels,
tf.contrib.learn.ModeKeys.TRAIN,
_noop_train_op, logits=logits)
+ self._assert_binary_metrics(model_fn_ops)
+ _assert_no_variables(self)
self.assertAlmostEqual(0.81326175, sess.run(model_fn_ops.loss),
delta=1e-6)
+ model_fn_ops = head.head_ops({}, labels,
+ tf.contrib.learn.ModeKeys.EVAL,
+ _noop_train_op, logits=logits)
+ self.assertIsNone(model_fn_ops.train_op)
def testErrorInSparseTensorLabels(self):
head = head_lib._multi_class_head(n_classes=2)
@@ -127,11 +233,41 @@ class MultiClassModelHeadTest(tf.test.TestCase):
model_fn_ops = head.head_ops(features, labels,
tf.contrib.learn.ModeKeys.TRAIN,
_noop_train_op, logits=logits)
+ self._assert_binary_metrics(model_fn_ops)
+ _assert_no_variables(self)
self.assertAlmostEqual(.31326166 / 2, sess.run(model_fn_ops.loss),
delta=1e-6)
+ def testBinaryClassificationWithCenteredBias(self):
+ head = head_lib._multi_class_head(n_classes=2, enable_centered_bias=True)
+ with tf.Graph().as_default(), tf.Session() as sess:
+ logits = tf.constant([[1.], [1.]])
+ labels = tf.constant([[1.], [0.]])
+ # logloss: z:label, x:logit
+ # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+ model_fn_ops = head.head_ops({}, labels,
+ tf.contrib.learn.ModeKeys.TRAIN,
+ _noop_train_op, logits=logits)
+ self._assert_binary_metrics(model_fn_ops)
+ _assert_variables(self, expected_global=(
+ "centered_bias_weight:0",
+ "train_op/centered_bias_step/centered_bias_weight/Adagrad:0",
+ ), expected_trainable=(
+ "centered_bias_weight:0",
+ ))
+ tf.global_variables_initializer().run()
+ self.assertAlmostEqual(0.81326175, sess.run(model_fn_ops.loss),
+ delta=1e-6)
+
+ def _assert_multi_class_metrics(self, model_fn_ops):
+ self.assertItemsEqual((
+ "accuracy",
+ "loss",
+ ), six.iterkeys(model_fn_ops.eval_metric_ops))
+
def testMultiClass(self):
- head = head_lib._multi_class_head(n_classes=3)
+ n_classes = 3
+ head = head_lib._multi_class_head(n_classes=n_classes)
with tf.Graph().as_default(), tf.Session() as sess:
logits = tf.constant([[1., 0., 0.]])
labels = tf.constant([2])
@@ -140,11 +276,18 @@ class MultiClassModelHeadTest(tf.test.TestCase):
model_fn_ops = head.head_ops({}, labels,
tf.contrib.learn.ModeKeys.TRAIN,
_noop_train_op, logits=logits)
+ self._assert_multi_class_metrics(model_fn_ops)
+ _assert_no_variables(self)
self.assertAlmostEqual(1.5514446, sess.run(model_fn_ops.loss))
+ model_fn_ops = head.head_ops({}, labels,
+ tf.contrib.learn.ModeKeys.EVAL,
+ _noop_train_op, logits=logits)
+ self.assertIsNone(model_fn_ops.train_op)
def testMultiClassWithWeight(self):
+ n_classes = 3
head = head_lib._multi_class_head(
- n_classes=3, weight_column_name="label_weight")
+ n_classes=n_classes, weight_column_name="label_weight")
with tf.Graph().as_default(), tf.Session() as sess:
features = {"label_weight": tf.constant([0.1])}
logits = tf.constant([[1., 0., 0.]])
@@ -154,6 +297,8 @@ class MultiClassModelHeadTest(tf.test.TestCase):
model_fn_ops = head.head_ops(features, labels,
tf.contrib.learn.ModeKeys.TRAIN,
_noop_train_op, logits=logits)
+ self._assert_multi_class_metrics(model_fn_ops)
+ _assert_no_variables(self)
self.assertAlmostEqual(.15514446, sess.run(model_fn_ops.loss))
def testInvalidNClasses(self):
@@ -164,34 +309,73 @@ class MultiClassModelHeadTest(tf.test.TestCase):
class BinarySvmModelHeadTest(tf.test.TestCase):
+ def setUp(self):
+ # Prediction for first example is in the right side of the hyperplane
+ # (i.e., < 0) but it is within the [-1,1] margin. There is a 0.5 loss
+ # incurred by this example. The 2nd prediction is outside the margin so it
+ # incurs no loss at all.
+ self._predictions = ((-0.5,), (1.2,))
+ self._labels = (0, 1)
+ self._expected_losses = (0.5, 0.0)
+
+ def _assert_metrics(self, model_fn_ops):
+ self.assertItemsEqual((
+ "accuracy",
+ "loss",
+ ), six.iterkeys(model_fn_ops.eval_metric_ops))
+
def testBinarySVMDefaultWeights(self):
head = head_lib._binary_svm_head()
- predictions = tf.constant([[-0.5], [1.2]])
- labels = tf.constant([0, 1])
+ with tf.Graph().as_default(), tf.Session():
+ predictions = tf.constant(self._predictions)
+ labels = tf.constant(self._labels)
+ model_fn_ops = head.head_ops({}, labels,
+ tf.contrib.learn.ModeKeys.TRAIN,
+ _noop_train_op, logits=predictions)
+ self._assert_metrics(model_fn_ops)
+ _assert_no_variables(self)
+ self.assertAlmostEqual(
+ np.average(self._expected_losses), model_fn_ops.loss.eval())
+
model_fn_ops = head.head_ops({}, labels,
- tf.contrib.learn.ModeKeys.TRAIN,
+ tf.contrib.learn.ModeKeys.EVAL,
_noop_train_op, logits=predictions)
- # Prediction for first example is in the right side of the hyperplane (i.e.,
- # < 0) but it is within the [-1,1] margin. There is a 0.5 loss incurred by
- # this example. The 2nd prediction is outside the margin so it incurs no
- # loss at all. The overall (normalized) loss is therefore 0.5/(1+1) = 0.25.
- with tf.Session() as sess:
- self.assertAlmostEqual(0.25, sess.run(model_fn_ops.loss))
+ self.assertIsNone(model_fn_ops.train_op)
def testBinarySVMWithWeights(self):
- head = head_lib._binary_svm_head(
- weight_column_name="weights")
- predictions = tf.constant([[-0.7], [0.2]])
- labels = tf.constant([0, 1])
- features = {"weights": tf.constant([2.0, 10.0])}
- model_fn_ops = head.head_ops(features, labels,
- tf.contrib.learn.ModeKeys.TRAIN,
- _noop_train_op, logits=predictions)
- # Prediction for both examples are in the right side of the hyperplane but
- # within the margin. The (weighted) loss incurred is 2*0.3=0.6 and 10*0.8=8
- # respectively. The overall (normalized) loss is therefore 8.6/12.
- with tf.Session() as sess:
- self.assertAlmostEqual(8.6 / 2, sess.run(model_fn_ops.loss), places=3)
+ head = head_lib._binary_svm_head(weight_column_name="weights")
+ with tf.Graph().as_default(), tf.Session():
+ predictions = tf.constant(self._predictions)
+ labels = tf.constant(self._labels)
+ weights = (7.0, 11.0)
+ features = {"weights": tf.constant(weights)}
+ model_fn_ops = head.head_ops(features, labels,
+ tf.contrib.learn.ModeKeys.TRAIN,
+ _noop_train_op, logits=predictions)
+ self._assert_metrics(model_fn_ops)
+ _assert_no_variables(self)
+ self.assertAlmostEqual(
+ np.sum(np.multiply(weights, self._expected_losses)) / 2.0,
+ model_fn_ops.loss.eval())
+
+ def testBinarySVMWithCenteredBias(self):
+ head = head_lib._binary_svm_head(enable_centered_bias=True)
+ with tf.Graph().as_default(), tf.Session():
+ predictions = tf.constant(self._predictions)
+ labels = tf.constant(self._labels)
+ model_fn_ops = head.head_ops({}, labels,
+ tf.contrib.learn.ModeKeys.TRAIN,
+ _noop_train_op, logits=predictions)
+ self._assert_metrics(model_fn_ops)
+ _assert_variables(self, expected_global=(
+ "centered_bias_weight:0",
+ "train_op/centered_bias_step/centered_bias_weight/Adagrad:0",
+ ), expected_trainable=(
+ "centered_bias_weight:0",
+ ))
+ tf.global_variables_initializer().run()
+ self.assertAlmostEqual(
+ np.average(self._expected_losses), model_fn_ops.loss.eval())
def _noop_train_op(unused_loss):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index d043468654..0405eb0476 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -27,6 +27,7 @@ import six
from tensorflow.contrib import layers
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
+from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
@@ -519,6 +520,22 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
default_batch_size=default_batch_size,
exports_to_keep=exports_to_keep)
+ @experimental
+ def export_savedmodel(self,
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=None,
+ assets_extra=None,
+ as_text=False,
+ exports_to_keep=None):
+ return self._estimator.export_savedmodel(
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=default_output_alternative_key,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ exports_to_keep=exports_to_keep)
+
@property
@deprecated("2016-10-30",
"This method will be removed after the deprecation date. "
@@ -761,6 +778,22 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
default_batch_size=default_batch_size,
exports_to_keep=exports_to_keep)
+ @experimental
+ def export_savedmodel(self,
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=None,
+ assets_extra=None,
+ as_text=False,
+ exports_to_keep=None):
+ return self._estimator.export_savedmodel(
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=default_output_alternative_key,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ exports_to_keep=exports_to_keep)
+
@property
@deprecated("2016-10-30",
"This method will be removed after the deprecation date. "
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
index 0c99de5dd1..50f0d2d75d 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
@@ -533,7 +533,7 @@ class LinearClassifierTest(tf.test.TestCase):
classifier = tf.contrib.learn.LinearClassifier(
feature_columns=[age, language], enable_centered_bias=False)
classifier.fit(input_fn=input_fn, steps=100)
- self.assertFalse('centered_bias_weight' in classifier.get_variable_names())
+ self.assertNotIn('centered_bias_weight', classifier.get_variable_names())
def testEnableCenteredBias(self):
"""Tests that we can disable centered bias."""
@@ -552,7 +552,7 @@ class LinearClassifierTest(tf.test.TestCase):
classifier = tf.contrib.learn.LinearClassifier(
feature_columns=[age, language], enable_centered_bias=True)
classifier.fit(input_fn=input_fn, steps=100)
- self.assertTrue('centered_bias_weight' in classifier.get_variable_names())
+ self.assertIn('centered_bias_weight', classifier.get_variable_names())
def testTrainOptimizerWithL1Reg(self):
"""Tests l1 regularized model has higher loss."""
diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
index 3f9351ce22..42f21bd196 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
@@ -49,13 +49,33 @@ class ModeKeys(object):
# TODO(roumposg): Pass output_signature_fn instead of signature_fn.
class ModelFnOps(collections.namedtuple(
'ModelFnOps',
- ['predictions', 'loss', 'train_op', 'eval_metric_ops', 'signature_fn'])):
+ ['predictions', 'loss', 'train_op', 'eval_metric_ops', 'signature_fn',
+ 'output_alternatives'])):
"""Ops returned from a model_fn."""
+ # TODO(soergel): remove signature_fn once sessionbundle export is deprecated.
+
def __new__(cls, mode, predictions=None, loss=None, train_op=None,
- eval_metric_ops=None, signature_fn=None):
+ eval_metric_ops=None, signature_fn=None,
+ output_alternatives=None):
"""Creates a validated `ModelFnOps` instance.
+ For a multi-headed model, the predictions dict here will contain the outputs
+ of all of the heads. However: at serving time, requests will be made
+ specifically for one or more heads, and the RPCs used for these requests may
+ differ by problem type (i.e., regression, classification, other). The
+ purpose of the output_alternatives dict is to aid in exporting a SavedModel
+ from which such head-specific queries can be served. These
+ output_alternatives will be combined with input_alternatives (see
+ `saved_model_export_utils`) to produce a set of `SignatureDef`s specifying
+ the valid requests that can be served from this model.
+
+ For a single-headed model, it is still adviseable to provide
+ output_alternatives with a single entry, because this is how the problem
+ type is communicated for export and serving. If output_alternatives is not
+ given, the resulting SavedModel will support only one head of unspecified
+ type.
+
Args:
mode: One of `ModeKeys`. Specifies if this training, evaluation or
prediction.
@@ -65,6 +85,14 @@ class ModelFnOps(collections.namedtuple(
eval_metric_ops: Dict of metric results keyed by name. The values of the
dict are the results of calling a metric function, such as `Tensor`.
signature_fn: The signature_fn used for exporting.
+ output_alternatives: a dict of
+ `{submodel_name: (problem_type, {tensor_name: Tensor})}`, where
+ `submodel_name` is a submodel identifier that should be consistent
+ across the pipeline (here likely taken from the name of each `Head`,
+ for models that use them), `problem_type` is a `ProblemType`,
+ `tensor_name` is a symbolic name for an output Tensor possibly but not
+ necessarily taken from `PredictionKey`, and `Tensor` is the
+ corresponding output Tensor itself.
Returns:
A validated `ModelFnOps` object.
@@ -122,4 +150,5 @@ class ModelFnOps(collections.namedtuple(
raise ValueError('signature_fn is not callable.')
return super(ModelFnOps, cls).__new__(cls, predictions, loss, train_op,
- eval_metric_ops, signature_fn)
+ eval_metric_ops, signature_fn,
+ output_alternatives)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py
index c2c41255c9..deb55efc9f 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib.framework import deprecated_arg_values
+from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import trainable
@@ -352,3 +353,19 @@ class TensorForestEstimator(evaluable.Evaluable, trainable.Trainable):
self._estimator._model_fn = orig_model_fn
# pylint: enable=protected-access
return result
+
+ @experimental
+ def export_savedmodel(self,
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=None,
+ assets_extra=None,
+ as_text=False,
+ exports_to_keep=None):
+ return self._estimator.export_savedmodel(
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=default_output_alternative_key,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ exports_to_keep=exports_to_keep)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/svm.py b/tensorflow/contrib/learn/python/learn/estimators/svm.py
index eeee673c5a..a6e4e7b6a3 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/svm.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/svm.py
@@ -19,12 +19,14 @@ from __future__ import division
from __future__ import print_function
import inspect
+import re
import tempfile
from tensorflow.contrib import layers
from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.framework import list_variables
from tensorflow.contrib.framework import load_variable
+from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import estimator
@@ -235,6 +237,22 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
default_batch_size=default_batch_size,
exports_to_keep=exports_to_keep)
+ @experimental
+ def export_savedmodel(self,
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=None,
+ assets_extra=None,
+ as_text=False,
+ exports_to_keep=None):
+ return self._estimator.export_savedmodel(
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=default_output_alternative_key,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ exports_to_keep=exports_to_keep)
+
@property
def weights_(self):
values = {}
diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py
index 9c70cc8dea..edd363b728 100644
--- a/tensorflow/contrib/learn/python/learn/monitors.py
+++ b/tensorflow/contrib/learn/python/learn/monitors.py
@@ -924,9 +924,9 @@ class ExportMonitor(EveryN):
`None`).
input_feature_key: String key into the features dict returned by
`input_fn` that corresponds to the raw `Example` strings `Tensor` that
- the exported model will take as input. Can only be `None` if you're
- using a custom `signature_fn` that does not use the first arg
- (examples).
+ the exported model will take as input. Should be `None` if and only if
+ you're passing in a `signature_fn` that does not use the first arg
+ (`Tensor` of `Example` strings).
exports_to_keep: int, number of exports to keep.
signature_fn: Function that returns a default signature and a named
signature map, given `Tensor` of `Example` strings, `dict` of `Tensor`s
diff --git a/tensorflow/contrib/learn/python/learn/ops/array_ops.py b/tensorflow/contrib/learn/python/learn/ops/array_ops.py
index a04e91b830..9196a9b9ad 100644
--- a/tensorflow/contrib/learn/python/learn/ops/array_ops.py
+++ b/tensorflow/contrib/learn/python/learn/ops/array_ops.py
@@ -21,25 +21,32 @@ from __future__ import print_function
from tensorflow.contrib.framework import deprecated
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops as array_ops_
from tensorflow.python.ops import math_ops
@deprecated('2016-12-01', 'Use `tf.one_hot` instead.')
-def one_hot_matrix(tensor_in, num_classes, on_value=1.0, off_value=0.0):
+def one_hot_matrix(tensor_in, num_classes, on_value=1.0, off_value=0.0,
+ name=None):
"""Encodes indices from given tensor as one-hot tensor.
TODO(ilblackdragon): Ideally implementation should be
part of TensorFlow with Eigen-native operation.
Args:
- tensor_in: Input tensor of shape [N1, N2].
+ tensor_in: Input `Tensor` of shape [N1, N2].
num_classes: Number of classes to expand index into.
- on_value: Tensor or float, value to fill-in given index.
- off_value: Tensor or float, value to fill-in everything else.
+ on_value: `Tensor` or float, value to fill-in given index.
+ off_value: `Tensor` or float, value to fill-in everything else.
+ name: Name of the op.
Returns:
- Tensor of shape [N1, N2, num_classes] with 1.0 for each id in original
+ `Tensor` of shape `[N1, N2, num_classes]` with 1.0 for each id in original
tensor.
"""
- return array_ops_.one_hot(
- math_ops.cast(tensor_in, dtypes.int64), num_classes, on_value, off_value)
+ with ops.name_scope(
+ name, 'one_hot_matrix',
+ [tensor_in, num_classes, on_value, off_value]) as name_scope:
+ return array_ops_.one_hot(
+ math_ops.cast(tensor_in, dtypes.int64), num_classes, on_value,
+ off_value, name=name_scope)
diff --git a/tensorflow/contrib/learn/python/learn/utils/gc.py b/tensorflow/contrib/learn/python/learn/utils/gc.py
new file mode 100644
index 0000000000..dd4376f051
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/utils/gc.py
@@ -0,0 +1,205 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+r"""System for specifying garbage collection (GC) of path based data.
+
+This framework allows for GC of data specified by path names, for example files
+on disk. gc.Path objects each represent a single item stored at a path and may
+be a base directory,
+ /tmp/exports/0/...
+ /tmp/exports/1/...
+ ...
+or a fully qualified file,
+ /tmp/train-1.ckpt
+ /tmp/train-2.ckpt
+ ...
+
+A gc filter function takes and returns a list of gc.Path items. Filter
+functions are responsible for selecting Path items for preservation or deletion.
+Note that functions should always return a sorted list.
+
+For example,
+ base_dir = "/tmp"
+ # create the directories
+ for e in xrange(10):
+ os.mkdir("%s/%d" % (base_dir, e), 0o755)
+
+ # create a simple parser that pulls the export_version from the directory
+ def parser(path):
+ match = re.match("^" + base_dir + "/(\\d+)$", path.path)
+ if not match:
+ return None
+ return path._replace(export_version=int(match.group(1)))
+
+ path_list = gc.get_paths("/tmp", parser) # contains all ten Paths
+
+ every_fifth = gc.mod_export_version(5)
+ print every_fifth(path_list) # shows ["/tmp/0", "/tmp/5"]
+
+ largest_three = gc.largest_export_versions(3)
+ print largest_three(all_paths) # shows ["/tmp/7", "/tmp/8", "/tmp/9"]
+
+ both = gc.union(every_fifth, largest_three)
+ print both(all_paths) # shows ["/tmp/0", "/tmp/5",
+ # "/tmp/7", "/tmp/8", "/tmp/9"]
+ # delete everything not in 'both'
+ to_delete = gc.negation(both)
+ for p in to_delete(all_paths):
+ gfile.DeleteRecursively(p.path) # deletes: "/tmp/1", "/tmp/2",
+ # "/tmp/3", "/tmp/4", "/tmp/6",
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import heapq
+import math
+import os
+
+from tensorflow.python.platform import gfile
+
+Path = collections.namedtuple('Path', 'path export_version')
+
+
+def largest_export_versions(n):
+ """Creates a filter that keeps the largest n export versions.
+
+ Args:
+ n: number of versions to keep.
+
+ Returns:
+ A filter function that keeps the n largest paths.
+ """
+ def keep(paths):
+ heap = []
+ for idx, path in enumerate(paths):
+ if path.export_version is not None:
+ heapq.heappush(heap, (path.export_version, idx))
+ keepers = [paths[i] for _, i in heapq.nlargest(n, heap)]
+ return sorted(keepers)
+
+ return keep
+
+
+def one_of_every_n_export_versions(n):
+ """Creates a filter that keeps one of every n export versions.
+
+ Args:
+ n: interval size.
+
+ Returns:
+ A filter function that keeps exactly one path from each interval
+ [0, n], (n, 2n], (2n, 3n], etc... If more than one path exists in an
+ interval the largest is kept.
+ """
+ def keep(paths):
+ """A filter function that keeps exactly one out of every n paths."""
+
+ keeper_map = {} # map from interval to largest path seen in that interval
+ for p in paths:
+ if p.export_version is None:
+ # Skip missing export_versions.
+ continue
+ # Find the interval (with a special case to map export_version = 0 to
+ # interval 0.
+ interval = math.floor(
+ (p.export_version - 1) / n) if p.export_version else 0
+ existing = keeper_map.get(interval, None)
+ if (not existing) or (existing.export_version < p.export_version):
+ keeper_map[interval] = p
+ return sorted(keeper_map.values())
+
+ return keep
+
+
+def mod_export_version(n):
+ """Creates a filter that keeps every export that is a multiple of n.
+
+ Args:
+ n: step size.
+
+ Returns:
+ A filter function that keeps paths where export_version % n == 0.
+ """
+ def keep(paths):
+ keepers = []
+ for p in paths:
+ if p.export_version % n == 0:
+ keepers.append(p)
+ return sorted(keepers)
+ return keep
+
+
+def union(lf, rf):
+ """Creates a filter that keeps the union of two filters.
+
+ Args:
+ lf: first filter
+ rf: second filter
+
+ Returns:
+ A filter function that keeps the n largest paths.
+ """
+ def keep(paths):
+ l = set(lf(paths))
+ r = set(rf(paths))
+ return sorted(list(l|r))
+ return keep
+
+
+def negation(f):
+ """Negate a filter.
+
+ Args:
+ f: filter function to invert
+
+ Returns:
+ A filter function that returns the negation of f.
+ """
+ def keep(paths):
+ l = set(paths)
+ r = set(f(paths))
+ return sorted(list(l-r))
+ return keep
+
+
+def get_paths(base_dir, parser):
+ """Gets a list of Paths in a given directory.
+
+ Args:
+ base_dir: directory.
+ parser: a function which gets the raw Path and can augment it with
+ information such as the export_version, or ignore the path by returning
+ None. An example parser may extract the export version from a path
+ such as "/tmp/exports/100" an another may extract from a full file
+ name such as "/tmp/checkpoint-99.out".
+
+ Returns:
+ A list of Paths contained in the base directory with the parsing function
+ applied.
+ By default the following fields are populated,
+ - Path.path
+ The parsing function is responsible for populating,
+ - Path.export_version
+ """
+ raw_paths = gfile.ListDirectory(base_dir)
+ paths = []
+ for r in raw_paths:
+ p = parser(Path(os.path.join(base_dir, r), None))
+ if p:
+ paths.append(p)
+ return sorted(paths)
diff --git a/tensorflow/contrib/learn/python/learn/utils/gc_test.py b/tensorflow/contrib/learn/python/learn/utils/gc_test.py
new file mode 100644
index 0000000000..dbe3304f21
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/utils/gc_test.py
@@ -0,0 +1,120 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for learn.utils.gc."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import re
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+import tensorflow as tf
+
+from tensorflow.contrib.learn.python.learn.utils import gc
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import gfile
+
+
+def tearDownModule():
+ gfile.DeleteRecursively(tf.test.get_temp_dir())
+
+
+class GcTest(test_util.TensorFlowTestCase):
+
+ def testLargestExportVersions(self):
+ paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)]
+ newest = gc.largest_export_versions(2)
+ n = newest(paths)
+ self.assertEquals(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])
+
+ def testLargestExportVersionsDoesNotDeleteZeroFolder(self):
+ paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)]
+ newest = gc.largest_export_versions(2)
+ n = newest(paths)
+ self.assertEquals(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)])
+
+ def testModExportVersion(self):
+ paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6),
+ gc.Path("/foo", 9)]
+ mod = gc.mod_export_version(2)
+ self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)])
+ mod = gc.mod_export_version(3)
+ self.assertEquals(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)])
+
+ def testOneOfEveryNExportVersions(self):
+ paths = [gc.Path("/foo", 0), gc.Path("/foo", 1), gc.Path("/foo", 3),
+ gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 7),
+ gc.Path("/foo", 8), gc.Path("/foo", 33)]
+ one_of = gc.one_of_every_n_export_versions(3)
+ self.assertEquals(one_of(paths),
+ [gc.Path("/foo", 3), gc.Path("/foo", 6),
+ gc.Path("/foo", 8), gc.Path("/foo", 33)])
+
+ def testOneOfEveryNExportVersionsZero(self):
+ # Zero is a special case since it gets rolled into the first interval.
+ # Test that here.
+ paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)]
+ one_of = gc.one_of_every_n_export_versions(3)
+ self.assertEquals(one_of(paths),
+ [gc.Path("/foo", 0), gc.Path("/foo", 5)])
+
+ def testUnion(self):
+ paths = []
+ for i in xrange(10):
+ paths.append(gc.Path("/foo", i))
+ f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3))
+ self.assertEquals(
+ f(paths), [gc.Path("/foo", 0), gc.Path("/foo", 3),
+ gc.Path("/foo", 6), gc.Path("/foo", 7),
+ gc.Path("/foo", 8), gc.Path("/foo", 9)])
+
+ def testNegation(self):
+ paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6),
+ gc.Path("/foo", 9)]
+ mod = gc.negation(gc.mod_export_version(2))
+ self.assertEquals(
+ mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)])
+ mod = gc.negation(gc.mod_export_version(3))
+ self.assertEquals(
+ mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)])
+
+ def testPathsWithParse(self):
+ base_dir = os.path.join(tf.test.get_temp_dir(), "paths_parse")
+ self.assertFalse(gfile.Exists(base_dir))
+ for p in xrange(3):
+ gfile.MakeDirs(os.path.join(base_dir, "%d" % p))
+ # add a base_directory to ignore
+ gfile.MakeDirs(os.path.join(base_dir, "ignore"))
+
+ # create a simple parser that pulls the export_version from the directory.
+ def parser(path):
+ match = re.match("^" + base_dir + "/(\\d+)$", path.path)
+ if not match:
+ return None
+ return path._replace(export_version=int(match.group(1)))
+
+ self.assertEquals(
+ gc.get_paths(base_dir, parser=parser),
+ [gc.Path(os.path.join(base_dir, "0"), 0),
+ gc.Path(os.path.join(base_dir, "1"), 1),
+ gc.Path(os.path.join(base_dir, "2"), 2)])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py
new file mode 100644
index 0000000000..2cb7173d5a
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py
@@ -0,0 +1,97 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utilities for creating input_fns."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import parsing_ops
+
+
+# A return type allowing input_fns to return multiple values in a well-
+# defined way (analogous to ModelFnOps).
+# The expected return values are:
+# features: a dict of string to Tensor, giving the features to be passed to
+# the model.
+# labels: a dict of string to Tensor, giving labels (aka targets) for training.
+# default_inputs: a dict of string to Tensor, giving the input Tensors (if
+# any) that this input_fn expects to be fed.
+InputFnOps = collections.namedtuple('InputFnOps',
+ ['features',
+ 'labels',
+ 'default_inputs'])
+
+
+def build_parsing_serving_input_fn(feature_spec, default_batch_size=1):
+ """Build an input_fn appropriate for serving, expecting fed tf.Examples.
+
+ Creates an input_fn that expects a serialized tf.Example fed into a string
+ placeholder. The function parses the tf.Example according to the provided
+ feature_spec, and returns all parsed Tensors as features. This input_fn is
+ for use at serving time, so the labels return value is always None.
+
+ Args:
+ feature_spec: a dict of string to `VarLenFeature`/`FixedLenFeature`.
+ default_batch_size: the number of query examples expected per batch.
+
+ Returns:
+ An input_fn suitable for use in serving.
+ """
+ def input_fn():
+ """An input_fn that expects a serialized tf.Example."""
+ serialized_tf_example = array_ops.placeholder(dtype=dtypes.string,
+ shape=[default_batch_size],
+ name='input_example_tensor')
+ inputs = {'examples': serialized_tf_example}
+ features = parsing_ops.parse_example(serialized_tf_example, feature_spec)
+ labels = None # these are not known in serving!
+ return InputFnOps(features, labels, inputs)
+ return input_fn
+
+
+def build_default_serving_input_fn(features, default_batch_size=1):
+ """Build an input_fn appropriate for serving, expecting feature Tensors.
+
+ Creates an input_fn that expects all features to be fed directly.
+ This input_fn is for use at serving time, so the labels return value is always
+ None.
+
+ Args:
+ features: a dict of string to `Tensor`.
+ default_batch_size: the number of query examples expected per batch.
+
+ Returns:
+ An input_fn suitable for use in serving.
+ """
+ def input_fn():
+ """an input_fn that expects all features to be fed directly."""
+ features_placeholders = {}
+ for name, t in features.items():
+ shape_list = t.get_shape().as_list()
+ shape_list[0] = default_batch_size
+ shape = tensor_shape.TensorShape(shape_list)
+
+ features_placeholders[name] = array_ops.placeholder(dtype=t.dtype,
+ shape=shape,
+ name=t.name)
+ labels = None # these are not known in serving!
+ return InputFnOps(features_placeholders, labels, features_placeholders)
+ return input_fn
diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
new file mode 100644
index 0000000000..54bb0fb3d7
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
@@ -0,0 +1,248 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utilities supporting export to SavedModel."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import os
+import re
+import time
+
+from tensorflow.contrib.learn.python.learn.estimators import constants
+from tensorflow.contrib.learn.python.learn.estimators import prediction_key
+from tensorflow.contrib.learn.python.learn.utils import gc
+from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
+from tensorflow.python.platform import gfile
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import signature_def_utils
+
+from tensorflow.python.util import compat
+
+# A key for use in the input_alternatives dict indicating the default input.
+# This is the input that will be expected when a serving request does not
+# specify a specific signature.
+# The default input alternative specifies placeholders that the input_fn
+# requires to be fed (in the typical case, a single placeholder for a
+# serialized tf.Example).
+DEFAULT_INPUT_ALTERNATIVE_KEY = 'default_input_alternative'
+
+# A key for use in the input_alternatives dict indicating the features input.
+# The features inputs alternative specifies the feature Tensors provided as
+# input to the model_fn, i.e. the outputs of the input_fn.
+FEATURES_INPUT_ALTERNATIVE_KEY = 'features_input_alternative'
+
+# A key for use in the output_alternatives dict indicating the default output.
+# This is the output that will be provided when a serving request does not
+# specify a specific signature.
+# In a single-headed model, the single output is automatically the default.
+# In a multi-headed model, the name of the desired default head should be
+# provided to get_output_alternatives.
+DEFAULT_OUTPUT_ALTERNATIVE_KEY = 'default_output_alternative'
+
+
+def build_standardized_signature_def(
+ input_tensors, output_tensors, problem_type):
+ """Build a SignatureDef using problem type and input and output Tensors.
+
+ Note that this delegates the actual creation of the signatures to methods in
+ //third_party/tensorflow/python/saved_model/signature_def_utils.py, which may
+ assign names to the input and output tensors (depending on the problem type)
+ that are standardized in the context of SavedModel.
+
+ Args:
+ input_tensors: a dict of string key to `Tensor`
+ output_tensors: a dict of string key to `Tensor`
+ problem_type: an instance of constants.ProblemType, specifying
+ classification, regression, etc.
+
+ Returns:
+ A SignatureDef using SavedModel standard keys where possible.
+
+ Raises:
+ ValueError: if input_tensors or output_tensors is None or empty.
+ """
+
+ if not input_tensors:
+ raise ValueError('input_tensors must be provided.')
+ if not output_tensors:
+ raise ValueError('output_tensors must be provided.')
+
+ # Per-method signature_def functions will standardize the keys if possible
+ if _is_classification_problem(problem_type, input_tensors, output_tensors):
+ (_, examples), = input_tensors.items()
+ classes = output_tensors.get(prediction_key.PredictionKey.CLASSES)
+ scores = output_tensors.get(prediction_key.PredictionKey.SCORES)
+ if not (classes or scores):
+ (_, classes), = output_tensors.items()
+ return signature_def_utils.classification_signature_def(
+ examples, classes, scores)
+ elif _is_regression_problem(problem_type, input_tensors, output_tensors):
+ (_, examples), = input_tensors.items()
+ (_, predictions), = output_tensors.items()
+ return signature_def_utils.regression_signature_def(examples, predictions)
+ else:
+ return signature_def_utils.predict_signature_def(
+ input_tensors, output_tensors)
+
+
+def _is_classification_problem(problem_type, input_tensors, output_tensors):
+ classes = output_tensors.get(prediction_key.PredictionKey.CLASSES)
+ scores = output_tensors.get(prediction_key.PredictionKey.SCORES)
+ return ((problem_type == constants.ProblemType.CLASSIFICATION or
+ problem_type == constants.ProblemType.LOGISTIC_REGRESSION)
+ and len(input_tensors) == 1
+ and (classes or scores or len(output_tensors) == 1))
+
+
+def _is_regression_problem(problem_type, input_tensors, output_tensors):
+ return (problem_type == constants.ProblemType.LINEAR_REGRESSION
+ and len(input_tensors) == 1
+ and len(output_tensors) == 1)
+
+
+def get_input_alternatives(input_ops):
+ """Obtain all input alternatives using the input_fn output and heuristics."""
+ input_alternatives = {}
+ if isinstance(input_ops, input_fn_utils.InputFnOps):
+ features, unused_labels, default_inputs = input_ops
+ input_alternatives[DEFAULT_INPUT_ALTERNATIVE_KEY] = default_inputs
+ else:
+ features, unused_labels = input_ops
+
+ if not features:
+ raise ValueError('Features must be defined.')
+
+ # Add the "features" input_signature in any case.
+ # Note defensive copy because model_fns alter the features dict.
+ input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY] = (
+ copy.copy(features))
+
+ return input_alternatives, features
+
+
+def get_output_alternatives(
+ model_fn_ops,
+ default_output_alternative_key=DEFAULT_OUTPUT_ALTERNATIVE_KEY):
+ """Obtain all output alternatives using the model_fn output and heuristics."""
+ output_alternatives = model_fn_ops.output_alternatives
+
+ # Identify the default outputs, creating them if needed.
+ if (output_alternatives
+ and default_output_alternative_key not in output_alternatives):
+ raise ValueError('default_output_alternative_key not in '
+ 'output_alternatives: %s' % default_output_alternative_key)
+
+ if (output_alternatives
+ and default_output_alternative_key in output_alternatives):
+ # If a default head is provided, use it.
+ actual_default_output_alternative_key = default_output_alternative_key
+ return output_alternatives, actual_default_output_alternative_key
+
+ if output_alternatives and len(output_alternatives) == 1:
+ # If there is only one head, use it as the default.
+ (actual_default_output_alternative_key, _), = output_alternatives.items()
+ return output_alternatives, actual_default_output_alternative_key
+
+ # Lacking provided output alternatives, the best we can do is to
+ # interpret the model as single-headed of unknown type.
+ default_problem_type = constants.ProblemType.UNSPECIFIED
+ default_outputs = model_fn_ops.predictions
+ actual_default_output_alternative_key = DEFAULT_OUTPUT_ALTERNATIVE_KEY
+ output_alternatives = {actual_default_output_alternative_key:
+ (default_problem_type, default_outputs)}
+ return output_alternatives, actual_default_output_alternative_key
+
+
+def build_all_signature_defs(input_alternatives, output_alternatives,
+ actual_default_output_alternative_key):
+ """Build `SignatureDef`s from all pairs of input and output alternatives."""
+
+ signature_def_map = {
+ ('%s:%s' % (input_key, output_key or 'None')):
+ build_standardized_signature_def(
+ inputs, outputs, problem_type)
+ for input_key, inputs in input_alternatives.items()
+ for output_key, (problem_type, outputs)
+ in output_alternatives.items()}
+
+ # Add the default SignatureDef
+ default_inputs = input_alternatives[DEFAULT_INPUT_ALTERNATIVE_KEY]
+ if not default_inputs:
+ default_inputs = input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY]
+ # default outputs are guaranteed to exist above
+ (default_problem_type, default_outputs) = (
+ output_alternatives[actual_default_output_alternative_key])
+ signature_def_map[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
+ build_standardized_signature_def(
+ default_inputs, default_outputs, default_problem_type))
+
+ return signature_def_map
+
+
+def get_timestamped_export_dir(export_dir_base):
+ """Builds a path to a new subdirectory within the base directory.
+
+ Each export is written into a new subdirectory named using the
+ current time. This guarantees monotonically increasing version
+ numbers even across multiple runs of the pipeline.
+ The timestamp used is the number of milliseconds since epoch UTC.
+
+ Args:
+ export_dir_base: A string containing a directory to write the exported
+ graph and checkpoints.
+ Returns:
+ The full path of the new subdirectory (which is not actually created yet).
+ """
+ export_timestamp = int(time.time() * 1e3)
+
+ export_dir = os.path.join(
+ compat.as_bytes(export_dir_base),
+ compat.as_bytes(str(export_timestamp)))
+ return export_dir
+
+
+def garbage_collect_exports(export_dir_base, exports_to_keep):
+ """Deletes older exports, retaining only a given number of the most recent.
+
+ Export subdirectories are assumed to be named with monotonically increasing
+ integers; the most recent are taken to be those with the largest values.
+
+ Args:
+ export_dir_base: the base directory under which each export is in a
+ versioned subdirectory.
+ exports_to_keep: the number of recent exports to retain.
+ """
+ if exports_to_keep is None:
+ return
+
+ keep_filter = gc.largest_export_versions(exports_to_keep)
+ delete_filter = gc.negation(keep_filter)
+
+ # Export dir must not end with / or it will break the re match below.
+ if export_dir_base.endswith('/'):
+ export_dir_base = export_dir_base[:-1]
+
+ # create a simple parser that pulls the export_version from the directory.
+ def parser(path):
+ match = re.match('^' + export_dir_base + '/(\\d{13})$', path.path)
+ if not match:
+ return None
+ return path._replace(export_version=int(match.group(1)))
+
+ for p in delete_filter(gc.get_paths(export_dir_base, parser=parser)):
+ gfile.DeleteRecursively(p.path)
diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py
new file mode 100644
index 0000000000..538e0ab104
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py
@@ -0,0 +1,228 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests of utilities supporting export to SavedModel."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+import time
+
+import tensorflow as tf
+
+from tensorflow.contrib.learn.python.learn.estimators import constants
+from tensorflow.contrib.learn.python.learn.estimators import model_fn
+from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
+from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
+from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import signature_def_utils
+
+
+class SavedModelExportUtilsTest(tf.test.TestCase):
+
+ def test_build_standardized_signature_def(self):
+ input_tensors = {
+ "input-1": tf.placeholder(tf.float32, 1, name="input-tensor-1")}
+ output_tensors = {
+ "output-1": tf.placeholder(tf.float32, 1, name="output-tensor-1")}
+ problem_type = constants.ProblemType.LINEAR_REGRESSION
+ regression_signature_def = (
+ saved_model_export_utils.build_standardized_signature_def(
+ input_tensors, output_tensors, problem_type))
+ expected_regression_signature_def = meta_graph_pb2.SignatureDef()
+ shape = tensor_shape_pb2.TensorShapeProto(
+ dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
+ dtype = types_pb2.DataType.Value("DT_FLOAT")
+ expected_regression_signature_def.inputs[
+ signature_constants.REGRESS_INPUTS].CopyFrom(
+ meta_graph_pb2.TensorInfo(name="input-tensor-1:0",
+ dtype=dtype,
+ tensor_shape=shape))
+ expected_regression_signature_def.outputs[
+ signature_constants.REGRESS_OUTPUTS].CopyFrom(
+ meta_graph_pb2.TensorInfo(name="output-tensor-1:0",
+ dtype=dtype,
+ tensor_shape=shape))
+
+ expected_regression_signature_def.method_name = (
+ signature_constants.REGRESS_METHOD_NAME)
+ self.assertEqual(regression_signature_def,
+ expected_regression_signature_def)
+
+ def test_get_input_alternatives(self):
+ input_ops = input_fn_utils.InputFnOps("bogus features dict", None,
+ "bogus default input dict")
+
+ input_alternatives, _ = saved_model_export_utils.get_input_alternatives(
+ input_ops)
+ self.assertEqual(
+ input_alternatives[
+ saved_model_export_utils.DEFAULT_INPUT_ALTERNATIVE_KEY],
+ "bogus default input dict")
+ self.assertEqual(
+ input_alternatives[
+ saved_model_export_utils.FEATURES_INPUT_ALTERNATIVE_KEY],
+ "bogus features dict")
+
+ def test_get_output_alternatives_explicit(self):
+ provided_output_alternatives = {
+ "head-1": (constants.ProblemType.LINEAR_REGRESSION,
+ "bogus output dict"),
+ "head-2": (constants.ProblemType.CLASSIFICATION,
+ "bogus output dict 2"),
+ "head-3": (constants.ProblemType.UNSPECIFIED,
+ "bogus output dict 3"),
+ }
+ model_fn_ops = model_fn.ModelFnOps(
+ model_fn.ModeKeys.INFER,
+ predictions={"some_output": "bogus_tensor"},
+ output_alternatives=provided_output_alternatives)
+ output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
+ model_fn_ops, "head-1")
+
+ self.assertEqual(provided_output_alternatives, output_alternatives)
+
+ def test_get_output_alternatives_implicit(self):
+ prediction_tensor = tf.constant(["bogus"])
+ model_fn_ops = model_fn.ModelFnOps(
+ model_fn.ModeKeys.INFER,
+ predictions={"some_output": prediction_tensor},
+ output_alternatives=None)
+
+ output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
+ model_fn_ops, "some_output")
+ self.assertEqual(
+ {"default_output_alternative": (constants.ProblemType.UNSPECIFIED,
+ {"some_output": prediction_tensor})},
+ output_alternatives)
+
+ def test_build_all_signature_defs(self):
+ input_features = tf.constant(["10"])
+ input_example = tf.constant(["11"])
+ input_ops = input_fn_utils.InputFnOps(
+ {"features": input_features},
+ None,
+ {"default input": input_example})
+ input_alternatives, _ = (
+ saved_model_export_utils.get_input_alternatives(input_ops))
+ output_1 = tf.constant(["1"])
+ output_2 = tf.constant(["2"])
+ output_3 = tf.constant(["3"])
+ provided_output_alternatives = {
+ "head-1": (constants.ProblemType.LINEAR_REGRESSION,
+ {"some_output_1": output_1}),
+ "head-2": (constants.ProblemType.CLASSIFICATION,
+ {"some_output_2": output_2}),
+ "head-3": (constants.ProblemType.UNSPECIFIED,
+ {"some_output_3": output_3}),
+ }
+ model_fn_ops = model_fn.ModelFnOps(
+ model_fn.ModeKeys.INFER,
+ predictions={"some_output": tf.constant(["4"])},
+ output_alternatives=provided_output_alternatives)
+ output_alternatives, _ = (
+ saved_model_export_utils.get_output_alternatives(model_fn_ops,
+ "head-1"))
+
+ signature_defs = saved_model_export_utils.build_all_signature_defs(
+ input_alternatives, output_alternatives, "head-1")
+
+ expected_signature_defs = {
+ "serving_default":
+ signature_def_utils.regression_signature_def(
+ input_example, output_1),
+ "default_input_alternative:head-1":
+ signature_def_utils.regression_signature_def(
+ input_example, output_1),
+ "default_input_alternative:head-2":
+ signature_def_utils.classification_signature_def(
+ input_example, output_2, None),
+ "default_input_alternative:head-3":
+ signature_def_utils.predict_signature_def(
+ {"input": input_example}, {"output": output_3}),
+ "features_input_alternative:head-1":
+ signature_def_utils.regression_signature_def(
+ input_features, output_1),
+ "features_input_alternative:head-2":
+ signature_def_utils.classification_signature_def(
+ input_features, output_2, None),
+ "features_input_alternative:head-3":
+ signature_def_utils.predict_signature_def(
+ {"input": input_features}, {"output": output_3}),
+ }
+
+ self.assertDictEqual(expected_signature_defs, signature_defs)
+
+ def test_get_timestamped_export_dir(self):
+ export_dir_base = tempfile.mkdtemp() + "export/"
+ export_dir_1 = saved_model_export_utils.get_timestamped_export_dir(
+ export_dir_base)
+ time.sleep(0.001)
+ export_dir_2 = saved_model_export_utils.get_timestamped_export_dir(
+ export_dir_base)
+ time.sleep(0.001)
+ export_dir_3 = saved_model_export_utils.get_timestamped_export_dir(
+ export_dir_base)
+
+ # Export directories should be named using a timestamp that is milliseconds
+ # since epoch. Such a timestamp is 13 digits long.
+ time_1 = os.path.basename(export_dir_1)
+ self.assertEqual(13, len(time_1))
+ time_2 = os.path.basename(export_dir_2)
+ self.assertEqual(13, len(time_2))
+ time_3 = os.path.basename(export_dir_3)
+ self.assertEqual(13, len(time_3))
+
+ self.assertTrue(int(time_1) < int(time_2))
+ self.assertTrue(int(time_2) < int(time_3))
+
+ def test_garbage_collect_exports(self):
+ export_dir_base = tempfile.mkdtemp() + "export/"
+ tf.gfile.MkDir(export_dir_base)
+ export_dir_1 = _create_test_export_dir(export_dir_base)
+ export_dir_2 = _create_test_export_dir(export_dir_base)
+ export_dir_3 = _create_test_export_dir(export_dir_base)
+ export_dir_4 = _create_test_export_dir(export_dir_base)
+
+ self.assertTrue(tf.gfile.Exists(export_dir_1))
+ self.assertTrue(tf.gfile.Exists(export_dir_2))
+ self.assertTrue(tf.gfile.Exists(export_dir_3))
+ self.assertTrue(tf.gfile.Exists(export_dir_4))
+
+ # Garbage collect all but the most recent 2 exports,
+ # where recency is determined based on the timestamp directory names.
+ saved_model_export_utils.garbage_collect_exports(export_dir_base, 2)
+
+ self.assertFalse(tf.gfile.Exists(export_dir_1))
+ self.assertFalse(tf.gfile.Exists(export_dir_2))
+ self.assertTrue(tf.gfile.Exists(export_dir_3))
+ self.assertTrue(tf.gfile.Exists(export_dir_4))
+
+
+def _create_test_export_dir(export_dir_base):
+ export_dir = saved_model_export_utils.get_timestamped_export_dir(
+ export_dir_base)
+ tf.gfile.MkDir(export_dir)
+ time.sleep(0.001)
+ return export_dir
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD
index a2f012d349..e3ed248dd5 100644
--- a/tensorflow/contrib/linalg/BUILD
+++ b/tensorflow/contrib/linalg/BUILD
@@ -24,7 +24,7 @@ cuda_py_tests(
cuda_py_tests(
name = "linear_operator_diag_test",
- size = "small",
+ size = "medium",
srcs = ["python/kernel_tests/linear_operator_diag_test.py"],
additional_deps = [
":linalg_py",
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py
index 98eac39683..d03fb1d66f 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py
@@ -26,50 +26,18 @@ linalg = tf.contrib.linalg
tf.set_random_seed(23)
-class LinearOperatorDiagtest(
- linear_operator_test_util.LinearOperatorDerivedClassTest):
+class LinearOperatorDiagTest(
+ linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
- @property
- def _dtypes_to_test(self):
- return [tf.float32, tf.float64]
-
- @property
- def _shapes_to_test(self):
- # non-batch operators (n, n) and batch operators.
- return [(0, 0), (1, 1), (1, 3, 3), (3, 2, 2), (2, 1, 3, 3)]
-
- def _make_rhs(self, operator):
- # This operator is square, so rhs and x will have same shape.
- return self._make_x(operator)
-
- def _make_x(self, operator):
- # Return the number of systems to solve, R, equal to 1 or 2.
- r = self._get_num_systems(operator)
- # If operator.shape = [B1,...,Bb, N, N] this returns a random matrix of
- # shape [B1,...,Bb, N, R], R = 1 or 2.
- if operator.shape.is_fully_defined():
- batch_shape = operator.batch_shape.as_list()
- n = operator.domain_dimension.value
- rhs_shape = batch_shape + [n, r]
- else:
- batch_shape = operator.batch_shape_dynamic()
- n = operator.domain_dimension_dynamic()
- rhs_shape = tf.concat(0, (batch_shape, [n, r]))
- return tf.random_normal(shape=rhs_shape, dtype=operator.dtype)
-
- def _get_num_systems(self, operator):
- """Get some number, either 1 or 2, depending on operator."""
- if operator.tensor_rank is None or operator.tensor_rank % 2:
- return 1
- else:
- return 2
-
def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder):
shape = list(shape)
diag_shape = shape[:-1]
- diag = tf.random_normal(diag_shape, dtype=dtype)
+ diag = tf.random_normal(diag_shape, dtype=dtype.real_dtype)
+ if dtype.is_complex:
+ diag = tf.complex(
+ diag, tf.random_normal(diag_shape, dtype=dtype.real_dtype))
diag_ph = tf.placeholder(dtype=dtype)
if use_placeholder:
@@ -87,15 +55,32 @@ class LinearOperatorDiagtest(
return operator, mat, feed_dict
- def test_assert_positive_definite(self):
- # Singlular matrix with one positive eigenvalue and one zero eigenvalue.
+ def test_assert_positive_definite_raises_for_zero_eigenvalue(self):
+ # Matrix with one positive eigenvalue and one zero eigenvalue.
+ with self.test_session():
+ diag = [1.0, 0.0]
+ operator = linalg.LinearOperatorDiag(diag)
+ with self.assertRaisesOpError("non-positive.*not positive definite"):
+ operator.assert_positive_definite().run()
+
+ def test_assert_positive_definite_raises_for_negative_real_eigvalues(self):
with self.test_session():
- diag = [1.0, -1.0]
+ diag_x = [1.0, -2.0]
+ diag_y = [0., 0.] # Imaginary eigenvalues should not matter.
+ diag = tf.complex(diag_x, diag_y)
operator = linalg.LinearOperatorDiag(diag)
- with self.assertRaisesOpError("was not positive definite"):
+ with self.assertRaisesOpError("non-positive real.*not positive definite"):
operator.assert_positive_definite().run()
- def test_assert_non_singular(self):
+ def test_assert_positive_definite_does_not_raise_if_pd_and_complex(self):
+ with self.test_session():
+ x = [1., 2.]
+ y = [1., 0.]
+ diag = tf.complex(x, y) # Re[diag] > 0.
+ # Should not fail
+ linalg.LinearOperatorDiag(diag).assert_positive_definite().run()
+
+ def test_assert_non_singular_raises_if_zero_eigenvalue(self):
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.
with self.test_session():
diag = [1.0, 0.0]
@@ -103,10 +88,36 @@ class LinearOperatorDiagtest(
with self.assertRaisesOpError("Singular operator"):
operator.assert_non_singular().run()
+ def test_assert_non_singular_does_not_raise_for_complex_nonsingular(self):
+ with self.test_session():
+ x = [1., 0.]
+ y = [0., 1.]
+ diag = tf.complex(x, y)
+ # Should not raise.
+ linalg.LinearOperatorDiag(diag).assert_non_singular().run()
+
+ def test_assert_self_adjoint_raises_if_diag_has_complex_part(self):
+ with self.test_session():
+ x = [1., 0.]
+ y = [0., 1.]
+ diag = tf.complex(x, y)
+ operator = linalg.LinearOperatorDiag(diag)
+ with self.assertRaisesOpError("imaginary.*not self-adjoint"):
+ operator.assert_self_adjoint().run()
+
+ def test_assert_self_adjoint_does_not_raise_for_diag_with_zero_imag(self):
+ with self.test_session():
+ x = [1., 0.]
+ y = [0., 0.]
+ diag = tf.complex(x, y)
+ operator = linalg.LinearOperatorDiag(diag)
+ # Should not raise
+ operator.assert_self_adjoint().run()
+
def test_broadcast_apply_and_solve(self):
# These cannot be done in the automated (base test class) tests since they
- # test shapes that tf.batch_matmul cannot handle.
- # In particular, tf.batch_matmul does not broadcast.
+ # test shapes that tf.matmul cannot handle.
+ # In particular, tf.matmul does not broadcast.
with self.test_session() as sess:
x = tf.random_normal(shape=(2, 2, 3, 4))
@@ -122,7 +133,7 @@ class LinearOperatorDiagtest(
self.assertAllEqual((2, 2, 3, 3), mat.get_shape()) # being pedantic.
operator_apply = operator.apply(x)
- mat_apply = tf.batch_matmul(mat, x)
+ mat_apply = tf.matmul(mat, x)
self.assertAllEqual(operator_apply.get_shape(), mat_apply.get_shape())
self.assertAllClose(*sess.run([operator_apply, mat_apply]))
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py
index eb279177ab..4228903388 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py
@@ -16,9 +16,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
import tensorflow as tf
linalg = tf.contrib.linalg
+rng = np.random.RandomState(123)
class LinearOperatorShape(linalg.LinearOperator):
@@ -44,6 +46,31 @@ class LinearOperatorShape(linalg.LinearOperator):
return tf.constant(self._stored_shape, dtype=tf.int32)
+class LinearOperatorApplyOnly(linalg.LinearOperator):
+ """LinearOperator that simply wraps a [batch] matrix and implements apply."""
+
+ def __init__(self,
+ matrix,
+ is_non_singular=None,
+ is_self_adjoint=None,
+ is_positive_definite=None):
+ self._matrix = tf.convert_to_tensor(matrix, name="matrix")
+ super(LinearOperatorApplyOnly, self).__init__(
+ dtype=matrix.dtype,
+ is_non_singular=is_non_singular,
+ is_self_adjoint=is_self_adjoint,
+ is_positive_definite=is_positive_definite,)
+
+ def _shape(self):
+ return self._matrix.get_shape()
+
+ def _shape_dynamic(self):
+ return tf.shape(self._matrix)
+
+ def _apply(self, x, adjoint=False):
+ return tf.matmul(self._matrix, x, adjoint_a=adjoint)
+
+
class LinearOperatorTest(tf.test.TestCase):
def test_all_shape_properties_defined_by_the_one_property_shape(self):
@@ -78,6 +105,23 @@ class LinearOperatorTest(tf.test.TestCase):
self.assertTrue(operator.is_self_adjoint)
self.assertFalse(operator.is_positive_definite)
+ def test_generic_to_dense_method_non_square_matrix_static(self):
+ matrix = rng.randn(2, 3, 4)
+ operator = LinearOperatorApplyOnly(matrix)
+ with self.test_session():
+ operator_dense = operator.to_dense()
+ self.assertAllEqual((2, 3, 4), operator_dense.get_shape())
+ self.assertAllClose(matrix, operator_dense.eval())
+
+ def test_generic_to_dense_method_non_square_matrix_dynamic(self):
+ matrix = rng.randn(2, 3, 4)
+ matrix_ph = tf.placeholder(tf.float64)
+ operator = LinearOperatorApplyOnly(matrix_ph)
+ with self.test_session():
+ operator_dense = operator.to_dense()
+ self.assertAllClose(
+ matrix, operator_dense.eval(feed_dict={matrix_ph: matrix}))
+
-if __name__ == '__main__':
+if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator.py b/tensorflow/contrib/linalg/python/ops/linear_operator.py
index d5aa3fdf25..6199518af0 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator.py
@@ -23,6 +23,7 @@ import contextlib
from tensorflow.contrib import framework as contrib_framework
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
__all__ = ["LinearOperator"]
@@ -114,6 +115,19 @@ class LinearOperator(object):
### Performance
FILL THIS IN
+
+ ### Matrix property hints
+
+ This `LinearOperator` is initialized with boolean flags of the form `is_X`,
+ for `X = non_singular, self_adjoint` etc...
+ These have the following meaning
+ * If `is_X == True`, callers should expect the operator to have the
+ property `X`. This is a promise that should be fulfilled, but is *not* a
+ runtime assert. For example, finite floating point precision may result
+ in these promises being violated.
+ * If `is_X == False`, callers should expect the operator to not have `X`.
+ * If `is_X == None` (the default), callers should have no expectation either
+ way.
"""
def __init__(self,
@@ -123,21 +137,11 @@ class LinearOperator(object):
is_self_adjoint=None,
is_positive_definite=None,
name=None):
- """Initialize the `LinearOperator`.
+ r"""Initialize the `LinearOperator`.
**This is a private method for subclass use.**
**Subclasses should copy-paste this `__init__` documentation.**
- For `X = non_singular, self_adjoint` etc...
- `is_X` is a Python `bool` initialization argument with the following meaning
- * If `is_X == True`, callers should expect the operator to have the
- attribute `X`. This is a promise that should be fulfilled, but is *not* a
- runtime assert. Issues, such as floating point error, could mean the
- operator violates this promise.
- * If `is_X == False`, callers should expect the operator to not have `X`.
- * If `is_X == None` (the default), callers should have no expectation either
- way.
-
Args:
dtype: The type of the this `LinearOperator`. Arguments to `apply` and
`solve` will have to be this type.
@@ -146,18 +150,21 @@ class LinearOperator(object):
is_non_singular: Expect that this operator is non-singular.
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose. If `dtype` is real, this is equivalent to being symmetric.
- is_positive_definite: Expect that this operator is positive definite.
- name: A name for this `LinearOperator`. Default: subclass name.
+ is_positive_definite: Expect that this operator is positive definite,
+ meaning the real part of all eigenvalues is positive. We do not require
+ the operator to be self-adjoint to be positive-definite. See:
+ https://en.wikipedia.org/wiki/Positive-definite_matrix\
+ #Extension_for_non_symmetric_matrices
+ name: A name for this `LinearOperator`.
Raises:
ValueError: if any member of graph_parents is `None` or not a `Tensor`.
"""
- if is_positive_definite and not is_self_adjoint:
- raise ValueError(
- "A positive definite matrix is by definition self adjoint")
- if is_positive_definite and not is_non_singular:
- raise ValueError(
- "A positive definite matrix is by definition non-singular")
+ # Check and auto-set flags.
+ if is_positive_definite:
+ if is_non_singular is False:
+ raise ValueError("A positive definite matrix is always non-singular.")
+ is_non_singular = True
graph_parents = [] if graph_parents is None else graph_parents
for i, t in enumerate(graph_parents):
@@ -384,10 +391,28 @@ class LinearOperator(object):
raise NotImplementedError("assert_positive_definite is not implemented.")
def assert_positive_definite(self, name="assert_positive_definite"):
- """Returns an `Op` that asserts this operator is positive definite."""
+ """Returns an `Op` that asserts this operator is positive definite.
+
+ Here, positive definite means the real part of all eigenvalues is positive.
+ We do not require the operator to be self-adjoint.
+
+ Args:
+ name: A name to give this `Op`.
+
+ Returns:
+ An `Op` that asserts this operator is positive definite.
+ """
with self._name_scope(name):
return self._assert_positive_definite()
+ def _assert_self_adjoint(self):
+ raise NotImplementedError("assert_self_adjoint is not implemented.")
+
+ def assert_self_adjoint(self, name="assert_self_adjoint"):
+ """Returns an `Op` that asserts this operator is self-adjoint."""
+ with self._name_scope(name):
+ return self._assert_self_adjoint()
+
def _apply(self, x, adjoint=False):
raise NotImplementedError("_apply is not implemented.")
@@ -485,9 +510,38 @@ class LinearOperator(object):
return self._solve(rhs, adjoint=adjoint)
def _to_dense(self):
- raise NotImplementedError("_to_dense is not implemented.")
+ """Generic and often inefficient implementation. Override often."""
+ if self.batch_shape.is_fully_defined():
+ batch_shape = self.batch_shape
+ else:
+ batch_shape = self.batch_shape_dynamic()
+
+ if self.domain_dimension.value is not None:
+ n = self.domain_dimension.value
+ else:
+ n = self.domain_dimension_dynamic()
+
+ eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype)
+ return self.apply(eye)
def to_dense(self, name="to_dense"):
"""Return a dense (batch) matrix representing this operator."""
with self._name_scope(name):
return self._to_dense()
+
+ def _add_to_tensor(self, x):
+ raise NotImplementedError("_add_to_tensor is not implemented.")
+
+ def add_to_tensor(self, x, name="add_to_tensor"):
+ """Add matrix represented by this operator to `x`. Equivalent to `A + x`.
+
+ Args:
+ x: `Tensor` with same `dtype` and shape broadcastable to `self.shape`.
+ name: A name to give this `Op`.
+
+ Returns:
+ A `Tensor` with broadcast shape and same `dtype` as `self`.
+ """
+ with self._name_scope(name, values=[x]):
+ x = ops.convert_to_tensor(x, name="x")
+ return self._add_to_tensor(x)
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py
index 6f1c769758..f65ed9a6c8 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.linalg.python.ops import linear_operator
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@@ -34,7 +35,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
This operator acts like a [batch] matrix `A` with shape
`[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
- an `m x n` matrix. Again, this matrix `A` may not be materialized, but for
+ an `N x N` matrix. This matrix `A` is not materialized, but for
purposes of broadcasting this shape will be relevant.
`LinearOperatorDiag` is initialized with a (batch) vector.
@@ -48,7 +49,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
==> [[1., 0.]
[0., -1.]]
- operator.shape()
+ operator.shape
==> [2, 2]
operator.log_determinant()
@@ -83,7 +84,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
### Performance
- Suppose `operator` is a `LinearOperatorDiag` is of shape `[N, N]`,
+ Suppose `operator` is a `LinearOperatorDiag` of shape `[N, N]`,
and `x.shape = [N, R]`. Then
* `operator.apply(x)` involves `N*R` multiplications.
@@ -92,6 +93,19 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
`[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
+
+ ### Matrix property hints
+
+ This `LinearOperator` is initialized with boolean flags of the form `is_X`,
+ for `X = non_singular, self_adjoint` etc...
+ These have the following meaning
+ * If `is_X == True`, callers should expect the operator to have the
+ property `X`. This is a promise that should be fulfilled, but is *not* a
+ runtime assert. For example, finite floating point precision may result
+ in these promises being violated.
+ * If `is_X == False`, callers should expect the operator to not have `X`.
+ * If `is_X == None` (the default), callers should have no expectation either
+ way.
"""
def __init__(self,
@@ -102,44 +116,45 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
name="LinearOperatorDiag"):
"""Initialize a `LinearOperatorDiag`.
- For `X = non_singular, self_adjoint` etc...
- `is_X` is a Python `bool` initialization argument with the following meaning
- * If `is_X == True`, callers should expect the operator to have the
- attribute `X`. This is a promise that should be fulfilled, but is *not* a
- runtime assert. Issues, such as floating point error, could mean the
- operator violates this promise.
- * If `is_X == False`, callers should expect the operator to not have `X`.
- * If `is_X == None` (the default), callers should have no expectation either
- way.
-
Args:
- diag: Shape `[B1,...,Bb, N]` real float type `Tensor` with `b >= 0`,
- `N >= 0`. The diagonal of the operator.
+ diag: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
+ The diagonal of the operator. Allowed dtypes: `float32`, `float64`,
+ `complex64`, `complex128`.
is_non_singular: Expect that this operator is non-singular.
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose. Since this is a real (not complex) diagonal operator, it is
always self adjoint.
- is_positive_definite: Expect that this operator is positive definite.
- name: A name for this `LinearOperator`. Default: subclass name.
+ is_positive_definite: Expect that this operator is positive definite,
+ meaning the real part of all eigenvalues is positive. We do not require
+ the operator to be self-adjoint to be positive-definite. See:
+ https://en.wikipedia.org/wiki/Positive-definite_matrix
+ #Extension_for_non_symmetric_matrices
+ name: A name for this `LinearOperator`.
Raises:
- ValueError: If `diag.dtype` is not floating point.
+ TypeError: If `diag.dtype` is not an allowed type.
ValueError: If `is_self_adjoint` is not `True`.
"""
+ allowed_dtypes = [
+ dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128]
+
with ops.name_scope(name, values=[diag]):
self._diag = ops.convert_to_tensor(diag, name="diag")
- if not self._diag.dtype.is_floating:
- raise ValueError("Only real floating point matrices are supported.")
- if not is_self_adjoint:
- raise ValueError("A real diagonal matrix is always self adjoint.")
+ dtype = self._diag.dtype
+ if dtype not in allowed_dtypes:
+ raise TypeError(
+ "Argument diag must have dtype in %s. Found: %s"
+ % (allowed_dtypes, dtype))
+ if dtype.is_floating and not is_self_adjoint:
+ raise ValueError("A real diagonal operator is always self adjoint.")
super(LinearOperatorDiag, self).__init__(
- dtype=self._diag.dtype,
+ dtype=dtype,
graph_parents=[self._diag],
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
- is_positive_definite=is_non_singular,
+ is_positive_definite=is_positive_definite,
name=name)
def _shape(self):
@@ -153,20 +168,42 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
return array_ops.concat(0, (d_shape, [k]))
def _assert_non_singular(self):
+ if self.dtype.is_complex:
+ should_be_nonzero = math_ops.complex_abs(self._diag)
+ else:
+ should_be_nonzero = self._diag
+
nonzero_diag = math_ops.reduce_all(
- math_ops.logical_not(math_ops.equal(self._diag, 0)))
+ math_ops.logical_not(math_ops.equal(should_be_nonzero, 0)))
+
return control_flow_ops.Assert(
nonzero_diag,
data=["Singular operator: diag contained zero values.", self._diag])
def _assert_positive_definite(self):
+ if self.dtype.is_complex:
+ message = (
+ "Diagonal operator had diagonal entries with non-positive real part, "
+ "thus was not positive definite.")
+ else:
+ message = (
+ "Real diagonal operator had non-positive diagonal entries, "
+ "thus was not positive definite.")
+
return check_ops.assert_positive(
+ math_ops.real(self._diag),
+ message=message)
+
+ def _assert_self_adjoint(self):
+ return _assert_imag_part_zero(
self._diag,
- message="Operator was not positive definite: diag was not all positive")
+ message=(
+ "This diagonal operator contained non-zero imaginary values. "
+ " Thus it was not self-adjoint."))
def _apply(self, x, adjoint=False):
- # adjoint has no effect since this matrix is self-adjoint.
- diag_mat = array_ops.expand_dims(self._diag, -1)
+ diag_term = math_ops.conj(self._diag) if adjoint else self._diag
+ diag_mat = array_ops.expand_dims(diag_term, -1)
return diag_mat * x
def _determinant(self):
@@ -177,9 +214,29 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1])
def _solve(self, rhs, adjoint=False):
- # adjoint has no effect since this matrix is self-adjoint.
- inv_diag_mat = array_ops.expand_dims(1. / self._diag, -1)
+ diag_term = math_ops.conj(self._diag) if adjoint else self._diag
+ inv_diag_mat = array_ops.expand_dims(1. / diag_term, -1)
return rhs * inv_diag_mat
def _to_dense(self):
return array_ops.matrix_diag(self._diag)
+
+ def _add_to_tensor(self, x):
+ x_diag = array_ops.matrix_diag_part(x)
+ new_diag = self._diag + x_diag
+ return array_ops.matrix_set_diag(x, new_diag)
+
+
+def _assert_imag_part_zero(x, message=None):
+ """Assert that floating or complex 'x' is real."""
+ dtype = x.dtype.base_dtype
+ if dtype.is_floating:
+ return control_flow_ops.no_op()
+
+ if not dtype.is_complex:
+ raise TypeError(
+ "imag_part_zero only handles float or complex types. Found: %s"
+ % dtype)
+
+ zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype)
+ return check_ops.assert_equal(zero, math_ops.imag(x), message=message)
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
index adbdb9b3d2..20136bfbd0 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
@@ -31,10 +31,25 @@ class LinearOperatorDerivedClassTest(tf.test.TestCase):
test methods to work.
"""
- @abc.abstractproperty
+ # Absolute/relative tolerance for tests.
+ _atol = {
+ tf.float16: 1e-3, tf.float32: 1e-6, tf.float64: 1e-12, tf.complex64: 1e-6,
+ tf.complex128: 1e-12}
+ _rtol = {
+ tf.float16: 1e-3, tf.float32: 1e-6, tf.float64: 1e-12, tf.complex64: 1e-6,
+ tf.complex128: 1e-12}
+
+ def assertAC(self, x, y):
+ """Derived classes can set _atol, _rtol to get different tolerance."""
+ dtype = tf.as_dtype(x.dtype)
+ atol = self._atol[dtype]
+ rtol = self._rtol[dtype]
+ self.assertAllClose(x, y, atol=atol, rtol=rtol)
+
+ @property
def _dtypes_to_test(self):
- """Returns list of numpy or tensorflow dtypes. Each will be tested."""
- raise NotImplementedError("dtypes_to_test has not been implemented.")
+ # TODO(langmore) Test tf.float16 once tf.matrix_diag works in 16bit.
+ return [tf.float32, tf.float64, tf.complex64, tf.complex128]
@abc.abstractproperty
def _shapes_to_test(self):
@@ -57,8 +72,9 @@ class LinearOperatorDerivedClassTest(tf.test.TestCase):
Returns:
operator: `LinearOperator` subclass instance.
mat: `Tensor` representing operator.
- feed_dict: Dictionary. If placholder is True, this must be fed to
- sess.run calls at runtime to make the operator work.
+ feed_dict: Dictionary.
+ If placholder is True, this must contains everything needed to be fed
+ to sess.run calls at runtime to make the operator work.
"""
# Create a matrix as a numpy array with desired shape/dtype.
# Create a LinearOperator that should have the same behavior as the matrix.
@@ -74,107 +90,145 @@ class LinearOperatorDerivedClassTest(tf.test.TestCase):
"""Make a rhs appropriate for calling operator.apply(rhs)."""
raise NotImplementedError("_make_x is not defined.")
- def _maybe_adjoint(self, x, adjoint):
- if adjoint:
- return tf.matrix_transpose(x)
- else:
- return x
+ @property
+ def _tests_to_skip(self):
+ """List of test names to skip."""
+ # Subclasses should over-ride if they want to skip some tests.
+ # To skip "test_foo", add "foo" to this list.
+ return []
+
+ def _maybe_skip(self, test_name):
+ if test_name in self._tests_to_skip:
+ self.skipTest("%s skipped because it was added to self._tests_to_skip.")
def test_to_dense(self):
+ self._maybe_skip("to_dense")
with self.test_session() as sess:
for shape in self._shapes_to_test:
for dtype in self._dtypes_to_test:
- operator, mat, _ = self._operator_and_mat_and_feed_dict(
- shape, dtype, use_placeholder=False)
- op_dense = operator.to_dense()
- self.assertAllEqual(shape, op_dense.get_shape())
- op_dense_v, mat_v = sess.run([op_dense, mat])
- self.assertAllClose(op_dense_v, mat_v)
-
- def test_to_dense_dynamic(self):
- with self.test_session() as sess:
- for shape in self._shapes_to_test:
- for dtype in self._dtypes_to_test:
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
- shape, dtype, use_placeholder=True)
- op_dense_v, mat_v = sess.run(
- [operator.to_dense(), mat], feed_dict=feed_dict)
- self.assertAllClose(op_dense_v, mat_v)
+ for use_placeholder in False, True:
+ operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ shape, dtype, use_placeholder=use_placeholder)
+ op_dense = operator.to_dense()
+ if not use_placeholder:
+ self.assertAllEqual(shape, op_dense.get_shape())
+ op_dense_v, mat_v = sess.run([op_dense, mat], feed_dict=feed_dict)
+ self.assertAC(op_dense_v, mat_v)
def test_det(self):
+ self._maybe_skip("det")
with self.test_session() as sess:
for shape in self._shapes_to_test:
for dtype in self._dtypes_to_test:
- operator, mat, _ = self._operator_and_mat_and_feed_dict(
- shape, dtype, use_placeholder=False)
- op_det = operator.determinant()
- self.assertAllEqual(shape[:-2], op_det.get_shape())
- op_det_v, mat_det_v = sess.run([op_det, tf.matrix_determinant(mat)])
- self.assertAllClose(op_det_v, mat_det_v)
-
- def test_det_dynamic(self):
- with self.test_session() as sess:
- for shape in self._shapes_to_test:
- for dtype in self._dtypes_to_test:
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
- shape, dtype, use_placeholder=True)
- op_det_v, mat_det_v = sess.run(
- [operator.determinant(), tf.matrix_determinant(mat)],
- feed_dict=feed_dict)
- self.assertAllClose(op_det_v, mat_det_v)
+ if dtype.is_complex:
+ self.skipTest(
+ "tf.matrix_determinant does not work with complex, so this test"
+ " is being skipped.")
+ for use_placeholder in False, True:
+ operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ shape, dtype, use_placeholder=use_placeholder)
+ op_det = operator.determinant()
+ if not use_placeholder:
+ self.assertAllEqual(shape[:-2], op_det.get_shape())
+ op_det_v, mat_det_v = sess.run(
+ [op_det, tf.matrix_determinant(mat)], feed_dict=feed_dict)
+ self.assertAC(op_det_v, mat_det_v)
def test_apply(self):
+ self._maybe_skip("apply")
with self.test_session() as sess:
for shape in self._shapes_to_test:
for dtype in self._dtypes_to_test:
- operator, mat, _ = self._operator_and_mat_and_feed_dict(
- shape, dtype, use_placeholder=False)
- for adjoint in [False, True]:
- if adjoint and operator.is_self_adjoint:
- continue
- x = self._make_x(operator)
- op_apply = operator.apply(x, adjoint=adjoint)
- mat_apply = tf.batch_matmul(self._maybe_adjoint(mat, adjoint), x)
- self.assertAllEqual(op_apply.get_shape(), mat_apply.get_shape())
- op_apply_v, mat_apply_v = sess.run([op_apply, mat_apply])
- self.assertAllClose(op_apply_v, mat_apply_v)
-
- def test_apply_dynamic(self):
- with self.test_session() as sess:
- for shape in self._shapes_to_test:
- for dtype in self._dtypes_to_test:
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
- shape, dtype, use_placeholder=True)
- x = self._make_x(operator)
- op_apply_v, mat_apply_v = sess.run(
- [operator.apply(x), tf.batch_matmul(mat, x)],
- feed_dict=feed_dict)
- self.assertAllClose(op_apply_v, mat_apply_v)
+ for use_placeholder in False, True:
+ for adjoint in [False, True]:
+ operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ shape, dtype, use_placeholder=use_placeholder)
+ x = self._make_x(operator)
+ op_apply = operator.apply(x, adjoint=adjoint)
+ mat_apply = tf.matmul(mat, x, adjoint_a=adjoint)
+ if not use_placeholder:
+ self.assertAllEqual(op_apply.get_shape(), mat_apply.get_shape())
+ op_apply_v, mat_apply_v = sess.run(
+ [op_apply, mat_apply], feed_dict=feed_dict)
+ self.assertAC(op_apply_v, mat_apply_v)
def test_solve(self):
+ self._maybe_skip("solve")
with self.test_session() as sess:
for shape in self._shapes_to_test:
for dtype in self._dtypes_to_test:
- operator, mat, _ = self._operator_and_mat_and_feed_dict(
- shape, dtype, use_placeholder=False)
- for adjoint in [False, True]:
- if adjoint and operator.is_self_adjoint:
- continue
- rhs = self._make_rhs(operator)
- op_solve = operator.solve(rhs, adjoint=adjoint)
- mat_solve = tf.matrix_solve(self._maybe_adjoint(mat, adjoint), rhs)
- self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape())
- op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve])
- self.assertAllClose(op_solve_v, mat_solve_v)
-
- def test_solve_dynamic(self):
+ for use_placeholder in False, True:
+ for adjoint in [False, True]:
+ operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ shape, dtype, use_placeholder=use_placeholder)
+ rhs = self._make_rhs(operator)
+ op_solve = operator.solve(rhs, adjoint=adjoint)
+ mat_solve = tf.matrix_solve(mat, rhs, adjoint=adjoint)
+ if not use_placeholder:
+ self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape())
+ op_solve_v, mat_solve_v = sess.run(
+ [op_solve, mat_solve], feed_dict=feed_dict)
+ self.assertAC(op_solve_v, mat_solve_v)
+
+ def test_add_to_tensor(self):
+ self._maybe_skip("add_to_tensor")
with self.test_session() as sess:
for shape in self._shapes_to_test:
for dtype in self._dtypes_to_test:
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
- shape, dtype, use_placeholder=True)
- rhs = self._make_rhs(operator)
- op_solve_v, mat_solve_v = sess.run(
- [operator.solve(rhs), tf.matrix_solve(mat, rhs)],
- feed_dict=feed_dict)
- self.assertAllClose(op_solve_v, mat_solve_v)
+ for use_placeholder in False, True:
+ operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ shape, dtype, use_placeholder=use_placeholder)
+ op_plus_2mat = operator.add_to_tensor(2 * mat)
+
+ if not use_placeholder:
+ self.assertAllEqual(shape, op_plus_2mat.get_shape())
+
+ op_plus_2mat_v, mat_v = sess.run(
+ [op_plus_2mat, mat], feed_dict=feed_dict)
+
+ self.assertAC(op_plus_2mat_v, 3 * mat_v)
+
+
+@six.add_metaclass(abc.ABCMeta)
+class SquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest):
+ """Base test class appropriate for square operators.
+
+ Sub-classes must still define all abstractmethods from
+ LinearOperatorDerivedClassTest that are not defined here.
+ """
+
+ @property
+ def _shapes_to_test(self):
+ # non-batch operators (n, n) and batch operators.
+ return [(0, 0), (1, 1), (1, 3, 3), (3, 4, 4), (2, 1, 4, 4)]
+
+ def _make_rhs(self, operator):
+ # This operator is square, so rhs and x will have same shape.
+ return self._make_x(operator)
+
+ def _make_x(self, operator):
+ # Return the number of systems to solve, R, equal to 1 or 2.
+ r = self._get_num_systems(operator)
+ # If operator.shape = [B1,...,Bb, N, N] this returns a random matrix of
+ # shape [B1,...,Bb, N, R], R = 1 or 2.
+ if operator.shape.is_fully_defined():
+ batch_shape = operator.batch_shape.as_list()
+ n = operator.domain_dimension.value
+ rhs_shape = batch_shape + [n, r]
+ else:
+ batch_shape = operator.batch_shape_dynamic()
+ n = operator.domain_dimension_dynamic()
+ rhs_shape = tf.concat(0, (batch_shape, [n, r]))
+
+ x = tf.random_normal(shape=rhs_shape, dtype=operator.dtype.real_dtype)
+ if operator.dtype.is_complex:
+ x = tf.complex(
+ x, tf.random_normal(shape=rhs_shape, dtype=operator.dtype.real_dtype))
+ return x
+
+ def _get_num_systems(self, operator):
+ """Get some number, either 1 or 2, depending on operator."""
+ if operator.tensor_rank is None or operator.tensor_rank % 2:
+ return 1
+ else:
+ return 2
diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD
index c607707bbb..3b68c51413 100644
--- a/tensorflow/contrib/linear_optimizer/BUILD
+++ b/tensorflow/contrib/linear_optimizer/BUILD
@@ -17,7 +17,8 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/lookup:lookup_py",
+ ":sharded_mutable_dense_hashtable_py",
+ ":sparse_feature_column_py",
],
)
@@ -34,6 +35,47 @@ py_test(
],
)
+py_library(
+ name = "sharded_mutable_dense_hashtable_py",
+ srcs = ["python/ops/sharded_mutable_dense_hashtable.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/lookup:lookup_py",
+ ],
+)
+
+py_test(
+ name = "sharded_mutable_dense_hashtable_test",
+ size = "small",
+ srcs = ["python/ops/sharded_mutable_dense_hashtable_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":sharded_mutable_dense_hashtable_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+py_library(
+ name = "sparse_feature_column_py",
+ srcs = ["python/ops/sparse_feature_column.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "sparse_feature_column_test",
+ size = "small",
+ srcs = ["python/ops/sparse_feature_column_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":sparse_feature_column_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/linear_optimizer/__init__.py b/tensorflow/contrib/linear_optimizer/__init__.py
index 40445c456f..83bd8b5fcf 100644
--- a/tensorflow/contrib/linear_optimizer/__init__.py
+++ b/tensorflow/contrib/linear_optimizer/__init__.py
@@ -23,5 +23,5 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel
-from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SparseFeatureColumn
+from tensorflow.contrib.linear_optimizer.python.ops.sparse_feature_column import SparseFeatureColumn
from tensorflow.contrib.linear_optimizer.python.sdca_optimizer import SDCAOptimizer
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index 40a6404881..8e918e8880 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -22,9 +22,8 @@ from threading import Thread
import tensorflow as tf
-from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import _ShardedMutableDenseHashTable
from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel
-from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SparseFeatureColumn
+from tensorflow.contrib.linear_optimizer.python.ops.sparse_feature_column import SparseFeatureColumn
from tensorflow.python.framework.test_util import TensorFlowTestCase
from tensorflow.python.ops import gen_sdca_ops
from tensorflow.python.platform import googletest
@@ -980,27 +979,6 @@ class SdcaWithSmoothHingeLossTest(SdcaModelTest):
self.assertAllClose(0.44, regularized_loss.eval(), atol=0.02)
-class SparseFeatureColumnTest(SdcaModelTest):
- """Tests for SparseFeatureColumn.
- """
-
- def testBasic(self):
- expected_example_indices = [1, 1, 1, 2]
- expected_feature_indices = [0, 1, 2, 0]
- sfc = SparseFeatureColumn(expected_example_indices,
- expected_feature_indices, None)
- self.assertTrue(isinstance(sfc.example_indices, tf.Tensor))
- self.assertTrue(isinstance(sfc.feature_indices, tf.Tensor))
- self.assertEqual(sfc.feature_values, None)
- with self._single_threaded_test_session():
- self.assertAllEqual(expected_example_indices, sfc.example_indices.eval())
- self.assertAllEqual(expected_feature_indices, sfc.feature_indices.eval())
- expected_feature_values = [1.0, 2.0, 3.0, 4.0]
- sfc = SparseFeatureColumn([1, 1, 1, 2], [0, 1, 2, 0],
- expected_feature_values)
- with self._single_threaded_test_session():
- self.assertAllEqual(expected_feature_values, sfc.feature_values.eval())
-
class SdcaFprintTest(SdcaModelTest):
"""Tests for the SdcaFprint op.
@@ -1020,74 +998,5 @@ class SdcaFprintTest(SdcaModelTest):
[603227410218889250, 8762207001949257490]],
out_data.eval())
-
-class ShardedMutableDenseHashTableTest(SdcaModelTest):
- """Tests for the _ShardedMutableHashTable class."""
-
- def testShardedMutableHashTable(self):
- for num_shards in [1, 3, 10]:
- with self._single_threaded_test_session():
- default_val = -1
- empty_key = 0
- keys = tf.constant([11, 12, 13], tf.int64)
- values = tf.constant([0, 1, 2], tf.int64)
- table = _ShardedMutableDenseHashTable(
- tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards)
- self.assertAllEqual(0, table.size().eval())
-
- table.insert(keys, values).run()
- self.assertAllEqual(3, table.size().eval())
-
- input_string = tf.constant([11, 12, 14], tf.int64)
- output = table.lookup(input_string)
- self.assertAllEqual([3], output.get_shape())
- self.assertAllEqual([0, 1, -1], output.eval())
-
- def testShardedMutableHashTableVectors(self):
- for num_shards in [1, 3, 10]:
- with self._single_threaded_test_session():
- default_val = [-0.1, 0.2]
- empty_key = [0, 1]
- keys = tf.constant([[11, 12], [13, 14], [15, 16]], tf.int64)
- values = tf.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]], tf.float32)
- table = _ShardedMutableDenseHashTable(
- tf.int64, tf.float32, default_val, empty_key, num_shards=num_shards)
- self.assertAllEqual(0, table.size().eval())
-
- table.insert(keys, values).run()
- self.assertAllEqual(3, table.size().eval())
-
- input_string = tf.constant([[11, 12], [13, 14], [11, 14]], tf.int64)
- output = table.lookup(input_string)
- self.assertAllEqual([3, 2], output.get_shape())
- self.assertAllClose([[0.5, 0.6], [1.5, 1.6], [-0.1, 0.2]],
- output.eval())
-
- def testExportSharded(self):
- with self._single_threaded_test_session():
- empty_key = -2
- default_val = -1
- num_shards = 2
- keys = tf.constant([10, 11, 12], tf.int64)
- values = tf.constant([2, 3, 4], tf.int64)
- table = _ShardedMutableDenseHashTable(
- tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards)
- self.assertAllEqual(0, table.size().eval())
-
- table.insert(keys, values).run()
- self.assertAllEqual(3, table.size().eval())
-
- keys_list, values_list = table.export_sharded()
- self.assertAllEqual(num_shards, len(keys_list))
- self.assertAllEqual(num_shards, len(values_list))
-
- # Exported keys include empty key buckets set to the empty_key
- self.assertAllEqual(set([-2, 10, 12]), set(keys_list[0].eval().flatten()))
- self.assertAllEqual(set([-2, 11]), set(keys_list[1].eval().flatten()))
- # Exported values include empty value buckets set to 0
- self.assertAllEqual(set([0, 2, 4]), set(values_list[0].eval().flatten()))
- self.assertAllEqual(set([0, 3]), set(values_list[1].eval().flatten()))
-
-
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 7143520e3f..415aa752ac 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -22,16 +22,14 @@ import collections
from six.moves import range
-from tensorflow.contrib.lookup import lookup_ops
+from tensorflow.contrib.linear_optimizer.python.ops.sharded_mutable_dense_hashtable import ShardedMutableDenseHashTable
from tensorflow.python import summary
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework.ops import internal_convert_to_tensor
from tensorflow.python.framework.ops import name_scope
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_sdca_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
@@ -41,233 +39,6 @@ from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits
__all__ = ['SdcaModel']
-class _ShardedMutableDenseHashTable(lookup_ops.LookupInterface):
- """A sharded version of MutableDenseHashTable.
-
- It is designed to be interface compatible with LookupInterface and
- MutableDenseHashTable, with the exception of the export method, which is
- replaced by an export_sharded method.
-
- The _ShardedMutableDenseHashTable keeps `num_shards` MutableDenseHashTable
- internally. The shard is computed via the modulo operation on the key.
- """
-
- # TODO(andreasst): consider moving this to lookup_ops
-
- def __init__(self,
- key_dtype,
- value_dtype,
- default_value,
- empty_key,
- num_shards=1,
- name='ShardedMutableHashTable'):
- with ops.name_scope(name, 'sharded_mutable_hash_table') as scope:
- super(_ShardedMutableDenseHashTable, self).__init__(key_dtype,
- value_dtype, scope)
- table_shards = []
- for i in range(num_shards):
- table_shards.append(
- lookup_ops.MutableDenseHashTable(
- key_dtype=key_dtype,
- value_dtype=value_dtype,
- default_value=default_value,
- empty_key=empty_key,
- name='%s-%d-of-%d' % (name, i + 1, num_shards)))
- self._table_shards = table_shards
- # TODO(andreasst): add a value_shape() method to LookupInterface
- # pylint: disable=protected-access
- self._value_shape = self._table_shards[0]._value_shape
- # pylint: enable=protected-access
-
- @property
- def _num_shards(self):
- return len(self._table_shards)
-
- @property
- def table_shards(self):
- return self._table_shards
-
- def size(self, name=None):
- with ops.name_scope(name, 'sharded_mutable_hash_table_size'):
- sizes = [
- self._table_shards[i].size() for i in range(self._num_shards)
- ]
- return math_ops.add_n(sizes)
-
- def _shard_indices(self, keys):
- key_shape = keys.get_shape()
- if key_shape.ndims > 1:
- # If keys are a matrix (i.e. a single key is a vector), we use the first
- # element of each key vector to determine the shard.
- keys = array_ops.slice(keys, [0, 0], [key_shape[0].value, 1])
- keys = array_ops.reshape(keys, [-1])
- indices = math_ops.mod(math_ops.abs(keys), self._num_shards)
- return math_ops.cast(indices, dtypes.int32)
-
- def _check_keys(self, keys):
- if not keys.get_shape().is_fully_defined():
- raise ValueError('Key shape must be fully defined, got %s.' %
- keys.get_shape())
- if keys.get_shape().ndims != 1 and keys.get_shape().ndims != 2:
- raise ValueError('Expected a vector or matrix for keys, got %s.' %
- keys.get_shape())
-
- def lookup(self, keys, name=None):
- if keys.dtype != self._key_dtype:
- raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' %
- (self._key_dtype, keys.dtype))
- self._check_keys(keys)
- num_shards = self._num_shards
- if num_shards == 1:
- return self._table_shards[0].lookup(keys, name=name)
-
- shard_indices = self._shard_indices(keys)
- # TODO(andreasst): support 'keys' that are not vectors
- key_shards = data_flow_ops.dynamic_partition(keys, shard_indices,
- num_shards)
- value_shards = [
- self._table_shards[i].lookup(key_shards[i], name=name)
- for i in range(num_shards)
- ]
-
- num_keys = keys.get_shape().dims[0]
- original_indices = math_ops.range(num_keys)
- partitioned_indices = data_flow_ops.dynamic_partition(original_indices,
- shard_indices,
- num_shards)
- result = data_flow_ops.dynamic_stitch(partitioned_indices, value_shards)
- result.set_shape(
- tensor_shape.TensorShape([num_keys]).concatenate(self._value_shape))
- return result
-
- def insert(self, keys, values, name=None):
- self._check_keys(keys)
- num_shards = self._num_shards
- if num_shards == 1:
- return self._table_shards[0].insert(keys, values, name=name)
-
- shard_indices = self._shard_indices(keys)
- # TODO(andreasst): support 'keys' that are not vectors
- key_shards = data_flow_ops.dynamic_partition(keys, shard_indices,
- num_shards)
- value_shards = data_flow_ops.dynamic_partition(values, shard_indices,
- num_shards)
- return_values = [
- self._table_shards[i].insert(key_shards[i], value_shards[i], name=name)
- for i in range(num_shards)
- ]
-
- return control_flow_ops.group(*return_values)
-
- def export_sharded(self, name=None):
- """Returns lists of the keys and values tensors in the sharded table.
-
- Returns:
- A pair of lists with the first list containing the key tensors and the
- second list containing the value tensors from each shard.
- """
- keys_list = []
- values_list = []
- for table_shard in self._table_shards:
- exported_keys, exported_values = table_shard.export(name=name)
- keys_list.append(exported_keys)
- values_list.append(exported_values)
- return keys_list, values_list
-
-
-class SparseFeatureColumn(object):
- """Represents a sparse feature column.
-
- Contains three tensors representing a sparse feature column, they are
- example indices (int64), feature indices (int64), and feature values (float).
- Feature weights are optional, and are treated as 1.0f if missing.
-
- For example, consider a batch of 4 examples, which contains the following
- features in a particular SparseFeatureColumn:
- Example 0: feature 5, value 1
- Example 1: feature 6, value 1 and feature 10, value 0.5
- Example 2: no features
- Example 3: two copies of feature 2, value 1
-
- This SparseFeatureColumn will be represented as follows:
- <0, 5, 1>
- <1, 6, 1>
- <1, 10, 0.5>
- <3, 2, 1>
- <3, 2, 1>
-
- For a batch of 2 examples below:
- Example 0: feature 5
- Example 1: feature 6
-
- is represented by SparseFeatureColumn as:
- <0, 5, 1>
- <1, 6, 1>
-
- ```
-
- @@__init__
- @@example_indices
- @@feature_indices
- @@feature_values
- """
-
- def __init__(self, example_indices, feature_indices, feature_values):
- """Creates a `SparseFeatureColumn` representation.
-
- Args:
- example_indices: A 1-D int64 tensor of shape `[N]`. Also, accepts
- python lists, or numpy arrays.
- feature_indices: A 1-D int64 tensor of shape `[N]`. Also, accepts
- python lists, or numpy arrays.
- feature_values: An optional 1-D tensor float tensor of shape `[N]`. Also,
- accepts python lists, or numpy arrays.
-
- Returns:
- A `SparseFeatureColumn`
- """
- with name_scope(None, 'SparseFeatureColumn',
- [example_indices, feature_indices]):
- self._example_indices = internal_convert_to_tensor(example_indices,
- name='example_indices',
- dtype=dtypes.int64)
- self._feature_indices = internal_convert_to_tensor(feature_indices,
- name='feature_indices',
- dtype=dtypes.int64)
- self._feature_values = None
- if feature_values is not None:
- with name_scope(None, 'SparseFeatureColumn', [feature_values]):
- self._feature_values = internal_convert_to_tensor(feature_values,
- name='feature_values',
- dtype=dtypes.float32)
-
- @property
- def example_indices(self):
- """The example indices represented as a dense tensor.
-
- Returns:
- A 1-D Tensor of int64 with shape `[N]`.
- """
- return self._example_indices
-
- @property
- def feature_indices(self):
- """The feature indices represented as a dense tensor.
-
- Returns:
- A 1-D Tensor of int64 with shape `[N]`.
- """
- return self._feature_indices
-
- @property
- def feature_values(self):
- """The feature values represented as a dense tensor.
-
- Returns:
- May return None, or a 1-D Tensor of float32 with shape `[N]`.
- """
- return self._feature_values
-
# TODO(sibyl-Aix6ihai): add name_scope to appropriate methods.
class SdcaModel(object):
@@ -372,7 +143,7 @@ class SdcaModel(object):
self._variables = variables
self._options = options
self._create_slots()
- self._hashtable = _ShardedMutableDenseHashTable(
+ self._hashtable = ShardedMutableDenseHashTable(
key_dtype=dtypes.int64,
value_dtype=dtypes.float32,
num_shards=self._num_table_shards(),
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py
new file mode 100644
index 0000000000..494dfb6c99
--- /dev/null
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py
@@ -0,0 +1,167 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Sharded mutable dense hash table."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six.moves import range
+
+from tensorflow.contrib.lookup import lookup_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import math_ops
+
+
+class ShardedMutableDenseHashTable(lookup_ops.LookupInterface):
+ """A sharded version of MutableDenseHashTable.
+
+ It is designed to be interface compatible with LookupInterface and
+ MutableDenseHashTable, with the exception of the export method, which is
+ replaced by an export_sharded method.
+
+ The _ShardedMutableDenseHashTable keeps `num_shards` MutableDenseHashTable
+ internally. The shard is computed via the modulo operation on the key.
+ """
+
+ # TODO(andreasst): consider moving this to lookup_ops
+
+ def __init__(self,
+ key_dtype,
+ value_dtype,
+ default_value,
+ empty_key,
+ num_shards=1,
+ name='ShardedMutableHashTable'):
+ with ops.name_scope(name, 'sharded_mutable_hash_table') as scope:
+ super(ShardedMutableDenseHashTable, self).__init__(key_dtype,
+ value_dtype, scope)
+ table_shards = []
+ for i in range(num_shards):
+ table_shards.append(
+ lookup_ops.MutableDenseHashTable(
+ key_dtype=key_dtype,
+ value_dtype=value_dtype,
+ default_value=default_value,
+ empty_key=empty_key,
+ name='%s-%d-of-%d' % (name, i + 1, num_shards)))
+ self._table_shards = table_shards
+ # TODO(andreasst): add a value_shape() method to LookupInterface
+ # pylint: disable=protected-access
+ self._value_shape = self._table_shards[0]._value_shape
+ # pylint: enable=protected-access
+
+ @property
+ def _num_shards(self):
+ return len(self._table_shards)
+
+ @property
+ def table_shards(self):
+ return self._table_shards
+
+ def size(self, name=None):
+ with ops.name_scope(name, 'sharded_mutable_hash_table_size'):
+ sizes = [
+ self._table_shards[i].size() for i in range(self._num_shards)
+ ]
+ return math_ops.add_n(sizes)
+
+ def _shard_indices(self, keys):
+ key_shape = keys.get_shape()
+ if key_shape.ndims > 1:
+ # If keys are a matrix (i.e. a single key is a vector), we use the first
+ # element of each key vector to determine the shard.
+ keys = array_ops.slice(keys, [0, 0], [key_shape[0].value, 1])
+ keys = array_ops.reshape(keys, [-1])
+ indices = math_ops.mod(math_ops.abs(keys), self._num_shards)
+ return math_ops.cast(indices, dtypes.int32)
+
+ def _check_keys(self, keys):
+ if not keys.get_shape().is_fully_defined():
+ raise ValueError('Key shape must be fully defined, got %s.' %
+ keys.get_shape())
+ if keys.get_shape().ndims != 1 and keys.get_shape().ndims != 2:
+ raise ValueError('Expected a vector or matrix for keys, got %s.' %
+ keys.get_shape())
+
+ def lookup(self, keys, name=None):
+ if keys.dtype != self._key_dtype:
+ raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' %
+ (self._key_dtype, keys.dtype))
+ self._check_keys(keys)
+ num_shards = self._num_shards
+ if num_shards == 1:
+ return self._table_shards[0].lookup(keys, name=name)
+
+ shard_indices = self._shard_indices(keys)
+ # TODO(andreasst): support 'keys' that are not vectors
+ key_shards = data_flow_ops.dynamic_partition(keys, shard_indices,
+ num_shards)
+ value_shards = [
+ self._table_shards[i].lookup(key_shards[i], name=name)
+ for i in range(num_shards)
+ ]
+
+ num_keys = keys.get_shape().dims[0]
+ original_indices = math_ops.range(num_keys)
+ partitioned_indices = data_flow_ops.dynamic_partition(original_indices,
+ shard_indices,
+ num_shards)
+ result = data_flow_ops.dynamic_stitch(partitioned_indices, value_shards)
+ result.set_shape(
+ tensor_shape.TensorShape([num_keys]).concatenate(self._value_shape))
+ return result
+
+ def insert(self, keys, values, name=None):
+ self._check_keys(keys)
+ num_shards = self._num_shards
+ if num_shards == 1:
+ return self._table_shards[0].insert(keys, values, name=name)
+
+ shard_indices = self._shard_indices(keys)
+ # TODO(andreasst): support 'keys' that are not vectors
+ key_shards = data_flow_ops.dynamic_partition(keys, shard_indices,
+ num_shards)
+ value_shards = data_flow_ops.dynamic_partition(values, shard_indices,
+ num_shards)
+ return_values = [
+ self._table_shards[i].insert(key_shards[i], value_shards[i], name=name)
+ for i in range(num_shards)
+ ]
+
+ return control_flow_ops.group(*return_values)
+
+ def export_sharded(self, name=None):
+ """Returns lists of the keys and values tensors in the sharded table.
+
+ Args:
+ name: name of the table.
+
+ Returns:
+ A pair of lists with the first list containing the key tensors and the
+ second list containing the value tensors from each shard.
+ """
+ keys_list = []
+ values_list = []
+ for table_shard in self._table_shards:
+ exported_keys, exported_values = table_shard.export(name=name)
+ keys_list.append(exported_keys)
+ values_list.append(exported_values)
+ return keys_list, values_list
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
new file mode 100644
index 0000000000..8c83700d51
--- /dev/null
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
@@ -0,0 +1,97 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for sharded_mutable_dense_hashtable.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.linear_optimizer.python.ops.sharded_mutable_dense_hashtable import ShardedMutableDenseHashTable
+from tensorflow.python.framework.test_util import TensorFlowTestCase
+from tensorflow.python.platform import googletest
+
+
+class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
+ """Tests for the ShardedMutableHashTable class."""
+
+ def testShardedMutableHashTable(self):
+ for num_shards in [1, 3, 10]:
+ with self.test_session():
+ default_val = -1
+ empty_key = 0
+ keys = tf.constant([11, 12, 13], tf.int64)
+ values = tf.constant([0, 1, 2], tf.int64)
+ table = ShardedMutableDenseHashTable(
+ tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards)
+ self.assertAllEqual(0, table.size().eval())
+
+ table.insert(keys, values).run()
+ self.assertAllEqual(3, table.size().eval())
+
+ input_string = tf.constant([11, 12, 14], tf.int64)
+ output = table.lookup(input_string)
+ self.assertAllEqual([3], output.get_shape())
+ self.assertAllEqual([0, 1, -1], output.eval())
+
+ def testShardedMutableHashTableVectors(self):
+ for num_shards in [1, 3, 10]:
+ with self.test_session():
+ default_val = [-0.1, 0.2]
+ empty_key = [0, 1]
+ keys = tf.constant([[11, 12], [13, 14], [15, 16]], tf.int64)
+ values = tf.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]], tf.float32)
+ table = ShardedMutableDenseHashTable(
+ tf.int64, tf.float32, default_val, empty_key, num_shards=num_shards)
+ self.assertAllEqual(0, table.size().eval())
+
+ table.insert(keys, values).run()
+ self.assertAllEqual(3, table.size().eval())
+
+ input_string = tf.constant([[11, 12], [13, 14], [11, 14]], tf.int64)
+ output = table.lookup(input_string)
+ self.assertAllEqual([3, 2], output.get_shape())
+ self.assertAllClose([[0.5, 0.6], [1.5, 1.6], [-0.1, 0.2]],
+ output.eval())
+
+ def testExportSharded(self):
+ with self.test_session():
+ empty_key = -2
+ default_val = -1
+ num_shards = 2
+ keys = tf.constant([10, 11, 12], tf.int64)
+ values = tf.constant([2, 3, 4], tf.int64)
+ table = ShardedMutableDenseHashTable(
+ tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards)
+ self.assertAllEqual(0, table.size().eval())
+
+ table.insert(keys, values).run()
+ self.assertAllEqual(3, table.size().eval())
+
+ keys_list, values_list = table.export_sharded()
+ self.assertAllEqual(num_shards, len(keys_list))
+ self.assertAllEqual(num_shards, len(values_list))
+
+ # Exported keys include empty key buckets set to the empty_key
+ self.assertAllEqual(set([-2, 10, 12]), set(keys_list[0].eval().flatten()))
+ self.assertAllEqual(set([-2, 11]), set(keys_list[1].eval().flatten()))
+ # Exported values include empty value buckets set to 0
+ self.assertAllEqual(set([0, 2, 4]), set(values_list[0].eval().flatten()))
+ self.assertAllEqual(set([0, 3]), set(values_list[1].eval().flatten()))
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py
new file mode 100644
index 0000000000..ed7105b5c9
--- /dev/null
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py
@@ -0,0 +1,114 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Sparse feature column."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework.ops import internal_convert_to_tensor
+from tensorflow.python.framework.ops import name_scope
+
+
+class SparseFeatureColumn(object):
+ """Represents a sparse feature column.
+
+ Contains three tensors representing a sparse feature column, they are
+ example indices (int64), feature indices (int64), and feature values (float).
+ Feature weights are optional, and are treated as 1.0f if missing.
+
+ For example, consider a batch of 4 examples, which contains the following
+ features in a particular SparseFeatureColumn:
+ Example 0: feature 5, value 1
+ Example 1: feature 6, value 1 and feature 10, value 0.5
+ Example 2: no features
+ Example 3: two copies of feature 2, value 1
+
+ This SparseFeatureColumn will be represented as follows:
+ <0, 5, 1>
+ <1, 6, 1>
+ <1, 10, 0.5>
+ <3, 2, 1>
+ <3, 2, 1>
+
+ For a batch of 2 examples below:
+ Example 0: feature 5
+ Example 1: feature 6
+
+ is represented by SparseFeatureColumn as:
+ <0, 5, 1>
+ <1, 6, 1>
+
+ ```
+
+ @@__init__
+ @@example_indices
+ @@feature_indices
+ @@feature_values
+ """
+
+ def __init__(self, example_indices, feature_indices, feature_values):
+ """Creates a `SparseFeatureColumn` representation.
+
+ Args:
+ example_indices: A 1-D int64 tensor of shape `[N]`. Also, accepts
+ python lists, or numpy arrays.
+ feature_indices: A 1-D int64 tensor of shape `[N]`. Also, accepts
+ python lists, or numpy arrays.
+ feature_values: An optional 1-D tensor float tensor of shape `[N]`. Also,
+ accepts python lists, or numpy arrays.
+
+ Returns:
+ A `SparseFeatureColumn`
+ """
+ with name_scope(None, 'SparseFeatureColumn',
+ [example_indices, feature_indices]):
+ self._example_indices = internal_convert_to_tensor(
+ example_indices, name='example_indices', dtype=dtypes.int64)
+ self._feature_indices = internal_convert_to_tensor(
+ feature_indices, name='feature_indices', dtype=dtypes.int64)
+ self._feature_values = None
+ if feature_values is not None:
+ with name_scope(None, 'SparseFeatureColumn', [feature_values]):
+ self._feature_values = internal_convert_to_tensor(
+ feature_values, name='feature_values', dtype=dtypes.float32)
+
+ @property
+ def example_indices(self):
+ """The example indices represented as a dense tensor.
+
+ Returns:
+ A 1-D Tensor of int64 with shape `[N]`.
+ """
+ return self._example_indices
+
+ @property
+ def feature_indices(self):
+ """The feature indices represented as a dense tensor.
+
+ Returns:
+ A 1-D Tensor of int64 with shape `[N]`.
+ """
+ return self._feature_indices
+
+ @property
+ def feature_values(self):
+ """The feature values represented as a dense tensor.
+
+ Returns:
+ May return None, or a 1-D Tensor of float32 with shape `[N]`.
+ """
+ return self._feature_values
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
new file mode 100644
index 0000000000..f2e4ca0c88
--- /dev/null
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
@@ -0,0 +1,51 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for sparse_feature_column.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.linear_optimizer.python.ops.sparse_feature_column import SparseFeatureColumn
+from tensorflow.python.framework.test_util import TensorFlowTestCase
+from tensorflow.python.platform import googletest
+
+
+class SparseFeatureColumnTest(TensorFlowTestCase):
+ """Tests for SparseFeatureColumn.
+ """
+
+ def testBasic(self):
+ expected_example_indices = [1, 1, 1, 2]
+ expected_feature_indices = [0, 1, 2, 0]
+ sfc = SparseFeatureColumn(expected_example_indices,
+ expected_feature_indices, None)
+ self.assertTrue(isinstance(sfc.example_indices, tf.Tensor))
+ self.assertTrue(isinstance(sfc.feature_indices, tf.Tensor))
+ self.assertEqual(sfc.feature_values, None)
+ with self.test_session():
+ self.assertAllEqual(expected_example_indices, sfc.example_indices.eval())
+ self.assertAllEqual(expected_feature_indices, sfc.feature_indices.eval())
+ expected_feature_values = [1.0, 2.0, 3.0, 4.0]
+ sfc = SparseFeatureColumn([1, 1, 1, 2], [0, 1, 2, 0],
+ expected_feature_values)
+ with self.test_session():
+ self.assertAllEqual(expected_feature_values, sfc.feature_values.eval())
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
index 6ff4bf3175..644347f0b5 100644
--- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
+++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
@@ -18,6 +18,7 @@ from __future__ import print_function
from tensorflow.contrib import layers
from tensorflow.contrib.linear_optimizer.python.ops import sdca_ops
+from tensorflow.contrib.linear_optimizer.python.ops.sparse_feature_column import SparseFeatureColumn
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -86,7 +87,7 @@ class SDCAOptimizer(object):
sparse_values = array_ops.gather_nd(dense_tensor, sparse_indices)
# TODO(sibyl-Aix6ihai, sibyl-vie3Poto): Makes this efficient, as now SDCA supports
# very sparse features with weights and not weights.
- return sdca_ops.SparseFeatureColumn(
+ return SparseFeatureColumn(
array_ops.reshape(
array_ops.split(1, 2, sparse_indices)[0], [-1]),
array_ops.reshape(
@@ -134,7 +135,7 @@ class SDCAOptimizer(object):
columns_to_variables[column][0])
elif isinstance(column, (layers.feature_column._CrossedColumn,
layers.feature_column._SparseColumn)):
- sparse_features.append(sdca_ops.SparseFeatureColumn(
+ sparse_features.append(SparseFeatureColumn(
array_ops.reshape(
array_ops.split(1, 2, transformed_tensor.indices)[0], [-1]),
array_ops.reshape(transformed_tensor.values, [-1]), None))
@@ -142,7 +143,7 @@ class SDCAOptimizer(object):
elif isinstance(column, layers.feature_column._WeightedSparseColumn):
id_tensor = column.id_tensor(transformed_tensor)
weight_tensor = column.weight_tensor(transformed_tensor)
- sparse_feature_with_values.append(sdca_ops.SparseFeatureColumn(
+ sparse_feature_with_values.append(SparseFeatureColumn(
array_ops.reshape(
array_ops.split(1, 2, id_tensor.indices)[0], [-1]),
array_ops.reshape(id_tensor.values, [-1]), array_ops.reshape(
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index 7610f9275f..c17b251d3e 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -117,9 +117,9 @@ def _safe_div(numerator, denominator, name="value"):
Returns:
The element-wise value of the numerator divided by the denominator.
"""
- return math_ops.select(
+ return array_ops.where(
math_ops.greater(denominator, 0),
- math_ops.div(numerator, math_ops.select(
+ math_ops.div(numerator, array_ops.where(
math_ops.equal(denominator, 0),
array_ops.ones_like(denominator), denominator)),
array_ops.zeros_like(numerator),
@@ -144,12 +144,13 @@ def _safe_mean(losses, num_present):
@deprecated_args(
"2016-11-25", "`weight` is being deprecated, use `weights`.", "weight")
def compute_weighted_loss(
- losses, weights=_WEIGHT_SENTINEL, weight=_WEIGHT_SENTINEL):
+ losses, weights=_WEIGHT_SENTINEL, scope=None, weight=_WEIGHT_SENTINEL):
"""Computes the weighted loss.
Args:
losses: A tensor of size [batch_size, d1, ... dN].
weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
+ scope: the scope for the operations performed in computing the loss.
weight: Deprecated alias for `weights`.
Returns:
@@ -161,27 +162,28 @@ def compute_weighted_loss(
`weights` is missing.
"""
weights = _weights(weights, weight)
- losses = ops.convert_to_tensor(losses)
- input_dtype = losses.dtype
- losses = math_ops.to_float(losses)
- weights = math_ops.to_float(ops.convert_to_tensor(weights))
+ with ops.name_scope(scope, "weighted_loss", [losses, weights]):
+ losses = ops.convert_to_tensor(losses)
+ input_dtype = losses.dtype
+ losses = math_ops.to_float(losses)
+ weights = math_ops.to_float(ops.convert_to_tensor(weights))
- if losses.get_shape().ndims is None:
- raise ValueError("losses.get_shape().ndims cannot be None")
- weights_shape = weights.get_shape()
- if weights_shape.ndims is None:
- raise ValueError("weight.get_shape().ndims cannot be None")
+ if losses.get_shape().ndims is None:
+ raise ValueError("losses.get_shape().ndims cannot be None")
+ weights_shape = weights.get_shape()
+ if weights_shape.ndims is None:
+ raise ValueError("weight.get_shape().ndims cannot be None")
- if weights_shape.ndims > 1 and weights_shape.dims[-1].is_compatible_with(1):
- weights = array_ops.squeeze(weights, [-1])
+ if weights_shape.ndims > 1 and weights_shape.dims[-1].is_compatible_with(1):
+ weights = array_ops.squeeze(weights, [-1])
- total_loss = _scale_losses(losses, weights)
- num_present = _num_present(losses, weights)
- mean_loss = _safe_mean(total_loss, num_present)
- # convert the result back to the input type
- mean_loss = math_ops.cast(mean_loss, input_dtype)
- add_loss(mean_loss)
- return mean_loss
+ total_loss = _scale_losses(losses, weights)
+ num_present = _num_present(losses, weights)
+ mean_loss = _safe_mean(total_loss, num_present)
+ # convert the result back to the input type
+ mean_loss = math_ops.cast(mean_loss, input_dtype)
+ add_loss(mean_loss)
+ return mean_loss
def _num_present(losses, weights, per_batch=False):
@@ -211,7 +213,7 @@ def _num_present(losses, weights, per_batch=False):
[0], [1]), [])
num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)),
math_ops.to_float(batch_size))
- num_per_batch = math_ops.select(math_ops.equal(weights, 0),
+ num_per_batch = array_ops.where(math_ops.equal(weights, 0),
0.0, num_per_batch)
num_per_batch = math_ops.mul(array_ops.ones(
array_ops.reshape(batch_size, [1])), num_per_batch)
@@ -334,7 +336,7 @@ def absolute_difference(
predictions = math_ops.to_float(predictions)
labels = math_ops.to_float(labels)
losses = math_ops.abs(math_ops.sub(predictions, labels))
- return compute_weighted_loss(losses, weights)
+ return compute_weighted_loss(losses, weights, scope=scope)
@deprecated_args(
@@ -373,7 +375,7 @@ def sigmoid_cross_entropy(
"""
weights = _weights(weights, weight)
with ops.name_scope(scope, "sigmoid_cross_entropy_loss",
- [logits, multi_class_labels, weights]):
+ [logits, multi_class_labels, weights]) as scope:
logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape())
multi_class_labels = math_ops.cast(multi_class_labels, logits.dtype)
@@ -384,7 +386,7 @@ def sigmoid_cross_entropy(
losses = nn.sigmoid_cross_entropy_with_logits(logits, multi_class_labels,
name="xentropy")
- return compute_weighted_loss(losses, weights)
+ return compute_weighted_loss(losses, weights, scope=scope)
@deprecated_args(
@@ -421,7 +423,7 @@ def softmax_cross_entropy(
"""
weights = _weights(weights, weight)
with ops.name_scope(scope, "softmax_cross_entropy_loss",
- [logits, onehot_labels, weights]):
+ [logits, onehot_labels, weights]) as scope:
logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape())
onehot_labels = math_ops.cast(onehot_labels, logits.dtype)
@@ -435,7 +437,7 @@ def softmax_cross_entropy(
losses = nn.softmax_cross_entropy_with_logits(logits, onehot_labels,
name="xentropy")
- return compute_weighted_loss(losses, weights)
+ return compute_weighted_loss(losses, weights, scope=scope)
@deprecated_args(
@@ -468,13 +470,13 @@ def sparse_softmax_cross_entropy(
"""
weights = _weights(weights, weight)
with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
- [logits, labels, weights]):
+ [logits, labels, weights]) as scope:
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
weights = array_ops.squeeze(weights)
losses = nn.sparse_softmax_cross_entropy_with_logits(logits, labels,
name="xentropy")
- return compute_weighted_loss(losses, weights)
+ return compute_weighted_loss(losses, weights, scope=scope)
@deprecated_args(
@@ -523,7 +525,7 @@ def log_loss(
labels,
math_ops.log(predictions + epsilon)) - math_ops.mul(
(1 - labels), math_ops.log(1 - predictions + epsilon))
- return compute_weighted_loss(losses, weights)
+ return compute_weighted_loss(losses, weights, scope=scope)
@deprecated_args(
@@ -597,7 +599,7 @@ def mean_squared_error(
predictions = math_ops.to_float(predictions)
labels = math_ops.to_float(labels)
losses = math_ops.square(math_ops.sub(predictions, labels))
- return compute_weighted_loss(losses, weights)
+ return compute_weighted_loss(losses, weights, scope=scope)
@deprecated_args(
@@ -681,7 +683,7 @@ def mean_pairwise_squared_error(
loss = _scale_losses(term1 - term2, weights)
- mean_loss = math_ops.select(math_ops.reduce_sum(num_present_per_batch) > 0,
+ mean_loss = array_ops.where(math_ops.reduce_sum(num_present_per_batch) > 0,
loss,
array_ops.zeros_like(loss),
name="value")
@@ -732,4 +734,4 @@ def cosine_distance(
radial_diffs = math_ops.mul(predictions, labels)
losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[dim,])
- return compute_weighted_loss(losses, weights)
+ return compute_weighted_loss(losses, weights, scope=scope)
diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile
index 380f2b440a..4c0d32f115 100644
--- a/tensorflow/contrib/makefile/Makefile
+++ b/tensorflow/contrib/makefile/Makefile
@@ -204,6 +204,7 @@ ifeq ($(TARGET),PI)
endif
# Set up Android building
+# LINT.IfChange
ifeq ($(TARGET),ANDROID)
# Override NDK_ROOT on the command line with your own NDK location, e.g.
# make -f tensorflow/contrib/makefile/Makefile TARGET=ANDROID \
@@ -276,6 +277,7 @@ ifeq ($(TARGET),ANDROID)
endif
endif # ANDROID
+# LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt)
# Settings for iOS.
ifeq ($(TARGET),IOS)
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index d07d1508a3..d39dc1d430 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -75,6 +75,8 @@ tensorflow/core/kernels/padding_fifo_queue.cc
tensorflow/core/kernels/pad_op.cc
tensorflow/core/kernels/pack_op.cc
tensorflow/core/kernels/ops_util.cc
+tensorflow/core/kernels/one_hot_op.cc
+tensorflow/core/kernels/non_max_suppression_op.cc
tensorflow/core/kernels/no_op.cc
tensorflow/core/kernels/mirror_pad_op.cc
tensorflow/core/kernels/mirror_pad_op_cpu_impl_1.cc
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 1199f8dd95..c6d6b50e90 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -52,7 +52,7 @@ def _safe_div(numerator, denominator, name):
Returns:
0 if `denominator` <= 0, else `numerator` / `denominator`
"""
- return math_ops.select(
+ return array_ops.where(
math_ops.greater(denominator, 0),
math_ops.truediv(numerator, denominator),
0,
@@ -587,7 +587,7 @@ def streaming_precision(predictions, labels, weights=None,
updates_collections=None, name=None)
def compute_precision(name):
- return math_ops.select(
+ return array_ops.where(
math_ops.greater(true_positives + false_positives, 0),
math_ops.div(true_positives, true_positives + false_positives),
0,
@@ -661,7 +661,7 @@ def streaming_recall(predictions, labels, weights=None,
updates_collections=None, name=None)
def compute_recall(true_positives, false_negatives, name):
- return math_ops.select(
+ return array_ops.where(
math_ops.greater(true_positives + false_negatives, 0),
math_ops.div(true_positives, true_positives + false_negatives),
0,
@@ -2388,7 +2388,7 @@ def streaming_mean_relative_error(predictions, labels, normalizer, weights=None,
predictions, normalizer = tensor_util.remove_squeezable_dimensions(
predictions, normalizer)
predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
- relative_errors = math_ops.select(
+ relative_errors = array_ops.where(
math_ops.equal(normalizer, 0.0),
array_ops.zeros_like(labels),
math_ops.div(math_ops.abs(labels - predictions), normalizer))
@@ -2923,7 +2923,7 @@ def streaming_mean_iou(predictions,
# If the value of the denominator is 0, set it to 1 to avoid
# zero division.
- denominator = math_ops.select(
+ denominator = array_ops.where(
math_ops.greater(denominator, 0),
denominator,
array_ops.ones_like(denominator))
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 9890e712c1..80ca709b3a 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn_cell
+from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
@@ -998,7 +999,7 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
# pylint: disable=protected-access
-_linear = rnn_cell._linear
+_linear = rnn_cell_impl._linear
# pylint: enable=protected-access
diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD
index 7750e54569..b6bfcc748a 100644
--- a/tensorflow/contrib/session_bundle/BUILD
+++ b/tensorflow/contrib/session_bundle/BUILD
@@ -254,13 +254,6 @@ cc_library(
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
-filegroup(
- name = "saved_model_half_plus_two",
- srcs = glob([
- "testdata/saved_model_half_plus_two/**",
- ]),
-)
-
cc_library(
name = "bundle_shim",
srcs = ["bundle_shim.cc"],
@@ -287,8 +280,8 @@ cc_test(
size = "small",
srcs = ["bundle_shim_test.cc"],
data = [
- ":saved_model_half_plus_two",
"//tensorflow/contrib/session_bundle/example:half_plus_two",
+ "//tensorflow/python/saved_model/example:saved_model_half_plus_two_data",
],
# Link in all registered kernels.
linkstatic = 1,
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.cc b/tensorflow/contrib/session_bundle/bundle_shim.cc
index 47a0935472..1ce2753c57 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim.cc
@@ -127,10 +127,16 @@ Status ConvertNamedSignaturesToSignatureDef(const Signatures& signatures,
AddOutputToSignatureDef(map_entry.second.tensor_name(), map_entry.first,
&signature_def);
}
- // Add the `default` key to the signature def map of the meta graph def and
- // map it to the constructed signature def.
- (*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] =
- signature_def;
+ // Add the constructed signature def to the signature def map of the meta
+ // graph def. Use the default key if it isn't already in use.
+ const bool already_has_default_signature =
+ meta_graph_def->signature_def().find(kDefaultServingSignatureDefKey) !=
+ meta_graph_def->signature_def().end();
+ const string signature_def_key =
+ already_has_default_signature
+ ? strings::StrCat(kDefaultServingSignatureDefKey, "_from_named")
+ : kDefaultServingSignatureDefKey;
+ (*meta_graph_def->mutable_signature_def())[signature_def_key] = signature_def;
return Status::OK();
}
@@ -138,9 +144,12 @@ Status ConvertSignaturesToSignatureDef(MetaGraphDef* meta_graph_def) {
Signatures signatures;
GetSignatures(*meta_graph_def, &signatures);
if (signatures.has_default_signature()) {
- return ConvertDefaultSignatureToSignatureDef(signatures, meta_graph_def);
- } else if (!signatures.named_signatures().empty()) {
- return ConvertNamedSignaturesToSignatureDef(signatures, meta_graph_def);
+ TF_RETURN_IF_ERROR(
+ ConvertDefaultSignatureToSignatureDef(signatures, meta_graph_def));
+ }
+ if (!signatures.named_signatures().empty()) {
+ TF_RETURN_IF_ERROR(
+ ConvertNamedSignaturesToSignatureDef(signatures, meta_graph_def));
}
return Status::OK();
}
diff --git a/tensorflow/contrib/session_bundle/bundle_shim_test.cc b/tensorflow/contrib/session_bundle/bundle_shim_test.cc
index cfdd05e608..a8dca12195 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim_test.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim_test.cc
@@ -35,7 +35,7 @@ constexpr char kSessionBundlePath[] =
constexpr char kSessionBundleMetaGraphFilename[] = "export.meta";
constexpr char kSessionBundleVariablesFilename[] = "export-00000-of-00001";
constexpr char kSavedModelBundlePath[] =
- "contrib/session_bundle/testdata/saved_model_half_plus_two";
+ "python/saved_model/example/saved_model_half_plus_two/00000123";
string MakeSerializedExample(float x) {
tensorflow::Example example;
@@ -72,16 +72,20 @@ void LoadAndValidateSavedModelBundle(const string& export_dir,
session_options, run_options, export_dir, tags, &saved_model_bundle));
const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
const auto& signature_def_map = meta_graph_def.signature_def();
- EXPECT_EQ(1, signature_def_map.size());
const auto& regression_entry = signature_def_map.find(signature_def_key);
+ ASSERT_FALSE(regression_entry == signature_def_map.end());
SignatureDef regression_signature_def = regression_entry->second;
EXPECT_EQ(1, regression_signature_def.inputs_size());
+ ASSERT_FALSE(regression_signature_def.inputs().find(kRegressInputs) ==
+ regression_signature_def.inputs().end());
TensorInfo input_tensor_info =
regression_signature_def.inputs().find(kRegressInputs)->second;
EXPECT_EQ(1, regression_signature_def.outputs_size());
+ ASSERT_FALSE(regression_signature_def.outputs().find(kRegressOutputs) ==
+ regression_signature_def.outputs().end());
TensorInfo output_tensor_info =
regression_signature_def.outputs().find(kRegressOutputs)->second;
ValidateHalfPlusTwo(saved_model_bundle, input_tensor_info.name(),
@@ -261,9 +265,14 @@ TEST(BundleShimTest, NamedSignatureGenericInputsAndOutputs) {
EXPECT_EQ(1, meta_graph_def.signature_def_size());
const auto actual_signature_def =
meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
+ ASSERT_FALSE(actual_signature_def == meta_graph_def.signature_def().end());
+ ASSERT_FALSE(actual_signature_def->second.inputs().find("foo-input") ==
+ actual_signature_def->second.inputs().end());
EXPECT_EQ(
"foo-input",
actual_signature_def->second.inputs().find("foo-input")->second.name());
+ ASSERT_FALSE(actual_signature_def->second.outputs().find("foo-output") ==
+ actual_signature_def->second.outputs().end());
EXPECT_EQ(
"foo-output",
actual_signature_def->second.outputs().find("foo-output")->second.name());
@@ -318,10 +327,40 @@ TEST(BundleShimTest, NamedSignatureGenericOnlyInput) {
// Checks a basic up conversion for half plus two for SessionBundle.
TEST(BundleShimTest, BasicExportSessionBundle) {
+ const std::unordered_set<string> tags = {"tag"};
const string session_bundle_export_dir =
test_util::TestSrcDirPath(kSessionBundlePath);
- LoadAndValidateSavedModelBundle(session_bundle_export_dir, {"tag"},
+ LoadAndValidateSavedModelBundle(session_bundle_export_dir, tags,
kDefaultServingSignatureDefKey);
+
+ // Verify that the named signature is also present.
+ SessionOptions session_options;
+ RunOptions run_options;
+ SavedModelBundle saved_model_bundle;
+ TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(session_options, run_options,
+ session_bundle_export_dir,
+ tags, &saved_model_bundle));
+ const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
+ const auto& signature_def_map = meta_graph_def.signature_def();
+ bool found_named_signature = false;
+ for (const auto& entry : signature_def_map) {
+ const string& key = entry.first;
+ const SignatureDef& signature_def = entry.second;
+
+ // We're looking for the key that is *not* kDefaultServingSignatureDefKey.
+ if (key == kDefaultServingSignatureDefKey) {
+ continue;
+ }
+ found_named_signature = true;
+
+ EXPECT_EQ(1, signature_def.inputs_size());
+ EXPECT_FALSE(signature_def.inputs().find("x") ==
+ signature_def.inputs().end());
+ EXPECT_EQ(1, signature_def.outputs_size());
+ EXPECT_FALSE(signature_def.outputs().find("y") ==
+ signature_def.outputs().end());
+ }
+ EXPECT_TRUE(found_named_signature);
}
// Checks a basic load for half plus two for SavedModelBundle.
diff --git a/tensorflow/contrib/session_bundle/session_bundle.cc b/tensorflow/contrib/session_bundle/session_bundle.cc
index 2b608a1348..bc6fdcd4de 100644
--- a/tensorflow/contrib/session_bundle/session_bundle.cc
+++ b/tensorflow/contrib/session_bundle/session_bundle.cc
@@ -43,12 +43,13 @@ namespace serving {
namespace {
auto* load_attempt_count = monitoring::Counter<2>::New(
- "/tensorflow/contrib/session_bundle/load_attempt_count", "model_path",
- "status",
- "The number of times a SessionBundle was requested to be loaded.");
+ "/tensorflow/contrib/session_bundle/load_attempt_count",
+ "The number of times a SessionBundle was requested to be loaded.",
+ "model_path", "status");
auto* load_latency = monitoring::Counter<1>::New(
- "/tensorflow/contrib/session_bundle/load_latency", "model_path",
- "Latency in microseconds for SessionBundles that were successfully loaded.");
+ "/tensorflow/contrib/session_bundle/load_latency",
+ "Latency in microseconds for SessionBundles that were successfully loaded.",
+ "model_path");
constexpr char kLoadAttemptFail[] = "fail";
constexpr char kLoadAttemptSuccess[] = "success";
diff --git a/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/saved_model.pb b/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/saved_model.pb
deleted file mode 100644
index e894f9b101..0000000000
--- a/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/saved_model.pb
+++ /dev/null
Binary files differ
diff --git a/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/variables/variables.data-00000-of-00001 b/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/variables/variables.data-00000-of-00001
deleted file mode 100644
index 20bc7d454d..0000000000
--- a/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/variables/variables.data-00000-of-00001
+++ /dev/null
Binary files differ
diff --git a/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/variables/variables.index b/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/variables/variables.index
deleted file mode 100644
index e7df518f5b..0000000000
--- a/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/variables/variables.index
+++ /dev/null
Binary files differ
diff --git a/tensorflow/contrib/slim/python/slim/evaluation.py b/tensorflow/contrib/slim/python/slim/evaluation.py
index b122799a32..b89eca46ea 100644
--- a/tensorflow/contrib/slim/python/slim/evaluation.py
+++ b/tensorflow/contrib/slim/python/slim/evaluation.py
@@ -86,7 +86,7 @@ more summaries and call the evaluation_loop method:
logdir,
num_evals=num_evals,
eval_op=names_to_updates.values(),
- summary_op=tf.merge_summary(summary_ops),
+ summary_op=tf.contrib.deprecated.merge_summary(summary_ops),
eval_interval_secs=600)
**************************************************
@@ -113,7 +113,7 @@ with only summaries. The user need only leave out the 'eval_op' argument:
checkpoint_dir,
logdir,
num_evals=1,
- summary_op=tf.merge_summary(summary_ops),
+ summary_op=tf.contrib.deprecated.merge_summary(summary_ops),
eval_interval_secs=600)
"""
@@ -122,160 +122,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import time
-
-from tensorflow.contrib.framework.python.ops import variables
+from tensorflow.contrib.training.python.training import evaluation
from tensorflow.python import summary
-from tensorflow.python.framework import ops
-from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import saver as tf_saver
-from tensorflow.python.training import summary_io
-from tensorflow.python.training import supervisor
-from tensorflow.python.training import training_util
+from tensorflow.python.training import monitored_session
__all__ = [
'evaluate_once',
- 'evaluation',
'evaluation_loop',
'wait_for_new_checkpoint',
'checkpoints_iterator',
]
-
-def wait_for_new_checkpoint(checkpoint_dir,
- last_checkpoint,
- seconds_to_sleep=1,
- timeout=None):
- """Waits until a new checkpoint file is found.
-
- Args:
- checkpoint_dir: The directory in which checkpoints are saved.
- last_checkpoint: The last checkpoint path used.
- seconds_to_sleep: The number of seconds to sleep for before looking for a
- new checkpoint.
- timeout: The maximum amount of time to wait. If left as `None`, then the
- process will wait indefinitely.
-
- Returns:
- a new checkpoint path, or None if the timeout was reached.
- """
- logging.info('Waiting for new checkpoint at %s', checkpoint_dir)
- stop_time = time.time() + timeout if timeout is not None else None
- while True:
- checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir)
- if checkpoint_path is None or checkpoint_path == last_checkpoint:
- if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
- return None
- time.sleep(seconds_to_sleep)
- else:
- logging.info('Found new checkpoint at %s', checkpoint_path)
- return checkpoint_path
-
-
-def checkpoints_iterator(checkpoint_dir,
- min_interval_secs=0,
- timeout=None):
- """Continuously yield new checkpoint files as they appear.
-
- The iterator only checks for new checkpoints when control flow has been
- reverted to it. This means it can miss checkpoints if your code takes longer
- to run between iterations than `min_interval_secs` or the interval at which
- new checkpoints are written.
-
- Args:
- checkpoint_dir: The directory in which checkpoints are saved.
- min_interval_secs: The minimum number of seconds between yielding
- checkpoints.
- timeout: The maximum amount of time to wait between checkpoints. If left as
- `None`, then the process will wait indefinitely.
-
- Yields:
- String paths to latest checkpoint files as they arrive. Stops yielding only
- if/when waiting for a checkpoint times out.
- """
- checkpoint_path = None
- while True:
- checkpoint_path = wait_for_new_checkpoint(
- checkpoint_dir, checkpoint_path, timeout=timeout)
- if checkpoint_path is None:
- # timed out
- return
- start = time.time()
- yield checkpoint_path
- time_to_next_eval = start + min_interval_secs - time.time()
- if time_to_next_eval > 0:
- time.sleep(time_to_next_eval)
-
-
-def evaluation(sess,
- num_evals=1,
- initial_op=None,
- initial_op_feed_dict=None,
- eval_op=None,
- eval_op_feed_dict=None,
- final_op=None,
- final_op_feed_dict=None,
- summary_op=None,
- summary_op_feed_dict=None,
- summary_writer=None,
- global_step=None):
- """Performs a single evaluation run.
-
- A single evaluation consists of several steps run in the following order:
- (1) an initialization op, (2) an evaluation op which is executed `num_evals`
- times (3) a finalization op and (4) the execution of a summary op which is
- written out using a summary writer.
-
- Args:
- sess: The current TensorFlow `Session`.
- num_evals: The number of times to execute `eval_op`.
- initial_op: An operation run at the beginning of evaluation.
- initial_op_feed_dict: A feed dictionary to use when executing `initial_op`.
- eval_op: A operation run `num_evals` times.
- eval_op_feed_dict: The feed dictionary to use when executing the `eval_op`.
- final_op: An operation to execute after all of the `eval_op` executions. The
- value of `final_op` is returned.
- final_op_feed_dict: A feed dictionary to use when executing `final_op`.
- summary_op: A summary op executed after `eval_op` and `finalize_op`.
- summary_op_feed_dict: An optional feed dictionary to use when executing the
- `summary_op`.
- summary_writer: The summery writer used if `summary_op` is provided.
- global_step: the global step variable. If left as `None`, then
- slim.variables.global_step() is used.
-
- Returns:
- The value of `final_op` or `None` if `final_op` is `None`.
-
- Raises:
- ValueError: if `summary_op` is provided but `global_step` is `None`.
- """
- if initial_op is not None:
- logging.info('Executing initial eval op')
- sess.run(initial_op, initial_op_feed_dict)
-
- if eval_op is not None:
- logging.info('Executing eval ops')
- for i in range(int(num_evals)):
- logging.info('Executing eval_op %d/%d', i + 1, num_evals)
- sess.run(eval_op, eval_op_feed_dict)
-
- if final_op is not None:
- logging.info('Executing final op')
- final_op_value = sess.run(final_op, final_op_feed_dict)
- else:
- final_op_value = None
-
- if summary_op is not None:
- logging.info('Executing summary op')
- if global_step is None:
- global_step = variables.get_or_create_global_step()
-
- global_step = training_util.global_step(sess, global_step)
- summary_str = sess.run(summary_op, summary_op_feed_dict)
- summary_writer.add_summary(summary_str, global_step)
- summary_writer.flush()
-
- return final_op_value
+wait_for_new_checkpoint = evaluation.wait_for_new_checkpoint
+checkpoints_iterator = evaluation.checkpoints_iterator
_USE_DEFAULT = 0
@@ -325,43 +184,27 @@ def evaluate_once(master,
if summary_op == _USE_DEFAULT:
summary_op = summary.merge_all()
- global_step = variables.get_or_create_global_step()
-
- saver = tf_saver.Saver(variables_to_restore or
- variables.get_variables_to_restore())
-
- summary_writer = summary_io.SummaryWriter(logdir)
-
- sv = supervisor.Supervisor(graph=ops.get_default_graph(),
- logdir=logdir,
- summary_op=None,
- summary_writer=None,
- global_step=None,
- saver=None)
-
- logging.info('Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
- time.gmtime()))
- with sv.managed_session(
- master, start_standard_services=False, config=session_config) as sess:
- saver.restore(sess, checkpoint_path)
- sv.start_queue_runners(sess)
- final_op_value = evaluation(sess,
- num_evals=num_evals,
- initial_op=initial_op,
- initial_op_feed_dict=initial_op_feed_dict,
- eval_op=eval_op,
- eval_op_feed_dict=eval_op_feed_dict,
- final_op=final_op,
- final_op_feed_dict=final_op_feed_dict,
- summary_op=summary_op,
- summary_op_feed_dict=summary_op_feed_dict,
- summary_writer=summary_writer,
- global_step=global_step)
-
- logging.info('Finished evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
- time.gmtime()))
-
- return final_op_value
+ hooks = [
+ evaluation.StopAfterNEvalsHook(num_evals),
+ ]
+
+ if summary_op is not None:
+ hooks.append(
+ evaluation.SummaryAtEndHook(logdir, summary_op, summary_op_feed_dict))
+
+ return evaluation.evaluate_once(
+ checkpoint_path,
+ master=master,
+ scaffold=monitored_session.Scaffold(
+ init_op=initial_op,
+ init_feed_dict=initial_op_feed_dict),
+ eval_ops=eval_op,
+ feed_dict=eval_op_feed_dict,
+ final_ops=final_op,
+ final_ops_feed_dict=final_op_feed_dict,
+ variables_to_restore=variables_to_restore,
+ hooks=hooks,
+ config=session_config)
def evaluation_loop(master,
@@ -416,53 +259,27 @@ def evaluation_loop(master,
if summary_op == _USE_DEFAULT:
summary_op = summary.merge_all()
- global_step = variables.get_or_create_global_step()
-
- saver = tf_saver.Saver(variables_to_restore or
- variables.get_variables_to_restore())
-
- summary_writer = summary_io.SummaryWriter(logdir)
-
- sv = supervisor.Supervisor(graph=ops.get_default_graph(),
- logdir=logdir,
- summary_op=None,
- summary_writer=None,
- global_step=None,
- saver=saver)
-
- number_of_evaluations = 0
- for checkpoint_path in checkpoints_iterator(checkpoint_dir,
- eval_interval_secs,
- timeout):
- logging.info('Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
- time.gmtime()))
-
- with sv.managed_session(
- master, start_standard_services=False, config=session_config) as sess:
- sv.saver.restore(sess, checkpoint_path)
- sv.start_queue_runners(sess)
- final_op_value = evaluation(sess,
- num_evals=num_evals,
- initial_op=initial_op,
- initial_op_feed_dict=initial_op_feed_dict,
- eval_op=eval_op,
- eval_op_feed_dict=eval_op_feed_dict,
- final_op=final_op,
- final_op_feed_dict=final_op_feed_dict,
- summary_op=summary_op,
- summary_op_feed_dict=summary_op_feed_dict,
- summary_writer=summary_writer,
- global_step=global_step)
-
- logging.info('Finished evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
- time.gmtime()))
- number_of_evaluations += 1
- if (max_number_of_evaluations and
- number_of_evaluations >= max_number_of_evaluations):
- logging.info('Reached max_number_of_evaluations=%s. Exit',
- max_number_of_evaluations)
- return final_op_value
-
- logging.info(
- 'Timed-out waiting for new checkpoint file. Exiting evaluation loop.')
- return final_op_value
+ hooks = [
+ evaluation.StopAfterNEvalsHook(num_evals),
+ ]
+
+ if summary_op is not None:
+ hooks.append(
+ evaluation.SummaryAtEndHook(logdir, summary_op, summary_op_feed_dict))
+
+ return evaluation.evaluate_repeatedly(
+ checkpoint_dir,
+ master=master,
+ scaffold=monitored_session.Scaffold(
+ init_op=initial_op,
+ init_feed_dict=initial_op_feed_dict),
+ eval_ops=eval_op,
+ feed_dict=eval_op_feed_dict,
+ final_ops=final_op,
+ final_ops_feed_dict=final_op_feed_dict,
+ variables_to_restore=variables_to_restore,
+ eval_interval_secs=eval_interval_secs,
+ hooks=hooks,
+ config=session_config,
+ max_number_of_evaluations=max_number_of_evaluations,
+ timeout=timeout)
diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py
index b12c82e985..a308a515bd 100644
--- a/tensorflow/contrib/slim/python/slim/evaluation_test.py
+++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py
@@ -73,28 +73,6 @@ class EvaluationTest(tf.test.TestCase):
self._labels = tf.constant(labels, dtype=tf.int64)
self._predictions, self._scale = TestModel(self._inputs)
- def testUpdateOpsAreEvaluated(self):
- accuracy, update_op = slim.metrics.streaming_accuracy(
- self._predictions, self._labels)
- initial_op = tf.group(tf.global_variables_initializer(),
- tf.local_variables_initializer())
-
- with self.test_session() as sess:
- slim.evaluation.evaluation(
- sess, initial_op=initial_op, eval_op=update_op)
- self.assertAlmostEqual(accuracy.eval(), self._expected_accuracy)
-
- def testFinalOpsIsEvaluated(self):
- _, update_op = slim.metrics.streaming_accuracy(
- self._predictions, self._labels)
- initial_op = tf.group(tf.global_variables_initializer(),
- tf.local_variables_initializer())
-
- with self.test_session() as sess:
- accuracy_value = slim.evaluation.evaluation(
- sess, initial_op=initial_op, final_op=update_op)
- self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
-
def testFinalOpsOnEvaluationLoop(self):
value_op, update_op = slim.metrics.streaming_accuracy(
self._predictions, self._labels)
@@ -153,96 +131,6 @@ class EvaluationTest(tf.test.TestCase):
for name in names_to_values:
self.assertAlmostEqual(names_to_values[name], saved_results[name])
- def testSummariesAreFlushedToDisk(self):
- output_dir = os.path.join(self.get_temp_dir(), 'flush_test')
- if tf.gfile.Exists(output_dir): # For running on jenkins.
- tf.gfile.DeleteRecursively(output_dir)
-
- names_to_metrics, names_to_updates = self._create_names_to_metrics(
- self._predictions, self._labels)
-
- for k in names_to_metrics:
- v = names_to_metrics[k]
- tf.summary.scalar(k, v)
-
- summary_writer = tf.train.SummaryWriter(output_dir)
-
- initial_op = tf.group(tf.global_variables_initializer(),
- tf.local_variables_initializer())
- eval_op = tf.group(*names_to_updates.values())
-
- with self.test_session() as sess:
- slim.evaluation.evaluation(
- sess,
- initial_op=initial_op,
- eval_op=eval_op,
- summary_op=tf.summary.merge_all(),
- summary_writer=summary_writer,
- global_step=self._global_step)
-
- names_to_values = {name: names_to_metrics[name].eval()
- for name in names_to_metrics}
- self._verify_summaries(output_dir, names_to_values)
-
- def testSummariesAreFlushedToDiskWithoutGlobalStep(self):
- output_dir = os.path.join(self.get_temp_dir(), 'flush_test_no_global_step')
- if tf.gfile.Exists(output_dir): # For running on jenkins.
- tf.gfile.DeleteRecursively(output_dir)
-
- names_to_metrics, names_to_updates = self._create_names_to_metrics(
- self._predictions, self._labels)
-
- for k in names_to_metrics:
- v = names_to_metrics[k]
- tf.summary.scalar(k, v)
-
- summary_writer = tf.train.SummaryWriter(output_dir)
-
- initial_op = tf.group(tf.global_variables_initializer(),
- tf.local_variables_initializer())
- eval_op = tf.group(*names_to_updates.values())
-
- with self.test_session() as sess:
- slim.evaluation.evaluation(
- sess,
- initial_op=initial_op,
- eval_op=eval_op,
- summary_op=tf.summary.merge_all(),
- summary_writer=summary_writer)
-
- names_to_values = {name: names_to_metrics[name].eval()
- for name in names_to_metrics}
- self._verify_summaries(output_dir, names_to_values)
-
- def testWithFeedDict(self):
- accuracy, update_op = slim.metrics.streaming_accuracy(
- self._predictions, self._labels)
- initial_op = tf.group(tf.global_variables_initializer(),
- tf.local_variables_initializer())
-
- with self.test_session() as sess:
- slim.evaluation.evaluation(
- sess,
- initial_op=initial_op,
- eval_op=update_op,
- eval_op_feed_dict={self._scale: np.ones([], dtype=np.float32)})
- self.assertAlmostEqual(accuracy.eval(), self._expected_accuracy)
-
- def testWithQueueRunning(self):
- strings = ['the', 'cat', 'in', 'the', 'hat']
- _ = tf.train.string_input_producer(strings, capacity=5)
-
- accuracy, update_op = slim.metrics.streaming_accuracy(
- self._predictions, self._labels)
-
- initial_op = tf.group(tf.global_variables_initializer(),
- tf.local_variables_initializer())
-
- with self.test_session() as sess:
- slim.evaluation.evaluation(
- sess, initial_op=initial_op, eval_op=update_op)
- self.assertAlmostEqual(accuracy.eval(), self._expected_accuracy)
-
def testLatestCheckpointReturnsNoneAfterTimeout(self):
start = time.time()
ret = slim.evaluation.wait_for_new_checkpoint(
@@ -259,38 +147,6 @@ class EvaluationTest(tf.test.TestCase):
'/non-existent-dir', timeout=0))
self.assertEqual(ret, [])
- def testEvaluationLoopTimeout(self):
- _, update_op = slim.metrics.streaming_accuracy(
- self._predictions, self._labels)
- init_op = tf.group(tf.global_variables_initializer(),
- tf.local_variables_initializer())
-
- # Create checkpoint and log directories.
- chkpt_dir = os.path.join(self.get_temp_dir(), 'tmp_logs/')
- gfile.MakeDirs(chkpt_dir)
- logdir = os.path.join(self.get_temp_dir(), 'tmp_logs2/')
- gfile.MakeDirs(logdir)
-
- # Save initialized variables to checkpoint directory.
- saver = tf.train.Saver()
- with self.test_session() as sess:
- init_op.run()
- saver.save(sess, os.path.join(chkpt_dir, 'chkpt'))
-
- # Run the evaluation loop with a timeout.
- with self.test_session() as sess:
- start = time.time()
- slim.evaluation.evaluation_loop(
- '', chkpt_dir, logdir, eval_op=update_op,
- eval_interval_secs=2.0, timeout=6.0)
- end = time.time()
-
- # Check we've waited for the timeout.
- self.assertGreater(end - start, 6.0)
-
- # Then the timeout kicked in and stops the loop.
- self.assertLess(end - start, 8.0)
-
class SingleEvaluationTest(tf.test.TestCase):
diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc
index 77d7f4290d..626b6a1daf 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc
+++ b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc
@@ -214,26 +214,15 @@ class TreePredictions : public OpKernel {
errors::InvalidArgument("node_index not in valid range."))
const int32 left_child = tree(node_index, CHILDREN_INDEX);
if (left_child == LEAF_NODE) {
- float sum = node_pcw(node_index, 0);
- float parent_weight = 0.0;
- if (sum < valid_leaf_threshold_ && parent >= 0) {
- VLOG(1) << "not enough samples at leaf, including parent counts."
- << "child sum = " << sum;
- float parent_sum = node_pcw(parent, 0);
- // Weight the parent's counts just enough so that the new sum is
- // valid_leaf_threshold_, but never give any counts a weight of
- // more than 1.
- parent_weight = std::min(1.0f,
- (valid_leaf_threshold_ - sum) / parent_sum);
- sum += parent_weight * parent_sum;
- VLOG(1) << "Sum w/ parent included = " << sum;
- }
+ const int32 flat_leaf_index = node_index * num_classes + 1;
+ const int32 flat_parent_index = parent * num_classes + 1;
+ std::vector<float> means(num_classes - 1);
+ tensorforest::GetParentWeightedMean(
+ node_pcw(node_index, 0), node_pcw.data() + flat_leaf_index,
+ node_pcw(parent, 0), node_pcw.data() + flat_parent_index,
+ valid_leaf_threshold_, num_classes - 1, &means);
for (int c = 1; c < num_classes; c++) {
- float w = node_pcw(node_index, c);
- if (parent_weight > 0.0) {
- w += parent_weight * node_pcw(parent, c);
- }
- out(i, c - 1) = w / sum;
+ out(i, c - 1) = means[c - 1];
}
break;
} else if (left_child == FREE_NODE) {
diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc
index 544336b1ba..5f538c9e41 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc
+++ b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc
@@ -555,6 +555,31 @@ bool IsAllInitialized(const Tensor& features) {
return feature_vec(feature_vec.size() - 1) >= 0;
}
+void GetParentWeightedMean(float leaf_sum, const float* leaf_data,
+ float parent_sum, const float* parent_data,
+ float valid_leaf_threshold, int num_outputs,
+ std::vector<float>* mean) {
+ float parent_weight = 0.0;
+ if (leaf_sum < valid_leaf_threshold && parent_sum >= 0) {
+ VLOG(1) << "not enough samples at leaf, including parent counts."
+ << "child sum = " << leaf_sum;
+ // Weight the parent's counts just enough so that the new sum is
+ // valid_leaf_threshold_, but never give any counts a weight of
+ // more than 1.
+ parent_weight =
+ std::min(1.0f, (valid_leaf_threshold - leaf_sum) / parent_sum);
+ leaf_sum += parent_weight * parent_sum;
+ VLOG(1) << "Sum w/ parent included = " << leaf_sum;
+ }
+
+ for (int c = 0; c < num_outputs; c++) {
+ float w = leaf_data[c];
+ if (parent_weight > 0.0) {
+ w += parent_weight * parent_data[c];
+ }
+ (*mean)[c] = w / leaf_sum;
+ }
+}
} // namespace tensorforest
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h
index 7c7193f0f4..a17622d8f5 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h
+++ b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h
@@ -228,6 +228,11 @@ inline bool CheckTensorBounds(OpKernelContext* context, const Tensor& tensor) {
return true;
}
+void GetParentWeightedMean(float leaf_sum, const float* leaf_data,
+ float parent_sum, const float* parent_data,
+ float valid_leaf_threshold, int num_outputs,
+ std::vector<float>* mean);
+
} // namespace tensorforest
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorboard/BUILD b/tensorflow/contrib/tensorboard/BUILD
index 94b7222737..3321abb7e9 100644
--- a/tensorflow/contrib/tensorboard/BUILD
+++ b/tensorflow/contrib/tensorboard/BUILD
@@ -30,7 +30,10 @@ py_library(
name = "plugins",
srcs = ["plugins/__init__.py"],
srcs_version = "PY2AND3",
- deps = [":projector"],
+ deps = [
+ ":projector",
+ ":trace",
+ ],
)
# API methods and protos in `tf.contrib.tensorboard.plugins.projector` package.
@@ -55,6 +58,31 @@ py_test(
],
)
+# API methods and protos in `tf.contrib.tensorboard.plugins.trace` package.
+py_library(
+ name = "trace",
+ srcs = glob(
+ ["plugins/trace/**/*.py"],
+ exclude = ["**/*test*"],
+ ),
+ srcs_version = "PY2AND3",
+ deps = [
+ ":protos_all_py",
+ "//tensorflow/python:lib",
+ ],
+)
+
+py_test(
+ name = "trace_test",
+ size = "small",
+ srcs = ["plugins/trace/trace_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":trace",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/tensorboard/plugins/__init__.py b/tensorflow/contrib/tensorboard/plugins/__init__.py
index 88336714a7..41aa77910c 100644
--- a/tensorflow/contrib/tensorboard/plugins/__init__.py
+++ b/tensorflow/contrib/tensorboard/plugins/__init__.py
@@ -20,3 +20,4 @@ from __future__ import print_function
# Add projects here, they will show up under tf.contrib.tensorboard.plugins
from tensorflow.contrib.tensorboard.plugins import projector
+from tensorflow.contrib.tensorboard.plugins import trace
diff --git a/tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py b/tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py
index 284f3ba24e..6bb310db3e 100644
--- a/tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py
+++ b/tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py
@@ -38,7 +38,7 @@ class ProjectorApiTest(tf.test.TestCase):
# Call the API method to save the configuration to a temporary dir.
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir)
- writer = tf.train.SummaryWriter(temp_dir)
+ writer = tf.summary.FileWriter(temp_dir)
tf.contrib.tensorboard.plugins.projector.visualize_embeddings(writer,
config)
diff --git a/tensorflow/python/util/net_lib_test.py b/tensorflow/contrib/tensorboard/plugins/trace/__init__.py
index 1e2ad53cda..2c99f4077e 100644
--- a/tensorflow/python/util/net_lib_test.py
+++ b/tensorflow/contrib/tensorboard/plugins/trace/__init__.py
@@ -12,28 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+"""Public API for the Trace plugin."""
-"""Tests for the SWIG-wrapped test lib."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import tensorflow as tf
-
-from tensorflow.python.util import net_lib
-
-
-class TestLibTest(tf.test.TestCase):
-
- def testPickUnusedPortOrDie(self):
- port0 = net_lib.pick_unused_port_or_die()
- port1 = net_lib.pick_unused_port_or_die()
- self.assertGreater(port0, 0)
- self.assertLess(port0, 65536)
- self.assertGreater(port1, 0)
- self.assertLess(port1, 65536)
- self.assertNotEqual(port0, port1)
-
-
-if __name__ == "__main__":
- tf.test.main()
+# pylint: disable=wildcard-import
+from tensorflow.contrib.tensorboard.plugins.trace.trace import *
+from tensorflow.contrib.tensorboard.plugins.trace.trace_info_pb2 import *
+# pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace.py b/tensorflow/contrib/tensorboard/plugins/trace/trace.py
new file mode 100644
index 0000000000..0c645889af
--- /dev/null
+++ b/tensorflow/contrib/tensorboard/plugins/trace/trace.py
@@ -0,0 +1,162 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Stores debugging information regarding TensorFlow model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import parser
+import re
+import token
+import tensorflow as tf
+
+from google.protobuf import json_format
+from tensorflow.contrib.tensorboard.plugins.trace.trace_info_pb2 import TraceInfo
+
+# List of regex patterns that match files in the core tensorflow library.
+TF_LIB_REGEX_FPATHS = [os.sep + os.path.join('tensorflow', 'python')]
+
+LEFT_TOKENS = [token.LPAR, token.LSQB, token.LBRACE]
+RIGHT_TOKENS = [token.RPAR, token.RSQB, token.RBRACE]
+TOKENS = LEFT_TOKENS + RIGHT_TOKENS
+
+
+def store_trace_info(output_file_path, graph=tf.get_default_graph(),
+ ignore_regex_fpaths=None):
+ """Collects and stores trace information for a TensorFlow model.
+
+ The output proto is stored in json format.
+
+ Args:
+ output_file_path: The path where to store the output proto.
+ graph: Optional. The data flow graph. Defaults to `tf.get_default_graph()`.
+ ignore_regex_fpaths: Optional. Files whose path matches any of the regexes
+ in this list will be ignored. Defaults to patterns that match the core
+ tensorflow python library.
+ """
+ if not ignore_regex_fpaths:
+ ignore_regex_fpaths = TF_LIB_REGEX_FPATHS
+
+ trace_info = TraceInfo()
+ # Extract trace information for every op in the graph.
+ source_fpaths = set()
+ for op in graph.get_operations():
+ op_info = trace_info.ops.add()
+ op_info.name = op.name
+ op_info.op_type = op.type
+ op_info.device = op.device
+ for trace in op.traceback:
+ fname, lineno, _, _ = trace
+ # Ignore traces in specified file paths.
+ if os.path.isabs(fname) and not _ignore_file_path(fname,
+ ignore_regex_fpaths):
+ line_trace = op_info.traceback.add()
+ line_trace.file_path = fname
+ line_trace.line_number = lineno
+ source_fpaths.add(fname)
+ _add_data_from_tensors(op.inputs, op_info.inputs)
+ _add_data_from_tensors(op.outputs, op_info.outputs)
+
+ # Read the source files involved in the graph construction.
+ for fpath in source_fpaths:
+ file_info = trace_info.files.add()
+
+ with tf.gfile.Open(fpath, 'r') as f:
+ source = f.read().decode('utf-8')
+
+ file_info.file_path = fpath
+ file_info.source_code = source
+
+ line2start = find_multiline_statements(source)
+
+ for key, value in line2start.items():
+ file_info.multiline_statements[key] = value
+
+ # Make sure the directory for the output file exists.
+ output_file_path = os.path.expanduser(output_file_path)
+ output_dir = os.path.dirname(output_file_path)
+ if not tf.gfile.Exists(output_dir):
+ tf.gfile.MakeDirs(output_dir)
+
+ # Store the debug information.
+ with tf.gfile.Open(output_file_path, 'w') as f:
+ f.write(json_format.MessageToJson(trace_info))
+
+
+def find_multiline_statements(source):
+ """Parses the python source and finds multiline statements.
+
+ Based on counting the number of open and closed parenthesis on each line.
+
+ Args:
+ source: The source code string.
+
+ Returns:
+ A dict that maps a line index A to a line index B, where A is the end of a
+ multiline statement and B is the start. Line indexing is 0-based.
+ """
+ # Get the AST.
+ tree = parser.suite(source)
+ line2paren_count = [0] * (source.count('\n') + 1)
+ _count_brackets_braces_parenthesis(tree.totuple(True), line2paren_count)
+
+ line2start = {}
+ for end in range(len(line2paren_count)):
+ if line2paren_count[end] >= 0:
+ # This is not the end of a multiline statement.
+ continue
+ cumulative_paren_count = 0
+ for start in range(end, -1, -1):
+ cumulative_paren_count += line2paren_count[start]
+ if cumulative_paren_count == 0:
+ line2start[end] = start
+ break
+ return line2start
+
+
+def _add_data_from_tensors(tensors, info):
+ for t in tensors:
+ tensor_info = info.add()
+
+ shape = t.get_shape()
+ if shape.ndims:
+ shape = [(-1 if s is None else s) for s in shape.as_list()]
+ tensor_info.shape.extend(shape)
+ tensor_info.dtype = t.dtype.name
+ tensor_info.num_bytes_per_elem = t.dtype.size
+
+ for c in t.consumers():
+ tensor_info.consumers.append(c.name)
+
+
+def _ignore_file_path(fname, ignore_regex_fpaths):
+ for regex_pattern in ignore_regex_fpaths:
+ if re.search(regex_pattern, fname):
+ return True
+ return False
+
+
+def _count_brackets_braces_parenthesis(node, line2par):
+ if isinstance(node[1], tuple):
+ for child in node[1:]:
+ _count_brackets_braces_parenthesis(child, line2par)
+ else:
+ tok = node[0]
+ if tok in TOKENS:
+ lineno = node[2]
+ line2par[lineno - 1] += (1 if tok in LEFT_TOKENS else -1)
+ return line2par
diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto b/tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto
new file mode 100644
index 0000000000..09013c6387
--- /dev/null
+++ b/tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto
@@ -0,0 +1,60 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package tensorflow;
+
+message TraceInfo {
+ repeated OpInfo ops = 1;
+ repeated FileInfo files = 2;
+}
+
+message OpInfo {
+ string name = 1;
+ string op_type = 2;
+ string device = 3;
+ repeated LineTrace traceback = 4;
+ repeated TensorInfo inputs = 5;
+ repeated TensorInfo outputs = 6;
+}
+
+message LineTrace {
+ // Absolute file path.
+ string file_path = 1;
+ // 1-based line number.
+ uint32 line_number = 2;
+}
+
+message TensorInfo {
+ // Size of the tensor for each dimension. Value of -1 denotes "unknown"
+ // size for that dimension.
+ repeated int32 shape = 1;
+ // The data type of the tensor.
+ string dtype = 2;
+ // Number of bytes per element in the tensor.
+ uint32 num_bytes_per_elem = 3;
+ // List of operation names that consume this tensor.
+ repeated string consumers = 4;
+}
+
+message FileInfo {
+ // Absolute file path to the source code.
+ string file_path = 1;
+ string source_code = 2;
+ // Map from end of statement to start of statement. End and start are 0-based
+ // line indexes.
+ map<uint32, uint32> multiline_statements = 3;
+}
diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace_test.py b/tensorflow/contrib/tensorboard/plugins/trace/trace_test.py
new file mode 100644
index 0000000000..e67bde9d59
--- /dev/null
+++ b/tensorflow/contrib/tensorboard/plugins/trace/trace_test.py
@@ -0,0 +1,91 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.contrib.tensorboard.plugins.trace package."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tempfile
+import tensorflow as tf
+
+from google.protobuf import json_format
+from tensorflow.contrib.tensorboard.plugins import trace
+
+
+class TraceTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._temp_dir = tempfile.mkdtemp()
+ self._temp_trace_json = self._temp_dir + 'trace.json'
+
+ def tearDown(self):
+ tf.gfile.DeleteRecursively(self._temp_dir)
+
+ def testEmptyGraph(self):
+ trace_info = self._store_and_read_trace_info()
+ self.assertEqual(len(trace_info.ops), 0)
+
+ def testHasSourceCodeOfThisFile(self):
+ tf.constant(0)
+ trace_info = self._store_and_read_trace_info()
+
+ self.assertTrue(trace_info.files)
+ for file_info in trace_info.files:
+ if file_info.file_path.endswith('trace_test.py'):
+ return
+ self.fail('trace_test file not found in the trace info json')
+
+ def testHasTheConstantOp(self):
+ tf.constant(0)
+ trace_info = self._store_and_read_trace_info()
+
+ self.assertTrue(trace_info.ops)
+
+ for op in trace_info.ops:
+ if op.op_type == 'Const':
+ return
+ self.fail('Could not find operation of type `Const` in the graph')
+
+ def testMultilineStatements(self):
+ source = """def test():
+ a(4,
+ 3,
+ 1)
+
+ b(3, 4, 5)
+
+ c((4, 3),
+ (),
+ )
+ """
+ line2start = trace.find_multiline_statements(source)
+
+ self.assertEqual(line2start[3], 1)
+ self.assertEqual(line2start[9], 7)
+ self.assertEqual(len(line2start), 2)
+
+ def _store_and_read_trace_info(self):
+ trace.store_trace_info(self._temp_trace_json)
+ trace_info = trace.TraceInfo()
+
+ with tf.gfile.Open(self._temp_trace_json) as f:
+ text = f.read().decode('utf-8')
+ json_format.Parse(text, trace_info)
+
+ return trace_info
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/testing/python/framework/fake_summary_writer.py b/tensorflow/contrib/testing/python/framework/fake_summary_writer.py
index 7966b1a2c0..3b06bcc0e0 100644
--- a/tensorflow/contrib/testing/python/framework/fake_summary_writer.py
+++ b/tensorflow/contrib/testing/python/framework/fake_summary_writer.py
@@ -19,8 +19,8 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import summary_pb2
+from tensorflow.python import summary
from tensorflow.python.summary.writer import writer_cache
-from tensorflow.python.training import summary_io
# TODO(ptucker): Replace with mock framework.
@@ -33,16 +33,16 @@ class FakeSummaryWriter(object):
def install(cls):
if cls._replaced_summary_writer:
raise ValueError('FakeSummaryWriter already installed.')
- cls._replaced_summary_writer = summary_io.SummaryWriter
- summary_io.SummaryWriter = FakeSummaryWriter
- writer_cache.SummaryWriter = FakeSummaryWriter
+ cls._replaced_summary_writer = summary.FileWriter
+ summary.FileWriter = FakeSummaryWriter
+ writer_cache.FileWriter = FakeSummaryWriter
@classmethod
def uninstall(cls):
if not cls._replaced_summary_writer:
raise ValueError('FakeSummaryWriter not installed.')
- summary_io.SummaryWriter = cls._replaced_summary_writer
- writer_cache.SummaryWriter = cls._replaced_summary_writer
+ summary.FileWriter = cls._replaced_summary_writer
+ writer_cache.FileWriter = cls._replaced_summary_writer
cls._replaced_summary_writer = None
def __init__(self, logdir, graph=None):
@@ -86,18 +86,18 @@ class FakeSummaryWriter(object):
if expected_session_logs is not None:
test_case.assertEqual(expected_session_logs, self._added_session_logs)
- def add_summary(self, summary, current_global_step):
+ def add_summary(self, summ, current_global_step):
"""Add summary."""
- if isinstance(summary, bytes):
+ if isinstance(summ, bytes):
summary_proto = summary_pb2.Summary()
- summary_proto.ParseFromString(summary)
- summary = summary_proto
+ summary_proto.ParseFromString(summ)
+ summ = summary_proto
if current_global_step in self._summaries:
step_summaries = self._summaries[current_global_step]
else:
step_summaries = []
self._summaries[current_global_step] = step_summaries
- step_summaries.append(summary)
+ step_summaries.append(summ)
# NOTE: Ignore global_step since its value is non-deterministic.
def add_graph(self, graph, global_step=None, graph_def=None):
diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py
index ee4208d312..9070cd6e8d 100644
--- a/tensorflow/contrib/training/python/training/evaluation.py
+++ b/tensorflow/contrib/training/python/training/evaluation.py
@@ -144,7 +144,6 @@ from tensorflow.contrib.framework.python.ops import variables
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import summary
from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import monitored_session
@@ -254,7 +253,7 @@ def get_or_create_eval_step():
class StopAfterNEvalsHook(session_run_hook.SessionRunHook):
- """A run hook used by the evaluation routines to run the `eval_ops` N times."""
+ """Run hook used by the evaluation routines to run the `eval_ops` N times."""
def __init__(self, num_evals):
"""Constructs the run hook.
@@ -274,6 +273,7 @@ class StopAfterNEvalsHook(session_run_hook.SessionRunHook):
def after_run(self, run_context, run_values):
evals_completed = run_values.results['evals_completed']
+ logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals)
if evals_completed >= self._num_evals:
run_context.request_stop()
@@ -299,7 +299,7 @@ class _FinalOpsHook(session_run_hook.SessionRunHook):
return self._final_ops_values
def end(self, session):
- if self._final_ops:
+ if self._final_ops is not None:
self._final_ops_values = session.run(self._final_ops,
feed_dict=self._final_ops_feed_dict)
@@ -379,14 +379,14 @@ def evaluate_once(
the requested number of times.
Optionally, a user can pass in `final_ops`, a single `Tensor`, a list of
- `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is evaluated
- a single time after `eval_ops` has finished running and the fetched values of
- `final_ops` are returned. If `final_ops` is left as `None`, then `None` is
- returned.
+ `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is
+ evaluated a single time after `eval_ops` has finished running and the fetched
+ values of `final_ops` are returned. If `final_ops` is left as `None`, then
+ `None` is returned.
One may also consider using a `tf.contrib.training.SummaryAtEndHook` to record
- summaries after the `eval_ops` have run. If `eval_ops` is `None`, the summaries
- run immedietly after the model checkpoint has been restored.
+ summaries after the `eval_ops` have run. If `eval_ops` is `None`, the
+ summaries run immedietly after the model checkpoint has been restored.
Note that `evaluate_once` creates a local variable used to track the number of
evaluations run via `tf.contrib.training.get_or_create_eval_step`.
@@ -403,8 +403,8 @@ def evaluate_once(
eval_ops: A operation which is run until the session is requested to stop,
commonly done by a `tf.contrib.training.StopAfterNEvalsHook`.
feed_dict: The feed dictionary to use when executing the `eval_ops`.
- final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names to
- `Tensors`.
+ final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
+ to `Tensors`.
final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`.
variables_to_restore: A list of TensorFlow variables to restore during
evaluation. If the argument is left as `None` then
@@ -420,9 +420,14 @@ def evaluate_once(
eval_step = get_or_create_eval_step()
if eval_ops is not None:
- eval_ops = control_flow_ops.with_dependencies(
- [eval_ops],
- state_ops.assign_add(eval_step, 1))
+ update_eval_step = state_ops.assign_add(eval_step, 1)
+
+ if isinstance(eval_ops, dict):
+ eval_ops['update_eval_step'] = update_eval_step
+ elif isinstance(eval_ops, (tuple, list)):
+ eval_ops = list(eval_ops) + [update_eval_step]
+ else:
+ eval_ops = [eval_ops, update_eval_step]
# Must come before the scaffold check.
if scaffold and scaffold.saver:
@@ -484,14 +489,14 @@ def evaluate_repeatedly(
the requested number of times.
Optionally, a user can pass in `final_ops`, a single `Tensor`, a list of
- `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is evaluated
- a single time after `eval_ops` has finished running and the fetched values of
- `final_ops` are returned. If `final_ops` is left as `None`, then `None` is
- returned.
+ `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is
+ evaluated a single time after `eval_ops` has finished running and the fetched
+ values of `final_ops` are returned. If `final_ops` is left as `None`, then
+ `None` is returned.
One may also consider using a `tf.contrib.training.SummaryAtEndHook` to record
- summaries after the `eval_ops` have run. If `eval_ops` is `None`, the summaries
- run immedietly after the model checkpoint has been restored.
+ summaries after the `eval_ops` have run. If `eval_ops` is `None`, the
+ summaries run immedietly after the model checkpoint has been restored.
Note that `evaluate_once` creates a local variable used to track the number of
evaluations run via `tf.contrib.training.get_or_create_eval_step`.
@@ -508,8 +513,8 @@ def evaluate_repeatedly(
eval_ops: A operation which is run until the session is requested to stop,
commonly done by a `tf.contrib.training.StopAfterNEvalsHook`.
feed_dict: The feed dictionary to use when executing the `eval_ops`.
- final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names to
- `Tensors`.
+ final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
+ to `Tensors`.
final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`.
variables_to_restore: A list of TensorFlow variables to restore during
evaluation. If the argument is left as `None` then
@@ -530,9 +535,14 @@ def evaluate_repeatedly(
eval_step = get_or_create_eval_step()
if eval_ops is not None:
- eval_ops = control_flow_ops.with_dependencies(
- [eval_ops],
- state_ops.assign_add(eval_step, 1))
+ update_eval_step = state_ops.assign_add(eval_step, 1)
+
+ if isinstance(eval_ops, dict):
+ eval_ops['update_eval_step'] = update_eval_step
+ elif isinstance(eval_ops, (tuple, list)):
+ eval_ops = list(eval_ops) + [update_eval_step]
+ else:
+ eval_ops = [eval_ops, update_eval_step]
# Must come before the scaffold check.
if scaffold and scaffold.saver:
@@ -572,7 +582,9 @@ def evaluate_repeatedly(
'Finished evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
time.gmtime()))
num_evaluations += 1
- if num_evaluations >= max_number_of_evaluations:
+
+ reached_max = num_evaluations >= max_number_of_evaluations
+ if max_number_of_evaluations and reached_max:
return final_ops_hook.final_ops_values
logging.info('Timed-out waiting for a checkpoint.')
diff --git a/tensorflow/contrib/training/python/training/resample.py b/tensorflow/contrib/training/python/training/resample.py
index 513b7f59a0..ffaf6b8a03 100644
--- a/tensorflow/contrib/training/python/training/resample.py
+++ b/tensorflow/contrib/training/python/training/resample.py
@@ -128,7 +128,7 @@ def resample_at_rate(inputs, rates, scope=None, seed=None, back_prop=False):
def weighted_resample(inputs, weights, overall_rate, scope=None,
- mean_decay=0.999, warmup=10, seed=None):
+ mean_decay=0.999, seed=None):
"""Performs an approximate weighted resampling of `inputs`.
This method chooses elements from `inputs` where each item's rate of
@@ -142,9 +142,6 @@ def weighted_resample(inputs, weights, overall_rate, scope=None,
overall_rate: Desired overall rate of resampling.
scope: Scope to use for the op.
mean_decay: How quickly to decay the running estimate of the mean weight.
- warmup: Until the resulting tensor has been evaluated `warmup`
- times, the resampling menthod uses the true mean over all calls
- as its weight estimate, rather than a decayed mean.
seed: Random seed.
Returns:
@@ -158,26 +155,16 @@ def weighted_resample(inputs, weights, overall_rate, scope=None,
# overall rate, and a weight twice the average has twice the rate,
# etc.
with ops.name_scope(scope, 'weighted_resample', inputs) as opscope:
- # First: Maintain a running estimated mean weight, with decay
- # adjusted (by also maintaining an invocation count) during the
- # warmup period so that at the beginning, there aren't too many
- # zeros mixed in, throwing the average off.
+ # First: Maintain a running estimated mean weight, with zero debiasing
+ # enabled (by default) to avoid throwing the average off.
with variable_scope.variable_scope(scope, 'estimate_mean', inputs):
- count_so_far = variable_scope.get_local_variable(
- 'resample_count', initializer=0)
-
estimated_mean = variable_scope.get_local_variable(
'estimated_mean', initializer=0.0)
- count = count_so_far.assign_add(1)
- real_decay = math_ops.minimum(
- math_ops.truediv((count - 1), math_ops.minimum(count, warmup)),
- mean_decay)
-
batch_mean = math_ops.reduce_mean(weights)
mean = moving_averages.assign_moving_average(
- estimated_mean, batch_mean, real_decay, zero_debias=False)
+ estimated_mean, batch_mean, mean_decay)
# Then, normalize the weights into rates using the mean weight and
# overall target rate:
diff --git a/tensorflow/contrib/training/python/training/resample_test.py b/tensorflow/contrib/training/python/training/resample_test.py
index 9324a940d3..1f8d332125 100644
--- a/tensorflow/contrib/training/python/training/resample_test.py
+++ b/tensorflow/contrib/training/python/training/resample_test.py
@@ -40,7 +40,8 @@ class ResampleTest(tf.test.TestCase):
resampled_back_out = tf.contrib.training.resample_at_rate(
resampled_in, 1.0/rates, seed=456)
- init = tf.local_variables_initializer()
+ init = tf.group(tf.local_variables_initializer(),
+ tf.global_variables_initializer())
with self.test_session() as s:
s.run(init) # initialize
@@ -81,7 +82,8 @@ class ResampleTest(tf.test.TestCase):
invrates = 1.0/rates
- init = tf.local_variables_initializer()
+ init = tf.group(tf.local_variables_initializer(),
+ tf.global_variables_initializer())
expected_sum_op = tf.reduce_sum(vals)
with self.test_session() as s:
s.run(init)
diff --git a/tensorflow/contrib/training/python/training/sampling_ops.py b/tensorflow/contrib/training/python/training/sampling_ops.py
index 2efc50cb4e..d5e6878e75 100644
--- a/tensorflow/contrib/training/python/training/sampling_ops.py
+++ b/tensorflow/contrib/training/python/training/sampling_ops.py
@@ -387,7 +387,7 @@ def _calculate_acceptance_probabilities(init_probs, target_probs):
ratio_l = target_probs / init_probs
# Replace NaNs with 0s.
- ratio_l = math_ops.select(math_ops.is_nan(ratio_l),
+ ratio_l = array_ops.where(math_ops.is_nan(ratio_l),
array_ops.zeros_like(ratio_l),
ratio_l)
diff --git a/tensorflow/contrib/training/python/training/sampling_ops_test.py b/tensorflow/contrib/training/python/training/sampling_ops_test.py
index f0501b3f3e..788e01efd7 100644
--- a/tensorflow/contrib/training/python/training/sampling_ops_test.py
+++ b/tensorflow/contrib/training/python/training/sampling_ops_test.py
@@ -133,7 +133,8 @@ class StratifiedSampleTest(tf.test.TestCase):
val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs)
batches += tf.contrib.training.stratified_sample(
val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs)
- summary_op = tf.merge_summary(tf.get_collection(tf.GraphKeys.SUMMARIES))
+ summary_op = tf.contrib.deprecated.merge_summary(
+ tf.get_collection(tf.GraphKeys.SUMMARIES))
with self.test_session() as sess:
coord = tf.train.Coordinator()
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 3b81fa859a..3f9b94128a 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1292,7 +1292,10 @@ cc_library(
hdrs = [
"platform/regexp.h",
],
- visibility = ["//tensorflow/tools/tfprof:__subpackages__"],
+ visibility = [
+ "//tensorflow/compiler:__subpackages__",
+ "//tensorflow/tools/tfprof:__subpackages__",
+ ],
deps = [":lib_internal"],
)
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 8a49c7d3ab..e1f2c55230 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -407,7 +407,9 @@ Status DirectSession::Run(const RunOptions& run_options,
&executors_and_keys, &run_state_args));
// Create a run state and start execution.
- RunState run_state(input_tensor_names, output_names);
+ Executor::Args args;
+ args.step_id = step_id_counter_.fetch_add(1);
+ RunState run_state(input_tensor_names, output_names, args.step_id, &devices_);
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
CancellationManager step_cancellation_manager;
@@ -425,8 +427,6 @@ Status DirectSession::Run(const RunOptions& run_options,
run_state.executors_done.Notify();
});
- Executor::Args args;
- args.step_id = step_id_counter_.fetch_add(1);
args.rendezvous = run_state.rendez;
args.cancellation_manager = &step_cancellation_manager;
args.runner = [this, pool](Executor::Args::Closure c) {
@@ -434,7 +434,7 @@ Status DirectSession::Run(const RunOptions& run_options,
};
args.session_state = &session_state_;
args.tensor_store = &run_state.tensor_store;
- args.step_resource_manager = &run_state.step_resource_manager;
+ args.step_container = &run_state.step_container;
if (LogMemory::IsEnabled()) {
LogMemory::RecordStep(args.step_id, run_state_args.handle);
}
@@ -582,7 +582,10 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
&run_state_args));
// Create the run state and save it for future PRun calls.
- RunState* run_state = new RunState(input_names, output_names);
+ Executor::Args args;
+ args.step_id = step_id_counter_.fetch_add(1);
+ RunState* run_state =
+ new RunState(input_names, output_names, args.step_id, &devices_);
run_state->rendez = new IntraProcessRendezvous(device_mgr_.get());
{
mutex_lock l(executor_lock_);
@@ -606,8 +609,6 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
run_state->executors_done.Notify();
});
- Executor::Args args;
- args.step_id = step_id_counter_.fetch_add(1);
args.rendezvous = run_state->rendez;
args.cancellation_manager = cancellation_manager_;
args.runner = [this, pool](Executor::Args::Closure c) {
@@ -615,7 +616,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
};
args.session_state = &session_state_;
args.tensor_store = &run_state->tensor_store;
- args.step_resource_manager = &run_state->step_resource_manager;
+ args.step_container = &run_state->step_container;
if (LogMemory::IsEnabled()) {
LogMemory::RecordStep(args.step_id, run_state_args.handle);
}
@@ -1173,7 +1174,16 @@ Status DirectSession::CreateGraphs(
}
DirectSession::RunState::RunState(const std::vector<string>& input_names,
- const std::vector<string>& output_names) {
+ const std::vector<string>& output_names,
+ int64 step_id,
+ const std::vector<Device*>* devices)
+ : step_container(step_id, [devices](const string& name) {
+ for (auto d : *devices) {
+ if (!d->resource_manager()->Cleanup(name).ok()) {
+ // Do nothing...
+ }
+ }
+ }) {
// Initially all the feeds and fetches are pending.
for (auto& name : input_names) {
pending_inputs.emplace(name);
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 37f0277a40..127c08d0a4 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -151,10 +151,11 @@ class DirectSession : public Session {
std::unordered_set<string> pending_inputs;
std::unordered_set<string> pending_outputs;
TensorStore tensor_store;
- ResourceMgr step_resource_manager;
+ ScopedStepContainer step_container;
RunState(const std::vector<string>& input_names,
- const std::vector<string>& output_names);
+ const std::vector<string>& output_names, int64 step_id,
+ const std::vector<Device*>* devices);
~RunState();
};
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index ef531dc6c5..542ed70b4c 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -873,8 +873,8 @@ class ExecutorState {
Rendezvous* rendezvous_;
SessionState* session_state_;
TensorStore* tensor_store_;
- // Step-local resource manager.
- ResourceMgr* step_resource_manager_;
+ // Step-local container.
+ ScopedStepContainer* step_container_;
StepStatsCollector* stats_collector_;
// QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
// instead of a pointer? (avoids having to delete).
@@ -992,7 +992,7 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
rendezvous_(args.rendezvous),
session_state_(args.session_state),
tensor_store_(args.tensor_store),
- step_resource_manager_(args.step_resource_manager),
+ step_container_(args.step_container),
stats_collector_(args.stats_collector),
slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
call_frame_(args.call_frame),
@@ -1220,7 +1220,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
params.call_frame = call_frame_;
params.function_library = impl_->params_.function_library;
params.resource_manager = device->resource_manager();
- params.step_resource_manager = step_resource_manager_;
+ params.step_container = step_container_;
params.slice_reader_cache = slice_reader_cache_;
params.inputs = &inputs;
params.input_device_contexts = &input_device_contexts;
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index 2e9990d951..8cca22fb6f 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -88,7 +88,7 @@ class Executor {
CancellationManager* cancellation_manager = nullptr;
SessionState* session_state = nullptr;
TensorStore* tensor_store = nullptr;
- ResourceMgr* step_resource_manager = nullptr;
+ ScopedStepContainer* step_container = nullptr;
// If true, calls Sync() on the device.
bool sync_on_finish = false;
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 2d29f5176d..695c7244ae 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -261,7 +261,7 @@ class CallOp : public AsyncOpKernel {
done);
FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id();
- opts.step_resource_manager = ctx->step_resource_manager();
+ opts.step_container = ctx->step_container();
opts.runner = ctx->runner();
std::vector<Tensor> args;
args.reserve(ctx->num_inputs());
@@ -558,7 +558,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
Executor::Args exec_args;
// Inherit the step_id from the caller.
exec_args.step_id = opts.step_id;
- exec_args.step_resource_manager = opts.step_resource_manager;
+ exec_args.step_container = opts.step_container;
exec_args.call_frame = frame;
exec_args.cancellation_manager = opts.cancellation_manager;
exec_args.runner = *opts.runner;
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 731edd6ac3..4e89226752 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -359,8 +359,8 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
return;
}
- StartParallelExecutors(handle, item, rendezvous, collector, cost_graph,
- cancellation_manager,
+ StartParallelExecutors(handle, step_id, item, rendezvous, collector,
+ cost_graph, cancellation_manager,
[this, item, rendezvous, done](const Status& s) {
done(s);
rendezvous->Unref();
@@ -368,22 +368,25 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
});
}
-void GraphMgr::StartParallelExecutors(const string& handle, Item* item,
- Rendezvous* rendezvous,
+void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
+ Item* item, Rendezvous* rendezvous,
StepStatsCollector* collector,
CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
StatusCallback done) {
const int num_units = item->units.size();
CHECK_GE(num_units, 1);
- ResourceMgr* step_resource_manager = new ResourceMgr;
+ ScopedStepContainer* step_container =
+ new ScopedStepContainer(step_id, [this](const string& name) {
+ worker_env_->device_mgr->ClearContainers({name});
+ });
// NOTE: Transfer one ref of rendezvous and item.
ExecutorBarrier* barrier = new ExecutorBarrier(
- num_units, rendezvous, [this, item, collector, cost_graph,
- step_resource_manager, done](const Status& s) {
+ num_units, rendezvous, [this, item, collector, cost_graph, step_container,
+ done](const Status& s) {
BuildCostModel(item, collector, cost_graph);
done(s);
- delete step_resource_manager;
+ delete step_container;
});
Executor::Args args;
{
@@ -393,7 +396,7 @@ void GraphMgr::StartParallelExecutors(const string& handle, Item* item,
args.rendezvous = rendezvous;
args.cancellation_manager = cancellation_manager;
args.stats_collector = collector;
- args.step_resource_manager = step_resource_manager;
+ args.step_container = step_container;
args.sync_on_finish = true;
if (LogMemory::IsEnabled()) {
LogMemory::RecordStep(args.step_id, handle);
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h
index a3771e6747..e9b8d415ed 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.h
+++ b/tensorflow/core/distributed_runtime/graph_mgr.h
@@ -140,7 +140,7 @@ class GraphMgr {
// mechanism to gc these graphs.
std::unordered_map<string, Item*> table_;
- void StartParallelExecutors(const string& handle, Item* item,
+ void StartParallelExecutors(const string& handle, int64 step_id, Item* item,
Rendezvous* rendezvous,
StepStatsCollector* collector,
CostGraphDef* cost_graph,
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index 4f8eb04c95..06859c5290 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -66,8 +66,13 @@ struct AllocatorStats {
// device memory.
class Allocator {
public:
+#ifdef EIGEN_VECTORIZE_AVX512
+ // Align to 64 byte boundary.
+ static constexpr size_t kAllocatorAlignment = 64;
+#else
// Align to 32 byte boundary.
static constexpr size_t kAllocatorAlignment = 32;
+#endif
virtual ~Allocator();
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index dcf0ae40d5..bc1441ac6e 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -739,6 +739,7 @@ Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
// Merge all the non-concat dims, and sum the concat dim to make an output
// shape.
const int32 concat_dim = concat_dim_t->scalar<int32>()();
+
// Minimum required number of dimensions.
const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1;
@@ -749,7 +750,11 @@ Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before));
DimensionHandle output_middle = c->Dim(input, concat_dim);
- TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after));
+ if (concat_dim == -1) {
+ output_after = c->Scalar(); // no dimensions.
+ } else {
+ TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after));
+ }
for (int i = end_value_index - 2; i >= start_value_index; --i) {
ShapeHandle before;
@@ -758,7 +763,11 @@ Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before));
DimensionHandle middle = c->Dim(input, concat_dim);
- TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after));
+ if (concat_dim == -1) {
+ after = c->Scalar();
+ } else {
+ TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after));
+ }
TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before));
TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle));
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 5cb4e28faf..1fa3aee517 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -35,6 +35,7 @@ namespace tensorflow {
class CancellationManager;
class OpKernel;
class ResourceMgr;
+class ScopedStepContainer;
// FunctionDefHelper::Create is a convenient helper to construct a
// FunctionDef proto.
@@ -381,8 +382,8 @@ class FunctionLibraryRuntime {
// The id of the step that is calling this function.
int64 step_id = 0;
- // Per-step resource manager. Does not take ownership.
- ResourceMgr* step_resource_manager = nullptr;
+ // Per-step container.
+ ScopedStepContainer* step_container;
std::function<void(std::function<void()>)>* runner = nullptr;
};
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 50520bb3fd..c4023d2ced 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -222,7 +222,7 @@ OpKernelContext::~OpKernelContext() {
Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
Allocator* allocator =
- params_->device->GetStepAllocator(attr, step_resource_manager());
+ params_->device->GetStepAllocator(attr, resource_manager());
if (params_->track_allocations) {
mutex_lock lock(mu_);
for (const auto& wrapped : wrapped_allocators_) {
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 4a66d43e50..7318a2dc7d 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -511,8 +511,9 @@ class OpKernelContext {
// Shared resources accessible by this op kernel invocation.
ResourceMgr* resource_manager = nullptr;
- // Per-step resources accessible by this op kernel invocation.
- ResourceMgr* step_resource_manager = nullptr;
+ // Per-step resources accessible by this op kernel invocation should be
+ // stored in this container..
+ ScopedStepContainer* step_container = nullptr;
// Mechanism used by this op kernel invocation to communicate with
// computations running on other devices.
@@ -938,9 +939,9 @@ class OpKernelContext {
// not be called from Op kernels.
void retrieve_accessed_tensors(TensorReferenceVector* out_vector);
- // Per-step resource manager for use by white-listed internal ops.
- ResourceMgr* step_resource_manager() const {
- return params_->step_resource_manager;
+ // Per-step container for use by white-listed internal ops.
+ ScopedStepContainer* step_container() const {
+ return params_->step_container;
}
// Helper routines for the OP_REQUIRES macros
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index ae4186ee71..a1053669b7 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -79,6 +79,24 @@ class ResourceBase : public core::RefCounted {
virtual string DebugString() = 0;
};
+// Container used for per-step resources.
+class ScopedStepContainer {
+ public:
+ // step_id: the unique ID of this step. Doesn't have to be sequential, just
+ // has to be unique.
+ // cleanup: callback to delete a container of this name.
+ ScopedStepContainer(const int64 step_id,
+ std::function<void(const string&)> cleanup)
+ : name_(strings::StrCat("__per_step_", step_id)), cleanup_(cleanup) {}
+ ~ScopedStepContainer() { cleanup_(name_); }
+
+ const string& name() const { return name_; }
+
+ private:
+ const string name_;
+ const std::function<void(const string&)> cleanup_;
+};
+
class ResourceMgr {
public:
ResourceMgr();
@@ -165,6 +183,9 @@ class ResourceMgr {
template <typename T>
ResourceHandle MakeResourceHandle(OpKernelContext* ctx, const string& container,
const string& name);
+template <typename T>
+ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx,
+ const string& name);
// Returns a resource handle from a numbered op input.
ResourceHandle HandleFromInput(OpKernelContext* ctx, int input);
@@ -385,6 +406,12 @@ ResourceHandle MakeResourceHandle(OpKernelContext* ctx, const string& container,
return result;
}
+template <typename T>
+ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx,
+ const string& name) {
+ return MakeResourceHandle<T>(ctx, ctx->step_container()->name(), name);
+}
+
namespace internal {
template <typename T>
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index 8e9eceb699..a35f3ff15c 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -74,6 +74,28 @@ struct RecvInfo {
typedef std::unordered_map<DupRecvKey, RecvInfo, DupRecvKeyHash, DupRecvKeyEq>
DupRecvTable;
+struct DupControlKey {
+ int dst_node_id; // Edge's dst node id
+ GraphDef* src_graph; // Edge's src node is in this subgraph
+};
+
+struct DupControlKeyHash {
+ size_t operator()(const DupControlKey& k) const {
+ return Hash64(reinterpret_cast<const char*>(&k.src_graph),
+ sizeof(k.src_graph), k.dst_node_id);
+ }
+};
+
+struct DupControlKeyEq {
+ bool operator()(const DupControlKey& x, const DupControlKey& y) const {
+ return (x.dst_node_id == y.dst_node_id) && (x.src_graph == y.src_graph);
+ }
+};
+
+typedef std::unordered_map<DupControlKey, NodeDef*, DupControlKeyHash,
+ DupControlKeyEq>
+ DupControlTable;
+
struct PairIntHash {
public:
std::size_t operator()(const std::pair<int, int>& x) const {
@@ -825,6 +847,7 @@ Status Partition(const PartitionOptions& opts, Graph* g,
string dstp;
std::vector<const Edge*> inputs;
DupRecvTable dup_recv(3);
+ DupControlTable dup_control(3);
// For a node dst, 'ref_recvs' remembers the recvs introduced by a ref
// edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref
// edge to dst. We will add a control edge for every pair in
@@ -918,7 +941,9 @@ Status Partition(const PartitionOptions& opts, Graph* g,
}
// Check whether there is already a send/recv pair transferring
- // the same tensor/control from the src to dst partition.
+ // the same tensor/control from src to the dst partition. This
+ // handles the dedup case when a single source in one partition
+ // going to multiple destinations in another partition.
const bool on_host = IsDstInputOnHost(edge, g_info);
DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
auto iter = dup_recv.find(key);
@@ -943,6 +968,16 @@ Status Partition(const PartitionOptions& opts, Graph* g,
NodeDefBuilder::NodeOut send_from;
if (edge->IsControlEdge()) {
+ // This handles the dedup case when multiple control edges going from
+ // one partition to a single destination in another partition.
+ DupControlKey key{dst->id(), src_graph};
+ auto iter = dup_control.find(key);
+ if (iter != dup_control.end()) {
+ // This could cause start_time(src) > start_time(iter->second).
+ AddInput(iter->second, src->name(), Graph::kControlSlot);
+ continue;
+ }
+
// Insert a dummy const node that will generate a tiny
// data element to be sent from send to recv.
VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "["
@@ -956,6 +991,7 @@ Status Partition(const PartitionOptions& opts, Graph* g,
}
AddInput(dummy, src->name(), Graph::kControlSlot);
send_from.Reset(dummy->name(), 0, DT_FLOAT);
+ dup_control[key] = dummy;
} else {
send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
}
diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc
index d8322e6077..fd259f0b40 100644
--- a/tensorflow/core/graph/graph_partition_test.cc
+++ b/tensorflow/core/graph/graph_partition_test.cc
@@ -398,5 +398,37 @@ TEST_F(GraphPartitionTest, PartitionIncompleteGraph) {
EXPECT_EQ(error::INVALID_ARGUMENT, status.code()) << status;
}
+TEST_F(GraphPartitionTest, CrossDevice_MultiControl) {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ auto a1 = Input(in_.WithOpName("A1"));
+ auto a2 = Input(in_.WithOpName("A2"));
+ auto b1 = Input(in_.WithOpName("B1"));
+ Combine(
+ in_.WithOpName("B2").WithControlDependencies(a1).WithControlDependencies(
+ a2),
+ b1, b1);
+
+ Partition(ToGraphDef(), &partitions_);
+ EXPECT_EQ(2, partitions_.size());
+
+ string a = "/job:a/replica:0/task:0/cpu:0";
+ string b = "/job:a/replica:0/task:0/cpu:1";
+ a1 = Input(scope_a_.WithOpName("A1"));
+ a2 = Input(scope_a_.WithOpName("A2"));
+ auto c = Const(scope_a_.WithOpName("A1/_0")
+ .WithControlDependencies(a1)
+ .WithControlDependencies(a2),
+ {});
+ _Send(scope_a_.WithOpName("A1/_1"), c, "edge_3_A1", a, 82, b);
+ ExpectMatchA();
+
+ auto recv =
+ _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_3_A1", a, 82, b);
+ auto id = Identity(scope_b_.WithOpName("A1/_3"), recv);
+ b1 = Input(scope_b_.WithOpName("B1"));
+ Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1);
+ ExpectMatchB();
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index cd76a40a47..3e8da9884e 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -30,7 +30,6 @@ load(
"tf_cc_tests",
"tf_copts",
"tf_opts_nortti_if_android",
- "tf_kernel_libraries",
"tf_kernel_library",
"cc_header_only_library",
)
@@ -396,7 +395,6 @@ ARRAY_DEPS = [
":fill_functor",
":gather_functor",
":ops_util",
- ":strided_slice_op",
":transpose_functor",
"//tensorflow/core:array_grad",
"//tensorflow/core:array_ops_op_lib",
@@ -432,49 +430,185 @@ tf_kernel_library(
],
)
-tf_kernel_libraries(
+cc_library(
name = "array",
- libs = [
+ deps = [
":batch_space_ops",
+ ":bcast_ops",
+ ":bitcast_op",
+ ":concat_op",
+ ":constant_op",
":depth_space_ops",
+ ":diag_op",
+ ":edit_distance_op",
":extract_image_patches_op",
+ ":gather_nd_op",
+ ":gather_op",
+ ":identity_op",
+ ":listdiff_op",
+ ":matrix_band_part_op",
+ ":matrix_diag_op",
+ ":matrix_set_diag_op",
+ ":mirror_pad_op",
+ ":one_hot_op",
+ ":pack_op",
+ ":pad_op",
+ ":quantize_and_dequantize_op",
+ ":reshape_op",
+ ":reverse_op",
+ ":reverse_sequence_op",
+ ":shape_ops",
+ ":slice_op",
":split_op",
":split_v_op",
+ ":strided_slice_op",
+ ":tile_ops",
+ ":transpose_op",
+ ":unique_op",
":unpack_op",
+ ":where_op",
],
- prefixes = [
- "bcast_ops",
- "bitcast_op",
- "concat_op",
- "constant_op",
- "diag_op",
- "matrix_band_part_op",
- "matrix_diag_op",
- "matrix_set_diag_op",
- "edit_distance_op",
- "gather_op",
- "gather_nd_op",
- "identity_op",
- "listdiff_op",
- "mirror_pad_op",
- "one_hot_op",
- "pack_op",
- "pad_op",
- "quantize_and_dequantize_op",
- "reshape_op",
- "reverse_op",
- "reverse_sequence_op",
- "shape_ops",
- "slice_op",
- "tile_ops",
- "transpose_op",
- "unique_op",
- "where_op",
- ],
+)
+
+tf_kernel_library(
+ name = "bcast_ops",
+ prefix = "bcast_ops",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "bitcast_op",
+ prefix = "bitcast_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "concat_op",
+ prefix = "concat_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "constant_op",
+ prefix = "constant_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "diag_op",
+ prefix = "diag_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "edit_distance_op",
+ prefix = "edit_distance_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "gather_nd_op",
+ prefix = "gather_nd_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "gather_op",
+ prefix = "gather_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "identity_op",
+ prefix = "identity_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "listdiff_op",
+ prefix = "listdiff_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "matrix_band_part_op",
+ prefix = "matrix_band_part_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "matrix_diag_op",
+ prefix = "matrix_diag_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "matrix_set_diag_op",
+ prefix = "matrix_set_diag_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "mirror_pad_op",
+ prefix = "mirror_pad_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "one_hot_op",
+ prefix = "one_hot_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "pack_op",
+ prefix = "pack_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "pad_op",
+ prefix = "pad_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "quantize_and_dequantize_op",
+ prefix = "quantize_and_dequantize_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "reshape_op",
+ prefix = "reshape_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "reverse_op",
+ prefix = "reverse_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "reverse_sequence_op",
+ prefix = "reverse_sequence_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "shape_ops",
+ prefix = "shape_ops",
deps = ARRAY_DEPS,
)
tf_kernel_library(
+ name = "slice_op",
+ prefix = "slice_op",
+ deps = ARRAY_DEPS + [":strided_slice_op"],
+)
+
+tf_kernel_library(
name = "split_op",
gpu_srcs = ["cuda_device_array.h"],
prefix = "split_op",
@@ -489,11 +623,35 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "tile_ops",
+ prefix = "tile_ops",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "transpose_op",
+ prefix = "transpose_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
+ name = "unique_op",
+ prefix = "unique_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
name = "unpack_op",
prefix = "unpack_op",
deps = ARRAY_DEPS + [":split_lib"],
)
+tf_kernel_library(
+ name = "where_op",
+ prefix = "where_op",
+ deps = ARRAY_DEPS,
+)
+
tf_cc_test(
name = "batch_norm_op_test",
size = "small",
@@ -918,84 +1076,167 @@ tf_cc_test(
],
)
-tf_kernel_libraries(
+cc_library(
name = "data_flow",
- libs = [
- ":dynamic",
- ":lookup",
- ],
- prefixes = [
- "conditional_accumulator_base_op",
- "conditional_accumulator_op",
- "barrier_ops",
- "fifo_queue_op",
- "priority_queue_op",
- "padding_fifo_queue_op",
- "queue_ops",
- "random_shuffle_queue_op",
- "session_ops",
- "sparse_conditional_accumulator_op",
- "stack_ops",
- "tensor_array_ops",
- ],
deps = [
- ":bounds_check",
- ":concat_lib",
- ":conditional_accumulator",
- ":conditional_accumulator_base",
- ":fifo_queue",
- ":initializable_lookup_table",
- ":lookup_util",
- ":padding_fifo_queue",
- ":priority_queue",
- ":queue_base",
- ":queue_op",
- ":sparse_conditional_accumulator",
- ":split_lib",
- ":tensor_array",
- ":typed_conditional_accumulator_base",
- ":typed_queue",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:data_flow_ops_op_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//third_party/eigen3",
- ],
+ ":barrier_ops",
+ ":conditional_accumulator_base_op",
+ ":conditional_accumulator_op",
+ ":dynamic_partition_op",
+ ":dynamic_stitch_op",
+ ":fifo_queue_op",
+ ":lookup_table_init_op",
+ ":lookup_table_op",
+ ":padding_fifo_queue_op",
+ ":priority_queue_op",
+ ":queue_ops",
+ ":random_shuffle_queue_op",
+ ":session_ops",
+ ":sparse_conditional_accumulator_op",
+ ":stack_ops",
+ ":tensor_array_ops",
+ ],
+)
+
+DATA_FLOW_DEPS = [
+ ":bounds_check",
+ ":concat_lib",
+ ":conditional_accumulator",
+ ":conditional_accumulator_base",
+ ":fifo_queue",
+ ":initializable_lookup_table",
+ ":lookup_util",
+ ":padding_fifo_queue",
+ ":priority_queue",
+ ":queue_base",
+ ":queue_op",
+ ":sparse_conditional_accumulator",
+ ":split_lib",
+ ":tensor_array",
+ ":typed_conditional_accumulator_base",
+ ":typed_queue",
+ "//third_party/eigen3",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:data_flow_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+]
+
+tf_kernel_library(
+ name = "conditional_accumulator_base_op",
+ prefix = "conditional_accumulator_base_op",
+ deps = DATA_FLOW_DEPS,
)
-tf_kernel_libraries(
- name = "dynamic",
- prefixes = [
- "dynamic_partition_op",
- "dynamic_stitch_op",
- ],
- deps = [
- ":bounds_check",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:data_flow_ops_op_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- ],
+tf_kernel_library(
+ name = "conditional_accumulator_op",
+ prefix = "conditional_accumulator_op",
+ deps = DATA_FLOW_DEPS,
)
-tf_kernel_libraries(
- name = "lookup",
- prefixes = [
- "lookup_table_init_op",
- "lookup_table_op",
- ],
- deps = [
- ":bounds_check",
- ":initializable_lookup_table",
- ":lookup_util",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:data_flow_ops_op_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- ],
+tf_kernel_library(
+ name = "barrier_ops",
+ prefix = "barrier_ops",
+ deps = DATA_FLOW_DEPS,
+)
+
+tf_kernel_library(
+ name = "fifo_queue_op",
+ prefix = "fifo_queue_op",
+ deps = DATA_FLOW_DEPS,
+)
+
+tf_kernel_library(
+ name = "padding_fifo_queue_op",
+ prefix = "padding_fifo_queue_op",
+ deps = DATA_FLOW_DEPS,
+)
+
+tf_kernel_library(
+ name = "priority_queue_op",
+ prefix = "priority_queue_op",
+ deps = DATA_FLOW_DEPS,
+)
+
+tf_kernel_library(
+ name = "queue_ops",
+ prefix = "queue_ops",
+ deps = DATA_FLOW_DEPS,
+)
+
+tf_kernel_library(
+ name = "random_shuffle_queue_op",
+ prefix = "random_shuffle_queue_op",
+ deps = DATA_FLOW_DEPS,
+)
+
+tf_kernel_library(
+ name = "session_ops",
+ prefix = "session_ops",
+ deps = DATA_FLOW_DEPS,
+)
+
+tf_kernel_library(
+ name = "sparse_conditional_accumulator_op",
+ prefix = "sparse_conditional_accumulator_op",
+ deps = DATA_FLOW_DEPS,
+)
+
+tf_kernel_library(
+ name = "stack_ops",
+ prefix = "stack_ops",
+ deps = DATA_FLOW_DEPS,
+)
+
+tf_kernel_library(
+ name = "tensor_array_ops",
+ prefix = "tensor_array_ops",
+ deps = DATA_FLOW_DEPS,
+)
+
+DYNAMIC_DEPS = [
+ ":bounds_check",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:data_flow_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+]
+
+tf_kernel_library(
+ name = "dynamic_partition_op",
+ prefix = "dynamic_partition_op",
+ deps = DYNAMIC_DEPS,
+)
+
+tf_kernel_library(
+ name = "dynamic_stitch_op",
+ prefix = "dynamic_stitch_op",
+ deps = DYNAMIC_DEPS,
+)
+
+LOOKUP_DEPS = [
+ ":bounds_check",
+ ":initializable_lookup_table",
+ ":lookup_util",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:data_flow_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+]
+
+tf_kernel_library(
+ name = "lookup_table_init_op",
+ prefix = "lookup_table_init_op",
+ deps = LOOKUP_DEPS,
+)
+
+tf_kernel_library(
+ name = "lookup_table_op",
+ prefix = "lookup_table_op",
+ deps = LOOKUP_DEPS,
)
tf_cc_tests(
@@ -1136,41 +1377,150 @@ tf_kernel_library(
],
)
-tf_kernel_libraries(
+cc_library(
name = "image",
- prefixes = [
- "adjust_contrast_op",
- "adjust_hue_op",
- "colorspace_op",
- "crop_and_resize_op",
- "decode_jpeg_op",
- "decode_png_op",
- "decode_gif_op",
- "draw_bounding_box_op",
- "encode_jpeg_op",
- "attention_ops",
- "encode_png_op",
- "non_max_suppression_op",
- "random_crop_op",
- "resize_area_op",
- "resize_bicubic_op",
- "resize_bilinear_op",
- "resize_nearest_neighbor_op",
- "sample_distorted_bounding_box_op",
- ],
deps = [
- ":bounds_check",
- ":eigen_helpers",
- ":image_resizer_state",
- "//tensorflow/core:framework",
- "//tensorflow/core:gif_internal",
- "//tensorflow/core:image_ops_op_lib",
- "//tensorflow/core:jpeg_internal",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- "//third_party/eigen3",
- ],
+ ":adjust_contrast_op",
+ ":adjust_hue_op",
+ ":attention_ops",
+ ":colorspace_op",
+ ":crop_and_resize_op",
+ ":decode_gif_op",
+ ":decode_jpeg_op",
+ ":decode_png_op",
+ ":draw_bounding_box_op",
+ ":encode_jpeg_op",
+ ":encode_png_op",
+ ":non_max_suppression_op",
+ ":random_crop_op",
+ ":resize_area_op",
+ ":resize_bicubic_op",
+ ":resize_bilinear_op",
+ ":resize_nearest_neighbor_op",
+ ":sample_distorted_bounding_box_op",
+ ],
+)
+
+IMAGE_DEPS = [
+ ":bounds_check",
+ ":eigen_helpers",
+ ":image_resizer_state",
+ "//third_party/eigen3",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:gif_internal",
+ "//tensorflow/core:image_ops_op_lib",
+ "//tensorflow/core:jpeg_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+]
+
+tf_kernel_library(
+ name = "adjust_contrast_op",
+ prefix = "adjust_contrast_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "adjust_hue_op",
+ prefix = "adjust_hue_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "attention_ops",
+ prefix = "attention_ops",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "colorspace_op",
+ prefix = "colorspace_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "crop_and_resize_op",
+ prefix = "crop_and_resize_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "decode_jpeg_op",
+ prefix = "decode_jpeg_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "decode_png_op",
+ prefix = "decode_png_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "decode_gif_op",
+ prefix = "decode_gif_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "draw_bounding_box_op",
+ prefix = "draw_bounding_box_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "encode_jpeg_op",
+ prefix = "encode_jpeg_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "encode_png_op",
+ prefix = "encode_png_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "non_max_suppression_op",
+ prefix = "non_max_suppression_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "random_crop_op",
+ prefix = "random_crop_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "resize_area_op",
+ prefix = "resize_area_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "resize_bicubic_op",
+ prefix = "resize_bicubic_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "resize_bilinear_op",
+ prefix = "resize_bilinear_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "resize_nearest_neighbor_op",
+ prefix = "resize_nearest_neighbor_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
+ name = "sample_distorted_bounding_box_op",
+ prefix = "sample_distorted_bounding_box_op",
+ deps = IMAGE_DEPS,
)
tf_cc_tests(
@@ -1254,47 +1604,102 @@ tf_cuda_cc_test(
],
)
-tf_kernel_libraries(
+cc_library(
name = "io",
- libs = [":save_restore"],
- prefixes = [
- "fixed_length_record_reader_op",
- "identity_reader_op",
- "matching_files_op",
- "reader_ops",
- "text_line_reader_op",
- "tf_record_reader_op",
- "whole_file_read_ops",
- ],
deps = [
- ":ops_util",
- ":reader_base",
- "//tensorflow/core:framework",
- "//tensorflow/core:io_ops_op_lib",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core/util/tensor_bundle",
+ ":fixed_length_record_reader_op",
+ ":identity_reader_op",
+ ":matching_files_op",
+ ":reader_ops",
+ ":restore_op",
+ ":save_op",
+ ":save_restore_v2_ops",
+ ":text_line_reader_op",
+ ":tf_record_reader_op",
+ ":whole_file_read_ops",
],
)
-tf_kernel_libraries(
- name = "save_restore",
- prefixes = [
- "restore_op",
- "save_op",
- "save_restore_v2_ops",
- ],
- deps = [
- ":bounds_check_lib",
- ":save_restore_tensor",
- "//tensorflow/core:framework",
- "//tensorflow/core:io_ops_op_lib",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core/util/tensor_bundle",
- ],
+IO_DEPS = [
+ ":ops_util",
+ ":reader_base",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:io_ops_op_lib",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/util/tensor_bundle",
+]
+
+tf_kernel_library(
+ name = "fixed_length_record_reader_op",
+ prefix = "fixed_length_record_reader_op",
+ deps = IO_DEPS,
+)
+
+tf_kernel_library(
+ name = "identity_reader_op",
+ prefix = "identity_reader_op",
+ deps = IO_DEPS,
+)
+
+tf_kernel_library(
+ name = "matching_files_op",
+ prefix = "matching_files_op",
+ deps = IO_DEPS,
+)
+
+tf_kernel_library(
+ name = "reader_ops",
+ prefix = "reader_ops",
+ deps = IO_DEPS,
+)
+
+SAVE_RESTORE_DEPS = [
+ ":bounds_check_lib",
+ ":save_restore_tensor",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:io_ops_op_lib",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/util/tensor_bundle",
+]
+
+tf_kernel_library(
+ name = "restore_op",
+ prefix = "restore_op",
+ deps = SAVE_RESTORE_DEPS,
+)
+
+tf_kernel_library(
+ name = "save_op",
+ prefix = "save_op",
+ deps = SAVE_RESTORE_DEPS,
+)
+
+tf_kernel_library(
+ name = "save_restore_v2_ops",
+ prefix = "save_restore_v2_ops",
+ deps = SAVE_RESTORE_DEPS,
+)
+
+tf_kernel_library(
+ name = "text_line_reader_op",
+ prefix = "text_line_reader_op",
+ deps = IO_DEPS,
+)
+
+tf_kernel_library(
+ name = "tf_record_reader_op",
+ prefix = "tf_record_reader_op",
+ deps = IO_DEPS,
+)
+
+tf_kernel_library(
+ name = "whole_file_read_ops",
+ prefix = "whole_file_read_ops",
+ deps = IO_DEPS,
)
tf_cc_tests(
@@ -1323,30 +1728,97 @@ tf_cc_tests(
],
)
-tf_kernel_libraries(
+cc_library(
name = "linalg",
- prefixes = [
- "cholesky_op",
- "cholesky_grad",
- "determinant_op",
- "self_adjoint_eig_op",
- "self_adjoint_eig_v2_op",
- "matrix_inverse_op",
- "matrix_solve_ls_op",
- "matrix_solve_op",
- "matrix_triangular_solve_op",
- "qr_op",
- "svd_op",
- ],
deps = [
- ":linalg_ops_common",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:linalg_ops_op_lib",
- "//third_party/eigen3",
+ ":cholesky_grad",
+ ":cholesky_op",
+ ":determinant_op",
+ ":matrix_inverse_op",
+ ":matrix_solve_ls_op",
+ ":matrix_solve_op",
+ ":matrix_triangular_solve_op",
+ ":qr_op",
+ ":self_adjoint_eig_op",
+ ":self_adjoint_eig_v2_op",
+ ":svd_op",
],
)
+LINALG_DEPS = [
+ ":linalg_ops_common",
+ "//third_party/eigen3",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:linalg_ops_op_lib",
+]
+
+tf_kernel_library(
+ name = "cholesky_op",
+ prefix = "cholesky_op",
+ deps = LINALG_DEPS,
+)
+
+tf_kernel_library(
+ name = "cholesky_grad",
+ prefix = "cholesky_grad",
+ deps = LINALG_DEPS,
+)
+
+tf_kernel_library(
+ name = "determinant_op",
+ prefix = "determinant_op",
+ deps = LINALG_DEPS,
+)
+
+tf_kernel_library(
+ name = "self_adjoint_eig_op",
+ prefix = "self_adjoint_eig_op",
+ deps = LINALG_DEPS,
+)
+
+tf_kernel_library(
+ name = "self_adjoint_eig_v2_op",
+ prefix = "self_adjoint_eig_v2_op",
+ deps = LINALG_DEPS,
+)
+
+tf_kernel_library(
+ name = "matrix_inverse_op",
+ prefix = "matrix_inverse_op",
+ deps = LINALG_DEPS,
+)
+
+tf_kernel_library(
+ name = "matrix_solve_ls_op",
+ prefix = "matrix_solve_ls_op",
+ deps = LINALG_DEPS,
+)
+
+tf_kernel_library(
+ name = "matrix_solve_op",
+ prefix = "matrix_solve_op",
+ deps = LINALG_DEPS,
+)
+
+tf_kernel_library(
+ name = "matrix_triangular_solve_op",
+ prefix = "matrix_triangular_solve_op",
+ deps = LINALG_DEPS,
+)
+
+tf_kernel_library(
+ name = "qr_op",
+ prefix = "qr_op",
+ deps = LINALG_DEPS,
+)
+
+tf_kernel_library(
+ name = "svd_op",
+ prefix = "svd_op",
+ deps = LINALG_DEPS,
+)
+
cc_library(
name = "linalg_ops_common",
srcs = ["linalg_ops_common.cc"],
@@ -1359,24 +1831,55 @@ cc_library(
],
)
-tf_kernel_libraries(
+cc_library(
name = "logging",
- prefixes = [
- "logging_ops",
- "summary_audio_op",
- "summary_image_op",
- "summary_op",
- "summary_tensor_op",
- ],
deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:logging_ops_op_lib",
- "//tensorflow/core:protos_all_cc",
+ ":logging_ops",
+ ":summary_audio_op",
+ ":summary_image_op",
+ ":summary_op",
+ ":summary_tensor_op",
],
)
+LOGGING_DEPS = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:logging_ops_op_lib",
+ "//tensorflow/core:protos_all_cc",
+]
+
+tf_kernel_library(
+ name = "logging_ops",
+ prefix = "logging_ops",
+ deps = LOGGING_DEPS,
+)
+
+tf_kernel_library(
+ name = "summary_audio_op",
+ prefix = "summary_audio_op",
+ deps = LOGGING_DEPS,
+)
+
+tf_kernel_library(
+ name = "summary_image_op",
+ prefix = "summary_image_op",
+ deps = LOGGING_DEPS,
+)
+
+tf_kernel_library(
+ name = "summary_op",
+ prefix = "summary_op",
+ deps = LOGGING_DEPS,
+)
+
+tf_kernel_library(
+ name = "summary_tensor_op",
+ prefix = "summary_tensor_op",
+ deps = LOGGING_DEPS,
+)
+
tf_cc_tests(
size = "small",
srcs = [
@@ -1411,32 +1914,120 @@ MATH_DEPS = [
"//third_party/eigen3",
]
-tf_kernel_libraries(
+cc_library(
name = "math_not_windows",
- prefixes = [
- "sparse_matmul_op",
+ deps = [
+ ":sparse_matmul_op",
],
+)
+
+tf_kernel_library(
+ name = "sparse_matmul_op",
+ prefix = "sparse_matmul_op",
deps = MATH_DEPS,
)
-tf_kernel_libraries(
+cc_library(
name = "math",
- prefixes = [
- "aggregate_ops",
- "argmax_op",
- "batch_matmul_op",
- "betainc_op",
- "cast_op",
- "check_numerics_op",
- "cross_op",
- "cwise_op",
- "fft_ops",
- "matmul_op",
- "reduction_ops",
- "segment_reduction_ops",
- "scan_ops",
- "sequence_ops",
+ deps = [
+ ":aggregate_ops",
+ ":argmax_op",
+ ":batch_matmul_op",
+ ":betainc_op",
+ ":cast_op",
+ ":check_numerics_op",
+ ":cross_op",
+ ":cwise_op",
+ ":fft_ops",
+ ":matmul_op",
+ ":reduction_ops",
+ ":scan_ops",
+ ":segment_reduction_ops",
+ ":sequence_ops",
],
+)
+
+tf_kernel_library(
+ name = "aggregate_ops",
+ prefix = "aggregate_ops",
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "argmax_op",
+ prefix = "argmax_op",
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "batch_matmul_op",
+ prefix = "batch_matmul_op",
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "betainc_op",
+ prefix = "betainc_op",
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "cast_op",
+ prefix = "cast_op",
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "check_numerics_op",
+ prefix = "check_numerics_op",
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "cross_op",
+ prefix = "cross_op",
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "cwise_op",
+ prefix = "cwise_op",
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "fft_ops",
+ prefix = "fft_ops",
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "matmul_op",
+ prefix = "matmul_op",
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "reduction_ops",
+ prefix = "reduction_ops",
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "segment_reduction_ops",
+ prefix = "segment_reduction_ops",
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "scan_ops",
+ prefix = "scan_ops",
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "sequence_ops",
+ prefix = "sequence_ops",
deps = MATH_DEPS,
)
@@ -1696,42 +2287,104 @@ tf_kernel_library(
],
)
-tf_kernel_libraries(
+cc_library(
name = "nn",
- libs = [
- ":l2loss_op",
- ],
- prefixes = [
- "batch_norm_op",
- "bias_op",
- "fused_batch_norm_op",
- "in_topk_op",
- "lrn_op",
- "relu_op",
- "softmax_op",
- "softplus_op",
- "softsign_op",
- "topk_op",
- "xent_op",
- ],
deps = [
- ":bounds_check",
- ":conv_2d",
+ ":batch_norm_op",
+ ":bias_op",
":conv_ops",
":dilation_ops",
- ":fused_batch_norm_util_gpu",
- ":ops_util",
- ":pooling_ops",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:nn_grad",
- "//tensorflow/core:nn_ops_op_lib",
- "//third_party/eigen3",
- ] + if_not_windows([
- ":depthwise_conv_grad_op",
- ":depthwise_conv_op",
- ]),
+ ":fused_batch_norm_op",
+ ":in_topk_op",
+ ":l2loss_op",
+ ":lrn_op",
+ ":relu_op",
+ ":softmax_op",
+ ":softplus_op",
+ ":softsign_op",
+ ":topk_op",
+ ":xent_op",
+ ] + if_not_windows([":depthwise_conv_op"]),
+)
+
+NN_DEPS = [
+ ":bounds_check",
+ ":conv_2d",
+ ":fused_batch_norm_util_gpu",
+ ":ops_util",
+ ":pooling_ops",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:nn_grad",
+ "//tensorflow/core:nn_ops_op_lib",
+ "//third_party/eigen3",
+]
+
+tf_kernel_library(
+ name = "batch_norm_op",
+ prefix = "batch_norm_op",
+ deps = NN_DEPS,
+)
+
+tf_kernel_library(
+ name = "bias_op",
+ prefix = "bias_op",
+ deps = NN_DEPS,
+)
+
+tf_kernel_library(
+ name = "fused_batch_norm_op",
+ prefix = "fused_batch_norm_op",
+ deps = NN_DEPS,
+)
+
+tf_kernel_library(
+ name = "in_topk_op",
+ prefix = "in_topk_op",
+ deps = NN_DEPS,
+)
+
+tf_kernel_library(
+ name = "lrn_op",
+ prefix = "lrn_op",
+ deps = NN_DEPS,
+)
+
+tf_kernel_library(
+ name = "relu_op",
+ prefix = "relu_op",
+ deps = NN_DEPS,
+)
+
+tf_kernel_library(
+ name = "softmax_op",
+ prefix = "softmax_op",
+ deps = NN_DEPS,
+)
+
+tf_kernel_library(
+ name = "softplus_op",
+ prefix = "softplus_op",
+ deps = NN_DEPS,
+)
+
+tf_kernel_library(
+ name = "softsign_op",
+ prefix = "softsign_op",
+ deps = NN_DEPS,
+)
+
+tf_kernel_library(
+ name = "topk_op",
+ prefix = "topk_op",
+ deps = NN_DEPS,
+)
+
+tf_kernel_library(
+ name = "xent_op",
+ prefix = "xent_op",
+ deps = NN_DEPS,
)
tf_kernel_library(
@@ -1965,39 +2618,83 @@ tf_kernel_library(
],
)
-tf_kernel_libraries(
+cc_library(
name = "parsing",
- prefixes = [
- "decode_csv_op",
- "decode_raw_op",
- "example_parsing_ops",
- "parse_tensor_op",
- "string_to_number_op",
- ],
deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:parsing_ops_op_lib",
- "//tensorflow/core:proto_text",
- "//tensorflow/core:protos_all_cc",
+ ":decode_csv_op",
+ ":decode_raw_op",
+ ":example_parsing_ops",
+ ":parse_tensor_op",
+ ":string_to_number_op",
],
)
-tf_kernel_libraries(
+PARSING_DEPS = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:parsing_ops_op_lib",
+ "//tensorflow/core:proto_text",
+ "//tensorflow/core:protos_all_cc",
+]
+
+tf_kernel_library(
+ name = "decode_csv_op",
+ prefix = "decode_csv_op",
+ deps = PARSING_DEPS,
+)
+
+tf_kernel_library(
+ name = "decode_raw_op",
+ prefix = "decode_raw_op",
+ deps = PARSING_DEPS,
+)
+
+tf_kernel_library(
+ name = "example_parsing_ops",
+ prefix = "example_parsing_ops",
+ deps = PARSING_DEPS,
+)
+
+tf_kernel_library(
+ name = "parse_tensor_op",
+ prefix = "parse_tensor_op",
+ deps = PARSING_DEPS,
+)
+
+tf_kernel_library(
+ name = "string_to_number_op",
+ prefix = "string_to_number_op",
+ deps = PARSING_DEPS,
+)
+
+cc_library(
name = "random_ops",
- prefixes = [
- "random_op",
- "random_shuffle_op",
- ],
deps = [
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:random_ops_op_lib",
+ ":random_op",
+ ":random_shuffle_op",
],
)
+RANDOM_OPS_DEPS = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:random_ops_op_lib",
+]
+
+tf_kernel_library(
+ name = "random_op",
+ prefix = "random_op",
+ deps = RANDOM_OPS_DEPS,
+)
+
+tf_kernel_library(
+ name = "random_shuffle_op",
+ prefix = "random_shuffle_op",
+ deps = RANDOM_OPS_DEPS,
+)
+
tf_cuda_cc_test(
name = "random_op_test",
size = "small",
@@ -2013,52 +2710,162 @@ tf_cuda_cc_test(
],
)
-tf_kernel_libraries(
+cc_library(
name = "required",
- prefixes = [
- "no_op",
- "sendrecv_ops",
- ],
deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:no_op_op_lib",
- "//tensorflow/core:sendrecv_ops_op_lib",
+ ":no_op",
+ ":sendrecv_ops",
],
)
-tf_kernel_libraries(
+REQUIRED_DEPS = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:no_op_op_lib",
+ "//tensorflow/core:sendrecv_ops_op_lib",
+]
+
+tf_kernel_library(
+ name = "no_op",
+ prefix = "no_op",
+ deps = REQUIRED_DEPS,
+)
+
+tf_kernel_library(
+ name = "sendrecv_ops",
+ prefix = "sendrecv_ops",
+ deps = REQUIRED_DEPS,
+)
+
+cc_library(
name = "sparse",
- prefixes = [
- "sparse_add_grad_op",
- "sparse_add_op",
- "sparse_concat_op",
- "sparse_reduce_sum_op",
- "sparse_dense_binary_op_shared",
- "sparse_sparse_binary_op_shared",
- "sparse_reorder_op",
- "sparse_reshape_op",
- "sparse_softmax",
- "sparse_split_op",
- "sparse_tensor_dense_add_op",
- "sparse_tensor_dense_matmul_op",
- "sparse_to_dense_op",
- "sparse_xent_op",
- "serialize_sparse_op",
- "sparse_tensors_map_ops",
- ],
deps = [
- ":bounds_check",
- ":cwise_op",
- ":fill_functor",
- ":scatter_functor",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:sparse_ops_op_lib",
- "//third_party/eigen3",
+ ":serialize_sparse_op",
+ ":sparse_add_grad_op",
+ ":sparse_add_op",
+ ":sparse_concat_op",
+ ":sparse_dense_binary_op_shared",
+ ":sparse_reduce_sum_op",
+ ":sparse_reorder_op",
+ ":sparse_reshape_op",
+ ":sparse_softmax",
+ ":sparse_sparse_binary_op_shared",
+ ":sparse_split_op",
+ ":sparse_tensor_dense_add_op",
+ ":sparse_tensor_dense_matmul_op",
+ ":sparse_tensors_map_ops",
+ ":sparse_to_dense_op",
+ ":sparse_xent_op",
],
)
+SPARSE_DEPS = [
+ ":bounds_check",
+ ":cwise_op",
+ ":fill_functor",
+ ":scatter_functor",
+ "//third_party/eigen3",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:sparse_ops_op_lib",
+]
+
+tf_kernel_library(
+ name = "sparse_add_grad_op",
+ prefix = "sparse_add_grad_op",
+ deps = SPARSE_DEPS,
+)
+
+tf_kernel_library(
+ name = "sparse_add_op",
+ prefix = "sparse_add_op",
+ deps = SPARSE_DEPS,
+)
+
+tf_kernel_library(
+ name = "sparse_concat_op",
+ prefix = "sparse_concat_op",
+ deps = SPARSE_DEPS,
+)
+
+tf_kernel_library(
+ name = "sparse_reduce_sum_op",
+ prefix = "sparse_reduce_sum_op",
+ deps = SPARSE_DEPS,
+)
+
+tf_kernel_library(
+ name = "sparse_dense_binary_op_shared",
+ prefix = "sparse_dense_binary_op_shared",
+ deps = SPARSE_DEPS,
+)
+
+tf_kernel_library(
+ name = "sparse_sparse_binary_op_shared",
+ prefix = "sparse_sparse_binary_op_shared",
+ deps = SPARSE_DEPS,
+)
+
+tf_kernel_library(
+ name = "sparse_reorder_op",
+ prefix = "sparse_reorder_op",
+ deps = SPARSE_DEPS,
+)
+
+tf_kernel_library(
+ name = "sparse_reshape_op",
+ prefix = "sparse_reshape_op",
+ deps = SPARSE_DEPS,
+)
+
+tf_kernel_library(
+ name = "sparse_softmax",
+ prefix = "sparse_softmax",
+ deps = SPARSE_DEPS,
+)
+
+tf_kernel_library(
+ name = "sparse_split_op",
+ prefix = "sparse_split_op",
+ deps = SPARSE_DEPS,
+)
+
+tf_kernel_library(
+ name = "sparse_tensor_dense_add_op",
+ prefix = "sparse_tensor_dense_add_op",
+ deps = SPARSE_DEPS,
+)
+
+tf_kernel_library(
+ name = "sparse_tensor_dense_matmul_op",
+ prefix = "sparse_tensor_dense_matmul_op",
+ deps = SPARSE_DEPS,
+)
+
+tf_kernel_library(
+ name = "sparse_to_dense_op",
+ prefix = "sparse_to_dense_op",
+ deps = SPARSE_DEPS,
+)
+
+tf_kernel_library(
+ name = "sparse_xent_op",
+ prefix = "sparse_xent_op",
+ deps = SPARSE_DEPS,
+)
+
+tf_kernel_library(
+ name = "serialize_sparse_op",
+ prefix = "serialize_sparse_op",
+ deps = SPARSE_DEPS,
+)
+
+tf_kernel_library(
+ name = "sparse_tensors_map_ops",
+ prefix = "sparse_tensors_map_ops",
+ deps = SPARSE_DEPS,
+)
+
tf_cuda_cc_tests(
size = "small",
srcs = [
@@ -2151,27 +2958,58 @@ cc_library(
],
)
-tf_kernel_libraries(
+cc_library(
name = "state",
- prefixes = [
- "count_up_to_op",
- "dense_update_ops",
- "scatter_op",
- "scatter_nd_op",
- "variable_ops",
- ],
deps = [
- ":assign_op",
- ":bounds_check",
- ":fill_functor",
- ":scatter_functor",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:state_ops_op_lib",
- "//third_party/eigen3",
+ ":count_up_to_op",
+ ":dense_update_ops",
+ ":scatter_nd_op",
+ ":scatter_op",
+ ":variable_ops",
],
)
+STATE_DEPS = [
+ ":assign_op",
+ ":bounds_check",
+ ":fill_functor",
+ ":scatter_functor",
+ "//third_party/eigen3",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:state_ops_op_lib",
+]
+
+tf_kernel_library(
+ name = "count_up_to_op",
+ prefix = "count_up_to_op",
+ deps = STATE_DEPS,
+)
+
+tf_kernel_library(
+ name = "dense_update_ops",
+ prefix = "dense_update_ops",
+ deps = STATE_DEPS,
+)
+
+tf_kernel_library(
+ name = "scatter_op",
+ prefix = "scatter_op",
+ deps = STATE_DEPS,
+)
+
+tf_kernel_library(
+ name = "scatter_nd_op",
+ prefix = "scatter_nd_op",
+ deps = STATE_DEPS,
+)
+
+tf_kernel_library(
+ name = "variable_ops",
+ prefix = "variable_ops",
+ deps = STATE_DEPS,
+)
+
tf_cc_test(
name = "scatter_op_test",
size = "small",
@@ -2208,27 +3046,70 @@ tf_cc_test(
],
)
-tf_kernel_libraries(
+cc_library(
name = "string",
- prefixes = [
- "string_to_hash_bucket_op",
- "reduce_join_op",
- "string_join_op",
- "string_split_op",
- "substr_op",
- "as_string_op",
- "base64_ops",
- ],
deps = [
- ":bounds_check",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:string_ops_op_lib",
- "//third_party/eigen3",
+ ":as_string_op",
+ ":base64_ops",
+ ":reduce_join_op",
+ ":string_join_op",
+ ":string_split_op",
+ ":string_to_hash_bucket_op",
+ ":substr_op",
],
)
+STRING_DEPS = [
+ ":bounds_check",
+ "//third_party/eigen3",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:string_ops_op_lib",
+]
+
+tf_kernel_library(
+ name = "string_to_hash_bucket_op",
+ prefix = "string_to_hash_bucket_op",
+ deps = STRING_DEPS,
+)
+
+tf_kernel_library(
+ name = "reduce_join_op",
+ prefix = "reduce_join_op",
+ deps = STRING_DEPS,
+)
+
+tf_kernel_library(
+ name = "string_join_op",
+ prefix = "string_join_op",
+ deps = STRING_DEPS,
+)
+
+tf_kernel_library(
+ name = "string_split_op",
+ prefix = "string_split_op",
+ deps = STRING_DEPS,
+)
+
+tf_kernel_library(
+ name = "substr_op",
+ prefix = "substr_op",
+ deps = STRING_DEPS,
+)
+
+tf_kernel_library(
+ name = "as_string_op",
+ prefix = "as_string_op",
+ deps = STRING_DEPS,
+)
+
+tf_kernel_library(
+ name = "base64_ops",
+ prefix = "base64_ops",
+ deps = STRING_DEPS,
+)
+
tf_kernel_library(
name = "training_ops",
prefix = "training_ops",
@@ -2398,6 +3279,10 @@ filegroup(
"matmul_op.h",
"no_op.cc",
"no_op.h",
+ "non_max_suppression_op.cc",
+ "non_max_suppression_op.h",
+ "one_hot_op.cc",
+ "one_hot_op.h",
"ops_util.h",
"pack_op.cc",
"pooling_ops_common.h",
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index 7ab37a8abd..2d1b21d9e4 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -507,7 +507,8 @@ void LaunchConv2DOp<GPUDevice, T>::launch(
transformed_output.template flat<T>().size());
static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
- "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
+ // default value is in bytes despite the name of the environment variable
+ "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
);
int device_id = stream->parent()->device_ordinal();
diff --git a/tensorflow/core/kernels/eigen_pooling.h b/tensorflow/core/kernels/eigen_pooling.h
index 8eea1b0f9d..e13c8b9835 100644
--- a/tensorflow/core/kernels/eigen_pooling.h
+++ b/tensorflow/core/kernels/eigen_pooling.h
@@ -329,7 +329,11 @@ struct AvgPoolMeanReducer {
}
#if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__)
-#ifdef EIGEN_VECTORIZE_AVX
+#ifdef EIGEN_VECTORIZE_AVX512
+#define pequal(a, b) \
+ _mm512_maskz_set1_epi32(_mm512_cmp_ps_mask(a, b, _CMP_EQ_UQ), -1)
+#define psel(a, b, false_mask) _mm512_ternarylogic_epi64(false_mask, a, b, 0xca)
+#elif defined EIGEN_VECTORIZE_AVX
#define pequal(a, b) _mm256_cmp_ps(a, b, _CMP_EQ_UQ)
#define psel(a, b, false_mask) _mm256_blendv_ps(a, b, false_mask)
#else
diff --git a/tensorflow/core/kernels/hexagon/BUILD b/tensorflow/core/kernels/hexagon/BUILD
index f6af111493..444180f986 100644
--- a/tensorflow/core/kernels/hexagon/BUILD
+++ b/tensorflow/core/kernels/hexagon/BUILD
@@ -59,6 +59,10 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:cwise_op",
+ "//tensorflow/core/kernels:quantized_ops",
+ "//tensorflow/core/kernels:reduction_ops",
+ "//tensorflow/core/kernels:reshape_op",
+ "//tensorflow/core/kernels:softmax_op",
],
)
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc
index 422be39d54..da13d64052 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.cc
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/tensor_slice_writer.h"
@@ -31,6 +32,9 @@ namespace tensorflow {
constexpr bool DBG_DUMP_VERIFICATION_STRING = false;
constexpr bool DBG_DUMP_PARAMS = false;
+const string RESHAPE_NODE_TYPE_STRING = "Reshape";
+const string SOURCE_NODE_NAME = "_SOURCE";
+const string SINK_NODE_NAME = "_SINK";
const string INPUTS_NODE_PREFIX = "inputs_for_";
const string OUTPUTS_NODE_PREFIX = "outputs_for_";
const string DATA_NODE_PREFIX = "data_for_op_";
@@ -83,9 +87,13 @@ Status GraphTransferer::LoadGraphFromProto(
}
for (const Node* const node : graph.nodes()) {
- RegisterNodeIfAllInputsAreCached(ops_definitions, shape_refiner, *node,
- false, input_node_info_list,
- output_node_names, output_tensor_map);
+ status = RegisterNodeIfAllInputsAreCached(
+ ops_definitions, shape_refiner, *node, false, input_node_info_list,
+ output_node_names, output_tensor_map);
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed to transfer graph " << status;
+ return status;
+ }
}
ClearCache();
if (DBG_DUMP_PARAMS) {
@@ -101,11 +109,13 @@ Status GraphTransferer::LoadGraphFromProtoFile(
const IGraphTransferOpsDefinitions& ops_definitions,
const string& graph_def_path,
const std::vector<InputNodeInfo>& input_node_info_list,
- const std::vector<string>& output_node_names,
- const OutputTensorMap& output_tensor_map, const bool is_text_proto) {
+ const std::vector<string>& output_node_names, const bool is_text_proto,
+ const bool dry_run_for_unknown_shape,
+ OutputTensorInfo* output_tensor_info) {
GraphDef graph_def;
string output;
Status status;
+ VLOG(1) << "Parse file " << graph_def_path;
if (is_text_proto) {
status = ReadFileToString(Env::Default(), graph_def_path, &output);
if (!protobuf::TextFormat::ParseFromString(output, &graph_def)) {
@@ -115,30 +125,21 @@ Status GraphTransferer::LoadGraphFromProtoFile(
status = ReadBinaryProto(Env::Default(), graph_def_path, &graph_def);
}
if (!status.ok()) {
+ VLOG(1) << "Failed to load graph " << status;
return status;
}
- return LoadGraphFromProto(ops_definitions, graph_def, input_node_info_list,
- output_node_names, output_tensor_map);
-}
-
-Status GraphTransferer::LoadGraphFromProtoFile(
- const IGraphTransferOpsDefinitions& ops_definitions,
- const string& graph_def_path,
- const std::vector<InputNodeInfo>& input_node_info_list,
- const std::vector<string>& output_node_names,
- const OutputTensorMap& output_tensor_map) {
- GraphDef graph_def;
- string output;
- Status status = ReadFileToString(Env::Default(), graph_def_path, &output);
- if (!status.ok()) {
- return status;
- }
- if (!protobuf::TextFormat::ParseFromString(output, &graph_def)) {
- return errors::InvalidArgument("Cannot parse proto string.");
+ if (dry_run_for_unknown_shape) {
+ VLOG(1) << "Dry run graph to obtain shape of nodes";
+ status = DryRunInferenceForAllNode(graph_def, input_node_info_list, true,
+ output_tensor_info);
+ if (!status.ok()) {
+ return status;
+ }
}
- LoadGraphFromProto(ops_definitions, graph_def, input_node_info_list,
- output_node_names, output_tensor_map);
- return Status();
+ VLOG(1) << "Load graph with output tensors";
+ return LoadGraphFromProto(ops_definitions, graph_def, input_node_info_list,
+ output_node_names,
+ output_tensor_info->output_tensor_map);
}
/**
@@ -172,17 +173,17 @@ Status GraphTransferer::LoadGraphFromProtoFile(
switch (data_type) {
case DT_INT32: {
auto int_tensor = input_tensor.flat<int32>();
- int_tensor = int_tensor.constant(0.0);
+ int_tensor = int_tensor.constant(0);
break;
}
case DT_FLOAT: {
auto float_tensor = input_tensor.flat<float>();
- float_tensor = float_tensor.constant(0.0);
+ float_tensor = float_tensor.constant(0.0f);
break;
}
case DT_QUINT8: {
auto int_tensor = input_tensor.flat<quint8>();
- int_tensor = int_tensor.constant(0.0);
+ int_tensor = int_tensor.constant(0);
break;
}
default:
@@ -234,7 +235,12 @@ Status GraphTransferer::LoadGraphFromProtoFile(
const Status status =
DryRunInference(graph_def, input_node_info_list, output_node_names,
initialize_by_zero, &output_tensors);
- CHECK(output_node_names.size() == output_tensors.size());
+ if (!status.ok()) {
+ VLOG(1) << "Failed to dryrun " << status;
+ return status;
+ }
+ CHECK(output_node_names.size() == output_tensors.size())
+ << output_node_names.size() << ", " << output_tensors.size();
// Append output tensor of input node in advance to create a map
// to avoid memory reallocation inside vector
@@ -257,6 +263,10 @@ Status GraphTransferer::LoadGraphFromProtoFile(
return status;
}
+void GraphTransferer::EnableStrictCheckMode(const bool enable) {
+ strict_check_mode_ = enable;
+}
+
const std::vector<GraphTransferer::ConstNodeTransferParams>&
GraphTransferer::GetConstNodeParams() const {
return const_node_transfer_params_list_;
@@ -314,13 +324,16 @@ bool GraphTransferer::AreAllInputsCached(const Node& node) const {
return true;
}
-void GraphTransferer::RegisterNode(
+Status GraphTransferer::RegisterNode(
const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map,
const Node& node, const std::vector<InputNodeInfo>& input_node_info_list,
const std::vector<string>& output_node_names) {
VLOG(1) << "Register node: " << node.name();
- if (IsInputNode(input_node_info_list, node.name())) {
+ if (node.name() == SOURCE_NODE_NAME || node.name() == SINK_NODE_NAME) {
+ // Just ignore sink and source
+ return Status();
+ } else if (IsInputNode(input_node_info_list, node.name())) {
RegisterInputNode(ops_definitions, shape_refiner, output_tensor_map, node);
} else if (std::find(output_node_names.begin(), output_node_names.end(),
node.name()) != output_node_names.end()) {
@@ -330,10 +343,18 @@ void GraphTransferer::RegisterNode(
} else if (HasPaddingAndStrides(node)) {
RegisterNodeWithPaddingAndStrides(ops_definitions, shape_refiner,
output_tensor_map, node);
+ } else if (IsNodeFlattenReshape(node, output_tensor_map, shape_refiner)) {
+ RegisterFlattenNode(ops_definitions, shape_refiner, output_tensor_map,
+ node);
+ } else if (ops_definitions.GetOpIdFor(node.type_string()) !=
+ IGraphTransferOpsDefinitions::INVALID_OP_ID) {
+ RegisterGenericNode(ops_definitions, shape_refiner, output_tensor_map,
+ node);
} else {
- // TODO(satok): register params for nodes which are supported by SOC
- VLOG(1) << "Not implemented for " << node.type_string();
+ return errors::InvalidArgument(node.type_string() +
+ " has not implemented yet.");
}
+ return Status();
}
void GraphTransferer::RegisterConstantNode(
@@ -348,8 +369,8 @@ void GraphTransferer::RegisterConstantNode(
// TODO(satok): support multiple outputs?
const int output_index = 0;
const DataType dt = node.output_type(output_index);
- const size_t max_bytes_per_data =
- checkpoint::TensorSliceWriter::MaxBytesPerElement(dt);
+ const size_t max_bytes_per_data = DataTypeSize(dt);
+ CHECK(max_bytes_per_data > 0);
shape_inference::InferenceContext* context = shape_refiner.GetContext(&node);
shape_inference::ShapeHandle shape_handle = context->output(output_index);
const shape_inference::DimensionHandle num_elements_dim =
@@ -402,6 +423,45 @@ bool GraphTransferer::HasPaddingAndStrides(const Node& node) {
node.def().attr().count(STRIDES_ATTR_NAME) > 0;
}
+bool GraphTransferer::IsNodeFlattenReshape(
+ const Node& node, const OutputTensorMap& output_tensor_map,
+ const ShapeRefiner& shape_refiner) {
+ // Check if node is reshape op
+ if (node.type_string() != RESHAPE_NODE_TYPE_STRING) {
+ return false;
+ }
+
+ shape_inference::InferenceContext* context = shape_refiner.GetContext(&node);
+ // Check if output count is valid
+ if (context->num_outputs() != 1) {
+ return false;
+ }
+
+ shape_inference::ShapeHandle shape_handle = context->output(0);
+ std::array<int64, SHAPE_ARRAY_SIZE> shape;
+ const shape_inference::DimensionHandle dim_handle =
+ context->NumElements(shape_handle);
+
+ // Obtain shape of output of node
+ if (context->ValueKnown(dim_handle)) {
+ shape = BuildShapeArray(shape_handle, context);
+ } else {
+ // Use output tensor for unknown shape
+ // TODO(stok): Remove this fallback
+ CHECK(!output_tensor_map.empty());
+ const TensorShape& tensor_shape =
+ output_tensor_map.at(node.name())->shape();
+ shape = ToTensorShapeArray(tensor_shape);
+ }
+
+ // check if reshape op just does flatten
+ if (shape[0] == 1 && shape[1] == 1 && shape[2] == 1) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
void GraphTransferer::RegisterNodeWithPaddingAndStrides(
const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map,
@@ -428,7 +488,8 @@ void GraphTransferer::RegisterNodeWithPaddingAndStrides(
padding == VALID ? PADDING_VALID_STR : PADDING_SAME_STR;
const int op_type_id = ops_definitions.GetOpIdFor(node.type_string());
CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount())
- << node.type_string();
+ << "Op " << node.type_string() << " not found in map(id = " << op_type_id
+ << ")";
AppendNodeParamsWithIoParams(shape_refiner, output_tensor_map, node,
node.name(), id, node.type_string(), op_type_id,
padding_str, node.num_inputs(), extra_inputs,
@@ -448,9 +509,8 @@ void GraphTransferer::RegisterInputNode(
CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
AppendNodeParamsWithIoParams(
shape_refiner, output_tensor_map, node, node.name(), id,
- IGraphTransferOpsDefinitions::INPUT_OP_NAME, op_type_id, PADDING_NA,
- node.num_inputs(), {}, node.num_outputs(), true /* append_input */,
- true /* append_output */);
+ node.type_string(), op_type_id, PADDING_NA, node.num_inputs(), {},
+ node.num_outputs(), true /* append_input */, true /* append_output */);
}
void GraphTransferer::RegisterOutputNode(
@@ -465,14 +525,47 @@ void GraphTransferer::RegisterOutputNode(
CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
// TODO(satok): Set output for output node?
AppendNodeParamsWithIoParams(
- shape_refiner, output_tensor_map, node, node.name(), id, op_type,
- op_type_id, PADDING_NA, node.num_inputs(), {}, 0 /* outputs_size */,
- true /* append_input */, false /* append_output */);
+ shape_refiner, output_tensor_map, node, node.name(), id,
+ node.type_string(), op_type_id, PADDING_NA, node.num_inputs(), {},
+ 0 /* outputs_size */, true /* append_input */, false /* append_output */);
+}
+
+void GraphTransferer::RegisterFlattenNode(
+ const IGraphTransferOpsDefinitions& ops_definitions,
+ const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map,
+ const Node& node) {
+ VLOG(1) << "Register flatten node: " << node.name();
+ CHECK(node_name_to_id_cache_map_.count(node.name()) == 1);
+ const int id = node_name_to_id_cache_map_[node.name()];
+ const string op_type = IGraphTransferOpsDefinitions::FLATTEN_OP_NAME;
+ const int op_type_id = ops_definitions.GetOpIdFor(op_type);
+ CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
+
+ AppendNodeParamsWithIoParams(
+ shape_refiner, output_tensor_map, node, node.name(), id,
+ node.type_string(), op_type_id, PADDING_NA, node.num_inputs(), {},
+ node.num_outputs(), true /* append_input */, true /* append_output */);
+}
+
+void GraphTransferer::RegisterGenericNode(
+ const IGraphTransferOpsDefinitions& ops_definitions,
+ const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map,
+ const Node& node) {
+ VLOG(1) << "Register generic node: " << node.name();
+ CHECK(node_name_to_id_cache_map_.count(node.name()) == 1);
+ const int id = node_name_to_id_cache_map_[node.name()];
+ const int op_type_id = ops_definitions.GetOpIdFor(node.type_string());
+ CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
+
+ AppendNodeParamsWithIoParams(
+ shape_refiner, output_tensor_map, node, node.name(), id,
+ node.type_string(), op_type_id, PADDING_NA, node.num_inputs(), {},
+ node.num_outputs(), true /* append_input */, true /* append_output */);
}
// TODO(satok): Remove this function.
// TODO(satok): Remove only_register_const_node.
-bool GraphTransferer::RegisterNodeIfAllInputsAreCached(
+Status GraphTransferer::RegisterNodeIfAllInputsAreCached(
const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node,
const bool only_register_const_node,
@@ -480,12 +573,11 @@ bool GraphTransferer::RegisterNodeIfAllInputsAreCached(
const std::vector<string>& output_node_names,
const OutputTensorMap& output_tensor_map) {
if (only_register_const_node && !node.IsConstant()) {
- return false;
+ return Status();
}
CHECK(AreAllInputsCached(node));
- RegisterNode(ops_definitions, shape_refiner, output_tensor_map, node,
- input_node_info_list, output_node_names);
- return true;
+ return RegisterNode(ops_definitions, shape_refiner, output_tensor_map, node,
+ input_node_info_list, output_node_names);
}
// CAVEAT: Append inputs and outputs params accordingly
@@ -542,7 +634,7 @@ void GraphTransferer::AppendNodeOutputParams(
output_node = output_edge->src();
}
}
- CHECK(output_node != nullptr);
+ CHECK(output_node != nullptr) << node.name() << ", " << node.type_string();
const int output_index = i;
const DataType dt = node.output_type(output_index);
const size_t max_bytes_per_data =
@@ -556,11 +648,14 @@ void GraphTransferer::AppendNodeOutputParams(
if (context->ValueKnown(num_elements_dim)) {
const int64 num_output_elements = context->Value(num_elements_dim);
data_size = max_bytes_per_data * num_output_elements;
- if (!output_tensor_map.empty()) {
+ if (!output_tensor_map.empty() && strict_check_mode_) {
CHECK(output_tensor_map.count(node.name()) == 1) << node.name();
const TensorShape& tensor_shape =
output_tensor_map.at(node.name())->shape();
- CHECK(num_output_elements == tensor_shape.num_elements());
+ CHECK(num_output_elements == tensor_shape.num_elements())
+ << "num elements of node " << node.name() << " doesn't match "
+ << num_output_elements << " vs " << tensor_shape.num_elements()
+ << ", " << node.type_string();
}
} else {
// Use dryrun result to get the output data size
@@ -718,7 +813,7 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
for (const ConstNodeTransferParams& params :
const_node_transfer_params_list_) {
std::stringstream sstream;
- sstream << "---(CONST) [" << std::hex << params.node_id << ","
+ sstream << "---(CONST) [" << std::hex << params.node_id << std::dec << ","
<< params.shape[0] << "," << params.shape[1] << ","
<< params.shape[2] << "," << params.shape[3] << ","
<< params.data_name << "," << params.data_size << "," << params.name
@@ -729,7 +824,7 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
for (const NodeTransferParams& params : node_transfer_params_list_) {
std::stringstream sstream;
sstream << "---(OP) [" << params.name.c_str() << "," << std::hex
- << params.node_id << "," << params.soc_op_id << ","
+ << params.node_id << std::dec << "," << params.soc_op_id << ","
<< params.padding << "," << params.inputs_name << ","
<< params.inputs_size << "," << params.outputs_name << ","
<< params.outputs_size << "," << params.type << "]";
@@ -738,7 +833,7 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
LOG(INFO) << "Op node count = " << node_transfer_params_list_.size();
for (const NodeInputParams& params : node_input_params_list_) {
std::stringstream sstream;
- sstream << "---(INPUT) [" << std::hex << params.node_id;
+ sstream << "---(INPUT) [" << std::hex << params.node_id << std::dec;
for (const std::tuple<int, int>& pair :
params.input_node_id_and_output_port_list) {
sstream << "," << std::get<0>(pair) << "," << std::get<1>(pair);
@@ -749,7 +844,7 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
LOG(INFO) << "Input params count = " << node_input_params_list_.size();
for (const NodeOutputParams& params : node_output_params_list_) {
std::stringstream sstream;
- sstream << "---(OUTPUT) [" << std::hex << params.node_id;
+ sstream << "---(OUTPUT) [" << std::hex << params.node_id << std::dec;
for (const int max_size : params.max_sizes) {
sstream << "," << max_size;
}
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h
index 666c5889ad..71bd1d3375 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.h
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.h
@@ -100,16 +100,9 @@ class GraphTransferer {
const IGraphTransferOpsDefinitions& ops_definitions,
const string& graph_def_path,
const std::vector<InputNodeInfo>& input_node_info_list,
- const std::vector<string>& output_node_names,
- const OutputTensorMap& output_tensor_map, const bool is_text_proto);
-
- // Load graph structure into GraphTransferer from protobuf file
- Status LoadGraphFromProtoFile(
- const IGraphTransferOpsDefinitions& ops_definitions,
- const string& graph_def_path,
- const std::vector<InputNodeInfo>& input_node_info_list,
- const std::vector<string>& output_node_names,
- const OutputTensorMap& output_tensor_map);
+ const std::vector<string>& output_node_names, const bool is_text_proto,
+ const bool dry_run_for_unknown_shape,
+ OutputTensorInfo* output_tensor_info);
// Dry run inference and cache the result to get memory mapping
static Status DryRunInference(
@@ -128,6 +121,8 @@ class GraphTransferer {
const std::vector<InputNodeInfo>& input_node_info_list,
const bool initialize_by_zero, OutputTensorInfo* output_tensor_info);
+ void EnableStrictCheckMode(bool enable);
+
// Return const node parameters for transfer
const std::vector<ConstNodeTransferParams>& GetConstNodeParams() const;
@@ -142,51 +137,84 @@ class GraphTransferer {
private:
int CacheNode(const Node& node);
+
static bool IsInputNode(
const std::vector<InputNodeInfo>& input_node_info_list,
const string& node_name);
+
bool AreAllInputsCached(const Node& node) const;
- void RegisterNode(const IGraphTransferOpsDefinitions& ops_definitions,
- const ShapeRefiner& shape_refiner,
- const OutputTensorMap& output_tensor_map, const Node& node,
- const std::vector<InputNodeInfo>& input_node_info_list,
- const std::vector<string>& output_node_names);
+
+ Status RegisterNode(const IGraphTransferOpsDefinitions& ops_definitions,
+ const ShapeRefiner& shape_refiner,
+ const OutputTensorMap& output_tensor_map,
+ const Node& node,
+ const std::vector<InputNodeInfo>& input_node_info_list,
+ const std::vector<string>& output_node_names);
+
void RegisterConstantNode(const ShapeRefiner& shape_refiner, const Node& node,
const OutputTensorMap& output_tensor_map);
+
int RegisterConstantShape(const std::vector<int>& shape);
+
bool HasPaddingAndStrides(const Node& node);
+
+ // Return true if the node is a reshape op which just flattens input
+ // TODO(satok): Remove this method once generic reshape op is implemented in
+ // SOC
+ bool IsNodeFlattenReshape(const Node& node,
+ const OutputTensorMap& output_tensor_map,
+ const ShapeRefiner& shape_refiner);
+
void RegisterNodeWithPaddingAndStrides(
const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner,
const OutputTensorMap& output_tensor_map, const Node& node);
+
void RegisterInputNode(const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner,
const OutputTensorMap& output_tensor_map,
const Node& node);
+
void RegisterOutputNode(const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner,
const OutputTensorMap& output_tensor_map,
const Node& node);
- bool RegisterNodeIfAllInputsAreCached(
+
+ void RegisterFlattenNode(const IGraphTransferOpsDefinitions& ops_definitions,
+ const ShapeRefiner& shape_refiner,
+ const OutputTensorMap& output_tensor_map,
+ const Node& node);
+
+ void RegisterGenericNode(const IGraphTransferOpsDefinitions& ops_definitions,
+ const ShapeRefiner& shape_refiner,
+ const OutputTensorMap& output_tensor_map,
+ const Node& node);
+
+ Status RegisterNodeIfAllInputsAreCached(
const IGraphTransferOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node,
const bool only_register_const_node,
const std::vector<InputNodeInfo>& input_node_info_list,
const std::vector<string>& output_node_names,
const OutputTensorMap& output_tensor_map);
+
void AppendNodeParams(const string& name, const int id, const string& type,
const int type_id, const string& padding_str,
const int inputs_size,
const std::vector<int>& extra_inputs,
const int outputs_size);
+
void AppendNodeInputParams(const int id, const Node& node,
const std::vector<int>& extra_inputs);
+
void AppendNodeOutputParams(const ShapeRefiner& shape_refiner,
const OutputTensorMap& output_tensor_map,
const int id, const Node& node);
+
static std::array<int64, SHAPE_ARRAY_SIZE> BuildShapeArray(
const shape_inference::ShapeHandle& shape_handle,
shape_inference::InferenceContext* context);
+
void AppendNodeParamsWithIoParams(
const ShapeRefiner& shape_refiner,
const OutputTensorMap& output_tensor_map, const Node& node,
@@ -194,14 +222,19 @@ class GraphTransferer {
const string& padding_str, const int inputs_size,
const std::vector<int>& extra_inputs, const int outputs_size,
const bool append_input_params, const bool append_output_params);
+
static std::array<int64, SHAPE_ARRAY_SIZE> ToTensorShapeArray(
const TensorShape& shape);
+
static void CheckShape(const OutputTensorMap& output_tensor_map,
const string& node_name,
const std::array<int64, SHAPE_ARRAY_SIZE>& actual);
+
void ClearCache();
+
// Dump pretty print of parameters
void DumpNodeTransferParams() const;
+
// Dump verification string of parameters to verify with offline tools
void DumpVerificationStringOfNodeTransferParams() const;
@@ -213,6 +246,10 @@ class GraphTransferer {
std::vector<const Node*> node_name_cache_list_;
std::unordered_map<string, int> node_name_to_id_cache_map_;
+ // strict check mode is true by default. Disable this if the ops' shape
+ // inferences are not implemented correctly.
+ bool strict_check_mode_{true};
+
TF_DISALLOW_COPY_AND_ASSIGN(GraphTransferer);
};
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc
index 71f19e1ea2..23d57ff3e9 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc
+++ b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc
@@ -46,8 +46,8 @@ class GraphTransfererTest : public ::testing::Test {
GraphTransferer gt_;
};
-static const std::vector<string> OP_TYPES{"INPUT", "OUTPUT", "Conv2D",
- "MaxPool"};
+static const std::vector<string> OP_TYPES{"INPUT", "OUTPUT", "Conv2D",
+ "MaxPool", "NoOp", "Add"};
const GraphTransferer::OutputTensorMap EMPTY_OUTPUT_TENSOR_MAP;
class TestGraphTransferOpsDefinitions : public IGraphTransferOpsDefinitions {
@@ -193,7 +193,8 @@ static void SanityCheckNodes(const GraphTransferer& gt) {
TEST_F(GraphTransfererTest, LoadAddGraph) {
GraphDef def = CreateAddGraphDef();
ASSERT_TRUE(gt_.LoadGraphFromProto(TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, def,
- {}, {}, EMPTY_OUTPUT_TENSOR_MAP)
+ {}, std::vector<string>{NAME_A_PLUS_B},
+ EMPTY_OUTPUT_TENSOR_MAP)
.ok());
SanityCheckNodes(gt_);
@@ -399,21 +400,29 @@ TEST(HexagonOpsDefinitions, CheckOpsDefinitions) {
}
TEST(GraphTransferer, LoadGraphFromProtoFile) {
+ const IGraphTransferOpsDefinitions* ops_definitions =
+ &TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
string filename =
io::JoinPath(testing::TensorFlowSrcRoot(),
"core/example/testdata/parse_example_graph_def.pbtxt");
std::vector<GraphTransferer::InputNodeInfo> input_node_info_list = {};
std::vector<string> output_node_names = {};
bool is_text_proto = true;
+
// Keep following comments for debugging purpose for now
- // filename = "";
- // input_node_names = { "Mul" };
- // output_node_names = { "softmax" };
+ // filename = "v3_stripped_quantized_graph_opt.pb";
+ // input_node_info_list.emplace_back(
+ // GraphTransferer::InputNodeInfo{"Mul", Tensor{DT_FLOAT, {1,299,299,3}}});
+ // output_node_names.emplace_back("softmax");
// is_text_proto = false;
+ // ops_definitions = &HexagonOpsDefinitions::getInstance();
+
+ GraphTransferer::OutputTensorInfo output_tensor_info;
GraphTransferer gt;
+ gt.EnableStrictCheckMode(false);
Status status = gt.LoadGraphFromProtoFile(
- TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, filename, input_node_info_list,
- output_node_names, EMPTY_OUTPUT_TENSOR_MAP, is_text_proto);
+ *ops_definitions, filename, input_node_info_list, output_node_names,
+ is_text_proto, true, &output_tensor_info);
// TODO(satok): Uncomment following assert once we fix the loader problem
// ASSERT_TRUE(status.ok()) << status;
}
diff --git a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc
index 8db1ee4b04..f170a4d556 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc
+++ b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc
@@ -21,6 +21,7 @@ limitations under the License.
namespace tensorflow {
+// HVX internal supported ops names
enum class SupportedOpType {
INPUT,
OUTPUT,
@@ -69,26 +70,28 @@ enum class SupportedOpType {
SUPPORTED_OP_TYPE_COUNT,
};
-static const std::unordered_map<string, SupportedOpType>
- OP_NAME_TO_SOC_OP_TYPE_MAP{
- // Custom Op name
- {IGraphTransferOpsDefinitions::INPUT_OP_NAME, SupportedOpType::INPUT},
- {IGraphTransferOpsDefinitions::OUTPUT_OP_NAME, SupportedOpType::OUTPUT},
- // Tensorflow op name
- {"QuantizedConv2D", SupportedOpType::QUANTIZEDCONV2D_8X8TO32},
- {"QuantizedMatMul", SupportedOpType::QUANTIZEDMATMUL_8X8TO32},
- {"QuantizeDownAndShrinkRange",
- SupportedOpType::QUANTIZEDOWNANDSHRINKRANGE_32TO8},
- {"QuantizedRelu", SupportedOpType::QUANTIZEDRELU_8},
- {"QuantizedReluX", SupportedOpType::QUANTIZEDRELUX_8},
- {"QuantizedMaxPool", SupportedOpType::QUANTIZEDMAXPOOL_8},
- {"QuantizedAvgPool", SupportedOpType::QUANTIZEDAVGPOOL_8},
- {"QuantizedConcat", SupportedOpType::QUANTIZEDCONCAT_8},
- {"QuantizedBiasAdd", SupportedOpType::QUANTIZEDBIASADD_8P8TO32},
- {"Min", SupportedOpType::MIN_F},
- {"Max", SupportedOpType::MAX_F},
- {"QuantizeV2", SupportedOpType::QUANTIZE},
- };
+const std::unordered_map<string, SupportedOpType> OP_NAME_TO_SOC_OP_TYPE_MAP{
+ // Custom Op name
+ {IGraphTransferOpsDefinitions::INPUT_OP_NAME, SupportedOpType::INPUT},
+ {IGraphTransferOpsDefinitions::OUTPUT_OP_NAME, SupportedOpType::OUTPUT},
+ {"NoOp", SupportedOpType::NOP},
+ {IGraphTransferOpsDefinitions::FLATTEN_OP_NAME, SupportedOpType::FLATTEN},
+ // Tensorflow op name
+ {"QuantizedConv2D", SupportedOpType::QUANTIZEDCONV2D_8X8TO32},
+ {"QuantizedMatMul", SupportedOpType::QUANTIZEDMATMUL_8X8TO32},
+ {"QuantizeDownAndShrinkRange",
+ SupportedOpType::QUANTIZEDOWNANDSHRINKRANGE_32TO8},
+ {"QuantizedRelu", SupportedOpType::QUANTIZEDRELU_8},
+ {"QuantizedReluX", SupportedOpType::QUANTIZEDRELUX_8},
+ {"QuantizedMaxPool", SupportedOpType::QUANTIZEDMAXPOOL_8},
+ {"QuantizedAvgPool", SupportedOpType::QUANTIZEDAVGPOOL_8},
+ {"QuantizedConcat", SupportedOpType::QUANTIZEDCONCAT_8},
+ {"QuantizedBiasAdd", SupportedOpType::QUANTIZEDBIASADD_8P8TO32},
+ {"Min", SupportedOpType::MIN_F},
+ {"Max", SupportedOpType::MAX_F},
+ {"QuantizeV2", SupportedOpType::QUANTIZE},
+ {"Dequantize", SupportedOpType::DEQUANTIZE},
+};
/* static */ const IGraphTransferOpsDefinitions&
HexagonOpsDefinitions::getInstance() {
diff --git a/tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.cc b/tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.cc
index 8e44c680f6..a4f6ec402e 100644
--- a/tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.cc
+++ b/tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.cc
@@ -21,4 +21,6 @@ namespace tensorflow {
IGraphTransferOpsDefinitions::INPUT_OP_NAME;
/* static */ constexpr const char* const
IGraphTransferOpsDefinitions::OUTPUT_OP_NAME;
+/* static */ constexpr const char* const
+ IGraphTransferOpsDefinitions::FLATTEN_OP_NAME;
}
diff --git a/tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h b/tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h
index 039c4376e4..7e733e1f63 100644
--- a/tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h
+++ b/tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h
@@ -32,6 +32,8 @@ class IGraphTransferOpsDefinitions {
static constexpr const char* const INPUT_OP_NAME = "INPUT";
// Custom op name for output node
static constexpr const char* const OUTPUT_OP_NAME = "OUTPUT";
+ // Custom op name for flatten node
+ static constexpr const char* const FLATTEN_OP_NAME = "FLATTEN";
IGraphTransferOpsDefinitions() = default;
virtual ~IGraphTransferOpsDefinitions() = default;
diff --git a/tensorflow/core/kernels/quantization_utils.h b/tensorflow/core/kernels/quantization_utils.h
index a098179034..683a2f5b41 100644
--- a/tensorflow/core/kernels/quantization_utils.h
+++ b/tensorflow/core/kernels/quantization_utils.h
@@ -112,10 +112,9 @@ void QuantizationRangeForMultiplication(float min_a, float max_a, float min_b,
// input_array is an eigen Tensor. q2f is a QuantizedToFloatStruct.
// This evaluates to an eigen tensor expression, to be used like:
// auto tensor = DEQUANTIZE_WITH_EIGEN(input_tensor, q2f);
-#define DEQUANTIZE_WITH_EIGEN(input_array, q2f) \
- (q2f.range_min + \
- (((input_array.template cast<float>() - q2f.lowest_quantized())) * \
- q2f.range_scale));
+#define DEQUANTIZE_WITH_EIGEN(input_array, q2f) \
+ ((q2f.range_min - q2f.lowest_quantized() * q2f.range_scale) + \
+ input_array.template cast<float>() * q2f.range_scale)
// input_array is an eigen Tensor. f2q is a FloatToQuantizedStruct.
// OutputType is the type of output (e.g. quint8).
diff --git a/tensorflow/core/kernels/quantization_utils_test.cc b/tensorflow/core/kernels/quantization_utils_test.cc
index 55b5193ce1..8456604740 100644
--- a/tensorflow/core/kernels/quantization_utils_test.cc
+++ b/tensorflow/core/kernels/quantization_utils_test.cc
@@ -252,12 +252,11 @@ class QuantizationUtilsTest : public ::testing::Test {
Eigen::ThreadPoolDevice* eigen_device) {
// These are the float values we're going to test the conversions on.
typedef std::pair<float, float> FPair;
- for (FPair min_and_max : std::vector<FPair>{FPair(-255.0f, 255.0f), //
- FPair(-1.0f, 1.0f), //
- FPair(-1.0f, 255.0f), //
- FPair(0.0f, 1e6), //
- FPair(0.0f, 1.0f), //
- FPair(-31.0f, 13.0f)}) {
+ for (FPair min_and_max : std::vector<FPair>{
+ FPair(-255.0f, 255.0f), FPair(-1.0f, 1.0f), FPair(-1.0f, 255.0f),
+ FPair(0.0f, 1e6), FPair(0.0f, 1.0f), FPair(-31.0f, 13.0f),
+ FPair(-5.89505e+08, 5.89505e+08),
+ }) {
const float f_min = min_and_max.first;
const float f_max = min_and_max.second;
const int values_count = sizeof(T) == 1 ? 256 : 50000;
@@ -272,8 +271,8 @@ class QuantizationUtilsTest : public ::testing::Test {
} else {
int64 offset = static_cast<int64>(q_range / values_count * i);
input_array(i) = static_cast<int32>(
- Eigen::NumTraits<T>::lowest() +
- std::min<int64>(Eigen::NumTraits<T>::highest(), offset));
+ std::min<int64>(Eigen::NumTraits<T>::lowest() + offset,
+ Eigen::NumTraits<T>::highest()));
}
}
@@ -285,7 +284,7 @@ class QuantizationUtilsTest : public ::testing::Test {
for (int i = 0; i < values_count; ++i) {
float expected = QuantizedToFloat<T>(input_array(i), f_min, f_max);
float actual = output_array(i);
- ASSERT_NEAR(expected, actual, range * 1e-6)
+ ASSERT_NEAR(expected, actual, range * 1.1e-7)
<< "expected=" << expected << " actual=" << actual
<< " v=" << input_array(i) << " i=" << i << " f_min=" << f_min
<< " f_max=" << f_max
@@ -340,14 +339,14 @@ TEST_F(QuantizationUtilsTest, QuantizedToFloat) {
const int int32_min = std::numeric_limits<int>::min();
const int int32_max = std::numeric_limits<int>::max();
- EXPECT_LT(
- fabsf(-1.0f - QuantizedToFloat<qint32>(qint32(int32_min), -1.0f, 1.0f)),
- 1e-5f);
- EXPECT_LT(fabsf(0.0f - QuantizedToFloat<qint32>(qint32(0), -1.0f, 1.0f)),
- 1e-5f);
- EXPECT_LT(
- fabsf(1.0f - QuantizedToFloat<qint32>(qint32(int32_max), -1.0f, 1.0f)),
- 1e-5f);
+ EXPECT_NEAR(-1.0f, QuantizedToFloat<qint32>(qint32(int32_min), -1.0f, 1.0f),
+ 1e-5f);
+ EXPECT_NEAR(0.0f, QuantizedToFloat<qint32>(qint32(0), -1.0f, 1.0f), 1e-5f);
+ EXPECT_NEAR(1.0f, QuantizedToFloat<qint32>(qint32(int32_max), -1.0f, 1.0f),
+ 1e-5f);
+
+ EXPECT_NEAR(32.0f, QuantizedToFloat<qint32>(qint32(32), int32_min, int32_max),
+ 1.0);
}
TEST_F(QuantizationUtilsTest, AvoidBias) {
@@ -531,6 +530,32 @@ TEST_F(QuantizationUtilsTest, QuantizedTensorToFloat) {
-103.0f, 115.0f, 116.0f, 117.0f});
Tensor output = QuantizedTensorToFloat<quint8>(input, input_min, input_max);
test::ExpectTensorEqual<float>(expected, output);
+
+ // Test for signed 32 bit.
+ // Note that we cannot use input mins and maxes that match the range because
+ // there are 7 too few bits of mantissa accuracy in floats to represent
+ // 2**31-1 accurately. Also there is no good fraction to use because 2**31-1
+ // is a mersenne prime.
+ Tensor input32(DT_QINT32, TensorShape({input_height, input_width}));
+
+ // Use a quantizer centered at 0.
+ float input_range = 1LL << 25;
+ int64 num_levels = (1LL << 32) - 1;
+ float step_size =
+ static_cast<float>(static_cast<double>(input_range) / num_levels);
+ float q_compatible_min_value =
+ roundf(-(input_range / 2.0) / step_size) * step_size;
+ float q_compatible_max_value = q_compatible_min_value + input_range;
+ test::FillValues<qint32>(&input32, {-16384, 0, 16256, -13440, -13312, -13184,
+ 14720, 14848, 14976});
+
+ Tensor output32 = QuantizedTensorToFloat<qint32>(
+ input32, q_compatible_min_value, q_compatible_max_value);
+ test::FillValues<float>(&expected, {-128.0f, 0.0f, 127.0f, -105.0f, -104.0f,
+ -103.0f, 115.0f, 116.0f, 117.0f});
+ // The quantization error in going between 1<<25 and 1<<32 levels.
+ const double kTolerance = .5 / 128.0;
+ test::ExpectTensorNear<float>(expected, output32, kTolerance);
}
// Verify that QuantizedToFloatInPlaceUsingEigen is same result as
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index 7704c5f65a..55d3ee36da 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -37,7 +37,7 @@ static bool ValidUpdateShape(const TensorShape& params_shape,
const Tensor& indices, const Tensor& updates) {
int64 indices_nd = 1;
if (indices.dims() > 1) {
- indices_nd = indices.dim_size(1);
+ indices_nd = indices.dim_size(indices.dims() - 1);
}
for (int d = indices_nd; d < params_shape.dims(); d++) {
if (updates.dim_size(d - indices_nd + 1) != params_shape.dim_size(d)) {
@@ -71,13 +71,13 @@ static void PrepareAndValidateInputs(OpKernelContext* c,
"The outermost dimension of updates and indices ",
"must match. Got indices.shape ", indices_shape.DebugString(),
", updates.shape ", updates_shape.DebugString()));
- OP_REQUIRES(
- c, ValidUpdateShape(params_shape, indices, updates),
- errors::InvalidArgument(
- "Must have updates.shape = indices.shape[0] + params_shape[IXDIM:], ",
- "got updates.shape ", updates_shape.DebugString(), ", indices.shape ",
- indices_shape.DebugString(), ", params_shape ",
- params_shape.DebugString()));
+ OP_REQUIRES(c, ValidUpdateShape(params_shape, indices, updates),
+ errors::InvalidArgument(
+ "Must have updates.shape = indices.shape[:IXDIM] + ",
+ "params_shape[IXDIM:], got updates.shape ",
+ updates_shape.DebugString(), ", indices.shape ",
+ indices_shape.DebugString(), ", params_shape ",
+ params_shape.DebugString()));
// Check that we have enough index space
const int64 N_big = indices.NumElements();
OP_REQUIRES(c, N_big <= std::numeric_limits<Index>::max(),
diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc
index d69909b6de..cc4772c001 100644
--- a/tensorflow/core/kernels/scatter_nd_op_test.cc
+++ b/tensorflow/core/kernels/scatter_nd_op_test.cc
@@ -217,8 +217,7 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
{100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004});
Status s = RunOpKernel();
EXPECT_TRUE(StringPiece(s.ToString())
- .contains("Must have updates.shape = indices.shape[0] + "
- "params_shape[IXDIM:], got"))
+ .contains("Must have updates.shape = indices.shape[:IXDIM]"))
<< s;
}
@@ -233,9 +232,10 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
AddInputFromArray<float>(TensorShape({2, 3}),
{100, 101, 102, 10000, 10001, 10002});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("The outermost dimension of updates and indices "
- "must match. Got "))
+ EXPECT_TRUE(
+ StringPiece(s.ToString())
+ .contains(
+ "The outermost dimension of updates and indices must match."))
<< s;
}
diff --git a/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc b/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc
index 2d1539fb9d..cc0f86ce05 100644
--- a/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc
+++ b/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc
@@ -163,6 +163,10 @@ class SparseDenseBinaryOpShared : public OpKernel {
}
};
+// NOTE(aselle): If Div is extended to non-reals, make sure to use the same
+// separation of operator semantics as done for dense cwise ops. I.e. you
+// should make SparseDenseCwiseRealDiv, SparseDenseCwiseTruncateDiv,
+// SparseDenseCwiseFloorDiv, and then deprecate, SparseDenseCwiseDiv.
// TODO(zongheng): extend to other eligible cwise operations as requested.
#define REGISTER_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
diff --git a/tensorflow/core/kernels/sparse_matmul_op.h b/tensorflow/core/kernels/sparse_matmul_op.h
index 4e14f0099a..170d4ec18b 100644
--- a/tensorflow/core/kernels/sparse_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_matmul_op.h
@@ -209,6 +209,77 @@ EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
#endif
+#ifdef EIGEN_VECTORIZE_AVX512
+template <>
+EIGEN_STRONG_INLINE Packet16f
+pbroadcast_first<Packet16f>(const Packet16f& a_in) {
+ Packet4f a = _mm512_castps512_ps128(a_in);
+ return _mm512_broadcastss_ps(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet16f
+pbroadcast_second<Packet16f>(const Packet16f& a_in) {
+ Packet4f a = _mm512_castps512_ps128(a_in);
+ return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(1, 1, 1, 1)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet16f
+pbroadcast_third<Packet16f>(const Packet16f& a_in) {
+ Packet4f a = _mm512_castps512_ps128(a_in);
+ return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(2, 2, 2, 2)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet16f
+pbroadcast_fourth<Packet16f>(const Packet16f& a_in) {
+ Packet4f a = _mm512_castps512_ps128(a_in);
+ return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(3, 3, 3, 3)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet8d pbroadcast_first<Packet8d>(const Packet8d& a_in) {
+ Packet2d a = _mm512_castpd512_pd128(a_in);
+ return _mm512_broadcastsd_pd(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8d pbroadcast_second<Packet8d>(const Packet8d& a_in) {
+ Packet2d a = _mm_permute_pd(_mm512_castpd512_pd128(a_in), 3);
+ return _mm512_broadcastsd_pd(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8d pbroadcast_third<Packet8d>(const Packet8d& a_in) {
+ Packet2d a = _mm512_extractf32x4_ps(a_in, 1);
+ return _mm512_broadcastsd_pd(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth<Packet8d>(const Packet8d& a_in) {
+ Packet2d a = _mm_permute_pd(_mm512_extractf32x4_ps(a_in, 1), 3);
+ return _mm512_broadcastsd_pd(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet16i
+pbroadcast_first<Packet16i>(const Packet16i& a_in) {
+ Packet4i a = _mm512_castsi512_si128(a_in);
+ return _mm512_broadcastd_epi32(a);
+}
+template <>
+EIGEN_STRONG_INLINE Packet16i
+pbroadcast_second<Packet16i>(const Packet16i& a_in) {
+ Packet4i a = _mm512_castsi512_si128(a_in);
+ return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(1, 1, 1, 1)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet16i
+pbroadcast_third<Packet16i>(const Packet16i& a_in) {
+ Packet4i a = _mm512_castsi512_si128(a_in);
+ return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(2, 2, 2, 2)));
+}
+template <>
+EIGEN_STRONG_INLINE Packet16i
+pbroadcast_fourth<Packet16i>(const Packet16i& a_in) {
+ Packet4i a = _mm512_castsi512_si128(a_in);
+ return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(3, 3, 3, 3)));
+}
+#endif
+
#ifdef EIGEN_VECTORIZE_AVX
// For a Packet of Size 8 floats(256-bits), swap the 2nd and 3rd quadwords
template <>
@@ -245,6 +316,25 @@ EIGEN_STRONG_INLINE Packet8f pload2bf16<Packet8f>(const float* from) {
_mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
}
+#ifdef EIGEN_VECTORIZE_AVX512
+// Return a Packet with 4 floats loaded from 4 bfloat16 values
+template <>
+EIGEN_STRONG_INLINE Packet16f pload4bf16<Packet16f>(const float* from) {
+ __m128i zero = _mm_setzero_si128();
+ __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
+ return _mm512_castps128_ps512(
+ _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
+}
+// Return a Packet with 2 floats loaded from 2 bfloat16 values
+template <>
+EIGEN_STRONG_INLINE Packet16f pload2bf16<Packet16f>(const float* from) {
+ __m128i zero = _mm_setzero_si128();
+ __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
+ return _mm512_castps128_ps512(
+ _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
+}
+#endif
+
// For each 128-bit lane convert 4 bfloat to 4 float values from the lower half
// of the 128-bit lane
template <typename Packet>
@@ -313,6 +403,22 @@ EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth<Packet8f>(const Packet8f& a) {
}
#endif
+
+#ifdef EIGEN_VECTORIZE_AVX512
+
+template <typename Packet>
+EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) {
+ return _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_castsi512_si256(from)),
+ 16);
+}
+
+template <typename Packet>
+EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) {
+ return _mm512_slli_epi32(
+ _mm512_cvtepu16_epi32(_mm512_extractf64x4_pd(from, 1)), 16);
+}
+
+#endif
} // namespace internal
} // namespace Eigen
#endif
diff --git a/tensorflow/core/kernels/sparse_matmul_op_test.cc b/tensorflow/core/kernels/sparse_matmul_op_test.cc
index 45cad2e23b..b155e45187 100644
--- a/tensorflow/core/kernels/sparse_matmul_op_test.cc
+++ b/tensorflow/core/kernels/sparse_matmul_op_test.cc
@@ -200,7 +200,7 @@ class SparseMatmulOpTest : public ::testing::Test {
// zero out lower 16-bits of mantissa of data3 values
// copy bfloat representation to data3_bfloat16
- for (int i = 0; i < kMaxPacketSize; ++i) {
+ for (int i = 0; i < kMaxPacketSize * 2; ++i) {
uint16_t* data3_p = reinterpret_cast<uint16_t*>(&data3[i]);
uint16_t* data3_bfloat16_p =
reinterpret_cast<uint16_t*>(data3_bfloat16) + i;
@@ -222,7 +222,13 @@ class SparseMatmulOpTest : public ::testing::Test {
return true;
}
+#ifdef EIGEN_VECTORIZE_AVX512
static const int kMaxPacketSize = 16;
+#elif defined EIGEN_VECTORIZE_AVX || defined EIGEN_VECTORIZE_AVX2
+ static const int kMaxPacketSize = 8;
+#else
+ static const int kMaxPacketSize = 4;
+#endif
typedef typename Eigen::internal::packet_traits<float>::type Packet;
const int PacketSize;
// float values
@@ -230,9 +236,9 @@ class SparseMatmulOpTest : public ::testing::Test {
// output of intrinsics
EIGEN_ALIGN_MAX float data2[kMaxPacketSize];
// float values with only 7 mantissa bits (bfloat representable)
- EIGEN_ALIGN_MAX float data3[kMaxPacketSize];
+ EIGEN_ALIGN_MAX float data3[kMaxPacketSize * 2];
// bfloat16 representation of data3
- EIGEN_ALIGN_MAX float data3_bfloat16[kMaxPacketSize / 2];
+ EIGEN_ALIGN_MAX float data3_bfloat16[kMaxPacketSize];
EIGEN_ALIGN_MAX float ref[kMaxPacketSize];
};
diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc
index 33705aac6a..e275b63de4 100644
--- a/tensorflow/core/kernels/stack_ops.cc
+++ b/tensorflow/core/kernels/stack_ops.cc
@@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -130,11 +129,12 @@ Status GetStack(OpKernelContext* ctx, Stack** stack) {
}
const string& container = Tstack_handle.flat<string>()(0);
const string& stack_name = Tstack_handle.flat<string>()(1);
- ResourceMgr* rm = ctx->step_resource_manager();
+ ResourceMgr* rm = ctx->resource_manager();
if (rm == nullptr) {
- return errors::Internal("No per-step resource manager.");
+ return errors::Internal("No resource manager.");
}
- TF_RETURN_IF_ERROR(rm->Lookup(container, stack_name, stack));
+ TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(),
+ strings::StrCat(container, stack_name), stack));
return Status::OK();
}
@@ -162,12 +162,13 @@ class StackOp : public OpKernel {
auto handle = stack_handle.flat<string>();
handle(0) = "_stacks";
handle(1) = strings::StrCat(stack_name_, "_", stack_id);
- // Store the handle in a container of the per-step RM.
- ResourceMgr* rm = ctx->step_resource_manager();
- OP_REQUIRES(ctx, rm != nullptr,
- errors::Internal("No per-step resource manager."));
+ // Store the handle in a per-step container.
+ ResourceMgr* rm = ctx->resource_manager();
+ OP_REQUIRES(ctx, rm != nullptr, errors::Internal("No resource manager."));
Stack* stack = new Stack(elem_type_, stack_handle);
- OP_REQUIRES_OK(ctx, rm->Create(handle(0), handle(1), stack));
+ OP_REQUIRES_OK(ctx,
+ rm->Create(ctx->step_container()->name(),
+ strings::StrCat(handle(0), handle(1)), stack));
ctx->set_output_ref(0, stack->mu(), stack->handle());
}
diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc
index c4584993fa..3226e5e0f8 100644
--- a/tensorflow/core/kernels/string_split_op.cc
+++ b/tensorflow/core/kernels/string_split_op.cc
@@ -30,7 +30,7 @@ namespace {
std::vector<string> Split(const string& str, const string& delimiter) {
if (delimiter.size()) {
- return str_util::Split(str, delimiter[0], str_util::SkipEmpty());
+ return str_util::Split(str, delimiter, str_util::SkipEmpty());
}
std::vector<string> char_vector(str.size());
for (size_t i = 0; i < str.size(); ++i) {
@@ -64,10 +64,6 @@ class StringSplitOp : public OpKernel {
const auto delimiter_vec = delimiter_tensor->flat<string>();
const string& delimiter = delimiter_vec(0);
// Empty delimiter means split the input character by character.
- OP_REQUIRES(ctx, delimiter.size() < 2,
- errors::InvalidArgument("Delimiter must be a character, got",
- delimiter));
-
std::vector<string> tokens;
// Guess that we'll be unpacking a handful of tokens per example.
static constexpr int kReserveSize = 4;
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index 318e8ba160..fa26232468 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -45,6 +45,8 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#endif // GOOGLE_CUDA
+// clang-format on
+
namespace tensorflow {
Status GetHandle(OpKernelContext* ctx, string* container, string* ta_handle) {
@@ -72,9 +74,10 @@ Status GetTensorArray(OpKernelContext* ctx, TensorArray** tensor_array) {
string container;
string ta_handle;
TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &ta_handle));
- ResourceMgr* rm = ctx->step_resource_manager();
- if (rm == nullptr) return errors::Internal("No per-step resource manager.");
- TF_RETURN_IF_ERROR(rm->Lookup(container, ta_handle, tensor_array));
+ ResourceMgr* rm = ctx->resource_manager();
+ if (rm == nullptr) return errors::Internal("No resource manager.");
+ TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(),
+ container + ta_handle, tensor_array));
return Status::OK();
}
@@ -104,10 +107,9 @@ class TensorArrayCreationOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->allocate_temp(
tensorflow::DT_STRING, tensorflow::TensorShape({2}),
&tensor_array_output_handle, alloc_attr));
- // Store the handle in a container of the per-step RM.
- ResourceMgr* rm = ctx->step_resource_manager();
- OP_REQUIRES(ctx, rm != nullptr,
- errors::Internal("No per-step resource manager."));
+ // Store the handle in a per-step container of the RM.
+ ResourceMgr* rm = ctx->resource_manager();
+ OP_REQUIRES(ctx, rm != nullptr, errors::Internal("No resource manager."));
TensorArray* output_tensor_array;
OP_REQUIRES_OK(ctx, CreateTensorArray(ctx, rm, &tensor_array_output_handle,
@@ -167,8 +169,9 @@ class TensorArrayOp : public TensorArrayCreationOp {
false /* multiple_writes_aggregate */, false /* is_grad */,
-1 /* marked_size */, clear_after_read_);
- TF_RETURN_IF_ERROR(
- rm->Create(handle(0), unique_tensor_array_name, tensor_array));
+ TF_RETURN_IF_ERROR(rm->Create(
+ ctx->step_container()->name(),
+ strings::StrCat(handle(0), unique_tensor_array_name), tensor_array));
*output_tensor_array = tensor_array;
@@ -236,7 +239,9 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
output_handle(1) = strings::StrCat(tensor_array_name, "@", source_);
TensorArray* tensor_array;
- TF_RETURN_IF_ERROR(rm->Lookup(container, tensor_array_name, &tensor_array));
+ TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(),
+ strings::StrCat(container, tensor_array_name),
+ &tensor_array));
core::ScopedUnref unref(tensor_array);
// Once gradients are being calculated, the forward TensorArray
@@ -268,7 +273,9 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
};
Status s = rm->LookupOrCreate<TensorArray>(
- output_handle(0), output_handle(1), output_tensor_array, creator);
+ ctx->step_container()->name(),
+ strings::StrCat(output_handle(0), output_handle(1)),
+ output_tensor_array, creator);
(*output_tensor_array)->Unref();
return s;
diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h
index f44f94c51b..d8d8831702 100644
--- a/tensorflow/core/kernels/variable_ops.h
+++ b/tensorflow/core/kernels/variable_ops.h
@@ -102,7 +102,7 @@ class TemporaryVariableOp : public OpKernel {
void Compute(OpKernelContext* context) override {
Status s;
- ResourceMgr* rm = context->step_resource_manager();
+ ResourceMgr* rm = context->resource_manager();
OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
auto* tmp_var = new TmpVar;
OP_REQUIRES(context, tmp_var,
@@ -111,7 +111,8 @@ class TemporaryVariableOp : public OpKernel {
s = context->allocate_temp(dtype_, shape_, &tmp_var->val);
if (!s.ok()) tmp_var->Unref();
OP_REQUIRES_OK(context, s);
- OP_REQUIRES_OK(context, rm->Create("tmp_var", var_name_, tmp_var));
+ OP_REQUIRES_OK(context, rm->Create(context->step_container()->name(),
+ var_name_, tmp_var));
context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
}
@@ -149,10 +150,10 @@ class DestroyTemporaryVariableOp : public OpKernel {
CHECK(IsRefType(context->input_dtype(0)));
Tensor tmpvar = context->mutable_input(0, false);
context->set_output(0, tmpvar);
- ResourceMgr* rm = context->step_resource_manager();
+ ResourceMgr* rm = context->resource_manager();
OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
- OP_REQUIRES_OK(
- context, rm->Delete<TemporaryVariableOp::TmpVar>("tmp_var", var_name_));
+ OP_REQUIRES_OK(context, rm->Delete<TemporaryVariableOp::TmpVar>(
+ context->step_container()->name(), var_name_));
}
private:
diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h
index ee197e54e3..183b18a5c6 100644
--- a/tensorflow/core/lib/strings/str_util.h
+++ b/tensorflow/core/lib/strings/str_util.h
@@ -108,9 +108,12 @@ struct SkipWhitespace {
}
};
-std::vector<string> Split(StringPiece text, char delim);
+// Split strings using any of the supplied delimiters. For example:
+// Split("a,b.c,d", ".,") would return {"a", "b", "c", "d"}.
+std::vector<string> Split(StringPiece text, StringPiece delims);
+
template <typename Predicate>
-std::vector<string> Split(StringPiece text, char delim, Predicate p);
+std::vector<string> Split(StringPiece text, StringPiece delims, Predicate p);
// Split "text" at "delim" characters, and parse each component as
// an integer. If successful, adds the individual numbers in order
@@ -157,17 +160,17 @@ string Join(const T& s, const char* sep, Formatter f) {
return result;
}
-inline std::vector<string> Split(StringPiece text, char delim) {
- return Split(text, delim, AllowEmpty());
+inline std::vector<string> Split(StringPiece text, StringPiece delims) {
+ return Split(text, delims, AllowEmpty());
}
template <typename Predicate>
-std::vector<string> Split(StringPiece text, char delim, Predicate p) {
+std::vector<string> Split(StringPiece text, StringPiece delims, Predicate p) {
std::vector<string> result;
size_t token_start = 0;
if (!text.empty()) {
for (size_t i = 0; i < text.size() + 1; i++) {
- if ((i == text.size()) || (text[i] == delim)) {
+ if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) {
StringPiece token(text.data() + token_start, i - token_start);
if (p(token)) {
result.push_back(token.ToString());
@@ -179,6 +182,15 @@ std::vector<string> Split(StringPiece text, char delim, Predicate p) {
return result;
}
+inline std::vector<string> Split(StringPiece text, char delim) {
+ return Split(text, StringPiece(&delim, 1));
+}
+
+template <typename Predicate>
+std::vector<string> Split(StringPiece text, char delims, Predicate p) {
+ return Split(text, StringPiece(&delims, 1), p);
+}
+
} // namespace str_util
} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/str_util_test.cc b/tensorflow/core/lib/strings/str_util_test.cc
index 055e1e4ac0..afdc5855cc 100644
--- a/tensorflow/core/lib/strings/str_util_test.cc
+++ b/tensorflow/core/lib/strings/str_util_test.cc
@@ -239,6 +239,8 @@ TEST(Split, Basic) {
EXPECT_EQ(str_util::Join(str_util::Split("a,b,c", ','), "|"), "a|b|c");
EXPECT_EQ(str_util::Join(str_util::Split("a,,,b,,c,", ','), "|"),
"a|||b||c|");
+ EXPECT_EQ(str_util::Join(str_util::Split("a!,!b,!c,", ",!"), "|"),
+ "a|||b||c|");
EXPECT_EQ(str_util::Join(
str_util::Split("a,,,b,,c,", ',', str_util::SkipEmpty()), "|"),
"a|b|c");
@@ -246,6 +248,10 @@ TEST(Split, Basic) {
str_util::Join(
str_util::Split("a, ,b,,c,", ',', str_util::SkipWhitespace()), "|"),
"a|b|c");
+ EXPECT_EQ(str_util::Join(str_util::Split("a. !b,;c,", ".,;!",
+ str_util::SkipWhitespace()),
+ "|"),
+ "a|b|c");
}
TEST(SplitAndParseAsInts, Int32) {
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index 8c459ed92b..7e8132f689 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -556,24 +556,32 @@ TEST(ArrayOpsTest, Concat_ShapeFn) {
set_n(2);
// Sum dim 0, merge the other two dims.
- concat_dim_t = test::AsScalar(0);
- INFER_OK(op, "[];[100,2,?];[10,?,3]", "[110,d1_1,d2_2]");
- INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op,
- "[];[100,2,5];[10,?,3]");
- // concat_dim can't be summed, as one value is unknown.
- INFER_OK(op, "[];[100,2,?];[?,?,3]", "[?,d1_1,d2_2]");
- INFER_OK(op, "[];[?,2,?];[10,?,3]", "[?,d1_1,d2_2]");
+ for (int concat_dim : {0, -3}) {
+ concat_dim_t = test::AsScalar(concat_dim);
+ INFER_OK(op, "[];[100,2,?];[10,?,3]", "[110,d1_1,d2_2]");
+ INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op,
+ "[];[100,2,5];[10,?,3]");
+ // concat_dim can't be summed, as one value is unknown.
+ INFER_OK(op, "[];[100,2,?];[?,?,3]", "[?,d1_1,d2_2]");
+ INFER_OK(op, "[];[?,2,?];[10,?,3]", "[?,d1_1,d2_2]");
+ }
// Test with a higher concat_dim.
- concat_dim_t = test::AsScalar(1);
- INFER_OK(op, "[];[1,100,?];[?,10,3]", "[d1_0,110,d2_2]");
- INFER_OK(op, "[];[1,100];[?,10]", "[d1_0,110]");
- INFER_OK(op, "[];[?,100];[1,10]", "[d2_0,110]");
- // concat_dim is too high.
- INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
- "[];[100];[10,?]");
- INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
- "[];[100,5];[10]");
+ for (bool use_negative : {false, true}) {
+ concat_dim_t = test::AsScalar(use_negative ? -2 : 1);
+ INFER_OK(op, "[];[1,100,?];[?,10,3]", "[d1_0,110,d2_2]");
+ concat_dim_t = test::AsScalar(use_negative ? -1 : 1);
+ INFER_OK(op, "[];[1,100];[?,10]", "[d1_0,110]");
+ INFER_OK(op, "[];[?,100];[1,10]", "[d2_0,110]");
+
+ // concat_dim is out of bounds.
+ concat_dim_t = test::AsScalar(use_negative ? -2 : 1);
+ INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
+ "[];[100];[10,?]");
+ INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
+ "[];[100,5];[10]");
+ }
+
// concat_dim is too low.
concat_dim_t = test::AsScalar(-2);
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 6fbdb86c45..ff46aa2725 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -20029,7 +20029,7 @@ op {
}
input_arg {
name: "delimiter"
- description: "0-D. Delimiter character, or empty string."
+ description: "0-D. Delimiter characters (bytes), or empty string."
type: DT_STRING
}
output_arg {
@@ -20048,7 +20048,7 @@ op {
type: DT_INT64
}
summary: "Split elements of `input` based on `delimiter` into a `SparseTensor`."
- description: "Let N be the size of source (typically N will be the batch size). Split each\nelement of `input` based on `delimiter` and return a `SparseTensor`\ncontaining the splitted tokens. Empty tokens are ignored.\n\n`delimiter` can be empty or a single-byte character. If `delimiter` is an empty\n string, each element of `input` is split into individual single-byte character\n strings, including splitting of UTF-8 multibyte sequences.\n\nFor example:\n N = 2, input[0] is \'hello world\' and input[1] is \'a b c\', then the output\n will be\n\n indices = [0, 0;\n 0, 1;\n 1, 0;\n 1, 1;\n 1, 2]\n shape = [2, 3]\n values = [\'hello\', \'world\', \'a\', \'b\', \'c\']"
+ description: "Let N be the size of source (typically N will be the batch size). Split each\nelement of `input` based on `delimiter` and return a `SparseTensor`\ncontaining the splitted tokens. Empty tokens are ignored.\n\n`delimiter` can be empty, or a string of split characters. If `delimiter` is an\n empty string, each element of `input` is split into individual single-byte\n character strings, including splitting of UTF-8 multibyte sequences. Otherwise\n every character of `delimiter` is a potential split point.\n\nFor example:\n N = 2, input[0] is \'hello world\' and input[1] is \'a b c\', then the output\n will be\n\n indices = [0, 0;\n 0, 1;\n 1, 0;\n 1, 1;\n 1, 2]\n shape = [2, 3]\n values = [\'hello\', \'world\', \'a\', \'b\', \'c\']"
}
op {
name: "StringToHashBucket"
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 53d75e4519..cef40289bf 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -222,9 +222,10 @@ Let N be the size of source (typically N will be the batch size). Split each
element of `input` based on `delimiter` and return a `SparseTensor`
containing the splitted tokens. Empty tokens are ignored.
-`delimiter` can be empty or a single-byte character. If `delimiter` is an empty
- string, each element of `input` is split into individual single-byte character
- strings, including splitting of UTF-8 multibyte sequences.
+`delimiter` can be empty, or a string of split characters. If `delimiter` is an
+ empty string, each element of `input` is split into individual single-byte
+ character strings, including splitting of UTF-8 multibyte sequences. Otherwise
+ every character of `delimiter` is a potential split point.
For example:
N = 2, input[0] is 'hello world' and input[1] is 'a b c', then the output
@@ -239,7 +240,7 @@ For example:
values = ['hello', 'world', 'a', 'b', 'c']
input: 1-D. Strings to split.
-delimiter: 0-D. Delimiter character, or empty string.
+delimiter: 0-D. Delimiter characters (bytes), or empty string.
indices: A dense matrix of int64 representing the indices of the sparse tensor.
values: A vector of strings corresponding to the splited values.
shape: a length-2 vector of int64 representing the shape of the sparse
diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD
index bb52e75df3..810675fbcb 100644
--- a/tensorflow/core/platform/default/build_config/BUILD
+++ b/tensorflow/core/platform/default/build_config/BUILD
@@ -38,10 +38,17 @@ tf_cuda_library(
)
cc_library(
+ name = "stream_executor_cuda",
+ deps = [
+ "//tensorflow/stream_executor",
+ ] + select({
+ "@local_config_cuda//cuda:darwin": ["IOKit"],
+ "//conditions:default": [],
+ }),
+)
+
+cc_library(
name = "stream_executor_no_cuda",
- hdrs = [
- "stream_executor_no_cuda.h",
- ],
deps = [
"//tensorflow/stream_executor",
],
diff --git a/tensorflow/core/platform/default/logging.h b/tensorflow/core/platform/default/logging.h
index eaae673464..961fb8b4ad 100644
--- a/tensorflow/core/platform/default/logging.h
+++ b/tensorflow/core/platform/default/logging.h
@@ -67,6 +67,8 @@ class LogMessageFatal : public LogMessage {
#define _TF_LOG_FATAL \
::tensorflow::internal::LogMessageFatal(__FILE__, __LINE__)
+#define _TF_LOG_QFATAL _TF_LOG_FATAL
+
#define LOG(severity) _TF_LOG_##severity
// TODO(jeff): Define a proper implementation of VLOG_IS_ON
diff --git a/tensorflow/core/platform/default/stacktrace.h b/tensorflow/core/platform/default/stacktrace.h
index 8dc27b5d63..5f3073262a 100644
--- a/tensorflow/core/platform/default/stacktrace.h
+++ b/tensorflow/core/platform/default/stacktrace.h
@@ -22,12 +22,14 @@ namespace tensorflow {
inline string CurrentStackTrace() { return "No stack trace available"; }
+inline void DebugWriteToString(const char* data, void* arg) {}
+
// A dummy class that does nothing. Someday, add real support.
class SavedStackTrace {
public:
SavedStackTrace() {}
- void CreateCurrent() {}
+ void CreateCurrent(int skip_count) {}
void Reset() {}
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index 52ced38ac8..104ad42439 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -63,6 +63,23 @@ class DefaultEnvTest : public ::testing::Test {
Env* env_ = Env::Default();
};
+TEST_F(DefaultEnvTest, IncompleteReadOutOfRange) {
+ const string filename = io::JoinPath(BaseDir(), "out_of_range");
+ const string input = CreateTestFile(env_, filename, 2);
+ std::unique_ptr<RandomAccessFile> f;
+ TF_EXPECT_OK(env_->NewRandomAccessFile(filename, &f));
+
+ // Reading past EOF should give an OUT_OF_RANGE error
+ StringPiece result;
+ char scratch[3];
+ EXPECT_EQ(error::OUT_OF_RANGE, f->Read(0, 3, &result, scratch).code());
+ EXPECT_EQ(input, result);
+
+ // Exact read to EOF works.
+ TF_EXPECT_OK(f->Read(0, 2, &result, scratch));
+ EXPECT_EQ(input, result);
+}
+
TEST_F(DefaultEnvTest, ReadFileToString) {
for (const int length : {0, 1, 1212, 2553, 4928, 8196, 9000, (1 << 20) - 1,
1 << 20, (1 << 20) + 1}) {
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index 2be35eb455..a3b8b400a3 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -78,6 +78,19 @@ message OptimizerOptions {
}
Level opt_level = 3;
+
+ // Control the use of the compiler/jit. Experimental.
+ enum GlobalJitLevel {
+ DEFAULT = 0; // Default setting ("off" now, but later expected to be "on")
+ OFF = -1;
+ // The following settings turn on compilation, with higher values being
+ // more aggressive. Higher values may reduce opportunities for parallelism
+ // and may use more memory. (At present, there is no distinction, but this
+ // is expected to change.)
+ ON_1 = 1;
+ ON_2 = 2;
+ }
+ GlobalJitLevel global_jit_level = 5;
}
message GraphOptions {
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
index 2ffe186e12..4693b4c005 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
@@ -511,12 +511,7 @@ TEST(TensorBundleTest, TruncatedTensorContents) {
BundleReader reader(env, Prefix("end"));
TF_ASSERT_OK(reader.status());
Tensor val(DT_FLOAT, TensorShape({2, 3}));
-#if defined(PLATFORM_GOOGLE)
- EXPECT_EQ("Data loss: Requested 24 bytes but read 23 bytes.",
- reader.Lookup("key", &val).ToString());
-#else
EXPECT_TRUE(errors::IsOutOfRange(reader.Lookup("key", &val)));
-#endif
}
TEST(TensorBundleTest, HeaderEntry) {
diff --git a/tensorflow/examples/android/AndroidManifest.xml b/tensorflow/examples/android/AndroidManifest.xml
index 0a48d3d50b..e388734564 100644
--- a/tensorflow/examples/android/AndroidManifest.xml
+++ b/tensorflow/examples/android/AndroidManifest.xml
@@ -41,6 +41,15 @@
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
+
+ <activity android:name="org.tensorflow.demo.DetectorActivity"
+ android:screenOrientation="portrait"
+ android:label="@string/activity_name_detection">
+ <intent-filter>
+ <action android:name="android.intent.action.MAIN" />
+ <category android:name="android.intent.category.LAUNCHER" />
+ </intent-filter>
+ </activity>
</application>
</manifest>
diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD
index beb8337702..3ba3a494ab 100644
--- a/tensorflow/examples/android/BUILD
+++ b/tensorflow/examples/android/BUILD
@@ -5,7 +5,11 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
-load("//tensorflow:tensorflow.bzl", "tf_copts")
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_copts",
+ "tf_opts_nortti_if_android",
+)
exports_files(["LICENSE"])
@@ -35,6 +39,7 @@ cc_binary(
"notap",
],
deps = [
+ ":demo_proto_lib_cc",
"//tensorflow/contrib/android:android_tensorflow_inference_jni",
"//tensorflow/core:android_tensorflow_lib",
LINKER_SCRIPT,
@@ -60,6 +65,7 @@ android_binary(
assets = [
"//tensorflow/examples/android/assets:asset_files",
"@inception5h//:model_files",
+ "@mobile_multibox//:model_files",
],
assets_dir = "",
custom_package = "org.tensorflow.demo",
@@ -111,3 +117,20 @@ filegroup(
)
exports_files(["AndroidManifest.xml"])
+
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_proto_library",
+)
+
+tf_proto_library(
+ name = "demo_proto_lib",
+ srcs = glob(
+ ["**/*.proto"],
+ ),
+ cc_api_version = 2,
+ visibility = ["//visibility:public"],
+)
+
+# -----------------------------------------------------------------------------
+# Google-internal targets go here (must be at the end).
diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md
index b0465f7faa..b6556cdef4 100644
--- a/tensorflow/examples/android/README.md
+++ b/tensorflow/examples/android/README.md
@@ -1,11 +1,24 @@
# TensorFlow Android Camera Demo
-This folder contains a simple camera-based demo application utilizing TensorFlow.
+This folder contains an example application utilizing TensorFlow for Android
+devices.
## Description
-This demo uses a Google Inception model to classify camera frames in real-time,
-displaying the top results in an overlay on the camera image.
+The demos in this folder are designed to give straightforward samples of using
+TensorFlow in mobile applications. Inference is done using the Java JNI API
+exposed by `tensorflow/contrib/android`.
+
+Current samples:
+
+1. [TF Classify](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java):
+ Uses the [Google Inception](https://arxiv.org/abs/1409.4842)
+ model to classify camera frames in real-time, displaying the top results
+ in an overlay on the camera image.
+2. [TF Detect](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java):
+ Demonstrates a model based on [Scalable Object Detection
+ using Deep Neural Networks](https://arxiv.org/abs/1312.2249) to
+ localize and track people in the camera preview in real-time.
## To build/install/run
@@ -19,9 +32,9 @@ installed on your system.
3. The Android SDK and build tools may be obtained from:
https://developer.android.com/tools/revisions/build-tools.html
-The Android entries in [`<workspace_root>/WORKSPACE`](../../../WORKSPACE#L2-L13) must be
-uncommented with the paths filled in appropriately depending on where you
-installed the NDK and SDK. Otherwise an error such as:
+The Android entries in [`<workspace_root>/WORKSPACE`](../../../WORKSPACE#L2-L13)
+must be uncommented with the paths filled in appropriately depending on where
+you installed the NDK and SDK. Otherwise an error such as:
"The external label '//external:android/sdk' is not bound to anything" will
be reported.
@@ -29,19 +42,21 @@ The TensorFlow `GraphDef` that contains the model definition and weights
is not packaged in the repo because of its size. It will be downloaded
automatically via a new_http_archive defined in WORKSPACE.
-**Optional**: If you wish to place the model in your assets manually (E.g. for
-non-Bazel builds), remove the
-`inception_5` entry in `BUILD` and download the archive yourself to the
-`assets` directory in the source tree:
+**Optional**: If you wish to place the models in your assets manually (E.g. for
+non-Bazel builds), remove the `inception_5` and `mobile_multibox` entries in
+`BUILD` and download the archives yourself to the `assets` directory in the
+source tree:
```bash
$ curl -L https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip -o /tmp/inception5h.zip
+$ curl -L https://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1.zip -o /tmp/mobile_multibox_v1.zip
$ unzip /tmp/inception5h.zip -d tensorflow/examples/android/assets/
+$ unzip /tmp/mobile_multibox_v1.zip -d tensorflow/examples/android/assets/
```
-The labels file describing the possible classification will also be in the
-assets directory.
+The associated label and box prior files for the models will also be extracted
+into the assets directory.
After editing your WORKSPACE file to update the SDK/NDK configuration,
you may build the APK. Run this from your workspace root:
diff --git a/tensorflow/examples/android/jni/box_coder_jni.cc b/tensorflow/examples/android/jni/box_coder_jni.cc
new file mode 100644
index 0000000000..be85414fc1
--- /dev/null
+++ b/tensorflow/examples/android/jni/box_coder_jni.cc
@@ -0,0 +1,92 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file loads the box coder mappings.
+
+#include <android/asset_manager.h>
+#include <android/asset_manager_jni.h>
+#include <android/bitmap.h>
+
+#include <jni.h>
+#include <pthread.h>
+#include <sys/stat.h>
+#include <unistd.h>
+#include <map>
+#include <queue>
+#include <sstream>
+#include <string>
+
+#include "tensorflow/contrib/android/jni/jni_utils.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/proto/box_coder.pb.h"
+
+#define TENSORFLOW_METHOD(METHOD_NAME) \
+ Java_org_tensorflow_demo_TensorFlowMultiBoxDetector_##METHOD_NAME // NOLINT
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+JNIEXPORT void JNICALL TENSORFLOW_METHOD(loadCoderOptions)(
+ JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring location,
+ jfloatArray priors);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+JNIEXPORT void JNICALL TENSORFLOW_METHOD(loadCoderOptions)(
+ JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring location,
+ jfloatArray priors) {
+ AAssetManager* const asset_manager =
+ AAssetManager_fromJava(env, java_asset_manager);
+ LOG(INFO) << "Acquired AssetManager.";
+
+ const std::string location_str = GetString(env, location);
+
+ org_tensorflow_demo::MultiBoxCoderOptions multi_options;
+
+ LOG(INFO) << "Reading file to proto: " << location_str;
+ ReadFileToProtoOrDie(asset_manager, location_str.c_str(), &multi_options);
+
+ LOG(INFO) << "Read file. " << multi_options.box_coder_size() << " entries.";
+
+ jboolean iCopied = JNI_FALSE;
+ jfloat* values = env->GetFloatArrayElements(priors, &iCopied);
+
+ const int array_length = env->GetArrayLength(priors);
+ LOG(INFO) << "Array length: " << array_length
+ << " (/8 = " << (array_length / 8) << ")";
+ CHECK_EQ(array_length % 8, 0);
+
+ const int num_items =
+ std::min(array_length / 8, multi_options.box_coder_size());
+
+ for (int i = 0; i < num_items; ++i) {
+ const org_tensorflow_demo::BoxCoderOptions& options =
+ multi_options.box_coder(i);
+
+ for (int j = 0; j < 4; ++j) {
+ const org_tensorflow_demo::BoxCoderPrior& prior = options.priors(j);
+ values[i * 8 + j * 2] = prior.mean();
+ values[i * 8 + j * 2 + 1] = prior.stddev();
+ }
+ }
+ env->ReleaseFloatArrayElements(priors, values, 0);
+
+ LOG(INFO) << "Read " << num_items << " options";
+}
diff --git a/tensorflow/examples/android/jni/object_tracking/config.h b/tensorflow/examples/android/jni/object_tracking/config.h
new file mode 100644
index 0000000000..86e9fc71b6
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/config.h
@@ -0,0 +1,300 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
+
+#include <math.h>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+
+namespace tf_tracking {
+
+// Arbitrary keypoint type ids for labeling the origin of tracked keypoints.
+enum KeypointType {
+ KEYPOINT_TYPE_DEFAULT = 0,
+ KEYPOINT_TYPE_FAST = 1,
+ KEYPOINT_TYPE_INTEREST = 2
+};
+
+// Struct that can be used to more richly store the results of a detection
+// than a single number, while still maintaining comparability.
+struct MatchScore {
+ explicit MatchScore(double val) : value(val) {}
+ MatchScore() { value = 0.0; }
+
+ double value;
+
+ MatchScore& operator+(const MatchScore& rhs) {
+ value += rhs.value;
+ return *this;
+ }
+
+ friend std::ostream& operator<<(std::ostream& stream,
+ const MatchScore& detection) {
+ stream << detection.value;
+ return stream;
+ }
+};
+inline bool operator< (const MatchScore& cC1, const MatchScore& cC2) {
+ return cC1.value < cC2.value;
+}
+inline bool operator> (const MatchScore& cC1, const MatchScore& cC2) {
+ return cC1.value > cC2.value;
+}
+inline bool operator>= (const MatchScore& cC1, const MatchScore& cC2) {
+ return cC1.value >= cC2.value;
+}
+inline bool operator<= (const MatchScore& cC1, const MatchScore& cC2) {
+ return cC1.value <= cC2.value;
+}
+
+// Fixed seed used for all random number generators.
+static const int kRandomNumberSeed = 11111;
+
+// TODO(andrewharp): Move as many of these settings as possible into a settings
+// object which can be passed in from Java at runtime.
+
+// Whether or not to use ESM instead of LK flow.
+static const bool kUseEsm = false;
+
+// This constant gets added to the diagonal of the Hessian
+// before solving for translation in 2dof ESM.
+// It ensures better behavior especially in the absence of
+// strong texture.
+static const int kEsmRegularizer = 20;
+
+// Do we want to brightness-normalize each keypoint patch when we compute
+// its flow using ESM?
+static const bool kDoBrightnessNormalize = true;
+
+// Whether or not to use fixed-point interpolated pixel lookups in optical flow.
+#define USE_FIXED_POINT_FLOW 1
+
+// Whether to normalize keypoint windows for intensity in LK optical flow.
+// This is a define for now because it helps keep the code streamlined.
+#define NORMALIZE 1
+
+// Number of keypoints to store per frame.
+static const int kMaxKeypoints = 76;
+
+// Keypoint detection.
+static const int kMaxTempKeypoints = 1024;
+
+// Number of floats each keypoint takes up when exporting to an array.
+static const int kKeypointStep = 7;
+
+// Number of frame deltas to keep around in the circular queue.
+static const int kNumFrames = 512;
+
+// Number of iterations to do tracking on each keypoint at each pyramid level.
+static const int kNumIterations = 3;
+
+// The number of bins (on a side) to divide each bin from the previous
+// cache level into. Higher numbers will decrease performance by increasing
+// cache misses, but mean that cache hits are more locally relevant.
+static const int kCacheBranchFactor = 2;
+
+// Number of levels to put in the cache.
+// Each level of the cache is a square grid of bins, length:
+// branch_factor^(level - 1) on each side.
+//
+// This may be greater than kNumPyramidLevels. Setting it to 0 means no
+// caching is enabled.
+static const int kNumCacheLevels = 3;
+
+// The level at which the cache pyramid gets cut off and replaced by a matrix
+// transform if such a matrix has been provided to the cache.
+static const int kCacheCutoff = 1;
+
+static const int kNumPyramidLevels = 4;
+
+// The minimum number of keypoints needed in an object's area.
+static const int kMaxKeypointsForObject = 16;
+
+// Minimum number of pyramid levels to use after getting cached value.
+// This allows fine-scale adjustment from the cached value, which is taken
+// from the center of the corresponding top cache level box.
+// Can be [0, kNumPyramidLevels).
+static const int kMinNumPyramidLevelsToUseForAdjustment = 1;
+
+// Window size to integrate over to find local image derivative.
+static const int kFlowIntegrationWindowSize = 3;
+
+// Total area of integration windows.
+static const int kFlowArraySize =
+ (2 * kFlowIntegrationWindowSize + 1) * (2 * kFlowIntegrationWindowSize + 1);
+
+// Error that's considered good enough to early abort tracking.
+static const float kTrackingAbortThreshold = 0.03f;
+
+// Maximum number of deviations a keypoint-correspondence delta can be from the
+// weighted average before being thrown out for region-based queries.
+static const float kNumDeviations = 2.0f;
+
+// The length of the allowed delta between the forward and the backward
+// flow deltas in terms of the length of the forward flow vector.
+static const float kMaxForwardBackwardErrorAllowed = 0.5f;
+
+// Threshold for pixels to be considered different.
+static const int kFastDiffAmount = 10;
+
+// How far from edge of frame to stop looking for FAST keypoints.
+static const int kFastBorderBuffer = 10;
+
+// Determines if non-detected arbitrary keypoints should be added to regions.
+// This will help if no keypoints have been detected in the region yet.
+static const bool kAddArbitraryKeypoints = true;
+
+// How many arbitrary keypoints to add along each axis as candidates for each
+// region?
+static const int kNumToAddAsCandidates = 1;
+
+// In terms of region dimensions, how closely can we place keypoints
+// next to each other?
+static const float kClosestPercent = 0.6f;
+
+// How many FAST qualifying pixels must be connected to a pixel for it to be
+// considered a candidate keypoint for Harris filtering.
+static const int kMinNumConnectedForFastKeypoint = 8;
+
+// Size of the window to integrate over for Harris filtering.
+// Compare to kFlowIntegrationWindowSize.
+static const int kHarrisWindowSize = 2;
+
+
+// DETECTOR PARAMETERS
+
+// Before relocalizing, make sure the new proposed position is better than
+// the existing position by a small amount to prevent thrashing.
+static const MatchScore kMatchScoreBuffer(0.01f);
+
+// Minimum score a tracked object can have and still be considered a match.
+// TODO(andrewharp): Make this a per detector thing.
+static const MatchScore kMinimumMatchScore(0.5f);
+
+static const float kMinimumCorrelationForTracking = 0.4f;
+
+static const MatchScore kMatchScoreForImmediateTermination(0.0f);
+
+// Run the detector every N frames.
+static const int kDetectEveryNFrames = 4;
+
+// How many features does each feature_set contain?
+static const int kFeaturesPerFeatureSet = 10;
+
+// The number of FeatureSets managed by the object detector.
+// More FeatureSets can increase recall at the cost of performance.
+static const int kNumFeatureSets = 7;
+
+// How many FeatureSets must respond affirmatively for a candidate descriptor
+// and position to be given more thorough attention?
+static const int kNumFeatureSetsForCandidate = 2;
+
+// How large the thumbnails used for correlation validation are. Used for both
+// width and height.
+static const int kNormalizedThumbnailSize = 11;
+
+// The area of intersection divided by union for the bounding boxes that tells
+// if this tracking has slipped enough to invalidate all unlocked examples.
+static const float kPositionOverlapThreshold = 0.6f;
+
+// The number of detection failures allowed before an object goes invisible.
+// Tracking will still occur, so if it is actually still being tracked and
+// comes back into a detectable position, it's likely to be found.
+static const int kMaxNumDetectionFailures = 4;
+
+
+// Minimum square size to scan with sliding window.
+static const float kScanMinSquareSize = 16.0f;
+
+// Minimum square size to scan with sliding window.
+static const float kScanMaxSquareSize = 64.0f;
+
+// Scale difference for consecutive scans of the sliding window.
+static const float kScanScaleFactor = sqrtf(2.0f);
+
+// Step size for sliding window.
+static const int kScanStepSize = 10;
+
+
+// How tightly to pack the descriptor boxes for confirmed exemplars.
+static const float kLockedScaleFactor = 1 / sqrtf(2.0f);
+
+// How tightly to pack the descriptor boxes for unconfirmed exemplars.
+static const float kUnlockedScaleFactor = 1 / 2.0f;
+
+// How tightly the boxes to scan centered at the last known position will be
+// packed.
+static const float kLastKnownPositionScaleFactor = 1.0f / sqrtf(2.0f);
+
+// The bounds on how close a new object example must be to existing object
+// examples for detection to be valid.
+static const float kMinCorrelationForNewExample = 0.75f;
+static const float kMaxCorrelationForNewExample = 0.99f;
+
+
+// The number of safe tries an exemplar has after being created before
+// missed detections count against it.
+static const int kFreeTries = 5;
+
+// A false positive is worth this many missed detections.
+static const int kFalsePositivePenalty = 5;
+
+struct ObjectDetectorConfig {
+ const Size image_size;
+
+ explicit ObjectDetectorConfig(const Size& image_size)
+ : image_size(image_size) {}
+ virtual ~ObjectDetectorConfig() = default;
+};
+
+struct KeypointDetectorConfig {
+ const Size image_size;
+
+ bool detect_skin;
+
+ explicit KeypointDetectorConfig(const Size& image_size)
+ : image_size(image_size),
+ detect_skin(false) {}
+};
+
+
+struct OpticalFlowConfig {
+ const Size image_size;
+
+ explicit OpticalFlowConfig(const Size& image_size)
+ : image_size(image_size) {}
+};
+
+struct TrackerConfig {
+ const Size image_size;
+ KeypointDetectorConfig keypoint_detector_config;
+ OpticalFlowConfig flow_config;
+ bool always_track;
+
+ float object_box_scale_factor_for_features;
+
+ explicit TrackerConfig(const Size& image_size)
+ : image_size(image_size),
+ keypoint_detector_config(image_size),
+ flow_config(image_size),
+ always_track(false),
+ object_box_scale_factor_for_features(1.0f) {}
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/flow_cache.h b/tensorflow/examples/android/jni/object_tracking/flow_cache.h
new file mode 100644
index 0000000000..8813ab6d71
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/flow_cache.h
@@ -0,0 +1,306 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
+
+namespace tf_tracking {
+
+// Class that helps OpticalFlow to speed up flow computation
+// by caching coarse-grained flow.
+class FlowCache {
+ public:
+ explicit FlowCache(const OpticalFlowConfig* const config)
+ : config_(config),
+ image_size_(config->image_size),
+ optical_flow_(config),
+ fullframe_matrix_(NULL) {
+ for (int i = 0; i < kNumCacheLevels; ++i) {
+ const int curr_dims = BlockDimForCacheLevel(i);
+ has_cache_[i] = new Image<bool>(curr_dims, curr_dims);
+ displacements_[i] = new Image<Point2f>(curr_dims, curr_dims);
+ }
+ }
+
+ ~FlowCache() {
+ for (int i = 0; i < kNumCacheLevels; ++i) {
+ SAFE_DELETE(has_cache_[i]);
+ SAFE_DELETE(displacements_[i]);
+ }
+ delete[](fullframe_matrix_);
+ fullframe_matrix_ = NULL;
+ }
+
+ void NextFrame(ImageData* const new_frame,
+ const float* const align_matrix23) {
+ ClearCache();
+ SetFullframeAlignmentMatrix(align_matrix23);
+ optical_flow_.NextFrame(new_frame);
+ }
+
+ void ClearCache() {
+ for (int i = 0; i < kNumCacheLevels; ++i) {
+ has_cache_[i]->Clear(false);
+ }
+ delete[](fullframe_matrix_);
+ fullframe_matrix_ = NULL;
+ }
+
+ // Finds the flow at a point, using the cache for performance.
+ bool FindFlowAtPoint(const float u_x, const float u_y,
+ float* const flow_x, float* const flow_y) const {
+ // Get the best guess from the cache.
+ const Point2f guess_from_cache = LookupGuess(u_x, u_y);
+
+ *flow_x = guess_from_cache.x;
+ *flow_y = guess_from_cache.y;
+
+ // Now refine the guess using the image pyramid.
+ for (int pyramid_level = kMinNumPyramidLevelsToUseForAdjustment - 1;
+ pyramid_level >= 0; --pyramid_level) {
+ if (!optical_flow_.FindFlowAtPointSingleLevel(
+ pyramid_level, u_x, u_y, false, flow_x, flow_y)) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ // Determines the displacement of a point, and uses that to calculate a new
+ // position.
+ // Returns true iff the displacement determination worked and the new position
+ // is in the image.
+ bool FindNewPositionOfPoint(const float u_x, const float u_y,
+ float* final_x, float* final_y) const {
+ float flow_x;
+ float flow_y;
+ if (!FindFlowAtPoint(u_x, u_y, &flow_x, &flow_y)) {
+ return false;
+ }
+
+ // Add in the displacement to get the final position.
+ *final_x = u_x + flow_x;
+ *final_y = u_y + flow_y;
+
+ // Assign the best guess, if we're still in the image.
+ if (InRange(*final_x, 0.0f, static_cast<float>(image_size_.width) - 1) &&
+ InRange(*final_y, 0.0f, static_cast<float>(image_size_.height) - 1)) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ // Comparison function for qsort.
+ static int Compare(const void* a, const void* b) {
+ return *reinterpret_cast<const float*>(a) -
+ *reinterpret_cast<const float*>(b);
+ }
+
+ // Returns the median flow within the given bounding box as determined
+ // by a grid_width x grid_height grid.
+ Point2f GetMedianFlow(const BoundingBox& bounding_box,
+ const bool filter_by_fb_error,
+ const int grid_width,
+ const int grid_height) const {
+ const int kMaxPoints = 100;
+ SCHECK(grid_width * grid_height <= kMaxPoints,
+ "Too many points for Median flow!");
+
+ const BoundingBox valid_box = bounding_box.Intersect(
+ BoundingBox(0, 0, image_size_.width - 1, image_size_.height - 1));
+
+ if (valid_box.GetArea() <= 0.0f) {
+ return Point2f(0, 0);
+ }
+
+ float x_deltas[kMaxPoints];
+ float y_deltas[kMaxPoints];
+
+ int curr_offset = 0;
+ for (int i = 0; i < grid_width; ++i) {
+ for (int j = 0; j < grid_height; ++j) {
+ const float x_in = valid_box.left_ +
+ (valid_box.GetWidth() * i) / (grid_width - 1);
+
+ const float y_in = valid_box.top_ +
+ (valid_box.GetHeight() * j) / (grid_height - 1);
+
+ float curr_flow_x;
+ float curr_flow_y;
+ const bool success = FindNewPositionOfPoint(x_in, y_in,
+ &curr_flow_x, &curr_flow_y);
+
+ if (success) {
+ x_deltas[curr_offset] = curr_flow_x;
+ y_deltas[curr_offset] = curr_flow_y;
+ ++curr_offset;
+ } else {
+ LOGW("Tracking failure!");
+ }
+ }
+ }
+
+ if (curr_offset > 0) {
+ qsort(x_deltas, curr_offset, sizeof(*x_deltas), Compare);
+ qsort(y_deltas, curr_offset, sizeof(*y_deltas), Compare);
+
+ return Point2f(x_deltas[curr_offset / 2], y_deltas[curr_offset / 2]);
+ }
+
+ LOGW("No points were valid!");
+ return Point2f(0, 0);
+ }
+
+ void SetFullframeAlignmentMatrix(const float* const align_matrix23) {
+ if (align_matrix23 != NULL) {
+ if (fullframe_matrix_ == NULL) {
+ fullframe_matrix_ = new float[6];
+ }
+
+ memcpy(fullframe_matrix_, align_matrix23,
+ 6 * sizeof(fullframe_matrix_[0]));
+ }
+ }
+
+ private:
+ Point2f LookupGuessFromLevel(
+ const int cache_level, const float x, const float y) const {
+ // LOGE("Looking up guess at %5.2f %5.2f for level %d.", x, y, cache_level);
+
+ // Cutoff at the target level and use the matrix transform instead.
+ if (fullframe_matrix_ != NULL && cache_level == kCacheCutoff) {
+ const float xnew = x * fullframe_matrix_[0] +
+ y * fullframe_matrix_[1] +
+ fullframe_matrix_[2];
+ const float ynew = x * fullframe_matrix_[3] +
+ y * fullframe_matrix_[4] +
+ fullframe_matrix_[5];
+
+ return Point2f(xnew - x, ynew - y);
+ }
+
+ const int level_dim = BlockDimForCacheLevel(cache_level);
+ const int pixels_per_cache_block_x =
+ (image_size_.width + level_dim - 1) / level_dim;
+ const int pixels_per_cache_block_y =
+ (image_size_.height + level_dim - 1) / level_dim;
+ const int index_x = x / pixels_per_cache_block_x;
+ const int index_y = y / pixels_per_cache_block_y;
+
+ Point2f displacement;
+ if (!(*has_cache_[cache_level])[index_y][index_x]) {
+ (*has_cache_[cache_level])[index_y][index_x] = true;
+
+ // Get the lower cache level's best guess, if it exists.
+ displacement = cache_level >= kNumCacheLevels - 1 ?
+ Point2f(0, 0) : LookupGuessFromLevel(cache_level + 1, x, y);
+ // LOGI("Best guess at cache level %d is %5.2f, %5.2f.", cache_level,
+ // best_guess.x, best_guess.y);
+
+ // Find the center of the block.
+ const float center_x = (index_x + 0.5f) * pixels_per_cache_block_x;
+ const float center_y = (index_y + 0.5f) * pixels_per_cache_block_y;
+ const int pyramid_level = PyramidLevelForCacheLevel(cache_level);
+
+ // LOGI("cache level %d: [%d, %d (%5.2f / %d, %5.2f / %d)] "
+ // "Querying %5.2f, %5.2f at pyramid level %d, ",
+ // cache_level, index_x, index_y,
+ // x, pixels_per_cache_block_x, y, pixels_per_cache_block_y,
+ // center_x, center_y, pyramid_level);
+
+ // TODO(andrewharp): Turn on FB error filtering.
+ const bool success = optical_flow_.FindFlowAtPointSingleLevel(
+ pyramid_level, center_x, center_y, false,
+ &displacement.x, &displacement.y);
+
+ if (!success) {
+ LOGV("Computation of cached value failed for level %d!", cache_level);
+ }
+
+ // Store the value for later use.
+ (*displacements_[cache_level])[index_y][index_x] = displacement;
+ } else {
+ displacement = (*displacements_[cache_level])[index_y][index_x];
+ }
+
+ // LOGI("Returning %5.2f, %5.2f for level %d",
+ // displacement.x, displacement.y, cache_level);
+ return displacement;
+ }
+
+ Point2f LookupGuess(const float x, const float y) const {
+ if (x < 0 || x >= image_size_.width || y < 0 || y >= image_size_.height) {
+ return Point2f(0, 0);
+ }
+
+ // LOGI("Looking up guess at %5.2f %5.2f.", x, y);
+ if (kNumCacheLevels > 0) {
+ return LookupGuessFromLevel(0, x, y);
+ } else {
+ return Point2f(0, 0);
+ }
+ }
+
+ // Returns the number of cache bins in each dimension for a given level
+ // of the cache.
+ int BlockDimForCacheLevel(const int cache_level) const {
+ // The highest (coarsest) cache level has a block dim of kCacheBranchFactor,
+ // thus if there are 4 cache levels, requesting level 3 (0-based) should
+ // return kCacheBranchFactor, level 2 should return kCacheBranchFactor^2,
+ // and so on.
+ int block_dim = kNumCacheLevels;
+ for (int curr_level = kNumCacheLevels - 1; curr_level > cache_level;
+ --curr_level) {
+ block_dim *= kCacheBranchFactor;
+ }
+ return block_dim;
+ }
+
+ // Returns the level of the image pyramid that a given cache level maps to.
+ int PyramidLevelForCacheLevel(const int cache_level) const {
+ // Higher cache and pyramid levels have smaller dimensions. The highest
+ // cache level should refer to the highest image pyramid level. The
+ // lower, finer image pyramid levels are uncached (assuming
+ // kNumCacheLevels < kNumPyramidLevels).
+ return cache_level + (kNumPyramidLevels - kNumCacheLevels);
+ }
+
+ const OpticalFlowConfig* const config_;
+
+ const Size image_size_;
+ OpticalFlow optical_flow_;
+
+ float* fullframe_matrix_;
+
+ // Whether this value is currently present in the cache.
+ Image<bool>* has_cache_[kNumCacheLevels];
+
+ // The cached displacement values.
+ Image<Point2f>* displacements_[kNumCacheLevels];
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FlowCache);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/frame_pair.cc b/tensorflow/examples/android/jni/object_tracking/frame_pair.cc
new file mode 100644
index 0000000000..fa86e2363c
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/frame_pair.cc
@@ -0,0 +1,308 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <float.h>
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
+
+namespace tf_tracking {
+
+void FramePair::Init(const int64 start_time, const int64 end_time) {
+ start_time_ = start_time;
+ end_time_ = end_time;
+ memset(optical_flow_found_keypoint_, false,
+ sizeof(*optical_flow_found_keypoint_) * kMaxKeypoints);
+ number_of_keypoints_ = 0;
+}
+
+void FramePair::AdjustBox(const BoundingBox box,
+ float* const translation_x,
+ float* const translation_y,
+ float* const scale_x,
+ float* const scale_y) const {
+ static float weights[kMaxKeypoints];
+ static Point2f deltas[kMaxKeypoints];
+ memset(weights, 0.0f, sizeof(*weights) * kMaxKeypoints);
+
+ BoundingBox resized_box(box);
+ resized_box.Scale(0.4f, 0.4f);
+ FillWeights(resized_box, weights);
+ FillTranslations(deltas);
+
+ const Point2f translation = GetWeightedMedian(weights, deltas);
+
+ *translation_x = translation.x;
+ *translation_y = translation.y;
+
+ const Point2f old_center = box.GetCenter();
+ const int good_scale_points =
+ FillScales(old_center, translation, weights, deltas);
+
+ // Default scale factor is 1 for x and y.
+ *scale_x = 1.0f;
+ *scale_y = 1.0f;
+
+ // The assumption is that all deltas that make it to this stage with a
+ // correspondending optical_flow_found_keypoint_[i] == true are not in
+ // themselves degenerate.
+ //
+ // The degeneracy with scale arose because if the points are too close to the
+ // center of the objects, the scale ratio determination might be incalculable.
+ //
+ // The check for kMinNumInRange is not a degeneracy check, but merely an
+ // attempt to ensure some sort of stability. The actual degeneracy check is in
+ // the comparison to EPSILON in FillScales (which I've updated to return the
+ // number good remaining as well).
+ static const int kMinNumInRange = 5;
+ if (good_scale_points >= kMinNumInRange) {
+ const float scale_factor = GetWeightedMedianScale(weights, deltas);
+
+ if (scale_factor > 0.0f) {
+ *scale_x = scale_factor;
+ *scale_y = scale_factor;
+ }
+ }
+}
+
+int FramePair::FillWeights(const BoundingBox& box,
+ float* const weights) const {
+ // Compute the max score.
+ float max_score = -FLT_MAX;
+ float min_score = FLT_MAX;
+ for (int i = 0; i < kMaxKeypoints; ++i) {
+ if (optical_flow_found_keypoint_[i]) {
+ max_score = MAX(max_score, frame1_keypoints_[i].score_);
+ min_score = MIN(min_score, frame1_keypoints_[i].score_);
+ }
+ }
+
+ int num_in_range = 0;
+ for (int i = 0; i < kMaxKeypoints; ++i) {
+ if (!optical_flow_found_keypoint_[i]) {
+ weights[i] = 0.0f;
+ continue;
+ }
+
+ const bool in_box = box.Contains(frame1_keypoints_[i].pos_);
+ if (in_box) {
+ ++num_in_range;
+ }
+
+ // The weighting based off distance. Anything within the bounding box
+ // has a weight of 1, and everything outside of that is within the range
+ // [0, kOutOfBoxMultiplier), falling off with the squared distance ratio.
+ float distance_score = 1.0f;
+ if (!in_box) {
+ const Point2f initial = box.GetCenter();
+ const float sq_x_dist =
+ Square(initial.x - frame1_keypoints_[i].pos_.x);
+ const float sq_y_dist =
+ Square(initial.y - frame1_keypoints_[i].pos_.y);
+ const float squared_half_width = Square(box.GetWidth() / 2.0f);
+ const float squared_half_height = Square(box.GetHeight() / 2.0f);
+
+ static const float kOutOfBoxMultiplier = 0.5f;
+ distance_score = kOutOfBoxMultiplier *
+ MIN(squared_half_height / sq_y_dist, squared_half_width / sq_x_dist);
+ }
+
+ // The weighting based on relative score strength. kBaseScore - 1.0f.
+ float intrinsic_score = 1.0f;
+ if (max_score > min_score) {
+ static const float kBaseScore = 0.5f;
+ intrinsic_score = ((frame1_keypoints_[i].score_ - min_score) /
+ (max_score - min_score)) * (1.0f - kBaseScore) + kBaseScore;
+ }
+
+ // The final score will be in the range [0, 1].
+ weights[i] = distance_score * intrinsic_score;
+ }
+
+ return num_in_range;
+}
+
+void FramePair::FillTranslations(Point2f* const translations) const {
+ for (int i = 0; i < kMaxKeypoints; ++i) {
+ if (!optical_flow_found_keypoint_[i]) {
+ continue;
+ }
+ translations[i].x =
+ frame2_keypoints_[i].pos_.x - frame1_keypoints_[i].pos_.x;
+ translations[i].y =
+ frame2_keypoints_[i].pos_.y - frame1_keypoints_[i].pos_.y;
+ }
+}
+
+int FramePair::FillScales(const Point2f& old_center,
+ const Point2f& translation,
+ float* const weights,
+ Point2f* const scales) const {
+ int num_good = 0;
+ for (int i = 0; i < kMaxKeypoints; ++i) {
+ if (!optical_flow_found_keypoint_[i]) {
+ continue;
+ }
+
+ const Keypoint keypoint1 = frame1_keypoints_[i];
+ const Keypoint keypoint2 = frame2_keypoints_[i];
+
+ const float dist1_x = keypoint1.pos_.x - old_center.x;
+ const float dist1_y = keypoint1.pos_.y - old_center.y;
+
+ const float dist2_x = (keypoint2.pos_.x - translation.x) - old_center.x;
+ const float dist2_y = (keypoint2.pos_.y - translation.y) - old_center.y;
+
+ // Make sure that the scale makes sense; points too close to the center
+ // will result in either NaNs or infinite results for scale due to
+ // limited tracking and floating point resolution.
+ // Also check that the parity of the points is the same with respect to
+ // x and y, as we can't really make sense of data that has flipped.
+ if (((dist2_x > EPSILON && dist1_x > EPSILON) ||
+ (dist2_x < -EPSILON && dist1_x < -EPSILON)) &&
+ ((dist2_y > EPSILON && dist1_y > EPSILON) ||
+ (dist2_y < -EPSILON && dist1_y < -EPSILON))) {
+ scales[i].x = dist2_x / dist1_x;
+ scales[i].y = dist2_y / dist1_y;
+ ++num_good;
+ } else {
+ weights[i] = 0.0f;
+ scales[i].x = 1.0f;
+ scales[i].y = 1.0f;
+ }
+ }
+ return num_good;
+}
+
+struct WeightedDelta {
+ float weight;
+ float delta;
+};
+
+// Sort by delta, not by weight.
+inline int WeightedDeltaCompare(const void* const a, const void* const b) {
+ return (reinterpret_cast<const WeightedDelta*>(a)->delta -
+ reinterpret_cast<const WeightedDelta*>(b)->delta) <= 0 ? 1 : -1;
+}
+
+// Returns the median delta from a sorted set of weighted deltas.
+static float GetMedian(const int num_items,
+ const WeightedDelta* const weighted_deltas,
+ const float sum) {
+ if (num_items == 0 || sum < EPSILON) {
+ return 0.0f;
+ }
+
+ float current_weight = 0.0f;
+ const float target_weight = sum / 2.0f;
+ for (int i = 0; i < num_items; ++i) {
+ if (weighted_deltas[i].weight > 0.0f) {
+ current_weight += weighted_deltas[i].weight;
+ if (current_weight >= target_weight) {
+ return weighted_deltas[i].delta;
+ }
+ }
+ }
+ LOGW("Median not found! %d points, sum of %.2f", num_items, sum);
+ return 0.0f;
+}
+
+Point2f FramePair::GetWeightedMedian(
+ const float* const weights, const Point2f* const deltas) const {
+ Point2f median_delta;
+
+ // TODO(andrewharp): only sort deltas that could possibly have an effect.
+ static WeightedDelta weighted_deltas[kMaxKeypoints];
+
+ // Compute median X value.
+ {
+ float total_weight = 0.0f;
+
+ // Compute weighted mean and deltas.
+ for (int i = 0; i < kMaxKeypoints; ++i) {
+ weighted_deltas[i].delta = deltas[i].x;
+ const float weight = weights[i];
+ weighted_deltas[i].weight = weight;
+ if (weight > 0.0f) {
+ total_weight += weight;
+ }
+ }
+ qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
+ WeightedDeltaCompare);
+ median_delta.x = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
+ }
+
+ // Compute median Y value.
+ {
+ float total_weight = 0.0f;
+
+ // Compute weighted mean and deltas.
+ for (int i = 0; i < kMaxKeypoints; ++i) {
+ const float weight = weights[i];
+ weighted_deltas[i].weight = weight;
+ weighted_deltas[i].delta = deltas[i].y;
+ if (weight > 0.0f) {
+ total_weight += weight;
+ }
+ }
+ qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
+ WeightedDeltaCompare);
+ median_delta.y = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
+ }
+
+ return median_delta;
+}
+
+float FramePair::GetWeightedMedianScale(
+ const float* const weights, const Point2f* const deltas) const {
+ float median_delta;
+
+ // TODO(andrewharp): only sort deltas that could possibly have an effect.
+ static WeightedDelta weighted_deltas[kMaxKeypoints * 2];
+
+ // Compute median scale value across x and y.
+ {
+ float total_weight = 0.0f;
+
+ // Add X values.
+ for (int i = 0; i < kMaxKeypoints; ++i) {
+ weighted_deltas[i].delta = deltas[i].x;
+ const float weight = weights[i];
+ weighted_deltas[i].weight = weight;
+ if (weight > 0.0f) {
+ total_weight += weight;
+ }
+ }
+
+ // Add Y values.
+ for (int i = 0; i < kMaxKeypoints; ++i) {
+ weighted_deltas[i + kMaxKeypoints].delta = deltas[i].y;
+ const float weight = weights[i];
+ weighted_deltas[i + kMaxKeypoints].weight = weight;
+ if (weight > 0.0f) {
+ total_weight += weight;
+ }
+ }
+
+ qsort(weighted_deltas, kMaxKeypoints * 2, sizeof(WeightedDelta),
+ WeightedDeltaCompare);
+
+ median_delta = GetMedian(kMaxKeypoints * 2, weighted_deltas, total_weight);
+ }
+
+ return median_delta;
+}
+
+} // namespace tf_tracking
diff --git a/tensorflow/examples/android/jni/object_tracking/frame_pair.h b/tensorflow/examples/android/jni/object_tracking/frame_pair.h
new file mode 100644
index 0000000000..3f2559a5e0
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/frame_pair.h
@@ -0,0 +1,103 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
+
+#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
+
+namespace tf_tracking {
+
+// A class that records keypoint correspondences from pairs of
+// consecutive frames.
+class FramePair {
+ public:
+ FramePair()
+ : start_time_(0),
+ end_time_(0),
+ number_of_keypoints_(0) {}
+
+ // Cleans up the FramePair so that they can be reused.
+ void Init(const int64 start_time, const int64 end_time);
+
+ void AdjustBox(const BoundingBox box,
+ float* const translation_x,
+ float* const translation_y,
+ float* const scale_x,
+ float* const scale_y) const;
+
+ private:
+ // Returns the weighted median of the given deltas, computed independently on
+ // x and y. Returns 0,0 in case of failure. The assumption is that a
+ // translation of 0.0 in the degenerate case is the best that can be done, and
+ // should not be considered an error.
+ //
+ // In the case of scale, a slight exception is made just to be safe and
+ // there is a check for 0.0 explicitly, but that shouldn't ever be possible to
+ // happen naturally because of the non-zero + parity checks in FillScales.
+ Point2f GetWeightedMedian(const float* const weights,
+ const Point2f* const deltas) const;
+
+ float GetWeightedMedianScale(const float* const weights,
+ const Point2f* const deltas) const;
+
+ // Weights points based on the query_point and cutoff_dist.
+ int FillWeights(const BoundingBox& box,
+ float* const weights) const;
+
+ // Fills in the array of deltas with the translations of the points
+ // between frames.
+ void FillTranslations(Point2f* const translations) const;
+
+ // Fills in the array of deltas with the relative scale factor of points
+ // relative to a given center. Has the ability to override the weight to 0 if
+ // a degenerate scale is detected.
+ // Translation is the amount the center of the box has moved from one frame to
+ // the next.
+ int FillScales(const Point2f& old_center,
+ const Point2f& translation,
+ float* const weights,
+ Point2f* const scales) const;
+
+ // TODO(andrewharp): Make these private.
+ public:
+ // The time at frame1.
+ int64 start_time_;
+
+ // The time at frame2.
+ int64 end_time_;
+
+ // This array will contain the keypoints found in frame 1.
+ Keypoint frame1_keypoints_[kMaxKeypoints];
+
+ // Contain the locations of the keypoints from frame 1 in frame 2.
+ Keypoint frame2_keypoints_[kMaxKeypoints];
+
+ // The number of keypoints in frame 1.
+ int number_of_keypoints_;
+
+ // Keeps track of which keypoint correspondences were actually found from one
+ // frame to another.
+ // The i-th element of this array will be non-zero if and only if the i-th
+ // keypoint of frame 1 was found in frame 2.
+ bool optical_flow_found_keypoint_[kMaxKeypoints];
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(FramePair);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/geom.h b/tensorflow/examples/android/jni/object_tracking/geom.h
new file mode 100644
index 0000000000..5d5249cd97
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/geom.h
@@ -0,0 +1,319 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
+
+#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+namespace tf_tracking {
+
+struct Size {
+ Size(const int width, const int height) : width(width), height(height) {}
+
+ int width;
+ int height;
+};
+
+
+class Point2f {
+ public:
+ Point2f() : x(0.0f), y(0.0f) {}
+ Point2f(const float x, const float y) : x(x), y(y) {}
+
+ inline Point2f operator- (const Point2f& that) const {
+ return Point2f(this->x - that.x, this->y - that.y);
+ }
+
+ inline Point2f operator+ (const Point2f& that) const {
+ return Point2f(this->x + that.x, this->y + that.y);
+ }
+
+ inline Point2f& operator+= (const Point2f& that) {
+ this->x += that.x;
+ this->y += that.y;
+ return *this;
+ }
+
+ inline Point2f& operator-= (const Point2f& that) {
+ this->x -= that.x;
+ this->y -= that.y;
+ return *this;
+ }
+
+ inline Point2f operator- (const Point2f& that) {
+ return Point2f(this->x - that.x, this->y - that.y);
+ }
+
+ inline float LengthSquared() {
+ return Square(this->x) + Square(this->y);
+ }
+
+ inline float Length() {
+ return sqrtf(LengthSquared());
+ }
+
+ inline float DistanceSquared(const Point2f& that) {
+ return Square(this->x - that.x) + Square(this->y - that.y);
+ }
+
+ inline float Distance(const Point2f& that) {
+ return sqrtf(DistanceSquared(that));
+ }
+
+ float x;
+ float y;
+};
+
+inline std::ostream& operator<<(std::ostream& stream, const Point2f& point) {
+ stream << point.x << "," << point.y;
+ return stream;
+}
+
+class BoundingBox {
+ public:
+ BoundingBox()
+ : left_(0),
+ top_(0),
+ right_(0),
+ bottom_(0) {}
+
+ BoundingBox(const BoundingBox& bounding_box)
+ : left_(bounding_box.left_),
+ top_(bounding_box.top_),
+ right_(bounding_box.right_),
+ bottom_(bounding_box.bottom_) {
+ SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_);
+ SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_);
+ }
+
+ BoundingBox(const float left,
+ const float top,
+ const float right,
+ const float bottom)
+ : left_(left),
+ top_(top),
+ right_(right),
+ bottom_(bottom) {
+ SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_);
+ SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_);
+ }
+
+ BoundingBox(const Point2f& point1, const Point2f& point2)
+ : left_(MIN(point1.x, point2.x)),
+ top_(MIN(point1.y, point2.y)),
+ right_(MAX(point1.x, point2.x)),
+ bottom_(MAX(point1.y, point2.y)) {}
+
+ inline void CopyToArray(float* const bounds_array) const {
+ bounds_array[0] = left_;
+ bounds_array[1] = top_;
+ bounds_array[2] = right_;
+ bounds_array[3] = bottom_;
+ }
+
+ inline float GetWidth() const {
+ return right_ - left_;
+ }
+
+ inline float GetHeight() const {
+ return bottom_ - top_;
+ }
+
+ inline float GetArea() const {
+ const float width = GetWidth();
+ const float height = GetHeight();
+ if (width <= 0 || height <= 0) {
+ return 0.0f;
+ }
+
+ return width * height;
+ }
+
+ inline Point2f GetCenter() const {
+ return Point2f((left_ + right_) / 2.0f,
+ (top_ + bottom_) / 2.0f);
+ }
+
+ inline bool ValidBox() const {
+ return GetArea() > 0.0f;
+ }
+
+ // Returns a bounding box created from the overlapping area of these two.
+ inline BoundingBox Intersect(const BoundingBox& that) const {
+ const float new_left = MAX(this->left_, that.left_);
+ const float new_right = MIN(this->right_, that.right_);
+
+ if (new_left >= new_right) {
+ return BoundingBox();
+ }
+
+ const float new_top = MAX(this->top_, that.top_);
+ const float new_bottom = MIN(this->bottom_, that.bottom_);
+
+ if (new_top >= new_bottom) {
+ return BoundingBox();
+ }
+
+ return BoundingBox(new_left, new_top, new_right, new_bottom);
+ }
+
+ // Returns a bounding box that can contain both boxes.
+ inline BoundingBox Union(const BoundingBox& that) const {
+ return BoundingBox(MIN(this->left_, that.left_),
+ MIN(this->top_, that.top_),
+ MAX(this->right_, that.right_),
+ MAX(this->bottom_, that.bottom_));
+ }
+
+ inline float PascalScore(const BoundingBox& that) const {
+ SCHECK(GetArea() > 0.0f, "Empty bounding box!");
+ SCHECK(that.GetArea() > 0.0f, "Empty bounding box!");
+
+ const float intersect_area = this->Intersect(that).GetArea();
+
+ if (intersect_area <= 0) {
+ return 0;
+ }
+
+ const float score =
+ intersect_area / (GetArea() + that.GetArea() - intersect_area);
+ SCHECK(InRange(score, 0.0f, 1.0f), "Invalid score! %.2f", score);
+ return score;
+ }
+
+ inline bool Intersects(const BoundingBox& that) const {
+ return InRange(that.left_, left_, right_)
+ || InRange(that.right_, left_, right_)
+ || InRange(that.top_, top_, bottom_)
+ || InRange(that.bottom_, top_, bottom_);
+ }
+
+ // Returns whether another bounding box is completely inside of this bounding
+ // box. Sharing edges is ok.
+ inline bool Contains(const BoundingBox& that) const {
+ return that.left_ >= left_ &&
+ that.right_ <= right_ &&
+ that.top_ >= top_ &&
+ that.bottom_ <= bottom_;
+ }
+
+ inline bool Contains(const Point2f& point) const {
+ return InRange(point.x, left_, right_) && InRange(point.y, top_, bottom_);
+ }
+
+ inline void Shift(const Point2f shift_amount) {
+ left_ += shift_amount.x;
+ top_ += shift_amount.y;
+ right_ += shift_amount.x;
+ bottom_ += shift_amount.y;
+ }
+
+ inline void ScaleOrigin(const float scale_x, const float scale_y) {
+ left_ *= scale_x;
+ right_ *= scale_x;
+ top_ *= scale_y;
+ bottom_ *= scale_y;
+ }
+
+ inline void Scale(const float scale_x, const float scale_y) {
+ const Point2f center = GetCenter();
+ const float half_width = GetWidth() / 2.0f;
+ const float half_height = GetHeight() / 2.0f;
+
+ left_ = center.x - half_width * scale_x;
+ right_ = center.x + half_width * scale_x;
+
+ top_ = center.y - half_height * scale_y;
+ bottom_ = center.y + half_height * scale_y;
+ }
+
+ float left_;
+ float top_;
+ float right_;
+ float bottom_;
+};
+inline std::ostream& operator<<(std::ostream& stream, const BoundingBox& box) {
+ stream << "[" << box.left_ << " - " << box.right_
+ << ", " << box.top_ << " - " << box.bottom_
+ << ", w:" << box.GetWidth() << " h:" << box.GetHeight() << "]";
+ return stream;
+}
+
+
+class BoundingSquare {
+ public:
+ BoundingSquare(const float x, const float y, const float size)
+ : x_(x), y_(y), size_(size) {}
+
+ explicit BoundingSquare(const BoundingBox& box)
+ : x_(box.left_), y_(box.top_), size_(box.GetWidth()) {
+#ifdef SANITY_CHECKS
+ if (std::abs(box.GetWidth() - box.GetHeight()) > 0.1f) {
+ LOG(WARNING) << "This is not a square: " << box << std::endl;
+ }
+#endif
+ }
+
+ inline BoundingBox ToBoundingBox() const {
+ return BoundingBox(x_, y_, x_ + size_, y_ + size_);
+ }
+
+ inline bool ValidBox() {
+ return size_ > 0.0f;
+ }
+
+ inline void Shift(const Point2f shift_amount) {
+ x_ += shift_amount.x;
+ y_ += shift_amount.y;
+ }
+
+ inline void Scale(const float scale) {
+ const float new_size = size_ * scale;
+ const float position_diff = (new_size - size_) / 2.0f;
+ x_ -= position_diff;
+ y_ -= position_diff;
+ size_ = new_size;
+ }
+
+ float x_;
+ float y_;
+ float size_;
+};
+inline std::ostream& operator<<(std::ostream& stream,
+ const BoundingSquare& square) {
+ stream << "[" << square.x_ << "," << square.y_ << " " << square.size_ << "]";
+ return stream;
+}
+
+
+inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box,
+ const float size) {
+ const float width_diff = (original_box.GetWidth() - size) / 2.0f;
+ const float height_diff = (original_box.GetHeight() - size) / 2.0f;
+ return BoundingSquare(original_box.left_ + width_diff,
+ original_box.top_ + height_diff,
+ size);
+}
+
+inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box) {
+ return GetCenteredSquare(
+ original_box, MIN(original_box.GetWidth(), original_box.GetHeight()));
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/gl_utils.h b/tensorflow/examples/android/jni/object_tracking/gl_utils.h
new file mode 100755
index 0000000000..bd5c233f4f
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/gl_utils.h
@@ -0,0 +1,55 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
+
+#include <GLES/gl.h>
+#include <GLES/glext.h>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+
+namespace tf_tracking {
+
+// Draws a box at the given position.
+inline static void DrawBox(const BoundingBox& bounding_box) {
+ const GLfloat line[] = {
+ bounding_box.left_, bounding_box.bottom_,
+ bounding_box.left_, bounding_box.top_,
+ bounding_box.left_, bounding_box.top_,
+ bounding_box.right_, bounding_box.top_,
+ bounding_box.right_, bounding_box.top_,
+ bounding_box.right_, bounding_box.bottom_,
+ bounding_box.right_, bounding_box.bottom_,
+ bounding_box.left_, bounding_box.bottom_
+ };
+
+ glVertexPointer(2, GL_FLOAT, 0, line);
+ glEnableClientState(GL_VERTEX_ARRAY);
+
+ glDrawArrays(GL_LINES, 0, 8);
+}
+
+
+// Changes the coordinate system such that drawing to an arbitrary square in
+// the world can thereafter be drawn to using coordinates 0 - 1.
+inline static void MapWorldSquareToUnitSquare(const BoundingSquare& square) {
+ glScalef(square.size_, square.size_, 1.0f);
+ glTranslatef(square.x_ / square.size_, square.y_ / square.size_, 0.0f);
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/image-inl.h b/tensorflow/examples/android/jni/object_tracking/image-inl.h
new file mode 100644
index 0000000000..18123cef01
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/image-inl.h
@@ -0,0 +1,642 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+namespace tf_tracking {
+
+template <typename T>
+Image<T>::Image(const int width, const int height)
+ : width_less_one_(width - 1),
+ height_less_one_(height - 1),
+ data_size_(width * height),
+ own_data_(true),
+ width_(width),
+ height_(height),
+ stride_(width) {
+ Allocate();
+}
+
+template <typename T>
+Image<T>::Image(const Size& size)
+ : width_less_one_(size.width - 1),
+ height_less_one_(size.height - 1),
+ data_size_(size.width * size.height),
+ own_data_(true),
+ width_(size.width),
+ height_(size.height),
+ stride_(size.width) {
+ Allocate();
+}
+
+// Constructor that creates an image from preallocated data.
+// Note: The image takes ownership of the data lifecycle, unless own_data is
+// set to false.
+template <typename T>
+Image<T>::Image(const int width, const int height, T* const image_data,
+ const bool own_data) :
+ width_less_one_(width - 1),
+ height_less_one_(height - 1),
+ data_size_(width * height),
+ own_data_(own_data),
+ width_(width),
+ height_(height),
+ stride_(width) {
+ image_data_ = image_data;
+ SCHECK(image_data_ != NULL, "Can't create image with NULL data!");
+}
+
+template <typename T>
+Image<T>::~Image() {
+ if (own_data_) {
+ delete[] image_data_;
+ }
+ image_data_ = NULL;
+}
+
+template<typename T>
+template<class DstType>
+bool Image<T>::ExtractPatchAtSubpixelFixed1616(const int fp_x,
+ const int fp_y,
+ const int patchwidth,
+ const int patchheight,
+ DstType* to_data) const {
+ // Calculate weights.
+ const int trunc_x = fp_x >> 16;
+ const int trunc_y = fp_y >> 16;
+
+ if (trunc_x < 0 || trunc_y < 0 ||
+ (trunc_x + patchwidth) >= width_less_one_ ||
+ (trunc_y + patchheight) >= height_less_one_) {
+ return false;
+ }
+
+ // Now walk over destination patch and fill from interpolated source image.
+ for (int y = 0; y < patchheight; ++y, to_data += patchwidth) {
+ for (int x = 0; x < patchwidth; ++x) {
+ to_data[x] =
+ static_cast<DstType>(GetPixelInterpFixed1616(fp_x + (x << 16),
+ fp_y + (y << 16)));
+ }
+ }
+
+ return true;
+}
+
+template <typename T>
+Image<T>* Image<T>::Crop(
+ const int left, const int top, const int right, const int bottom) const {
+ SCHECK(left >= 0 && left < width_, "out of bounds at %d!", left);
+ SCHECK(right >= 0 && right < width_, "out of bounds at %d!", right);
+ SCHECK(top >= 0 && top < height_, "out of bounds at %d!", top);
+ SCHECK(bottom >= 0 && bottom < height_, "out of bounds at %d!", bottom);
+
+ SCHECK(left <= right, "mismatch!");
+ SCHECK(top <= bottom, "mismatch!");
+
+ const int new_width = right - left + 1;
+ const int new_height = bottom - top + 1;
+
+ Image<T>* const cropped_image = new Image(new_width, new_height);
+
+ for (int y = 0; y < new_height; ++y) {
+ memcpy((*cropped_image)[y], ((*this)[y + top] + left),
+ new_width * sizeof(T));
+ }
+
+ return cropped_image;
+}
+
+template <typename T>
+inline float Image<T>::GetPixelInterp(const float x, const float y) const {
+ // Do int conversion one time.
+ const int floored_x = static_cast<int>(x);
+ const int floored_y = static_cast<int>(y);
+
+ // Note: it might be the case that the *_[min|max] values are clipped, and
+ // these (the a b c d vals) aren't (for speed purposes), but that doesn't
+ // matter. We'll just be blending the pixel with itself in that case anyway.
+ const float b = x - floored_x;
+ const float a = 1.0f - b;
+
+ const float d = y - floored_y;
+ const float c = 1.0f - d;
+
+ SCHECK(ValidInterpPixel(x, y),
+ "x or y out of bounds! %.2f [0 - %d), %.2f [0 - %d)",
+ x, width_less_one_, y, height_less_one_);
+
+ const T* const pix_ptr = (*this)[floored_y] + floored_x;
+
+ // Get the pixel values surrounding this point.
+ const T& p1 = pix_ptr[0];
+ const T& p2 = pix_ptr[1];
+ const T& p3 = pix_ptr[width_];
+ const T& p4 = pix_ptr[width_ + 1];
+
+ // Simple bilinear interpolation between four reference pixels.
+ // If x is the value requested:
+ // a b
+ // -------
+ // c |p1 p2|
+ // | x |
+ // d |p3 p4|
+ // -------
+ return c * ((a * p1) + (b * p2)) +
+ d * ((a * p3) + (b * p4));
+}
+
+
+template <typename T>
+inline T Image<T>::GetPixelInterpFixed1616(
+ const int fp_x_whole, const int fp_y_whole) const {
+ static const int kFixedPointOne = 0x00010000;
+ static const int kFixedPointHalf = 0x00008000;
+ static const int kFixedPointTruncateMask = 0xFFFF0000;
+
+ int trunc_x = fp_x_whole & kFixedPointTruncateMask;
+ int trunc_y = fp_y_whole & kFixedPointTruncateMask;
+ const int fp_x = fp_x_whole - trunc_x;
+ const int fp_y = fp_y_whole - trunc_y;
+
+ // Scale the truncated values back to regular ints.
+ trunc_x >>= 16;
+ trunc_y >>= 16;
+
+ const int one_minus_fp_x = kFixedPointOne - fp_x;
+ const int one_minus_fp_y = kFixedPointOne - fp_y;
+
+ const T* trunc_start = (*this)[trunc_y] + trunc_x;
+
+ const T a = trunc_start[0];
+ const T b = trunc_start[1];
+ const T c = trunc_start[stride_];
+ const T d = trunc_start[stride_ + 1];
+
+ return ((one_minus_fp_y * static_cast<int64>(one_minus_fp_x * a + fp_x * b) +
+ fp_y * static_cast<int64>(one_minus_fp_x * c + fp_x * d) +
+ kFixedPointHalf) >> 32);
+}
+
+template <typename T>
+inline bool Image<T>::ValidPixel(const int x, const int y) const {
+ return InRange(x, ZERO, width_less_one_) &&
+ InRange(y, ZERO, height_less_one_);
+}
+
+template <typename T>
+inline BoundingBox Image<T>::GetContainingBox() const {
+ return BoundingBox(
+ 0, 0, width_less_one_ - EPSILON, height_less_one_ - EPSILON);
+}
+
+template <typename T>
+inline bool Image<T>::Contains(const BoundingBox& bounding_box) const {
+ // TODO(andrewharp): Come up with a more elegant way of ensuring that bounds
+ // are ok.
+ return GetContainingBox().Contains(bounding_box);
+}
+
+template <typename T>
+inline bool Image<T>::ValidInterpPixel(const float x, const float y) const {
+ // Exclusive of max because we can be more efficient if we don't handle
+ // interpolating on or past the last pixel.
+ return (x >= ZERO) && (x < width_less_one_) &&
+ (y >= ZERO) && (y < height_less_one_);
+}
+
+template <typename T>
+void Image<T>::DownsampleAveraged(const T* const original, const int stride,
+ const int factor) {
+#ifdef __ARM_NEON
+ if (factor == 4 || factor == 2) {
+ DownsampleAveragedNeon(original, stride, factor);
+ return;
+ }
+#endif
+
+ // TODO(andrewharp): delete or enable this for non-uint8 downsamples.
+ const int pixels_per_block = factor * factor;
+
+ // For every pixel in resulting image.
+ for (int y = 0; y < height_; ++y) {
+ const int orig_y = y * factor;
+ const int y_bound = orig_y + factor;
+
+ // Sum up the original pixels.
+ for (int x = 0; x < width_; ++x) {
+ const int orig_x = x * factor;
+ const int x_bound = orig_x + factor;
+
+ // Making this int32 because type U or T might overflow.
+ int32 pixel_sum = 0;
+
+ // Grab all the pixels that make up this pixel.
+ for (int curr_y = orig_y; curr_y < y_bound; ++curr_y) {
+ const T* p = original + curr_y * stride + orig_x;
+
+ for (int curr_x = orig_x; curr_x < x_bound; ++curr_x) {
+ pixel_sum += *p++;
+ }
+ }
+
+ (*this)[y][x] = pixel_sum / pixels_per_block;
+ }
+ }
+}
+
+template <typename T>
+void Image<T>::DownsampleInterpolateNearest(const Image<T>& original) {
+ // Calculating the scaling factors based on target image size.
+ const float factor_x = static_cast<float>(original.GetWidth()) /
+ static_cast<float>(width_);
+ const float factor_y = static_cast<float>(original.GetHeight()) /
+ static_cast<float>(height_);
+
+ // Calculating initial offset in x-axis.
+ const float offset_x = 0.5f * (original.GetWidth() - width_) / width_;
+
+ // Calculating initial offset in y-axis.
+ const float offset_y = 0.5f * (original.GetHeight() - height_) / height_;
+
+ float orig_y = offset_y;
+
+ // For every pixel in resulting image.
+ for (int y = 0; y < height_; ++y) {
+ float orig_x = offset_x;
+
+ // Finding nearest pixel on y-axis.
+ const int nearest_y = static_cast<int>(orig_y + 0.5f);
+ const T* row_data = original[nearest_y];
+
+ T* pixel_ptr = (*this)[y];
+
+ for (int x = 0; x < width_; ++x) {
+ // Finding nearest pixel on x-axis.
+ const int nearest_x = static_cast<int>(orig_x + 0.5f);
+
+ *pixel_ptr++ = row_data[nearest_x];
+
+ orig_x += factor_x;
+ }
+
+ orig_y += factor_y;
+ }
+}
+
+template <typename T>
+void Image<T>::DownsampleInterpolateLinear(const Image<T>& original) {
+ // TODO(andrewharp): Turn this into a general compare sizes/bulk
+ // copy method.
+ if (original.GetWidth() == GetWidth() &&
+ original.GetHeight() == GetHeight() &&
+ original.stride() == stride()) {
+ memcpy(image_data_, original.data(), data_size_ * sizeof(T));
+ return;
+ }
+
+ // Calculating the scaling factors based on target image size.
+ const float factor_x = static_cast<float>(original.GetWidth()) /
+ static_cast<float>(width_);
+ const float factor_y = static_cast<float>(original.GetHeight()) /
+ static_cast<float>(height_);
+
+ // Calculating initial offset in x-axis.
+ const float offset_x = 0;
+ const int offset_x_fp = RealToFixed1616(offset_x);
+
+ // Calculating initial offset in y-axis.
+ const float offset_y = 0;
+ const int offset_y_fp = RealToFixed1616(offset_y);
+
+ // Get the fixed point scaling factor value.
+ // Shift by 8 so we can fit everything into a 4 byte int later for speed
+ // reasons. This means the precision is limited to 1 / 256th of a pixel,
+ // but this should be good enough.
+ const int factor_x_fp = RealToFixed1616(factor_x) >> 8;
+ const int factor_y_fp = RealToFixed1616(factor_y) >> 8;
+
+ int src_y_fp = offset_y_fp >> 8;
+
+ static const int kFixedPointOne8 = 0x00000100;
+ static const int kFixedPointHalf8 = 0x00000080;
+ static const int kFixedPointTruncateMask8 = 0xFFFFFF00;
+
+ // For every pixel in resulting image.
+ for (int y = 0; y < height_; ++y) {
+ int src_x_fp = offset_x_fp >> 8;
+
+ int trunc_y = src_y_fp & kFixedPointTruncateMask8;
+ const int fp_y = src_y_fp - trunc_y;
+
+ // Scale the truncated values back to regular ints.
+ trunc_y >>= 8;
+
+ const int one_minus_fp_y = kFixedPointOne8 - fp_y;
+
+ T* pixel_ptr = (*this)[y];
+
+ // Make sure not to read from an invalid row.
+ const int trunc_y_b = MIN(original.height_less_one_, trunc_y + 1);
+ const T* other_top_ptr = original[trunc_y];
+ const T* other_bot_ptr = original[trunc_y_b];
+
+ int last_trunc_x = -1;
+ int trunc_x = -1;
+
+ T a = 0;
+ T b = 0;
+ T c = 0;
+ T d = 0;
+
+ for (int x = 0; x < width_; ++x) {
+ trunc_x = src_x_fp & kFixedPointTruncateMask8;
+
+ const int fp_x = (src_x_fp - trunc_x) >> 8;
+
+ // Scale the truncated values back to regular ints.
+ trunc_x >>= 8;
+
+ // It's possible we're reading from the same pixels
+ if (trunc_x != last_trunc_x) {
+ // Make sure not to read from an invalid column.
+ const int trunc_x_b = MIN(original.width_less_one_, trunc_x + 1);
+ a = other_top_ptr[trunc_x];
+ b = other_top_ptr[trunc_x_b];
+ c = other_bot_ptr[trunc_x];
+ d = other_bot_ptr[trunc_x_b];
+ last_trunc_x = trunc_x;
+ }
+
+ const int one_minus_fp_x = kFixedPointOne8 - fp_x;
+
+ const int32 value =
+ ((one_minus_fp_y * one_minus_fp_x * a + fp_x * b) +
+ (fp_y * one_minus_fp_x * c + fp_x * d) +
+ kFixedPointHalf8) >> 16;
+
+ *pixel_ptr++ = value;
+
+ src_x_fp += factor_x_fp;
+ }
+ src_y_fp += factor_y_fp;
+ }
+}
+
+template <typename T>
+void Image<T>::DownsampleSmoothed3x3(const Image<T>& original) {
+ for (int y = 0; y < height_; ++y) {
+ const int orig_y = Clip(2 * y, ZERO, original.height_less_one_);
+ const int min_y = Clip(orig_y - 1, ZERO, original.height_less_one_);
+ const int max_y = Clip(orig_y + 1, ZERO, original.height_less_one_);
+
+ for (int x = 0; x < width_; ++x) {
+ const int orig_x = Clip(2 * x, ZERO, original.width_less_one_);
+ const int min_x = Clip(orig_x - 1, ZERO, original.width_less_one_);
+ const int max_x = Clip(orig_x + 1, ZERO, original.width_less_one_);
+
+ // Center.
+ int32 pixel_sum = original[orig_y][orig_x] * 4;
+
+ // Sides.
+ pixel_sum += (original[orig_y][max_x] +
+ original[orig_y][min_x] +
+ original[max_y][orig_x] +
+ original[min_y][orig_x]) * 2;
+
+ // Diagonals.
+ pixel_sum += (original[min_y][max_x] +
+ original[min_y][min_x] +
+ original[max_y][max_x] +
+ original[max_y][min_x]);
+
+ (*this)[y][x] = pixel_sum >> 4; // 16
+ }
+ }
+}
+
+template <typename T>
+void Image<T>::DownsampleSmoothed5x5(const Image<T>& original) {
+ const int max_x = original.width_less_one_;
+ const int max_y = original.height_less_one_;
+
+ // The JY Bouget paper on Lucas-Kanade recommends a
+ // [1/16 1/4 3/8 1/4 1/16]^2 filter.
+ // This works out to a [1 4 6 4 1]^2 / 256 array, precomputed below.
+ static const int window_radius = 2;
+ static const int window_size = window_radius*2 + 1;
+ static const int window_weights[] = {1, 4, 6, 4, 1, // 16 +
+ 4, 16, 24, 16, 4, // 64 +
+ 6, 24, 36, 24, 6, // 96 +
+ 4, 16, 24, 16, 4, // 64 +
+ 1, 4, 6, 4, 1}; // 16 = 256
+
+ // We'll multiply and sum with the the whole numbers first, then divide by
+ // the total weight to normalize at the last moment.
+ for (int y = 0; y < height_; ++y) {
+ for (int x = 0; x < width_; ++x) {
+ int32 pixel_sum = 0;
+
+ const int* w = window_weights;
+ const int start_x = Clip((x << 1) - window_radius, ZERO, max_x);
+
+ // Clip the boundaries to the size of the image.
+ for (int window_y = 0; window_y < window_size; ++window_y) {
+ const int start_y =
+ Clip((y << 1) - window_radius + window_y, ZERO, max_y);
+
+ const T* p = original[start_y] + start_x;
+
+ for (int window_x = 0; window_x < window_size; ++window_x) {
+ pixel_sum += *p++ * *w++;
+ }
+ }
+
+ // Conversion to type T will happen here after shifting right 8 bits to
+ // divide by 256.
+ (*this)[y][x] = pixel_sum >> 8;
+ }
+ }
+}
+
+template <typename T>
+template <typename U>
+inline T Image<T>::ScharrPixelX(const Image<U>& original,
+ const int center_x, const int center_y) const {
+ const int min_x = Clip(center_x - 1, ZERO, original.width_less_one_);
+ const int max_x = Clip(center_x + 1, ZERO, original.width_less_one_);
+ const int min_y = Clip(center_y - 1, ZERO, original.height_less_one_);
+ const int max_y = Clip(center_y + 1, ZERO, original.height_less_one_);
+
+ // Convolution loop unrolled for performance...
+ return (3 * (original[min_y][max_x]
+ + original[max_y][max_x]
+ - original[min_y][min_x]
+ - original[max_y][min_x])
+ + 10 * (original[center_y][max_x]
+ - original[center_y][min_x])) / 32;
+}
+
+template <typename T>
+template <typename U>
+inline T Image<T>::ScharrPixelY(const Image<U>& original,
+ const int center_x, const int center_y) const {
+ const int min_x = Clip(center_x - 1, 0, original.width_less_one_);
+ const int max_x = Clip(center_x + 1, 0, original.width_less_one_);
+ const int min_y = Clip(center_y - 1, 0, original.height_less_one_);
+ const int max_y = Clip(center_y + 1, 0, original.height_less_one_);
+
+ // Convolution loop unrolled for performance...
+ return (3 * (original[max_y][min_x]
+ + original[max_y][max_x]
+ - original[min_y][min_x]
+ - original[min_y][max_x])
+ + 10 * (original[max_y][center_x]
+ - original[min_y][center_x])) / 32;
+}
+
+template <typename T>
+template <typename U>
+inline void Image<T>::ScharrX(const Image<U>& original) {
+ for (int y = 0; y < height_; ++y) {
+ for (int x = 0; x < width_; ++x) {
+ SetPixel(x, y, ScharrPixelX(original, x, y));
+ }
+ }
+}
+
+template <typename T>
+template <typename U>
+inline void Image<T>::ScharrY(const Image<U>& original) {
+ for (int y = 0; y < height_; ++y) {
+ for (int x = 0; x < width_; ++x) {
+ SetPixel(x, y, ScharrPixelY(original, x, y));
+ }
+ }
+}
+
+template <typename T>
+template <typename U>
+void Image<T>::DerivativeX(const Image<U>& original) {
+ for (int y = 0; y < height_; ++y) {
+ const U* const source_row = original[y];
+ T* const dest_row = (*this)[y];
+
+ // Compute first pixel. Approximated with forward difference.
+ dest_row[0] = source_row[1] - source_row[0];
+
+ // All the pixels in between. Central difference method.
+ const U* source_prev_pixel = source_row;
+ T* dest_pixel = dest_row + 1;
+ const U* source_next_pixel = source_row + 2;
+ for (int x = 1; x < width_less_one_; ++x) {
+ *dest_pixel++ = HalfDiff(*source_prev_pixel++, *source_next_pixel++);
+ }
+
+ // Last pixel. Approximated with backward difference.
+ dest_row[width_less_one_] =
+ source_row[width_less_one_] - source_row[width_less_one_ - 1];
+ }
+}
+
+template <typename T>
+template <typename U>
+void Image<T>::DerivativeY(const Image<U>& original) {
+ const int src_stride = original.stride();
+
+ // Compute 1st row. Approximated with forward difference.
+ {
+ const U* const src_row = original[0];
+ T* dest_row = (*this)[0];
+ for (int x = 0; x < width_; ++x) {
+ dest_row[x] = src_row[x + src_stride] - src_row[x];
+ }
+ }
+
+ // Compute all rows in between using central difference.
+ for (int y = 1; y < height_less_one_; ++y) {
+ T* dest_row = (*this)[y];
+
+ const U* source_prev_pixel = original[y - 1];
+ const U* source_next_pixel = original[y + 1];
+ for (int x = 0; x < width_; ++x) {
+ *dest_row++ = HalfDiff(*source_prev_pixel++, *source_next_pixel++);
+ }
+ }
+
+ // Compute last row. Approximated with backward difference.
+ {
+ const U* const src_row = original[height_less_one_];
+ T* dest_row = (*this)[height_less_one_];
+ for (int x = 0; x < width_; ++x) {
+ dest_row[x] = src_row[x] - src_row[x - src_stride];
+ }
+ }
+}
+
+template <typename T>
+template <typename U>
+inline T Image<T>::ConvolvePixel3x3(const Image<U>& original,
+ const int* const filter,
+ const int center_x, const int center_y,
+ const int total) const {
+ int32 sum = 0;
+ for (int filter_y = 0; filter_y < 3; ++filter_y) {
+ const int y = Clip(center_y - 1 + filter_y, 0, original.GetHeight());
+ for (int filter_x = 0; filter_x < 3; ++filter_x) {
+ const int x = Clip(center_x - 1 + filter_x, 0, original.GetWidth());
+ sum += original[y][x] * filter[filter_y * 3 + filter_x];
+ }
+ }
+ return sum / total;
+}
+
+template <typename T>
+template <typename U>
+inline void Image<T>::Convolve3x3(const Image<U>& original,
+ const int32* const filter) {
+ int32 sum = 0;
+ for (int i = 0; i < 9; ++i) {
+ sum += abs(filter[i]);
+ }
+ for (int y = 0; y < height_; ++y) {
+ for (int x = 0; x < width_; ++x) {
+ SetPixel(x, y, ConvolvePixel3x3(original, filter, x, y, sum));
+ }
+ }
+}
+
+template <typename T>
+inline void Image<T>::FromArray(const T* const pixels, const int stride,
+ const int factor) {
+ if (factor == 1 && stride == width_) {
+ // If not subsampling, memcpy per line should be faster.
+ memcpy(this->image_data_, pixels, data_size_ * sizeof(T));
+ return;
+ }
+
+ DownsampleAveraged(pixels, stride, factor);
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/image.h b/tensorflow/examples/android/jni/object_tracking/image.h
new file mode 100644
index 0000000000..29b0adbda8
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/image.h
@@ -0,0 +1,346 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
+
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+using namespace tensorflow;
+
+// TODO(andrewharp): Make this a cast to uint32 if/when we go unsigned for
+// operations.
+#define ZERO 0
+
+#ifdef SANITY_CHECKS
+ #define CHECK_PIXEL(IMAGE, X, Y) {\
+ SCHECK((IMAGE)->ValidPixel((X), (Y)), \
+ "CHECK_PIXEL(%d,%d) in %dx%d image.", \
+ static_cast<int>(X), static_cast<int>(Y), \
+ (IMAGE)->GetWidth(), (IMAGE)->GetHeight());\
+ }
+
+ #define CHECK_PIXEL_INTERP(IMAGE, X, Y) {\
+ SCHECK((IMAGE)->validInterpPixel((X), (Y)), \
+ "CHECK_PIXEL_INTERP(%.2f, %.2f) in %dx%d image.", \
+ static_cast<float>(X), static_cast<float>(Y), \
+ (IMAGE)->GetWidth(), (IMAGE)->GetHeight());\
+ }
+#else
+ #define CHECK_PIXEL(image, x, y) {}
+ #define CHECK_PIXEL_INTERP(IMAGE, X, Y) {}
+#endif
+
+namespace tf_tracking {
+
+#ifdef SANITY_CHECKS
+// Class which exists solely to provide bounds checking for array-style image
+// data access.
+template <typename T>
+class RowData {
+ public:
+ RowData(T* const row_data, const int max_col)
+ : row_data_(row_data), max_col_(max_col) {}
+
+ inline T& operator[](const int col) const {
+ SCHECK(InRange(col, 0, max_col_),
+ "Column out of range: %d (%d max)", col, max_col_);
+ return row_data_[col];
+ }
+
+ inline operator T*() const {
+ return row_data_;
+ }
+
+ private:
+ T* const row_data_;
+ const int max_col_;
+};
+#endif
+
+// Naive templated sorting function.
+template <typename T>
+int Comp(const void* a, const void* b) {
+ const T val1 = *reinterpret_cast<const T*>(a);
+ const T val2 = *reinterpret_cast<const T*>(b);
+
+ if (val1 == val2) {
+ return 0;
+ } else if (val1 < val2) {
+ return -1;
+ } else {
+ return 1;
+ }
+}
+
+// TODO(andrewharp): Make explicit which operations support negative numbers or
+// struct/class types in image data (possibly create fast multi-dim array class
+// for data where pixel arithmetic does not make sense).
+
+// Image class optimized for working on numeric arrays as grayscale image data.
+// Supports other data types as a 2D array class, so long as no pixel math
+// operations are called (convolution, downsampling, etc).
+template <typename T>
+class Image {
+ public:
+ Image(const int width, const int height);
+ explicit Image(const Size& size);
+
+ // Constructor that creates an image from preallocated data.
+ // Note: The image takes ownership of the data lifecycle, unless own_data is
+ // set to false.
+ Image(const int width, const int height, T* const image_data,
+ const bool own_data = true);
+
+ ~Image();
+
+ // Extract a pixel patch from this image, starting at a subpixel location.
+ // Uses 16:16 fixed point format for representing real values and doing the
+ // bilinear interpolation.
+ //
+ // Arguments fp_x and fp_y tell the subpixel position in fixed point format,
+ // patchwidth/patchheight give the size of the patch in pixels and
+ // to_data must be a valid pointer to a *contiguous* destination data array.
+ template<class DstType>
+ bool ExtractPatchAtSubpixelFixed1616(const int fp_x,
+ const int fp_y,
+ const int patchwidth,
+ const int patchheight,
+ DstType* to_data) const;
+
+ Image<T>* Crop(
+ const int left, const int top, const int right, const int bottom) const;
+
+ inline int GetWidth() const { return width_; }
+ inline int GetHeight() const { return height_; }
+
+ // Bilinearly sample a value between pixels. Values must be within the image.
+ inline float GetPixelInterp(const float x, const float y) const;
+
+ // Bilinearly sample a pixels at a subpixel position using fixed point
+ // arithmetic.
+ // Avoids float<->int conversions.
+ // Values must be within the image.
+ // Arguments fp_x and fp_y tell the subpixel position in
+ // 16:16 fixed point format.
+ //
+ // Important: This function only makes sense for integer-valued images, such
+ // as Image<uint8> or Image<int> etc.
+ inline T GetPixelInterpFixed1616(const int fp_x_whole,
+ const int fp_y_whole) const;
+
+ // Returns true iff the pixel is in the image's boundaries.
+ inline bool ValidPixel(const int x, const int y) const;
+
+ inline BoundingBox GetContainingBox() const;
+
+ inline bool Contains(const BoundingBox& bounding_box) const;
+
+ inline T GetMedianValue() {
+ qsort(image_data_, data_size_, sizeof(image_data_[0]), Comp<T>);
+ return image_data_[data_size_ >> 1];
+ }
+
+ // Returns true iff the pixel is in the image's boundaries for interpolation
+ // purposes.
+ // TODO(andrewharp): check in interpolation follow-up change.
+ inline bool ValidInterpPixel(const float x, const float y) const;
+
+ // Safe lookup with boundary enforcement.
+ inline T GetPixelClipped(const int x, const int y) const {
+ return (*this)[Clip(y, ZERO, height_less_one_)]
+ [Clip(x, ZERO, width_less_one_)];
+ }
+
+#ifdef SANITY_CHECKS
+ inline RowData<T> operator[](const int row) {
+ SCHECK(InRange(row, 0, height_less_one_),
+ "Row out of range: %d (%d max)", row, height_less_one_);
+ return RowData<T>(image_data_ + row * stride_, width_less_one_);
+ }
+
+ inline const RowData<T> operator[](const int row) const {
+ SCHECK(InRange(row, 0, height_less_one_),
+ "Row out of range: %d (%d max)", row, height_less_one_);
+ return RowData<T>(image_data_ + row * stride_, width_less_one_);
+ }
+#else
+ inline T* operator[](const int row) {
+ return image_data_ + row * stride_;
+ }
+
+ inline const T* operator[](const int row) const {
+ return image_data_ + row * stride_;
+ }
+#endif
+
+ const T* data() const { return image_data_; }
+
+ inline int stride() const { return stride_; }
+
+ // Clears image to a single value.
+ inline void Clear(const T& val) {
+ memset(image_data_, val, sizeof(*image_data_) * data_size_);
+ }
+
+#ifdef __ARM_NEON
+ void Downsample2x32ColumnsNeon(const uint8* const original,
+ const int stride,
+ const int orig_x);
+
+ void Downsample4x32ColumnsNeon(const uint8* const original,
+ const int stride,
+ const int orig_x);
+
+ void DownsampleAveragedNeon(const uint8* const original, const int stride,
+ const int factor);
+#endif
+
+ // Naive downsampler that reduces image size by factor by averaging pixels in
+ // blocks of size factor x factor.
+ void DownsampleAveraged(const T* const original, const int stride,
+ const int factor);
+
+ // Naive downsampler that reduces image size by factor by averaging pixels in
+ // blocks of size factor x factor.
+ inline void DownsampleAveraged(const Image<T>& original, const int factor) {
+ DownsampleAveraged(original.data(), original.GetWidth(), factor);
+ }
+
+ // Native downsampler that reduces image size using nearest interpolation
+ void DownsampleInterpolateNearest(const Image<T>& original);
+
+ // Native downsampler that reduces image size using fixed-point bilinear
+ // interpolation
+ void DownsampleInterpolateLinear(const Image<T>& original);
+
+ // Relatively efficient downsampling of an image by a factor of two with a
+ // low-pass 3x3 smoothing operation thrown in.
+ void DownsampleSmoothed3x3(const Image<T>& original);
+
+ // Relatively efficient downsampling of an image by a factor of two with a
+ // low-pass 5x5 smoothing operation thrown in.
+ void DownsampleSmoothed5x5(const Image<T>& original);
+
+ // Optimized Scharr filter on a single pixel in the X direction.
+ // Scharr filters are like central-difference operators, but have more
+ // rotational symmetry in their response because they also consider the
+ // diagonal neighbors.
+ template <typename U>
+ inline T ScharrPixelX(const Image<U>& original,
+ const int center_x, const int center_y) const;
+
+ // Optimized Scharr filter on a single pixel in the X direction.
+ // Scharr filters are like central-difference operators, but have more
+ // rotational symmetry in their response because they also consider the
+ // diagonal neighbors.
+ template <typename U>
+ inline T ScharrPixelY(const Image<U>& original,
+ const int center_x, const int center_y) const;
+
+ // Convolve the image with a Scharr filter in the X direction.
+ // Much faster than an equivalent generic convolution.
+ template <typename U>
+ inline void ScharrX(const Image<U>& original);
+
+ // Convolve the image with a Scharr filter in the Y direction.
+ // Much faster than an equivalent generic convolution.
+ template <typename U>
+ inline void ScharrY(const Image<U>& original);
+
+ static inline T HalfDiff(int32 first, int32 second) {
+ return (second - first) / 2;
+ }
+
+ template <typename U>
+ void DerivativeX(const Image<U>& original);
+
+ template <typename U>
+ void DerivativeY(const Image<U>& original);
+
+ // Generic function for convolving pixel with 3x3 filter.
+ // Filter pixels should be in row major order.
+ template <typename U>
+ inline T ConvolvePixel3x3(const Image<U>& original,
+ const int* const filter,
+ const int center_x, const int center_y,
+ const int total) const;
+
+ // Generic function for convolving an image with a 3x3 filter.
+ // TODO(andrewharp): Generalize this for any size filter.
+ template <typename U>
+ inline void Convolve3x3(const Image<U>& original,
+ const int32* const filter);
+
+ // Load this image's data from a data array. The data at pixels is assumed to
+ // have dimensions equivalent to this image's dimensions * factor.
+ inline void FromArray(const T* const pixels, const int stride,
+ const int factor = 1);
+
+ // Copy the image back out to an appropriately sized data array.
+ inline void ToArray(T* const pixels) const {
+ // If not subsampling, memcpy should be faster.
+ memcpy(pixels, this->image_data_, data_size_ * sizeof(T));
+ }
+
+ // Precompute these for efficiency's sake as they're used by a lot of
+ // clipping code and loop code.
+ // TODO(andrewharp): make these only accessible by other Images.
+ const int width_less_one_;
+ const int height_less_one_;
+
+ // The raw size of the allocated data.
+ const int data_size_;
+
+ private:
+ inline void Allocate() {
+ image_data_ = new T[data_size_];
+ if (image_data_ == NULL) {
+ LOGE("Couldn't allocate image data!");
+ }
+ }
+
+ T* image_data_;
+
+ bool own_data_;
+
+ const int width_;
+ const int height_;
+
+ // The image stride (offset to next row).
+ // TODO(andrewharp): Make sure that stride is honored in all code.
+ const int stride_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Image);
+};
+
+template <typename t>
+inline std::ostream& operator<<(std::ostream& stream, const Image<t>& image) {
+ for (int y = 0; y < image.GetHeight(); ++y) {
+ for (int x = 0; x < image.GetWidth(); ++x) {
+ stream << image[y][x] << " ";
+ }
+ stream << std::endl;
+ }
+ return stream;
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/image_data.h b/tensorflow/examples/android/jni/object_tracking/image_data.h
new file mode 100644
index 0000000000..16b1864ee6
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/image_data.h
@@ -0,0 +1,270 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
+
+#include <memory>
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/image_utils.h"
+#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+
+using namespace tensorflow;
+
+namespace tf_tracking {
+
+// Class that encapsulates all bulky processed data for a frame.
+class ImageData {
+ public:
+ explicit ImageData(const int width, const int height)
+ : uv_frame_width_(width << 1),
+ uv_frame_height_(height << 1),
+ timestamp_(0),
+ image_(width, height) {
+ InitPyramid(width, height);
+ ResetComputationCache();
+ }
+
+ private:
+ void ResetComputationCache() {
+ uv_data_computed_ = false;
+ integral_image_computed_ = false;
+ for (int i = 0; i < kNumPyramidLevels; ++i) {
+ spatial_x_computed_[i] = false;
+ spatial_y_computed_[i] = false;
+ pyramid_sqrt2_computed_[i * 2] = false;
+ pyramid_sqrt2_computed_[i * 2 + 1] = false;
+ }
+ }
+
+ void InitPyramid(const int width, const int height) {
+ int level_width = width;
+ int level_height = height;
+
+ for (int i = 0; i < kNumPyramidLevels; ++i) {
+ pyramid_sqrt2_[i * 2] = NULL;
+ pyramid_sqrt2_[i * 2 + 1] = NULL;
+ spatial_x_[i] = NULL;
+ spatial_y_[i] = NULL;
+
+ level_width /= 2;
+ level_height /= 2;
+ }
+
+ // Alias the first pyramid level to image_.
+ pyramid_sqrt2_[0] = &image_;
+ }
+
+ public:
+ ~ImageData() {
+ // The first pyramid level is actually an alias to image_,
+ // so make sure it doesn't get deleted here.
+ pyramid_sqrt2_[0] = NULL;
+
+ for (int i = 0; i < kNumPyramidLevels; ++i) {
+ SAFE_DELETE(pyramid_sqrt2_[i * 2]);
+ SAFE_DELETE(pyramid_sqrt2_[i * 2 + 1]);
+ SAFE_DELETE(spatial_x_[i]);
+ SAFE_DELETE(spatial_y_[i]);
+ }
+ }
+
+ void SetData(const uint8* const new_frame, const int stride,
+ const int64 timestamp, const int downsample_factor) {
+ SetData(new_frame, NULL, stride, timestamp, downsample_factor);
+ }
+
+ void SetData(const uint8* const new_frame,
+ const uint8* const uv_frame,
+ const int stride,
+ const int64 timestamp, const int downsample_factor) {
+ ResetComputationCache();
+
+ timestamp_ = timestamp;
+
+ TimeLog("SetData!");
+
+ pyramid_sqrt2_[0]->FromArray(new_frame, stride, downsample_factor);
+ pyramid_sqrt2_computed_[0] = true;
+ TimeLog("Downsampled image");
+
+ if (uv_frame != NULL) {
+ if (u_data_.get() == NULL) {
+ u_data_.reset(new Image<uint8>(uv_frame_width_, uv_frame_height_));
+ v_data_.reset(new Image<uint8>(uv_frame_width_, uv_frame_height_));
+ }
+
+ GetUV(uv_frame, u_data_.get(), v_data_.get());
+ uv_data_computed_ = true;
+ TimeLog("Copied UV data");
+ } else {
+ LOGV("No uv data!");
+ }
+
+#ifdef LOG_TIME
+ // If profiling is enabled, precompute here to make it easier to distinguish
+ // total costs.
+ Precompute();
+#endif
+ }
+
+ inline const uint64 GetTimestamp() const {
+ return timestamp_;
+ }
+
+ inline const Image<uint8>* GetImage() const {
+ SCHECK(pyramid_sqrt2_computed_[0], "image not set!");
+ return pyramid_sqrt2_[0];
+ }
+
+ const Image<uint8>* GetPyramidSqrt2Level(const int level) const {
+ if (!pyramid_sqrt2_computed_[level]) {
+ SCHECK(level != 0, "Level equals 0!");
+ if (level == 1) {
+ const Image<uint8>& upper_level = *GetPyramidSqrt2Level(0);
+ if (pyramid_sqrt2_[level] == NULL) {
+ const int new_width =
+ (static_cast<int>(upper_level.GetWidth() / sqrtf(2)) + 1) / 2 * 2;
+ const int new_height =
+ (static_cast<int>(upper_level.GetHeight() / sqrtf(2)) + 1) / 2 *
+ 2;
+
+ pyramid_sqrt2_[level] = new Image<uint8>(new_width, new_height);
+ }
+ pyramid_sqrt2_[level]->DownsampleInterpolateLinear(upper_level);
+ } else {
+ const Image<uint8>& upper_level = *GetPyramidSqrt2Level(level - 2);
+ if (pyramid_sqrt2_[level] == NULL) {
+ pyramid_sqrt2_[level] = new Image<uint8>(
+ upper_level.GetWidth() / 2, upper_level.GetHeight() / 2);
+ }
+ pyramid_sqrt2_[level]->DownsampleAveraged(
+ upper_level.data(), upper_level.stride(), 2);
+ }
+ pyramid_sqrt2_computed_[level] = true;
+ }
+ return pyramid_sqrt2_[level];
+ }
+
+ inline const Image<int32>* GetSpatialX(const int level) const {
+ if (!spatial_x_computed_[level]) {
+ const Image<uint8>& src = *GetPyramidSqrt2Level(level * 2);
+ if (spatial_x_[level] == NULL) {
+ spatial_x_[level] = new Image<int32>(src.GetWidth(), src.GetHeight());
+ }
+ spatial_x_[level]->DerivativeX(src);
+ spatial_x_computed_[level] = true;
+ }
+ return spatial_x_[level];
+ }
+
+ inline const Image<int32>* GetSpatialY(const int level) const {
+ if (!spatial_y_computed_[level]) {
+ const Image<uint8>& src = *GetPyramidSqrt2Level(level * 2);
+ if (spatial_y_[level] == NULL) {
+ spatial_y_[level] = new Image<int32>(src.GetWidth(), src.GetHeight());
+ }
+ spatial_y_[level]->DerivativeY(src);
+ spatial_y_computed_[level] = true;
+ }
+ return spatial_y_[level];
+ }
+
+ // The integral image is currently only used for object detection, so lazily
+ // initialize it on request.
+ inline const IntegralImage* GetIntegralImage() const {
+ if (integral_image_.get() == NULL) {
+ integral_image_.reset(new IntegralImage(image_));
+ } else if (!integral_image_computed_) {
+ integral_image_->Recompute(image_);
+ }
+ integral_image_computed_ = true;
+ return integral_image_.get();
+ }
+
+ inline const Image<uint8>* GetU() const {
+ SCHECK(uv_data_computed_, "UV data not provided!");
+ return u_data_.get();
+ }
+
+ inline const Image<uint8>* GetV() const {
+ SCHECK(uv_data_computed_, "UV data not provided!");
+ return v_data_.get();
+ }
+
+ private:
+ void Precompute() {
+ // Create the smoothed pyramids.
+ for (int i = 0; i < kNumPyramidLevels * 2; i += 2) {
+ (void) GetPyramidSqrt2Level(i);
+ }
+ TimeLog("Created smoothed pyramids");
+
+ // Create the smoothed pyramids.
+ for (int i = 1; i < kNumPyramidLevels * 2; i += 2) {
+ (void) GetPyramidSqrt2Level(i);
+ }
+ TimeLog("Created smoothed sqrt pyramids");
+
+ // Create the spatial derivatives for frame 1.
+ for (int i = 0; i < kNumPyramidLevels; ++i) {
+ (void) GetSpatialX(i);
+ (void) GetSpatialY(i);
+ }
+ TimeLog("Created spatial derivatives");
+
+ (void) GetIntegralImage();
+ TimeLog("Got integral image!");
+ }
+
+ const int uv_frame_width_;
+ const int uv_frame_height_;
+
+ int64 timestamp_;
+
+ Image<uint8> image_;
+
+ bool uv_data_computed_;
+ std::unique_ptr<Image<uint8> > u_data_;
+ std::unique_ptr<Image<uint8> > v_data_;
+
+ mutable bool spatial_x_computed_[kNumPyramidLevels];
+ mutable Image<int32>* spatial_x_[kNumPyramidLevels];
+
+ mutable bool spatial_y_computed_[kNumPyramidLevels];
+ mutable Image<int32>* spatial_y_[kNumPyramidLevels];
+
+ // Mutable so the lazy initialization can work when this class is const.
+ // Whether or not the integral image has been computed for the current image.
+ mutable bool integral_image_computed_;
+ mutable std::unique_ptr<IntegralImage> integral_image_;
+
+ mutable bool pyramid_sqrt2_computed_[kNumPyramidLevels * 2];
+ mutable Image<uint8>* pyramid_sqrt2_[kNumPyramidLevels * 2];
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ImageData);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/image_neon.cc b/tensorflow/examples/android/jni/object_tracking/image_neon.cc
new file mode 100644
index 0000000000..ddd8447bf3
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/image_neon.cc
@@ -0,0 +1,270 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// NEON implementations of Image methods for compatible devices. Control
+// should never enter this compilation unit on incompatible devices.
+
+#ifdef __ARM_NEON
+
+#include <arm_neon.h>
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/image_utils.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+using namespace tensorflow;
+
+namespace tf_tracking {
+
+// This function does the bulk of the work.
+template <>
+void Image<uint8>::Downsample2x32ColumnsNeon(const uint8* const original,
+ const int stride,
+ const int orig_x) {
+ // Divide input x offset by 2 to find output offset.
+ const int new_x = orig_x >> 1;
+
+ // Initial offset into top row.
+ const uint8* offset = original + orig_x;
+
+ // This points to the leftmost pixel of our 8 horizontally arranged
+ // pixels in the destination data.
+ uint8* ptr_dst = (*this)[0] + new_x;
+
+ // Sum along vertical columns.
+ // Process 32x2 input pixels and 16x1 output pixels per iteration.
+ for (int new_y = 0; new_y < height_; ++new_y) {
+ uint16x8_t accum1 = vdupq_n_u16(0);
+ uint16x8_t accum2 = vdupq_n_u16(0);
+
+ // Go top to bottom across the four rows of input pixels that make up
+ // this output row.
+ for (int row_num = 0; row_num < 2; ++row_num) {
+ // First 16 bytes.
+ {
+ // Load 16 bytes of data from current offset.
+ const uint8x16_t curr_data1 = vld1q_u8(offset);
+
+ // Pairwise add and accumulate into accum vectors (16 bit to account
+ // for values above 255).
+ accum1 = vpadalq_u8(accum1, curr_data1);
+ }
+
+ // Second 16 bytes.
+ {
+ // Load 16 bytes of data from current offset.
+ const uint8x16_t curr_data2 = vld1q_u8(offset + 16);
+
+ // Pairwise add and accumulate into accum vectors (16 bit to account
+ // for values above 255).
+ accum2 = vpadalq_u8(accum2, curr_data2);
+ }
+
+ // Move offset down one row.
+ offset += stride;
+ }
+
+ // Divide by 4 (number of input pixels per output
+ // pixel) and narrow data from 16 bits per pixel to 8 bpp.
+ const uint8x8_t tmp_pix1 = vqshrn_n_u16(accum1, 2);
+ const uint8x8_t tmp_pix2 = vqshrn_n_u16(accum2, 2);
+
+ // Concatenate 8x1 pixel strips into 16x1 pixel strip.
+ const uint8x16_t allpixels = vcombine_u8(tmp_pix1, tmp_pix2);
+
+ // Copy all pixels from composite 16x1 vector into output strip.
+ vst1q_u8(ptr_dst, allpixels);
+
+ ptr_dst += stride_;
+ }
+}
+
+// This function does the bulk of the work.
+template <>
+void Image<uint8>::Downsample4x32ColumnsNeon(const uint8* const original,
+ const int stride,
+ const int orig_x) {
+ // Divide input x offset by 4 to find output offset.
+ const int new_x = orig_x >> 2;
+
+ // Initial offset into top row.
+ const uint8* offset = original + orig_x;
+
+ // This points to the leftmost pixel of our 8 horizontally arranged
+ // pixels in the destination data.
+ uint8* ptr_dst = (*this)[0] + new_x;
+
+ // Sum along vertical columns.
+ // Process 32x4 input pixels and 8x1 output pixels per iteration.
+ for (int new_y = 0; new_y < height_; ++new_y) {
+ uint16x8_t accum1 = vdupq_n_u16(0);
+ uint16x8_t accum2 = vdupq_n_u16(0);
+
+ // Go top to bottom across the four rows of input pixels that make up
+ // this output row.
+ for (int row_num = 0; row_num < 4; ++row_num) {
+ // First 16 bytes.
+ {
+ // Load 16 bytes of data from current offset.
+ const uint8x16_t curr_data1 = vld1q_u8(offset);
+
+ // Pairwise add and accumulate into accum vectors (16 bit to account
+ // for values above 255).
+ accum1 = vpadalq_u8(accum1, curr_data1);
+ }
+
+ // Second 16 bytes.
+ {
+ // Load 16 bytes of data from current offset.
+ const uint8x16_t curr_data2 = vld1q_u8(offset + 16);
+
+ // Pairwise add and accumulate into accum vectors (16 bit to account
+ // for values above 255).
+ accum2 = vpadalq_u8(accum2, curr_data2);
+ }
+
+ // Move offset down one row.
+ offset += stride;
+ }
+
+ // Add and widen, then divide by 16 (number of input pixels per output
+ // pixel) and narrow data from 32 bits per pixel to 16 bpp.
+ const uint16x4_t tmp_pix1 = vqshrn_n_u32(vpaddlq_u16(accum1), 4);
+ const uint16x4_t tmp_pix2 = vqshrn_n_u32(vpaddlq_u16(accum2), 4);
+
+ // Combine 4x1 pixel strips into 8x1 pixel strip and narrow from
+ // 16 bits to 8 bits per pixel.
+ const uint8x8_t allpixels = vmovn_u16(vcombine_u16(tmp_pix1, tmp_pix2));
+
+ // Copy all pixels from composite 8x1 vector into output strip.
+ vst1_u8(ptr_dst, allpixels);
+
+ ptr_dst += stride_;
+ }
+}
+
+
+// Hardware accelerated downsampling method for supported devices.
+// Requires that image size be a multiple of 16 pixels in each dimension,
+// and that downsampling be by a factor of 2 or 4.
+template <>
+void Image<uint8>::DownsampleAveragedNeon(const uint8* const original,
+ const int stride, const int factor) {
+ // TODO(andrewharp): stride is a bad approximation for the src image's width.
+ // Better to pass that in directly.
+ SCHECK(width_ * factor <= stride, "Uh oh!");
+ const int last_starting_index = width_ * factor - 32;
+
+ // We process 32 input pixels lengthwise at a time.
+ // The output per pass of this loop is an 8 wide by downsampled height tall
+ // pixel strip.
+ int orig_x = 0;
+ for (; orig_x <= last_starting_index; orig_x += 32) {
+ if (factor == 2) {
+ Downsample2x32ColumnsNeon(original, stride, orig_x);
+ } else {
+ Downsample4x32ColumnsNeon(original, stride, orig_x);
+ }
+ }
+
+ // If a last pass is required, push it to the left enough so that it never
+ // goes out of bounds. This will result in some extra computation on devices
+ // whose frame widths are multiples of 16 and not 32.
+ if (orig_x < last_starting_index + 32) {
+ if (factor == 2) {
+ Downsample2x32ColumnsNeon(original, stride, last_starting_index);
+ } else {
+ Downsample4x32ColumnsNeon(original, stride, last_starting_index);
+ }
+ }
+}
+
+
+// Puts the image gradient matrix about a pixel into the 2x2 float array G.
+// vals_x should be an array of the window x gradient values, whose indices
+// can be in any order but are parallel to the vals_y entries.
+// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for more details.
+void CalculateGNeon(const float* const vals_x, const float* const vals_y,
+ const int num_vals, float* const G) {
+ const float32_t* const arm_vals_x = (const float32_t*) vals_x;
+ const float32_t* const arm_vals_y = (const float32_t*) vals_y;
+
+ // Running sums.
+ float32x4_t xx = vdupq_n_f32(0.0f);
+ float32x4_t xy = vdupq_n_f32(0.0f);
+ float32x4_t yy = vdupq_n_f32(0.0f);
+
+ // Maximum index we can load 4 consecutive values from.
+ // e.g. if there are 81 values, our last full pass can be from index 77:
+ // 81-4=>77 (77, 78, 79, 80)
+ const int max_i = num_vals - 4;
+
+ // Defined here because we want to keep track of how many values were
+ // processed by NEON, so that we can finish off the remainder the normal
+ // way.
+ int i = 0;
+
+ // Process values 4 at a time, accumulating the sums of
+ // the pixel-wise x*x, x*y, and y*y values.
+ for (; i <= max_i; i += 4) {
+ // Load xs
+ float32x4_t x = vld1q_f32(arm_vals_x + i);
+
+ // Multiply x*x and accumulate.
+ xx = vmlaq_f32(xx, x, x);
+
+ // Load ys
+ float32x4_t y = vld1q_f32(arm_vals_y + i);
+
+ // Multiply x*y and accumulate.
+ xy = vmlaq_f32(xy, x, y);
+
+ // Multiply y*y and accumulate.
+ yy = vmlaq_f32(yy, y, y);
+ }
+
+ static float32_t xx_vals[4];
+ static float32_t xy_vals[4];
+ static float32_t yy_vals[4];
+
+ vst1q_f32(xx_vals, xx);
+ vst1q_f32(xy_vals, xy);
+ vst1q_f32(yy_vals, yy);
+
+ // Accumulated values are store in sets of 4, we have to manually add
+ // the last bits together.
+ for (int j = 0; j < 4; ++j) {
+ G[0] += xx_vals[j];
+ G[1] += xy_vals[j];
+ G[3] += yy_vals[j];
+ }
+
+ // Finishes off last few values (< 4) from above.
+ for (; i < num_vals; ++i) {
+ G[0] += Square(vals_x[i]);
+ G[1] += vals_x[i] * vals_y[i];
+ G[3] += Square(vals_y[i]);
+ }
+
+ // The matrix is symmetric, so this is a given.
+ G[2] = G[1];
+}
+
+} // namespace tf_tracking
+
+#endif
diff --git a/tensorflow/examples/android/jni/object_tracking/image_utils.h b/tensorflow/examples/android/jni/object_tracking/image_utils.h
new file mode 100644
index 0000000000..5357a9352f
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/image_utils.h
@@ -0,0 +1,301 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+using namespace tensorflow;
+
+namespace tf_tracking {
+
+inline void GetUV(
+ const uint8* const input, Image<uint8>* const u, Image<uint8>* const v) {
+ const uint8* pUV = input;
+
+ for (int row = 0; row < u->GetHeight(); ++row) {
+ uint8* u_curr = (*u)[row];
+ uint8* v_curr = (*v)[row];
+ for (int col = 0; col < u->GetWidth(); ++col) {
+#ifdef __APPLE__
+ *u_curr++ = *pUV++;
+ *v_curr++ = *pUV++;
+#else
+ *v_curr++ = *pUV++;
+ *u_curr++ = *pUV++;
+#endif
+ }
+ }
+}
+
+// Marks every point within a circle of a given radius on the given boolean
+// image true.
+template <typename U>
+inline static void MarkImage(const int x, const int y, const int radius,
+ Image<U>* const img) {
+ SCHECK(img->ValidPixel(x, y), "Marking invalid pixel in image! %d, %d", x, y);
+
+ // Precomputed for efficiency.
+ const int squared_radius = Square(radius);
+
+ // Mark every row in the circle.
+ for (int d_y = 0; d_y <= radius; ++d_y) {
+ const int squared_y_dist = Square(d_y);
+
+ const int min_y = MAX(y - d_y, 0);
+ const int max_y = MIN(y + d_y, img->height_less_one_);
+
+ // The max d_x of the circle must be strictly greater or equal to
+ // radius - d_y for any positive d_y. Thus, starting from radius - d_y will
+ // reduce the number of iterations required as compared to starting from
+ // either 0 and counting up or radius and counting down.
+ for (int d_x = radius - d_y; d_x <= radius; ++d_x) {
+ // The first time this critera is met, we know the width of the circle at
+ // this row (without using sqrt).
+ if (squared_y_dist + Square(d_x) >= squared_radius) {
+ const int min_x = MAX(x - d_x, 0);
+ const int max_x = MIN(x + d_x, img->width_less_one_);
+
+ // Mark both above and below the center row.
+ bool* const top_row_start = (*img)[min_y] + min_x;
+ bool* const bottom_row_start = (*img)[max_y] + min_x;
+
+ const int x_width = max_x - min_x + 1;
+ memset(top_row_start, true, sizeof(*top_row_start) * x_width);
+ memset(bottom_row_start, true, sizeof(*bottom_row_start) * x_width);
+
+ // This row is marked, time to move on to the next row.
+ break;
+ }
+ }
+ }
+}
+
+#ifdef __ARM_NEON
+void CalculateGNeon(
+ const float* const vals_x, const float* const vals_y,
+ const int num_vals, float* const G);
+#endif
+
+// Puts the image gradient matrix about a pixel into the 2x2 float array G.
+// vals_x should be an array of the window x gradient values, whose indices
+// can be in any order but are parallel to the vals_y entries.
+// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for more details.
+inline void CalculateG(const float* const vals_x, const float* const vals_y,
+ const int num_vals, float* const G) {
+#ifdef __ARM_NEON
+ CalculateGNeon(vals_x, vals_y, num_vals, G);
+ return;
+#endif
+
+ // Non-accelerated version.
+ for (int i = 0; i < num_vals; ++i) {
+ G[0] += Square(vals_x[i]);
+ G[1] += vals_x[i] * vals_y[i];
+ G[3] += Square(vals_y[i]);
+ }
+
+ // The matrix is symmetric, so this is a given.
+ G[2] = G[1];
+}
+
+
+inline void CalculateGInt16(const int16* const vals_x,
+ const int16* const vals_y,
+ const int num_vals, int* const G) {
+ // Non-accelerated version.
+ for (int i = 0; i < num_vals; ++i) {
+ G[0] += Square(vals_x[i]);
+ G[1] += vals_x[i] * vals_y[i];
+ G[3] += Square(vals_y[i]);
+ }
+
+ // The matrix is symmetric, so this is a given.
+ G[2] = G[1];
+}
+
+
+// Puts the image gradient matrix about a pixel into the 2x2 float array G.
+// Looks up interpolated pixels, then calls above method for implementation.
+inline void CalculateG(const int window_radius,
+ const float center_x, const float center_y,
+ const Image<int32>& I_x, const Image<int32>& I_y,
+ float* const G) {
+ SCHECK(I_x.ValidPixel(center_x, center_y), "Problem in calculateG!");
+
+ // Hardcoded to allow for a max window radius of 5 (9 pixels x 9 pixels).
+ static const int kMaxWindowRadius = 5;
+ SCHECK(window_radius <= kMaxWindowRadius,
+ "Window %d > %d!", window_radius, kMaxWindowRadius);
+
+ // Diameter of window is 2 * radius + 1 for center pixel.
+ static const int kWindowBufferSize =
+ (kMaxWindowRadius * 2 + 1) * (kMaxWindowRadius * 2 + 1);
+
+ // Preallocate buffers statically for efficiency.
+ static int16 vals_x[kWindowBufferSize];
+ static int16 vals_y[kWindowBufferSize];
+
+ const int src_left_fixed = RealToFixed1616(center_x - window_radius);
+ const int src_top_fixed = RealToFixed1616(center_y - window_radius);
+
+ int16* vals_x_ptr = vals_x;
+ int16* vals_y_ptr = vals_y;
+
+ const int window_size = 2 * window_radius + 1;
+ for (int y = 0; y < window_size; ++y) {
+ const int fp_y = src_top_fixed + (y << 16);
+
+ for (int x = 0; x < window_size; ++x) {
+ const int fp_x = src_left_fixed + (x << 16);
+
+ *vals_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y);
+ *vals_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y);
+ }
+ }
+
+ int32 g_temp[] = {0, 0, 0, 0};
+ CalculateGInt16(vals_x, vals_y, window_size * window_size, g_temp);
+
+ for (int i = 0; i < 4; ++i) {
+ G[i] = g_temp[i];
+ }
+}
+
+inline float ImageCrossCorrelation(const Image<float>& image1,
+ const Image<float>& image2,
+ const int x_offset, const int y_offset) {
+ SCHECK(image1.GetWidth() == image2.GetWidth() &&
+ image1.GetHeight() == image2.GetHeight(),
+ "Dimension mismatch! %dx%d vs %dx%d",
+ image1.GetWidth(), image1.GetHeight(),
+ image2.GetWidth(), image2.GetHeight());
+
+ const int num_pixels = image1.GetWidth() * image1.GetHeight();
+ const float* data1 = image1.data();
+ const float* data2 = image2.data();
+ return ComputeCrossCorrelation(data1, data2, num_pixels);
+}
+
+// Copies an arbitrary region of an image to another (floating point)
+// image, scaling as it goes using bilinear interpolation.
+inline void CopyArea(const Image<uint8>& image,
+ const BoundingBox& area_to_copy,
+ Image<float>* const patch_image) {
+ VLOG(2) << "Copying from: " << area_to_copy << std::endl;
+
+ const int patch_width = patch_image->GetWidth();
+ const int patch_height = patch_image->GetHeight();
+
+ const float x_dist_between_samples = patch_width > 0 ?
+ area_to_copy.GetWidth() / (patch_width - 1) : 0;
+
+ const float y_dist_between_samples = patch_height > 0 ?
+ area_to_copy.GetHeight() / (patch_height - 1) : 0;
+
+ for (int y_index = 0; y_index < patch_height; ++y_index) {
+ const float sample_y =
+ y_index * y_dist_between_samples + area_to_copy.top_;
+
+ for (int x_index = 0; x_index < patch_width; ++x_index) {
+ const float sample_x =
+ x_index * x_dist_between_samples + area_to_copy.left_;
+
+ if (image.ValidInterpPixel(sample_x, sample_y)) {
+ // TODO(andrewharp): Do area averaging when downsampling.
+ (*patch_image)[y_index][x_index] =
+ image.GetPixelInterp(sample_x, sample_y);
+ } else {
+ (*patch_image)[y_index][x_index] = -1.0f;
+ }
+ }
+ }
+}
+
+
+// Takes a floating point image and normalizes it in-place.
+//
+// First, negative values will be set to the mean of the non-negative pixels
+// in the image.
+//
+// Then, the resulting will be normalized such that it has mean value of 0.0 and
+// a standard deviation of 1.0.
+inline void NormalizeImage(Image<float>* const image) {
+ const float* const data_ptr = image->data();
+
+ // Copy only the non-negative values to some temp memory.
+ float running_sum = 0.0f;
+ int num_data_gte_zero = 0;
+ {
+ float* const curr_data = (*image)[0];
+ for (int i = 0; i < image->data_size_; ++i) {
+ if (curr_data[i] >= 0.0f) {
+ running_sum += curr_data[i];
+ ++num_data_gte_zero;
+ } else {
+ curr_data[i] = -1.0f;
+ }
+ }
+ }
+
+ // If none of the pixels are valid, just set the entire thing to 0.0f.
+ if (num_data_gte_zero == 0) {
+ image->Clear(0.0f);
+ return;
+ }
+
+ const float corrected_mean = running_sum / num_data_gte_zero;
+
+ float* curr_data = (*image)[0];
+ for (int i = 0; i < image->data_size_; ++i) {
+ const float curr_val = *curr_data;
+ *curr_data++ = curr_val < 0 ? 0 : curr_val - corrected_mean;
+ }
+
+ const float std_dev = ComputeStdDev(data_ptr, image->data_size_, 0.0f);
+
+ if (std_dev > 0.0f) {
+ curr_data = (*image)[0];
+ for (int i = 0; i < image->data_size_; ++i) {
+ *curr_data++ /= std_dev;
+ }
+
+#ifdef SANITY_CHECKS
+ LOGV("corrected_mean: %1.2f std_dev: %1.2f", corrected_mean, std_dev);
+ const float correlation =
+ ComputeCrossCorrelation(image->data(),
+ image->data(),
+ image->data_size_);
+
+ if (std::abs(correlation - 1.0f) > EPSILON) {
+ LOG(ERROR) << "Bad image!" << std::endl;
+ LOG(ERROR) << *image << std::endl;
+ }
+
+ SCHECK(std::abs(correlation - 1.0f) < EPSILON,
+ "Correlation wasn't 1.0f: %.10f", correlation);
+#endif
+ }
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/integral_image.h b/tensorflow/examples/android/jni/object_tracking/integral_image.h
new file mode 100755
index 0000000000..28b2045572
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/integral_image.h
@@ -0,0 +1,187 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+namespace tf_tracking {
+
+typedef uint8 Code;
+
+class IntegralImage: public Image<uint32> {
+ public:
+ explicit IntegralImage(const Image<uint8>& image_base) :
+ Image<uint32>(image_base.GetWidth(), image_base.GetHeight()) {
+ Recompute(image_base);
+ }
+
+ IntegralImage(const int width, const int height) :
+ Image<uint32>(width, height) {}
+
+ void Recompute(const Image<uint8>& image_base) {
+ SCHECK(image_base.GetWidth() == GetWidth() &&
+ image_base.GetHeight() == GetHeight(), "Dimensions don't match!");
+
+ // Sum along first row.
+ {
+ int x_sum = 0;
+ for (int x = 0; x < image_base.GetWidth(); ++x) {
+ x_sum += image_base[0][x];
+ (*this)[0][x] = x_sum;
+ }
+ }
+
+ // Sum everything else.
+ for (int y = 1; y < image_base.GetHeight(); ++y) {
+ uint32* curr_sum = (*this)[y];
+
+ // Previously summed pointers.
+ const uint32* up_one = (*this)[y - 1];
+
+ // Current value pointer.
+ const uint8* curr_delta = image_base[y];
+
+ uint32 row_till_now = 0;
+
+ for (int x = 0; x < GetWidth(); ++x) {
+ // Add the one above and the one to the left.
+ row_till_now += *curr_delta;
+ *curr_sum = *up_one + row_till_now;
+
+ // Scoot everything along.
+ ++curr_sum;
+ ++up_one;
+ ++curr_delta;
+ }
+ }
+
+ SCHECK(VerifyData(image_base), "Images did not match!");
+ }
+
+ bool VerifyData(const Image<uint8>& image_base) {
+ for (int y = 0; y < GetHeight(); ++y) {
+ for (int x = 0; x < GetWidth(); ++x) {
+ uint32 curr_val = (*this)[y][x];
+
+ if (x > 0) {
+ curr_val -= (*this)[y][x - 1];
+ }
+
+ if (y > 0) {
+ curr_val -= (*this)[y - 1][x];
+ }
+
+ if (x > 0 && y > 0) {
+ curr_val += (*this)[y - 1][x - 1];
+ }
+
+ if (curr_val != image_base[y][x]) {
+ LOGE("Mismatch! %d vs %d", curr_val, image_base[y][x]);
+ return false;
+ }
+
+ if (GetRegionSum(x, y, x, y) != curr_val) {
+ LOGE("Mismatch!");
+ }
+ }
+ }
+
+ return true;
+ }
+
+ // Returns the sum of all pixels in the specified region.
+ inline uint32 GetRegionSum(const int x1, const int y1,
+ const int x2, const int y2) const {
+ SCHECK(x1 >= 0 && y1 >= 0 &&
+ x2 >= x1 && y2 >= y1 && x2 < GetWidth() && y2 < GetHeight(),
+ "indices out of bounds! %d-%d / %d, %d-%d / %d, ",
+ x1, x2, GetWidth(), y1, y2, GetHeight());
+
+ const uint32 everything = (*this)[y2][x2];
+
+ uint32 sum = everything;
+ if (x1 > 0 && y1 > 0) {
+ // Most common case.
+ const uint32 left = (*this)[y2][x1 - 1];
+ const uint32 top = (*this)[y1 - 1][x2];
+ const uint32 top_left = (*this)[y1 - 1][x1 - 1];
+
+ sum = everything - left - top + top_left;
+ SCHECK(sum >= 0, "Both: %d - %d - %d + %d => %d! indices: %d %d %d %d",
+ everything, left, top, top_left, sum, x1, y1, x2, y2);
+ } else if (x1 > 0) {
+ // Flush against top of image.
+ // Subtract out the region to the left only.
+ const uint32 top = (*this)[y2][x1 - 1];
+ sum = everything - top;
+ SCHECK(sum >= 0, "Top: %d - %d => %d!", everything, top, sum);
+ } else if (y1 > 0) {
+ // Flush against left side of image.
+ // Subtract out the region above only.
+ const uint32 left = (*this)[y1 - 1][x2];
+ sum = everything - left;
+ SCHECK(sum >= 0, "Left: %d - %d => %d!", everything, left, sum);
+ }
+
+ SCHECK(sum >= 0, "Negative sum!");
+
+ return sum;
+ }
+
+ // Returns the 2bit code associated with this region, which represents
+ // the overall gradient.
+ inline Code GetCode(const BoundingBox& bounding_box) const {
+ return GetCode(bounding_box.left_, bounding_box.top_,
+ bounding_box.right_, bounding_box.bottom_);
+ }
+
+ inline Code GetCode(const int x1, const int y1,
+ const int x2, const int y2) const {
+ SCHECK(x1 < x2 && y1 < y2, "Bounds out of order!! TL:%d,%d BR:%d,%d",
+ x1, y1, x2, y2);
+
+ // Gradient computed vertically.
+ const int box_height = (y2 - y1) / 2;
+ const int top_sum = GetRegionSum(x1, y1, x2, y1 + box_height);
+ const int bottom_sum = GetRegionSum(x1, y2 - box_height, x2, y2);
+ const bool vertical_code = top_sum > bottom_sum;
+
+ // Gradient computed horizontally.
+ const int box_width = (x2 - x1) / 2;
+ const int left_sum = GetRegionSum(x1, y1, x1 + box_width, y2);
+ const int right_sum = GetRegionSum(x2 - box_width, y1, x2, y2);
+ const bool horizontal_code = left_sum > right_sum;
+
+ const Code final_code = (vertical_code << 1) | horizontal_code;
+
+ SCHECK(InRange(final_code, static_cast<Code>(0), static_cast<Code>(3)),
+ "Invalid code! %d", final_code);
+
+ // Returns a value 0-3.
+ return final_code;
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(IntegralImage);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/jni_utils.h b/tensorflow/examples/android/jni/object_tracking/jni_utils.h
new file mode 100644
index 0000000000..92458536b6
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/jni_utils.h
@@ -0,0 +1,62 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
+
+#include <android/log.h>
+
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+// The JniIntField class is used to access Java fields from native code. This
+// technique of hiding pointers to native objects in opaque Java fields is how
+// the Android hardware libraries work. This reduces the amount of static
+// native methods and makes it easier to manage the lifetime of native objects.
+class JniIntField {
+ public:
+ JniIntField(const char* field_name) : field_name_(field_name), field_ID_(0) {}
+
+ int get(JNIEnv* env, jobject thiz) {
+ if (field_ID_ == 0) {
+ jclass cls = env->GetObjectClass(thiz);
+ CHECK_ALWAYS(cls != 0, "Unable to find class");
+ field_ID_ = env->GetFieldID(cls, field_name_, "I");
+ CHECK_ALWAYS(field_ID_ != 0,
+ "Unable to find field %s. (Check proguard cfg)", field_name_);
+ }
+
+ return env->GetIntField(thiz, field_ID_);
+ }
+
+ void set(JNIEnv* env, jobject thiz, int value) {
+ if (field_ID_ == 0) {
+ jclass cls = env->GetObjectClass(thiz);
+ CHECK_ALWAYS(cls != 0, "Unable to find class");
+ field_ID_ = env->GetFieldID(cls, field_name_, "I");
+ CHECK_ALWAYS(field_ID_ != 0,
+ "Unable to find field %s (Check proguard cfg)", field_name_);
+ }
+
+ env->SetIntField(thiz, field_ID_, value);
+ }
+
+ private:
+ const char* const field_name_;
+
+ // This is just a cache
+ jfieldID field_ID_;
+};
+
+#endif
diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint.h b/tensorflow/examples/android/jni/object_tracking/keypoint.h
new file mode 100644
index 0000000000..82917261cb
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/keypoint.h
@@ -0,0 +1,48 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h"
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+
+namespace tf_tracking {
+
+// For keeping track of keypoints.
+struct Keypoint {
+ Keypoint() : pos_(0.0f, 0.0f), score_(0.0f), type_(0) {}
+ Keypoint(const float x, const float y)
+ : pos_(x, y), score_(0.0f), type_(0) {}
+
+ Point2f pos_;
+ float score_;
+ uint8 type_;
+};
+
+inline std::ostream& operator<<(std::ostream& stream, const Keypoint keypoint) {
+ return stream << "[" << keypoint.pos_ << ", "
+ << keypoint.score_ << ", " << keypoint.type_ << "]";
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc
new file mode 100644
index 0000000000..6cc6b4e73f
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc
@@ -0,0 +1,549 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Various keypoint detecting functions.
+
+#include <float.h>
+
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
+
+namespace tf_tracking {
+
+static inline int GetDistSquaredBetween(const int* vec1, const int* vec2) {
+ return Square(vec1[0] - vec2[0]) + Square(vec1[1] - vec2[1]);
+}
+
+void KeypointDetector::ScoreKeypoints(const ImageData& image_data,
+ const int num_candidates,
+ Keypoint* const candidate_keypoints) {
+ const Image<int>& I_x = *image_data.GetSpatialX(0);
+ const Image<int>& I_y = *image_data.GetSpatialY(0);
+
+ if (config_->detect_skin) {
+ const Image<uint8>& u_data = *image_data.GetU();
+ const Image<uint8>& v_data = *image_data.GetV();
+
+ static const int reference[] = {111, 155};
+
+ // Score all the keypoints.
+ for (int i = 0; i < num_candidates; ++i) {
+ Keypoint* const keypoint = candidate_keypoints + i;
+
+ const int x_pos = keypoint->pos_.x * 2;
+ const int y_pos = keypoint->pos_.y * 2;
+
+ const int curr_color[] = {u_data[y_pos][x_pos], v_data[y_pos][x_pos]};
+ keypoint->score_ =
+ HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y) /
+ GetDistSquaredBetween(reference, curr_color);
+ }
+ } else {
+ // Score all the keypoints.
+ for (int i = 0; i < num_candidates; ++i) {
+ Keypoint* const keypoint = candidate_keypoints + i;
+ keypoint->score_ =
+ HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y);
+ }
+ }
+}
+
+
+inline int KeypointCompare(const void* const a, const void* const b) {
+ return (reinterpret_cast<const Keypoint*>(a)->score_ -
+ reinterpret_cast<const Keypoint*>(b)->score_) <= 0 ? 1 : -1;
+}
+
+
+// Quicksorts detected keypoints by score.
+void KeypointDetector::SortKeypoints(const int num_candidates,
+ Keypoint* const candidate_keypoints) const {
+ qsort(candidate_keypoints, num_candidates, sizeof(Keypoint), KeypointCompare);
+
+#ifdef SANITY_CHECKS
+ // Verify that the array got sorted.
+ float last_score = FLT_MAX;
+ for (int i = 0; i < num_candidates; ++i) {
+ const float curr_score = candidate_keypoints[i].score_;
+
+ // Scores should be monotonically increasing.
+ SCHECK(last_score >= curr_score,
+ "Quicksort failure! %d: %.5f > %d: %.5f (%d total)",
+ i - 1, last_score, i, curr_score, num_candidates);
+
+ last_score = curr_score;
+ }
+#endif
+}
+
+
+int KeypointDetector::SelectKeypointsInBox(
+ const BoundingBox& box,
+ const Keypoint* const candidate_keypoints,
+ const int num_candidates,
+ const int max_keypoints,
+ const int num_existing_keypoints,
+ const Keypoint* const existing_keypoints,
+ Keypoint* const final_keypoints) const {
+ if (max_keypoints <= 0) {
+ return 0;
+ }
+
+ // This is the distance within which keypoints may be placed to each other
+ // within this box, roughly based on the box dimensions.
+ const int distance =
+ MAX(1, MIN(box.GetWidth(), box.GetHeight()) * kClosestPercent / 2.0f);
+
+ // First, mark keypoints that already happen to be inside this region. Ignore
+ // keypoints that are outside it, however close they might be.
+ interest_map_->Clear(false);
+ for (int i = 0; i < num_existing_keypoints; ++i) {
+ const Keypoint& candidate = existing_keypoints[i];
+
+ const int x_pos = candidate.pos_.x;
+ const int y_pos = candidate.pos_.y;
+ if (box.Contains(candidate.pos_)) {
+ MarkImage(x_pos, y_pos, distance, interest_map_.get());
+ }
+ }
+
+ // Now, go through and check which keypoints will still fit in the box.
+ int num_keypoints_selected = 0;
+ for (int i = 0; i < num_candidates; ++i) {
+ const Keypoint& candidate = candidate_keypoints[i];
+
+ const int x_pos = candidate.pos_.x;
+ const int y_pos = candidate.pos_.y;
+
+ if (!box.Contains(candidate.pos_) ||
+ !interest_map_->ValidPixel(x_pos, y_pos)) {
+ continue;
+ }
+
+ if (!(*interest_map_)[y_pos][x_pos]) {
+ final_keypoints[num_keypoints_selected++] = candidate;
+ if (num_keypoints_selected >= max_keypoints) {
+ break;
+ }
+ MarkImage(x_pos, y_pos, distance, interest_map_.get());
+ }
+ }
+ return num_keypoints_selected;
+}
+
+
+void KeypointDetector::SelectKeypoints(
+ const std::vector<BoundingBox>& boxes,
+ const Keypoint* const candidate_keypoints,
+ const int num_candidates,
+ FramePair* const curr_change) const {
+ // Now select all the interesting keypoints that fall insider our boxes.
+ curr_change->number_of_keypoints_ = 0;
+ for (std::vector<BoundingBox>::const_iterator iter = boxes.begin();
+ iter != boxes.end(); ++iter) {
+ const BoundingBox bounding_box = *iter;
+
+ // Count up keypoints that have already been selected, and fall within our
+ // box.
+ int num_keypoints_already_in_box = 0;
+ for (int i = 0; i < curr_change->number_of_keypoints_; ++i) {
+ if (bounding_box.Contains(curr_change->frame1_keypoints_[i].pos_)) {
+ ++num_keypoints_already_in_box;
+ }
+ }
+
+ const int max_keypoints_to_find_in_box =
+ MIN(kMaxKeypointsForObject - num_keypoints_already_in_box,
+ kMaxKeypoints - curr_change->number_of_keypoints_);
+
+ const int num_new_keypoints_in_box = SelectKeypointsInBox(
+ bounding_box,
+ candidate_keypoints,
+ num_candidates,
+ max_keypoints_to_find_in_box,
+ curr_change->number_of_keypoints_,
+ curr_change->frame1_keypoints_,
+ curr_change->frame1_keypoints_ + curr_change->number_of_keypoints_);
+
+ curr_change->number_of_keypoints_ += num_new_keypoints_in_box;
+
+ LOGV("Selected %d keypoints!", curr_change->number_of_keypoints_);
+ }
+}
+
+
+// Walks along the given circle checking for pixels above or below the center.
+// Returns a score, or 0 if the keypoint did not pass the criteria.
+//
+// Parameters:
+// circle_perimeter: the circumference in pixels of the circle.
+// threshold: the minimum number of contiguous pixels that must be above or
+// below the center value.
+// center_ptr: the location of the center pixel in memory
+// offsets: the relative offsets from the center pixel of the edge pixels.
+inline int TestCircle(const int circle_perimeter, const int threshold,
+ const uint8* const center_ptr,
+ const int* offsets) {
+ // Get the actual value of the center pixel for easier reference later on.
+ const int center_value = static_cast<int>(*center_ptr);
+
+ // Number of total pixels to check. Have to wrap around some in case
+ // the contiguous section is split by the array edges.
+ const int num_total = circle_perimeter + threshold - 1;
+
+ int num_above = 0;
+ int above_diff = 0;
+
+ int num_below = 0;
+ int below_diff = 0;
+
+ // Used to tell when this is definitely not going to meet the threshold so we
+ // can early abort.
+ int minimum_by_now = threshold - num_total + 1;
+
+ // Go through every pixel along the perimeter of the circle, and then around
+ // again a little bit.
+ for (int i = 0; i < num_total; ++i) {
+ // This should be faster than mod.
+ const int perim_index = i < circle_perimeter ? i : i - circle_perimeter;
+
+ // This gets the value of the current pixel along the perimeter by using
+ // a precomputed offset.
+ const int curr_value =
+ static_cast<int>(center_ptr[offsets[perim_index]]);
+
+ const int difference = curr_value - center_value;
+
+ if (difference > kFastDiffAmount) {
+ above_diff += difference;
+ ++num_above;
+
+ num_below = 0;
+ below_diff = 0;
+
+ if (num_above >= threshold) {
+ return above_diff;
+ }
+ } else if (difference < -kFastDiffAmount) {
+ below_diff += difference;
+ ++num_below;
+
+ num_above = 0;
+ above_diff = 0;
+
+ if (num_below >= threshold) {
+ return below_diff;
+ }
+ } else {
+ num_above = 0;
+ num_below = 0;
+ above_diff = 0;
+ below_diff = 0;
+ }
+
+ // See if there's any chance of making the threshold.
+ if (MAX(num_above, num_below) < minimum_by_now) {
+ // Didn't pass.
+ return 0;
+ }
+ ++minimum_by_now;
+ }
+
+ // Didn't pass.
+ return 0;
+}
+
+
+// Returns a score in the range [0.0, positive infinity) which represents the
+// relative likelihood of a point being a corner.
+float KeypointDetector::HarrisFilter(const Image<int32>& I_x,
+ const Image<int32>& I_y,
+ const float x, const float y) const {
+ if (I_x.ValidInterpPixel(x - kHarrisWindowSize, y - kHarrisWindowSize) &&
+ I_x.ValidInterpPixel(x + kHarrisWindowSize, y + kHarrisWindowSize)) {
+ // Image gradient matrix.
+ float G[] = { 0, 0, 0, 0 };
+ CalculateG(kHarrisWindowSize, x, y, I_x, I_y, G);
+
+ const float dx = G[0];
+ const float dy = G[3];
+ const float dxy = G[1];
+
+ // Harris-Nobel corner score.
+ return (dx * dy - Square(dxy)) / (dx + dy + FLT_MIN);
+ }
+
+ return 0.0f;
+}
+
+
+int KeypointDetector::AddExtraCandidatesForBoxes(
+ const std::vector<BoundingBox>& boxes,
+ const int max_num_keypoints,
+ Keypoint* const keypoints) const {
+ int num_keypoints_added = 0;
+
+ for (std::vector<BoundingBox>::const_iterator iter = boxes.begin();
+ iter != boxes.end(); ++iter) {
+ const BoundingBox box = *iter;
+
+ for (int i = 0; i < kNumToAddAsCandidates; ++i) {
+ for (int j = 0; j < kNumToAddAsCandidates; ++j) {
+ if (num_keypoints_added >= max_num_keypoints) {
+ LOGW("Hit cap of %d for temporary keypoints!", max_num_keypoints);
+ return num_keypoints_added;
+ }
+
+ Keypoint curr_keypoint = keypoints[num_keypoints_added++];
+ curr_keypoint.pos_ = Point2f(
+ box.left_ + box.GetWidth() * (i + 0.5f) / kNumToAddAsCandidates,
+ box.top_ + box.GetHeight() * (j + 0.5f) / kNumToAddAsCandidates);
+ curr_keypoint.type_ = KEYPOINT_TYPE_INTEREST;
+ }
+ }
+ }
+
+ return num_keypoints_added;
+}
+
+
+void KeypointDetector::FindKeypoints(const ImageData& image_data,
+ const std::vector<BoundingBox>& rois,
+ const FramePair& prev_change,
+ FramePair* const curr_change) {
+ // Copy keypoints from second frame of last pass to temp keypoints of this
+ // pass.
+ int number_of_tmp_keypoints = CopyKeypoints(prev_change, tmp_keypoints_);
+
+ const int max_num_fast = kMaxTempKeypoints - number_of_tmp_keypoints;
+ number_of_tmp_keypoints +=
+ FindFastKeypoints(image_data, max_num_fast,
+ tmp_keypoints_ + number_of_tmp_keypoints);
+
+ TimeLog("Found FAST keypoints");
+
+ if (number_of_tmp_keypoints >= kMaxTempKeypoints) {
+ LOGW("Hit cap of %d for temporary keypoints (FAST)! %d keypoints",
+ kMaxTempKeypoints, number_of_tmp_keypoints);
+ }
+
+ if (kAddArbitraryKeypoints) {
+ // Add some for each object prior to scoring.
+ const int max_num_box_keypoints =
+ kMaxTempKeypoints - number_of_tmp_keypoints;
+ number_of_tmp_keypoints +=
+ AddExtraCandidatesForBoxes(rois, max_num_box_keypoints,
+ tmp_keypoints_ + number_of_tmp_keypoints);
+ TimeLog("Added box keypoints");
+
+ if (number_of_tmp_keypoints >= kMaxTempKeypoints) {
+ LOGW("Hit cap of %d for temporary keypoints (boxes)! %d keypoints",
+ kMaxTempKeypoints, number_of_tmp_keypoints);
+ }
+ }
+
+ // Score them...
+ LOGV("Scoring %d keypoints!", number_of_tmp_keypoints);
+ ScoreKeypoints(image_data, number_of_tmp_keypoints, tmp_keypoints_);
+ TimeLog("Scored keypoints");
+
+ // Now pare it down a bit.
+ SortKeypoints(number_of_tmp_keypoints, tmp_keypoints_);
+ TimeLog("Sorted keypoints");
+
+ LOGV("%d keypoints to select from!", number_of_tmp_keypoints);
+
+ SelectKeypoints(rois, tmp_keypoints_, number_of_tmp_keypoints, curr_change);
+ TimeLog("Selected keypoints");
+
+ LOGV("Picked %d (%d max) final keypoints out of %d potential.",
+ curr_change->number_of_keypoints_,
+ kMaxKeypoints, number_of_tmp_keypoints);
+}
+
+
+int KeypointDetector::CopyKeypoints(const FramePair& prev_change,
+ Keypoint* const new_keypoints) {
+ int number_of_keypoints = 0;
+
+ // Caching values from last pass, just copy and compact.
+ for (int i = 0; i < prev_change.number_of_keypoints_; ++i) {
+ if (prev_change.optical_flow_found_keypoint_[i]) {
+ new_keypoints[number_of_keypoints] =
+ prev_change.frame2_keypoints_[i];
+
+ new_keypoints[number_of_keypoints].score_ =
+ prev_change.frame1_keypoints_[i].score_;
+
+ ++number_of_keypoints;
+ }
+ }
+
+ TimeLog("Copied keypoints");
+ return number_of_keypoints;
+}
+
+
+// FAST keypoint detector.
+int KeypointDetector::FindFastKeypoints(const Image<uint8>& frame,
+ const int quadrant,
+ const int downsample_factor,
+ const int max_num_keypoints,
+ Keypoint* const keypoints) {
+ /*
+ // Reference for a circle of diameter 7.
+ const int circle[] = {0, 0, 1, 1, 1, 0, 0,
+ 0, 1, 0, 0, 0, 1, 0,
+ 1, 0, 0, 0, 0, 0, 1,
+ 1, 0, 0, 0, 0, 0, 1,
+ 1, 0, 0, 0, 0, 0, 1,
+ 0, 1, 0, 0, 0, 1, 0,
+ 0, 0, 1, 1, 1, 0, 0};
+ const int circle_offset[] =
+ {2, 3, 4, 8, 12, 14, 20, 21, 27, 28, 34, 36, 40, 44, 45, 46};
+ */
+
+ // Quick test of compass directions. Any length 16 circle with a break of up
+ // to 4 pixels will have at least 3 of these 4 pixels active.
+ static const int short_circle_perimeter = 4;
+ static const int short_threshold = 3;
+ static const int short_circle_x[] = { -3, 0, +3, 0 };
+ static const int short_circle_y[] = { 0, -3, 0, +3 };
+
+ // Precompute image offsets.
+ int short_offsets[short_circle_perimeter];
+ for (int i = 0; i < short_circle_perimeter; ++i) {
+ short_offsets[i] = short_circle_x[i] + short_circle_y[i] * frame.GetWidth();
+ }
+
+ // Large circle values.
+ static const int full_circle_perimeter = 16;
+ static const int full_threshold = 12;
+ static const int full_circle_x[] =
+ { -1, 0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2, -3, -3, -3, -2 };
+ static const int full_circle_y[] =
+ { -3, -3, -3, -2, -1, 0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2 };
+
+ // Precompute image offsets.
+ int full_offsets[full_circle_perimeter];
+ for (int i = 0; i < full_circle_perimeter; ++i) {
+ full_offsets[i] = full_circle_x[i] + full_circle_y[i] * frame.GetWidth();
+ }
+
+ const int scratch_stride = frame.stride();
+
+ keypoint_scratch_->Clear(0);
+
+ // Set up the bounds on the region to test based on the passed-in quadrant.
+ const int quadrant_width = (frame.GetWidth() / 2) - kFastBorderBuffer;
+ const int quadrant_height = (frame.GetHeight() / 2) - kFastBorderBuffer;
+ const int start_x =
+ kFastBorderBuffer + ((quadrant % 2 == 0) ? 0 : quadrant_width);
+ const int start_y =
+ kFastBorderBuffer + ((quadrant < 2) ? 0 : quadrant_height);
+ const int end_x = start_x + quadrant_width;
+ const int end_y = start_y + quadrant_height;
+
+ // Loop through once to find FAST keypoint clumps.
+ for (int img_y = start_y; img_y < end_y; ++img_y) {
+ const uint8* curr_pixel_ptr = frame[img_y] + start_x;
+
+ for (int img_x = start_x; img_x < end_x; ++img_x) {
+ // Only insert it if it meets the quick minimum requirements test.
+ if (TestCircle(short_circle_perimeter, short_threshold,
+ curr_pixel_ptr, short_offsets) != 0) {
+ // Longer test for actual keypoint score..
+ const int fast_score = TestCircle(full_circle_perimeter,
+ full_threshold,
+ curr_pixel_ptr,
+ full_offsets);
+
+ // Non-zero score means the keypoint was found.
+ if (fast_score != 0) {
+ uint8* const center_ptr = (*keypoint_scratch_)[img_y] + img_x;
+
+ // Increase the keypoint count on this pixel and the pixels in all
+ // 4 cardinal directions.
+ *center_ptr += 5;
+ *(center_ptr - 1) += 1;
+ *(center_ptr + 1) += 1;
+ *(center_ptr - scratch_stride) += 1;
+ *(center_ptr + scratch_stride) += 1;
+ }
+ }
+
+ ++curr_pixel_ptr;
+ } // x
+ } // y
+
+ TimeLog("Found FAST keypoints.");
+
+ int num_keypoints = 0;
+ // Loop through again and Harris filter pixels in the center of clumps.
+ // We can shrink the window by 1 pixel on every side.
+ for (int img_y = start_y + 1; img_y < end_y - 1; ++img_y) {
+ const uint8* curr_pixel_ptr = (*keypoint_scratch_)[img_y] + start_x;
+
+ for (int img_x = start_x + 1; img_x < end_x - 1; ++img_x) {
+ if (*curr_pixel_ptr >= kMinNumConnectedForFastKeypoint) {
+ Keypoint* const keypoint = keypoints + num_keypoints;
+ keypoint->pos_ = Point2f(
+ img_x * downsample_factor, img_y * downsample_factor);
+ keypoint->score_ = 0;
+ keypoint->type_ = KEYPOINT_TYPE_FAST;
+
+ ++num_keypoints;
+ if (num_keypoints >= max_num_keypoints) {
+ return num_keypoints;
+ }
+ }
+
+ ++curr_pixel_ptr;
+ } // x
+ } // y
+
+ TimeLog("Picked FAST keypoints.");
+
+ return num_keypoints;
+}
+
+int KeypointDetector::FindFastKeypoints(const ImageData& image_data,
+ const int max_num_keypoints,
+ Keypoint* const keypoints) {
+ int downsample_factor = 1;
+ int num_found = 0;
+
+ // TODO(andrewharp): Get this working for multiple image scales.
+ for (int i = 0; i < 1; ++i) {
+ const Image<uint8>& frame = *image_data.GetPyramidSqrt2Level(i);
+ num_found += FindFastKeypoints(
+ frame, fast_quadrant_,
+ downsample_factor, max_num_keypoints, keypoints + num_found);
+ downsample_factor *= 2;
+ }
+
+ // Increment the current quadrant.
+ fast_quadrant_ = (fast_quadrant_ + 1) % 4;
+
+ return num_found;
+}
+
+} // namespace tf_tracking
diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h
new file mode 100644
index 0000000000..6cdd5dde11
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h
@@ -0,0 +1,133 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
+#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
+
+using namespace tensorflow;
+
+namespace tf_tracking {
+
+struct Keypoint;
+
+class KeypointDetector {
+ public:
+ explicit KeypointDetector(const KeypointDetectorConfig* const config)
+ : config_(config),
+ keypoint_scratch_(new Image<uint8>(config_->image_size)),
+ interest_map_(new Image<bool>(config_->image_size)),
+ fast_quadrant_(0) {
+ interest_map_->Clear(false);
+ }
+
+ ~KeypointDetector() {}
+
+ // Finds a new set of keypoints for the current frame, picked from the current
+ // set of keypoints and also from a set discovered via a keypoint detector.
+ // Special attention is applied to make sure that keypoints are distributed
+ // within the supplied ROIs.
+ void FindKeypoints(const ImageData& image_data,
+ const std::vector<BoundingBox>& rois,
+ const FramePair& prev_change,
+ FramePair* const curr_change);
+
+ private:
+ // Compute the corneriness of a point in the image.
+ float HarrisFilter(const Image<int32>& I_x, const Image<int32>& I_y,
+ const float x, const float y) const;
+
+ // Adds a grid of candidate keypoints to the given box, up to
+ // max_num_keypoints or kNumToAddAsCandidates^2, whichever is lower.
+ int AddExtraCandidatesForBoxes(
+ const std::vector<BoundingBox>& boxes,
+ const int max_num_keypoints,
+ Keypoint* const keypoints) const;
+
+ // Scan the frame for potential keypoints using the FAST keypoint detector.
+ // Quadrant is an argument 0-3 which refers to the quadrant of the image in
+ // which to detect keypoints.
+ int FindFastKeypoints(const Image<uint8>& frame,
+ const int quadrant,
+ const int downsample_factor,
+ const int max_num_keypoints,
+ Keypoint* const keypoints);
+
+ int FindFastKeypoints(const ImageData& image_data,
+ const int max_num_keypoints,
+ Keypoint* const keypoints);
+
+ // Score a bunch of candidate keypoints. Assigns the scores to the input
+ // candidate_keypoints array entries.
+ void ScoreKeypoints(const ImageData& image_data,
+ const int num_candidates,
+ Keypoint* const candidate_keypoints);
+
+ void SortKeypoints(const int num_candidates,
+ Keypoint* const candidate_keypoints) const;
+
+ // Selects a set of keypoints falling within the supplied box such that the
+ // most highly rated keypoints are picked first, and so that none of them are
+ // too close together.
+ int SelectKeypointsInBox(
+ const BoundingBox& box,
+ const Keypoint* const candidate_keypoints,
+ const int num_candidates,
+ const int max_keypoints,
+ const int num_existing_keypoints,
+ const Keypoint* const existing_keypoints,
+ Keypoint* const final_keypoints) const;
+
+ // Selects from the supplied sorted keypoint pool a set of keypoints that will
+ // best cover the given set of boxes, such that each box is covered at a
+ // resolution proportional to its size.
+ void SelectKeypoints(
+ const std::vector<BoundingBox>& boxes,
+ const Keypoint* const candidate_keypoints,
+ const int num_candidates,
+ FramePair* const frame_change) const;
+
+ // Copies and compacts the found keypoints in the second frame of prev_change
+ // into the array at new_keypoints.
+ static int CopyKeypoints(const FramePair& prev_change,
+ Keypoint* const new_keypoints);
+
+ const KeypointDetectorConfig* const config_;
+
+ // Scratch memory for keypoint candidacy detection and non-max suppression.
+ std::unique_ptr<Image<uint8> > keypoint_scratch_;
+
+ // Regions of the image to pay special attention to.
+ std::unique_ptr<Image<bool> > interest_map_;
+
+ // The current quadrant of the image to detect FAST keypoints in.
+ // Keypoint detection is staggered for performance reasons. Every four frames
+ // a full scan of the frame will have been performed.
+ int fast_quadrant_;
+
+ Keypoint tmp_keypoints_[kMaxTempKeypoints];
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/log_streaming.h b/tensorflow/examples/android/jni/object_tracking/log_streaming.h
new file mode 100644
index 0000000000..e68945cc72
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/log_streaming.h
@@ -0,0 +1,37 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_
+
+#include <string.h>
+#include <string>
+
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+
+using namespace tensorflow;
+
+namespace tf_tracking {
+
+#define LOGV(...)
+#define LOGD(...)
+#define LOGI(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__);
+#define LOGW(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__);
+#define LOGE(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__);
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/object_detector.cc b/tensorflow/examples/android/jni/object_tracking/object_detector.cc
new file mode 100644
index 0000000000..7f65716fdf
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/object_detector.cc
@@ -0,0 +1,27 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// NOTE: no native object detectors are currently provided or used by the code
+// in this directory. This class remains mainly for historical reasons.
+// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.
+
+#include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
+
+namespace tf_tracking {
+
+// This is here so that the vtable gets created properly.
+ObjectDetectorBase::~ObjectDetectorBase() {}
+
+} // namespace tf_tracking
diff --git a/tensorflow/examples/android/jni/object_tracking/object_detector.h b/tensorflow/examples/android/jni/object_tracking/object_detector.h
new file mode 100644
index 0000000000..043f606e1d
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/object_detector.h
@@ -0,0 +1,232 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// NOTE: no native object detectors are currently provided or used by the code
+// in this directory. This class remains mainly for historical reasons.
+// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.
+
+// Defines the ObjectDetector class that is the main interface for detecting
+// ObjectModelBases in frames.
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
+
+#include <float.h>
+#include <map>
+#include <memory>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
+#ifdef __RENDER_OPENGL__
+#include "tensorflow/examples/android/jni/object_tracking/sprite.h"
+#endif
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
+#include "tensorflow/examples/android/jni/object_tracking/object_model.h"
+
+namespace tf_tracking {
+
+// Adds BoundingSquares to a vector such that the first square added is centered
+// in the position given and of square_size, and the remaining squares are added
+// concentrentically, scaling down by scale_factor until the minimum threshold
+// size is passed.
+// Squares that do not fall completely within image_bounds will not be added.
+static inline void FillWithSquares(
+ const BoundingBox& image_bounds,
+ const BoundingBox& position,
+ const float starting_square_size,
+ const float smallest_square_size,
+ const float scale_factor,
+ std::vector<BoundingSquare>* const squares) {
+ BoundingSquare descriptor_area =
+ GetCenteredSquare(position, starting_square_size);
+
+ SCHECK(scale_factor < 1.0f, "Scale factor too large at %.2f!", scale_factor);
+
+ // Use a do/while loop to ensure that at least one descriptor is created.
+ do {
+ if (image_bounds.Contains(descriptor_area.ToBoundingBox())) {
+ squares->push_back(descriptor_area);
+ }
+ descriptor_area.Scale(scale_factor);
+ } while (descriptor_area.size_ >= smallest_square_size - EPSILON);
+ LOGV("Created %zu squares starting from size %.2f to min size %.2f "
+ "using scale factor: %.2f",
+ squares->size(), starting_square_size, smallest_square_size,
+ scale_factor);
+}
+
+
+// Represents a potential detection of a specific ObjectExemplar and Descriptor
+// at a specific position in the image.
+class Detection {
+ public:
+ explicit Detection(const ObjectModelBase* const object_model,
+ const MatchScore match_score,
+ const BoundingBox& bounding_box)
+ : object_model_(object_model),
+ match_score_(match_score),
+ bounding_box_(bounding_box) {}
+
+ Detection(const Detection& other)
+ : object_model_(other.object_model_),
+ match_score_(other.match_score_),
+ bounding_box_(other.bounding_box_) {}
+
+ virtual ~Detection() {}
+
+ inline BoundingBox GetObjectBoundingBox() const {
+ return bounding_box_;
+ }
+
+ inline MatchScore GetMatchScore() const {
+ return match_score_;
+ }
+
+ inline const ObjectModelBase* GetObjectModel() const {
+ return object_model_;
+ }
+
+ inline bool Intersects(const Detection& other) {
+ // Check if any of the four axes separates us, there must be at least one.
+ return bounding_box_.Intersects(other.bounding_box_);
+ }
+
+ struct Comp {
+ inline bool operator()(const Detection& a, const Detection& b) const {
+ return a.match_score_ > b.match_score_;
+ }
+ };
+
+ // TODO(andrewharp): add accessors to update these instead.
+ const ObjectModelBase* object_model_;
+ MatchScore match_score_;
+ BoundingBox bounding_box_;
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const Detection& detection) {
+ const BoundingBox actual_area = detection.GetObjectBoundingBox();
+ stream << actual_area;
+ return stream;
+}
+
+class ObjectDetectorBase {
+ public:
+ explicit ObjectDetectorBase(const ObjectDetectorConfig* const config)
+ : config_(config),
+ image_data_(NULL) {}
+
+ virtual ~ObjectDetectorBase();
+
+ // Sets the current image data. All calls to ObjectDetector other than
+ // FillDescriptors use the image data last set.
+ inline void SetImageData(const ImageData* const image_data) {
+ image_data_ = image_data;
+ }
+
+ // Main entry point into the detection algorithm.
+ // Scans the frame for candidates, tweaks them, and fills in the
+ // given std::vector of Detection objects with acceptable matches.
+ virtual void Detect(const std::vector<BoundingSquare>& positions,
+ std::vector<Detection>* const detections) const = 0;
+
+ virtual ObjectModelBase* CreateObjectModel(const std::string& name) = 0;
+
+ virtual void DeleteObjectModel(const std::string& name) = 0;
+
+ virtual void GetObjectModels(
+ std::vector<const ObjectModelBase*>* models) const = 0;
+
+ // Creates a new ObjectExemplar from the given position in the context of
+ // the last frame passed to NextFrame.
+ // Will return null in the case that there's no room for a descriptor to be
+ // created in the example area, or the example area is not completely
+ // contained within the frame.
+ virtual void UpdateModel(
+ const Image<uint8>& base_image,
+ const IntegralImage& integral_image,
+ const BoundingBox& bounding_box,
+ const bool locked,
+ ObjectModelBase* model) const = 0;
+
+ virtual void Draw() const = 0;
+
+ virtual bool AllowSpontaneousDetections() = 0;
+
+ protected:
+ const std::unique_ptr<const ObjectDetectorConfig> config_;
+
+ // The latest frame data, upon which all detections will be performed.
+ // Not owned by this object, just provided for reference by ObjectTracker
+ // via SetImageData().
+ const ImageData* image_data_;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetectorBase);
+};
+
+template <typename ModelType>
+class ObjectDetector : public ObjectDetectorBase {
+ public:
+ explicit ObjectDetector(const ObjectDetectorConfig* const config)
+ : ObjectDetectorBase(config) {}
+
+ virtual ~ObjectDetector() {
+ typename std::map<std::string, ModelType*>::const_iterator it =
+ object_models_.begin();
+ for (; it != object_models_.end(); ++it) {
+ ModelType* model = it->second;
+ delete model;
+ }
+ }
+
+ virtual void DeleteObjectModel(const std::string& name) {
+ ModelType* model = object_models_[name];
+ CHECK_ALWAYS(model != NULL, "Model was null!");
+ object_models_.erase(name);
+ SAFE_DELETE(model);
+ }
+
+ virtual void GetObjectModels(
+ std::vector<const ObjectModelBase*>* models) const {
+ typename std::map<std::string, ModelType*>::const_iterator it =
+ object_models_.begin();
+ for (; it != object_models_.end(); ++it) {
+ models->push_back(it->second);
+ }
+ }
+
+ virtual bool AllowSpontaneousDetections() {
+ return false;
+ }
+
+ protected:
+ std::map<std::string, ModelType*> object_models_;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetector);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/object_model.h b/tensorflow/examples/android/jni/object_tracking/object_model.h
new file mode 100644
index 0000000000..2d359668b2
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/object_model.h
@@ -0,0 +1,101 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// NOTE: no native object detectors are currently provided or used by the code
+// in this directory. This class remains mainly for historical reasons.
+// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.
+
+// Contains ObjectModelBase declaration.
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_
+
+#ifdef __RENDER_OPENGL__
+#include <GLES/gl.h>
+#include <GLES/glext.h>
+#endif
+
+#include <vector>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
+#ifdef __RENDER_OPENGL__
+#include "tensorflow/examples/android/jni/object_tracking/sprite.h"
+#endif
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
+
+namespace tf_tracking {
+
+// The ObjectModelBase class represents all the known appearance information for
+// an object. It is not a specific instance of the object in the world,
+// but just the general appearance information that enables detection. An
+// ObjectModelBase can be reused across multiple-instances of TrackedObjects.
+class ObjectModelBase {
+ public:
+ ObjectModelBase(const std::string& name) : name_(name) {}
+
+ virtual ~ObjectModelBase() {}
+
+ // Called when the next step in an ongoing track occurs.
+ virtual void TrackStep(
+ const BoundingBox& position, const Image<uint8>& image,
+ const IntegralImage& integral_image, const bool authoritative) {}
+
+ // Called when an object track is lost.
+ virtual void TrackLost() {}
+
+ // Called when an object track is confirmed as legitimate.
+ virtual void TrackConfirmed() {}
+
+ virtual float GetMaxCorrelation(const Image<float>& patch_image) const = 0;
+
+ virtual MatchScore GetMatchScore(
+ const BoundingBox& position, const ImageData& image_data) const = 0;
+
+ virtual void Draw(float* const depth) const = 0;
+
+ inline const std::string& GetName() const {
+ return name_;
+ }
+
+ protected:
+ const std::string name_;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ObjectModelBase);
+};
+
+template <typename DetectorType>
+class ObjectModel : public ObjectModelBase {
+ public:
+ ObjectModel<DetectorType>(const DetectorType* const detector,
+ const std::string& name)
+ : ObjectModelBase(name), detector_(detector) {}
+
+ protected:
+ const DetectorType* const detector_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ObjectModel<DetectorType>);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker.cc b/tensorflow/examples/android/jni/object_tracking/object_tracker.cc
new file mode 100644
index 0000000000..1d867b934b
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/object_tracker.cc
@@ -0,0 +1,690 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef __RENDER_OPENGL__
+#include <GLES/gl.h>
+#include <GLES/glext.h>
+#endif
+
+#include <string>
+#include <map>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
+#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h"
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
+#include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
+#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h"
+#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
+
+namespace tf_tracking {
+
+ObjectTracker::ObjectTracker(const TrackerConfig* const config,
+ ObjectDetectorBase* const detector)
+ : config_(config),
+ frame_width_(config->image_size.width),
+ frame_height_(config->image_size.height),
+ curr_time_(0),
+ num_frames_(0),
+ flow_cache_(&config->flow_config),
+ keypoint_detector_(&config->keypoint_detector_config),
+ curr_num_frame_pairs_(0),
+ first_frame_index_(0),
+ frame1_(new ImageData(frame_width_, frame_height_)),
+ frame2_(new ImageData(frame_width_, frame_height_)),
+ detector_(detector),
+ num_detected_(0) {
+ for (int i = 0; i < kNumFrames; ++i) {
+ frame_pairs_[i].Init(-1, -1);
+ }
+}
+
+
+ObjectTracker::~ObjectTracker() {
+ for (TrackedObjectMap::iterator iter = objects_.begin();
+ iter != objects_.end(); iter++) {
+ TrackedObject* object = iter->second;
+ SAFE_DELETE(object);
+ }
+}
+
+
+// Finds the correspondences for all the points in the current pair of frames.
+// Stores the results in the given FramePair.
+void ObjectTracker::FindCorrespondences(FramePair* const frame_pair) const {
+ // Keypoints aren't found until they're found.
+ memset(frame_pair->optical_flow_found_keypoint_, false,
+ sizeof(*frame_pair->optical_flow_found_keypoint_) * kMaxKeypoints);
+ TimeLog("Cleared old found keypoints");
+
+ int num_keypoints_found = 0;
+
+ // For every keypoint...
+ for (int i_feat = 0; i_feat < frame_pair->number_of_keypoints_; ++i_feat) {
+ Keypoint* const keypoint1 = frame_pair->frame1_keypoints_ + i_feat;
+ Keypoint* const keypoint2 = frame_pair->frame2_keypoints_ + i_feat;
+
+ if (flow_cache_.FindNewPositionOfPoint(
+ keypoint1->pos_.x, keypoint1->pos_.y,
+ &keypoint2->pos_.x, &keypoint2->pos_.y)) {
+ frame_pair->optical_flow_found_keypoint_[i_feat] = true;
+ ++num_keypoints_found;
+ }
+ }
+
+ TimeLog("Found correspondences");
+
+ LOGV("Found %d of %d keypoint correspondences",
+ num_keypoints_found, frame_pair->number_of_keypoints_);
+}
+
+
+void ObjectTracker::NextFrame(const uint8* const new_frame,
+ const uint8* const uv_frame,
+ const int64 timestamp,
+ const float* const alignment_matrix_2x3) {
+ IncrementFrameIndex();
+ LOGV("Received frame %d", num_frames_);
+
+ FramePair* const curr_change = frame_pairs_ + GetNthIndexFromEnd(0);
+ curr_change->Init(curr_time_, timestamp);
+
+ CHECK_ALWAYS(curr_time_ < timestamp,
+ "Timestamp must monotonically increase! Went from %lld to %lld"
+ " on frame %d.",
+ curr_time_, timestamp, num_frames_);
+ curr_time_ = timestamp;
+
+ // Swap the frames.
+ frame1_.swap(frame2_);
+
+ frame2_->SetData(new_frame, uv_frame, frame_width_, timestamp, 1);
+
+ if (detector_.get() != NULL) {
+ detector_->SetImageData(frame2_.get());
+ }
+
+ flow_cache_.NextFrame(frame2_.get(), alignment_matrix_2x3);
+
+ if (num_frames_ == 1) {
+ // This must be the first frame, so abort.
+ return;
+ }
+
+ if (config_->always_track || objects_.size() > 0) {
+ LOGV("Tracking %zu targets", objects_.size());
+ ComputeKeypoints(true);
+ TimeLog("Keypoints computed!");
+
+ FindCorrespondences(curr_change);
+ TimeLog("Flow computed!");
+
+ TrackObjects();
+ }
+ TimeLog("Targets tracked!");
+
+ if (detector_.get() != NULL && num_frames_ % kDetectEveryNFrames == 0) {
+ DetectTargets();
+ }
+ TimeLog("Detected objects.");
+}
+
+
+TrackedObject* ObjectTracker::MaybeAddObject(
+ const std::string& id,
+ const Image<uint8>& source_image,
+ const BoundingBox& bounding_box,
+ const ObjectModelBase* object_model) {
+ // Train the detector if this is a new object.
+ if (objects_.find(id) != objects_.end()) {
+ return objects_[id];
+ }
+
+ // Need to get a non-const version of the model, or create a new one if it
+ // wasn't given.
+ ObjectModelBase* model = NULL;
+ if (detector_ != NULL) {
+ // If a detector is registered, then this new object must have a model.
+ CHECK_ALWAYS(object_model != NULL, "No model given!");
+ model = detector_->CreateObjectModel(object_model->GetName());
+ }
+ TrackedObject* const object =
+ new TrackedObject(id, source_image, bounding_box, model);
+
+ objects_[id] = object;
+ return object;
+}
+
+
+void ObjectTracker::RegisterNewObjectWithAppearance(
+ const std::string& id, const uint8* const new_frame,
+ const BoundingBox& bounding_box) {
+ ObjectModelBase* object_model = NULL;
+
+ Image<uint8> image(frame_width_, frame_height_);
+ image.FromArray(new_frame, frame_width_, 1);
+
+ if (detector_ != NULL) {
+ object_model = detector_->CreateObjectModel(id);
+ CHECK_ALWAYS(object_model != NULL, "Null object model!");
+
+ const IntegralImage integral_image(image);
+ object_model->TrackStep(bounding_box, image, integral_image, true);
+ }
+
+ // Create an object at this position.
+ CHECK_ALWAYS(!HaveObject(id), "Already have this object!");
+ if (objects_.find(id) == objects_.end()) {
+ TrackedObject* const object =
+ MaybeAddObject(id, image, bounding_box, object_model);
+ CHECK_ALWAYS(object != NULL, "Object not created!");
+ }
+}
+
+
+void ObjectTracker::SetPreviousPositionOfObject(const std::string& id,
+ const BoundingBox& bounding_box,
+ const int64 timestamp) {
+ CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %lld", timestamp);
+ CHECK_ALWAYS(timestamp <= curr_time_,
+ "Timestamp too great! %lld vs %lld", timestamp, curr_time_);
+
+ TrackedObject* const object = GetObject(id);
+
+ // Track this bounding box from the past to the current time.
+ const BoundingBox current_position = TrackBox(bounding_box, timestamp);
+
+ object->UpdatePosition(current_position, curr_time_, *frame2_, false);
+
+ VLOG(2) << "Set tracked position for " << id << " to " << bounding_box
+ << std::endl;
+}
+
+
+void ObjectTracker::SetCurrentPositionOfObject(
+ const std::string& id, const BoundingBox& bounding_box) {
+ SetPreviousPositionOfObject(id, bounding_box, curr_time_);
+}
+
+
+void ObjectTracker::ForgetTarget(const std::string& id) {
+ LOGV("Forgetting object %s", id.c_str());
+ TrackedObject* const object = GetObject(id);
+ delete object;
+ objects_.erase(id);
+
+ if (detector_ != NULL) {
+ detector_->DeleteObjectModel(id);
+ }
+}
+
+
+int ObjectTracker::GetKeypointsPacked(uint16* const out_data,
+ const float scale) const {
+ const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)];
+ uint16* curr_data = out_data;
+ int num_keypoints = 0;
+
+ for (int i = 0; i < change.number_of_keypoints_; ++i) {
+ if (change.optical_flow_found_keypoint_[i]) {
+ ++num_keypoints;
+ const Point2f& point1 = change.frame1_keypoints_[i].pos_;
+ *curr_data++ = RealToFixed115(point1.x * scale);
+ *curr_data++ = RealToFixed115(point1.y * scale);
+
+ const Point2f& point2 = change.frame2_keypoints_[i].pos_;
+ *curr_data++ = RealToFixed115(point2.x * scale);
+ *curr_data++ = RealToFixed115(point2.y * scale);
+ }
+ }
+
+ return num_keypoints;
+}
+
+
+int ObjectTracker::GetKeypoints(const bool only_found,
+ float* const out_data) const {
+ int curr_keypoint = 0;
+ const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)];
+
+ for (int i = 0; i < change.number_of_keypoints_; ++i) {
+ if (!only_found || change.optical_flow_found_keypoint_[i]) {
+ const int base = curr_keypoint * kKeypointStep;
+ out_data[base + 0] = change.frame1_keypoints_[i].pos_.x;
+ out_data[base + 1] = change.frame1_keypoints_[i].pos_.y;
+
+ out_data[base + 2] =
+ change.optical_flow_found_keypoint_[i] ? 1.0f : -1.0f;
+ out_data[base + 3] = change.frame2_keypoints_[i].pos_.x;
+ out_data[base + 4] = change.frame2_keypoints_[i].pos_.y;
+
+ out_data[base + 5] = change.frame1_keypoints_[i].score_;
+ out_data[base + 6] = change.frame1_keypoints_[i].type_;
+ ++curr_keypoint;
+ }
+ }
+
+ LOGV("Got %d keypoints.", curr_keypoint);
+
+ return curr_keypoint;
+}
+
+
+BoundingBox ObjectTracker::TrackBox(const BoundingBox& region,
+ const FramePair& frame_pair) const {
+ float translation_x;
+ float translation_y;
+
+ float scale_x;
+ float scale_y;
+
+ BoundingBox tracked_box(region);
+ frame_pair.AdjustBox(
+ tracked_box, &translation_x, &translation_y, &scale_x, &scale_y);
+
+ tracked_box.Shift(Point2f(translation_x, translation_y));
+
+ if (scale_x > 0 && scale_y > 0) {
+ tracked_box.Scale(scale_x, scale_y);
+ }
+ return tracked_box;
+}
+
+
+BoundingBox ObjectTracker::TrackBox(const BoundingBox& region,
+ const int64 timestamp) const {
+ CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %lld", timestamp);
+ CHECK_ALWAYS(timestamp <= curr_time_, "Timestamp is in the future!");
+
+ // Anything that ended before the requested timestamp is of no concern to us.
+ bool found_it = false;
+ int num_frames_back = -1;
+ for (int i = 0; i < curr_num_frame_pairs_; ++i) {
+ const FramePair& frame_pair =
+ frame_pairs_[GetNthIndexFromEnd(i)];
+
+ if (frame_pair.end_time_ <= timestamp) {
+ num_frames_back = i - 1;
+
+ if (num_frames_back > 0) {
+ LOGV("Went %d out of %d frames before finding frame. (index: %d)",
+ num_frames_back, curr_num_frame_pairs_, GetNthIndexFromEnd(i));
+ }
+
+ found_it = true;
+ break;
+ }
+ }
+
+ if (!found_it) {
+ LOGW("History did not go back far enough! %lld vs %lld",
+ frame_pairs_[GetNthIndexFromEnd(0)].end_time_ -
+ frame_pairs_[GetNthIndexFromStart(0)].end_time_,
+ frame_pairs_[GetNthIndexFromEnd(0)].end_time_ - timestamp);
+ }
+
+ // Loop over all the frames in the queue, tracking the accumulated delta
+ // of the point from frame to frame. It's possible the point could
+ // go out of frame, but keep tracking as best we can, using points near
+ // the edge of the screen where it went out of bounds.
+ BoundingBox tracked_box(region);
+ for (int i = num_frames_back; i >= 0; --i) {
+ const FramePair& frame_pair = frame_pairs_[GetNthIndexFromEnd(i)];
+ SCHECK(frame_pair.end_time_ >= timestamp, "Frame timestamp was too early!");
+ tracked_box = TrackBox(tracked_box, frame_pair);
+ }
+ return tracked_box;
+}
+
+
+// Converts a row-major 3x3 2d transformation matrix to a column-major 4x4
+// 3d transformation matrix.
+inline void Convert3x3To4x4(
+ const float* const in_matrix, float* const out_matrix) {
+ // X
+ out_matrix[0] = in_matrix[0];
+ out_matrix[1] = in_matrix[3];
+ out_matrix[2] = 0.0f;
+ out_matrix[3] = 0.0f;
+
+ // Y
+ out_matrix[4] = in_matrix[1];
+ out_matrix[5] = in_matrix[4];
+ out_matrix[6] = 0.0f;
+ out_matrix[7] = 0.0f;
+
+ // Z
+ out_matrix[8] = 0.0f;
+ out_matrix[9] = 0.0f;
+ out_matrix[10] = 1.0f;
+ out_matrix[11] = 0.0f;
+
+ // Translation
+ out_matrix[12] = in_matrix[2];
+ out_matrix[13] = in_matrix[5];
+ out_matrix[14] = 0.0f;
+ out_matrix[15] = 1.0f;
+}
+
+
+void ObjectTracker::Draw(const int canvas_width, const int canvas_height,
+ const float* const frame_to_canvas) const {
+#ifdef __RENDER_OPENGL__
+ glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
+
+ glMatrixMode(GL_PROJECTION);
+ glLoadIdentity();
+
+ glOrthof(0.0f, canvas_width, 0.0f, canvas_height, 0.0f, 1.0f);
+
+ // To make Y go the right direction (0 at top of frame).
+ glScalef(1.0f, -1.0f, 1.0f);
+ glTranslatef(0.0f, -canvas_height, 0.0f);
+
+ glMatrixMode(GL_MODELVIEW);
+ glLoadIdentity();
+
+ glPushMatrix();
+
+ // Apply the frame to canvas transformation.
+ static GLfloat transformation[16];
+ Convert3x3To4x4(frame_to_canvas, transformation);
+ glMultMatrixf(transformation);
+
+ // Draw tracked object bounding boxes.
+ for (TrackedObjectMap::const_iterator iter = objects_.begin();
+ iter != objects_.end(); ++iter) {
+ TrackedObject* tracked_object = iter->second;
+ tracked_object->Draw();
+ }
+
+ static const bool kRenderDebugPyramid = false;
+ if (kRenderDebugPyramid) {
+ glColor4f(1.0f, 1.0f, 1.0f, 1.0f);
+ for (int i = 0; i < kNumPyramidLevels * 2; ++i) {
+ Sprite(*frame1_->GetPyramidSqrt2Level(i)).Draw();
+ }
+ }
+
+ static const bool kRenderDebugDerivative = false;
+ if (kRenderDebugDerivative) {
+ glColor4f(1.0f, 1.0f, 1.0f, 1.0f);
+ for (int i = 0; i < kNumPyramidLevels; ++i) {
+ const Image<int32>& dx = *frame1_->GetSpatialX(i);
+ Image<uint8> render_image(dx.GetWidth(), dx.GetHeight());
+ for (int y = 0; y < dx.GetHeight(); ++y) {
+ const int32* dx_ptr = dx[y];
+ uint8* dst_ptr = render_image[y];
+ for (int x = 0; x < dx.GetWidth(); ++x) {
+ *dst_ptr++ = Clip(-(*dx_ptr++), 0, 255);
+ }
+ }
+
+ Sprite(render_image).Draw();
+ }
+ }
+
+ if (detector_ != NULL) {
+ glDisable(GL_CULL_FACE);
+ detector_->Draw();
+ }
+ glPopMatrix();
+#endif
+}
+
+static void AddQuadrants(const BoundingBox& box,
+ std::vector<BoundingBox>* boxes) {
+ const Point2f center = box.GetCenter();
+
+ float x1 = box.left_;
+ float x2 = center.x;
+ float x3 = box.right_;
+
+ float y1 = box.top_;
+ float y2 = center.y;
+ float y3 = box.bottom_;
+
+ // Upper left.
+ boxes->push_back(BoundingBox(x1, y1, x2, y2));
+
+ // Upper right.
+ boxes->push_back(BoundingBox(x2, y1, x3, y2));
+
+ // Bottom left.
+ boxes->push_back(BoundingBox(x1, y2, x2, y3));
+
+ // Bottom right.
+ boxes->push_back(BoundingBox(x2, y2, x3, y3));
+
+ // Whole thing.
+ boxes->push_back(box);
+}
+
+void ObjectTracker::ComputeKeypoints(const bool cached_ok) {
+ const FramePair& prev_change = frame_pairs_[GetNthIndexFromEnd(1)];
+ FramePair* const curr_change = &frame_pairs_[GetNthIndexFromEnd(0)];
+
+ std::vector<BoundingBox> boxes;
+
+ for (TrackedObjectMap::iterator object_iter = objects_.begin();
+ object_iter != objects_.end(); ++object_iter) {
+ BoundingBox box = object_iter->second->GetPosition();
+ box.Scale(config_->object_box_scale_factor_for_features,
+ config_->object_box_scale_factor_for_features);
+ AddQuadrants(box, &boxes);
+ }
+
+ AddQuadrants(frame1_->GetImage()->GetContainingBox(), &boxes);
+
+ keypoint_detector_.FindKeypoints(*frame1_, boxes, prev_change, curr_change);
+}
+
+
+// Given a vector of detections and a model, simply returns the Detection for
+// that model with the highest correlation.
+bool ObjectTracker::GetBestObjectForDetection(
+ const Detection& detection, TrackedObject** match) const {
+ TrackedObject* best_match = NULL;
+ float best_overlap = -FLT_MAX;
+
+ LOGV("Looking for matches in %zu objects!", objects_.size());
+ for (TrackedObjectMap::const_iterator object_iter = objects_.begin();
+ object_iter != objects_.end(); ++object_iter) {
+ TrackedObject* const tracked_object = object_iter->second;
+
+ const float overlap = tracked_object->GetPosition().PascalScore(
+ detection.GetObjectBoundingBox());
+
+ if (!detector_->AllowSpontaneousDetections() &&
+ (detection.GetObjectModel() != tracked_object->GetModel())) {
+ if (overlap > 0.0f) {
+ return false;
+ }
+ continue;
+ }
+
+ const float jump_distance =
+ (tracked_object->GetPosition().GetCenter() -
+ detection.GetObjectBoundingBox().GetCenter()).LengthSquared();
+
+ const float allowed_distance =
+ tracked_object->GetAllowableDistanceSquared();
+
+ LOGV("Distance: %.2f, Allowed distance %.2f, Overlap: %.2f",
+ jump_distance, allowed_distance, overlap);
+
+ // TODO(andrewharp): No need to do this verification twice, eliminate
+ // one of the score checks (the other being in OnDetection).
+ if (jump_distance < allowed_distance &&
+ overlap > best_overlap &&
+ tracked_object->GetMatchScore() + kMatchScoreBuffer <
+ detection.GetMatchScore()) {
+ best_match = tracked_object;
+ best_overlap = overlap;
+ } else if (overlap > 0.0f) {
+ return false;
+ }
+ }
+
+ *match = best_match;
+ return true;
+}
+
+
+void ObjectTracker::ProcessDetections(
+ std::vector<Detection>* const detections) {
+ LOGV("Initial detection done, iterating over %zu detections now.",
+ detections->size());
+
+ const bool spontaneous_detections_allowed =
+ detector_->AllowSpontaneousDetections();
+ for (std::vector<Detection>::const_iterator it = detections->begin();
+ it != detections->end(); ++it) {
+ const Detection& detection = *it;
+ SCHECK(frame2_->GetImage()->Contains(detection.GetObjectBoundingBox()),
+ "Frame does not contain bounding box!");
+
+ TrackedObject* best_match = NULL;
+
+ const bool no_collisions =
+ GetBestObjectForDetection(detection, &best_match);
+
+ // Need to get a non-const version of the model, or create a new one if it
+ // wasn't given.
+ ObjectModelBase* model =
+ const_cast<ObjectModelBase*>(detection.GetObjectModel());
+
+ if (best_match != NULL) {
+ if (model != best_match->GetModel()) {
+ CHECK_ALWAYS(detector_->AllowSpontaneousDetections(),
+ "Model for object changed but spontaneous detections not allowed!");
+ }
+ best_match->OnDetection(model,
+ detection.GetObjectBoundingBox(),
+ detection.GetMatchScore(),
+ curr_time_, *frame2_);
+ } else if (no_collisions && spontaneous_detections_allowed) {
+ if (detection.GetMatchScore() > kMinimumMatchScore) {
+ LOGV("No match, adding it!");
+ const ObjectModelBase* model = detection.GetObjectModel();
+ std::ostringstream ss;
+ // TODO(andrewharp): Generate this in a more general fashion.
+ ss << "hand_" << num_detected_++;
+ std::string object_name = ss.str();
+ MaybeAddObject(object_name, *frame2_->GetImage(),
+ detection.GetObjectBoundingBox(), model);
+ }
+ }
+ }
+}
+
+
+void ObjectTracker::DetectTargets() {
+ // Detect all object model types that we're currently tracking.
+ std::vector<const ObjectModelBase*> object_models;
+ detector_->GetObjectModels(&object_models);
+ if (object_models.size() == 0) {
+ LOGV("No objects to search for, aborting.");
+ return;
+ }
+
+ LOGV("Trying to detect %zu models", object_models.size());
+
+ LOGV("Creating test vector!");
+ std::vector<BoundingSquare> positions;
+
+ for (TrackedObjectMap::iterator object_iter = objects_.begin();
+ object_iter != objects_.end(); ++object_iter) {
+ TrackedObject* const tracked_object = object_iter->second;
+
+#if DEBUG_PREDATOR
+ positions.push_back(GetCenteredSquare(
+ frame2_->GetImage()->GetContainingBox(), 32.0f));
+#else
+ const BoundingBox& position = tracked_object->GetPosition();
+
+ const float square_size = MAX(
+ kScanMinSquareSize / (kLastKnownPositionScaleFactor *
+ kLastKnownPositionScaleFactor),
+ MIN(position.GetWidth(),
+ position.GetHeight())) / kLastKnownPositionScaleFactor;
+
+ FillWithSquares(frame2_->GetImage()->GetContainingBox(),
+ tracked_object->GetPosition(),
+ square_size,
+ kScanMinSquareSize,
+ kLastKnownPositionScaleFactor,
+ &positions);
+ }
+#endif
+
+ LOGV("Created test vector!");
+
+ std::vector<Detection> detections;
+ LOGV("Detecting!");
+ detector_->Detect(positions, &detections);
+ LOGV("Found %zu detections", detections.size());
+
+ TimeLog("Finished detection.");
+
+ ProcessDetections(&detections);
+
+ TimeLog("iterated over detections");
+
+ LOGV("Done detecting!");
+}
+
+
+void ObjectTracker::TrackObjects() {
+ // TODO(andrewharp): Correlation should be allowed to remove objects too.
+ const bool automatic_removal_allowed = detector_.get() != NULL ?
+ detector_->AllowSpontaneousDetections() : false;
+
+ LOGV("Tracking %zu objects!", objects_.size());
+ std::vector<std::string> dead_objects;
+ for (TrackedObjectMap::iterator iter = objects_.begin();
+ iter != objects_.end(); iter++) {
+ TrackedObject* object = iter->second;
+ const BoundingBox tracked_position = TrackBox(
+ object->GetPosition(), frame_pairs_[GetNthIndexFromEnd(0)]);
+ object->UpdatePosition(tracked_position, curr_time_, *frame2_, false);
+
+ if (automatic_removal_allowed &&
+ object->GetNumConsecutiveFramesBelowThreshold() >
+ kMaxNumDetectionFailures * 5) {
+ dead_objects.push_back(iter->first);
+ }
+ }
+
+ if (detector_ != NULL && automatic_removal_allowed) {
+ for (std::vector<std::string>::iterator iter = dead_objects.begin();
+ iter != dead_objects.end(); iter++) {
+ LOGE("Removing object! %s", iter->c_str());
+ ForgetTarget(*iter);
+ }
+ }
+ TimeLog("Tracked all objects.");
+
+ LOGV("%zu objects tracked!", objects_.size());
+}
+
+} // namespace tf_tracking
diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker.h b/tensorflow/examples/android/jni/object_tracking/object_tracker.h
new file mode 100644
index 0000000000..3d2a9af360
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/object_tracker.h
@@ -0,0 +1,271 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
+
+#include <map>
+#include <string>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
+#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h"
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
+#include "tensorflow/examples/android/jni/object_tracking/object_model.h"
+#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
+#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h"
+
+namespace tf_tracking {
+
+typedef std::map<const std::string, TrackedObject*> TrackedObjectMap;
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const TrackedObjectMap& map) {
+ for (TrackedObjectMap::const_iterator iter = map.begin();
+ iter != map.end(); ++iter) {
+ const TrackedObject& tracked_object = *iter->second;
+ const std::string& key = iter->first;
+ stream << key << ": " << tracked_object;
+ }
+ return stream;
+}
+
+
+// ObjectTracker is the highest-level class in the tracking/detection framework.
+// It handles basic image processing, keypoint detection, keypoint tracking,
+// object tracking, and object detection/relocalization.
+class ObjectTracker {
+ public:
+ ObjectTracker(const TrackerConfig* const config,
+ ObjectDetectorBase* const detector);
+ virtual ~ObjectTracker();
+
+ virtual void NextFrame(const uint8* const new_frame,
+ const int64 timestamp,
+ const float* const alignment_matrix_2x3) {
+ NextFrame(new_frame, NULL, timestamp, alignment_matrix_2x3);
+ }
+
+ // Called upon the arrival of a new frame of raw data.
+ // Does all image processing, keypoint detection, and object
+ // tracking/detection for registered objects.
+ // Argument alignment_matrix_2x3 is a 2x3 matrix (stored row-wise) that
+ // represents the main transformation that has happened between the last
+ // and the current frame.
+ // Argument align_level is the pyramid level (where 0 == finest) that
+ // the matrix is valid for.
+ virtual void NextFrame(const uint8* const new_frame,
+ const uint8* const uv_frame,
+ const int64 timestamp,
+ const float* const alignment_matrix_2x3);
+
+ virtual void RegisterNewObjectWithAppearance(
+ const std::string& id, const uint8* const new_frame,
+ const BoundingBox& bounding_box);
+
+ // Updates the position of a tracked object, given that it was known to be at
+ // a certain position at some point in the past.
+ virtual void SetPreviousPositionOfObject(const std::string& id,
+ const BoundingBox& bounding_box,
+ const int64 timestamp);
+
+ // Sets the current position of the object in the most recent frame provided.
+ virtual void SetCurrentPositionOfObject(const std::string& id,
+ const BoundingBox& bounding_box);
+
+ // Tells the ObjectTracker to stop tracking a target.
+ void ForgetTarget(const std::string& id);
+
+ // Fills the given out_data buffer with the latest detected keypoint
+ // correspondences, first scaled by scale_factor (to adjust for downsampling
+ // that may have occurred elsewhere), then packed in a fixed-point format.
+ int GetKeypointsPacked(uint16* const out_data,
+ const float scale_factor) const;
+
+ // Copy the keypoint arrays after computeFlow is called.
+ // out_data should be at least kMaxKeypoints * kKeypointStep long.
+ // Currently, its format is [x1 y1 found x2 y2 score] repeated N times,
+ // where N is the number of keypoints tracked. N is returned as the result.
+ int GetKeypoints(const bool only_found, float* const out_data) const;
+
+ // Returns the current position of a box, given that it was at a certain
+ // position at the given time.
+ BoundingBox TrackBox(const BoundingBox& region,
+ const int64 timestamp) const;
+
+ // Returns the number of frames that have been passed to NextFrame().
+ inline int GetNumFrames() const {
+ return num_frames_;
+ }
+
+ inline bool HaveObject(const std::string& id) const {
+ return objects_.find(id) != objects_.end();
+ }
+
+ // Returns the TrackedObject associated with the given id.
+ inline const TrackedObject* GetObject(const std::string& id) const {
+ TrackedObjectMap::const_iterator iter = objects_.find(id);
+ CHECK_ALWAYS(iter != objects_.end(),
+ "Unknown object key! \"%s\"", id.c_str());
+ TrackedObject* const object = iter->second;
+ return object;
+ }
+
+ // Returns the TrackedObject associated with the given id.
+ inline TrackedObject* GetObject(const std::string& id) {
+ TrackedObjectMap::iterator iter = objects_.find(id);
+ CHECK_ALWAYS(iter != objects_.end(),
+ "Unknown object key! \"%s\"", id.c_str());
+ TrackedObject* const object = iter->second;
+ return object;
+ }
+
+ bool IsObjectVisible(const std::string& id) const {
+ SCHECK(HaveObject(id), "Don't have this object.");
+
+ const TrackedObject* object = GetObject(id);
+ return object->IsVisible();
+ }
+
+ virtual void Draw(const int canvas_width, const int canvas_height,
+ const float* const frame_to_canvas) const;
+
+ protected:
+ // Creates a new tracked object at the given position.
+ // If an object model is provided, then that model will be associated with the
+ // object. If not, a new model may be created from the appearance at the
+ // initial position and registered with the object detector.
+ virtual TrackedObject* MaybeAddObject(const std::string& id,
+ const Image<uint8>& image,
+ const BoundingBox& bounding_box,
+ const ObjectModelBase* object_model);
+
+ // Find the keypoints in the frame before the current frame.
+ // If only one frame exists, keypoints will be found in that frame.
+ void ComputeKeypoints(const bool cached_ok = false);
+
+ // Finds the correspondences for all the points in the current pair of frames.
+ // Stores the results in the given FramePair.
+ void FindCorrespondences(FramePair* const curr_change) const;
+
+ inline int GetNthIndexFromEnd(const int offset) const {
+ return GetNthIndexFromStart(curr_num_frame_pairs_ - 1 - offset);
+ }
+
+ BoundingBox TrackBox(const BoundingBox& region,
+ const FramePair& frame_pair) const;
+
+ inline void IncrementFrameIndex() {
+ // Move the current framechange index up.
+ ++num_frames_;
+ ++curr_num_frame_pairs_;
+
+ // If we've got too many, push up the start of the queue.
+ if (curr_num_frame_pairs_ > kNumFrames) {
+ first_frame_index_ = GetNthIndexFromStart(1);
+ --curr_num_frame_pairs_;
+ }
+ }
+
+ inline int GetNthIndexFromStart(const int offset) const {
+ SCHECK(offset >= 0 && offset < curr_num_frame_pairs_,
+ "Offset out of range! %d out of %d.", offset, curr_num_frame_pairs_);
+ return (first_frame_index_ + offset) % kNumFrames;
+ }
+
+ void TrackObjects();
+
+ const std::unique_ptr<const TrackerConfig> config_;
+
+ const int frame_width_;
+ const int frame_height_;
+
+ int64 curr_time_;
+
+ int num_frames_;
+
+ TrackedObjectMap objects_;
+
+ FlowCache flow_cache_;
+
+ KeypointDetector keypoint_detector_;
+
+ int curr_num_frame_pairs_;
+ int first_frame_index_;
+
+ std::unique_ptr<ImageData> frame1_;
+ std::unique_ptr<ImageData> frame2_;
+
+ FramePair frame_pairs_[kNumFrames];
+
+ std::unique_ptr<ObjectDetectorBase> detector_;
+
+ int num_detected_;
+
+ private:
+ void TrackTarget(TrackedObject* const object);
+
+ bool GetBestObjectForDetection(
+ const Detection& detection, TrackedObject** match) const;
+
+ void ProcessDetections(std::vector<Detection>* const detections);
+
+ void DetectTargets();
+
+ // Temp object used in ObjectTracker::CreateNewExample.
+ mutable std::vector<BoundingSquare> squares;
+
+ friend std::ostream& operator<<(std::ostream& stream,
+ const ObjectTracker& tracker);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ObjectTracker);
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const ObjectTracker& tracker) {
+ stream << "Frame size: " << tracker.frame_width_ << "x"
+ << tracker.frame_height_ << std::endl;
+
+ stream << "Num frames: " << tracker.num_frames_ << std::endl;
+
+ stream << "Curr time: " << tracker.curr_time_ << std::endl;
+
+ const int first_frame_index = tracker.GetNthIndexFromStart(0);
+ const FramePair& first_frame_pair = tracker.frame_pairs_[first_frame_index];
+
+ const int last_frame_index = tracker.GetNthIndexFromEnd(0);
+ const FramePair& last_frame_pair = tracker.frame_pairs_[last_frame_index];
+
+ stream << "first frame: " << first_frame_index << ","
+ << first_frame_pair.end_time_ << " "
+ << "last frame: " << last_frame_index << ","
+ << last_frame_pair.end_time_ << " diff: "
+ << last_frame_pair.end_time_ - first_frame_pair.end_time_ << "ms"
+ << std::endl;
+
+ stream << "Tracked targets:";
+ stream << tracker.objects_;
+
+ return stream;
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc b/tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc
new file mode 100644
index 0000000000..30c5974654
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc
@@ -0,0 +1,463 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <android/log.h>
+#include <jni.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+#include <cstdint>
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/jni_utils.h"
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h"
+
+using namespace tensorflow;
+
+namespace tf_tracking {
+
+#define OBJECT_TRACKER_METHOD(METHOD_NAME) \
+ Java_org_tensorflow_demo_tracking_ObjectTracker_##METHOD_NAME // NOLINT
+
+JniIntField object_tracker_field("nativeObjectTracker");
+
+ObjectTracker* get_object_tracker(JNIEnv* env, jobject thiz) {
+ ObjectTracker* const object_tracker =
+ reinterpret_cast<ObjectTracker*>(object_tracker_field.get(env, thiz));
+ CHECK_ALWAYS(object_tracker != NULL, "null object tracker!");
+ return object_tracker;
+}
+
+void set_object_tracker(JNIEnv* env, jobject thiz,
+ const ObjectTracker* object_tracker) {
+ object_tracker_field.set(env, thiz,
+ reinterpret_cast<intptr_t>(object_tracker));
+}
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz,
+ jint width, jint height,
+ jboolean always_track);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env,
+ jobject thiz);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)(
+ JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
+ jfloat x2, jfloat y2, jbyteArray frame_data);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)(
+ JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
+ jfloat x2, jfloat y2, jlong timestamp);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)(
+ JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
+ jfloat x2, jfloat y2);
+
+JNIEXPORT
+jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz,
+ jstring object_id);
+
+JNIEXPORT
+jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env,
+ jobject thiz,
+ jstring object_id);
+
+JNIEXPORT
+jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env,
+ jobject thiz,
+ jstring object_id);
+
+JNIEXPORT
+jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env,
+ jobject thiz,
+ jstring object_id);
+
+JNIEXPORT
+jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz,
+ jstring object_id);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)(
+ JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz,
+ jbyteArray y_data,
+ jbyteArray uv_data,
+ jlong timestamp,
+ jfloatArray vg_matrix_2x3);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz,
+ jstring object_id);
+
+JNIEXPORT
+jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)(
+ JNIEnv* env, jobject thiz, jfloat scale_factor);
+
+JNIEXPORT
+jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)(
+ JNIEnv* env, jobject thiz, jboolean only_found_);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)(
+ JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1,
+ jfloat position_y1, jfloat position_x2, jfloat position_y2,
+ jfloatArray delta);
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(drawNative)(JNIEnv* env, jobject obj,
+ jint view_width,
+ jint view_height,
+ jfloatArray delta);
+
+JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)(
+ JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride,
+ jbyteArray input, jint factor, jbyteArray output);
+
+#ifdef __cplusplus
+}
+#endif
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz,
+ jint width, jint height,
+ jboolean always_track) {
+ LOGI("Initializing object tracker. %dx%d @%p", width, height, thiz);
+ const Size image_size(width, height);
+ TrackerConfig* const tracker_config = new TrackerConfig(image_size);
+ tracker_config->always_track = always_track;
+
+ // XXX detector
+ ObjectTracker* const tracker = new ObjectTracker(tracker_config, NULL);
+ set_object_tracker(env, thiz, tracker);
+ LOGI("Initialized!");
+
+ CHECK_ALWAYS(get_object_tracker(env, thiz) == tracker,
+ "Failure to set hand tracker!");
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env,
+ jobject thiz) {
+ delete get_object_tracker(env, thiz);
+ set_object_tracker(env, thiz, NULL);
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)(
+ JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
+ jfloat x2, jfloat y2, jbyteArray frame_data) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1,
+ x2, y2);
+
+ jboolean iCopied = JNI_FALSE;
+
+ // Copy image into currFrame.
+ jbyte* pixels = env->GetByteArrayElements(frame_data, &iCopied);
+
+ BoundingBox bounding_box(x1, y1, x2, y2);
+ get_object_tracker(env, thiz)->RegisterNewObjectWithAppearance(
+ id_str, reinterpret_cast<const uint8*>(pixels), bounding_box);
+
+ env->ReleaseByteArrayElements(frame_data, pixels, JNI_ABORT);
+
+ env->ReleaseStringUTFChars(object_id, id_str);
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)(
+ JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
+ jfloat x2, jfloat y2, jlong timestamp) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ LOGI(
+ "Registering the position of %s at %.2f,%.2f,%.2f,%.2f"
+ " at time %lld",
+ id_str, x1, y1, x2, y2, static_cast<int64>(timestamp));
+
+ get_object_tracker(env, thiz)->SetPreviousPositionOfObject(
+ id_str, BoundingBox(x1, y1, x2, y2), timestamp);
+
+ env->ReleaseStringUTFChars(object_id, id_str);
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)(
+ JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
+ jfloat x2, jfloat y2) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1,
+ x2, y2);
+
+ get_object_tracker(env, thiz)->SetCurrentPositionOfObject(
+ id_str, BoundingBox(x1, y1, x2, y2));
+
+ env->ReleaseStringUTFChars(object_id, id_str);
+}
+
+JNIEXPORT
+jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz,
+ jstring object_id) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ const bool haveObject = get_object_tracker(env, thiz)->HaveObject(id_str);
+ env->ReleaseStringUTFChars(object_id, id_str);
+ return haveObject;
+}
+
+JNIEXPORT
+jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env,
+ jobject thiz,
+ jstring object_id) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ const bool visible = get_object_tracker(env, thiz)->IsObjectVisible(id_str);
+ env->ReleaseStringUTFChars(object_id, id_str);
+ return visible;
+}
+
+JNIEXPORT
+jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env,
+ jobject thiz,
+ jstring object_id) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+ const TrackedObject* const object =
+ get_object_tracker(env, thiz)->GetObject(id_str);
+ env->ReleaseStringUTFChars(object_id, id_str);
+ jstring model_name = env->NewStringUTF(object->GetModel()->GetName().c_str());
+ return model_name;
+}
+
+JNIEXPORT
+jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env,
+ jobject thiz,
+ jstring object_id) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ const float correlation =
+ get_object_tracker(env, thiz)->GetObject(id_str)->GetCorrelation();
+ env->ReleaseStringUTFChars(object_id, id_str);
+ return correlation;
+}
+
+JNIEXPORT
+jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz,
+ jstring object_id) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ const float match_score =
+ get_object_tracker(env, thiz)->GetObject(id_str)->GetMatchScore().value;
+ env->ReleaseStringUTFChars(object_id, id_str);
+ return match_score;
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)(
+ JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array) {
+ jboolean iCopied = JNI_FALSE;
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ const BoundingBox bounding_box =
+ get_object_tracker(env, thiz)->GetObject(id_str)->GetPosition();
+ env->ReleaseStringUTFChars(object_id, id_str);
+
+ jfloat* rect = env->GetFloatArrayElements(rect_array, &iCopied);
+ bounding_box.CopyToArray(reinterpret_cast<float*>(rect));
+ env->ReleaseFloatArrayElements(rect_array, rect, 0);
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz,
+ jbyteArray y_data,
+ jbyteArray uv_data,
+ jlong timestamp,
+ jfloatArray vg_matrix_2x3) {
+ TimeLog("Starting object tracker");
+
+ jboolean iCopied = JNI_FALSE;
+
+ float vision_gyro_matrix_array[6];
+ jfloat* jmat = NULL;
+
+ if (vg_matrix_2x3 != NULL) {
+ // Copy the alignment matrix into a float array.
+ jmat = env->GetFloatArrayElements(vg_matrix_2x3, &iCopied);
+ for (int i = 0; i < 6; ++i) {
+ vision_gyro_matrix_array[i] = static_cast<float>(jmat[i]);
+ }
+ }
+ // Copy image into currFrame.
+ jbyte* pixels = env->GetByteArrayElements(y_data, &iCopied);
+ jbyte* uv_pixels =
+ uv_data != NULL ? env->GetByteArrayElements(uv_data, &iCopied) : NULL;
+
+ TimeLog("Got elements");
+
+ // Add the frame to the object tracker object.
+ get_object_tracker(env, thiz)->NextFrame(
+ reinterpret_cast<uint8*>(pixels), reinterpret_cast<uint8*>(uv_pixels),
+ timestamp, vg_matrix_2x3 != NULL ? vision_gyro_matrix_array : NULL);
+
+ env->ReleaseByteArrayElements(y_data, pixels, JNI_ABORT);
+
+ if (uv_data != NULL) {
+ env->ReleaseByteArrayElements(uv_data, uv_pixels, JNI_ABORT);
+ }
+
+ if (vg_matrix_2x3 != NULL) {
+ env->ReleaseFloatArrayElements(vg_matrix_2x3, jmat, JNI_ABORT);
+ }
+
+ TimeLog("Released elements");
+
+ PrintTimeLog();
+ ResetTimeLog();
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz,
+ jstring object_id) {
+ const char* const id_str = env->GetStringUTFChars(object_id, 0);
+
+ get_object_tracker(env, thiz)->ForgetTarget(id_str);
+
+ env->ReleaseStringUTFChars(object_id, id_str);
+}
+
+JNIEXPORT
+jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)(
+ JNIEnv* env, jobject thiz, jboolean only_found) {
+ jfloat keypoint_arr[kMaxKeypoints * kKeypointStep];
+
+ const int number_of_keypoints =
+ get_object_tracker(env, thiz)->GetKeypoints(only_found, keypoint_arr);
+
+ // Create and return the array that will be passed back to Java.
+ jfloatArray keypoints =
+ env->NewFloatArray(number_of_keypoints * kKeypointStep);
+ if (keypoints == NULL) {
+ LOGE("null array!");
+ return NULL;
+ }
+ env->SetFloatArrayRegion(keypoints, 0, number_of_keypoints * kKeypointStep,
+ keypoint_arr);
+
+ return keypoints;
+}
+
+JNIEXPORT
+jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)(
+ JNIEnv* env, jobject thiz, jfloat scale_factor) {
+ // 2 bytes to a uint16 and two pairs of xy coordinates per keypoint.
+ const int bytes_per_keypoint = sizeof(uint16) * 2 * 2;
+ jbyte keypoint_arr[kMaxKeypoints * bytes_per_keypoint];
+
+ const int number_of_keypoints =
+ get_object_tracker(env, thiz)->GetKeypointsPacked(
+ reinterpret_cast<uint16*>(keypoint_arr), scale_factor);
+
+ // Create and return the array that will be passed back to Java.
+ jbyteArray keypoints =
+ env->NewByteArray(number_of_keypoints * bytes_per_keypoint);
+
+ if (keypoints == NULL) {
+ LOGE("null array!");
+ return NULL;
+ }
+
+ env->SetByteArrayRegion(
+ keypoints, 0, number_of_keypoints * bytes_per_keypoint, keypoint_arr);
+
+ return keypoints;
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)(
+ JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1,
+ jfloat position_y1, jfloat position_x2, jfloat position_y2,
+ jfloatArray delta) {
+ jfloat point_arr[4];
+
+ const BoundingBox new_position = get_object_tracker(env, thiz)->TrackBox(
+ BoundingBox(position_x1, position_y1, position_x2, position_y2),
+ timestamp);
+
+ new_position.CopyToArray(point_arr);
+ env->SetFloatArrayRegion(delta, 0, 4, point_arr);
+}
+
+JNIEXPORT
+void JNICALL OBJECT_TRACKER_METHOD(drawNative)(
+ JNIEnv* env, jobject thiz, jint view_width, jint view_height,
+ jfloatArray frame_to_canvas_arr) {
+ ObjectTracker* object_tracker = get_object_tracker(env, thiz);
+ if (object_tracker != NULL) {
+ jfloat* frame_to_canvas =
+ env->GetFloatArrayElements(frame_to_canvas_arr, NULL);
+
+ object_tracker->Draw(view_width, view_height, frame_to_canvas);
+ env->ReleaseFloatArrayElements(frame_to_canvas_arr, frame_to_canvas,
+ JNI_ABORT);
+ }
+}
+
+JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)(
+ JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride,
+ jbyteArray input, jint factor, jbyteArray output) {
+ if (input == NULL || output == NULL) {
+ LOGW("Received null arrays, hopefully this is a test!");
+ return;
+ }
+
+ jbyte* const input_array = env->GetByteArrayElements(input, 0);
+ jbyte* const output_array = env->GetByteArrayElements(output, 0);
+
+ {
+ tf_tracking::Image<uint8> full_image(
+ width, height, reinterpret_cast<uint8*>(input_array), false);
+
+ const int new_width = (width + factor - 1) / factor;
+ const int new_height = (height + factor - 1) / factor;
+
+ tf_tracking::Image<uint8> downsampled_image(
+ new_width, new_height, reinterpret_cast<uint8*>(output_array), false);
+
+ downsampled_image.DownsampleAveraged(reinterpret_cast<uint8*>(input_array),
+ row_stride, factor);
+ }
+
+ env->ReleaseByteArrayElements(input, input_array, JNI_ABORT);
+ env->ReleaseByteArrayElements(output, output_array, 0);
+}
+
+} // namespace tf_tracking
diff --git a/tensorflow/examples/android/jni/object_tracking/optical_flow.cc b/tensorflow/examples/android/jni/object_tracking/optical_flow.cc
new file mode 100644
index 0000000000..fab0a3155d
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/optical_flow.cc
@@ -0,0 +1,490 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <math.h>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
+#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
+#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
+#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
+
+namespace tf_tracking {
+
+OpticalFlow::OpticalFlow(const OpticalFlowConfig* const config)
+ : config_(config),
+ frame1_(NULL),
+ frame2_(NULL),
+ working_size_(config->image_size) {}
+
+
+void OpticalFlow::NextFrame(const ImageData* const image_data) {
+ // Special case for the first frame: make sure the image ends up in
+ // frame1_ so that keypoint detection can be done on it if desired.
+ frame1_ = (frame1_ == NULL) ? image_data : frame2_;
+ frame2_ = image_data;
+}
+
+
+// Static heart of the optical flow computation.
+// Lucas Kanade algorithm.
+bool OpticalFlow::FindFlowAtPoint_LK(const Image<uint8>& img_I,
+ const Image<uint8>& img_J,
+ const Image<int32>& I_x,
+ const Image<int32>& I_y,
+ const float p_x,
+ const float p_y,
+ float* out_g_x,
+ float* out_g_y) {
+ float g_x = *out_g_x;
+ float g_y = *out_g_y;
+ // Get values for frame 1. They remain constant through the inner
+ // iteration loop.
+ float vals_I[kFlowArraySize];
+ float vals_I_x[kFlowArraySize];
+ float vals_I_y[kFlowArraySize];
+
+ const int kPatchSize = 2 * kFlowIntegrationWindowSize + 1;
+ const float kWindowSizeFloat = static_cast<float>(kFlowIntegrationWindowSize);
+
+#if USE_FIXED_POINT_FLOW
+ const int fixed_x_max = RealToFixed1616(img_I.width_less_one_) - 1;
+ const int fixed_y_max = RealToFixed1616(img_I.height_less_one_) - 1;
+#else
+ const float real_x_max = I_x.width_less_one_ - EPSILON;
+ const float real_y_max = I_x.height_less_one_ - EPSILON;
+#endif
+
+ // Get the window around the original point.
+ const float src_left_real = p_x - kWindowSizeFloat;
+ const float src_top_real = p_y - kWindowSizeFloat;
+ float* vals_I_ptr = vals_I;
+ float* vals_I_x_ptr = vals_I_x;
+ float* vals_I_y_ptr = vals_I_y;
+#if USE_FIXED_POINT_FLOW
+ // Source integer coordinates.
+ const int src_left_fixed = RealToFixed1616(src_left_real);
+ const int src_top_fixed = RealToFixed1616(src_top_real);
+
+ for (int y = 0; y < kPatchSize; ++y) {
+ const int fp_y = Clip(src_top_fixed + (y << 16), 0, fixed_y_max);
+
+ for (int x = 0; x < kPatchSize; ++x) {
+ const int fp_x = Clip(src_left_fixed + (x << 16), 0, fixed_x_max);
+
+ *vals_I_ptr++ = img_I.GetPixelInterpFixed1616(fp_x, fp_y);
+ *vals_I_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y);
+ *vals_I_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y);
+ }
+ }
+#else
+ for (int y = 0; y < kPatchSize; ++y) {
+ const float y_pos = Clip(src_top_real + y, 0.0f, real_y_max);
+
+ for (int x = 0; x < kPatchSize; ++x) {
+ const float x_pos = Clip(src_left_real + x, 0.0f, real_x_max);
+
+ *vals_I_ptr++ = img_I.GetPixelInterp(x_pos, y_pos);
+ *vals_I_x_ptr++ = I_x.GetPixelInterp(x_pos, y_pos);
+ *vals_I_y_ptr++ = I_y.GetPixelInterp(x_pos, y_pos);
+ }
+ }
+#endif
+
+ // Compute the spatial gradient matrix about point p.
+ float G[] = { 0, 0, 0, 0 };
+ CalculateG(vals_I_x, vals_I_y, kFlowArraySize, G);
+
+ // Find the inverse of G.
+ float G_inv[4];
+ if (!Invert2x2(G, G_inv)) {
+ return false;
+ }
+
+#if NORMALIZE
+ const float mean_I = ComputeMean(vals_I, kFlowArraySize);
+ const float std_dev_I = ComputeStdDev(vals_I, kFlowArraySize, mean_I);
+#endif
+
+ // Iterate kNumIterations times or until we converge.
+ for (int iteration = 0; iteration < kNumIterations; ++iteration) {
+ // Get values for frame 2.
+ float vals_J[kFlowArraySize];
+
+ // Get the window around the destination point.
+ const float left_real = p_x + g_x - kWindowSizeFloat;
+ const float top_real = p_y + g_y - kWindowSizeFloat;
+ float* vals_J_ptr = vals_J;
+#if USE_FIXED_POINT_FLOW
+ // The top-left sub-pixel is set for the current iteration (in 16:16
+ // fixed). This is constant over one iteration.
+ const int left_fixed = RealToFixed1616(left_real);
+ const int top_fixed = RealToFixed1616(top_real);
+
+ for (int win_y = 0; win_y < kPatchSize; ++win_y) {
+ const int fp_y = Clip(top_fixed + (win_y << 16), 0, fixed_y_max);
+ for (int win_x = 0; win_x < kPatchSize; ++win_x) {
+ const int fp_x = Clip(left_fixed + (win_x << 16), 0, fixed_x_max);
+ *vals_J_ptr++ = img_J.GetPixelInterpFixed1616(fp_x, fp_y);
+ }
+ }
+#else
+ for (int win_y = 0; win_y < kPatchSize; ++win_y) {
+ const float y_pos = Clip(top_real + win_y, 0.0f, real_y_max);
+ for (int win_x = 0; win_x < kPatchSize; ++win_x) {
+ const float x_pos = Clip(left_real + win_x, 0.0f, real_x_max);
+ *vals_J_ptr++ = img_J.GetPixelInterp(x_pos, y_pos);
+ }
+ }
+#endif
+
+#if NORMALIZE
+ const float mean_J = ComputeMean(vals_J, kFlowArraySize);
+ const float std_dev_J = ComputeStdDev(vals_J, kFlowArraySize, mean_J);
+
+ // TODO(andrewharp): Probably better to completely detect and handle the
+ // "corner case" where the patch is fully outside the image diagonally.
+ const float std_dev_ratio = std_dev_J > 0.0f ? std_dev_I / std_dev_J : 1.0f;
+#endif
+
+ // Compute image mismatch vector.
+ float b_x = 0.0f;
+ float b_y = 0.0f;
+
+ vals_I_ptr = vals_I;
+ vals_J_ptr = vals_J;
+ vals_I_x_ptr = vals_I_x;
+ vals_I_y_ptr = vals_I_y;
+
+ for (int win_y = 0; win_y < kPatchSize; ++win_y) {
+ for (int win_x = 0; win_x < kPatchSize; ++win_x) {
+#if NORMALIZE
+ // Normalized Image difference.
+ const float dI =
+ (*vals_I_ptr++ - mean_I) - (*vals_J_ptr++ - mean_J) * std_dev_ratio;
+#else
+ const float dI = *vals_I_ptr++ - *vals_J_ptr++;
+#endif
+ b_x += dI * *vals_I_x_ptr++;
+ b_y += dI * *vals_I_y_ptr++;
+ }
+ }
+
+ // Optical flow... solve n = G^-1 * b
+ const float n_x = (G_inv[0] * b_x) + (G_inv[1] * b_y);
+ const float n_y = (G_inv[2] * b_x) + (G_inv[3] * b_y);
+
+ // Update best guess with residual displacement from this level and
+ // iteration.
+ g_x += n_x;
+ g_y += n_y;
+
+ // LOGV("Iteration %d: delta (%.3f, %.3f)", iteration, n_x, n_y);
+
+ // Abort early if we're already below the threshold.
+ if (Square(n_x) + Square(n_y) < Square(kTrackingAbortThreshold)) {
+ break;
+ }
+ } // Iteration.
+
+ // Copy value back into output.
+ *out_g_x = g_x;
+ *out_g_y = g_y;
+ return true;
+}
+
+
+// Pointwise flow using translational 2dof ESM.
+bool OpticalFlow::FindFlowAtPoint_ESM(const Image<uint8>& img_I,
+ const Image<uint8>& img_J,
+ const Image<int32>& I_x,
+ const Image<int32>& I_y,
+ const Image<int32>& J_x,
+ const Image<int32>& J_y,
+ const float p_x,
+ const float p_y,
+ float* out_g_x,
+ float* out_g_y) {
+ float g_x = *out_g_x;
+ float g_y = *out_g_y;
+ const float area_inv = 1.0f / static_cast<float>(kFlowArraySize);
+
+ // Get values for frame 1. They remain constant through the inner
+ // iteration loop.
+ uint8 vals_I[kFlowArraySize];
+ uint8 vals_J[kFlowArraySize];
+ int16 src_gradient_x[kFlowArraySize];
+ int16 src_gradient_y[kFlowArraySize];
+
+ // TODO(rspring): try out the IntegerPatchAlign() method once
+ // the code for that is in ../common.
+ const float wsize_float = static_cast<float>(kFlowIntegrationWindowSize);
+ const int src_left_fixed = RealToFixed1616(p_x - wsize_float);
+ const int src_top_fixed = RealToFixed1616(p_y - wsize_float);
+ const int patch_size = 2 * kFlowIntegrationWindowSize + 1;
+
+ // Create the keypoint template patch from a subpixel location.
+ if (!img_I.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
+ patch_size, patch_size, vals_I) ||
+ !I_x.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
+ patch_size, patch_size,
+ src_gradient_x) ||
+ !I_y.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
+ patch_size, patch_size,
+ src_gradient_y)) {
+ return false;
+ }
+
+ int bright_offset = 0;
+ int sum_diff = 0;
+
+ // The top-left sub-pixel is set for the current iteration (in 16:16 fixed).
+ // This is constant over one iteration.
+ int left_fixed = RealToFixed1616(p_x + g_x - wsize_float);
+ int top_fixed = RealToFixed1616(p_y + g_y - wsize_float);
+
+ // The truncated version gives the most top-left pixel that is used.
+ int left_trunc = left_fixed >> 16;
+ int top_trunc = top_fixed >> 16;
+
+ // Compute an initial brightness offset.
+ if (kDoBrightnessNormalize &&
+ left_trunc >= 0 && top_trunc >= 0 &&
+ (left_trunc + patch_size) < img_J.width_less_one_ &&
+ (top_trunc + patch_size) < img_J.height_less_one_) {
+ int templ_index = 0;
+ const uint8* j_row = img_J[top_trunc] + left_trunc;
+
+ const int j_stride = img_J.stride();
+
+ for (int y = 0; y < patch_size; ++y, j_row += j_stride) {
+ for (int x = 0; x < patch_size; ++x) {
+ sum_diff += static_cast<int>(j_row[x]) - vals_I[templ_index++];
+ }
+ }
+
+ bright_offset = static_cast<int>(static_cast<float>(sum_diff) * area_inv);
+ }
+
+ // Iterate kNumIterations times or until we go out of image.
+ for (int iteration = 0; iteration < kNumIterations; ++iteration) {
+ int jtj[3] = { 0, 0, 0 };
+ int jtr[2] = { 0, 0 };
+ sum_diff = 0;
+
+ // Extract the target image values.
+ // Extract the gradient from the target image patch and accumulate to
+ // the gradient of the source image patch.
+ if (!img_J.ExtractPatchAtSubpixelFixed1616(left_fixed, top_fixed,
+ patch_size, patch_size,
+ vals_J)) {
+ break;
+ }
+
+ const uint8* templ_row = vals_I;
+ const uint8* extract_row = vals_J;
+ const int16* src_dx_row = src_gradient_x;
+ const int16* src_dy_row = src_gradient_y;
+
+ for (int y = 0; y < patch_size; ++y, templ_row += patch_size,
+ src_dx_row += patch_size, src_dy_row += patch_size,
+ extract_row += patch_size) {
+ const int fp_y = top_fixed + (y << 16);
+ for (int x = 0; x < patch_size; ++x) {
+ const int fp_x = left_fixed + (x << 16);
+ int32 target_dx = J_x.GetPixelInterpFixed1616(fp_x, fp_y);
+ int32 target_dy = J_y.GetPixelInterpFixed1616(fp_x, fp_y);
+
+ // Combine the two Jacobians.
+ // Right-shift by one to account for the fact that we add
+ // two Jacobians.
+ int32 dx = (src_dx_row[x] + target_dx) >> 1;
+ int32 dy = (src_dy_row[x] + target_dy) >> 1;
+
+ // The current residual b - h(q) == extracted - (template + offset)
+ int32 diff = static_cast<int32>(extract_row[x]) -
+ static_cast<int32>(templ_row[x]) -
+ bright_offset;
+
+ jtj[0] += dx * dx;
+ jtj[1] += dx * dy;
+ jtj[2] += dy * dy;
+
+ jtr[0] += dx * diff;
+ jtr[1] += dy * diff;
+
+ sum_diff += diff;
+ }
+ }
+
+ const float jtr1_float = static_cast<float>(jtr[0]);
+ const float jtr2_float = static_cast<float>(jtr[1]);
+
+ // Add some baseline stability to the system.
+ jtj[0] += kEsmRegularizer;
+ jtj[2] += kEsmRegularizer;
+
+ const int64 prod1 = static_cast<int64>(jtj[0]) * jtj[2];
+ const int64 prod2 = static_cast<int64>(jtj[1]) * jtj[1];
+
+ // One ESM step.
+ const float jtj_1[4] = { static_cast<float>(jtj[2]),
+ static_cast<float>(-jtj[1]),
+ static_cast<float>(-jtj[1]),
+ static_cast<float>(jtj[0]) };
+ const double det_inv = 1.0 / static_cast<double>(prod1 - prod2);
+
+ g_x -= det_inv * (jtj_1[0] * jtr1_float + jtj_1[1] * jtr2_float);
+ g_y -= det_inv * (jtj_1[2] * jtr1_float + jtj_1[3] * jtr2_float);
+
+ if (kDoBrightnessNormalize) {
+ bright_offset +=
+ static_cast<int>(area_inv * static_cast<float>(sum_diff) + 0.5f);
+ }
+
+ // Update top left position.
+ left_fixed = RealToFixed1616(p_x + g_x - wsize_float);
+ top_fixed = RealToFixed1616(p_y + g_y - wsize_float);
+
+ left_trunc = left_fixed >> 16;
+ top_trunc = top_fixed >> 16;
+
+ // Abort iterations if we go out of borders.
+ if (left_trunc < 0 || top_trunc < 0 ||
+ (left_trunc + patch_size) >= J_x.width_less_one_ ||
+ (top_trunc + patch_size) >= J_y.height_less_one_) {
+ break;
+ }
+ } // Iteration.
+
+ // Copy value back into output.
+ *out_g_x = g_x;
+ *out_g_y = g_y;
+ return true;
+}
+
+
+bool OpticalFlow::FindFlowAtPointReversible(
+ const int level, const float u_x, const float u_y,
+ const bool reverse_flow,
+ float* flow_x, float* flow_y) const {
+ const ImageData& frame_a = reverse_flow ? *frame2_ : *frame1_;
+ const ImageData& frame_b = reverse_flow ? *frame1_ : *frame2_;
+
+ // Images I (prev) and J (next).
+ const Image<uint8>& img_I = *frame_a.GetPyramidSqrt2Level(level * 2);
+ const Image<uint8>& img_J = *frame_b.GetPyramidSqrt2Level(level * 2);
+
+ // Computed gradients.
+ const Image<int32>& I_x = *frame_a.GetSpatialX(level);
+ const Image<int32>& I_y = *frame_a.GetSpatialY(level);
+ const Image<int32>& J_x = *frame_b.GetSpatialX(level);
+ const Image<int32>& J_y = *frame_b.GetSpatialY(level);
+
+ // Shrink factor from original.
+ const float shrink_factor = (1 << level);
+
+ // Image position vector (p := u^l), scaled for this level.
+ const float scaled_p_x = u_x / shrink_factor;
+ const float scaled_p_y = u_y / shrink_factor;
+
+ float scaled_flow_x = *flow_x / shrink_factor;
+ float scaled_flow_y = *flow_y / shrink_factor;
+
+ // LOGE("FindFlowAtPoint level %d: %5.2f, %5.2f (%5.2f, %5.2f)", level,
+ // scaled_p_x, scaled_p_y, &scaled_flow_x, &scaled_flow_y);
+
+ const bool success = kUseEsm ?
+ FindFlowAtPoint_ESM(img_I, img_J, I_x, I_y, J_x, J_y,
+ scaled_p_x, scaled_p_y,
+ &scaled_flow_x, &scaled_flow_y) :
+ FindFlowAtPoint_LK(img_I, img_J, I_x, I_y,
+ scaled_p_x, scaled_p_y,
+ &scaled_flow_x, &scaled_flow_y);
+
+ *flow_x = scaled_flow_x * shrink_factor;
+ *flow_y = scaled_flow_y * shrink_factor;
+
+ return success;
+}
+
+
+bool OpticalFlow::FindFlowAtPointSingleLevel(
+ const int level,
+ const float u_x, const float u_y,
+ const bool filter_by_fb_error,
+ float* flow_x, float* flow_y) const {
+ if (!FindFlowAtPointReversible(level, u_x, u_y, false, flow_x, flow_y)) {
+ return false;
+ }
+
+ if (filter_by_fb_error) {
+ const float new_position_x = u_x + *flow_x;
+ const float new_position_y = u_y + *flow_y;
+
+ float reverse_flow_x = 0.0f;
+ float reverse_flow_y = 0.0f;
+
+ // Now find the backwards flow and confirm it lines up with the original
+ // starting point.
+ if (!FindFlowAtPointReversible(level, new_position_x, new_position_y,
+ true,
+ &reverse_flow_x, &reverse_flow_y)) {
+ LOGE("Backward error!");
+ return false;
+ }
+
+ const float discrepancy_length =
+ sqrtf(Square(*flow_x + reverse_flow_x) +
+ Square(*flow_y + reverse_flow_y));
+
+ const float flow_length = sqrtf(Square(*flow_x) + Square(*flow_y));
+
+ return discrepancy_length <
+ (kMaxForwardBackwardErrorAllowed * flow_length);
+ }
+
+ return true;
+}
+
+
+// An implementation of the Pyramidal Lucas-Kanade Optical Flow algorithm.
+// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for details.
+bool OpticalFlow::FindFlowAtPointPyramidal(const float u_x, const float u_y,
+ const bool filter_by_fb_error,
+ float* flow_x, float* flow_y) const {
+ const int max_level = MAX(kMinNumPyramidLevelsToUseForAdjustment,
+ kNumPyramidLevels - kNumCacheLevels);
+
+ // For every level in the pyramid, update the coordinates of the best match.
+ for (int l = max_level - 1; l >= 0; --l) {
+ if (!FindFlowAtPointSingleLevel(l, u_x, u_y,
+ filter_by_fb_error, flow_x, flow_y)) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+} // namespace tf_tracking
diff --git a/tensorflow/examples/android/jni/object_tracking/optical_flow.h b/tensorflow/examples/android/jni/object_tracking/optical_flow.h
new file mode 100644
index 0000000000..1329927b99
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/optical_flow.h
@@ -0,0 +1,111 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/config.h"
+#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
+#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
+#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
+
+using namespace tensorflow;
+
+namespace tf_tracking {
+
+class FlowCache;
+
+// Class encapsulating all the data and logic necessary for performing optical
+// flow.
+class OpticalFlow {
+ public:
+ explicit OpticalFlow(const OpticalFlowConfig* const config);
+
+ // Add a new frame to the optical flow. Will update all the non-keypoint
+ // related member variables.
+ //
+ // new_frame should be a buffer of grayscale values, one byte per pixel,
+ // at the original frame_width and frame_height used to initialize the
+ // OpticalFlow object. Downsampling will be handled internally.
+ //
+ // time_stamp should be a time in milliseconds that later calls to this and
+ // other methods will be relative to.
+ void NextFrame(const ImageData* const image_data);
+
+ // An implementation of the Lucas-Kanade Optical Flow algorithm.
+ static bool FindFlowAtPoint_LK(const Image<uint8>& img_I,
+ const Image<uint8>& img_J,
+ const Image<int32>& I_x,
+ const Image<int32>& I_y,
+ const float p_x,
+ const float p_y,
+ float* out_g_x,
+ float* out_g_y);
+
+ // Pointwise flow using translational 2dof ESM.
+ static bool FindFlowAtPoint_ESM(const Image<uint8>& img_I,
+ const Image<uint8>& img_J,
+ const Image<int32>& I_x,
+ const Image<int32>& I_y,
+ const Image<int32>& J_x,
+ const Image<int32>& J_y,
+ const float p_x,
+ const float p_y,
+ float* out_g_x,
+ float* out_g_y);
+
+ // Finds the flow using a specific level, in either direction.
+ // If reversed, the coordinates are in the context of the latest
+ // frame, not the frame before it.
+ // All coordinates used in parameters are global, not scaled.
+ bool FindFlowAtPointReversible(
+ const int level, const float u_x, const float u_y,
+ const bool reverse_flow,
+ float* final_x, float* final_y) const;
+
+ // Finds the flow using a specific level, filterable by forward-backward
+ // error. All coordinates used in parameters are global, not scaled.
+ bool FindFlowAtPointSingleLevel(const int level,
+ const float u_x, const float u_y,
+ const bool filter_by_fb_error,
+ float* flow_x, float* flow_y) const;
+
+ // Pyramidal optical-flow using all levels.
+ bool FindFlowAtPointPyramidal(const float u_x, const float u_y,
+ const bool filter_by_fb_error,
+ float* flow_x, float* flow_y) const;
+
+ private:
+ const OpticalFlowConfig* const config_;
+
+ const ImageData* frame1_;
+ const ImageData* frame2_;
+
+ // Size of the internally allocated images (after original is downsampled).
+ const Size working_size_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(OpticalFlow);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/sprite.h b/tensorflow/examples/android/jni/object_tracking/sprite.h
new file mode 100755
index 0000000000..6240591cf2
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/sprite.h
@@ -0,0 +1,205 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
+
+#include <GLES/gl.h>
+#include <GLES/glext.h>
+
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+
+#ifndef __RENDER_OPENGL__
+#error sprite.h should not included if OpenGL is not enabled by platform.h
+#endif
+
+namespace tf_tracking {
+
+// This class encapsulates the logic necessary to load an render image data
+// at the same aspect ratio as the original source.
+class Sprite {
+ public:
+ // Only create Sprites when you have an OpenGl context.
+ explicit Sprite(const Image<uint8>& image) {
+ LoadTexture(image, NULL);
+ }
+
+ Sprite(const Image<uint8>& image, const BoundingBox* const area) {
+ LoadTexture(image, area);
+ }
+
+ // Also, try to only delete a Sprite when holding an OpenGl context.
+ ~Sprite() {
+ glDeleteTextures(1, &texture_);
+ }
+
+ inline int GetWidth() const {
+ return actual_width_;
+ }
+
+ inline int GetHeight() const {
+ return actual_height_;
+ }
+
+ // Draw the sprite at 0,0 - original width/height in the current reference
+ // frame. Any transformations desired must be applied before calling this
+ // function.
+ void Draw() const {
+ const float float_width = static_cast<float>(actual_width_);
+ const float float_height = static_cast<float>(actual_height_);
+
+ // Where it gets rendered to.
+ const float vertices[] = { 0.0f, 0.0f, 0.0f,
+ 0.0f, float_height, 0.0f,
+ float_width, 0.0f, 0.0f,
+ float_width, float_height, 0.0f,
+ };
+
+ // The coordinates the texture gets drawn from.
+ const float max_x = float_width / texture_width_;
+ const float max_y = float_height / texture_height_;
+ const float textureVertices[] = {
+ 0, 0,
+ 0, max_y,
+ max_x, 0,
+ max_x, max_y,
+ };
+
+ glEnable(GL_TEXTURE_2D);
+ glBindTexture(GL_TEXTURE_2D, texture_);
+
+ glEnableClientState(GL_VERTEX_ARRAY);
+ glEnableClientState(GL_TEXTURE_COORD_ARRAY);
+
+ glVertexPointer(3, GL_FLOAT, 0, vertices);
+ glTexCoordPointer(2, GL_FLOAT, 0, textureVertices);
+
+ glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);
+
+ glDisableClientState(GL_VERTEX_ARRAY);
+ glDisableClientState(GL_TEXTURE_COORD_ARRAY);
+ }
+
+ private:
+ inline int GetNextPowerOfTwo(const int number) const {
+ int power_of_two = 1;
+ while (power_of_two < number) {
+ power_of_two *= 2;
+ }
+ return power_of_two;
+ }
+
+ // TODO(andrewharp): Allow sprites to have their textures reloaded.
+ void LoadTexture(const Image<uint8>& texture_source,
+ const BoundingBox* const area) {
+ glEnable(GL_TEXTURE_2D);
+
+ glGenTextures(1, &texture_);
+
+ glBindTexture(GL_TEXTURE_2D, texture_);
+
+ int left = 0;
+ int top = 0;
+
+ if (area != NULL) {
+ // If a sub-region was provided to pull the texture from, use that.
+ left = area->left_;
+ top = area->top_;
+ actual_width_ = area->GetWidth();
+ actual_height_ = area->GetHeight();
+ } else {
+ actual_width_ = texture_source.GetWidth();
+ actual_height_ = texture_source.GetHeight();
+ }
+
+ // The textures must be a power of two, so find the sizes that are large
+ // enough to contain the image data.
+ texture_width_ = GetNextPowerOfTwo(actual_width_);
+ texture_height_ = GetNextPowerOfTwo(actual_height_);
+
+ bool allocated_data = false;
+ uint8* texture_data;
+
+ // Except in the lucky case where we're not using a sub-region of the
+ // original image AND the source data has dimensions that are power of two,
+ // care must be taken to copy data at the appropriate source and destination
+ // strides so that the final block can be copied directly into texture
+ // memory.
+ // TODO(andrewharp): Figure out if data can be pulled directly from the
+ // source image with some alignment modifications.
+ if (left != 0 || top != 0 ||
+ actual_width_ != texture_source.GetWidth() ||
+ actual_height_ != texture_source.GetHeight()) {
+ texture_data = new uint8[actual_width_ * actual_height_];
+
+ for (int y = 0; y < actual_height_; ++y) {
+ memcpy(texture_data + actual_width_ * y,
+ texture_source[top + y] + left,
+ actual_width_ * sizeof(uint8));
+ }
+ allocated_data = true;
+ } else {
+ // Cast away const-ness because for some reason glTexSubImage2D wants
+ // a non-const data pointer.
+ texture_data = const_cast<uint8*>(texture_source.data());
+ }
+
+ glTexImage2D(GL_TEXTURE_2D,
+ 0,
+ GL_LUMINANCE,
+ texture_width_,
+ texture_height_,
+ 0,
+ GL_LUMINANCE,
+ GL_UNSIGNED_BYTE,
+ NULL);
+
+ glPixelStorei(GL_UNPACK_ALIGNMENT, 1);
+ glTexSubImage2D(GL_TEXTURE_2D,
+ 0,
+ 0,
+ 0,
+ actual_width_,
+ actual_height_,
+ GL_LUMINANCE,
+ GL_UNSIGNED_BYTE,
+ texture_data);
+
+ if (allocated_data) {
+ delete(texture_data);
+ }
+
+ glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR);
+ }
+
+ // The id for the texture on the GPU.
+ GLuint texture_;
+
+ // The width and height to be used for display purposes, referring to the
+ // dimensions of the original texture.
+ int actual_width_;
+ int actual_height_;
+
+ // The allocated dimensions of the texture data, which must be powers of 2.
+ int texture_width_;
+ int texture_height_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Sprite);
+};
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/time_log.cc b/tensorflow/examples/android/jni/object_tracking/time_log.cc
new file mode 100644
index 0000000000..cb1f3c23c8
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/time_log.cc
@@ -0,0 +1,29 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
+
+using namespace tensorflow;
+
+#ifdef LOG_TIME
+// Storage for logging functionality.
+int num_time_logs = 0;
+LogEntry time_logs[NUM_LOGS];
+
+int num_avg_entries = 0;
+AverageEntry avg_entries[NUM_LOGS];
+#endif
diff --git a/tensorflow/examples/android/jni/object_tracking/time_log.h b/tensorflow/examples/android/jni/object_tracking/time_log.h
new file mode 100644
index 0000000000..ec539a1b3b
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/time_log.h
@@ -0,0 +1,138 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Utility functions for performance profiling.
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+#ifdef LOG_TIME
+
+// Blend constant for running average.
+#define ALPHA 0.98f
+#define NUM_LOGS 100
+
+struct LogEntry {
+ const char* id;
+ int64 time_stamp;
+};
+
+struct AverageEntry {
+ const char* id;
+ float average_duration;
+};
+
+// Storage for keeping track of this frame's values.
+extern int num_time_logs;
+extern LogEntry time_logs[NUM_LOGS];
+
+// Storage for keeping track of average values (each entry may not be printed
+// out each frame).
+extern AverageEntry avg_entries[NUM_LOGS];
+extern int num_avg_entries;
+
+// Call this at the start of a logging phase.
+inline static void ResetTimeLog() {
+ num_time_logs = 0;
+}
+
+
+// Log a message to be printed out when printTimeLog is called, along with the
+// amount of time in ms that has passed since the last call to this function.
+inline static void TimeLog(const char* const str) {
+ LOGV("%s", str);
+ if (num_time_logs >= NUM_LOGS) {
+ LOGE("Out of log entries!");
+ return;
+ }
+
+ time_logs[num_time_logs].id = str;
+ time_logs[num_time_logs].time_stamp = CurrentThreadTimeNanos();
+ ++num_time_logs;
+}
+
+
+inline static float Blend(float old_val, float new_val) {
+ return ALPHA * old_val + (1.0f - ALPHA) * new_val;
+}
+
+
+inline static float UpdateAverage(const char* str, const float new_val) {
+ for (int entry_num = 0; entry_num < num_avg_entries; ++entry_num) {
+ AverageEntry* const entry = avg_entries + entry_num;
+ if (str == entry->id) {
+ entry->average_duration = Blend(entry->average_duration, new_val);
+ return entry->average_duration;
+ }
+ }
+
+ if (num_avg_entries >= NUM_LOGS) {
+ LOGE("Too many log entries!");
+ }
+
+ // If it wasn't there already, add it.
+ avg_entries[num_avg_entries].id = str;
+ avg_entries[num_avg_entries].average_duration = new_val;
+ ++num_avg_entries;
+
+ return new_val;
+}
+
+
+// Prints out all the timeLog statements in chronological order with the
+// interval that passed between subsequent statements. The total time between
+// the first and last statements is printed last.
+inline static void PrintTimeLog() {
+ LogEntry* last_time = time_logs;
+
+ float average_running_total = 0.0f;
+
+ for (int i = 0; i < num_time_logs; ++i) {
+ LogEntry* const this_time = time_logs + i;
+
+ const float curr_time =
+ (this_time->time_stamp - last_time->time_stamp) / 1000000.0f;
+
+ const float avg_time = UpdateAverage(this_time->id, curr_time);
+ average_running_total += avg_time;
+
+ LOGD("%32s: %6.3fms %6.4fms", this_time->id, curr_time, avg_time);
+ last_time = this_time;
+ }
+
+ const float total_time =
+ (last_time->time_stamp - time_logs->time_stamp) / 1000000.0f;
+
+ LOGD("TOTAL TIME: %6.3fms %6.4fms\n",
+ total_time, average_running_total);
+ LOGD(" ");
+}
+#else
+inline static void ResetTimeLog() {}
+
+inline static void TimeLog(const char* const str) {
+ LOGV("%s", str);
+}
+
+inline static void PrintTimeLog() {}
+#endif
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/tracked_object.cc b/tensorflow/examples/android/jni/object_tracking/tracked_object.cc
new file mode 100644
index 0000000000..823fb3a90e
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/tracked_object.cc
@@ -0,0 +1,163 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h"
+
+namespace tf_tracking {
+
+static const float kInitialDistance = 20.0f;
+
+static void InitNormalized(const Image<uint8>& src_image,
+ const BoundingBox& position,
+ Image<float>* const dst_image) {
+ BoundingBox scaled_box(position);
+ CopyArea(src_image, scaled_box, dst_image);
+ NormalizeImage(dst_image);
+}
+
+TrackedObject::TrackedObject(const std::string& id,
+ const Image<uint8>& image,
+ const BoundingBox& bounding_box,
+ ObjectModelBase* const model)
+ : id_(id),
+ last_known_position_(bounding_box),
+ last_detection_position_(bounding_box),
+ position_last_computed_time_(-1),
+ object_model_(model),
+ last_detection_thumbnail_(kNormalizedThumbnailSize,
+ kNormalizedThumbnailSize),
+ last_frame_thumbnail_(kNormalizedThumbnailSize, kNormalizedThumbnailSize),
+ tracked_correlation_(0.0f),
+ tracked_match_score_(0.0),
+ num_consecutive_frames_below_threshold_(0),
+ allowable_detection_distance_(Square(kInitialDistance)) {
+ InitNormalized(image, bounding_box, &last_detection_thumbnail_);
+}
+
+TrackedObject::~TrackedObject() {}
+
+void TrackedObject::UpdatePosition(const BoundingBox& new_position,
+ const int64 timestamp,
+ const ImageData& image_data,
+ const bool authoratative) {
+ last_known_position_ = new_position;
+ position_last_computed_time_ = timestamp;
+
+ InitNormalized(*image_data.GetImage(), new_position, &last_frame_thumbnail_);
+
+ const float last_localization_correlation = ComputeCrossCorrelation(
+ last_detection_thumbnail_.data(),
+ last_frame_thumbnail_.data(),
+ last_frame_thumbnail_.data_size_);
+ LOGV("Tracked correlation to last localization: %.6f",
+ last_localization_correlation);
+
+ // Correlation to object model, if it exists.
+ if (object_model_ != NULL) {
+ tracked_correlation_ =
+ object_model_->GetMaxCorrelation(last_frame_thumbnail_);
+ LOGV("Tracked correlation to model: %.6f",
+ tracked_correlation_);
+
+ tracked_match_score_ =
+ object_model_->GetMatchScore(new_position, image_data);
+ LOGV("Tracked match score with model: %.6f",
+ tracked_match_score_.value);
+ } else {
+ // If there's no model to check against, set the tracked correlation to
+ // simply be the correlation to the last set position.
+ tracked_correlation_ = last_localization_correlation;
+ tracked_match_score_ = MatchScore(0.0f);
+ }
+
+ // Determine if it's still being tracked.
+ if (tracked_correlation_ >= kMinimumCorrelationForTracking &&
+ tracked_match_score_ >= kMinimumMatchScore) {
+ num_consecutive_frames_below_threshold_ = 0;
+
+ if (object_model_ != NULL) {
+ object_model_->TrackStep(last_known_position_, *image_data.GetImage(),
+ *image_data.GetIntegralImage(), authoratative);
+ }
+ } else if (tracked_match_score_ < kMatchScoreForImmediateTermination) {
+ if (num_consecutive_frames_below_threshold_ < 1000) {
+ LOGD("Tracked match score is way too low (%.6f), aborting track.",
+ tracked_match_score_.value);
+ }
+
+ // Add an absurd amount of missed frames so that all heuristics will
+ // consider it a lost track.
+ num_consecutive_frames_below_threshold_ += 1000;
+
+ if (object_model_ != NULL) {
+ object_model_->TrackLost();
+ }
+ } else {
+ ++num_consecutive_frames_below_threshold_;
+ allowable_detection_distance_ *= 1.1f;
+ }
+}
+
+void TrackedObject::OnDetection(ObjectModelBase* const model,
+ const BoundingBox& detection_position,
+ const MatchScore match_score,
+ const int64 timestamp,
+ const ImageData& image_data) {
+ const float overlap = detection_position.PascalScore(last_known_position_);
+ if (overlap > kPositionOverlapThreshold) {
+ // If the position agreement with the current tracked position is good
+ // enough, lock all the current unlocked examples.
+ object_model_->TrackConfirmed();
+ num_consecutive_frames_below_threshold_ = 0;
+ }
+
+ // Before relocalizing, make sure the new proposed position is better than
+ // the existing position by a small amount to prevent thrashing.
+ if (match_score <= tracked_match_score_ + kMatchScoreBuffer) {
+ LOGI("Not relocalizing since new match is worse: %.6f < %.6f + %.6f",
+ match_score.value, tracked_match_score_.value,
+ kMatchScoreBuffer.value);
+ return;
+ }
+
+ LOGI("Relocalizing! From (%.1f, %.1f)[%.1fx%.1f] to "
+ "(%.1f, %.1f)[%.1fx%.1f]: %.6f > %.6f",
+ last_known_position_.left_, last_known_position_.top_,
+ last_known_position_.GetWidth(), last_known_position_.GetHeight(),
+ detection_position.left_, detection_position.top_,
+ detection_position.GetWidth(), detection_position.GetHeight(),
+ match_score.value, tracked_match_score_.value);
+
+ if (overlap < kPositionOverlapThreshold) {
+ // The path might be good, it might be bad, but it's no longer a path
+ // since we're moving the box to a new position, so just nuke it from
+ // orbit to be safe.
+ object_model_->TrackLost();
+ }
+
+ object_model_ = model;
+
+ // Reset the last detected appearance.
+ InitNormalized(
+ *image_data.GetImage(), detection_position, &last_detection_thumbnail_);
+
+ num_consecutive_frames_below_threshold_ = 0;
+ last_detection_position_ = detection_position;
+
+ UpdatePosition(detection_position, timestamp, image_data, false);
+ allowable_detection_distance_ = Square(kInitialDistance);
+}
+
+} // namespace tf_tracking
diff --git a/tensorflow/examples/android/jni/object_tracking/tracked_object.h b/tensorflow/examples/android/jni/object_tracking/tracked_object.h
new file mode 100644
index 0000000000..5580cd2b89
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/tracked_object.h
@@ -0,0 +1,191 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
+
+#ifdef __RENDER_OPENGL__
+#include "tensorflow/examples/android/jni/object_tracking/gl_utils.h"
+#endif
+#include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
+
+namespace tf_tracking {
+
+// A TrackedObject is a specific instance of an ObjectModel, with a known
+// position in the world.
+// It provides the last known position and number of recent detection failures,
+// in addition to the more general appearance data associated with the object
+// class (which is in ObjectModel).
+// TODO(andrewharp): Make getters/setters follow styleguide.
+class TrackedObject {
+ public:
+ TrackedObject(const std::string& id,
+ const Image<uint8>& image,
+ const BoundingBox& bounding_box,
+ ObjectModelBase* const model);
+
+ ~TrackedObject();
+
+ void UpdatePosition(const BoundingBox& new_position,
+ const int64 timestamp,
+ const ImageData& image_data,
+ const bool authoratative);
+
+ // This method is called when the tracked object is detected at a
+ // given position, and allows the associated Model to grow and/or prune
+ // itself based on where the detection occurred.
+ void OnDetection(ObjectModelBase* const model,
+ const BoundingBox& detection_position,
+ const MatchScore match_score,
+ const int64 timestamp,
+ const ImageData& image_data);
+
+ // Called when there's no detection of the tracked object. This will cause
+ // a tracking failure after enough consecutive failures if the area under
+ // the current bounding box also doesn't meet a minimum correlation threshold
+ // with the model.
+ void OnDetectionFailure() {}
+
+ inline bool IsVisible() const {
+ return tracked_correlation_ >= kMinimumCorrelationForTracking ||
+ num_consecutive_frames_below_threshold_ < kMaxNumDetectionFailures;
+ }
+
+ inline float GetCorrelation() {
+ return tracked_correlation_;
+ }
+
+ inline MatchScore GetMatchScore() {
+ return tracked_match_score_;
+ }
+
+ inline BoundingBox GetPosition() const {
+ return last_known_position_;
+ }
+
+ inline BoundingBox GetLastDetectionPosition() const {
+ return last_detection_position_;
+ }
+
+ inline const ObjectModelBase* GetModel() const {
+ return object_model_;
+ }
+
+ inline const std::string& GetName() const {
+ return id_;
+ }
+
+ inline void Draw() const {
+#ifdef __RENDER_OPENGL__
+ if (tracked_correlation_ < kMinimumCorrelationForTracking) {
+ glColor4f(MAX(0.0f, -tracked_correlation_),
+ MAX(0.0f, tracked_correlation_),
+ 0.0f,
+ 1.0f);
+ } else {
+ glColor4f(MAX(0.0f, -tracked_correlation_),
+ MAX(0.0f, tracked_correlation_),
+ 1.0f,
+ 1.0f);
+ }
+
+ // Render the box itself.
+ BoundingBox temp_box(last_known_position_);
+ DrawBox(temp_box);
+
+ // Render a box inside this one (in case the actual box is hidden).
+ const float kBufferSize = 1.0f;
+ temp_box.left_ -= kBufferSize;
+ temp_box.top_ -= kBufferSize;
+ temp_box.right_ += kBufferSize;
+ temp_box.bottom_ += kBufferSize;
+ DrawBox(temp_box);
+
+ // Render one outside as well.
+ temp_box.left_ -= -2.0f * kBufferSize;
+ temp_box.top_ -= -2.0f * kBufferSize;
+ temp_box.right_ += -2.0f * kBufferSize;
+ temp_box.bottom_ += -2.0f * kBufferSize;
+ DrawBox(temp_box);
+#endif
+ }
+
+ // Get current object's num_consecutive_frames_below_threshold_.
+ inline int64 GetNumConsecutiveFramesBelowThreshold() {
+ return num_consecutive_frames_below_threshold_;
+ }
+
+ // Reset num_consecutive_frames_below_threshold_ to 0.
+ inline void resetNumConsecutiveFramesBelowThreshold() {
+ num_consecutive_frames_below_threshold_ = 0;
+ }
+
+ inline float GetAllowableDistanceSquared() const {
+ return allowable_detection_distance_;
+ }
+
+ private:
+ // The unique id used throughout the system to identify this
+ // tracked object.
+ const std::string id_;
+
+ // The last known position of the object.
+ BoundingBox last_known_position_;
+
+ // The last known position of the object.
+ BoundingBox last_detection_position_;
+
+ // When the position was last computed.
+ int64 position_last_computed_time_;
+
+ // The object model this tracked object is representative of.
+ ObjectModelBase* object_model_;
+
+ Image<float> last_detection_thumbnail_;
+
+ Image<float> last_frame_thumbnail_;
+
+ // The correlation of the object model with the preview frame at its last
+ // tracked position.
+ float tracked_correlation_;
+
+ MatchScore tracked_match_score_;
+
+ // The number of consecutive frames that the tracked position for this object
+ // has been under the correlation threshold.
+ int num_consecutive_frames_below_threshold_;
+
+ float allowable_detection_distance_;
+
+ friend std::ostream& operator<<(std::ostream& stream,
+ const TrackedObject& tracked_object);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TrackedObject);
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const TrackedObject& tracked_object) {
+ stream << tracked_object.id_
+ << " " << tracked_object.last_known_position_
+ << " " << tracked_object.position_last_computed_time_
+ << " " << tracked_object.num_consecutive_frames_below_threshold_
+ << " " << tracked_object.object_model_
+ << " " << tracked_object.tracked_correlation_;
+ return stream;
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/utils.h b/tensorflow/examples/android/jni/object_tracking/utils.h
new file mode 100644
index 0000000000..cbdfc408c6
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/utils.h
@@ -0,0 +1,386 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
+
+#include <math.h>
+#include <stdlib.h>
+#include <time.h>
+
+#include <cmath> // for std::abs(float)
+
+#ifndef HAVE_CLOCK_GETTIME
+// Use gettimeofday() instead of clock_gettime().
+#include <sys/time.h>
+#endif // ifdef HAVE_CLOCK_GETTIME
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+using namespace tensorflow;
+
+// TODO(andrewharp): clean up these macros to use the codebase statndard.
+
+// A very small number, generally used as the tolerance for accumulated
+// floating point errors in bounds-checks.
+#define EPSILON 0.00001f
+
+#define SAFE_DELETE(pointer) {\
+ if ((pointer) != NULL) {\
+ LOGV("Safe deleting pointer: %s", #pointer);\
+ delete (pointer);\
+ (pointer) = NULL;\
+ } else {\
+ LOGV("Pointer already null: %s", #pointer);\
+ }\
+}
+
+
+#ifdef __GOOGLE__
+
+#define CHECK_ALWAYS(condition, format, ...) {\
+ CHECK(condition) << StringPrintf(format, ##__VA_ARGS__);\
+}
+
+#define SCHECK(condition, format, ...) {\
+ DCHECK(condition) << StringPrintf(format, ##__VA_ARGS__);\
+}
+
+#else
+
+#define CHECK_ALWAYS(condition, format, ...) {\
+ if (!(condition)) {\
+ LOGE("CHECK FAILED (%s): " format, #condition, ##__VA_ARGS__);\
+ abort();\
+ }\
+}
+
+#ifdef SANITY_CHECKS
+#define SCHECK(condition, format, ...) {\
+ CHECK_ALWAYS(condition, format, ##__VA_ARGS__);\
+}
+#else
+#define SCHECK(condition, format, ...) {}
+#endif // SANITY_CHECKS
+
+#endif // __GOOGLE__
+
+
+#ifndef MAX
+#define MAX(a, b) (((a) > (b)) ? (a) : (b))
+#endif
+#ifndef MIN
+#define MIN(a, b) (((a) > (b)) ? (b) : (a))
+#endif
+
+
+
+inline static int64 CurrentThreadTimeNanos() {
+#ifdef HAVE_CLOCK_GETTIME
+ struct timespec tm;
+ clock_gettime(CLOCK_THREAD_CPUTIME_ID, &tm);
+ return tm.tv_sec * 1000000000LL + tm.tv_nsec;
+#else
+ struct timeval tv;
+ gettimeofday(&tv, NULL);
+ return tv.tv_sec * 1000000000 + tv.tv_usec * 1000;
+#endif
+}
+
+
+inline static int64 CurrentRealTimeMillis() {
+#ifdef HAVE_CLOCK_GETTIME
+ struct timespec tm;
+ clock_gettime(CLOCK_MONOTONIC, &tm);
+ return tm.tv_sec * 1000LL + tm.tv_nsec / 1000000LL;
+#else
+ struct timeval tv;
+ gettimeofday(&tv, NULL);
+ return tv.tv_sec * 1000 + tv.tv_usec / 1000;
+#endif
+}
+
+
+template<typename T>
+inline static T Square(const T a) {
+ return a * a;
+}
+
+
+template<typename T>
+inline static T Clip(const T a, const T floor, const T ceil) {
+ SCHECK(ceil >= floor, "Bounds mismatch!");
+ return (a <= floor) ? floor : ((a >= ceil) ? ceil : a);
+}
+
+
+template<typename T>
+inline static int Floor(const T a) {
+ return static_cast<int>(a);
+}
+
+
+template<typename T>
+inline static int Ceil(const T a) {
+ return Floor(a) + 1;
+}
+
+
+template<typename T>
+inline static bool InRange(const T a, const T min, const T max) {
+ return (a >= min) && (a <= max);
+}
+
+
+inline static bool ValidIndex(const int a, const int max) {
+ return (a >= 0) && (a < max);
+}
+
+
+inline bool NearlyEqual(const float a, const float b, const float tolerance) {
+ return std::abs(a - b) < tolerance;
+}
+
+
+inline bool NearlyEqual(const float a, const float b) {
+ return NearlyEqual(a, b, EPSILON);
+}
+
+
+template<typename T>
+inline static int Round(const float a) {
+ return (a - static_cast<float>(floor(a) > 0.5f) ? ceil(a) : floor(a));
+}
+
+
+template<typename T>
+inline static void Swap(T* const a, T* const b) {
+ // Cache out the VALUE of what's at a.
+ T tmp = *a;
+ *a = *b;
+
+ *b = tmp;
+}
+
+
+static inline float randf() {
+ return rand() / static_cast<float>(RAND_MAX);
+}
+
+static inline float randf(const float min_value, const float max_value) {
+ return randf() * (max_value - min_value) + min_value;
+}
+
+static inline uint16 RealToFixed115(const float real_number) {
+ SCHECK(InRange(real_number, 0.0f, 2048.0f),
+ "Value out of range! %.2f", real_number);
+
+ static const float kMult = 32.0f;
+ const float round_add = (real_number > 0.0f) ? 0.5f : -0.5f;
+ return static_cast<uint16>(real_number * kMult + round_add);
+}
+
+static inline float FixedToFloat115(const uint16 fp_number) {
+ const float kDiv = 32.0f;
+ return (static_cast<float>(fp_number) / kDiv);
+}
+
+static inline int RealToFixed1616(const float real_number) {
+ static const float kMult = 65536.0f;
+ SCHECK(InRange(real_number, -kMult, kMult),
+ "Value out of range! %.2f", real_number);
+
+ const float round_add = (real_number > 0.0f) ? 0.5f : -0.5f;
+ return static_cast<int>(real_number * kMult + round_add);
+}
+
+static inline float FixedToFloat1616(const int fp_number) {
+ const float kDiv = 65536.0f;
+ return (static_cast<float>(fp_number) / kDiv);
+}
+
+template<typename T>
+// produces numbers in range [0,2*M_PI] (rather than -PI,PI)
+inline T FastAtan2(const T y, const T x) {
+ static const T coeff_1 = (T)(M_PI / 4.0);
+ static const T coeff_2 = (T)(3.0 * coeff_1);
+ const T abs_y = fabs(y);
+ T angle;
+ if (x >= 0) {
+ T r = (x - abs_y) / (x + abs_y);
+ angle = coeff_1 - coeff_1 * r;
+ } else {
+ T r = (x + abs_y) / (abs_y - x);
+ angle = coeff_2 - coeff_1 * r;
+ }
+ static const T PI_2 = 2.0 * M_PI;
+ return y < 0 ? PI_2 - angle : angle;
+}
+
+#define NELEMS(X) (sizeof(X) / sizeof(X[0]))
+
+namespace tf_tracking {
+
+#ifdef __ARM_NEON
+float ComputeMeanNeon(const float* const values, const int num_vals);
+
+float ComputeStdDevNeon(const float* const values, const int num_vals,
+ const float mean);
+
+float ComputeWeightedMeanNeon(const float* const values,
+ const float* const weights, const int num_vals);
+
+float ComputeCrossCorrelationNeon(const float* const values1,
+ const float* const values2,
+ const int num_vals);
+#endif
+
+inline float ComputeMeanCpu(const float* const values, const int num_vals) {
+ // Get mean.
+ float sum = values[0];
+ for (int i = 1; i < num_vals; ++i) {
+ sum += values[i];
+ }
+ return sum / static_cast<float>(num_vals);
+}
+
+
+inline float ComputeMean(const float* const values, const int num_vals) {
+ return
+#ifdef __ARM_NEON
+ (num_vals >= 8) ? ComputeMeanNeon(values, num_vals) :
+#endif
+ ComputeMeanCpu(values, num_vals);
+}
+
+
+inline float ComputeStdDevCpu(const float* const values,
+ const int num_vals,
+ const float mean) {
+ // Get Std dev.
+ float squared_sum = 0.0f;
+ for (int i = 0; i < num_vals; ++i) {
+ squared_sum += Square(values[i] - mean);
+ }
+ return sqrt(squared_sum / static_cast<float>(num_vals));
+}
+
+
+inline float ComputeStdDev(const float* const values,
+ const int num_vals,
+ const float mean) {
+ return
+#ifdef __ARM_NEON
+ (num_vals >= 8) ? ComputeStdDevNeon(values, num_vals, mean) :
+#endif
+ ComputeStdDevCpu(values, num_vals, mean);
+}
+
+
+// TODO(andrewharp): Accelerate with NEON.
+inline float ComputeWeightedMean(const float* const values,
+ const float* const weights,
+ const int num_vals) {
+ float sum = 0.0f;
+ float total_weight = 0.0f;
+ for (int i = 0; i < num_vals; ++i) {
+ sum += values[i] * weights[i];
+ total_weight += weights[i];
+ }
+ return sum / num_vals;
+}
+
+
+inline float ComputeCrossCorrelationCpu(const float* const values1,
+ const float* const values2,
+ const int num_vals) {
+ float sxy = 0.0f;
+ for (int offset = 0; offset < num_vals; ++offset) {
+ sxy += values1[offset] * values2[offset];
+ }
+
+ const float cross_correlation = sxy / num_vals;
+
+ return cross_correlation;
+}
+
+
+inline float ComputeCrossCorrelation(const float* const values1,
+ const float* const values2,
+ const int num_vals) {
+ return
+#ifdef __ARM_NEON
+ (num_vals >= 8) ? ComputeCrossCorrelationNeon(values1, values2, num_vals)
+ :
+#endif
+ ComputeCrossCorrelationCpu(values1, values2, num_vals);
+}
+
+
+inline void NormalizeNumbers(float* const values, const int num_vals) {
+ // Find the mean and then subtract so that the new mean is 0.0.
+ const float mean = ComputeMean(values, num_vals);
+ VLOG(2) << "Mean is " << mean;
+ float* curr_data = values;
+ for (int i = 0; i < num_vals; ++i) {
+ *curr_data -= mean;
+ curr_data++;
+ }
+
+ // Now divide by the std deviation so the new standard deviation is 1.0.
+ // The numbers might all be identical (and thus shifted to 0.0 now),
+ // so only scale by the standard deviation if this is not the case.
+ const float std_dev = ComputeStdDev(values, num_vals, 0.0f);
+ if (std_dev > 0.0f) {
+ VLOG(2) << "Std dev is " << std_dev;
+ curr_data = values;
+ for (int i = 0; i < num_vals; ++i) {
+ *curr_data /= std_dev;
+ curr_data++;
+ }
+ }
+}
+
+
+// Returns the determinant of a 2x2 matrix.
+template<class T>
+inline T FindDeterminant2x2(const T* const a) {
+ // Determinant: (ad - bc)
+ return a[0] * a[3] - a[1] * a[2];
+}
+
+
+// Finds the inverse of a 2x2 matrix.
+// Returns true upon success, false if the matrix is not invertible.
+template<class T>
+inline bool Invert2x2(const T* const a, float* const a_inv) {
+ const float det = static_cast<float>(FindDeterminant2x2(a));
+ if (fabs(det) < EPSILON) {
+ return false;
+ }
+ const float inv_det = 1.0f / det;
+
+ a_inv[0] = inv_det * static_cast<float>(a[3]); // d
+ a_inv[1] = inv_det * static_cast<float>(-a[1]); // -b
+ a_inv[2] = inv_det * static_cast<float>(-a[2]); // -c
+ a_inv[3] = inv_det * static_cast<float>(a[0]); // a
+
+ return true;
+}
+
+} // namespace tf_tracking
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/utils_neon.cc b/tensorflow/examples/android/jni/object_tracking/utils_neon.cc
new file mode 100755
index 0000000000..5a5250e32e
--- /dev/null
+++ b/tensorflow/examples/android/jni/object_tracking/utils_neon.cc
@@ -0,0 +1,151 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// NEON implementations of Image methods for compatible devices. Control
+// should never enter this compilation unit on incompatible devices.
+
+#ifdef __ARM_NEON
+
+#include <arm_neon.h>
+
+#include "tensorflow/examples/android/jni/object_tracking/geom.h"
+#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
+#include "tensorflow/examples/android/jni/object_tracking/image.h"
+#include "tensorflow/examples/android/jni/object_tracking/utils.h"
+
+namespace tf_tracking {
+
+inline static float GetSum(const float32x4_t& values) {
+ static float32_t summed_values[4];
+ vst1q_f32(summed_values, values);
+ return summed_values[0]
+ + summed_values[1]
+ + summed_values[2]
+ + summed_values[3];
+}
+
+
+float ComputeMeanNeon(const float* const values, const int num_vals) {
+ SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals);
+
+ const float32_t* const arm_vals = (const float32_t* const) values;
+ float32x4_t accum = vdupq_n_f32(0.0f);
+
+ int offset = 0;
+ for (; offset <= num_vals - 4; offset += 4) {
+ accum = vaddq_f32(accum, vld1q_f32(&arm_vals[offset]));
+ }
+
+ // Pull the accumulated values into a single variable.
+ float sum = GetSum(accum);
+
+ // Get the remaining 1 to 3 values.
+ for (; offset < num_vals; ++offset) {
+ sum += values[offset];
+ }
+
+ const float mean_neon = sum / static_cast<float>(num_vals);
+
+#ifdef SANITY_CHECKS
+ const float mean_cpu = ComputeMeanCpu(values, num_vals);
+ SCHECK(NearlyEqual(mean_neon, mean_cpu, EPSILON * num_vals),
+ "Neon mismatch with CPU mean! %.10f vs %.10f",
+ mean_neon, mean_cpu);
+#endif
+
+ return mean_neon;
+}
+
+
+float ComputeStdDevNeon(const float* const values,
+ const int num_vals, const float mean) {
+ SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals);
+
+ const float32_t* const arm_vals = (const float32_t* const) values;
+ const float32x4_t mean_vec = vdupq_n_f32(-mean);
+
+ float32x4_t accum = vdupq_n_f32(0.0f);
+
+ int offset = 0;
+ for (; offset <= num_vals - 4; offset += 4) {
+ const float32x4_t deltas =
+ vaddq_f32(mean_vec, vld1q_f32(&arm_vals[offset]));
+
+ accum = vmlaq_f32(accum, deltas, deltas);
+ }
+
+ // Pull the accumulated values into a single variable.
+ float squared_sum = GetSum(accum);
+
+ // Get the remaining 1 to 3 values.
+ for (; offset < num_vals; ++offset) {
+ squared_sum += Square(values[offset] - mean);
+ }
+
+ const float std_dev_neon = sqrt(squared_sum / static_cast<float>(num_vals));
+
+#ifdef SANITY_CHECKS
+ const float std_dev_cpu = ComputeStdDevCpu(values, num_vals, mean);
+ SCHECK(NearlyEqual(std_dev_neon, std_dev_cpu, EPSILON * num_vals),
+ "Neon mismatch with CPU std dev! %.10f vs %.10f",
+ std_dev_neon, std_dev_cpu);
+#endif
+
+ return std_dev_neon;
+}
+
+
+float ComputeCrossCorrelationNeon(const float* const values1,
+ const float* const values2,
+ const int num_vals) {
+ SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals);
+
+ const float32_t* const arm_vals1 = (const float32_t* const) values1;
+ const float32_t* const arm_vals2 = (const float32_t* const) values2;
+
+ float32x4_t accum = vdupq_n_f32(0.0f);
+
+ int offset = 0;
+ for (; offset <= num_vals - 4; offset += 4) {
+ accum = vmlaq_f32(accum,
+ vld1q_f32(&arm_vals1[offset]),
+ vld1q_f32(&arm_vals2[offset]));
+ }
+
+ // Pull the accumulated values into a single variable.
+ float sxy = GetSum(accum);
+
+ // Get the remaining 1 to 3 values.
+ for (; offset < num_vals; ++offset) {
+ sxy += values1[offset] * values2[offset];
+ }
+
+ const float cross_correlation_neon = sxy / num_vals;
+
+#ifdef SANITY_CHECKS
+ const float cross_correlation_cpu =
+ ComputeCrossCorrelationCpu(values1, values2, num_vals);
+ SCHECK(NearlyEqual(cross_correlation_neon, cross_correlation_cpu,
+ EPSILON * num_vals),
+ "Neon mismatch with CPU cross correlation! %.10f vs %.10f",
+ cross_correlation_neon, cross_correlation_cpu);
+#endif
+
+ return cross_correlation_neon;
+}
+
+} // namespace tf_tracking
+
+#endif // __ARM_NEON
diff --git a/tensorflow/examples/android/proto/box_coder.proto b/tensorflow/examples/android/proto/box_coder.proto
new file mode 100644
index 0000000000..8576294110
--- /dev/null
+++ b/tensorflow/examples/android/proto/box_coder.proto
@@ -0,0 +1,42 @@
+syntax = "proto2";
+
+package org_tensorflow_demo;
+
+// Prior for a single feature (like minimum x coordinate, width, area, etc.)
+message BoxCoderPrior {
+ optional float mean = 1 [default = 0.0];
+ optional float stddev = 2 [default = 1.0];
+};
+
+// Box encoding/decoding configuration for a single box.
+message BoxCoderOptions {
+ // Number of priors must match the number of values used to encoded
+ // values which is derived from the use_... flags below.
+ repeated BoxCoderPrior priors = 1;
+
+ // Minimum/maximum X/Y of the four corners are used as features.
+ // Order: MinX, MinY, MaxX, MaxY.
+ // Number of values: 4.
+ optional bool use_corners = 2 [default = true];
+
+ // Width and height of the box in this order.
+ // Number of values: 2.
+ optional bool use_width_height = 3 [default = false];
+
+ // Coordinates of the center of the box.
+ // Order: X, Y.
+ // Number of values: 2.
+ optional bool use_center = 4 [default = false];
+
+ // Area of the box.
+ // Number of values: 1.
+ optional bool use_area = 5 [default = false];
+};
+
+// Options for MultiBoxCoder which is a encoder/decoder for a fixed number of
+// boxes.
+// A list of BoxCoderOptions that allows for storing multiple box coder options
+// in a single file.
+message MultiBoxCoderOptions {
+ repeated BoxCoderOptions box_coder = 1;
+};
diff --git a/tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml b/tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml
new file mode 100644
index 0000000000..674f25785a
--- /dev/null
+++ b/tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml
@@ -0,0 +1,30 @@
+<?xml version="1.0" encoding="utf-8"?><!--
+ Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent">
+
+ <org.tensorflow.demo.AutoFitTextureView
+ android:id="@+id/texture"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"/>
+
+ <org.tensorflow.demo.OverlayView
+ android:id="@+id/overlay"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"/>
+
+</FrameLayout>
diff --git a/tensorflow/examples/android/res/values/base-strings.xml b/tensorflow/examples/android/res/values/base-strings.xml
index 93cfe0dac2..f6c57d5030 100644
--- a/tensorflow/examples/android/res/values/base-strings.xml
+++ b/tensorflow/examples/android/res/values/base-strings.xml
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
- Copyright 2013 The TensorFlow Authors. All Rights Reserved.
+ Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -17,5 +17,6 @@
<resources>
<string name="app_name">TensorFlow Demo</string>
- <string name="activity_name_classification">TF Classification</string>
+ <string name="activity_name_classification">TF Classify</string>
+ <string name="activity_name_detection">TF Detect</string>
</resources>
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java b/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java
index e498c9e28f..2f16ded6c2 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java
@@ -17,7 +17,6 @@ package org.tensorflow.demo;
import android.graphics.Bitmap;
import android.graphics.RectF;
-
import java.util.List;
/**
@@ -44,10 +43,8 @@ public interface Classifier {
*/
private final Float confidence;
- /**
- * Optional location within the source image for the location of the recognized object.
- */
- private final RectF location;
+ /** Optional location within the source image for the location of the recognized object. */
+ private RectF location;
public Recognition(
final String id, final String title, final Float confidence, final RectF location) {
@@ -73,6 +70,10 @@ public interface Classifier {
return new RectF(location);
}
+ public void setLocation(RectF location) {
+ this.location = location;
+ }
+
@Override
public String toString() {
String resultString = "";
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java
new file mode 100644
index 0000000000..d75136485a
--- /dev/null
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java
@@ -0,0 +1,317 @@
+/*
+ * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.tensorflow.demo;
+
+import android.graphics.Bitmap;
+import android.graphics.Bitmap.Config;
+import android.graphics.Canvas;
+import android.graphics.Color;
+import android.graphics.Matrix;
+import android.graphics.Paint;
+import android.graphics.Paint.Style;
+import android.graphics.RectF;
+import android.media.Image;
+import android.media.Image.Plane;
+import android.media.ImageReader;
+import android.media.ImageReader.OnImageAvailableListener;
+import android.os.SystemClock;
+import android.os.Trace;
+import android.util.Size;
+import android.util.TypedValue;
+import android.view.Display;
+import java.io.IOException;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Vector;
+import org.tensorflow.demo.OverlayView.DrawCallback;
+import org.tensorflow.demo.env.BorderedText;
+import org.tensorflow.demo.env.ImageUtils;
+import org.tensorflow.demo.env.Logger;
+import org.tensorflow.demo.tracking.MultiBoxTracker;
+
+/**
+ * An activity that uses a TensorFlowMultiboxDetector and ObjectTracker to detect and then track
+ * objects.
+ */
+public class DetectorActivity extends CameraActivity implements OnImageAvailableListener {
+ private static final Logger LOGGER = new Logger();
+
+ private static final int NUM_LOCATIONS = 784;
+ private static final int INPUT_SIZE = 224;
+ private static final int IMAGE_MEAN = 128;
+ private static final float IMAGE_STD = 128;
+ private static final String INPUT_NAME = "ResizeBilinear";
+ private static final String OUTPUT_NAMES = "output_locations/Reshape,output_scores/Reshape";
+
+ private static final String MODEL_FILE = "file:///android_asset/multibox_model.pb";
+ private static final String LOCATION_FILE = "file:///android_asset/multibox_location_priors.pb";
+
+ // Minimum detection confidence to track a detection.
+ private static final float MINIMUM_CONFIDENCE = 0.1f;
+
+ private static final boolean SAVE_PREVIEW_BITMAP = false;
+
+ private static final boolean MAINTAIN_ASPECT = false;
+
+ private static final float TEXT_SIZE_DIP = 18;
+
+ private Integer sensorOrientation;
+
+ private TensorFlowMultiBoxDetector detector;
+
+ private int previewWidth = 0;
+ private int previewHeight = 0;
+ private byte[][] yuvBytes;
+ private int[] rgbBytes = null;
+ private Bitmap rgbFrameBitmap = null;
+ private Bitmap croppedBitmap = null;
+
+ private boolean computing = false;
+
+ private long timestamp = 0;
+
+ private Matrix frameToCropTransform;
+ private Matrix cropToFrameTransform;
+
+ private Bitmap cropCopyBitmap;
+
+ private MultiBoxTracker tracker;
+
+ private byte[] luminance;
+
+ private BorderedText borderedText;
+
+ private long lastProcessingTimeMs;
+
+ @Override
+ public void onPreviewSizeChosen(final Size size, final int rotation) {
+ final float textSizePx =
+ TypedValue.applyDimension(
+ TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
+ borderedText = new BorderedText(textSizePx);
+
+ tracker = new MultiBoxTracker(getResources().getDisplayMetrics());
+
+ detector = new TensorFlowMultiBoxDetector();
+ try {
+ detector.initializeTensorFlow(
+ getAssets(),
+ MODEL_FILE,
+ LOCATION_FILE,
+ NUM_LOCATIONS,
+ INPUT_SIZE,
+ IMAGE_MEAN,
+ IMAGE_STD,
+ INPUT_NAME,
+ OUTPUT_NAMES);
+ } catch (final IOException e) {
+ LOGGER.e(e, "Exception!");
+ }
+
+ previewWidth = size.getWidth();
+ previewHeight = size.getHeight();
+
+ final Display display = getWindowManager().getDefaultDisplay();
+ final int screenOrientation = display.getRotation();
+
+ LOGGER.i("Sensor orientation: %d, Screen orientation: %d", rotation, screenOrientation);
+
+ sensorOrientation = rotation + screenOrientation;
+
+ LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
+ rgbBytes = new int[previewWidth * previewHeight];
+ rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
+ croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888);
+
+ frameToCropTransform =
+ ImageUtils.getTransformationMatrix(
+ previewWidth, previewHeight,
+ INPUT_SIZE, INPUT_SIZE,
+ sensorOrientation, MAINTAIN_ASPECT);
+
+ cropToFrameTransform = new Matrix();
+ frameToCropTransform.invert(cropToFrameTransform);
+ yuvBytes = new byte[3][];
+
+ addCallback(
+ new DrawCallback() {
+ @Override
+ public void drawCallback(final Canvas canvas) {
+ final Bitmap copy = cropCopyBitmap;
+
+ tracker.draw(canvas);
+
+ if (!isDebug()) {
+ return;
+ }
+
+ tracker.drawDebug(canvas);
+
+ if (copy != null) {
+ final Matrix matrix = new Matrix();
+ final float scaleFactor = 2;
+ matrix.postScale(scaleFactor, scaleFactor);
+ matrix.postTranslate(
+ canvas.getWidth() - copy.getWidth() * scaleFactor,
+ canvas.getHeight() - copy.getHeight() * scaleFactor);
+ canvas.drawBitmap(copy, matrix, new Paint());
+
+ final Vector<String> lines = new Vector<String>();
+ lines.add("Frame: " + previewWidth + "x" + previewHeight);
+ lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
+ lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
+ lines.add("Rotation: " + sensorOrientation);
+ lines.add("Inference time: " + lastProcessingTimeMs + "ms");
+
+ int lineNum = 0;
+ for (final String line : lines) {
+ borderedText.drawText(
+ canvas,
+ 10,
+ canvas.getHeight() - 10 - borderedText.getTextSize() * lineNum,
+ line);
+ ++lineNum;
+ }
+ }
+ }
+ });
+ }
+
+ @Override
+ public void onImageAvailable(final ImageReader reader) {
+ Image image = null;
+
+ ++timestamp;
+ final long currTimestamp = timestamp;
+
+ try {
+ image = reader.acquireLatestImage();
+
+ if (image == null) {
+ return;
+ }
+
+ Trace.beginSection("imageAvailable");
+
+ final Plane[] planes = image.getPlanes();
+ fillBytes(planes, yuvBytes);
+
+ tracker.onFrame(
+ previewWidth,
+ previewHeight,
+ planes[0].getRowStride(),
+ sensorOrientation,
+ yuvBytes[0],
+ timestamp);
+
+ requestRender();
+
+ // No mutex needed as this method is not reentrant.
+ if (computing) {
+ image.close();
+ return;
+ }
+ computing = true;
+
+ final int yRowStride = planes[0].getRowStride();
+ final int uvRowStride = planes[1].getRowStride();
+ final int uvPixelStride = planes[1].getPixelStride();
+ ImageUtils.convertYUV420ToARGB8888(
+ yuvBytes[0],
+ yuvBytes[1],
+ yuvBytes[2],
+ rgbBytes,
+ previewWidth,
+ previewHeight,
+ yRowStride,
+ uvRowStride,
+ uvPixelStride,
+ false);
+
+ image.close();
+ } catch (final Exception e) {
+ if (image != null) {
+ image.close();
+ }
+ LOGGER.e(e, "Exception!");
+ Trace.endSection();
+ return;
+ }
+
+ rgbFrameBitmap.setPixels(rgbBytes, 0, previewWidth, 0, 0, previewWidth, previewHeight);
+ final Canvas canvas = new Canvas(croppedBitmap);
+ canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
+
+ // For examining the actual TF input.
+ if (SAVE_PREVIEW_BITMAP) {
+ ImageUtils.saveBitmap(croppedBitmap);
+ }
+
+ if (luminance == null) {
+ luminance = new byte[yuvBytes[0].length];
+ }
+ System.arraycopy(yuvBytes[0], 0, luminance, 0, luminance.length);
+
+ runInBackground(
+ new Runnable() {
+ @Override
+ public void run() {
+ final long startTime = SystemClock.uptimeMillis();
+ final List<Classifier.Recognition> results = detector.recognizeImage(croppedBitmap);
+ lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
+
+ cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
+ final Canvas canvas = new Canvas(cropCopyBitmap);
+ final Paint paint = new Paint();
+ paint.setColor(Color.RED);
+ paint.setStyle(Style.STROKE);
+ paint.setStrokeWidth(2.0f);
+
+ final List<Classifier.Recognition> mappedRecognitions =
+ new LinkedList<Classifier.Recognition>();
+
+ for (final Classifier.Recognition result : results) {
+ final RectF location = result.getLocation();
+ if (location != null && result.getConfidence() >= MINIMUM_CONFIDENCE) {
+ canvas.drawRect(location, paint);
+
+ cropToFrameTransform.mapRect(location);
+ result.setLocation(location);
+ mappedRecognitions.add(result);
+ }
+ }
+
+ tracker.trackResults(mappedRecognitions, luminance, currTimestamp);
+
+ requestRender();
+ computing = false;
+ }
+ });
+
+ Trace.endSection();
+ }
+
+ @Override
+ protected int getLayoutId() {
+ return R.layout.camera_connection_fragment_tracking;
+ }
+
+ @Override
+ protected int getDesiredPreviewFrameSize() {
+ return INPUT_SIZE;
+ }
+}
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
new file mode 100644
index 0000000000..66e25304d3
--- /dev/null
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
@@ -0,0 +1,218 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.demo;
+
+import android.content.res.AssetManager;
+import android.graphics.Bitmap;
+import android.graphics.RectF;
+import android.os.Trace;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.PriorityQueue;
+import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
+import org.tensorflow.demo.env.Logger;
+
+/**
+ * A detector for general purpose object detection as described in Scalable Object Detection using
+ * Deep Neural Networks (https://arxiv.org/abs/1312.2249).
+ */
+public class TensorFlowMultiBoxDetector implements Classifier {
+ private static final Logger LOGGER = new Logger();
+
+ static {
+ System.loadLibrary("tensorflow_demo");
+ }
+
+ // Only return this many results with at least this confidence.
+ private static final int MAX_RESULTS = Integer.MAX_VALUE;
+
+ // Config values.
+ private String inputName;
+ private int inputSize;
+ private int imageMean;
+ private float imageStd;
+
+ // Pre-allocated buffers.
+ private int[] intValues;
+ private float[] floatValues;
+ private float[] outputLocations;
+ private float[] outputScores;
+ private String[] outputNames;
+ private int numLocations;
+
+ private TensorFlowInferenceInterface inferenceInterface;
+
+ private float[] boxPriors;
+
+ /**
+ * Initializes a native TensorFlow session for classifying images.
+ *
+ * @param assetManager The asset manager to be used to load assets.
+ * @param modelFilename The filepath of the model GraphDef protocol buffer.
+ * @param locationFilename The filepath of label file for classes.
+ * @param inputSize The input size. A square image of inputSize x inputSize is assumed.
+ * @param imageMean The assumed mean of the image values.
+ * @param imageStd The assumed std of the image values.
+ * @param inputName The label of the image input node.
+ * @param outputName The label of the output node.
+ * @return The native return value, 0 indicating success.
+ * @throws IOException
+ */
+ public int initializeTensorFlow(
+ final AssetManager assetManager,
+ final String modelFilename,
+ final String locationFilename,
+ final int numLocations,
+ final int inputSize,
+ final int imageMean,
+ final float imageStd,
+ final String inputName,
+ final String outputName)
+ throws IOException {
+ this.inputName = inputName;
+ this.inputSize = inputSize;
+ this.imageMean = imageMean;
+ this.imageStd = imageStd;
+ this.numLocations = numLocations;
+
+ this.boxPriors = new float[numLocations * 8];
+
+ loadCoderOptions(assetManager, locationFilename, boxPriors);
+
+ // Pre-allocate buffers.
+ outputNames = outputName.split(",");
+ intValues = new int[inputSize * inputSize];
+ floatValues = new float[inputSize * inputSize * 3];
+ outputScores = new float[numLocations];
+ outputLocations = new float[numLocations * 4];
+
+ inferenceInterface = new TensorFlowInferenceInterface();
+
+ return inferenceInterface.initializeTensorFlow(assetManager, modelFilename);
+ }
+
+ // Load BoxCoderOptions from native code.
+ private native void loadCoderOptions(
+ AssetManager assetManager, String locationFilename, float[] boxPriors);
+
+ private float[] decodeLocationsEncoding(final float[] locationEncoding) {
+ final float[] locations = new float[locationEncoding.length];
+ boolean nonZero = false;
+ for (int i = 0; i < numLocations; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ final float currEncoding = locationEncoding[4 * i + j];
+ nonZero = nonZero || currEncoding != 0.0f;
+
+ final float mean = boxPriors[i * 8 + j * 2];
+ final float stdDev = boxPriors[i * 8 + j * 2 + 1];
+ float currentLocation = currEncoding * stdDev + mean;
+ currentLocation = Math.max(currentLocation, 0.0f);
+ currentLocation = Math.min(currentLocation, 1.0f);
+ locations[4 * i + j] = currentLocation;
+ }
+ }
+
+ if (!nonZero) {
+ LOGGER.w("No non-zero encodings; check log for inference errors.");
+ }
+ return locations;
+ }
+
+ private float[] decodeScoresEncoding(final float[] scoresEncoding) {
+ final float[] scores = new float[scoresEncoding.length];
+ for (int i = 0; i < scoresEncoding.length; ++i) {
+ scores[i] = 1 / ((float) (1 + Math.exp(-scoresEncoding[i])));
+ }
+ return scores;
+ }
+
+ @Override
+ public List<Recognition> recognizeImage(final Bitmap bitmap) {
+ // Log this method so that it can be analyzed with systrace.
+ Trace.beginSection("recognizeImage");
+
+ Trace.beginSection("preprocessBitmap");
+ // Preprocess the image data from 0-255 int to normalized float based
+ // on the provided parameters.
+ bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
+
+ for (int i = 0; i < intValues.length; ++i) {
+ floatValues[i * 3 + 0] = ((intValues[i] & 0xFF) - imageMean) / imageStd;
+ floatValues[i * 3 + 1] = (((intValues[i] >> 8) & 0xFF) - imageMean) / imageStd;
+ floatValues[i * 3 + 2] = (((intValues[i] >> 16) & 0xFF) - imageMean) / imageStd;
+ }
+ Trace.endSection(); // preprocessBitmap
+
+ // Copy the input data into TensorFlow.
+ Trace.beginSection("fillNodeFloat");
+ inferenceInterface.fillNodeFloat(
+ inputName, new int[] {1, inputSize, inputSize, 3}, floatValues);
+ Trace.endSection();
+
+ // Run the inference call.
+ Trace.beginSection("runInference");
+ inferenceInterface.runInference(outputNames);
+ Trace.endSection();
+
+ // Copy the output Tensor back into the output array.
+ Trace.beginSection("readNodeFloat");
+ final float[] outputScoresEncoding = new float[numLocations];
+ final float[] outputLocationsEncoding = new float[numLocations * 4];
+ inferenceInterface.readNodeFloat(outputNames[0], outputLocationsEncoding);
+ inferenceInterface.readNodeFloat(outputNames[1], outputScoresEncoding);
+ Trace.endSection();
+
+ outputLocations = decodeLocationsEncoding(outputLocationsEncoding);
+ outputScores = decodeScoresEncoding(outputScoresEncoding);
+
+ // Find the best detections.
+ final PriorityQueue<Recognition> pq =
+ new PriorityQueue<Recognition>(
+ 1,
+ new Comparator<Recognition>() {
+ @Override
+ public int compare(final Recognition lhs, final Recognition rhs) {
+ // Intentionally reversed to put high confidence at the head of the queue.
+ return Float.compare(rhs.getConfidence(), lhs.getConfidence());
+ }
+ });
+
+ // Scale them back to the input size.
+ for (int i = 0; i < outputScores.length; ++i) {
+ final RectF detection =
+ new RectF(
+ outputLocations[4 * i] * inputSize,
+ outputLocations[4 * i + 1] * inputSize,
+ outputLocations[4 * i + 2] * inputSize,
+ outputLocations[4 * i + 3] * inputSize);
+ pq.add(new Recognition("" + i, "" + i, outputScores[i], detection));
+ }
+
+ final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
+ for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
+ recognitions.add(pq.poll());
+ }
+ Trace.endSection(); // "recognizeImage"
+ return recognitions;
+ }
+
+ @Override
+ public void close() {
+ inferenceInterface.close();
+ }
+}
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java
new file mode 100644
index 0000000000..24e5cb57df
--- /dev/null
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java
@@ -0,0 +1,381 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.demo.tracking;
+
+import android.graphics.Canvas;
+import android.graphics.Color;
+import android.graphics.Matrix;
+import android.graphics.Paint;
+import android.graphics.Paint.Cap;
+import android.graphics.Paint.Join;
+import android.graphics.Paint.Style;
+import android.graphics.RectF;
+import android.util.DisplayMetrics;
+import android.util.Pair;
+import android.util.TypedValue;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Queue;
+
+import org.tensorflow.demo.Classifier.Recognition;
+import org.tensorflow.demo.env.BorderedText;
+import org.tensorflow.demo.env.ImageUtils;
+import org.tensorflow.demo.env.Logger;
+
+/**
+ * A tracker wrapping ObjectTracker that also handles non-max suppression and matching existing
+ * objects to new detections.
+ */
+public class MultiBoxTracker {
+ private final Logger logger = new Logger();
+
+ private static final float TEXT_SIZE_DIP = 18;
+
+ // Maximum percentage of a box that can be overlapped by another box at detection time. Otherwise
+ // the lower scored box (new or old) will be removed.
+ private static final float MAX_OVERLAP = 0.35f;
+
+ private static final float MIN_SIZE = 16.0f;
+
+ // Allow replacement of the tracked box with new results if
+ // correlation has dropped below this level.
+ private static final float MARGINAL_CORRELATION = 0.75f;
+
+ // Consider object to be lost if correlation falls below this threshold.
+ private static final float MIN_CORRELATION = 0.3f;
+
+ private static final int[] COLORS = {
+ Color.BLUE, Color.RED, Color.GREEN, Color.YELLOW, Color.CYAN, Color.MAGENTA
+ };
+
+ private final Queue<Integer> availableColors = new LinkedList<Integer>();
+
+ public ObjectTracker objectTracker;
+
+ final List<Pair<Float, RectF>> screenRects = new LinkedList<Pair<Float, RectF>>();
+
+ private static class TrackedRecognition {
+ ObjectTracker.TrackedObject trackedObject;
+ float detectionConfidence;
+ int color;
+ }
+
+ private final List<TrackedRecognition> trackedObjects = new LinkedList<TrackedRecognition>();
+
+ private final Paint boxPaint = new Paint();
+
+ private final float textSizePx;
+ private final BorderedText borderedText;
+
+ private Matrix frameToCanvasMatrix;
+
+ private int frameWidth;
+ private int frameHeight;
+
+ private int sensorOrientation;
+
+ public MultiBoxTracker(final DisplayMetrics metrics) {
+ for (final int color : COLORS) {
+ availableColors.add(color);
+ }
+
+ boxPaint.setColor(Color.RED);
+ boxPaint.setStyle(Style.STROKE);
+ boxPaint.setStrokeWidth(12.0f);
+ boxPaint.setStrokeCap(Cap.ROUND);
+ boxPaint.setStrokeJoin(Join.ROUND);
+ boxPaint.setStrokeMiter(100);
+
+ textSizePx = TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, metrics);
+ borderedText = new BorderedText(textSizePx);
+ }
+
+ private Matrix getFrameToCanvasMatrix() {
+ return frameToCanvasMatrix;
+ }
+
+ public synchronized void drawDebug(final Canvas canvas) {
+ final Paint textPaint = new Paint();
+ textPaint.setColor(Color.WHITE);
+ textPaint.setTextSize(60.0f);
+
+ final Paint boxPaint = new Paint();
+ boxPaint.setColor(Color.RED);
+ boxPaint.setAlpha(200);
+ boxPaint.setStyle(Style.STROKE);
+
+ for (final Pair<Float, RectF> detection : screenRects) {
+ final RectF rect = detection.second;
+ canvas.drawRect(rect, boxPaint);
+ canvas.drawText("" + detection.first, rect.left, rect.top, textPaint);
+ borderedText.drawText(canvas, rect.centerX(), rect.centerY(), "" + detection.first);
+ }
+
+ if (objectTracker == null) {
+ return;
+ }
+
+ // Draw correlations.
+ for (final TrackedRecognition recognition : trackedObjects) {
+ final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
+
+ final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame();
+
+ if (getFrameToCanvasMatrix().mapRect(trackedPos)) {
+ final String labelString = String.format("%.2f", trackedObject.getCurrentCorrelation());
+ borderedText.drawText(canvas, trackedPos.right, trackedPos.bottom, labelString);
+ }
+ }
+
+ final Matrix matrix = getFrameToCanvasMatrix();
+ objectTracker.drawDebug(canvas, matrix);
+ }
+
+ public synchronized void trackResults(
+ final List<Recognition> results, final byte[] frame, final long timestamp) {
+ logger.i("Processing %d results from %d", results.size(), timestamp);
+ processResults(timestamp, results, frame);
+ }
+
+ public synchronized void draw(final Canvas canvas) {
+ if (objectTracker == null) {
+ return;
+ }
+
+ // TODO(andrewharp): This may not work for non-90 deg rotations.
+ final float multiplier =
+ Math.min(canvas.getWidth() / (float) frameHeight, canvas.getHeight() / (float) frameWidth);
+ frameToCanvasMatrix =
+ ImageUtils.getTransformationMatrix(
+ frameWidth,
+ frameHeight,
+ (int) (multiplier * frameHeight),
+ (int) (multiplier * frameWidth),
+ sensorOrientation,
+ false);
+
+ for (final TrackedRecognition recognition : trackedObjects) {
+ final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
+
+ final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame();
+
+ if (getFrameToCanvasMatrix().mapRect(trackedPos)) {
+ boxPaint.setColor(recognition.color);
+
+ final float cornerSize = Math.min(trackedPos.width(), trackedPos.height()) / 8.0f;
+ canvas.drawRoundRect(trackedPos, cornerSize, cornerSize, boxPaint);
+
+ final String labelString = String.format("%.2f", recognition.detectionConfidence);
+ borderedText.drawText(canvas, trackedPos.left + cornerSize, trackedPos.bottom, labelString);
+ }
+ }
+ }
+
+ public synchronized void onFrame(
+ final int w,
+ final int h,
+ final int rowStride,
+ final int sensorOrienation,
+ final byte[] frame,
+ final long timestamp) {
+ if (objectTracker == null) {
+ ObjectTracker.clearInstance();
+
+ logger.i("Initializing ObjectTracker: %dx%d", w, h);
+ objectTracker = ObjectTracker.getInstance(w, h, rowStride, true);
+ frameWidth = w;
+ frameHeight = h;
+ this.sensorOrientation = sensorOrienation;
+ }
+
+ objectTracker.nextFrame(frame, null, timestamp, null, true);
+
+ // Clean up any objects not worth tracking any more.
+ final LinkedList<TrackedRecognition> copyList =
+ new LinkedList<TrackedRecognition>(trackedObjects);
+ for (final TrackedRecognition recognition : copyList) {
+ final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject;
+ final float correlation = trackedObject.getCurrentCorrelation();
+ if (correlation < MIN_CORRELATION) {
+ logger.v("Removing tracked object %s because NCC is %.2f", trackedObject, correlation);
+ trackedObject.stopTracking();
+ trackedObjects.remove(recognition);
+
+ availableColors.add(recognition.color);
+ }
+ }
+ }
+
+ private void processResults(
+ final long timestamp, final List<Recognition> results, final byte[] originalFrame) {
+ final List<Pair<Float, RectF>> rectsToTrack = new LinkedList<Pair<Float, RectF>>();
+
+ screenRects.clear();
+ final Matrix rgbFrameToScreen = new Matrix(getFrameToCanvasMatrix());
+
+ for (final Recognition result : results) {
+ if (result.getLocation() == null) {
+ continue;
+ }
+ final RectF detectionFrameRect = new RectF(result.getLocation());
+
+ final RectF detectionScreenRect = new RectF();
+ rgbFrameToScreen.mapRect(detectionScreenRect, detectionFrameRect);
+
+ logger.v(
+ "Result! Frame: " + result.getLocation() + " mapped to screen:" + detectionScreenRect);
+
+ screenRects.add(new Pair<Float, RectF>(result.getConfidence(), detectionScreenRect));
+
+ if (detectionFrameRect.width() < MIN_SIZE || detectionFrameRect.height() < MIN_SIZE) {
+ logger.w("Degenerate rectangle! " + detectionFrameRect);
+ continue;
+ }
+
+ rectsToTrack.add(new Pair<Float, RectF>(result.getConfidence(), detectionFrameRect));
+ }
+
+ if (rectsToTrack.isEmpty()) {
+ logger.v("Nothing to track, aborting.");
+ return;
+ }
+
+ if (objectTracker == null) {
+ logger.w("No ObjectTracker, can't track anything!");
+ return;
+ }
+
+ logger.i("%d rects to track", rectsToTrack.size());
+ for (final Pair<Float, RectF> potential : rectsToTrack) {
+ handleDetection(originalFrame, timestamp, potential);
+ }
+ }
+
+ private void handleDetection(
+ final byte[] frameCopy, final long timestamp, final Pair<Float, RectF> potential) {
+ final ObjectTracker.TrackedObject potentialObject =
+ objectTracker.trackObject(potential.second, timestamp, frameCopy);
+
+ final float potentialCorrelation = potentialObject.getCurrentCorrelation();
+ logger.v(
+ "Tracked object went from %s to %s with correlation %.2f",
+ potential.second, potentialObject.getTrackedPositionInPreviewFrame(), potentialCorrelation);
+
+ if (potentialCorrelation < MARGINAL_CORRELATION) {
+ logger.v("Correlation too low to begin tracking %s.", potentialObject);
+ potentialObject.stopTracking();
+ return;
+ }
+
+ final List<TrackedRecognition> removeList = new LinkedList<TrackedRecognition>();
+
+ float maxIntersect = 0.0f;
+
+ // This is the current tracked object whose color we will take. If left null we'll take the
+ // first one from the color queue.
+ TrackedRecognition recogToReplace = null;
+
+ // Look for intersections that will be overridden by this object or an intersection that would
+ // prevent this one from being placed.
+ for (final TrackedRecognition trackedRecognition : trackedObjects) {
+ final RectF a = trackedRecognition.trackedObject.getTrackedPositionInPreviewFrame();
+ final RectF b = potentialObject.getTrackedPositionInPreviewFrame();
+ final RectF intersection = new RectF();
+ final boolean intersects = intersection.setIntersect(a, b);
+
+ final float intersectAmount =
+ intersection.width()
+ * intersection.height()
+ / Math.min(a.width() * a.height(), b.width() * b.height());
+
+ // If there is an intersection with this currently tracked box above the maximum overlap
+ // percentage allowed, either the new recognition needs to be dismissed or the old
+ // recognition needs to be removed and possibly replaced with the new one.
+ if (intersects && intersectAmount > MAX_OVERLAP) {
+ if (potential.first < trackedRecognition.detectionConfidence
+ && trackedRecognition.trackedObject.getCurrentCorrelation() > MARGINAL_CORRELATION) {
+ // If track for the existing object is still going strong and the detection score was
+ // good, reject this new object.
+ potentialObject.stopTracking();
+ return;
+ } else {
+ removeList.add(trackedRecognition);
+
+ // Let the previously tracked object with max intersection amount donate its color to
+ // the new object.
+ if (intersectAmount > maxIntersect) {
+ maxIntersect = intersectAmount;
+ recogToReplace = trackedRecognition;
+ }
+ }
+ }
+ }
+
+ // If we're already tracking the max object and no intersections were found to bump off,
+ // pick the worst current tracked object to remove, if it's also worse than this candidate
+ // object.
+ if (availableColors.isEmpty() && removeList.isEmpty()) {
+ for (final TrackedRecognition candidate : trackedObjects) {
+ if (candidate.detectionConfidence < potential.first) {
+ if (recogToReplace == null
+ || candidate.detectionConfidence < recogToReplace.detectionConfidence) {
+ // Save it so that we use this color for the new object.
+ recogToReplace = candidate;
+ }
+ }
+ }
+ if (recogToReplace != null) {
+ logger.v("Found non-intersecting object to remove.");
+ removeList.add(recogToReplace);
+ } else {
+ logger.v("No non-intersecting object found to remove");
+ }
+ }
+
+ // Remove everything that got intersected.
+ for (final TrackedRecognition trackedRecognition : removeList) {
+ logger.v(
+ "Removing tracked object %s with detection confidence %.2f, correlation %.2f",
+ trackedRecognition.trackedObject,
+ trackedRecognition.detectionConfidence,
+ trackedRecognition.trackedObject.getCurrentCorrelation());
+ trackedRecognition.trackedObject.stopTracking();
+ trackedObjects.remove(trackedRecognition);
+ if (trackedRecognition != recogToReplace) {
+ availableColors.add(trackedRecognition.color);
+ }
+ }
+
+ if (recogToReplace == null && availableColors.isEmpty()) {
+ logger.e("No room to track this object, aborting.");
+ potentialObject.stopTracking();
+ return;
+ }
+
+ // Finally safe to say we can track this object.
+ logger.v(
+ "Tracking object %s with detection confidence %.2f at position %s",
+ potentialObject, potential.first, potential.second);
+ final TrackedRecognition trackedRecognition = new TrackedRecognition();
+ trackedRecognition.detectionConfidence = potential.first;
+ trackedRecognition.trackedObject = potentialObject;
+
+ // Use the color from a replaced object before taking one from the color queue.
+ trackedRecognition.color =
+ recogToReplace != null ? recogToReplace.color : availableColors.poll();
+ trackedObjects.add(trackedRecognition);
+ }
+}
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java
new file mode 100644
index 0000000000..211d8077a3
--- /dev/null
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java
@@ -0,0 +1,649 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.demo.tracking;
+
+import android.graphics.Canvas;
+import android.graphics.Color;
+import android.graphics.Matrix;
+import android.graphics.Paint;
+import android.graphics.PointF;
+import android.graphics.RectF;
+import android.graphics.Typeface;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Vector;
+import javax.microedition.khronos.opengles.GL10;
+import org.tensorflow.demo.env.Logger;
+import org.tensorflow.demo.env.Size;
+
+/**
+ * True object detector/tracker class that tracks objects across consecutive preview frames.
+ * It provides a simplified Java interface to the analogous native object defined by
+ * jni/client_vision/tracking/object_tracker.*.
+ *
+ * Currently, the ObjectTracker is a singleton due to native code restrictions, and so must
+ * be allocated by ObjectTracker.getInstance(). In addition, release() should be called
+ * as soon as the ObjectTracker is no longer needed, and before a new one is created.
+ *
+ * nextFrame() should be called as new frames become available, preferably as often as possible.
+ *
+ * After allocation, new TrackedObjects may be instantiated via trackObject(). TrackedObjects
+ * are associated with the ObjectTracker that created them, and are only valid while that
+ * ObjectTracker still exists.
+ */
+public class ObjectTracker {
+ private final Logger logger = new Logger();
+
+ private static final boolean DRAW_TEXT = false;
+
+ /**
+ * How many history points to keep track of and draw in the red history line.
+ */
+ private static final int MAX_DEBUG_HISTORY_SIZE = 30;
+
+ /**
+ * How many frames of optical flow deltas to record.
+ * TODO(andrewharp): Push this down to the native level so it can be polled
+ * efficiently into a an array for upload, instead of keeping a duplicate
+ * copy in Java.
+ */
+ private static final int MAX_FRAME_HISTORY_SIZE = 200;
+
+ private static final int DOWNSAMPLE_FACTOR = 2;
+
+ private final byte[] downsampledFrame;
+
+ protected static ObjectTracker instance;
+
+ private final Map<String, TrackedObject> trackedObjects;
+
+ private long lastTimestamp;
+
+ private FrameChange lastKeypoints;
+
+ private final Vector<PointF> debugHistory;
+
+ private final LinkedList<TimestampedDeltas> timestampedDeltas;
+
+ protected final int frameWidth;
+ protected final int frameHeight;
+ private final int rowStride;
+ protected final boolean alwaysTrack;
+
+ private static class TimestampedDeltas {
+ final long timestamp;
+ final byte[] deltas;
+
+ public TimestampedDeltas(final long timestamp, final byte[] deltas) {
+ this.timestamp = timestamp;
+ this.deltas = deltas;
+ }
+ }
+
+ /**
+ * A simple class that records keypoint information, which includes
+ * local location, score and type. This will be used in calculating
+ * FrameChange.
+ */
+ public static class Keypoint {
+ public final float x;
+ public final float y;
+ public final float score;
+ public final int type;
+
+ public Keypoint(final float x, final float y) {
+ this.x = x;
+ this.y = y;
+ this.score = 0;
+ this.type = -1;
+ }
+
+ public Keypoint(final float x, final float y, final float score, final int type) {
+ this.x = x;
+ this.y = y;
+ this.score = score;
+ this.type = type;
+ }
+
+ Keypoint delta(final Keypoint other) {
+ return new Keypoint(this.x - other.x, this.y - other.y);
+ }
+ }
+
+ /**
+ * A simple class that could calculate Keypoint delta.
+ * This class will be used in calculating frame translation delta
+ * for optical flow.
+ */
+ public static class PointChange {
+ public final Keypoint keypointA;
+ public final Keypoint keypointB;
+ Keypoint pointDelta;
+ private final boolean wasFound;
+
+ public PointChange(final float x1, final float y1,
+ final float x2, final float y2,
+ final float score, final int type,
+ final boolean wasFound) {
+ this.wasFound = wasFound;
+
+ keypointA = new Keypoint(x1, y1, score, type);
+ keypointB = new Keypoint(x2, y2);
+ }
+
+ public Keypoint getDelta() {
+ if (pointDelta == null) {
+ pointDelta = keypointB.delta(keypointA);
+ }
+ return pointDelta;
+ }
+ }
+
+ /** A class that records a timestamped frame translation delta for optical flow. */
+ public static class FrameChange {
+ public static final int KEYPOINT_STEP = 7;
+
+ public final Vector<PointChange> pointDeltas;
+
+ private final float minScore;
+ private final float maxScore;
+
+ public FrameChange(final float[] framePoints) {
+ float minScore = 100.0f;
+ float maxScore = -100.0f;
+
+ pointDeltas = new Vector<PointChange>(framePoints.length / KEYPOINT_STEP);
+
+ for (int i = 0; i < framePoints.length; i += KEYPOINT_STEP) {
+ final float x1 = framePoints[i + 0] * DOWNSAMPLE_FACTOR;
+ final float y1 = framePoints[i + 1] * DOWNSAMPLE_FACTOR;
+
+ final boolean wasFound = framePoints[i + 2] > 0.0f;
+
+ final float x2 = framePoints[i + 3] * DOWNSAMPLE_FACTOR;
+ final float y2 = framePoints[i + 4] * DOWNSAMPLE_FACTOR;
+ final float score = framePoints[i + 5];
+ final int type = (int) framePoints[i + 6];
+
+ minScore = Math.min(minScore, score);
+ maxScore = Math.max(maxScore, score);
+
+ pointDeltas.add(new PointChange(x1, y1, x2, y2, score, type, wasFound));
+ }
+
+ this.minScore = minScore;
+ this.maxScore = maxScore;
+ }
+ }
+
+ public static synchronized ObjectTracker getInstance(
+ final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) {
+ if (instance == null) {
+ instance = new ObjectTracker(frameWidth, frameHeight, rowStride, alwaysTrack);
+ instance.init();
+ } else {
+ throw new RuntimeException(
+ "Tried to create a new objectracker before releasing the old one!");
+ }
+ return instance;
+ }
+
+ public static synchronized void clearInstance() {
+ if (instance != null) {
+ instance.release();
+ }
+ }
+
+ protected ObjectTracker(
+ final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) {
+ this.frameWidth = frameWidth;
+ this.frameHeight = frameHeight;
+ this.rowStride = rowStride;
+ this.alwaysTrack = alwaysTrack;
+ this.timestampedDeltas = new LinkedList<TimestampedDeltas>();
+
+ trackedObjects = new HashMap<String, TrackedObject>();
+
+ debugHistory = new Vector<PointF>(MAX_DEBUG_HISTORY_SIZE);
+
+ downsampledFrame =
+ new byte
+ [(frameWidth + DOWNSAMPLE_FACTOR - 1)
+ / DOWNSAMPLE_FACTOR
+ * (frameWidth + DOWNSAMPLE_FACTOR - 1)
+ / DOWNSAMPLE_FACTOR];
+ }
+
+ protected void init() {
+ // The native tracker never sees the full frame, so pre-scale dimensions
+ // by the downsample factor.
+ initNative(frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, alwaysTrack);
+ }
+
+ private final float[] matrixValues = new float[9];
+
+ private long downsampledTimestamp;
+
+ @SuppressWarnings("unused")
+ public synchronized void drawOverlay(final GL10 gl,
+ final Size cameraViewSize, final Matrix matrix) {
+ final Matrix tempMatrix = new Matrix(matrix);
+ tempMatrix.preScale(DOWNSAMPLE_FACTOR, DOWNSAMPLE_FACTOR);
+ tempMatrix.getValues(matrixValues);
+ drawNative(cameraViewSize.width, cameraViewSize.height, matrixValues);
+ }
+
+ public synchronized void nextFrame(
+ final byte[] frameData, final byte[] uvData,
+ final long timestamp, final float[] transformationMatrix,
+ final boolean updateDebugInfo) {
+ if (downsampledTimestamp != timestamp) {
+ ObjectTracker.downsampleImageNative(
+ frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame);
+ downsampledTimestamp = timestamp;
+ }
+
+ // Do Lucas Kanade using the fullframe initializer.
+ nextFrameNative(downsampledFrame, uvData, timestamp, transformationMatrix);
+
+ timestampedDeltas.add(new TimestampedDeltas(timestamp, getKeypointsPacked(DOWNSAMPLE_FACTOR)));
+ while (timestampedDeltas.size() > MAX_FRAME_HISTORY_SIZE) {
+ timestampedDeltas.removeFirst();
+ }
+
+ for (final TrackedObject trackedObject : trackedObjects.values()) {
+ trackedObject.updateTrackedPosition();
+ }
+
+ if (updateDebugInfo) {
+ updateDebugHistory();
+ }
+
+ lastTimestamp = timestamp;
+ }
+
+ public synchronized void release() {
+ releaseMemoryNative();
+ synchronized (ObjectTracker.class) {
+ instance = null;
+ }
+ }
+
+ private void drawHistoryDebug(final Canvas canvas) {
+ drawHistoryPoint(
+ canvas, frameWidth * DOWNSAMPLE_FACTOR / 2, frameHeight * DOWNSAMPLE_FACTOR / 2);
+ }
+
+ private void drawHistoryPoint(final Canvas canvas, final float startX, final float startY) {
+ final Paint p = new Paint();
+ p.setAntiAlias(false);
+ p.setTypeface(Typeface.SERIF);
+
+ p.setColor(Color.RED);
+ p.setStrokeWidth(2.0f);
+
+ // Draw the center circle.
+ p.setColor(Color.GREEN);
+ canvas.drawCircle(startX, startY, 3.0f, p);
+
+ p.setColor(Color.RED);
+
+ // Iterate through in backwards order.
+ synchronized (debugHistory) {
+ final int numPoints = debugHistory.size();
+ float lastX = startX;
+ float lastY = startY;
+ for (int keypointNum = 0; keypointNum < numPoints; ++keypointNum) {
+ final PointF delta = debugHistory.get(numPoints - keypointNum - 1);
+ final float newX = lastX + delta.x;
+ final float newY = lastY + delta.y;
+ canvas.drawLine(lastX, lastY, newX, newY, p);
+ lastX = newX;
+ lastY = newY;
+ }
+ }
+ }
+
+ private static int floatToChar(final float value) {
+ return Math.max(0, Math.min((int) (value * 255.999f), 255));
+ }
+
+ private void drawKeypointsDebug(final Canvas canvas) {
+ final Paint p = new Paint();
+ if (lastKeypoints == null) {
+ return;
+ }
+ final int keypointSize = 3;
+
+ final float minScore = lastKeypoints.minScore;
+ final float maxScore = lastKeypoints.maxScore;
+
+ for (final PointChange keypoint : lastKeypoints.pointDeltas) {
+ if (keypoint.wasFound) {
+ final int r =
+ floatToChar((keypoint.keypointA.score - minScore) / (maxScore - minScore));
+ final int b =
+ floatToChar(1.0f - (keypoint.keypointA.score - minScore) / (maxScore - minScore));
+
+ final int color = 0xFF000000 | (r << 16) | b;
+ p.setColor(color);
+
+ final float[] screenPoints = {keypoint.keypointA.x, keypoint.keypointA.y,
+ keypoint.keypointB.x, keypoint.keypointB.y};
+ canvas.drawRect(screenPoints[2] - keypointSize,
+ screenPoints[3] - keypointSize,
+ screenPoints[2] + keypointSize,
+ screenPoints[3] + keypointSize, p);
+ p.setColor(Color.CYAN);
+ canvas.drawLine(screenPoints[2], screenPoints[3],
+ screenPoints[0], screenPoints[1], p);
+
+ if (DRAW_TEXT) {
+ p.setColor(Color.WHITE);
+ canvas.drawText(keypoint.keypointA.type + ": " + keypoint.keypointA.score,
+ keypoint.keypointA.x, keypoint.keypointA.y, p);
+ }
+ } else {
+ p.setColor(Color.YELLOW);
+ final float[] screenPoint = {keypoint.keypointA.x, keypoint.keypointA.y};
+ canvas.drawCircle(screenPoint[0], screenPoint[1], 5.0f, p);
+ }
+ }
+ }
+
+ private synchronized PointF getAccumulatedDelta(final long timestamp, final float positionX,
+ final float positionY, final float radius) {
+ final RectF currPosition = getCurrentPosition(timestamp,
+ new RectF(positionX - radius, positionY - radius, positionX + radius, positionY + radius));
+ return new PointF(currPosition.centerX() - positionX, currPosition.centerY() - positionY);
+ }
+
+ private synchronized RectF getCurrentPosition(final long timestamp, final RectF
+ oldPosition) {
+ final RectF downscaledFrameRect = downscaleRect(oldPosition);
+
+ final float[] delta = new float[4];
+ getCurrentPositionNative(timestamp, downscaledFrameRect.left, downscaledFrameRect.top,
+ downscaledFrameRect.right, downscaledFrameRect.bottom, delta);
+
+ final RectF newPosition = new RectF(delta[0], delta[1], delta[2], delta[3]);
+
+ return upscaleRect(newPosition);
+ }
+
+ private void updateDebugHistory() {
+ lastKeypoints = new FrameChange(getKeypointsNative(false));
+
+ if (lastTimestamp == 0) {
+ return;
+ }
+
+ final PointF delta =
+ getAccumulatedDelta(
+ lastTimestamp, frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, 100);
+
+ synchronized (debugHistory) {
+ debugHistory.add(delta);
+
+ while (debugHistory.size() > MAX_DEBUG_HISTORY_SIZE) {
+ debugHistory.remove(0);
+ }
+ }
+ }
+
+ public synchronized void drawDebug(final Canvas canvas, final Matrix frameToCanvas) {
+ canvas.save();
+ canvas.setMatrix(frameToCanvas);
+
+ drawHistoryDebug(canvas);
+ drawKeypointsDebug(canvas);
+
+ canvas.restore();
+ }
+
+ public Vector<String> getDebugText() {
+ final Vector<String> lines = new Vector<String>();
+
+ if (lastKeypoints != null) {
+ lines.add("Num keypoints " + lastKeypoints.pointDeltas.size());
+ lines.add("Min score: " + lastKeypoints.minScore);
+ lines.add("Max score: " + lastKeypoints.maxScore);
+ }
+
+ return lines;
+ }
+
+ public synchronized List<byte[]> pollAccumulatedFlowData(final long endFrameTime) {
+ final List<byte[]> frameDeltas = new ArrayList<byte[]>();
+ while (timestampedDeltas.size() > 0) {
+ final TimestampedDeltas currentDeltas = timestampedDeltas.peek();
+ if (currentDeltas.timestamp <= endFrameTime) {
+ frameDeltas.add(currentDeltas.deltas);
+ timestampedDeltas.removeFirst();
+ } else {
+ break;
+ }
+ }
+
+ return frameDeltas;
+ }
+
+ private RectF downscaleRect(final RectF fullFrameRect) {
+ return new RectF(
+ fullFrameRect.left / DOWNSAMPLE_FACTOR,
+ fullFrameRect.top / DOWNSAMPLE_FACTOR,
+ fullFrameRect.right / DOWNSAMPLE_FACTOR,
+ fullFrameRect.bottom / DOWNSAMPLE_FACTOR);
+ }
+
+ private RectF upscaleRect(final RectF downsampledFrameRect) {
+ return new RectF(
+ downsampledFrameRect.left * DOWNSAMPLE_FACTOR,
+ downsampledFrameRect.top * DOWNSAMPLE_FACTOR,
+ downsampledFrameRect.right * DOWNSAMPLE_FACTOR,
+ downsampledFrameRect.bottom * DOWNSAMPLE_FACTOR);
+ }
+
+ /**
+ * A TrackedObject represents a native TrackedObject, and provides access to the
+ * relevant native tracking information available after every frame update. They may
+ * be safely passed around and acessed externally, but will become invalid after
+ * stopTracking() is called or the related creating ObjectTracker is deactivated.
+ *
+ * @author andrewharp@google.com (Andrew Harp)
+ */
+ public class TrackedObject {
+ private final String id;
+
+ private long lastExternalPositionTime;
+
+ private RectF lastTrackedPosition;
+ private boolean visibleInLastFrame;
+
+ private boolean isDead;
+
+ TrackedObject(final RectF position, final long timestamp, final byte[] data) {
+ isDead = false;
+
+ id = Integer.toString(this.hashCode());
+
+ lastExternalPositionTime = timestamp;
+
+ synchronized (ObjectTracker.this) {
+ registerInitialAppearance(position, data);
+ setPreviousPosition(position, timestamp);
+ trackedObjects.put(id, this);
+ }
+ }
+
+ public void stopTracking() {
+ checkValidObject();
+
+ synchronized (ObjectTracker.this) {
+ isDead = true;
+ forgetNative(id);
+ trackedObjects.remove(id);
+ }
+ }
+
+ public float getCurrentCorrelation() {
+ checkValidObject();
+ return ObjectTracker.this.getCurrentCorrelation(id);
+ }
+
+ void registerInitialAppearance(final RectF position, final byte[] data) {
+ final RectF externalPosition = downscaleRect(position);
+ registerNewObjectWithAppearanceNative(id,
+ externalPosition.left, externalPosition.top,
+ externalPosition.right, externalPosition.bottom,
+ data);
+ }
+
+ synchronized void setPreviousPosition(final RectF position, final long timestamp) {
+ checkValidObject();
+ synchronized (ObjectTracker.this) {
+ if (lastExternalPositionTime > timestamp) {
+ logger.w("Tried to use older position time!");
+ return;
+ }
+ final RectF externalPosition = downscaleRect(position);
+ lastExternalPositionTime = timestamp;
+
+ setPreviousPositionNative(id,
+ externalPosition.left, externalPosition.top,
+ externalPosition.right, externalPosition.bottom,
+ lastExternalPositionTime);
+
+ updateTrackedPosition();
+ }
+ }
+
+ void setCurrentPosition(final RectF position) {
+ checkValidObject();
+ final RectF downsampledPosition = downscaleRect(position);
+ synchronized (ObjectTracker.this) {
+ setCurrentPositionNative(id,
+ downsampledPosition.left, downsampledPosition.top,
+ downsampledPosition.right, downsampledPosition.bottom);
+ }
+ }
+
+ private synchronized void updateTrackedPosition() {
+ checkValidObject();
+
+ final float[] delta = new float[4];
+ getTrackedPositionNative(id, delta);
+ lastTrackedPosition = new RectF(delta[0], delta[1], delta[2], delta[3]);
+
+ visibleInLastFrame = isObjectVisible(id);
+ }
+
+ public synchronized RectF getTrackedPositionInPreviewFrame() {
+ checkValidObject();
+
+ if (lastTrackedPosition == null) {
+ return null;
+ }
+ return upscaleRect(lastTrackedPosition);
+ }
+
+ synchronized long getLastExternalPositionTime() {
+ return lastExternalPositionTime;
+ }
+
+ public synchronized boolean visibleInLastPreviewFrame() {
+ return visibleInLastFrame;
+ }
+
+ private void checkValidObject() {
+ if (isDead) {
+ throw new RuntimeException("TrackedObject already removed from tracking!");
+ } else if (ObjectTracker.this != instance) {
+ throw new RuntimeException("TrackedObject created with another ObjectTracker!");
+ }
+ }
+ }
+
+ public synchronized TrackedObject trackObject(
+ final RectF position, final long timestamp, final byte[] frameData) {
+ if (downsampledTimestamp != timestamp) {
+ ObjectTracker.downsampleImageNative(
+ frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame);
+ downsampledTimestamp = timestamp;
+ }
+ return new TrackedObject(position, timestamp, downsampledFrame);
+ }
+
+ public synchronized TrackedObject trackObject(final RectF position, final byte[] frameData) {
+ return new TrackedObject(position, lastTimestamp, frameData);
+ }
+
+ /*********************** NATIVE CODE *************************************/
+
+ /**
+ * This will contain an opaque pointer to the native ObjectTracker
+ */
+ private int nativeObjectTracker;
+
+ private native void initNative(int imageWidth, int imageHeight, boolean alwaysTrack);
+
+ protected native void registerNewObjectWithAppearanceNative(
+ String objectId, float x1, float y1, float x2, float y2, byte[] data);
+
+ protected native void setPreviousPositionNative(
+ String objectId, float x1, float y1, float x2, float y2, long timestamp);
+
+ protected native void setCurrentPositionNative(
+ String objectId, float x1, float y1, float x2, float y2);
+
+ protected native void forgetNative(String key);
+
+ protected native String getModelIdNative(String key);
+
+ protected native boolean haveObject(String key);
+ protected native boolean isObjectVisible(String key);
+ protected native float getCurrentCorrelation(String key);
+
+ protected native float getMatchScore(String key);
+
+ protected native void getTrackedPositionNative(String key, float[] points);
+
+ protected native void nextFrameNative(
+ byte[] frameData, byte[] uvData, long timestamp, float[] frameAlignMatrix);
+
+ protected native void releaseMemoryNative();
+
+ protected native void getCurrentPositionNative(long timestamp,
+ final float positionX1, final float positionY1,
+ final float positionX2, final float positionY2,
+ final float[] delta);
+
+ protected native byte[] getKeypointsPacked(float scaleFactor);
+
+ protected native float[] getKeypointsNative(boolean onlyReturnCorrespondingKeypoints);
+
+ protected native void drawNative(int viewWidth, int viewHeight, float[] frameToCanvas);
+
+ protected static native void downsampleImageNative(
+ int width, int height, int rowStride, byte[] input, int factor, byte[] output);
+
+ static {
+ System.loadLibrary("tensorflow_demo");
+ }
+}
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py
index e9ca4e5520..ef3d21767a 100644
--- a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py
+++ b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py
@@ -91,7 +91,7 @@ def run_training():
sess.run(init_op)
# Instantiate a SummaryWriter to output summaries and the Graph.
- summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
+ summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
# Start input enqueue threads.
coord = tf.train.Coordinator()
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py
index 4e41ab18e3..392309d543 100644
--- a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py
+++ b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py
@@ -101,7 +101,7 @@ def run_training():
feed_dict={labels_initializer: data_sets.train.labels})
# Instantiate a SummaryWriter to output summaries and the Graph.
- summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
+ summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
# Start input enqueue threads.
coord = tf.train.Coordinator()
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 6b34e57b8f..ca8c9358b3 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -85,14 +85,6 @@ from tensorflow.python.util import compat
FLAGS = None
-# Input and output file flags.
-
-# Details of the training configuration.
-
-# File-system cache locations.
-
-# Controls the distortions used during training.
-
# These are all parameters that are tied to the particular model architecture
# we're using for Inception v3. These include things like tensor names and their
# sizes. If you want to adapt this script to work with another model, you will
@@ -455,7 +447,8 @@ def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
Args:
sess: Current TensorFlow Session.
image_lists: Dictionary of training images for each label.
- how_many: The number of bottleneck values to return.
+ how_many: If positive, a random sample of this size will be chosen.
+ If negative, all bottlenecks will be retrieved.
category: Name string of which set to pull from - training, testing, or
validation.
bottleneck_dir: Folder string holding cached files of bottleneck values.
@@ -465,24 +458,47 @@ def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
bottleneck_tensor: The bottleneck output layer of the CNN graph.
Returns:
- List of bottleneck arrays and their corresponding ground truths.
+ List of bottleneck arrays, their corresponding ground truths, and the
+ relevant filenames.
"""
class_count = len(image_lists.keys())
bottlenecks = []
ground_truths = []
- for unused_i in range(how_many):
- label_index = random.randrange(class_count)
- label_name = list(image_lists.keys())[label_index]
- image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
- bottleneck = get_or_create_bottleneck(sess, image_lists, label_name,
- image_index, image_dir, category,
- bottleneck_dir, jpeg_data_tensor,
- bottleneck_tensor)
- ground_truth = np.zeros(class_count, dtype=np.float32)
- ground_truth[label_index] = 1.0
- bottlenecks.append(bottleneck)
- ground_truths.append(ground_truth)
- return bottlenecks, ground_truths
+ filenames = []
+ if how_many >= 0:
+ # Retrieve a random sample of bottlenecks.
+ for unused_i in range(how_many):
+ label_index = random.randrange(class_count)
+ label_name = list(image_lists.keys())[label_index]
+ image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
+ image_name = get_image_path(image_lists, label_name, image_index,
+ image_dir, category)
+ bottleneck = get_or_create_bottleneck(sess, image_lists, label_name,
+ image_index, image_dir, category,
+ bottleneck_dir, jpeg_data_tensor,
+ bottleneck_tensor)
+ ground_truth = np.zeros(class_count, dtype=np.float32)
+ ground_truth[label_index] = 1.0
+ bottlenecks.append(bottleneck)
+ ground_truths.append(ground_truth)
+ filenames.append(image_name)
+ else:
+ # Retrieve all bottlenecks.
+ for label_index, label_name in enumerate(image_lists.keys()):
+ for image_index, image_name in enumerate(
+ image_lists[label_name][category]):
+ image_name = get_image_path(image_lists, label_name, image_index,
+ image_dir, category)
+ bottleneck = get_or_create_bottleneck(sess, image_lists, label_name,
+ image_index, image_dir, category,
+ bottleneck_dir, jpeg_data_tensor,
+ bottleneck_tensor)
+ ground_truth = np.zeros(class_count, dtype=np.float32)
+ ground_truth[label_index] = 1.0
+ bottlenecks.append(bottleneck)
+ ground_truths.append(ground_truth)
+ filenames.append(image_name)
+ return bottlenecks, ground_truths, filenames
def get_random_distorted_bottlenecks(
@@ -729,16 +745,17 @@ def add_evaluation_step(result_tensor, ground_truth_tensor):
into.
Returns:
- Nothing.
+ Tuple of (evaluation step, prediction).
"""
with tf.name_scope('accuracy'):
with tf.name_scope('correct_prediction'):
- correct_prediction = tf.equal(tf.argmax(result_tensor, 1), \
- tf.argmax(ground_truth_tensor, 1))
+ prediction = tf.argmax(result_tensor, 1)
+ correct_prediction = tf.equal(
+ prediction, tf.argmax(ground_truth_tensor, 1))
with tf.name_scope('accuracy'):
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', evaluation_step)
- return evaluation_step
+ return evaluation_step, prediction
def main(_):
@@ -788,13 +805,14 @@ def main(_):
bottleneck_tensor)
# Create the operations we need to evaluate the accuracy of our new layer.
- evaluation_step = add_evaluation_step(final_tensor, ground_truth_input)
+ evaluation_step, prediction = add_evaluation_step(
+ final_tensor, ground_truth_input)
# Merge all the summaries and write them out to /tmp/retrain_logs (by default)
merged = tf.summary.merge_all()
- train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/train',
- sess.graph)
- validation_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/validation')
+ train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
+ sess.graph)
+ validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation')
# Set up all our weights to their initial default values.
init = tf.global_variables_initializer()
@@ -810,7 +828,7 @@ def main(_):
FLAGS.image_dir, distorted_jpeg_data_tensor,
distorted_image_tensor, resized_image_tensor, bottleneck_tensor)
else:
- train_bottlenecks, train_ground_truth = get_random_cached_bottlenecks(
+ train_bottlenecks, train_ground_truth, _ = get_random_cached_bottlenecks(
sess, image_lists, FLAGS.train_batch_size, 'training',
FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
bottleneck_tensor)
@@ -832,7 +850,7 @@ def main(_):
train_accuracy * 100))
print('%s: Step %d: Cross entropy = %f' % (datetime.now(), i,
cross_entropy_value))
- validation_bottlenecks, validation_ground_truth = (
+ validation_bottlenecks, validation_ground_truth, _ = (
get_random_cached_bottlenecks(
sess, image_lists, FLAGS.validation_batch_size, 'validation',
FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
@@ -844,20 +862,29 @@ def main(_):
feed_dict={bottleneck_input: validation_bottlenecks,
ground_truth_input: validation_ground_truth})
validation_writer.add_summary(validation_summary, i)
- print('%s: Step %d: Validation accuracy = %.1f%%' %
- (datetime.now(), i, validation_accuracy * 100))
+ print('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' %
+ (datetime.now(), i, validation_accuracy * 100,
+ len(validation_bottlenecks)))
# We've completed all our training, so run a final test evaluation on
# some new images we haven't used before.
- test_bottlenecks, test_ground_truth = get_random_cached_bottlenecks(
- sess, image_lists, FLAGS.test_batch_size, 'testing',
- FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
- bottleneck_tensor)
- test_accuracy = sess.run(
- evaluation_step,
+ test_bottlenecks, test_ground_truth, test_filenames = (
+ get_random_cached_bottlenecks(sess, image_lists, FLAGS.test_batch_size,
+ 'testing', FLAGS.bottleneck_dir,
+ FLAGS.image_dir, jpeg_data_tensor,
+ bottleneck_tensor))
+ test_accuracy, predictions = sess.run(
+ [evaluation_step, prediction],
feed_dict={bottleneck_input: test_bottlenecks,
ground_truth_input: test_ground_truth})
- print('Final test accuracy = %.1f%%' % (test_accuracy * 100))
+ print('Final test accuracy = %.1f%% (N=%d)' % (
+ test_accuracy * 100, len(test_bottlenecks)))
+
+ if FLAGS.print_misclassified_test_images:
+ print('=== MISCLASSIFIED TEST IMAGES ===')
+ for i, test_filename in enumerate(test_filenames):
+ if predictions[i] != test_ground_truth[i].argmax():
+ print('%70s %s' % (test_filename, image_lists.keys()[predictions[i]]))
# Write out the trained graph and labels with the weights stored as constants.
output_graph_def = graph_util.convert_variables_to_constants(
@@ -933,10 +960,12 @@ if __name__ == '__main__':
parser.add_argument(
'--test_batch_size',
type=int,
- default=500,
+ default=-1,
help="""\
- How many images to test on at a time. This test set is only used
- infrequently to verify the overall accuracy of the model.\
+ How many images to test on. This test set is only used once, to evaluate
+ the final accuracy of the model after training completes.
+ A value of -1 causes the entire test set to be used, which leads to more
+ stable results across runs.\
"""
)
parser.add_argument(
@@ -946,10 +975,21 @@ if __name__ == '__main__':
help="""\
How many images to use in an evaluation batch. This validation set is
used much more often than the test set, and is an early indicator of how
- accurate the model is during training.\
+ accurate the model is during training.
+ A value of -1 causes the entire validation set to be used, which leads to
+ more stable results across training iterations, but may be slower on large
+ training sets.\
"""
)
parser.add_argument(
+ '--print_misclassified_test_images',
+ default=False,
+ help="""\
+ Whether to print out a list of all misclassified test images.\
+ """,
+ action='store_true'
+ )
+ parser.add_argument(
'--model_dir',
type=str,
default='/tmp/imagenet',
diff --git a/tensorflow/examples/learn/wide_n_deep_tutorial.py b/tensorflow/examples/learn/wide_n_deep_tutorial.py
index 888bf33b48..1fe12f7b76 100644
--- a/tensorflow/examples/learn/wide_n_deep_tutorial.py
+++ b/tensorflow/examples/learn/wide_n_deep_tutorial.py
@@ -56,7 +56,7 @@ def maybe_download():
train_file_name = FLAGS.train_data
else:
train_file = tempfile.NamedTemporaryFile(delete=False)
- urllib.request.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", train_file.name) # pylint: disable=line-too-long
+ urllib.request.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data", train_file.name) # pylint: disable=line-too-long
train_file_name = train_file.name
train_file.close()
print("Training data is downloaded to %s" % train_file_name)
@@ -65,7 +65,7 @@ def maybe_download():
test_file_name = FLAGS.test_data
else:
test_file = tempfile.NamedTemporaryFile(delete=False)
- urllib.request.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", test_file.name) # pylint: disable=line-too-long
+ urllib.request.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test", test_file.name) # pylint: disable=line-too-long
test_file_name = test_file.name
test_file.close()
print("Test data is downloaded to %s" % test_file_name)
diff --git a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py
index f24a558dc2..be50f4529f 100644
--- a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py
+++ b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py
@@ -152,7 +152,7 @@ def run_training():
sess = tf.Session()
# Instantiate a SummaryWriter to output summaries and the Graph.
- summary_writer = tf.train.SummaryWriter(FLAGS.log_dir, sess.graph)
+ summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
# And then after everything is built:
diff --git a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
index 0e12a6571b..83879d0807 100644
--- a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
+++ b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
@@ -137,9 +137,8 @@ def train():
# Merge all the summaries and write them out to /tmp/mnist_logs (by default)
merged = tf.summary.merge_all()
- train_writer = tf.train.SummaryWriter(FLAGS.log_dir + '/train',
- sess.graph)
- test_writer = tf.train.SummaryWriter(FLAGS.log_dir + '/test')
+ train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
+ test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')
tf.global_variables_initializer().run()
# Train the model, and also write summaries.
diff --git a/tensorflow/g3doc/api_docs/python/client.md b/tensorflow/g3doc/api_docs/python/client.md
index 76571a0aff..fbd1bf5808 100644
--- a/tensorflow/g3doc/api_docs/python/client.md
+++ b/tensorflow/g3doc/api_docs/python/client.md
@@ -103,8 +103,8 @@ and evaluate every `Tensor` in `fetches`, substituting the values in
`feed_dict` for the corresponding input values.
The `fetches` argument may be a single graph element, or an arbitrarily
-nested list, tuple, namedtuple, or dict containing graph elements at its
-leaves. A graph element can be one of the following types:
+nested list, tuple, namedtuple, dict, or OrderedDict containing graph
+elements at its leaves. A graph element can be one of the following types:
* An [`Operation`](../../api_docs/python/framework.md#Operation).
The corresponding fetched value will be `None`.
diff --git a/tensorflow/g3doc/api_docs/python/contrib.distributions.md b/tensorflow/g3doc/api_docs/python/contrib.distributions.md
index bd1a9db7bb..da86d2cad1 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.distributions.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.distributions.md
@@ -20891,7 +20891,7 @@ log_normal = ds.TransformedDistribution(
forward_fn=tf.exp,
inverse_fn=tf.log,
inverse_log_det_jacobian_fn=(
- lambda y: -tf.reduce_sum(tf.log(x), reduction_indices=-1)),
+ lambda y: -tf.reduce_sum(tf.log(y), reduction_indices=-1)),
name="LogNormalTransformedDistribution")
```
@@ -20913,7 +20913,7 @@ Construct a Transformed Distribution.
##### Args:
-* <b>`distribution`</b>: The base distribution class to transform. Typically an
+* <b>`distribution`</b>: The base distribution instance to transform. Typically an
instance of `Distribution`.
* <b>`bijector`</b>: The object responsible for calculating the transformation.
Typically an instance of `Bijector`.
@@ -20987,10 +20987,10 @@ cdf(x) := P[X <= x]
Additional documentation from `TransformedDistribution`:
-##### <b>`condition_kwargs`</b>:
+##### `condition_kwargs`:
-* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector.
-* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution.
+* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector.
+* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution.
##### Args:
@@ -21128,10 +21128,10 @@ a more accurate answer than simply taking the logarithm of the `cdf` when
Additional documentation from `TransformedDistribution`:
-##### <b>`condition_kwargs`</b>:
+##### `condition_kwargs`:
-* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector.
-* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution.
+* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector.
+* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution.
##### Args:
@@ -21212,10 +21212,10 @@ Implements `(log o p o g^{-1})(y) + (log o det o J o g^{-1})(y)`,
Also raises a `ValueError` if `inverse` was not provided to the
distribution and `y` was not returned from `sample`.
-##### <b>`condition_kwargs`</b>:
+##### `condition_kwargs`:
-* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector.
-* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution.
+* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector.
+* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution.
##### Args:
@@ -21251,10 +21251,10 @@ survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`.
Additional documentation from `TransformedDistribution`:
-##### <b>`condition_kwargs`</b>:
+##### `condition_kwargs`:
-* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector.
-* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution.
+* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector.
+* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution.
##### Args:
@@ -21404,10 +21404,10 @@ Implements `p(g^{-1}(y)) det|J(g^{-1}(y))|`, where `g^{-1}` is the
Also raises a `ValueError` if `inverse` was not provided to the
distribution and `y` was not returned from `sample`.
-##### <b>`condition_kwargs`</b>:
+##### `condition_kwargs`:
-* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector.
-* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution.
+* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector.
+* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution.
##### Args:
@@ -21458,10 +21458,10 @@ Additional documentation from `TransformedDistribution`:
Samples from the base distribution and then passes through
the bijector's forward transform.
-##### <b>`condition_kwargs`</b>:
+##### `condition_kwargs`:
-* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector.
-* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution.
+* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector.
+* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution.
##### Args:
@@ -21507,10 +21507,10 @@ survival_function(x) = P[X > x]
Additional documentation from `TransformedDistribution`:
-##### <b>`condition_kwargs`</b>:
+##### `condition_kwargs`:
-* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector.
-* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution.
+* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector.
+* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution.
##### Args:
diff --git a/tensorflow/g3doc/api_docs/python/contrib.framework.md b/tensorflow/g3doc/api_docs/python/contrib.framework.md
index 0b8690cf2d..49892fdcaf 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.framework.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.framework.md
@@ -37,14 +37,15 @@ be `dtypes.float32` or `dtypes.float64`. If neither `tensors` nor
- - -
-### `tf.contrib.framework.assert_scalar_int(tensor)` {#assert_scalar_int}
+### `tf.contrib.framework.assert_scalar_int(tensor, name=None)` {#assert_scalar_int}
Assert `tensor` is 0-D, of type `tf.int32` or `tf.int64`.
##### Args:
-* <b>`tensor`</b>: Tensor to test.
+* <b>`tensor`</b>: `Tensor` to test.
+* <b>`name`</b>: Name of the op and of the new `Tensor` if one is created.
##### Returns:
@@ -309,7 +310,7 @@ to the rest of the docstring.
- - -
-### `tf.contrib.framework.deprecated_args(date, instructions, *deprecated_arg_names)` {#deprecated_args}
+### `tf.contrib.framework.deprecated_args(date, instructions, *deprecated_arg_names_or_tuples)` {#deprecated_args}
Decorator for marking specific function arguments as deprecated.
@@ -333,7 +334,10 @@ prepended to the rest of the docstring.
ISO 8601 (YYYY-MM-DD).
* <b>`instructions`</b>: String. Instructions on how to update code using the
deprecated function.
-* <b>`*deprecated_arg_names`</b>: String. The deprecated arguments.
+* <b>`*deprecated_arg_names_or_tuples`</b>: String. or 2-Tuple(String,
+ [ok_vals]). The string is the deprecated argument name.
+ Optionally, an ok-value may be provided. If the user provided
+ argument equals this value, the warning is suppressed.
##### Returns:
@@ -342,8 +346,10 @@ prepended to the rest of the docstring.
##### Raises:
-* <b>`ValueError`</b>: If date is not in ISO 8601 format, instructions are empty, or
- the deprecated arguments are not present in the function signature.
+* <b>`ValueError`</b>: If date is not in ISO 8601 format, instructions are
+ empty, the deprecated arguments are not present in the function
+ signature, or the second element of a deprecated_tuple is not a
+ list.
- - -
@@ -865,6 +871,11 @@ Gets an existing model variable with these parameters or creates a new one.
device.
* <b>`device`</b>: Optional device to place the variable. It can be an string or a
function that is called to get the device for the variable.
+* <b>`partitioner`</b>: Optional callable that accepts a fully defined `TensorShape`
+ and dtype of the `Variable` to be created, and returns a list of
+ partitions for each axis (currently only one axis can be partitioned).
+* <b>`custom_getter`</b>: Callable that allows overwriting the internal
+ get_variable method and has to have the same signature.
##### Returns:
@@ -896,6 +907,11 @@ Gets an existing variable with these parameters or creates a new one.
device.
* <b>`device`</b>: Optional device to place the variable. It can be an string or a
function that is called to get the device for the variable.
+* <b>`partitioner`</b>: Optional callable that accepts a fully defined `TensorShape`
+ and dtype of the `Variable` to be created, and returns a list of
+ partitions for each axis (currently only one axis can be partitioned).
+* <b>`custom_getter`</b>: Callable that allows overwriting the internal
+ get_variable method and has to have the same signature.
##### Returns:
diff --git a/tensorflow/g3doc/api_docs/python/contrib.layers.md b/tensorflow/g3doc/api_docs/python/contrib.layers.md
index a6fcf6f270..3babfa0cab 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.layers.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.layers.md
@@ -635,53 +635,6 @@ to produce the end result.
- - -
-### `tf.stack(values, axis=0, name='stack')` {#stack}
-
-Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
-
-Packs the list of tensors in `values` into a tensor with rank one higher than
-each tensor in `values`, by packing them along the `axis` dimension.
-Given a list of length `N` of tensors of shape `(A, B, C)`;
-
-if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
-if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
-Etc.
-
-For example:
-
-```prettyprint
-# 'x' is [1, 4]
-# 'y' is [2, 5]
-# 'z' is [3, 6]
-stack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim.
-stack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]]
-```
-
-This is the opposite of unstack. The numpy equivalent is
-
- tf.stack([x, y, z]) = np.asarray([x, y, z])
-
-##### Args:
-
-
-* <b>`values`</b>: A list of `Tensor` objects with the same shape and type.
-* <b>`axis`</b>: An `int`. The axis to stack along. Defaults to the first dimension.
- Supports negative indexes.
-* <b>`name`</b>: A name for this operation (optional).
-
-##### Returns:
-
-
-* <b>`output`</b>: A stacked `Tensor` with the same type as `values`.
-
-##### Raises:
-
-
-* <b>`ValueError`</b>: If `axis` is out of the range [-(R+1), R+1).
-
-
-- - -
-
### `tf.contrib.layers.unit_norm(*args, **kwargs)` {#unit_norm}
Normalizes the given input across the specified dimension to unit length.
@@ -710,6 +663,9 @@ Note that the rank of `input` must be known.
Aliases for fully_connected which set a default activation function are
available: `relu`, `relu6` and `linear`.
+`stack` operation is also available. It builds a stack of layers by applying
+a layer repeatedly.
+
## Regularizers
Regularization can help prevent overfitting. These have the signature
@@ -1230,7 +1186,7 @@ Creates a _CrossedColumn for performing feature crosses.
- - -
-### `tf.contrib.layers.embedding_column(sparse_id_column, dimension, combiner=None, initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None)` {#embedding_column}
+### `tf.contrib.layers.embedding_column(sparse_id_column, dimension, combiner=None, initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None)` {#embedding_column}
Creates an `_EmbeddingColumn` for feeding sparse data into a DNN.
@@ -1258,6 +1214,8 @@ Creates an `_EmbeddingColumn` for feeding sparse data into a DNN.
* <b>`tensor_name_in_ckpt`</b>: (Optional). Name of the `Tensor` in the provided
checkpoint from which to restore the column weights. Required if
`ckpt_to_load_from` is not None.
+* <b>`max_norm`</b>: (Optional). If not None, embedding values are l2-normalized to
+ the value of max_norm.
##### Returns:
@@ -1582,7 +1540,7 @@ Creates a `_RealValuedColumn` for dense numeric data.
- - -
-### `tf.contrib.layers.shared_embedding_columns(sparse_id_columns, dimension, combiner=None, shared_embedding_name=None, initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None)` {#shared_embedding_columns}
+### `tf.contrib.layers.shared_embedding_columns(sparse_id_columns, dimension, combiner=None, shared_embedding_name=None, initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None)` {#shared_embedding_columns}
Creates a list of `_EmbeddingColumn` sharing the same embedding.
@@ -1613,6 +1571,8 @@ Creates a list of `_EmbeddingColumn` sharing the same embedding.
* <b>`tensor_name_in_ckpt`</b>: (Optional). Name of the `Tensor` in the provided
checkpoint from which to restore the column weights. Required if
`ckpt_to_load_from` is not None.
+* <b>`max_norm`</b>: (Optional). If not None, embedding values are l2-normalized to
+ the value of max_norm.
##### Returns:
diff --git a/tensorflow/g3doc/api_docs/python/contrib.learn.md b/tensorflow/g3doc/api_docs/python/contrib.learn.md
index 682d6ff930..1808bb94e2 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.learn.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.learn.md
@@ -459,6 +459,39 @@ The signature of the input_fn accepted by export is changing to be consistent wi
- - -
+#### `tf.contrib.learn.Estimator.export_savedmodel(*args, **kwargs)` {#Estimator.export_savedmodel}
+
+Exports inference graph as a SavedModel into given dir. (experimental)
+
+THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning.
+
+
+ Args:
+ export_dir_base: A string containing a directory to write the exported
+ graph and checkpoints.
+ input_fn: A function that takes no argument and
+ returns an `InputFnOps`.
+ default_output_alternative_key: the name of the head to serve when none is
+ specified.
+ assets_extra: A dict specifying how to populate the assets.extra directory
+ within the exported SavedModel. Each key should give the destination
+ path (including the filename) relative to the assets.extra directory.
+ The corresponding value gives the full path of the source file to be
+ copied. For example, the simple case of copying a single file without
+ renaming it is specified as
+ `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
+ as_text: whether to write the SavedModel proto in text format.
+ exports_to_keep: Number of exports to keep.
+
+ Returns:
+ The string path to the exported directory.
+
+ Raises:
+ ValueError: if an unrecognized export_type is requested.
+
+
+- - -
+
#### `tf.contrib.learn.Estimator.fit(*args, **kwargs)` {#Estimator.fit}
See `Trainable`. (deprecated arguments)
@@ -842,7 +875,7 @@ Input of `fit` and `evaluate` should have following features,
whose `value` is a `Tensor`.
- - -
-#### `tf.contrib.learn.DNNClassifier.__init__(hidden_units, feature_columns, model_dir=None, n_classes=2, weight_column_name=None, optimizer=None, activation_fn=relu, dropout=None, gradient_clip_norm=None, enable_centered_bias=False, config=None, feature_engineering_fn=None)` {#DNNClassifier.__init__}
+#### `tf.contrib.learn.DNNClassifier.__init__(hidden_units, feature_columns, model_dir=None, n_classes=2, weight_column_name=None, optimizer=None, activation_fn=relu, dropout=None, gradient_clip_norm=None, enable_centered_bias=False, config=None, feature_engineering_fn=None, embedding_lr_multipliers=None)` {#DNNClassifier.__init__}
Initializes a DNNClassifier instance.
@@ -882,6 +915,9 @@ Initializes a DNNClassifier instance.
labels which are the output of `input_fn` and
returns features and labels which will be fed
into the model.
+* <b>`embedding_lr_multipliers`</b>: Optional. A dictionary from `EmbeddingColumn` to
+ a `float` multiplier. Multiplier will be used to multiply with
+ learning rate for the embedding variables.
##### Returns:
@@ -927,6 +963,15 @@ See BaseEstimator.export.
- - -
+#### `tf.contrib.learn.DNNClassifier.export_savedmodel(*args, **kwargs)` {#DNNClassifier.export_savedmodel}
+
+EXPERIMENTAL FUNCTION
+
+THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning.
+
+
+- - -
+
#### `tf.contrib.learn.DNNClassifier.fit(x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None)` {#DNNClassifier.fit}
See trainable.Trainable. Note: Labels must be integer class indices.
@@ -1556,6 +1601,15 @@ See BaseEstimator.export.
- - -
+#### `tf.contrib.learn.LinearClassifier.export_savedmodel(*args, **kwargs)` {#LinearClassifier.export_savedmodel}
+
+EXPERIMENTAL FUNCTION
+
+THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning.
+
+
+- - -
+
#### `tf.contrib.learn.LinearClassifier.fit(x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None)` {#LinearClassifier.fit}
See trainable.Trainable. Note: Labels must be integer class indices.
@@ -1746,6 +1800,15 @@ See BaseEstimator.export.
- - -
+#### `tf.contrib.learn.LinearRegressor.export_savedmodel(*args, **kwargs)` {#LinearRegressor.export_savedmodel}
+
+EXPERIMENTAL FUNCTION
+
+THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning.
+
+
+- - -
+
#### `tf.contrib.learn.LinearRegressor.fit(x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None)` {#LinearRegressor.fit}
See trainable.Trainable.
@@ -1928,6 +1991,39 @@ The signature of the input_fn accepted by export is changing to be consistent wi
- - -
+#### `tf.contrib.learn.LogisticRegressor.export_savedmodel(*args, **kwargs)` {#LogisticRegressor.export_savedmodel}
+
+Exports inference graph as a SavedModel into given dir. (experimental)
+
+THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning.
+
+
+ Args:
+ export_dir_base: A string containing a directory to write the exported
+ graph and checkpoints.
+ input_fn: A function that takes no argument and
+ returns an `InputFnOps`.
+ default_output_alternative_key: the name of the head to serve when none is
+ specified.
+ assets_extra: A dict specifying how to populate the assets.extra directory
+ within the exported SavedModel. Each key should give the destination
+ path (including the filename) relative to the assets.extra directory.
+ The corresponding value gives the full path of the source file to be
+ copied. For example, the simple case of copying a single file without
+ renaming it is specified as
+ `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
+ as_text: whether to write the SavedModel proto in text format.
+ exports_to_keep: Number of exports to keep.
+
+ Returns:
+ The string path to the exported directory.
+
+ Raises:
+ ValueError: if an unrecognized export_type is requested.
+
+
+- - -
+
#### `tf.contrib.learn.LogisticRegressor.fit(*args, **kwargs)` {#LogisticRegressor.fit}
See `Trainable`. (deprecated arguments)
diff --git a/tensorflow/g3doc/api_docs/python/contrib.learn.monitors.md b/tensorflow/g3doc/api_docs/python/contrib.learn.monitors.md
index 122c4e3551..d4d200399e 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.learn.monitors.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.learn.monitors.md
@@ -887,9 +887,9 @@ The signature of the input_fn accepted by export is changing to be consistent wi
`None`).
input_feature_key: String key into the features dict returned by
`input_fn` that corresponds to the raw `Example` strings `Tensor` that
- the exported model will take as input. Can only be `None` if you're
- using a custom `signature_fn` that does not use the first arg
- (examples).
+ the exported model will take as input. Should be `None` if and only if
+ you're passing in a `signature_fn` that does not use the first arg
+ (`Tensor` of `Example` strings).
exports_to_keep: int, number of exports to keep.
signature_fn: Function that returns a default signature and a named
signature map, given `Tensor` of `Example` strings, `dict` of `Tensor`s
diff --git a/tensorflow/g3doc/api_docs/python/contrib.linalg.md b/tensorflow/g3doc/api_docs/python/contrib.linalg.md
index e678edfe72..d5edaa3e82 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.linalg.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.linalg.md
@@ -106,6 +106,19 @@ FILL IN WHAT IS MEANT BY COMPATIBLE SHAPE
### Performance
FILL THIS IN
+
+### Matrix property hints
+
+This `LinearOperator` is initialized with boolean flags of the form `is_X`,
+for `X = non_singular, self_adjoint` etc...
+These have the following meaning
+* If `is_X == True`, callers should expect the operator to have the
+ property `X`. This is a promise that should be fulfilled, but is *not* a
+ runtime assert. For example, finite floating point precision may result
+ in these promises being violated.
+* If `is_X == False`, callers should expect the operator to not have `X`.
+* If `is_X == None` (the default), callers should have no expectation either
+ way.
- - -
#### `tf.contrib.linalg.LinearOperator.__init__(dtype, graph_parents=None, is_non_singular=None, is_self_adjoint=None, is_positive_definite=None, name=None)` {#LinearOperator.__init__}
@@ -115,16 +128,6 @@ Initialize the `LinearOperator`.
**This is a private method for subclass use.**
**Subclasses should copy-paste this `__init__` documentation.**
-For `X = non_singular, self_adjoint` etc...
-`is_X` is a Python `bool` initialization argument with the following meaning
-* If `is_X == True`, callers should expect the operator to have the
- attribute `X`. This is a promise that should be fulfilled, but is *not* a
- runtime assert. Issues, such as floating point error, could mean the
- operator violates this promise.
-* If `is_X == False`, callers should expect the operator to not have `X`.
-* If `is_X == None` (the default), callers should have no expectation either
- way.
-
##### Args:
@@ -135,8 +138,12 @@ For `X = non_singular, self_adjoint` etc...
* <b>`is_non_singular`</b>: Expect that this operator is non-singular.
* <b>`is_self_adjoint`</b>: Expect that this operator is equal to its hermitian
transpose. If `dtype` is real, this is equivalent to being symmetric.
-* <b>`is_positive_definite`</b>: Expect that this operator is positive definite.
-* <b>`name`</b>: A name for this `LinearOperator`. Default: subclass name.
+* <b>`is_positive_definite`</b>: Expect that this operator is positive definite,
+ meaning the real part of all eigenvalues is positive. We do not require
+ the operator to be self-adjoint to be positive-definite. See:
+* <b>`https`</b>: //en.wikipedia.org/wiki/Positive-definite_matrix\
+ #Extension_for_non_symmetric_matrices
+* <b>`name`</b>: A name for this `LinearOperator`.
##### Raises:
@@ -146,6 +153,23 @@ For `X = non_singular, self_adjoint` etc...
- - -
+#### `tf.contrib.linalg.LinearOperator.add_to_tensor(x, name='add_to_tensor')` {#LinearOperator.add_to_tensor}
+
+Add matrix represented by this operator to `x`. Equivalent to `A + x`.
+
+##### Args:
+
+
+* <b>`x`</b>: `Tensor` with same `dtype` and shape broadcastable to `self.shape`.
+* <b>`name`</b>: A name to give this `Op`.
+
+##### Returns:
+
+ A `Tensor` with broadcast shape and same `dtype` as `self`.
+
+
+- - -
+
#### `tf.contrib.linalg.LinearOperator.apply(x, adjoint=False, name='apply')` {#LinearOperator.apply}
Transform `x` with left multiplication: `x --> Ax`.
@@ -176,6 +200,25 @@ Returns an `Op` that asserts this operator is non singular.
Returns an `Op` that asserts this operator is positive definite.
+Here, positive definite means the real part of all eigenvalues is positive.
+We do not require the operator to be self-adjoint.
+
+##### Args:
+
+
+* <b>`name`</b>: A name to give this `Op`.
+
+##### Returns:
+
+ An `Op` that asserts this operator is positive definite.
+
+
+- - -
+
+#### `tf.contrib.linalg.LinearOperator.assert_self_adjoint(name='assert_self_adjoint')` {#LinearOperator.assert_self_adjoint}
+
+Returns an `Op` that asserts this operator is self-adjoint.
+
- - -
@@ -493,7 +536,7 @@ Return a dense (batch) matrix representing this operator.
This operator acts like a [batch] matrix `A` with shape
`[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
-an `m x n` matrix. Again, this matrix `A` may not be materialized, but for
+an `N x N` matrix. This matrix `A` is not materialized, but for
purposes of broadcasting this shape will be relevant.
`LinearOperatorDiag` is initialized with a (batch) vector.
@@ -507,7 +550,7 @@ operator.to_dense()
==> [[1., 0.]
[0., -1.]]
-operator.shape()
+operator.shape
==> [2, 2]
operator.log_determinant()
@@ -542,7 +585,7 @@ and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
### Performance
-Suppose `operator` is a `LinearOperatorDiag` is of shape `[N, N]`,
+Suppose `operator` is a `LinearOperatorDiag` of shape `[N, N]`,
and `x.shape = [N, R]`. Then
* `operator.apply(x)` involves `N*R` multiplications.
@@ -551,43 +594,68 @@ and `x.shape = [N, R]`. Then
If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
`[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
-- - -
-#### `tf.contrib.linalg.LinearOperatorDiag.__init__(diag, is_non_singular=None, is_self_adjoint=True, is_positive_definite=None, name='LinearOperatorDiag')` {#LinearOperatorDiag.__init__}
+### Matrix property hints
-Initialize a `LinearOperatorDiag`.
-
-For `X = non_singular, self_adjoint` etc...
-`is_X` is a Python `bool` initialization argument with the following meaning
+This `LinearOperator` is initialized with boolean flags of the form `is_X`,
+for `X = non_singular, self_adjoint` etc...
+These have the following meaning
* If `is_X == True`, callers should expect the operator to have the
- attribute `X`. This is a promise that should be fulfilled, but is *not* a
- runtime assert. Issues, such as floating point error, could mean the
- operator violates this promise.
+ property `X`. This is a promise that should be fulfilled, but is *not* a
+ runtime assert. For example, finite floating point precision may result
+ in these promises being violated.
* If `is_X == False`, callers should expect the operator to not have `X`.
* If `is_X == None` (the default), callers should have no expectation either
way.
+- - -
+
+#### `tf.contrib.linalg.LinearOperatorDiag.__init__(diag, is_non_singular=None, is_self_adjoint=True, is_positive_definite=None, name='LinearOperatorDiag')` {#LinearOperatorDiag.__init__}
+
+Initialize a `LinearOperatorDiag`.
##### Args:
-* <b>`diag`</b>: Shape `[B1,...,Bb, N]` real float type `Tensor` with `b >= 0`,
- `N >= 0`. The diagonal of the operator.
+* <b>`diag`</b>: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
+ The diagonal of the operator. Allowed dtypes: `float32`, `float64`,
+ `complex64`, `complex128`.
* <b>`is_non_singular`</b>: Expect that this operator is non-singular.
* <b>`is_self_adjoint`</b>: Expect that this operator is equal to its hermitian
transpose. Since this is a real (not complex) diagonal operator, it is
always self adjoint.
-* <b>`is_positive_definite`</b>: Expect that this operator is positive definite.
-* <b>`name`</b>: A name for this `LinearOperator`. Default: subclass name.
+* <b>`is_positive_definite`</b>: Expect that this operator is positive definite,
+ meaning the real part of all eigenvalues is positive. We do not require
+ the operator to be self-adjoint to be positive-definite. See:
+* <b>`https`</b>: //en.wikipedia.org/wiki/Positive-definite_matrix
+ #Extension_for_non_symmetric_matrices
+* <b>`name`</b>: A name for this `LinearOperator`.
##### Raises:
-* <b>`ValueError`</b>: If `diag.dtype` is not floating point.
+* <b>`TypeError`</b>: If `diag.dtype` is not an allowed type.
* <b>`ValueError`</b>: If `is_self_adjoint` is not `True`.
- - -
+#### `tf.contrib.linalg.LinearOperatorDiag.add_to_tensor(x, name='add_to_tensor')` {#LinearOperatorDiag.add_to_tensor}
+
+Add matrix represented by this operator to `x`. Equivalent to `A + x`.
+
+##### Args:
+
+
+* <b>`x`</b>: `Tensor` with same `dtype` and shape broadcastable to `self.shape`.
+* <b>`name`</b>: A name to give this `Op`.
+
+##### Returns:
+
+ A `Tensor` with broadcast shape and same `dtype` as `self`.
+
+
+- - -
+
#### `tf.contrib.linalg.LinearOperatorDiag.apply(x, adjoint=False, name='apply')` {#LinearOperatorDiag.apply}
Transform `x` with left multiplication: `x --> Ax`.
@@ -618,6 +686,25 @@ Returns an `Op` that asserts this operator is non singular.
Returns an `Op` that asserts this operator is positive definite.
+Here, positive definite means the real part of all eigenvalues is positive.
+We do not require the operator to be self-adjoint.
+
+##### Args:
+
+
+* <b>`name`</b>: A name to give this `Op`.
+
+##### Returns:
+
+ An `Op` that asserts this operator is positive definite.
+
+
+- - -
+
+#### `tf.contrib.linalg.LinearOperatorDiag.assert_self_adjoint(name='assert_self_adjoint')` {#LinearOperatorDiag.assert_self_adjoint}
+
+Returns an `Op` that asserts this operator is self-adjoint.
+
- - -
diff --git a/tensorflow/g3doc/api_docs/python/contrib.losses.md b/tensorflow/g3doc/api_docs/python/contrib.losses.md
index cc6a14f891..e6b0e136a9 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.losses.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.losses.md
@@ -67,6 +67,7 @@ Instructions for updating:
Args:
losses: A tensor of size [batch_size, d1, ... dN].
weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
+ scope: the scope for the operations performed in computing the loss.
weight: Deprecated alias for `weights`.
Returns:
diff --git a/tensorflow/g3doc/api_docs/python/contrib.training.md b/tensorflow/g3doc/api_docs/python/contrib.training.md
index 935c163e06..67ce73a347 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.training.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.training.md
@@ -900,7 +900,7 @@ batch.
- - -
-### `tf.contrib.training.weighted_resample(inputs, weights, overall_rate, scope=None, mean_decay=0.999, warmup=10, seed=None)` {#weighted_resample}
+### `tf.contrib.training.weighted_resample(inputs, weights, overall_rate, scope=None, mean_decay=0.999, seed=None)` {#weighted_resample}
Performs an approximate weighted resampling of `inputs`.
@@ -917,9 +917,6 @@ rate of selection across all inputs (and many invocations!) is
* <b>`overall_rate`</b>: Desired overall rate of resampling.
* <b>`scope`</b>: Scope to use for the op.
* <b>`mean_decay`</b>: How quickly to decay the running estimate of the mean weight.
-* <b>`warmup`</b>: Until the resulting tensor has been evaluated `warmup`
- times, the resampling menthod uses the true mean over all calls
- as its weight estimate, rather than a decayed mean.
* <b>`seed`</b>: Random seed.
##### Returns:
diff --git a/tensorflow/g3doc/api_docs/python/control_flow_ops.md b/tensorflow/g3doc/api_docs/python/control_flow_ops.md
index 2c9b4bed34..31435d5ec3 100644
--- a/tensorflow/g3doc/api_docs/python/control_flow_ops.md
+++ b/tensorflow/g3doc/api_docs/python/control_flow_ops.md
@@ -596,7 +596,7 @@ Returns the truth value of (x >= y) element-wise.
- - -
-### `tf.select(condition, t, e, name=None)` {#select}
+### `tf.select(*args, **kwargs)` {#select}
Selects elements from `t` or `e`, depending on `condition`.
diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md
index 2601bd99ff..7c799283fd 100644
--- a/tensorflow/g3doc/api_docs/python/framework.md
+++ b/tensorflow/g3doc/api_docs/python/framework.md
@@ -1438,8 +1438,19 @@ dynamic condition of the `Tensor`.
#### `tf.Tensor.__div__(x, y)` {#Tensor.__div__}
+Divide two values using Python 2 semantics. Used for Tensor.__div__.
+
+##### Args:
+* <b>`x`</b>: `Tensor` numerator of real numeric type.
+* <b>`y`</b>: `Tensor` denominator of real numeric type.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ `x / y` returns the quotient of x and y.
+
- - -
@@ -1847,7 +1858,18 @@ Returns the truth value of x AND y element-wise.
#### `tf.Tensor.__rdiv__(y, x)` {#Tensor.__rdiv__}
+Divide two values using Python 2 semantics. Used for Tensor.__div__.
+
+##### Args:
+
+
+* <b>`x`</b>: `Tensor` numerator of real numeric type.
+* <b>`y`</b>: `Tensor` denominator of real numeric type.
+* <b>`name`</b>: A name for the operation (optional).
+##### Returns:
+
+ `x / y` returns the quotient of x and y.
- - -
@@ -1998,34 +2020,7 @@ Returns x - y element-wise.
#### `tf.Tensor.__rtruediv__(y, x)` {#Tensor.__rtruediv__}
-Divides x / y elementwise, always producing floating point results.
-
-The same as `tf.div` for floating point arguments, but casts integer arguments
-to floating point before dividing so that the result is always floating point.
-This op is generated by normal `x / y` division in Python 3 and in Python 2.7
-with `from __future__ import division`. If you want integer division that
-rounds down, use `x // y` or `tf.floordiv`.
-
-`x` and `y` must have the same numeric type. If the inputs are floating
-point, the output will have the same type. If the inputs are integral, the
-inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
-and `int64` (matching the behavior of Numpy).
-
-##### Args:
-
-
-* <b>`x`</b>: `Tensor` numerator of numeric type.
-* <b>`y`</b>: `Tensor` denominator of numeric type.
-* <b>`name`</b>: A name for the operation (optional).
-
-##### Returns:
- `x / y` evaluated in floating point.
-
-##### Raises:
-
-
-* <b>`TypeError`</b>: If `x` and `y` have different dtypes.
- - -
@@ -2067,34 +2062,7 @@ Returns x - y element-wise.
#### `tf.Tensor.__truediv__(x, y)` {#Tensor.__truediv__}
-Divides x / y elementwise, always producing floating point results.
-
-The same as `tf.div` for floating point arguments, but casts integer arguments
-to floating point before dividing so that the result is always floating point.
-This op is generated by normal `x / y` division in Python 3 and in Python 2.7
-with `from __future__ import division`. If you want integer division that
-rounds down, use `x // y` or `tf.floordiv`.
-
-`x` and `y` must have the same numeric type. If the inputs are floating
-point, the output will have the same type. If the inputs are integral, the
-inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
-and `int64` (matching the behavior of Numpy).
-
-##### Args:
-
-
-* <b>`x`</b>: `Tensor` numerator of numeric type.
-* <b>`y`</b>: `Tensor` denominator of numeric type.
-* <b>`name`</b>: A name for the operation (optional).
-
-##### Returns:
-
- `x / y` evaluated in floating point.
-
-##### Raises:
-
-* <b>`TypeError`</b>: If `x` and `y` have different dtypes.
- - -
@@ -2827,7 +2795,7 @@ The following standard keys are defined:
for more details.
* `SUMMARIES`: the summary `Tensor` objects that have been created in the
graph. See
- [`tf.merge_all_summaries()`](../../api_docs/python/train.md#merge_all_summaries)
+ [`tf.contrib.deprecated.merge_all_summaries()`](../../api_docs/python/train.md#merge_all_summaries)
for more details.
* `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to
produce input for a computation. See
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.framework.deprecated_args.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.framework.deprecated_args.md
index 4dd9bbc0f8..3f81ac9fc1 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.framework.deprecated_args.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.framework.deprecated_args.md
@@ -1,4 +1,4 @@
-### `tf.contrib.framework.deprecated_args(date, instructions, *deprecated_arg_names)` {#deprecated_args}
+### `tf.contrib.framework.deprecated_args(date, instructions, *deprecated_arg_names_or_tuples)` {#deprecated_args}
Decorator for marking specific function arguments as deprecated.
@@ -22,7 +22,10 @@ prepended to the rest of the docstring.
ISO 8601 (YYYY-MM-DD).
* <b>`instructions`</b>: String. Instructions on how to update code using the
deprecated function.
-* <b>`*deprecated_arg_names`</b>: String. The deprecated arguments.
+* <b>`*deprecated_arg_names_or_tuples`</b>: String. or 2-Tuple(String,
+ [ok_vals]). The string is the deprecated argument name.
+ Optionally, an ok-value may be provided. If the user provided
+ argument equals this value, the warning is suppressed.
##### Returns:
@@ -31,6 +34,8 @@ prepended to the rest of the docstring.
##### Raises:
-* <b>`ValueError`</b>: If date is not in ISO 8601 format, instructions are empty, or
- the deprecated arguments are not present in the function signature.
+* <b>`ValueError`</b>: If date is not in ISO 8601 format, instructions are
+ empty, the deprecated arguments are not present in the function
+ signature, or the second element of a deprecated_tuple is not a
+ list.
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.LinearRegressor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.LinearRegressor.md
index 02cf9a8674..81405d1ab5 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.LinearRegressor.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.LinearRegressor.md
@@ -113,6 +113,15 @@ See BaseEstimator.export.
- - -
+#### `tf.contrib.learn.LinearRegressor.export_savedmodel(*args, **kwargs)` {#LinearRegressor.export_savedmodel}
+
+EXPERIMENTAL FUNCTION
+
+THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning.
+
+
+- - -
+
#### `tf.contrib.learn.LinearRegressor.fit(x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None)` {#LinearRegressor.fit}
See trainable.Trainable.
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.linalg.LinearOperatorDiag.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.linalg.LinearOperatorDiag.md
index ebe0d9a3c2..f4360dc46c 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.linalg.LinearOperatorDiag.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.linalg.LinearOperatorDiag.md
@@ -3,7 +3,7 @@
This operator acts like a [batch] matrix `A` with shape
`[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
-an `m x n` matrix. Again, this matrix `A` may not be materialized, but for
+an `N x N` matrix. This matrix `A` is not materialized, but for
purposes of broadcasting this shape will be relevant.
`LinearOperatorDiag` is initialized with a (batch) vector.
@@ -17,7 +17,7 @@ operator.to_dense()
==> [[1., 0.]
[0., -1.]]
-operator.shape()
+operator.shape
==> [2, 2]
operator.log_determinant()
@@ -52,7 +52,7 @@ and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
### Performance
-Suppose `operator` is a `LinearOperatorDiag` is of shape `[N, N]`,
+Suppose `operator` is a `LinearOperatorDiag` of shape `[N, N]`,
and `x.shape = [N, R]`. Then
* `operator.apply(x)` involves `N*R` multiplications.
@@ -61,43 +61,68 @@ and `x.shape = [N, R]`. Then
If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
`[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
-- - -
-#### `tf.contrib.linalg.LinearOperatorDiag.__init__(diag, is_non_singular=None, is_self_adjoint=True, is_positive_definite=None, name='LinearOperatorDiag')` {#LinearOperatorDiag.__init__}
+### Matrix property hints
-Initialize a `LinearOperatorDiag`.
-
-For `X = non_singular, self_adjoint` etc...
-`is_X` is a Python `bool` initialization argument with the following meaning
+This `LinearOperator` is initialized with boolean flags of the form `is_X`,
+for `X = non_singular, self_adjoint` etc...
+These have the following meaning
* If `is_X == True`, callers should expect the operator to have the
- attribute `X`. This is a promise that should be fulfilled, but is *not* a
- runtime assert. Issues, such as floating point error, could mean the
- operator violates this promise.
+ property `X`. This is a promise that should be fulfilled, but is *not* a
+ runtime assert. For example, finite floating point precision may result
+ in these promises being violated.
* If `is_X == False`, callers should expect the operator to not have `X`.
* If `is_X == None` (the default), callers should have no expectation either
way.
+- - -
+
+#### `tf.contrib.linalg.LinearOperatorDiag.__init__(diag, is_non_singular=None, is_self_adjoint=True, is_positive_definite=None, name='LinearOperatorDiag')` {#LinearOperatorDiag.__init__}
+
+Initialize a `LinearOperatorDiag`.
##### Args:
-* <b>`diag`</b>: Shape `[B1,...,Bb, N]` real float type `Tensor` with `b >= 0`,
- `N >= 0`. The diagonal of the operator.
+* <b>`diag`</b>: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
+ The diagonal of the operator. Allowed dtypes: `float32`, `float64`,
+ `complex64`, `complex128`.
* <b>`is_non_singular`</b>: Expect that this operator is non-singular.
* <b>`is_self_adjoint`</b>: Expect that this operator is equal to its hermitian
transpose. Since this is a real (not complex) diagonal operator, it is
always self adjoint.
-* <b>`is_positive_definite`</b>: Expect that this operator is positive definite.
-* <b>`name`</b>: A name for this `LinearOperator`. Default: subclass name.
+* <b>`is_positive_definite`</b>: Expect that this operator is positive definite,
+ meaning the real part of all eigenvalues is positive. We do not require
+ the operator to be self-adjoint to be positive-definite. See:
+* <b>`https`</b>: //en.wikipedia.org/wiki/Positive-definite_matrix
+ #Extension_for_non_symmetric_matrices
+* <b>`name`</b>: A name for this `LinearOperator`.
##### Raises:
-* <b>`ValueError`</b>: If `diag.dtype` is not floating point.
+* <b>`TypeError`</b>: If `diag.dtype` is not an allowed type.
* <b>`ValueError`</b>: If `is_self_adjoint` is not `True`.
- - -
+#### `tf.contrib.linalg.LinearOperatorDiag.add_to_tensor(x, name='add_to_tensor')` {#LinearOperatorDiag.add_to_tensor}
+
+Add matrix represented by this operator to `x`. Equivalent to `A + x`.
+
+##### Args:
+
+
+* <b>`x`</b>: `Tensor` with same `dtype` and shape broadcastable to `self.shape`.
+* <b>`name`</b>: A name to give this `Op`.
+
+##### Returns:
+
+ A `Tensor` with broadcast shape and same `dtype` as `self`.
+
+
+- - -
+
#### `tf.contrib.linalg.LinearOperatorDiag.apply(x, adjoint=False, name='apply')` {#LinearOperatorDiag.apply}
Transform `x` with left multiplication: `x --> Ax`.
@@ -128,6 +153,25 @@ Returns an `Op` that asserts this operator is non singular.
Returns an `Op` that asserts this operator is positive definite.
+Here, positive definite means the real part of all eigenvalues is positive.
+We do not require the operator to be self-adjoint.
+
+##### Args:
+
+
+* <b>`name`</b>: A name to give this `Op`.
+
+##### Returns:
+
+ An `Op` that asserts this operator is positive definite.
+
+
+- - -
+
+#### `tf.contrib.linalg.LinearOperatorDiag.assert_self_adjoint(name='assert_self_adjoint')` {#LinearOperatorDiag.assert_self_adjoint}
+
+Returns an `Op` that asserts this operator is self-adjoint.
+
- - -
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.training.weighted_resample.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.training.weighted_resample.md
index 4977793e37..903cad838b 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.training.weighted_resample.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.training.weighted_resample.md
@@ -1,4 +1,4 @@
-### `tf.contrib.training.weighted_resample(inputs, weights, overall_rate, scope=None, mean_decay=0.999, warmup=10, seed=None)` {#weighted_resample}
+### `tf.contrib.training.weighted_resample(inputs, weights, overall_rate, scope=None, mean_decay=0.999, seed=None)` {#weighted_resample}
Performs an approximate weighted resampling of `inputs`.
@@ -15,9 +15,6 @@ rate of selection across all inputs (and many invocations!) is
* <b>`overall_rate`</b>: Desired overall rate of resampling.
* <b>`scope`</b>: Scope to use for the op.
* <b>`mean_decay`</b>: How quickly to decay the running estimate of the mean weight.
-* <b>`warmup`</b>: Until the resulting tensor has been evaluated `warmup`
- times, the resampling menthod uses the true mean over all calls
- as its weight estimate, rather than a decayed mean.
* <b>`seed`</b>: Random seed.
##### Returns:
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.summary.FileWriterCache.clear.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.summary.FileWriterCache.clear.md
new file mode 100644
index 0000000000..e3c7027813
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.summary.FileWriterCache.clear.md
@@ -0,0 +1,4 @@
+#### `tf.summary.FileWriterCache.clear()` {#FileWriterCache.clear}
+
+Clear cached summary writers. Currently only used for unit tests.
+
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.ExponentialMovingAverage.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.ExponentialMovingAverage.md
index 3c8fd4c447..f1f89fff93 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.ExponentialMovingAverage.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.ExponentialMovingAverage.md
@@ -80,7 +80,7 @@ saver.restore(...checkpoint filename...)
- - -
-#### `tf.train.ExponentialMovingAverage.__init__(decay, num_updates=None, name='ExponentialMovingAverage')` {#ExponentialMovingAverage.__init__}
+#### `tf.train.ExponentialMovingAverage.__init__(decay, num_updates=None, zero_debias=False, name='ExponentialMovingAverage')` {#ExponentialMovingAverage.__init__}
Creates a new ExponentialMovingAverage object.
@@ -100,6 +100,8 @@ move faster. If passed, the actual decay rate used is:
* <b>`decay`</b>: Float. The decay to use.
* <b>`num_updates`</b>: Optional count of number of updates applied to variables.
+* <b>`zero_debias`</b>: If `True`, zero debias moving-averages that are initialized
+ with tensors.
* <b>`name`</b>: String. Optional prefix name to use for the name of ops added in
`apply()`.
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md
index 2ce825eb7b..aac10e4396 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md
@@ -299,8 +299,19 @@ dynamic condition of the `Tensor`.
#### `tf.Tensor.__div__(x, y)` {#Tensor.__div__}
+Divide two values using Python 2 semantics. Used for Tensor.__div__.
+
+##### Args:
+* <b>`x`</b>: `Tensor` numerator of real numeric type.
+* <b>`y`</b>: `Tensor` denominator of real numeric type.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ `x / y` returns the quotient of x and y.
+
- - -
@@ -708,7 +719,18 @@ Returns the truth value of x AND y element-wise.
#### `tf.Tensor.__rdiv__(y, x)` {#Tensor.__rdiv__}
+Divide two values using Python 2 semantics. Used for Tensor.__div__.
+
+##### Args:
+
+
+* <b>`x`</b>: `Tensor` numerator of real numeric type.
+* <b>`y`</b>: `Tensor` denominator of real numeric type.
+* <b>`name`</b>: A name for the operation (optional).
+##### Returns:
+
+ `x / y` returns the quotient of x and y.
- - -
@@ -859,34 +881,7 @@ Returns x - y element-wise.
#### `tf.Tensor.__rtruediv__(y, x)` {#Tensor.__rtruediv__}
-Divides x / y elementwise, always producing floating point results.
-
-The same as `tf.div` for floating point arguments, but casts integer arguments
-to floating point before dividing so that the result is always floating point.
-This op is generated by normal `x / y` division in Python 3 and in Python 2.7
-with `from __future__ import division`. If you want integer division that
-rounds down, use `x // y` or `tf.floordiv`.
-
-`x` and `y` must have the same numeric type. If the inputs are floating
-point, the output will have the same type. If the inputs are integral, the
-inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
-and `int64` (matching the behavior of Numpy).
-
-##### Args:
-
-
-* <b>`x`</b>: `Tensor` numerator of numeric type.
-* <b>`y`</b>: `Tensor` denominator of numeric type.
-* <b>`name`</b>: A name for the operation (optional).
-
-##### Returns:
- `x / y` evaluated in floating point.
-
-##### Raises:
-
-
-* <b>`TypeError`</b>: If `x` and `y` have different dtypes.
- - -
@@ -928,34 +923,7 @@ Returns x - y element-wise.
#### `tf.Tensor.__truediv__(x, y)` {#Tensor.__truediv__}
-Divides x / y elementwise, always producing floating point results.
-
-The same as `tf.div` for floating point arguments, but casts integer arguments
-to floating point before dividing so that the result is always floating point.
-This op is generated by normal `x / y` division in Python 3 and in Python 2.7
-with `from __future__ import division`. If you want integer division that
-rounds down, use `x // y` or `tf.floordiv`.
-
-`x` and `y` must have the same numeric type. If the inputs are floating
-point, the output will have the same type. If the inputs are integral, the
-inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
-and `int64` (matching the behavior of Numpy).
-
-##### Args:
-
-
-* <b>`x`</b>: `Tensor` numerator of numeric type.
-* <b>`y`</b>: `Tensor` denominator of numeric type.
-* <b>`name`</b>: A name for the operation (optional).
-
-##### Returns:
-
- `x / y` evaluated in floating point.
-
-##### Raises:
-
-* <b>`TypeError`</b>: If `x` and `y` have different dtypes.
- - -
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.distributions.TransformedDistribution.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.distributions.TransformedDistribution.md
index 4b4f4413b5..746339a662 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.distributions.TransformedDistribution.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.distributions.TransformedDistribution.md
@@ -87,7 +87,7 @@ log_normal = ds.TransformedDistribution(
forward_fn=tf.exp,
inverse_fn=tf.log,
inverse_log_det_jacobian_fn=(
- lambda y: -tf.reduce_sum(tf.log(x), reduction_indices=-1)),
+ lambda y: -tf.reduce_sum(tf.log(y), reduction_indices=-1)),
name="LogNormalTransformedDistribution")
```
@@ -109,7 +109,7 @@ Construct a Transformed Distribution.
##### Args:
-* <b>`distribution`</b>: The base distribution class to transform. Typically an
+* <b>`distribution`</b>: The base distribution instance to transform. Typically an
instance of `Distribution`.
* <b>`bijector`</b>: The object responsible for calculating the transformation.
Typically an instance of `Bijector`.
@@ -183,10 +183,10 @@ cdf(x) := P[X <= x]
Additional documentation from `TransformedDistribution`:
-##### <b>`condition_kwargs`</b>:
+##### `condition_kwargs`:
-* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector.
-* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution.
+* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector.
+* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution.
##### Args:
@@ -324,10 +324,10 @@ a more accurate answer than simply taking the logarithm of the `cdf` when
Additional documentation from `TransformedDistribution`:
-##### <b>`condition_kwargs`</b>:
+##### `condition_kwargs`:
-* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector.
-* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution.
+* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector.
+* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution.
##### Args:
@@ -408,10 +408,10 @@ Implements `(log o p o g^{-1})(y) + (log o det o J o g^{-1})(y)`,
Also raises a `ValueError` if `inverse` was not provided to the
distribution and `y` was not returned from `sample`.
-##### <b>`condition_kwargs`</b>:
+##### `condition_kwargs`:
-* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector.
-* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution.
+* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector.
+* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution.
##### Args:
@@ -447,10 +447,10 @@ survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`.
Additional documentation from `TransformedDistribution`:
-##### <b>`condition_kwargs`</b>:
+##### `condition_kwargs`:
-* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector.
-* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution.
+* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector.
+* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution.
##### Args:
@@ -600,10 +600,10 @@ Implements `p(g^{-1}(y)) det|J(g^{-1}(y))|`, where `g^{-1}` is the
Also raises a `ValueError` if `inverse` was not provided to the
distribution and `y` was not returned from `sample`.
-##### <b>`condition_kwargs`</b>:
+##### `condition_kwargs`:
-* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector.
-* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution.
+* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector.
+* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution.
##### Args:
@@ -654,10 +654,10 @@ Additional documentation from `TransformedDistribution`:
Samples from the base distribution and then passes through
the bijector's forward transform.
-##### <b>`condition_kwargs`</b>:
+##### `condition_kwargs`:
-* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector.
-* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution.
+* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector.
+* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution.
##### Args:
@@ -703,10 +703,10 @@ survival_function(x) = P[X > x]
Additional documentation from `TransformedDistribution`:
-##### <b>`condition_kwargs`</b>:
+##### `condition_kwargs`:
-* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector.
-* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution.
+* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector.
+* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution.
##### Args:
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.layers.embedding_column.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.layers.embedding_column.md
index e4457e4be4..09a78073d6 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.layers.embedding_column.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.layers.embedding_column.md
@@ -1,4 +1,4 @@
-### `tf.contrib.layers.embedding_column(sparse_id_column, dimension, combiner=None, initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None)` {#embedding_column}
+### `tf.contrib.layers.embedding_column(sparse_id_column, dimension, combiner=None, initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None)` {#embedding_column}
Creates an `_EmbeddingColumn` for feeding sparse data into a DNN.
@@ -26,6 +26,8 @@ Creates an `_EmbeddingColumn` for feeding sparse data into a DNN.
* <b>`tensor_name_in_ckpt`</b>: (Optional). Name of the `Tensor` in the provided
checkpoint from which to restore the column weights. Required if
`ckpt_to_load_from` is not None.
+* <b>`max_norm`</b>: (Optional). If not None, embedding values are l2-normalized to
+ the value of max_norm.
##### Returns:
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.learn.LinearClassifier.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.learn.LinearClassifier.md
index e5c7d7edf3..fbbab399f3 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.learn.LinearClassifier.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.learn.LinearClassifier.md
@@ -141,6 +141,15 @@ See BaseEstimator.export.
- - -
+#### `tf.contrib.learn.LinearClassifier.export_savedmodel(*args, **kwargs)` {#LinearClassifier.export_savedmodel}
+
+EXPERIMENTAL FUNCTION
+
+THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning.
+
+
+- - -
+
#### `tf.contrib.learn.LinearClassifier.fit(x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None)` {#LinearClassifier.fit}
See trainable.Trainable. Note: Labels must be integer class indices.
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.losses.compute_weighted_loss.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.losses.compute_weighted_loss.md
index 9cbade4389..8cd3b0d69f 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.losses.compute_weighted_loss.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.losses.compute_weighted_loss.md
@@ -9,6 +9,7 @@ Instructions for updating:
Args:
losses: A tensor of size [batch_size, d1, ... dN].
weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
+ scope: the scope for the operations performed in computing the loss.
weight: Deprecated alias for `weights`.
Returns:
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.framework.assert_scalar_int.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.framework.assert_scalar_int.md
index 94d5355e10..469566f7b8 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.framework.assert_scalar_int.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.framework.assert_scalar_int.md
@@ -1,11 +1,12 @@
-### `tf.contrib.framework.assert_scalar_int(tensor)` {#assert_scalar_int}
+### `tf.contrib.framework.assert_scalar_int(tensor, name=None)` {#assert_scalar_int}
Assert `tensor` is 0-D, of type `tf.int32` or `tf.int64`.
##### Args:
-* <b>`tensor`</b>: Tensor to test.
+* <b>`tensor`</b>: `Tensor` to test.
+* <b>`name`</b>: Name of the op and of the new `Tensor` if one is created.
##### Returns:
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.layers.shared_embedding_columns.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.layers.shared_embedding_columns.md
index 830f1bd352..0550580a0e 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.layers.shared_embedding_columns.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.layers.shared_embedding_columns.md
@@ -1,4 +1,4 @@
-### `tf.contrib.layers.shared_embedding_columns(sparse_id_columns, dimension, combiner=None, shared_embedding_name=None, initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None)` {#shared_embedding_columns}
+### `tf.contrib.layers.shared_embedding_columns(sparse_id_columns, dimension, combiner=None, shared_embedding_name=None, initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None)` {#shared_embedding_columns}
Creates a list of `_EmbeddingColumn` sharing the same embedding.
@@ -29,6 +29,8 @@ Creates a list of `_EmbeddingColumn` sharing the same embedding.
* <b>`tensor_name_in_ckpt`</b>: (Optional). Name of the `Tensor` in the provided
checkpoint from which to restore the column weights. Required if
`ckpt_to_load_from` is not None.
+* <b>`max_norm`</b>: (Optional). If not None, embedding values are l2-normalized to
+ the value of max_norm.
##### Returns:
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.learn.monitors.ExportMonitor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.learn.monitors.ExportMonitor.md
index 647ef7e955..53992bdf4f 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.learn.monitors.ExportMonitor.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.learn.monitors.ExportMonitor.md
@@ -18,9 +18,9 @@ The signature of the input_fn accepted by export is changing to be consistent wi
`None`).
input_feature_key: String key into the features dict returned by
`input_fn` that corresponds to the raw `Example` strings `Tensor` that
- the exported model will take as input. Can only be `None` if you're
- using a custom `signature_fn` that does not use the first arg
- (examples).
+ the exported model will take as input. Should be `None` if and only if
+ you're passing in a `signature_fn` that does not use the first arg
+ (`Tensor` of `Example` strings).
exports_to_keep: int, number of exports to keep.
signature_fn: Function that returns a default signature and a named
signature map, given `Tensor` of `Example` strings, `dict` of `Tensor`s
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.string_split.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.string_split.md
index 25607d1619..08ccc5f104 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.string_split.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.string_split.md
@@ -8,7 +8,8 @@ containing the splitted tokens. Empty tokens are ignored.
If `delimiter` is an empty string, each element of the `source` is split
into individual strings, each containing one byte. (This includes splitting
-multibyte sequences of UTF-8.)
+multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is
+treated as a set of delimiters with each considered a potential split point.
For example:
N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output
@@ -29,14 +30,14 @@ st.values = ['hello', 'world', 'a', 'b', 'c']
* <b>`delimiter`</b>: `0-D` string `Tensor`, the delimiter character, the string should
be length 0 or 1.
+##### Raises:
+
+
+* <b>`ValueError`</b>: If delimiter is not a string.
+
##### Returns:
A `SparseTensor` of rank `2`, the strings split according to the delimiter.
The first column of the indices corresponds to the row in `source` and the
second column corresponds to the index of the split component in this row.
-##### Raises:
-
-
-* <b>`ValueError`</b>: If delimiter is not a single-byte character.
-
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.learn.Estimator.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.learn.Estimator.md
index aa3c101dbf..0f3006d9ca 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.learn.Estimator.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.learn.Estimator.md
@@ -140,6 +140,39 @@ The signature of the input_fn accepted by export is changing to be consistent wi
- - -
+#### `tf.contrib.learn.Estimator.export_savedmodel(*args, **kwargs)` {#Estimator.export_savedmodel}
+
+Exports inference graph as a SavedModel into given dir. (experimental)
+
+THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning.
+
+
+ Args:
+ export_dir_base: A string containing a directory to write the exported
+ graph and checkpoints.
+ input_fn: A function that takes no argument and
+ returns an `InputFnOps`.
+ default_output_alternative_key: the name of the head to serve when none is
+ specified.
+ assets_extra: A dict specifying how to populate the assets.extra directory
+ within the exported SavedModel. Each key should give the destination
+ path (including the filename) relative to the assets.extra directory.
+ The corresponding value gives the full path of the source file to be
+ copied. For example, the simple case of copying a single file without
+ renaming it is specified as
+ `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
+ as_text: whether to write the SavedModel proto in text format.
+ exports_to_keep: Number of exports to keep.
+
+ Returns:
+ The string path to the exported directory.
+
+ Raises:
+ ValueError: if an unrecognized export_type is requested.
+
+
+- - -
+
#### `tf.contrib.learn.Estimator.fit(*args, **kwargs)` {#Estimator.fit}
See `Trainable`. (deprecated arguments)
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.DNNClassifier.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.DNNClassifier.md
index e9034e9115..f9fa7c70cb 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.DNNClassifier.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.DNNClassifier.md
@@ -51,7 +51,7 @@ Input of `fit` and `evaluate` should have following features,
whose `value` is a `Tensor`.
- - -
-#### `tf.contrib.learn.DNNClassifier.__init__(hidden_units, feature_columns, model_dir=None, n_classes=2, weight_column_name=None, optimizer=None, activation_fn=relu, dropout=None, gradient_clip_norm=None, enable_centered_bias=False, config=None, feature_engineering_fn=None)` {#DNNClassifier.__init__}
+#### `tf.contrib.learn.DNNClassifier.__init__(hidden_units, feature_columns, model_dir=None, n_classes=2, weight_column_name=None, optimizer=None, activation_fn=relu, dropout=None, gradient_clip_norm=None, enable_centered_bias=False, config=None, feature_engineering_fn=None, embedding_lr_multipliers=None)` {#DNNClassifier.__init__}
Initializes a DNNClassifier instance.
@@ -91,6 +91,9 @@ Initializes a DNNClassifier instance.
labels which are the output of `input_fn` and
returns features and labels which will be fed
into the model.
+* <b>`embedding_lr_multipliers`</b>: Optional. A dictionary from `EmbeddingColumn` to
+ a `float` multiplier. Multiplier will be used to multiply with
+ learning rate for the embedding variables.
##### Returns:
@@ -136,6 +139,15 @@ See BaseEstimator.export.
- - -
+#### `tf.contrib.learn.DNNClassifier.export_savedmodel(*args, **kwargs)` {#DNNClassifier.export_savedmodel}
+
+EXPERIMENTAL FUNCTION
+
+THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning.
+
+
+- - -
+
#### `tf.contrib.learn.DNNClassifier.fit(x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None)` {#DNNClassifier.fit}
See trainable.Trainable. Note: Labels must be integer class indices.
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.select.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.select.md
index af9d8dbb76..594855e8a8 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.select.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.select.md
@@ -1,4 +1,4 @@
-### `tf.select(condition, t, e, name=None)` {#select}
+### `tf.select(*args, **kwargs)` {#select}
Selects elements from `t` or `e`, depending on `condition`.
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.framework.variable.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.framework.variable.md
index fba1f9071c..79081d4e9f 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.framework.variable.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.framework.variable.md
@@ -21,6 +21,11 @@ Gets an existing variable with these parameters or creates a new one.
device.
* <b>`device`</b>: Optional device to place the variable. It can be an string or a
function that is called to get the device for the variable.
+* <b>`partitioner`</b>: Optional callable that accepts a fully defined `TensorShape`
+ and dtype of the `Variable` to be created, and returns a list of
+ partitions for each axis (currently only one axis can be partitioned).
+* <b>`custom_getter`</b>: Callable that allows overwriting the internal
+ get_variable method and has to have the same signature.
##### Returns:
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.div.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.div.md
index 0c0913dc09..8c25e24373 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.div.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.div.md
@@ -1,4 +1,23 @@
### `tf.div(x, y, name=None)` {#div}
+Divides x / y elementwise (using Python 2 division operator semantics).
+NOTE: Prefer using the Tensor division operator or tf.divide which obey Python
+division operator semantics.
+
+This function divides `x` and `y`, forcing Python 2.7 semantics. That is,
+if one of `x` or `y` is a float, then the result will be a float.
+Otherwise, the output will be an integer type. Flooring semantics are used
+for integer division.
+
+##### Args:
+
+
+* <b>`x`</b>: `Tensor` numerator of real numeric type.
+* <b>`y`</b>: `Tensor` denominator of real numeric type.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ `x / y` returns the quotient of x and y.
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.make_template.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.make_template.md
index 1f1e960f48..8e628e6067 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.make_template.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.make_template.md
@@ -1,4 +1,4 @@
-### `tf.make_template(name_, func_, create_scope_now_=False, unique_name_=None, **kwargs)` {#make_template}
+### `tf.make_template(name_, func_, create_scope_now_=False, unique_name_=None, custom_getter_=None, **kwargs)` {#make_template}
Given an arbitrary function, wrap it so that it does variable sharing.
@@ -89,6 +89,9 @@ reduce the likelihood of collisions with kwargs.
* <b>`unique_name_`</b>: When used, it overrides name_ and is not made unique. If a
template of the same scope/unique_name already exists and reuse is false,
an error is raised. Defaults to None.
+* <b>`custom_getter_`</b>: Optional custom getter for variables used in `func_`. See
+ the [`get_variable`](#get_variable) `custom_getter` documentation for
+ more information.
* <b>`**kwargs`</b>: Keyword arguments to apply to `func_`.
##### Returns:
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.summary.FileWriterCache.get.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.summary.FileWriterCache.get.md
new file mode 100644
index 0000000000..0f416a5909
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.summary.FileWriterCache.get.md
@@ -0,0 +1,13 @@
+#### `tf.summary.FileWriterCache.get(logdir)` {#FileWriterCache.get}
+
+Returns the FileWriter for the specified directory.
+
+##### Args:
+
+
+* <b>`logdir`</b>: str, name of the directory.
+
+##### Returns:
+
+ A `FileWriter`.
+
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.SummaryWriterCache.get.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.SummaryWriterCache.get.md
index a7bb580232..d5fee8f7b4 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.SummaryWriterCache.get.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.SummaryWriterCache.get.md
@@ -1,6 +1,6 @@
#### `tf.train.SummaryWriterCache.get(logdir)` {#SummaryWriterCache.get}
-Returns the SummaryWriter for the specified directory.
+Returns the FileWriter for the specified directory.
##### Args:
@@ -9,5 +9,5 @@ Returns the SummaryWriter for the specified directory.
##### Returns:
- A `SummaryWriter`.
+ A `FileWriter`.
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.summary.FileWriter.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.summary.FileWriter.md
index 1ecf1822c9..6e80a4a562 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.summary.FileWriter.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.summary.FileWriter.md
@@ -29,7 +29,7 @@ the graph from the session in which you launched it:
# Launch the graph in a session.
sess = tf.Session()
# Create a summary writer, add the 'graph' to the event file.
-writer = tf.train.SummaryWriter(<some-directory>, sess.graph)
+writer = tf.summary.FileWriter(<some-directory>, sess.graph)
```
The other arguments to the constructor control the asynchronous writes to
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.summary.FileWriterCache.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.summary.FileWriterCache.md
new file mode 100644
index 0000000000..3c6c8773b3
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.summary.FileWriterCache.md
@@ -0,0 +1,26 @@
+Cache for file writers.
+
+This class caches file writers, one per directory.
+- - -
+
+#### `tf.summary.FileWriterCache.clear()` {#FileWriterCache.clear}
+
+Clear cached summary writers. Currently only used for unit tests.
+
+
+- - -
+
+#### `tf.summary.FileWriterCache.get(logdir)` {#FileWriterCache.get}
+
+Returns the FileWriter for the specified directory.
+
+##### Args:
+
+
+* <b>`logdir`</b>: str, name of the directory.
+
+##### Returns:
+
+ A `FileWriter`.
+
+
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriter.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriter.md
index cb96358aa1..e9bdda200f 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriter.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriter.md
@@ -1,131 +1,109 @@
-Writes `Summary` protocol buffers to event files.
-
-The `FileWriter` class provides a mechanism to create an event file in a
-given directory and add summaries and events to it. The class updates the
-file contents asynchronously. This allows a training program to call methods
-to add data to the file directly from the training loop, without slowing down
-training.
- - -
-#### `tf.train.SummaryWriter.__init__(logdir, graph=None, max_queue=10, flush_secs=120, graph_def=None)` {#SummaryWriter.__init__}
-
-Creates a `FileWriter` and an event file.
+#### `tf.train.SummaryWriter.__init__(*args, **kwargs)` {#SummaryWriter.__init__}
-On construction the summary writer creates a new event file in `logdir`.
-This event file will contain `Event` protocol buffers constructed when you
-call one of the following functions: `add_summary()`, `add_session_log()`,
-`add_event()`, or `add_graph()`.
+Creates a `SummaryWriter` and an event file. (deprecated)
-If you pass a `Graph` to the constructor it is added to
-the event file. (This is equivalent to calling `add_graph()` later).
+THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30.
+Instructions for updating:
+Please switch to tf.summary.FileWriter. The interface and behavior is the same; this is just a rename.
-TensorBoard will pick the graph from the file and display it graphically so
-you can interactively explore the graph you built. You will usually pass
-the graph from the session in which you launched it:
+ This class is deprecated, and should be replaced with tf.summary.FileWriter.
-```python
-...create a graph...
-# Launch the graph in a session.
-sess = tf.Session()
-# Create a summary writer, add the 'graph' to the event file.
-writer = tf.train.SummaryWriter(<some-directory>, sess.graph)
-```
+ On construction the summary writer creates a new event file in `logdir`.
+ This event file will contain `Event` protocol buffers constructed when you
+ call one of the following functions: `add_summary()`, `add_session_log()`,
+ `add_event()`, or `add_graph()`.
-The other arguments to the constructor control the asynchronous writes to
-the event file:
+ If you pass a `Graph` to the constructor it is added to
+ the event file. (This is equivalent to calling `add_graph()` later).
-* `flush_secs`: How often, in seconds, to flush the added summaries
- and events to disk.
-* `max_queue`: Maximum number of summaries or events pending to be
- written to disk before one of the 'add' calls block.
+ TensorBoard will pick the graph from the file and display it graphically so
+ you can interactively explore the graph you built. You will usually pass
+ the graph from the session in which you launched it:
-##### Args:
+ ```python
+ ...create a graph...
+ # Launch the graph in a session.
+ sess = tf.Session()
+ # Create a summary writer, add the 'graph' to the event file.
+ writer = tf.train.SummaryWriter(<some-directory>, sess.graph)
+ ```
+ The other arguments to the constructor control the asynchronous writes to
+ the event file:
-* <b>`logdir`</b>: A string. Directory where event file will be written.
-* <b>`graph`</b>: A `Graph` object, such as `sess.graph`.
-* <b>`max_queue`</b>: Integer. Size of the queue for pending events and summaries.
-* <b>`flush_secs`</b>: Number. How often, in seconds, to flush the
- pending events and summaries to disk.
-* <b>`graph_def`</b>: DEPRECATED: Use the `graph` argument instead.
+ * `flush_secs`: How often, in seconds, to flush the added summaries
+ and events to disk.
+ * `max_queue`: Maximum number of summaries or events pending to be
+ written to disk before one of the 'add' calls block.
+ Args:
+ logdir: A string. Directory where event file will be written.
+ graph: A `Graph` object, such as `sess.graph`.
+ max_queue: Integer. Size of the queue for pending events and summaries.
+ flush_secs: Number. How often, in seconds, to flush the
+ pending events and summaries to disk.
+ graph_def: DEPRECATED: Use the `graph` argument instead.
- - -
-#### `tf.train.SummaryWriter.add_summary(summary, global_step=None)` {#SummaryWriter.add_summary}
-
-Adds a `Summary` protocol buffer to the event file.
-
-This method wraps the provided summary in an `Event` protocol buffer
-and adds it to the event file.
+#### `tf.train.SummaryWriter.add_event(event)` {#SummaryWriter.add_event}
-You can pass the result of evaluating any summary op, using
-[`Session.run()`](client.md#Session.run) or
-[`Tensor.eval()`](framework.md#Tensor.eval), to this
-function. Alternatively, you can pass a `tf.Summary` protocol
-buffer that you populate with your own data. The latter is
-commonly done to report evaluation results in event files.
+Adds an event to the event file.
##### Args:
-* <b>`summary`</b>: A `Summary` protocol buffer, optionally serialized as a string.
-* <b>`global_step`</b>: Number. Optional global step value to record with the
- summary.
+* <b>`event`</b>: An `Event` protocol buffer.
- - -
-#### `tf.train.SummaryWriter.add_session_log(session_log, global_step=None)` {#SummaryWriter.add_session_log}
+#### `tf.train.SummaryWriter.add_graph(graph, global_step=None, graph_def=None)` {#SummaryWriter.add_graph}
-Adds a `SessionLog` protocol buffer to the event file.
+Adds a `Graph` to the event file.
-This method wraps the provided session in an `Event` protocol buffer
-and adds it to the event file.
+The graph described by the protocol buffer will be displayed by
+TensorBoard. Most users pass a graph in the constructor instead.
##### Args:
-* <b>`session_log`</b>: A `SessionLog` protocol buffer.
-* <b>`global_step`</b>: Number. Optional global step value to record with the
- summary.
-
-
-- - -
-
-#### `tf.train.SummaryWriter.add_event(event)` {#SummaryWriter.add_event}
-
-Adds an event to the event file.
+* <b>`graph`</b>: A `Graph` object, such as `sess.graph`.
+* <b>`global_step`</b>: Number. Optional global step counter to record with the
+ graph.
+* <b>`graph_def`</b>: DEPRECATED. Use the `graph` parameter instead.
-##### Args:
+##### Raises:
-* <b>`event`</b>: An `Event` protocol buffer.
+* <b>`ValueError`</b>: If both graph and graph_def are passed to the method.
- - -
-#### `tf.train.SummaryWriter.add_graph(graph, global_step=None, graph_def=None)` {#SummaryWriter.add_graph}
+#### `tf.train.SummaryWriter.add_meta_graph(meta_graph_def, global_step=None)` {#SummaryWriter.add_meta_graph}
-Adds a `Graph` to the event file.
+Adds a `MetaGraphDef` to the event file.
-The graph described by the protocol buffer will be displayed by
-TensorBoard. Most users pass a graph in the constructor instead.
+The `MetaGraphDef` allows running the given graph via
+`saver.import_meta_graph()`.
##### Args:
-* <b>`graph`</b>: A `Graph` object, such as `sess.graph`.
+* <b>`meta_graph_def`</b>: A `MetaGraphDef` object, often as retured by
+ `saver.export_meta_graph()`.
* <b>`global_step`</b>: Number. Optional global step counter to record with the
graph.
-* <b>`graph_def`</b>: DEPRECATED. Use the `graph` parameter instead.
##### Raises:
-* <b>`ValueError`</b>: If both graph and graph_def are passed to the method.
+* <b>`TypeError`</b>: If both `meta_graph_def` is not an instance of `MetaGraphDef`.
- - -
@@ -150,20 +128,43 @@ Adds a metadata information for a single session.run() call.
- - -
-#### `tf.train.SummaryWriter.get_logdir()` {#SummaryWriter.get_logdir}
+#### `tf.train.SummaryWriter.add_session_log(session_log, global_step=None)` {#SummaryWriter.add_session_log}
-Returns the directory where event file will be written.
+Adds a `SessionLog` protocol buffer to the event file.
+
+This method wraps the provided session in an `Event` protocol buffer
+and adds it to the event file.
+
+##### Args:
+* <b>`session_log`</b>: A `SessionLog` protocol buffer.
+* <b>`global_step`</b>: Number. Optional global step value to record with the
+ summary.
+
- - -
-#### `tf.train.SummaryWriter.flush()` {#SummaryWriter.flush}
+#### `tf.train.SummaryWriter.add_summary(summary, global_step=None)` {#SummaryWriter.add_summary}
-Flushes the event file to disk.
+Adds a `Summary` protocol buffer to the event file.
-Call this method to make sure that all pending events have been written to
-disk.
+This method wraps the provided summary in an `Event` protocol buffer
+and adds it to the event file.
+
+You can pass the result of evaluating any summary op, using
+[`Session.run()`](client.md#Session.run) or
+[`Tensor.eval()`](framework.md#Tensor.eval), to this
+function. Alternatively, you can pass a `tf.Summary` protocol
+buffer that you populate with your own data. The latter is
+commonly done to report evaluation results in event files.
+
+##### Args:
+
+
+* <b>`summary`</b>: A `Summary` protocol buffer, optionally serialized as a string.
+* <b>`global_step`</b>: Number. Optional global step value to record with the
+ summary.
- - -
@@ -175,8 +176,23 @@ Flushes the event file to disk and close the file.
Call this method when you do not need the summary writer anymore.
+- - -
+
+#### `tf.train.SummaryWriter.flush()` {#SummaryWriter.flush}
+
+Flushes the event file to disk.
+
+Call this method to make sure that all pending events have been written to
+disk.
+
+
+- - -
+
+#### `tf.train.SummaryWriter.get_logdir()` {#SummaryWriter.get_logdir}
+
+Returns the directory where event file will be written.
+
-#### Other Methods
- - -
#### `tf.train.SummaryWriter.reopen()` {#SummaryWriter.reopen}
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriterCache.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriterCache.md
index 4782bfac68..01136ac630 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriterCache.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriterCache.md
@@ -1,6 +1,6 @@
-Cache for summary writers.
+Cache for file writers.
-This class caches summary writers, one per directory.
+This class caches file writers, one per directory.
- - -
#### `tf.train.SummaryWriterCache.clear()` {#SummaryWriterCache.clear}
@@ -12,7 +12,7 @@ Clear cached summary writers. Currently only used for unit tests.
#### `tf.train.SummaryWriterCache.get(logdir)` {#SummaryWriterCache.get}
-Returns the SummaryWriter for the specified directory.
+Returns the FileWriter for the specified directory.
##### Args:
@@ -21,6 +21,6 @@ Returns the SummaryWriter for the specified directory.
##### Returns:
- A `SummaryWriter`.
+ A `FileWriter`.
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.truediv.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.truediv.md
index 0ccb1b2217..7a0c7a4aac 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.truediv.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.truediv.md
@@ -1,12 +1,15 @@
### `tf.truediv(x, y, name=None)` {#truediv}
-Divides x / y elementwise, always producing floating point results.
+Divides x / y elementwise (using Python 3 division operator semantics).
-The same as `tf.div` for floating point arguments, but casts integer arguments
-to floating point before dividing so that the result is always floating point.
-This op is generated by normal `x / y` division in Python 3 and in Python 2.7
-with `from __future__ import division`. If you want integer division that
-rounds down, use `x // y` or `tf.floordiv`.
+NOTE: Prefer using the Tensor operator or tf.divide which obey Python
+division operator semantics.
+
+This function forces Python 3 division operator semantics where all integer
+arguments are cast to floating types first. This op is generated by normal
+`x / y` division in Python 3 and in Python 2.7 with
+`from __future__ import division`. If you want integer division that rounds
+down, use `x // y` or `tf.floordiv`.
`x` and `y` must have the same numeric type. If the inputs are floating
point, the output will have the same type. If the inputs are integral, the
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.GraphKeys.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.GraphKeys.md
index f647b524d4..7ec18e834c 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.GraphKeys.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.GraphKeys.md
@@ -27,7 +27,7 @@ The following standard keys are defined:
for more details.
* `SUMMARIES`: the summary `Tensor` objects that have been created in the
graph. See
- [`tf.merge_all_summaries()`](../../api_docs/python/train.md#merge_all_summaries)
+ [`tf.contrib.deprecated.merge_all_summaries()`](../../api_docs/python/train.md#merge_all_summaries)
for more details.
* `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to
produce input for a computation. See
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Session.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Session.md
index d9de06d5d0..1c183cb120 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Session.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Session.md
@@ -87,8 +87,8 @@ and evaluate every `Tensor` in `fetches`, substituting the values in
`feed_dict` for the corresponding input values.
The `fetches` argument may be a single graph element, or an arbitrarily
-nested list, tuple, namedtuple, or dict containing graph elements at its
-leaves. A graph element can be one of the following types:
+nested list, tuple, namedtuple, dict, or OrderedDict containing graph
+elements at its leaves. A graph element can be one of the following types:
* An [`Operation`](../../api_docs/python/framework.md#Operation).
The corresponding fetched value will be `None`.
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Variable.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Variable.md
index 648606c3db..9b58355a38 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Variable.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Variable.md
@@ -56,7 +56,7 @@ all the variables. You then run that Op after launching the graph.
```python
# Add an Op to initialize global variables.
-init_op = tf.global_variable_initializers()
+init_op = tf.global_variables_initializer()
# Launch the graph in a session.
with tf.Session() as sess:
@@ -154,6 +154,10 @@ Returns the value of the initialized variable.
You should use this instead of the variable itself to initialize another
variable with a value that depends on the value of this variable.
+Beware of using initialized_value except during initialization:
+initialized_value causes the Variable's initializer op to be run, so running
+this op resets the variable to the initial value.
+
```python
# Initialize 'v' with a random tensor.
v = tf.Variable(tf.truncated_normal([10, 40]))
@@ -455,7 +459,18 @@ Returns the truth value of x AND y element-wise.
#### `tf.Variable.__div__(a, *args)` {#Variable.__div__}
+Divide two values using Python 2 semantics. Used for Tensor.__div__.
+
+##### Args:
+
+
+* <b>`x`</b>: `Tensor` numerator of real numeric type.
+* <b>`y`</b>: `Tensor` denominator of real numeric type.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+ `x / y` returns the quotient of x and y.
- - -
@@ -807,7 +822,18 @@ Returns the truth value of x AND y element-wise.
#### `tf.Variable.__rdiv__(a, *args)` {#Variable.__rdiv__}
+Divide two values using Python 2 semantics. Used for Tensor.__div__.
+##### Args:
+
+
+* <b>`x`</b>: `Tensor` numerator of real numeric type.
+* <b>`y`</b>: `Tensor` denominator of real numeric type.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ `x / y` returns the quotient of x and y.
- - -
@@ -951,34 +977,7 @@ Returns x - y element-wise.
#### `tf.Variable.__rtruediv__(a, *args)` {#Variable.__rtruediv__}
-Divides x / y elementwise, always producing floating point results.
-
-The same as `tf.div` for floating point arguments, but casts integer arguments
-to floating point before dividing so that the result is always floating point.
-This op is generated by normal `x / y` division in Python 3 and in Python 2.7
-with `from __future__ import division`. If you want integer division that
-rounds down, use `x // y` or `tf.floordiv`.
-`x` and `y` must have the same numeric type. If the inputs are floating
-point, the output will have the same type. If the inputs are integral, the
-inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
-and `int64` (matching the behavior of Numpy).
-
-##### Args:
-
-
-* <b>`x`</b>: `Tensor` numerator of numeric type.
-* <b>`y`</b>: `Tensor` denominator of numeric type.
-* <b>`name`</b>: A name for the operation (optional).
-
-##### Returns:
-
- `x / y` evaluated in floating point.
-
-##### Raises:
-
-
-* <b>`TypeError`</b>: If `x` and `y` have different dtypes.
- - -
@@ -1020,34 +1019,7 @@ Returns x - y element-wise.
#### `tf.Variable.__truediv__(a, *args)` {#Variable.__truediv__}
-Divides x / y elementwise, always producing floating point results.
-
-The same as `tf.div` for floating point arguments, but casts integer arguments
-to floating point before dividing so that the result is always floating point.
-This op is generated by normal `x / y` division in Python 3 and in Python 2.7
-with `from __future__ import division`. If you want integer division that
-rounds down, use `x // y` or `tf.floordiv`.
-
-`x` and `y` must have the same numeric type. If the inputs are floating
-point, the output will have the same type. If the inputs are integral, the
-inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
-and `int64` (matching the behavior of Numpy).
-
-##### Args:
-
-
-* <b>`x`</b>: `Tensor` numerator of numeric type.
-* <b>`y`</b>: `Tensor` denominator of numeric type.
-* <b>`name`</b>: A name for the operation (optional).
-
-##### Returns:
-
- `x / y` evaluated in floating point.
-
-##### Raises:
-
-* <b>`TypeError`</b>: If `x` and `y` have different dtypes.
- - -
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.framework.model_variable.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.framework.model_variable.md
index 2bbd4d5077..daa96911d9 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.framework.model_variable.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.framework.model_variable.md
@@ -22,6 +22,11 @@ Gets an existing model variable with these parameters or creates a new one.
device.
* <b>`device`</b>: Optional device to place the variable. It can be an string or a
function that is called to get the device for the variable.
+* <b>`partitioner`</b>: Optional callable that accepts a fully defined `TensorShape`
+ and dtype of the `Variable` to be created, and returns a list of
+ partitions for each axis (currently only one axis can be partitioned).
+* <b>`custom_getter`</b>: Callable that allows overwriting the internal
+ get_variable method and has to have the same signature.
##### Returns:
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.learn.LogisticRegressor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.learn.LogisticRegressor.md
index dcbf0fbb1c..82a42aaf22 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.learn.LogisticRegressor.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.learn.LogisticRegressor.md
@@ -125,6 +125,39 @@ The signature of the input_fn accepted by export is changing to be consistent wi
- - -
+#### `tf.contrib.learn.LogisticRegressor.export_savedmodel(*args, **kwargs)` {#LogisticRegressor.export_savedmodel}
+
+Exports inference graph as a SavedModel into given dir. (experimental)
+
+THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning.
+
+
+ Args:
+ export_dir_base: A string containing a directory to write the exported
+ graph and checkpoints.
+ input_fn: A function that takes no argument and
+ returns an `InputFnOps`.
+ default_output_alternative_key: the name of the head to serve when none is
+ specified.
+ assets_extra: A dict specifying how to populate the assets.extra directory
+ within the exported SavedModel. Each key should give the destination
+ path (including the filename) relative to the assets.extra directory.
+ The corresponding value gives the full path of the source file to be
+ copied. For example, the simple case of copying a single file without
+ renaming it is specified as
+ `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
+ as_text: whether to write the SavedModel proto in text format.
+ exports_to_keep: Number of exports to keep.
+
+ Returns:
+ The string path to the exported directory.
+
+ Raises:
+ ValueError: if an unrecognized export_type is requested.
+
+
+- - -
+
#### `tf.contrib.learn.LogisticRegressor.fit(*args, **kwargs)` {#LogisticRegressor.fit}
See `Trainable`. (deprecated arguments)
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.linalg.LinearOperator.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.linalg.LinearOperator.md
index 6624ad3ce9..eb8a200558 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.linalg.LinearOperator.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.linalg.LinearOperator.md
@@ -84,6 +84,19 @@ FILL IN WHAT IS MEANT BY COMPATIBLE SHAPE
### Performance
FILL THIS IN
+
+### Matrix property hints
+
+This `LinearOperator` is initialized with boolean flags of the form `is_X`,
+for `X = non_singular, self_adjoint` etc...
+These have the following meaning
+* If `is_X == True`, callers should expect the operator to have the
+ property `X`. This is a promise that should be fulfilled, but is *not* a
+ runtime assert. For example, finite floating point precision may result
+ in these promises being violated.
+* If `is_X == False`, callers should expect the operator to not have `X`.
+* If `is_X == None` (the default), callers should have no expectation either
+ way.
- - -
#### `tf.contrib.linalg.LinearOperator.__init__(dtype, graph_parents=None, is_non_singular=None, is_self_adjoint=None, is_positive_definite=None, name=None)` {#LinearOperator.__init__}
@@ -93,16 +106,6 @@ Initialize the `LinearOperator`.
**This is a private method for subclass use.**
**Subclasses should copy-paste this `__init__` documentation.**
-For `X = non_singular, self_adjoint` etc...
-`is_X` is a Python `bool` initialization argument with the following meaning
-* If `is_X == True`, callers should expect the operator to have the
- attribute `X`. This is a promise that should be fulfilled, but is *not* a
- runtime assert. Issues, such as floating point error, could mean the
- operator violates this promise.
-* If `is_X == False`, callers should expect the operator to not have `X`.
-* If `is_X == None` (the default), callers should have no expectation either
- way.
-
##### Args:
@@ -113,8 +116,12 @@ For `X = non_singular, self_adjoint` etc...
* <b>`is_non_singular`</b>: Expect that this operator is non-singular.
* <b>`is_self_adjoint`</b>: Expect that this operator is equal to its hermitian
transpose. If `dtype` is real, this is equivalent to being symmetric.
-* <b>`is_positive_definite`</b>: Expect that this operator is positive definite.
-* <b>`name`</b>: A name for this `LinearOperator`. Default: subclass name.
+* <b>`is_positive_definite`</b>: Expect that this operator is positive definite,
+ meaning the real part of all eigenvalues is positive. We do not require
+ the operator to be self-adjoint to be positive-definite. See:
+* <b>`https`</b>: //en.wikipedia.org/wiki/Positive-definite_matrix\
+ #Extension_for_non_symmetric_matrices
+* <b>`name`</b>: A name for this `LinearOperator`.
##### Raises:
@@ -124,6 +131,23 @@ For `X = non_singular, self_adjoint` etc...
- - -
+#### `tf.contrib.linalg.LinearOperator.add_to_tensor(x, name='add_to_tensor')` {#LinearOperator.add_to_tensor}
+
+Add matrix represented by this operator to `x`. Equivalent to `A + x`.
+
+##### Args:
+
+
+* <b>`x`</b>: `Tensor` with same `dtype` and shape broadcastable to `self.shape`.
+* <b>`name`</b>: A name to give this `Op`.
+
+##### Returns:
+
+ A `Tensor` with broadcast shape and same `dtype` as `self`.
+
+
+- - -
+
#### `tf.contrib.linalg.LinearOperator.apply(x, adjoint=False, name='apply')` {#LinearOperator.apply}
Transform `x` with left multiplication: `x --> Ax`.
@@ -154,6 +178,25 @@ Returns an `Op` that asserts this operator is non singular.
Returns an `Op` that asserts this operator is positive definite.
+Here, positive definite means the real part of all eigenvalues is positive.
+We do not require the operator to be self-adjoint.
+
+##### Args:
+
+
+* <b>`name`</b>: A name to give this `Op`.
+
+##### Returns:
+
+ An `Op` that asserts this operator is positive definite.
+
+
+- - -
+
+#### `tf.contrib.linalg.LinearOperator.assert_self_adjoint(name='assert_self_adjoint')` {#LinearOperator.assert_self_adjoint}
+
+Returns an `Op` that asserts this operator is self-adjoint.
+
- - -
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.nn.zero_fraction.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.nn.zero_fraction.md
index 625ff785cb..766341a73f 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.nn.zero_fraction.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.nn.zero_fraction.md
@@ -8,7 +8,7 @@ This is useful in summaries to measure and report sparsity. For example,
```python
z = tf.Relu(...)
- summ = tf.scalar_summary('sparsity', tf.nn.zero_fraction(z))
+ summ = tf.contrib.deprecated.scalar_summary('sparsity', tf.nn.zero_fraction(z))
```
##### Args:
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.summary_iterator.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.summary_iterator.md
index 5702571441..f998e62046 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.summary_iterator.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.summary_iterator.md
@@ -18,7 +18,7 @@ Example: Print selected summary values.
# This example supposes that the events file contains summaries with a
# summary value tag 'loss'. These could have been added by calling
# `add_summary()`, passing the output of a scalar summary op created with
-# with: `tf.scalar_summary(['loss'], loss_tensor)`.
+# with: `tf.summary.scalar('loss', loss_tensor)`.
for e in tf.train.summary_iterator(path to events file):
for v in e.summary.value:
if v.tag == 'loss':
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index a4490f22db..f3127013bf 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -687,6 +687,7 @@
* **[Summary Operations](../../api_docs/python/summary.md)**:
* [`audio`](../../api_docs/python/summary.md#audio)
* [`FileWriter`](../../api_docs/python/summary.md#FileWriter)
+ * [`FileWriterCache`](../../api_docs/python/summary.md#FileWriterCache)
* [`get_summary_description`](../../api_docs/python/summary.md#get_summary_description)
* [`histogram`](../../api_docs/python/summary.md#histogram)
* [`image`](../../api_docs/python/summary.md#image)
@@ -957,7 +958,6 @@
* [`sparse_column_with_hash_bucket`](../../api_docs/python/contrib.layers.md#sparse_column_with_hash_bucket)
* [`sparse_column_with_integerized_feature`](../../api_docs/python/contrib.layers.md#sparse_column_with_integerized_feature)
* [`sparse_column_with_keys`](../../api_docs/python/contrib.layers.md#sparse_column_with_keys)
- * [`stack`](../../api_docs/python/contrib.layers.md#stack)
* [`sum_regularizer`](../../api_docs/python/contrib.layers.md#sum_regularizer)
* [`summarize_activation`](../../api_docs/python/contrib.layers.md#summarize_activation)
* [`summarize_activations`](../../api_docs/python/contrib.layers.md#summarize_activations)
diff --git a/tensorflow/g3doc/api_docs/python/math_ops.md b/tensorflow/g3doc/api_docs/python/math_ops.md
index 27e8f7ed82..994ff48c7a 100644
--- a/tensorflow/g3doc/api_docs/python/math_ops.md
+++ b/tensorflow/g3doc/api_docs/python/math_ops.md
@@ -108,7 +108,26 @@ multiply with arbitrary tensors.
### `tf.div(x, y, name=None)` {#div}
+Divides x / y elementwise (using Python 2 division operator semantics).
+NOTE: Prefer using the Tensor division operator or tf.divide which obey Python
+division operator semantics.
+
+This function divides `x` and `y`, forcing Python 2.7 semantics. That is,
+if one of `x` or `y` is a float, then the result will be a float.
+Otherwise, the output will be an integer type. Flooring semantics are used
+for integer division.
+
+##### Args:
+
+
+* <b>`x`</b>: `Tensor` numerator of real numeric type.
+* <b>`y`</b>: `Tensor` denominator of real numeric type.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ `x / y` returns the quotient of x and y.
- - -
@@ -122,13 +141,16 @@ Computes Python style division of `x` by `y`.
### `tf.truediv(x, y, name=None)` {#truediv}
-Divides x / y elementwise, always producing floating point results.
+Divides x / y elementwise (using Python 3 division operator semantics).
+
+NOTE: Prefer using the Tensor operator or tf.divide which obey Python
+division operator semantics.
-The same as `tf.div` for floating point arguments, but casts integer arguments
-to floating point before dividing so that the result is always floating point.
-This op is generated by normal `x / y` division in Python 3 and in Python 2.7
-with `from __future__ import division`. If you want integer division that
-rounds down, use `x // y` or `tf.floordiv`.
+This function forces Python 3 division operator semantics where all integer
+arguments are cast to floating types first. This op is generated by normal
+`x / y` division in Python 3 and in Python 2.7 with
+`from __future__ import division`. If you want integer division that rounds
+down, use `x // y` or `tf.floordiv`.
`x` and `y` must have the same numeric type. If the inputs are floating
point, the output will have the same type. If the inputs are integral, the
diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md
index 7e440b887d..d5173bdb19 100644
--- a/tensorflow/g3doc/api_docs/python/state_ops.md
+++ b/tensorflow/g3doc/api_docs/python/state_ops.md
@@ -71,7 +71,7 @@ all the variables. You then run that Op after launching the graph.
```python
# Add an Op to initialize global variables.
-init_op = tf.global_variable_initializers()
+init_op = tf.global_variables_initializer()
# Launch the graph in a session.
with tf.Session() as sess:
@@ -169,6 +169,10 @@ Returns the value of the initialized variable.
You should use this instead of the variable itself to initialize another
variable with a value that depends on the value of this variable.
+Beware of using initialized_value except during initialization:
+initialized_value causes the Variable's initializer op to be run, so running
+this op resets the variable to the initial value.
+
```python
# Initialize 'v' with a random tensor.
v = tf.Variable(tf.truncated_normal([10, 40]))
@@ -470,7 +474,18 @@ Returns the truth value of x AND y element-wise.
#### `tf.Variable.__div__(a, *args)` {#Variable.__div__}
+Divide two values using Python 2 semantics. Used for Tensor.__div__.
+
+##### Args:
+
+
+* <b>`x`</b>: `Tensor` numerator of real numeric type.
+* <b>`y`</b>: `Tensor` denominator of real numeric type.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+ `x / y` returns the quotient of x and y.
- - -
@@ -822,7 +837,18 @@ Returns the truth value of x AND y element-wise.
#### `tf.Variable.__rdiv__(a, *args)` {#Variable.__rdiv__}
+Divide two values using Python 2 semantics. Used for Tensor.__div__.
+##### Args:
+
+
+* <b>`x`</b>: `Tensor` numerator of real numeric type.
+* <b>`y`</b>: `Tensor` denominator of real numeric type.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ `x / y` returns the quotient of x and y.
- - -
@@ -966,34 +992,7 @@ Returns x - y element-wise.
#### `tf.Variable.__rtruediv__(a, *args)` {#Variable.__rtruediv__}
-Divides x / y elementwise, always producing floating point results.
-
-The same as `tf.div` for floating point arguments, but casts integer arguments
-to floating point before dividing so that the result is always floating point.
-This op is generated by normal `x / y` division in Python 3 and in Python 2.7
-with `from __future__ import division`. If you want integer division that
-rounds down, use `x // y` or `tf.floordiv`.
-`x` and `y` must have the same numeric type. If the inputs are floating
-point, the output will have the same type. If the inputs are integral, the
-inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
-and `int64` (matching the behavior of Numpy).
-
-##### Args:
-
-
-* <b>`x`</b>: `Tensor` numerator of numeric type.
-* <b>`y`</b>: `Tensor` denominator of numeric type.
-* <b>`name`</b>: A name for the operation (optional).
-
-##### Returns:
-
- `x / y` evaluated in floating point.
-
-##### Raises:
-
-
-* <b>`TypeError`</b>: If `x` and `y` have different dtypes.
- - -
@@ -1035,34 +1034,7 @@ Returns x - y element-wise.
#### `tf.Variable.__truediv__(a, *args)` {#Variable.__truediv__}
-Divides x / y elementwise, always producing floating point results.
-
-The same as `tf.div` for floating point arguments, but casts integer arguments
-to floating point before dividing so that the result is always floating point.
-This op is generated by normal `x / y` division in Python 3 and in Python 2.7
-with `from __future__ import division`. If you want integer division that
-rounds down, use `x // y` or `tf.floordiv`.
-
-`x` and `y` must have the same numeric type. If the inputs are floating
-point, the output will have the same type. If the inputs are integral, the
-inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
-and `int64` (matching the behavior of Numpy).
-
-##### Args:
-
-
-* <b>`x`</b>: `Tensor` numerator of numeric type.
-* <b>`y`</b>: `Tensor` denominator of numeric type.
-* <b>`name`</b>: A name for the operation (optional).
-
-##### Returns:
-
- `x / y` evaluated in floating point.
-
-##### Raises:
-
-* <b>`TypeError`</b>: If `x` and `y` have different dtypes.
- - -
@@ -2268,7 +2240,7 @@ Returns the current variable scope.
- - -
-### `tf.make_template(name_, func_, create_scope_now_=False, unique_name_=None, **kwargs)` {#make_template}
+### `tf.make_template(name_, func_, create_scope_now_=False, unique_name_=None, custom_getter_=None, **kwargs)` {#make_template}
Given an arbitrary function, wrap it so that it does variable sharing.
@@ -2359,6 +2331,9 @@ reduce the likelihood of collisions with kwargs.
* <b>`unique_name_`</b>: When used, it overrides name_ and is not made unique. If a
template of the same scope/unique_name already exists and reuse is false,
an error is raised. Defaults to None.
+* <b>`custom_getter_`</b>: Optional custom getter for variables used in `func_`. See
+ the [`get_variable`](#get_variable) `custom_getter` documentation for
+ more information.
* <b>`**kwargs`</b>: Keyword arguments to apply to `func_`.
##### Returns:
diff --git a/tensorflow/g3doc/api_docs/python/string_ops.md b/tensorflow/g3doc/api_docs/python/string_ops.md
index 86878ca664..7e75148891 100644
--- a/tensorflow/g3doc/api_docs/python/string_ops.md
+++ b/tensorflow/g3doc/api_docs/python/string_ops.md
@@ -194,7 +194,8 @@ containing the splitted tokens. Empty tokens are ignored.
If `delimiter` is an empty string, each element of the `source` is split
into individual strings, each containing one byte. (This includes splitting
-multibyte sequences of UTF-8.)
+multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is
+treated as a set of delimiters with each considered a potential split point.
For example:
N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output
@@ -215,17 +216,17 @@ st.values = ['hello', 'world', 'a', 'b', 'c']
* <b>`delimiter`</b>: `0-D` string `Tensor`, the delimiter character, the string should
be length 0 or 1.
+##### Raises:
+
+
+* <b>`ValueError`</b>: If delimiter is not a string.
+
##### Returns:
A `SparseTensor` of rank `2`, the strings split according to the delimiter.
The first column of the indices corresponds to the row in `source` and the
second column corresponds to the index of the split component in this row.
-##### Raises:
-
-
-* <b>`ValueError`</b>: If delimiter is not a single-byte character.
-
- - -
diff --git a/tensorflow/g3doc/api_docs/python/summary.md b/tensorflow/g3doc/api_docs/python/summary.md
index 90598f7e8c..f20f876ca3 100644
--- a/tensorflow/g3doc/api_docs/python/summary.md
+++ b/tensorflow/g3doc/api_docs/python/summary.md
@@ -41,7 +41,7 @@ the graph from the session in which you launched it:
# Launch the graph in a session.
sess = tf.Session()
# Create a summary writer, add the 'graph' to the event file.
-writer = tf.train.SummaryWriter(<some-directory>, sess.graph)
+writer = tf.summary.FileWriter(<some-directory>, sess.graph)
```
The other arguments to the constructor control the asynchronous writes to
@@ -202,6 +202,37 @@ Does nothing if the EventFileWriter was not closed.
+- - -
+
+### `class tf.summary.FileWriterCache` {#FileWriterCache}
+
+Cache for file writers.
+
+This class caches file writers, one per directory.
+- - -
+
+#### `tf.summary.FileWriterCache.clear()` {#FileWriterCache.clear}
+
+Clear cached summary writers. Currently only used for unit tests.
+
+
+- - -
+
+#### `tf.summary.FileWriterCache.get(logdir)` {#FileWriterCache.get}
+
+Returns the FileWriter for the specified directory.
+
+##### Args:
+
+
+* <b>`logdir`</b>: str, name of the directory.
+
+##### Returns:
+
+ A `FileWriter`.
+
+
+
### Summary Ops
- - -
diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md
index b336b8cfc9..fb4ff94f4f 100644
--- a/tensorflow/g3doc/api_docs/python/train.md
+++ b/tensorflow/g3doc/api_docs/python/train.md
@@ -1372,7 +1372,7 @@ saver.restore(...checkpoint filename...)
- - -
-#### `tf.train.ExponentialMovingAverage.__init__(decay, num_updates=None, name='ExponentialMovingAverage')` {#ExponentialMovingAverage.__init__}
+#### `tf.train.ExponentialMovingAverage.__init__(decay, num_updates=None, zero_debias=False, name='ExponentialMovingAverage')` {#ExponentialMovingAverage.__init__}
Creates a new ExponentialMovingAverage object.
@@ -1392,6 +1392,8 @@ move faster. If passed, the actual decay rate used is:
* <b>`decay`</b>: Float. The decay to use.
* <b>`num_updates`</b>: Optional count of number of updates applied to variables.
+* <b>`zero_debias`</b>: If `True`, zero debias moving-averages that are initialized
+ with tensors.
* <b>`name`</b>: String. Optional prefix name to use for the name of ops added in
`apply()`.
@@ -4048,7 +4050,7 @@ This is useful in summaries to measure and report sparsity. For example,
```python
z = tf.Relu(...)
- summ = tf.scalar_summary('sparsity', tf.nn.zero_fraction(z))
+ summ = tf.contrib.deprecated.scalar_summary('sparsity', tf.nn.zero_fraction(z))
```
##### Args:
@@ -4122,134 +4124,112 @@ overview of summaries, event files, and visualization in TensorBoard.
### `class tf.train.SummaryWriter` {#SummaryWriter}
-Writes `Summary` protocol buffers to event files.
-
-The `FileWriter` class provides a mechanism to create an event file in a
-given directory and add summaries and events to it. The class updates the
-file contents asynchronously. This allows a training program to call methods
-to add data to the file directly from the training loop, without slowing down
-training.
- - -
-#### `tf.train.SummaryWriter.__init__(logdir, graph=None, max_queue=10, flush_secs=120, graph_def=None)` {#SummaryWriter.__init__}
+#### `tf.train.SummaryWriter.__init__(*args, **kwargs)` {#SummaryWriter.__init__}
-Creates a `FileWriter` and an event file.
+Creates a `SummaryWriter` and an event file. (deprecated)
-On construction the summary writer creates a new event file in `logdir`.
-This event file will contain `Event` protocol buffers constructed when you
-call one of the following functions: `add_summary()`, `add_session_log()`,
-`add_event()`, or `add_graph()`.
+THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30.
+Instructions for updating:
+Please switch to tf.summary.FileWriter. The interface and behavior is the same; this is just a rename.
-If you pass a `Graph` to the constructor it is added to
-the event file. (This is equivalent to calling `add_graph()` later).
+ This class is deprecated, and should be replaced with tf.summary.FileWriter.
-TensorBoard will pick the graph from the file and display it graphically so
-you can interactively explore the graph you built. You will usually pass
-the graph from the session in which you launched it:
+ On construction the summary writer creates a new event file in `logdir`.
+ This event file will contain `Event` protocol buffers constructed when you
+ call one of the following functions: `add_summary()`, `add_session_log()`,
+ `add_event()`, or `add_graph()`.
-```python
-...create a graph...
-# Launch the graph in a session.
-sess = tf.Session()
-# Create a summary writer, add the 'graph' to the event file.
-writer = tf.train.SummaryWriter(<some-directory>, sess.graph)
-```
+ If you pass a `Graph` to the constructor it is added to
+ the event file. (This is equivalent to calling `add_graph()` later).
-The other arguments to the constructor control the asynchronous writes to
-the event file:
+ TensorBoard will pick the graph from the file and display it graphically so
+ you can interactively explore the graph you built. You will usually pass
+ the graph from the session in which you launched it:
-* `flush_secs`: How often, in seconds, to flush the added summaries
- and events to disk.
-* `max_queue`: Maximum number of summaries or events pending to be
- written to disk before one of the 'add' calls block.
-
-##### Args:
+ ```python
+ ...create a graph...
+ # Launch the graph in a session.
+ sess = tf.Session()
+ # Create a summary writer, add the 'graph' to the event file.
+ writer = tf.train.SummaryWriter(<some-directory>, sess.graph)
+ ```
+ The other arguments to the constructor control the asynchronous writes to
+ the event file:
-* <b>`logdir`</b>: A string. Directory where event file will be written.
-* <b>`graph`</b>: A `Graph` object, such as `sess.graph`.
-* <b>`max_queue`</b>: Integer. Size of the queue for pending events and summaries.
-* <b>`flush_secs`</b>: Number. How often, in seconds, to flush the
- pending events and summaries to disk.
-* <b>`graph_def`</b>: DEPRECATED: Use the `graph` argument instead.
+ * `flush_secs`: How often, in seconds, to flush the added summaries
+ and events to disk.
+ * `max_queue`: Maximum number of summaries or events pending to be
+ written to disk before one of the 'add' calls block.
+ Args:
+ logdir: A string. Directory where event file will be written.
+ graph: A `Graph` object, such as `sess.graph`.
+ max_queue: Integer. Size of the queue for pending events and summaries.
+ flush_secs: Number. How often, in seconds, to flush the
+ pending events and summaries to disk.
+ graph_def: DEPRECATED: Use the `graph` argument instead.
- - -
-#### `tf.train.SummaryWriter.add_summary(summary, global_step=None)` {#SummaryWriter.add_summary}
-
-Adds a `Summary` protocol buffer to the event file.
-
-This method wraps the provided summary in an `Event` protocol buffer
-and adds it to the event file.
+#### `tf.train.SummaryWriter.add_event(event)` {#SummaryWriter.add_event}
-You can pass the result of evaluating any summary op, using
-[`Session.run()`](client.md#Session.run) or
-[`Tensor.eval()`](framework.md#Tensor.eval), to this
-function. Alternatively, you can pass a `tf.Summary` protocol
-buffer that you populate with your own data. The latter is
-commonly done to report evaluation results in event files.
+Adds an event to the event file.
##### Args:
-* <b>`summary`</b>: A `Summary` protocol buffer, optionally serialized as a string.
-* <b>`global_step`</b>: Number. Optional global step value to record with the
- summary.
+* <b>`event`</b>: An `Event` protocol buffer.
- - -
-#### `tf.train.SummaryWriter.add_session_log(session_log, global_step=None)` {#SummaryWriter.add_session_log}
+#### `tf.train.SummaryWriter.add_graph(graph, global_step=None, graph_def=None)` {#SummaryWriter.add_graph}
-Adds a `SessionLog` protocol buffer to the event file.
+Adds a `Graph` to the event file.
-This method wraps the provided session in an `Event` protocol buffer
-and adds it to the event file.
+The graph described by the protocol buffer will be displayed by
+TensorBoard. Most users pass a graph in the constructor instead.
##### Args:
-* <b>`session_log`</b>: A `SessionLog` protocol buffer.
-* <b>`global_step`</b>: Number. Optional global step value to record with the
- summary.
-
-
-- - -
-
-#### `tf.train.SummaryWriter.add_event(event)` {#SummaryWriter.add_event}
-
-Adds an event to the event file.
+* <b>`graph`</b>: A `Graph` object, such as `sess.graph`.
+* <b>`global_step`</b>: Number. Optional global step counter to record with the
+ graph.
+* <b>`graph_def`</b>: DEPRECATED. Use the `graph` parameter instead.
-##### Args:
+##### Raises:
-* <b>`event`</b>: An `Event` protocol buffer.
+* <b>`ValueError`</b>: If both graph and graph_def are passed to the method.
- - -
-#### `tf.train.SummaryWriter.add_graph(graph, global_step=None, graph_def=None)` {#SummaryWriter.add_graph}
+#### `tf.train.SummaryWriter.add_meta_graph(meta_graph_def, global_step=None)` {#SummaryWriter.add_meta_graph}
-Adds a `Graph` to the event file.
+Adds a `MetaGraphDef` to the event file.
-The graph described by the protocol buffer will be displayed by
-TensorBoard. Most users pass a graph in the constructor instead.
+The `MetaGraphDef` allows running the given graph via
+`saver.import_meta_graph()`.
##### Args:
-* <b>`graph`</b>: A `Graph` object, such as `sess.graph`.
+* <b>`meta_graph_def`</b>: A `MetaGraphDef` object, often as retured by
+ `saver.export_meta_graph()`.
* <b>`global_step`</b>: Number. Optional global step counter to record with the
graph.
-* <b>`graph_def`</b>: DEPRECATED. Use the `graph` parameter instead.
##### Raises:
-* <b>`ValueError`</b>: If both graph and graph_def are passed to the method.
+* <b>`TypeError`</b>: If both `meta_graph_def` is not an instance of `MetaGraphDef`.
- - -
@@ -4274,20 +4254,43 @@ Adds a metadata information for a single session.run() call.
- - -
-#### `tf.train.SummaryWriter.get_logdir()` {#SummaryWriter.get_logdir}
+#### `tf.train.SummaryWriter.add_session_log(session_log, global_step=None)` {#SummaryWriter.add_session_log}
-Returns the directory where event file will be written.
+Adds a `SessionLog` protocol buffer to the event file.
+This method wraps the provided session in an `Event` protocol buffer
+and adds it to the event file.
+
+##### Args:
+
+
+* <b>`session_log`</b>: A `SessionLog` protocol buffer.
+* <b>`global_step`</b>: Number. Optional global step value to record with the
+ summary.
- - -
-#### `tf.train.SummaryWriter.flush()` {#SummaryWriter.flush}
+#### `tf.train.SummaryWriter.add_summary(summary, global_step=None)` {#SummaryWriter.add_summary}
-Flushes the event file to disk.
+Adds a `Summary` protocol buffer to the event file.
-Call this method to make sure that all pending events have been written to
-disk.
+This method wraps the provided summary in an `Event` protocol buffer
+and adds it to the event file.
+
+You can pass the result of evaluating any summary op, using
+[`Session.run()`](client.md#Session.run) or
+[`Tensor.eval()`](framework.md#Tensor.eval), to this
+function. Alternatively, you can pass a `tf.Summary` protocol
+buffer that you populate with your own data. The latter is
+commonly done to report evaluation results in event files.
+
+##### Args:
+
+
+* <b>`summary`</b>: A `Summary` protocol buffer, optionally serialized as a string.
+* <b>`global_step`</b>: Number. Optional global step value to record with the
+ summary.
- - -
@@ -4299,8 +4302,23 @@ Flushes the event file to disk and close the file.
Call this method when you do not need the summary writer anymore.
+- - -
+
+#### `tf.train.SummaryWriter.flush()` {#SummaryWriter.flush}
+
+Flushes the event file to disk.
+
+Call this method to make sure that all pending events have been written to
+disk.
+
+
+- - -
+
+#### `tf.train.SummaryWriter.get_logdir()` {#SummaryWriter.get_logdir}
+
+Returns the directory where event file will be written.
+
-#### Other Methods
- - -
#### `tf.train.SummaryWriter.reopen()` {#SummaryWriter.reopen}
@@ -4318,9 +4336,9 @@ Does nothing if the EventFileWriter was not closed.
### `class tf.train.SummaryWriterCache` {#SummaryWriterCache}
-Cache for summary writers.
+Cache for file writers.
-This class caches summary writers, one per directory.
+This class caches file writers, one per directory.
- - -
#### `tf.train.SummaryWriterCache.clear()` {#SummaryWriterCache.clear}
@@ -4332,7 +4350,7 @@ Clear cached summary writers. Currently only used for unit tests.
#### `tf.train.SummaryWriterCache.get(logdir)` {#SummaryWriterCache.get}
-Returns the SummaryWriter for the specified directory.
+Returns the FileWriter for the specified directory.
##### Args:
@@ -4341,7 +4359,7 @@ Returns the SummaryWriter for the specified directory.
##### Returns:
- A `SummaryWriter`.
+ A `FileWriter`.
@@ -4367,7 +4385,7 @@ Example: Print selected summary values.
# This example supposes that the events file contains summaries with a
# summary value tag 'loss'. These could have been added by calling
# `add_summary()`, passing the output of a scalar summary op created with
-# with: `tf.scalar_summary(['loss'], loss_tensor)`.
+# with: `tf.summary.scalar('loss', loss_tensor)`.
for e in tf.train.summary_iterator(path to events file):
for v in e.summary.value:
if v.tag == 'loss':
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index 875f19be74..f4177dc47a 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -293,7 +293,7 @@ packages needed by TensorFlow.
* Activate the conda environment and install TensorFlow in it.
* After the install you will activate the conda environment each time you
want to use TensorFlow.
-* Optionally install ipython and other packages into the conda environment
+* Optionally install ipython and other packages into the conda environment.
Install Anaconda:
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/index.md b/tensorflow/g3doc/how_tos/adding_an_op/index.md
index 15a1e68d5f..88d0cf9e1c 100644
--- a/tensorflow/g3doc/how_tos/adding_an_op/index.md
+++ b/tensorflow/g3doc/how_tos/adding_an_op/index.md
@@ -37,7 +37,7 @@ any [attrs](#attrs) the Op might require.
To see how this works, suppose you'd like to create an Op that takes a tensor of
`int32`s and outputs a copy of the tensor, with all but the first element set to
-zero. Create file [`tensorflow/core/user_ops`][user_ops]`/zero_out.cc` and
+zero. Create file `tensorflow/core/user_ops/zero_out.cc` and
add a call to the `REGISTER_OP` macro that defines the interface for such an Op:
```c++
@@ -321,11 +321,10 @@ using the `Attr` method, which expects a spec of the form:
where `<name>` begins with a letter and can be composed of alphanumeric
characters and underscores, and `<attr-type-expr>` is a type expression of the
-form [described below](#attr-types)
+form [described below](#attr-types).
For example, if you'd like the `ZeroOut` Op to preserve a user-specified index,
instead of only the 0th element, you can register the Op like so:
-
<code class="lang-c++"><pre>
REGISTER\_OP("ZeroOut")
<b>.Attr("preserve\_index: int")</b>
@@ -335,7 +334,6 @@ REGISTER\_OP("ZeroOut")
Your kernel can then access this attr in its constructor via the `context`
parameter:
-
<code class="lang-c++"><pre>
class ZeroOutOp : public OpKernel {
public:
@@ -357,7 +355,6 @@ class ZeroOutOp : public OpKernel {
</pre></code>
which can then be used in the `Compute` method:
-
<code class="lang-c++"><pre>
void Compute(OpKernelContext\* context) override {
// ...
@@ -512,7 +509,6 @@ you would then register an `OpKernel` for each supported type.
For instance, if you'd like the `ZeroOut` Op to work on `float`s
in addition to `int32`s, your Op registration might look like:
-
<code class="lang-c++"><pre>
REGISTER\_OP("ZeroOut")
<b>.Attr("T: {float, int32}")</b>
@@ -632,7 +628,6 @@ REGISTER\_KERNEL\_BUILDER(
> </pre></code>
Lets say you wanted to add more types, say `double`:
-
<code class="lang-c++"><pre>
REGISTER\_OP("ZeroOut")
<b>.Attr("T: {float, <b>double,</b> int32}")</b>
@@ -643,7 +638,6 @@ REGISTER\_OP("ZeroOut")
Instead of writing another `OpKernel` with redundant code as above, often you
will be able to use a C++ template instead. You will still have one kernel
registration (`REGISTER_KERNEL_BUILDER` call) per overload.
-
<code class="lang-c++"><pre>
<b>template &lt;typename T&gt;</b>
class ZeroOutOp : public OpKernel {
diff --git a/tensorflow/g3doc/how_tos/graph_viz/index.md b/tensorflow/g3doc/how_tos/graph_viz/index.md
index c94afd70b5..d09769e274 100644
--- a/tensorflow/g3doc/how_tos/graph_viz/index.md
+++ b/tensorflow/g3doc/how_tos/graph_viz/index.md
@@ -33,9 +33,9 @@ with tf.name_scope('hidden') as scope:
This results in the following three op names:
-* *hidden*/alpha
-* *hidden*/weights
-* *hidden*/biases
+* `hidden/alpha`
+* `hidden/weights`
+* `hidden/biases`
By default, the visualization will collapse all three into a node labeled `hidden`.
The extra detail isn't lost. You can double-click, or click
@@ -253,7 +253,7 @@ The images below show the CIFAR-10 model with tensor shape information:
Often it is useful to collect runtime metadata for a run, such as total memory
usage, total compute time, and tensor shapes for nodes. The code example below
is a snippet from the train and test section of a modification of the
-[simple MNIST tutorial](http://tensorflow.org/tutorials/mnist/beginners/index.md),
+[simple MNIST tutorial](../../tutorials/mnist/beginners/index.md),
in which we have recorded summaries and runtime statistics. See the [Summaries Tutorial](../../how_tos/summaries_and_tensorboard/index.md#serializing-the-data)
for details on how to record summaries.
Full source is [here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py).
diff --git a/tensorflow/g3doc/how_tos/hadoop/index.md b/tensorflow/g3doc/how_tos/hadoop/index.md
index f55d6d182f..a2dd67babd 100644
--- a/tensorflow/g3doc/how_tos/hadoop/index.md
+++ b/tensorflow/g3doc/how_tos/hadoop/index.md
@@ -29,7 +29,7 @@ be set:
set this environment variable by running:
```shell
-source $HADOOP_HOME/libexec/hadoop-config.sh
+source ${HADOOP_HOME}/libexec/hadoop-config.sh
```
* **LD_LIBRARY_PATH**: To include the path to libjvm.so, and optionally the path
@@ -37,16 +37,16 @@ source $HADOOP_HOME/libexec/hadoop-config.sh
`$HADOOP_HDFS_HOME/lib/native`. On Linux:
```shell
-export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$JAVA_HOME/jre/lib/amd64/server
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${JAVA_HOME}/jre/lib/amd64/server
```
* **CLASSPATH**: The Hadoop jars must be added prior to running your
TensorFlow program. The CLASSPATH set by
- `$HADOOP_HOME/libexec/hadoop-config.sh` is insufficient. Globs must be
+ `${HADOOP_HOME}/libexec/hadoop-config.sh` is insufficient. Globs must be
expanded as described in the libhdfs documentation:
```shell
-CLASSPATH=$($HADOOP_HDFS_HOME/bin/hadoop classpath --glob) python your_script.py
+CLASSPATH=$($HADOOP_HDFS_HOME}/bin/hadoop classpath --glob) python your_script.py
```
If you are running [Distributed TensorFlow](../distributed/index.md), then all
diff --git a/tensorflow/g3doc/how_tos/image_retraining/index.md b/tensorflow/g3doc/how_tos/image_retraining/index.md
index c6a0467fb3..d721f61810 100644
--- a/tensorflow/g3doc/how_tos/image_retraining/index.md
+++ b/tensorflow/g3doc/how_tos/image_retraining/index.md
@@ -290,11 +290,32 @@ usual split is to put 80% of the images into the main training set, keep 10%
aside to run as validation frequently during training, and then have a final 10%
that are used less often as a testing set to predict the real-world performance
of the classifier. These ratios can be controlled using the
-`--testing_percentage` and `--validation_percentage` flags. One subtle thing
-that the script does is it uses the filename of the image to determine which set
-it is put into. This is designed to ensure that images don't get moved between
-training and testing sets on different runs, since that could be a problem if
-images that had been used for training a model were subsequently used in a
-validation set. In general you should be able to leave these values at their
-defaults, since you won't usually find any advantage to training to adjusting
-them.
+`--testing_percentage` and `--validation_percentage` flags. In general
+you should be able to leave these values at their defaults, since you won't
+usually find any advantage to training to adjusting them.
+
+Note that the script uses the image filenames (rather than a completely random
+function) to divide the images among the training, validation, and test sets.
+This is done to ensure that images don't get moved between training and testing
+sets on different runs, since that could be a problem if images that had been
+used for training a model were subsequently used in a validation set.
+
+You might notice that the validation accuracy fluctuates among iterations. Much
+of this fluctuation arises from the fact that a random subset of the validation
+set is chosen for each validation accuracy measurement. The fluctuations can be
+greatly reduced, at the cost of some increase in training time, by choosing
+`--validation_batch_size=-1`, which uses the entire validation set for each
+accuracy computation.
+
+Once training is complete, you may find it insightful to examine misclassified
+images in the test set. This can be done by adding the flag
+`--print_misclassified_test_images`. This may help you get a feeling for which
+types of images were most confusing for the model, and which categories were
+most difficult to distinguish. For instance, you might discover that some
+subtype of a particular category, or some unusual photo angle, is particularly
+difficult to identify, which may encourage you to add more training images of
+that subtype. Oftentimes, examining misclassified images can also point to
+errors in the input data set, such as mislabeled, low-quality, or ambiguous
+images. However, one should generally avoid point-fixing individual errors in
+the test set, since they are likely to merely reflect more general problems in
+the (much larger) training set.
diff --git a/tensorflow/g3doc/how_tos/threading_and_queues/index.md b/tensorflow/g3doc/how_tos/threading_and_queues/index.md
index 46444a02db..639ad116c9 100644
--- a/tensorflow/g3doc/how_tos/threading_and_queues/index.md
+++ b/tensorflow/g3doc/how_tos/threading_and_queues/index.md
@@ -28,7 +28,7 @@ creating these operations.
Now that you have a bit of a feel for queues, let's dive into the details...
-## Queue Use Overview
+## Queue use overview
Queues, such as `FIFOQueue` and `RandomShuffleQueue`, are important TensorFlow
objects for computing tensors asynchronously in a graph.
@@ -149,7 +149,7 @@ coord.request_stop()
coord.join(enqueue_threads)
```
-## Handling Exceptions
+## Handling exceptions
Threads started by queue runners do more than just run the enqueue ops. They
also catch and handle exceptions generated by queues, including
diff --git a/tensorflow/g3doc/how_tos/variable_scope/index.md b/tensorflow/g3doc/how_tos/variable_scope/index.md
index 4e01ce1259..bb1b3e53f4 100644
--- a/tensorflow/g3doc/how_tos/variable_scope/index.md
+++ b/tensorflow/g3doc/how_tos/variable_scope/index.md
@@ -69,7 +69,7 @@ def my_image_filter(input_images, variables_dict):
strides=[1, 1, 1, 1], padding='SAME')
return tf.nn.relu(conv2 + variables_dict["conv2_biases"])
-# The 2 calls to my_image_filter() now use the same variables
+# Both calls to my_image_filter() now use the same variables
result1 = my_image_filter(image1, variables_dict)
result2 = my_image_filter(image2, variables_dict)
```
@@ -90,7 +90,7 @@ while constructing a graph.
## Variable Scope Example
-Variable Scope mechanism in TensorFlow consists of 2 main functions:
+Variable Scope mechanism in TensorFlow consists of two main functions:
* `tf.get_variable(<name>, <shape>, <initializer>)`:
Creates or returns a variable with a given name.
@@ -280,9 +280,9 @@ when opening a new variable scope.
```python
with tf.variable_scope("foo") as foo_scope:
v = tf.get_variable("v", [1])
-with tf.variable_scope(foo_scope)
+with tf.variable_scope(foo_scope):
w = tf.get_variable("w", [1])
-with tf.variable_scope(foo_scope, reuse=True)
+with tf.variable_scope(foo_scope, reuse=True):
v1 = tf.get_variable("v", [1])
w1 = tf.get_variable("w", [1])
assert v1 is v
@@ -296,7 +296,7 @@ different one. This is fully independent of where we do it.
```python
with tf.variable_scope("foo") as foo_scope:
assert foo_scope.name == "foo"
-with tf.variable_scope("bar")
+with tf.variable_scope("bar"):
with tf.variable_scope("baz") as other_scope:
assert other_scope.name == "bar/baz"
with tf.variable_scope(foo_scope) as foo_scope2:
diff --git a/tensorflow/g3doc/tutorials/seq2seq/index.md b/tensorflow/g3doc/tutorials/seq2seq/index.md
index 7e8c3cb929..4cfcc56b29 100644
--- a/tensorflow/g3doc/tutorials/seq2seq/index.md
+++ b/tensorflow/g3doc/tutorials/seq2seq/index.md
@@ -35,7 +35,7 @@ File | What's in it?
`models/rnn/translate/translate.py` | Binary that trains and runs the translation model.
-## Sequence-to-Sequence Basics
+## Sequence-to-sequence basics
A basic sequence-to-sequence model, as introduced in
[Cho et al., 2014](http://arxiv.org/abs/1406.1078)
@@ -69,7 +69,7 @@ attention mechanism in the decoder looks like this.
<img style="width:100%" src="../../images/attention_seq2seq.png" />
</div>
-## TensorFlow seq2seq Library
+## TensorFlow seq2seq library
As you can see above, there are many different sequence-to-sequence
models. Each of these models can use different RNN cells, but all
@@ -148,7 +148,7 @@ more sequence-to-sequence models in `seq2seq.py`, take a look there. They all
have similar interfaces, so we will not describe them in detail. We will use
`embedding_attention_seq2seq` for our translation model below.
-## Neural Translation Model
+## Neural translation model
While the core of the sequence-to-sequence model is constructed by
the functions in `python/ops/seq2seq.py`, there are still a few tricks
@@ -238,7 +238,7 @@ with encoder inputs representing `[PAD PAD "." "go" "I"]` and decoder
inputs `[GO "Je" "vais" "." EOS PAD PAD PAD PAD PAD]`.
-## Let's Run It
+## Let's run it
To train the model described above, we need to a large English-French corpus.
We will use the *10^9-French-English corpus* from the
@@ -312,7 +312,7 @@ Reading model parameters from /tmp/translate.ckpt-340000
Qui est le président des États-Unis ?
```
-## What Next?
+## What next?
The example above shows how you can build your own English-to-French
translator, end-to-end. Run it and see how the model performs for yourself.
diff --git a/tensorflow/g3doc/tutorials/wide/index.md b/tensorflow/g3doc/tutorials/wide/index.md
index 4d76f85628..d30ad11374 100644
--- a/tensorflow/g3doc/tutorials/wide/index.md
+++ b/tensorflow/g3doc/tutorials/wide/index.md
@@ -63,8 +63,8 @@ import tempfile
import urllib
train_file = tempfile.NamedTemporaryFile()
test_file = tempfile.NamedTemporaryFile()
-urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", train_file.name)
-urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", test_file.name)
+urllib.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data", train_file.name)
+urllib.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test", test_file.name)
```
Once the CSV files are downloaded, let's read them into
diff --git a/tensorflow/g3doc/tutorials/wide_and_deep/index.md b/tensorflow/g3doc/tutorials/wide_and_deep/index.md
index f1928bdca4..4928dd41a3 100644
--- a/tensorflow/g3doc/tutorials/wide_and_deep/index.md
+++ b/tensorflow/g3doc/tutorials/wide_and_deep/index.md
@@ -215,8 +215,8 @@ CONTINUOUS_COLUMNS = ["age", "education_num", "capital_gain", "capital_loss",
# test_file to your own paths.
train_file = tempfile.NamedTemporaryFile()
test_file = tempfile.NamedTemporaryFile()
-urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", train_file.name)
-urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", test_file.name)
+urllib.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data", train_file.name)
+urllib.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test", test_file.name)
# Read the training and test data sets into Pandas dataframe.
df_train = pd.read_csv(train_file, names=COLUMNS, skipinitialspace=True)
diff --git a/tensorflow/g3doc/tutorials/word2vec/index.md b/tensorflow/g3doc/tutorials/word2vec/index.md
index 15653474df..936cb24a23 100644
--- a/tensorflow/g3doc/tutorials/word2vec/index.md
+++ b/tensorflow/g3doc/tutorials/word2vec/index.md
@@ -102,7 +102,7 @@ $$
\begin{align}
P(w_t | h) &= \text{softmax}(\text{score}(w_t, h)) \\
&= \frac{\exp \{ \text{score}(w_t, h) \} }
- {\sum_\text{Word w' in Vocab} \exp \{ \text{score}(w', h) \} }.
+ {\sum_\text{Word w' in Vocab} \exp \{ \text{score}(w', h) \} }
\end{align}
$$
@@ -115,7 +115,7 @@ $$
\begin{align}
J_\text{ML} &= \log P(w_t | h) \\
&= \text{score}(w_t, h) -
- \log \left( \sum_\text{Word w' in Vocab} \exp \{ \text{score}(w', h) \} \right)
+ \log \left( \sum_\text{Word w' in Vocab} \exp \{ \text{score}(w', h) \} \right).
\end{align}
$$
diff --git a/tensorflow/go/README.md b/tensorflow/go/README.md
index 9817469006..5e08372cf9 100644
--- a/tensorflow/go/README.md
+++ b/tensorflow/go/README.md
@@ -76,5 +76,4 @@ go test -v github.com/tensorflow/tensorflow/tensorflow/go
This API has been built on top of the [C
API](https://www.tensorflow.org/code/tensorflow/c/c_api.h),
which is intended for building language bindings for TensorFlow functionality.
-However, this is far from complete. Contributions are welcome. To monitor
-progress follow [issue 10](https://github.com/tensorflow/tensorflow/issues/10).
+However, this is far from complete. Contributions are welcome.
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
new file mode 100644
index 0000000000..2aa077c2b7
--- /dev/null
+++ b/tensorflow/java/BUILD
@@ -0,0 +1,61 @@
+# Description:
+# TensorFlow Java API.
+
+package(default_visibility = ["//visibility:private"])
+
+licenses(["notice"]) # Apache 2.0
+
+java_library(
+ name = "tensorflow",
+ srcs = glob(["src/main/java/org/tensorflow/*.java"]),
+ data = [":libtensorflow-jni"],
+ visibility = ["//visibility:public"],
+)
+
+java_test(
+ name = "TensorFlowTest",
+ srcs = ["src/test/java/org/tensorflow/TensorFlowTest.java"],
+ test_class = "org.tensorflow.TensorFlowTest",
+ deps = [
+ ":tensorflow",
+ "//external:junit",
+ ],
+)
+
+filegroup(
+ name = "libtensorflow-jni",
+ srcs = select({
+ "//tensorflow:darwin": [":libtensorflow-jni.dylib"],
+ "//conditions:default": [":libtensorflow-jni.so"],
+ }),
+)
+
+cc_binary(
+ name = "libtensorflow-jni.so",
+ linkshared = 1,
+ linkstatic = 1,
+ deps = ["//tensorflow/java/src/main/native"],
+)
+
+# System.loadLibrary() on OS X looks for ".dylib" or ".jnilib"
+# and no ".so". If and when https://github.com/bazelbuild/bazel/issues/914
+# is resolved, perhaps this workaround rule can be removed.
+genrule(
+ name = "darwin-compat",
+ srcs = [":libtensorflow-jni.so"],
+ outs = ["libtensorflow-jni.dylib"],
+ cmd = "cp $< $@",
+ output_to_bindir = 1,
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/java/README.md b/tensorflow/java/README.md
new file mode 100644
index 0000000000..d9bee5e342
--- /dev/null
+++ b/tensorflow/java/README.md
@@ -0,0 +1,72 @@
+# TensorFlow for Java
+
+Java bindings for TensorFlow.
+
+> *WARNING*: The TensorFlow Java API is incomplete and experimental and can
+> change without notice. Progress can be followed in
+> [issue #5](https://github.com/tensorflow/tensorflow/issues/5).
+>
+> Till then, for using TensorFlow on Android refer to
+> [contrib/android](https://www.tensorflow.org/code/tensorflow/contrib/android),
+> [makefile](https://www.tensorflow.org/code/tensorflow/contrib/makefile#android)
+> and/or the [Android camera
+> demo](https://www.tensorflow.org/code/tensorflow/examples/android).
+
+## Requirements
+
+- [bazel](https://www.bazel.build/versions/master/docs/install.html)
+- Environment to build TensorFlow from source code
+ ([Linux](https://www.tensorflow.org/versions/master/get_started/os_setup.html#prepare-environment-for-linux)
+ or [Mac OS
+ X](https://www.tensorflow.org/versions/master/get_started/os_setup.html#prepare-environment-for-mac-os-x)).
+ If you'd like to skip reading those details and do not care about GPU
+ support, try the following:
+
+ ```sh
+ # On Linux
+ sudo apt-get install python swig python-numpy
+
+ # On Mac OS X with homebrew
+ brew install swig
+ ```
+
+## Installation
+
+Build the Java Archive and native library:
+
+```sh
+bazel build -c opt \
+ //tensorflow/java:libtensorflow.jar \
+ //tensorflow/java:libtensorflow-jni
+```
+
+## Example Usage
+
+### With bazel
+
+Add a dependency on `//tensorflow/java:tensorflow` to the `java_binary` or
+`java_library` rule. For example:
+
+```sh
+bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example
+```
+
+### With `javac`
+
+- Add `libtensorflow.jar` to classpath for compilation. For example:
+
+ ```sh
+ javac \
+ -cp ../../bazel-bin/tensorflow/java/libtensorflow.jar \
+ ./src/main/java/org/tensorflow/examples/Example.java
+ ```
+
+- Make `libtensorflow.jar` and `libtensorflow-jni.so`
+ (`libtensorflow-jni.dylib` on OS X) available during execution. For example:
+
+ ```sh
+ java \
+ -Djava.library.path=../../bazel-bin/tensorflow/java \
+ -cp ../../bazel-bin/tensorflow/java/libtensorflow.jar:./src/main/java \
+ org.tensorflow.examples.Example
+ ```
diff --git a/tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java
new file mode 100644
index 0000000000..dc7f87b928
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java
@@ -0,0 +1,28 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow;
+
+/** Static utility methods describing the TensorFlow runtime. */
+public final class TensorFlow {
+ private TensorFlow() {}
+
+ static {
+ System.loadLibrary("tensorflow-jni");
+ }
+
+ /** Returns the version of the underlying TensorFlow runtime. */
+ public static native String getVersion();
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/.gitignore b/tensorflow/java/src/main/java/org/tensorflow/examples/.gitignore
new file mode 100644
index 0000000000..8dc1579ef5
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/examples/.gitignore
@@ -0,0 +1,3 @@
+# .class files generated when building examples using javac
+# as described in README.md
+*.class
diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD b/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD
new file mode 100644
index 0000000000..529287a038
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD
@@ -0,0 +1,25 @@
+# Description:
+# TensorFlow Java examples.
+
+package(default_visibility = ["//visibility:private"])
+
+licenses(["notice"]) # Apache 2.0
+
+java_binary(
+ name = "example",
+ srcs = ["Example.java"],
+ main_class = "org.tensorflow.examples.Example",
+ deps = ["//tensorflow/java:tensorflow"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/Example.java b/tensorflow/java/src/main/java/org/tensorflow/examples/Example.java
new file mode 100644
index 0000000000..f61c44b4ab
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/examples/Example.java
@@ -0,0 +1,29 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.examples;
+
+import org.tensorflow.TensorFlow;
+
+/**
+ * Sample usage of the TensorFlow Java library.
+ *
+ * <p>This sample should become more useful as functionality is added to the API.
+ */
+public class Example {
+ public static void main(String[] args) {
+ System.out.println("TensorFlow version: " + TensorFlow.getVersion());
+ }
+}
diff --git a/tensorflow/java/src/main/native/BUILD b/tensorflow/java/src/main/native/BUILD
new file mode 100644
index 0000000000..3a2d0cbbfb
--- /dev/null
+++ b/tensorflow/java/src/main/native/BUILD
@@ -0,0 +1,66 @@
+# Description:
+# Java Native Interface (JNI) library intended for implementing the
+# TensorFlow Java API using the TensorFlow C library.
+
+package(default_visibility = ["//tensorflow/java:__pkg__"])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
+
+tf_cuda_library(
+ name = "native",
+ srcs = [
+ "tensorflow.cc",
+ ":jni.h",
+ ":jni_md.h",
+ ],
+ hdrs = ["tensorflow.h"],
+ includes = ["."],
+ deps = [
+ "//tensorflow/c:c_api",
+ ],
+ alwayslink = 1,
+)
+
+# Silly rules to make
+# #include <jni.h>
+# in the source headers work
+# (in combination with the "includes" attribute of the tf_cuda_library rule
+# above).
+#
+# Inspired from:
+# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD
+# but hopefully there is a simpler alternative to this.
+#
+# TODO(ashankar): This should not be necessary for Android builds as the
+# toolchain makes <jni.h> available. Perhaps remove ":jni.h" and ":jni_md.h"
+# from "srcs" and make these genrules a no-op when building for Android?
+genrule(
+ name = "copy_jni_h",
+ srcs = ["@bazel_tools//tools/jdk:jni_header"],
+ outs = ["jni.h"],
+ cmd = "cp -f $< $@",
+)
+
+genrule(
+ name = "copy_jni_md_h",
+ srcs = select({
+ "//tensorflow:darwin": ["@bazel_tools//tools/jdk:jni_md_header-darwin"],
+ "//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"],
+ }),
+ outs = ["jni_md.h"],
+ cmd = "cp -f $< $@",
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/python/client/net_lib.i b/tensorflow/java/src/main/native/tensorflow.cc
index 333e2abbc5..55de5771dd 100644
--- a/tensorflow/python/client/net_lib.i
+++ b/tensorflow/java/src/main/native/tensorflow.cc
@@ -13,18 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-%include "tensorflow/python/platform/base.i"
+#include "tensorflow/java/src/main/native/tensorflow.h"
+#include "tensorflow/c/c_api.h"
-%{
-#include "tensorflow/core/platform/net.h"
-%}
-
-%ignoreall
-
-%unignore tensorflow;
-%unignore tensorflow::internal;
-%unignore tensorflow::internal::PickUnusedPortOrDie;
-
-%include "tensorflow/core/platform/net.h"
-
-%unignoreall
+JNIEXPORT jstring JNICALL
+Java_org_tensorflow_TensorFlow_getVersion(JNIEnv* env, jclass clazz) {
+ return env->NewStringUTF(TF_Version());
+}
diff --git a/tensorflow/java/src/main/native/tensorflow.h b/tensorflow/java/src/main/native/tensorflow.h
new file mode 100644
index 0000000000..897a000ac0
--- /dev/null
+++ b/tensorflow/java/src/main/native/tensorflow.h
@@ -0,0 +1,36 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_JAVA_JNI_H_
+#define TENSORFLOW_JAVA_JNI_H_
+
+#include <jni.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+/*
+ * Class: TensorFlow
+ * Method: getVersion
+ * Signature: ()Ljava/lang/String;
+ */
+JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_getVersion(JNIEnv*,
+ jclass);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_JAVA_JNI_H_
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java
new file mode 100644
index 0000000000..94fd0582c1
--- /dev/null
+++ b/tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java
@@ -0,0 +1,31 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow;
+
+import static org.junit.Assert.assertTrue;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link org.tensorflow.TensorFlow}. */
+@RunWith(JUnit4.class)
+public class TensorFlowTest {
+ @Test
+ public void version() {
+ assertTrue(TensorFlow.getVersion().length() > 0);
+ }
+}
diff --git a/tensorflow/models/embedding/word2vec.py b/tensorflow/models/embedding/word2vec.py
index e463e300c1..4b36554716 100644
--- a/tensorflow/models/embedding/word2vec.py
+++ b/tensorflow/models/embedding/word2vec.py
@@ -365,7 +365,7 @@ class Word2Vec(object):
self._word2id[w] = i
true_logits, sampled_logits = self.forward(examples, labels)
loss = self.nce_loss(true_logits, sampled_logits)
- tf.scalar_summary("NCE loss", loss)
+ tf.contrib.deprecated.scalar_summary("NCE loss", loss)
self._loss = loss
self.optimize(loss)
@@ -396,8 +396,8 @@ class Word2Vec(object):
initial_epoch, initial_words = self._session.run([self._epoch, self._words])
- summary_op = tf.merge_all_summaries()
- summary_writer = tf.train.SummaryWriter(opts.save_path, self._session.graph)
+ summary_op = tf.contrib.deprecated.merge_all_summaries()
+ summary_writer = tf.summary.FileWriter(opts.save_path, self._session.graph)
workers = []
for _ in xrange(opts.concurrent_steps):
t = threading.Thread(target=self._train_thread_body)
diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py
index 1c51b76f09..55c34ba84b 100644
--- a/tensorflow/models/image/cifar10/cifar10.py
+++ b/tensorflow/models/image/cifar10/cifar10.py
@@ -91,8 +91,9 @@ def _activation_summary(x):
# Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
# session. This helps the clarity of presentation on tensorboard.
tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
- tf.histogram_summary(tensor_name + '/activations', x)
- tf.scalar_summary(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
+ tf.contrib.deprecated.histogram_summary(tensor_name + '/activations', x)
+ tf.contrib.deprecated.scalar_summary(tensor_name + '/sparsity',
+ tf.nn.zero_fraction(x))
def _variable_on_cpu(name, shape, initializer):
@@ -316,8 +317,8 @@ def _add_loss_summaries(total_loss):
for l in losses + [total_loss]:
# Name each loss as '(raw)' and name the moving average version of the loss
# as the original loss name.
- tf.scalar_summary(l.op.name +' (raw)', l)
- tf.scalar_summary(l.op.name, loss_averages.average(l))
+ tf.contrib.deprecated.scalar_summary(l.op.name + ' (raw)', l)
+ tf.contrib.deprecated.scalar_summary(l.op.name, loss_averages.average(l))
return loss_averages_op
@@ -345,7 +346,7 @@ def train(total_loss, global_step):
decay_steps,
LEARNING_RATE_DECAY_FACTOR,
staircase=True)
- tf.scalar_summary('learning_rate', lr)
+ tf.contrib.deprecated.scalar_summary('learning_rate', lr)
# Generate moving averages of all losses and associated summaries.
loss_averages_op = _add_loss_summaries(total_loss)
@@ -360,12 +361,12 @@ def train(total_loss, global_step):
# Add histograms for trainable variables.
for var in tf.trainable_variables():
- tf.histogram_summary(var.op.name, var)
+ tf.contrib.deprecated.histogram_summary(var.op.name, var)
# Add histograms for gradients.
for grad, var in grads:
if grad is not None:
- tf.histogram_summary(var.op.name + '/gradients', grad)
+ tf.contrib.deprecated.histogram_summary(var.op.name + '/gradients', grad)
# Track the moving averages of all trainable variables.
variable_averages = tf.train.ExponentialMovingAverage(
@@ -394,5 +395,5 @@ def maybe_download_and_extract():
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
-
+
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
diff --git a/tensorflow/models/image/cifar10/cifar10_eval.py b/tensorflow/models/image/cifar10/cifar10_eval.py
index 19bf74c477..c2329380d6 100644
--- a/tensorflow/models/image/cifar10/cifar10_eval.py
+++ b/tensorflow/models/image/cifar10/cifar10_eval.py
@@ -134,9 +134,9 @@ def evaluate():
saver = tf.train.Saver(variables_to_restore)
# Build the summary operation based on the TF collection of Summaries.
- summary_op = tf.merge_all_summaries()
+ summary_op = tf.contrib.deprecated.merge_all_summaries()
- summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, g)
+ summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)
while True:
eval_once(saver, summary_writer, top_k_op, summary_op)
diff --git a/tensorflow/models/image/cifar10/cifar10_input.py b/tensorflow/models/image/cifar10/cifar10_input.py
index 14ea94d72f..b00859b262 100644
--- a/tensorflow/models/image/cifar10/cifar10_input.py
+++ b/tensorflow/models/image/cifar10/cifar10_input.py
@@ -130,7 +130,7 @@ def _generate_image_and_label_batch(image, label, min_queue_examples,
capacity=min_queue_examples + 3 * batch_size)
# Display the training images in the visualizer.
- tf.image_summary('images', images)
+ tf.contrib.deprecated.image_summary('images', images)
return images, tf.reshape(label_batch, [batch_size])
diff --git a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py
index 53ae7d5c74..a59e13d5e3 100644
--- a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py
+++ b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py
@@ -93,7 +93,7 @@ def tower_loss(scope):
# Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
# session. This helps the clarity of presentation on tensorboard.
loss_name = re.sub('%s_[0-9]*/' % cifar10.TOWER_NAME, '', l.op.name)
- tf.scalar_summary(loss_name, l)
+ tf.contrib.deprecated.scalar_summary(loss_name, l)
return total_loss
@@ -187,20 +187,22 @@ def train():
grads = average_gradients(tower_grads)
# Add a summary to track the learning rate.
- summaries.append(tf.scalar_summary('learning_rate', lr))
+ summaries.append(tf.contrib.deprecated.scalar_summary('learning_rate', lr))
# Add histograms for gradients.
for grad, var in grads:
if grad is not None:
summaries.append(
- tf.histogram_summary(var.op.name + '/gradients', grad))
+ tf.contrib.deprecated.histogram_summary(var.op.name + '/gradients',
+ grad))
# Apply the gradients to adjust the shared variables.
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
# Add histograms for trainable variables.
for var in tf.trainable_variables():
- summaries.append(tf.histogram_summary(var.op.name, var))
+ summaries.append(
+ tf.contrib.deprecated.histogram_summary(var.op.name, var))
# Track the moving averages of all trainable variables.
variable_averages = tf.train.ExponentialMovingAverage(
@@ -214,7 +216,7 @@ def train():
saver = tf.train.Saver(tf.all_variables())
# Build the summary operation from the last tower summaries.
- summary_op = tf.merge_summary(summaries)
+ summary_op = tf.contrib.deprecated.merge_summary(summaries)
# Build an initialization operation to run below.
init = tf.global_variables_initializer()
@@ -230,7 +232,7 @@ def train():
# Start the queue runners.
tf.train.start_queue_runners(sess=sess)
- summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
+ summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
for step in xrange(FLAGS.max_steps):
start_time = time.time()
diff --git a/tensorflow/models/image/cifar10/cifar10_train.py b/tensorflow/models/image/cifar10/cifar10_train.py
index 45c0bbd9f0..eab499fc3e 100644
--- a/tensorflow/models/image/cifar10/cifar10_train.py
+++ b/tensorflow/models/image/cifar10/cifar10_train.py
@@ -118,3 +118,4 @@ def main(argv=None): # pylint: disable=unused-argument
if __name__ == '__main__':
tf.app.run()
+
diff --git a/tensorflow/models/rnn/ptb/ptb_word_lm.py b/tensorflow/models/rnn/ptb/ptb_word_lm.py
index 020edbd5e5..45fb1e774a 100644
--- a/tensorflow/models/rnn/ptb/ptb_word_lm.py
+++ b/tensorflow/models/rnn/ptb/ptb_word_lm.py
@@ -328,14 +328,14 @@ def main(_):
train_input = PTBInput(config=config, data=train_data, name="TrainInput")
with tf.variable_scope("Model", reuse=None, initializer=initializer):
m = PTBModel(is_training=True, config=config, input_=train_input)
- tf.scalar_summary("Training Loss", m.cost)
- tf.scalar_summary("Learning Rate", m.lr)
+ tf.contrib.deprecated.scalar_summary("Training Loss", m.cost)
+ tf.contrib.deprecated.scalar_summary("Learning Rate", m.lr)
with tf.name_scope("Valid"):
valid_input = PTBInput(config=config, data=valid_data, name="ValidInput")
with tf.variable_scope("Model", reuse=True, initializer=initializer):
mvalid = PTBModel(is_training=False, config=config, input_=valid_input)
- tf.scalar_summary("Validation Loss", mvalid.cost)
+ tf.contrib.deprecated.scalar_summary("Validation Loss", mvalid.cost)
with tf.name_scope("Test"):
test_input = PTBInput(config=eval_config, data=test_data, name="TestInput")
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index e1dda0674b..f1fae16bb0 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1204,7 +1204,10 @@ py_library(
py_library(
name = "rnn_cell",
- srcs = ["ops/rnn_cell.py"],
+ srcs = [
+ "ops/rnn_cell.py",
+ "ops/rnn_cell_impl.py",
+ ],
srcs_version = "PY2AND3",
deps = [
":array_ops",
@@ -1906,28 +1909,6 @@ cuda_py_tests(
],
)
-py_library(
- name = "net_lib",
- testonly = 1,
- srcs = ["util/net_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":pywrap_tensorflow",
- ],
-)
-
-py_tests(
- name = "net_lib_test",
- size = "small",
- srcs = [
- "util/net_lib_test.py",
- ],
- additional_deps = [
- ":net_lib",
- "//tensorflow:tensorflow_py",
- ],
-)
-
tf_cuda_library(
name = "tf_session_helper",
srcs = ["client/tf_session_helper.cc"],
@@ -1954,7 +1935,6 @@ tf_py_wrap_cc(
swig_includes = [
"client/device_lib.i",
"client/events_writer.i",
- "client/net_lib.i",
"client/quantize_training.i",
"client/tf_session.i",
"framework/cpp_shape_inference.i",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index ae8d0e02f1..2a7a76c396 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -82,7 +82,7 @@ from tensorflow.python.ops.standard_ops import *
# pylint: enable=wildcard-import
# Bring in subpackages.
-from tensorflow.python import layers
+from tensorflow.python.layers import layers
from tensorflow.python.ops import nn
from tensorflow.python.ops import resources
from tensorflow.python.ops import sdca_ops as sdca
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 71c931037e..591cc5afbc 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -364,6 +364,7 @@ class _DictFetchMapper(_FetchMapper):
Args:
fetches: Dict of fetches.
"""
+ self._fetch_type = type(fetches)
self._keys = fetches.keys()
self._mappers = [_FetchMapper.for_fetch(fetch)
for fetch in fetches.values()]
@@ -373,7 +374,7 @@ class _DictFetchMapper(_FetchMapper):
return self._unique_fetches
def build_results(self, values):
- results = {}
+ results = self._fetch_type()
for k, m, vi in zip(self._keys, self._mappers, self._value_indices):
results[k] = m.build_results([values[j] for j in vi])
return results
@@ -661,8 +662,8 @@ class BaseSession(SessionInterface):
`feed_dict` for the corresponding input values.
The `fetches` argument may be a single graph element, or an arbitrarily
- nested list, tuple, namedtuple, or dict containing graph elements at its
- leaves. A graph element can be one of the following types:
+ nested list, tuple, namedtuple, dict, or OrderedDict containing graph
+ elements at its leaves. A graph element can be one of the following types:
* An [`Operation`](../../api_docs/python/framework.md#Operation).
The corresponding fetched value will be `None`.
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index a20376b91d..0c602a9014 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -254,6 +254,18 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(None, res['b'])
self.assertEqual(44.0, res['c'])
+ def testFetchOrderedDict(self):
+ with session.Session() as sess:
+ a = constant_op.constant(42.0)
+ b = control_flow_ops.no_op() # An op, not a tensor.
+ c = constant_op.constant(44.0)
+ res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)]))
+ self.assertTrue(isinstance(res, collections.OrderedDict))
+ self.assertEqual([3, 2, 1], list(res.keys()))
+ self.assertEqual(42.0, res[3])
+ self.assertEqual(None, res[2])
+ self.assertEqual(44.0, res[1])
+
def testFetchNestingEmptyOneLevel(self):
with session.Session() as sess:
a_val = 11.0
diff --git a/tensorflow/python/client/timeline.py b/tensorflow/python/client/timeline.py
index ea64d74d6d..f3ba4244ce 100644
--- a/tensorflow/python/client/timeline.py
+++ b/tensorflow/python/client/timeline.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import copy
import json
+import re
# The timeline target is usually imported as part of BUILD target
# "platform_test", which includes also includes the "platform"
@@ -384,12 +385,15 @@ class Timeline(object):
def _parse_op_label(self, label):
"""Parses the fields in a node timeline label."""
- nn, rest = label.split(' = ')
- op, rest = rest.split('(')
- if rest == ')':
+ # Expects labels of the form: name = op(arg, arg, ...).
+ match = re.match(r'(.*) = (.*)\((.*)\)', label)
+ if match is None:
+ return 'unknown', 'unknown', []
+ nn, op, inputs = match.groups()
+ if not inputs:
inputs = []
else:
- inputs = rest[:-1].split(', ')
+ inputs = inputs.split(', ')
return nn, op, inputs
def _assign_lanes(self):
@@ -421,11 +425,14 @@ class Timeline(object):
start = nodestats.all_start_micros
duration = nodestats.all_end_rel_micros
tid = nodestats.thread_id
+ inputs = []
if is_gputrace:
# Node names should always have the form 'name:op'.
fields = node_name.split(':') + ['unknown']
node_name, op = fields[:2]
- inputs = []
+ elif node_name == 'RecvTensor':
+ # RPC tracing does not use the standard timeline_label format.
+ op = 'RecvTensor'
else:
_, op, inputs = self._parse_op_label(nodestats.timeline_label)
args = {'name': node_name, 'op': op}
@@ -518,7 +525,7 @@ class Timeline(object):
end_time = node_stats.all_start_micros + node_stats.all_end_rel_micros
self._emit_op(node_stats, device_pid, is_gputrace)
- if is_gputrace:
+ if is_gputrace or node_stats.node_name == 'RecvTensor':
continue
_, _, inputs = self._parse_op_label(node_stats.timeline_label)
diff --git a/tensorflow/python/client/timeline_test.py b/tensorflow/python/client/timeline_test.py
index 7c9d7847b7..46984f694f 100644
--- a/tensorflow/python/client/timeline_test.py
+++ b/tensorflow/python/client/timeline_test.py
@@ -109,6 +109,23 @@ class TimelineTest(tf.test.TestCase):
show_dataflow=False)
self._validateTrace(ctf)
+ def testTimelineWithRPCs(self):
+ """Tests that Timeline can handle RPC tracing."""
+ metadata = tf.RunMetadata()
+ step_stats = metadata.step_stats
+ dev_stats = step_stats.dev_stats.add()
+ dev_stats.device = '/job:worker/replica:0/task:0/cpu:0'
+ node_stats = dev_stats.node_stats.add()
+ node_stats.node_name = 'RecvTensor'
+ node_stats.all_start_micros = 12345
+ node_stats.op_end_rel_micros = 42
+ node_stats.timeline_label = ('[1024B] edge_160_conv2/biases/read from '
+ '/job:ps/replica:0/task:3/cpu:0 to '
+ '/job:worker/replica:0/task:0/cpu:0')
+ tl = timeline.Timeline(step_stats)
+ ctf = tl.generate_chrome_trace_format()
+ self._validateTrace(ctf)
+
def testAnalysisAndAllocations(self):
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
diff --git a/tensorflow/python/framework/graph_util.py b/tensorflow/python/framework/graph_util.py
index 402a5ebf0e..a666630e44 100644
--- a/tensorflow/python/framework/graph_util.py
+++ b/tensorflow/python/framework/graph_util.py
@@ -26,7 +26,6 @@ from tensorflow.python.framework.graph_util_impl import convert_variables_to_con
from tensorflow.python.framework.graph_util_impl import extract_sub_graph
from tensorflow.python.framework.graph_util_impl import must_run_on_cpu
from tensorflow.python.framework.graph_util_impl import remove_training_nodes
-from tensorflow.python.framework.graph_util_impl import set_cpu0
from tensorflow.python.framework.graph_util_impl import tensor_shape_from_node_def_name
# pylint: enable=unused-import
from tensorflow.python.util.all_util import remove_undocumented
@@ -36,7 +35,6 @@ _allowed_symbols = [
"convert_variables_to_constants",
"extract_sub_graph",
"must_run_on_cpu",
- "set_cpu0",
"tensor_shape_from_node_def_name",
"remove_training_nodes",
]
diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py
index ba693503c2..587f883260 100644
--- a/tensorflow/python/framework/graph_util_impl.py
+++ b/tensorflow/python/framework/graph_util_impl.py
@@ -25,7 +25,6 @@ import re
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
-from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
@@ -49,23 +48,6 @@ def _is_variable_op(op):
return op in _VARIABLE_OPS
-def set_cpu0(device_string):
- """Creates a new device string based on `device_string' but using /CPU:0.
-
- If the device is already on /CPU:0, this is a no-op.
-
- Args:
- device_string: A device string.
-
- Returns:
- A device string.
- """
- parsed_device = pydev.DeviceSpec.from_string(device_string)
- parsed_device.device_type = "CPU"
- parsed_device.device_index = 0
- return parsed_device.to_string()
-
-
def must_run_on_cpu(node, pin_variables_on_cpu=False):
"""Returns True if the given node_def must run on CPU, otherwise False.
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 24169b57db..d1edf43193 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -3980,7 +3980,7 @@ class GraphKeys(object):
for more details.
* `SUMMARIES`: the summary `Tensor` objects that have been created in the
graph. See
- [`tf.merge_all_summaries()`](../../api_docs/python/train.md#merge_all_summaries)
+ [`tf.contrib.deprecated.merge_all_summaries()`](../../api_docs/python/train.md#merge_all_summaries)
for more details.
* `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to
produce input for a computation. See
diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py
index 02c71d3032..e779dc7c69 100644
--- a/tensorflow/python/kernel_tests/concat_op_test.py
+++ b/tensorflow/python/kernel_tests/concat_op_test.py
@@ -461,11 +461,17 @@ class ConcatOpTest(tf.test.TestCase):
with self.test_session(use_gpu=True):
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
- output = tf.concat(-2, [t1, t2]).eval()
+
+ c = tf.concat(-2, [t1, t2])
+ output = c.eval()
+ self.assertEqual([4, 3], c.get_shape().as_list())
self.assertAllEqual(
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
output)
- output = tf.concat(-1, [t1, t2]).eval()
+
+ c = tf.concat(-1, [t1, t2])
+ self.assertEqual([2, 6], c.get_shape().as_list())
+ output = c.eval()
self.assertAllEqual(
[[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]],
output)
@@ -488,11 +494,17 @@ class ConcatOpTest(tf.test.TestCase):
with self.test_session(use_gpu=True):
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
- output = gen_array_ops._concat_v2([t1, t2], -2).eval()
+
+ c = gen_array_ops._concat_v2([t1, t2], -2)
+ self.assertEqual([4, 3], c.get_shape().as_list())
+ output = c.eval()
self.assertAllEqual(
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
output)
- output = gen_array_ops._concat_v2([t1, t2], -1).eval()
+
+ c = gen_array_ops._concat_v2([t1, t2], -1)
+ self.assertEqual([2, 6], c.get_shape().as_list())
+ output = c.eval()
self.assertAllEqual(
[[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]],
output)
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 8fc4be8e6e..732a604dc2 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -1269,6 +1269,31 @@ class ControlFlowTest(tf.test.TestCase):
tf.global_variables_initializer().run()
self.assertAllClose(216.0, r[0].eval())
+ def testWhileGradInCond(self):
+ with self.test_session():
+ n = tf.convert_to_tensor(1.0, name="n")
+ x = tf.placeholder(tf.float32, shape=None)
+ c = lambda n: tf.less(n, 10.0)
+ b = lambda n: tf.add(n, x)
+ def fn1():
+ r = tf.while_loop(c, b, [n], [tensor_shape.unknown_shape()])
+ return tf.gradients(r, x)
+ r = tf.cond(tf.less(1, 2), fn1, lambda: x)
+ self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+
+ def testWhileGradInWhile(self):
+ with self.test_session():
+ n = tf.convert_to_tensor(1.0, name="n")
+ x = tf.placeholder(tf.float32, shape=None)
+ c = lambda n: tf.less(n, 10.0)
+ b = lambda n: tf.add(n, x)
+ def b1(n):
+ r = tf.while_loop(c, b, [n], [tensor_shape.unknown_shape()])
+ return tf.gradients(r, x)
+ r = tf.while_loop(lambda n: n < 6.0, b1, [n],
+ [tensor_shape.unknown_shape()])
+ self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+
def testWhile_NestedInput(self):
with self.test_session() as sess:
named = collections.namedtuple("named", ("a", "b"))
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index f6397226af..aa31c03e19 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -1297,7 +1297,7 @@ class SelectOpTest(tf.test.TestCase):
def _compare(self, c, x, y, use_gpu):
np_ans = np.where(c, x, y)
with self.test_session(use_gpu=use_gpu):
- out = tf.select(c, x, y)
+ out = tf.where(c, x, y)
tf_ans = out.eval()
self.assertAllEqual(np_ans, tf_ans)
self.assertShapeEqual(np_ans, out)
@@ -1306,7 +1306,7 @@ class SelectOpTest(tf.test.TestCase):
with self.test_session():
inx = tf.convert_to_tensor(x)
iny = tf.convert_to_tensor(y)
- out = tf.select(c, inx, iny)
+ out = tf.where(c, inx, iny)
s = list(np.shape(c))
jacob_t, jacob_n = tf.test.compute_gradient(inx,
s,
@@ -1318,7 +1318,7 @@ class SelectOpTest(tf.test.TestCase):
yf = y.astype(numeric_gradient_type)
inxf = tf.convert_to_tensor(xf)
inyf = tf.convert_to_tensor(yf)
- outf = tf.select(c, inxf, inyf)
+ outf = tf.where(c, inxf, inyf)
_, jacob_n = tf.test.compute_gradient(inxf,
s,
outf,
@@ -1336,7 +1336,7 @@ class SelectOpTest(tf.test.TestCase):
with self.test_session():
inx = tf.convert_to_tensor(x)
iny = tf.convert_to_tensor(y)
- out = tf.select(c, inx, iny)
+ out = tf.where(c, inx, iny)
s = list(np.shape(c))
jacob_t, jacob_n = tf.test.compute_gradient(iny,
s,
@@ -1349,7 +1349,7 @@ class SelectOpTest(tf.test.TestCase):
yf = y.astype(numeric_gradient_type)
inxf = tf.convert_to_tensor(xf)
inyf = tf.convert_to_tensor(yf)
- outf = tf.select(c, inxf, inyf)
+ outf = tf.where(c, inxf, inyf)
_, jacob_n = tf.test.compute_gradient(inyf,
s,
outf,
@@ -1415,7 +1415,7 @@ class SelectOpTest(tf.test.TestCase):
xt = x.astype(t)
yt = y.astype(t)
with self.assertRaises(ValueError):
- tf.select(c, xt, yt)
+ tf.where(c, xt, yt)
def testEmptyTensor(self):
c = np.random.randint(0, 3, 0).astype(np.bool).reshape(1, 3, 0)
@@ -1425,7 +1425,7 @@ class SelectOpTest(tf.test.TestCase):
with self.test_session():
xt = x.astype(np.float32)
yt = y.astype(np.float32)
- z = tf.select(c, xt, yt).eval()
+ z = tf.where(c, xt, yt).eval()
self.assertAllEqual(z_expected, z)
def testNan(self):
@@ -1434,7 +1434,7 @@ class SelectOpTest(tf.test.TestCase):
for c in False, True:
for a in 7.0, np.nan:
for b in 5.0, np.nan:
- x = tf.select(c, a, b).eval()
+ x = tf.where(c, a, b).eval()
y = a if c else b
self.assertEqual(np.isnan(x), np.isnan(y))
@@ -1447,7 +1447,7 @@ class BatchSelectOpTest(tf.test.TestCase):
[x_i if c_i else y_i for c_i, x_i, y_i in zip(c, x, y)]).transpose(
[2, 0, 1])
with self.test_session(use_gpu=use_gpu):
- out = tf.select(c, x, y)
+ out = tf.where(c, x, y)
tf_ans = out.eval()
self.assertAllEqual(np_ans, tf_ans)
self.assertShapeEqual(np_ans, out)
@@ -1456,7 +1456,7 @@ class BatchSelectOpTest(tf.test.TestCase):
with self.test_session():
inx = tf.convert_to_tensor(x)
iny = tf.convert_to_tensor(y)
- out = tf.select(c, inx, iny)
+ out = tf.where(c, inx, iny)
s = list(np.shape(x))
jacob_t, jacob_n = tf.test.compute_gradient(inx,
s,
@@ -1468,7 +1468,7 @@ class BatchSelectOpTest(tf.test.TestCase):
yf = y.astype(numeric_gradient_type)
inxf = tf.convert_to_tensor(xf)
inyf = tf.convert_to_tensor(yf)
- outf = tf.select(c, inxf, inyf)
+ outf = tf.where(c, inxf, inyf)
_, jacob_n = tf.test.compute_gradient(inxf,
s,
outf,
@@ -1486,7 +1486,7 @@ class BatchSelectOpTest(tf.test.TestCase):
with self.test_session():
inx = tf.convert_to_tensor(x)
iny = tf.convert_to_tensor(y)
- out = tf.select(c, inx, iny)
+ out = tf.where(c, inx, iny)
s = list(np.shape(x))
jacob_t, jacob_n = tf.test.compute_gradient(iny,
s,
@@ -1498,7 +1498,7 @@ class BatchSelectOpTest(tf.test.TestCase):
yf = y.astype(numeric_gradient_type)
inxf = tf.convert_to_tensor(xf)
inyf = tf.convert_to_tensor(yf)
- outf = tf.select(c, inxf, inyf)
+ outf = tf.where(c, inxf, inyf)
_, jacob_n = tf.test.compute_gradient(inyf,
s,
outf,
@@ -1552,7 +1552,7 @@ class BatchSelectOpTest(tf.test.TestCase):
xt = x.astype(t)
yt = y.astype(t)
with self.assertRaises(ValueError):
- tf.select(c, xt, yt)
+ tf.where(c, xt, yt)
class MinMaxOpTest(tf.test.TestCase):
diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py
index 889f14cd53..cbc5ee278e 100644
--- a/tensorflow/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/kernel_tests/parsing_ops_test.py
@@ -258,6 +258,82 @@ class ParseExampleTest(tf.test.TestCase):
}
}, expected_output)
+ def testSerializedContainingSparseFeature(self):
+ original = [
+ example(features=features({
+ "val": float_feature([3, 4]),
+ "idx": int64_feature([5, 10])
+ })),
+ example(features=features({
+ "val": float_feature([]), # empty float list
+ "idx": int64_feature([])
+ })),
+ example(features=features({
+ "val": feature(), # feature with nothing in it
+ # missing idx feature
+ })),
+ example(features=features({
+ "val": float_feature([1, 2, -1]),
+ "idx": int64_feature([0, 9, 3]) # unsorted
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp = ( # indices, values, shape
+ np.array([[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64),
+ np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32),
+ np.array([4, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ expected_output = {
+ "sp": expected_sp,
+ }
+
+ self._test({
+ "serialized": tf.convert_to_tensor(serialized),
+ "features": {
+ "sp": tf.SparseFeature("idx", "val", tf.float32, 13)
+ }
+ }, expected_output)
+
+ def testSerializedContainingSparseFeatureReuse(self):
+ original = [
+ example(features=features({
+ "val1": float_feature([3, 4]),
+ "val2": float_feature([5, 6]),
+ "idx": int64_feature([5, 10])
+ })),
+ example(features=features({
+ "val1": float_feature([]), # empty float list
+ "idx": int64_feature([])
+ })),
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp1 = ( # indices, values, shape
+ np.array([[0, 5], [0, 10]], dtype=np.int64),
+ np.array([3.0, 4.0], dtype=np.float32),
+ np.array([2, 13], dtype=np.int64)) # batch == 2, max_elems = 13
+
+ expected_sp2 = ( # indices, values, shape
+ np.array([[0, 5], [0, 10]], dtype=np.int64),
+ np.array([5.0, 6.0], dtype=np.float32),
+ np.array([2, 7], dtype=np.int64)) # batch == 2, max_elems = 13
+
+ expected_output = {
+ "sp1": expected_sp1,
+ "sp2": expected_sp2,
+ }
+
+ self._test({
+ "serialized": tf.convert_to_tensor(serialized),
+ "features": {
+ "sp1": tf.SparseFeature("idx", "val1", tf.float32, 13),
+ "sp2": tf.SparseFeature("idx", "val2", tf.float32, 7)
+ }
+ }, expected_output)
+
def testSerializedContainingDense(self):
aname = "a"
bname = "b*has+a:tricky_name"
@@ -400,7 +476,7 @@ class ParseExampleTest(tf.test.TestCase):
},
expected_output)
- def testSerializedContainingSparseAndDenseWithNoDefault(self):
+ def testSerializedContainingSparseAndSparseFeatureAndDenseWithNoDefault(self):
expected_st_a = ( # indices, values, shape
np.empty(
(0, 2), dtype=np.int64), # indices
@@ -408,12 +484,20 @@ class ParseExampleTest(tf.test.TestCase):
(0,), dtype=np.int64), # sp_a is DT_INT64
np.array(
[2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+ expected_sp = ( # indices, values, shape
+ np.array([[0, 0], [0, 3], [1, 7]], dtype=np.int64),
+ np.array(["a", "b", "c"], dtype="|S"),
+ np.array([2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
original = [
example(features=features({
- "c": float_feature([3, 4])
+ "c": float_feature([3, 4]),
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3])
})), example(features=features({
- "c": float_feature([1, 2])
+ "c": float_feature([1, 2]),
+ "val": bytes_feature([b"c"]),
+ "idx": int64_feature([7])
}))
]
@@ -424,6 +508,7 @@ class ParseExampleTest(tf.test.TestCase):
b_default = np.random.rand(3, 3).astype(bytes)
expected_output = {
"st_a": expected_st_a,
+ "sp": expected_sp,
"a": np.array(2 * [[a_default]]),
"b": np.array(2 * [b_default]),
"c": np.array(
@@ -436,6 +521,7 @@ class ParseExampleTest(tf.test.TestCase):
"serialized": tf.convert_to_tensor(serialized),
"features": {
"st_a": tf.VarLenFeature(tf.int64),
+ "sp": tf.SparseFeature("idx", "val", tf.string, 13),
"a": tf.FixedLenFeature(
(1, 3), tf.int64, default_value=a_default),
"b": tf.FixedLenFeature(
@@ -446,6 +532,46 @@ class ParseExampleTest(tf.test.TestCase):
},
expected_output)
+ def testSerializedContainingSparseAndSparseFeatureWithReuse(self):
+ expected_idx = ( # indices, values, shape
+ np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
+ np.array([0, 3, 7, 1]),
+ np.array([2, 2], dtype=np.int64)) # batch == 4, max_elems = 2
+
+ expected_sp = ( # indices, values, shape
+ np.array([[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64),
+ np.array(["a", "b", "d", "c"], dtype="|S"),
+ np.array([2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ original = [
+ example(features=features({
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3])
+ })), example(features=features({
+ "val": bytes_feature([b"c", b"d"]),
+ "idx": int64_feature([7, 1])
+ }))
+ ]
+
+ names = ["in1", "in2"]
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "idx": expected_idx,
+ "sp": expected_sp,
+ }
+
+ self._test(
+ {
+ "example_names": names,
+ "serialized": tf.convert_to_tensor(serialized),
+ "features": {
+ "idx": tf.VarLenFeature(tf.int64),
+ "sp": tf.SparseFeature("idx", "val", tf.string, 13),
+ }
+ },
+ expected_output)
+
class ParseSingleExampleTest(tf.test.TestCase):
@@ -473,8 +599,10 @@ class ParseSingleExampleTest(tf.test.TestCase):
self.assertEqual(tuple(out[k].values.get_shape().as_list()), (None,))
self.assertEqual(tuple(out[k].shape.get_shape().as_list()), (1,))
- def testSingleExampleWithSparseAndDense(self):
+ def testSingleExampleWithSparseAndSparseFeatureAndDense(self):
original = example(features=features({"c": float_feature([3, 4]),
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3]),
"st_a": float_feature([3.0, 4.0])}))
serialized = original.SerializeToString()
@@ -486,10 +614,16 @@ class ParseSingleExampleTest(tf.test.TestCase):
np.array(
[2], dtype=np.int64)) # shape: max_values = 2
+ expected_sp = ( # indices, values, shape
+ np.array([[0], [3]], dtype=np.int64),
+ np.array(["a", "b"], dtype="|S"),
+ np.array([13], dtype=np.int64)) # max_values = 13
+
a_default = [1, 2, 3]
b_default = np.random.rand(3, 3).astype(bytes)
expected_output = {
"st_a": expected_st_a,
+ "sp": expected_sp,
"a": [a_default],
"b": b_default,
"c": np.array(
@@ -502,6 +636,7 @@ class ParseSingleExampleTest(tf.test.TestCase):
"serialized": tf.convert_to_tensor(serialized),
"features": {
"st_a": tf.VarLenFeature(tf.float32),
+ "sp": tf.SparseFeature("idx", "val", tf.string, 13),
"a": tf.FixedLenFeature(
(1, 3), tf.int64, default_value=a_default),
"b": tf.FixedLenFeature(
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index 9a1f96b6fe..6fe112b6be 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -62,8 +62,8 @@ def GetTestConfigs():
all the valid test configs as tuples of data_format and use_gpu.
"""
test_configs = [("NHWC", False), ("NHWC", True)]
- if tf.test.is_gpu_available():
- # "NCHW" format is not currently supported on CPU.
+ if tf.test.is_gpu_available(cuda_only=True):
+ # "NCHW" format is currently supported exclusively on CUDA GPUs.
test_configs += [("NCHW", True)]
return test_configs
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index bfb3b3a56b..4af5c3c8a2 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -741,6 +741,21 @@ class TFRecordIteratorTest(tf.test.TestCase):
actual.append(r)
self.assertEqual(actual, original)
+ def testBadFile(self):
+ """Verify that tf_record_iterator throws an exception on bad TFRecords."""
+ fn = os.path.join(self.get_temp_dir(), "bad_file")
+ with tf.python_io.TFRecordWriter(fn) as writer:
+ writer.write(b"123")
+ fn_truncated = os.path.join(self.get_temp_dir(), "bad_file_truncated")
+ with open(fn, "rb") as f:
+ with open(fn_truncated, "wb") as f2:
+ # DataLossError requires that we've written the header, so this must
+ # be at least 12 bytes.
+ f2.write(f.read(14))
+ with self.assertRaises(tf.errors.DataLossError):
+ for _ in tf.python_io.tf_record_iterator(fn_truncated):
+ pass
+
class AsyncReaderTest(tf.test.TestCase):
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index bea631f038..776d9b6665 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -300,6 +300,36 @@ class EluTest(tf.test.TestCase):
print("elu (float64) gradient of gradient err = ", err)
self.assertLess(err, 1e-6)
-
+
+class CreluTest(tf.test.TestCase):
+
+ def testCreluShape(self):
+ f = tf.random_normal([50, 5, 7, 10])
+ t = tf.nn.crelu(f)
+ self.assertEqual([50, 5, 7, 20], t.get_shape())
+
+ def _testCrelu(self, np_features, use_gpu=False):
+ np_relu = np.maximum(np_features, np.zeros_like(np_features))
+ np_neg_relu = np.maximum(-np_features, np.zeros_like(np_features))
+ np_crelu = np.concatenate(
+ (np_relu, np_neg_relu), len(np_features.shape) - 1)
+
+ with self.test_session(use_gpu=use_gpu):
+ crelu = tf.nn.crelu(np_features)
+ tf_relu = crelu.eval()
+
+ self.assertAllClose(np_crelu, tf_relu)
+ self.assertShapeEqual(np_crelu, crelu)
+
+ def testNumbers(self):
+ for t in [np.int32, np.int64, np.float16, np.float32, np.float64]:
+ self._testCrelu(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
+ use_gpu=False)
+ if t in [np.float16, np.float32, np.float64]:
+ self._testCrelu(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
+ use_gpu=True)
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py
index e4e239169a..cc60e796ba 100644
--- a/tensorflow/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/python/kernel_tests/rnn_cell_test.py
@@ -23,9 +23,10 @@ import functools
import numpy as np
import tensorflow as tf
+from tensorflow.python.ops import rnn_cell_impl
# TODO(ebrevdo): Remove once _linear is fully deprecated.
# pylint: disable=protected-access
-from tensorflow.python.ops.rnn_cell import _linear as linear
+from tensorflow.python.ops.rnn_cell_impl import _linear as linear
# pylint: enable=protected-access
@@ -367,7 +368,7 @@ class SlimRNNCellTest(tf.test.TestCase):
m = tf.zeros([1, 2])
my_cell = functools.partial(basic_rnn_cell, num_units=2)
# pylint: disable=protected-access
- g, _ = tf.nn.rnn_cell._SlimRNNCell(my_cell)(x, m)
+ g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
# pylint: enable=protected-access
sess.run([tf.global_variables_initializer()])
res = sess.run([g], {x.name: np.array([[1., 1.]]),
@@ -384,7 +385,7 @@ class SlimRNNCellTest(tf.test.TestCase):
_, initial_state = basic_rnn_cell(inputs, None, num_units)
my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
# pylint: disable=protected-access
- slim_cell = tf.nn.rnn_cell._SlimRNNCell(my_cell)
+ slim_cell = rnn_cell_impl._SlimRNNCell(my_cell)
# pylint: enable=protected-access
slim_outputs, slim_state = slim_cell(inputs, initial_state)
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units)
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index 6ec5274873..1b1810e175 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -227,6 +227,24 @@ class ScatterNdTest(tf.test.TestCase):
tf.scatter_nd_update(ref, indices, updates).get_shape().as_list(),
shape)
+ def testExtraIndicesDimensions(self):
+ indices = tf.zeros([1, 1, 2], tf.int32)
+ updates = tf.zeros([1, 1], tf.int32)
+ shape = np.array([2, 2])
+ scatter = tf.scatter_nd(indices, updates, shape)
+ self.assertAllEqual(scatter.get_shape().as_list(), shape)
+ expected_result = np.zeros([2, 2], dtype=np.int32)
+ with self.test_session():
+ self.assertAllEqual(expected_result, scatter.eval())
+
+ ref = tf.Variable(tf.zeros(shape, tf.int32))
+ scatter_update = tf.scatter_nd_update(ref, indices, updates)
+ self.assertAllEqual(scatter_update.get_shape().as_list(), shape)
+
+ with self.test_session():
+ ref.initializer.run()
+ self.assertAllEqual(expected_result, scatter_update.eval())
+
def testUndefinedIndicesShape(self):
indices = tf.placeholder(tf.int32, shape=None)
updates = tf.placeholder(tf.int32, shape=[2, 2, 2])
diff --git a/tensorflow/python/kernel_tests/string_split_op_test.py b/tensorflow/python/kernel_tests/string_split_op_test.py
index 227832b18e..5aa1390a9a 100644
--- a/tensorflow/python/kernel_tests/string_split_op_test.py
+++ b/tensorflow/python/kernel_tests/string_split_op_test.py
@@ -63,9 +63,6 @@ class StringSplitOpTest(tf.test.TestCase):
with self.test_session() as sess:
self.assertRaises(
- ValueError, tf.string_split, strings, delimiter="delimiter")
-
- self.assertRaises(
ValueError, tf.string_split, strings, delimiter=["|", ""])
self.assertRaises(ValueError, tf.string_split, strings, delimiter=["a"])
@@ -76,6 +73,12 @@ class StringSplitOpTest(tf.test.TestCase):
self.assertAllEqual(values, [b"hello", b"world", b"hello world"])
self.assertAllEqual(shape, [2, 2])
+ tokens = tf.string_split(strings, delimiter="| ")
+ indices, values, shape = sess.run(tokens)
+ self.assertAllEqual(indices, [[0, 0], [0, 1], [1, 0], [1, 1]])
+ self.assertAllEqual(values, [b"hello", b"world", b"hello", b"world"])
+ self.assertAllEqual(shape, [2, 2])
+
def testStringSplitWithDelimiterTensor(self):
strings = ["hello|world", "hello world"]
@@ -88,14 +91,31 @@ class StringSplitOpTest(tf.test.TestCase):
sess.run(tokens, feed_dict={delimiter: ["a", "b"]})
with self.assertRaises(tf.errors.InvalidArgumentError):
sess.run(tokens, feed_dict={delimiter: ["a"]})
- with self.assertRaises(tf.errors.InvalidArgumentError):
- sess.run(tokens, feed_dict={delimiter: "abc"})
indices, values, shape = sess.run(tokens, feed_dict={delimiter: "|"})
self.assertAllEqual(indices, [[0, 0], [0, 1], [1, 0]])
self.assertAllEqual(values, [b"hello", b"world", b"hello world"])
self.assertAllEqual(shape, [2, 2])
+ def testStringSplitWithDelimitersTensor(self):
+ strings = ["hello.cruel,world", "hello cruel world"]
+
+ with self.test_session() as sess:
+ delimiter = tf.placeholder(tf.string)
+
+ tokens = tf.string_split(strings, delimiter=delimiter)
+
+ with self.assertRaises(tf.errors.InvalidArgumentError):
+ sess.run(tokens, feed_dict={delimiter: ["a", "b"]})
+ with self.assertRaises(tf.errors.InvalidArgumentError):
+ sess.run(tokens, feed_dict={delimiter: ["a"]})
+ indices, values, shape = sess.run(tokens, feed_dict={delimiter: ".,"})
+
+ self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [1, 0]])
+ self.assertAllEqual(values, [b"hello", b"cruel", b"world",
+ b"hello cruel world"])
+ self.assertAllEqual(shape, [2, 3])
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py
index 7c434f4561..06148eefa4 100644
--- a/tensorflow/python/kernel_tests/summary_ops_test.py
+++ b/tensorflow/python/kernel_tests/summary_ops_test.py
@@ -30,7 +30,8 @@ class SummaryOpsTest(tf.test.TestCase):
def testScalarSummary(self):
with self.test_session() as sess:
const = tf.constant([10.0, 20.0])
- summ = tf.scalar_summary(["c1", "c2"], const, name="mysumm")
+ summ = tf.contrib.deprecated.scalar_summary(
+ ["c1", "c2"], const, name="mysumm")
value = sess.run(summ)
self.assertEqual([], summ.get_shape())
self.assertProtoEquals("""
@@ -41,7 +42,7 @@ class SummaryOpsTest(tf.test.TestCase):
def testScalarSummaryDefaultName(self):
with self.test_session() as sess:
const = tf.constant([10.0, 20.0])
- summ = tf.scalar_summary(["c1", "c2"], const)
+ summ = tf.contrib.deprecated.scalar_summary(["c1", "c2"], const)
value = sess.run(summ)
self.assertEqual([], summ.get_shape())
self.assertProtoEquals("""
@@ -53,7 +54,7 @@ class SummaryOpsTest(tf.test.TestCase):
with self.test_session() as sess:
const = tf.constant(10.0)
summ1 = tf.summary.histogram("h", const)
- summ2 = tf.scalar_summary("c", const)
+ summ2 = tf.contrib.deprecated.scalar_summary("c", const)
merge = tf.summary.merge([summ1, summ2])
value = sess.run(merge)
self.assertEqual([], merge.get_shape())
@@ -88,11 +89,12 @@ class SummaryOpsTest(tf.test.TestCase):
self.assertEqual(2, len(merge.op.inputs))
self.assertEqual(summ1, merge.op.inputs[0])
self.assertEqual(summ3, merge.op.inputs[1])
- merge = tf.merge_all_summaries("foo_key")
+ merge = tf.contrib.deprecated.merge_all_summaries("foo_key")
self.assertEqual("MergeSummary", merge.op.type)
self.assertEqual(1, len(merge.op.inputs))
self.assertEqual(summ2, merge.op.inputs[0])
- self.assertTrue(tf.merge_all_summaries("bar_key") is None)
+ self.assertTrue(
+ tf.contrib.deprecated.merge_all_summaries("bar_key") is None)
def testHistogramSummaryTypes(self):
with tf.Graph().as_default():
diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py
index 17acbeffc4..f9508d4709 100644
--- a/tensorflow/python/kernel_tests/template_test.py
+++ b/tensorflow/python/kernel_tests/template_test.py
@@ -272,5 +272,34 @@ class TemplateTest(tf.test.TestCase):
# Template is called at the top level, so there is no preceding "foo_2".
self.assertEqual(tc.var_scope.name, "blah")
+ def test_custom_getter(self):
+ # Custom getter that maintains call count and forwards to true getter
+ custom_getter_count = [0]
+ def custom_getter(getter, name, *args, **kwargs):
+ custom_getter_count[0] += 1
+ return getter(name, *args, **kwargs)
+
+ # Test that custom getter is called both when variables are created and
+ # subsequently accessed
+ tmpl1 = template.make_template("s1", var_scoped_function,
+ custom_getter_=custom_getter)
+ self.assertEqual(custom_getter_count[0], 0)
+ tmpl1()
+ self.assertEqual(custom_getter_count[0], 1)
+ tmpl1()
+ self.assertEqual(custom_getter_count[0], 2)
+
+ # Test that custom getter is called when the variable scope is created
+ # during construction
+ custom_getter_count[0] = 0
+ tmpl2 = template.make_template("s2", var_scoped_function,
+ custom_getter_=custom_getter,
+ create_scope_now_=True)
+ self.assertEqual(custom_getter_count[0], 0)
+ tmpl2()
+ self.assertEqual(custom_getter_count[0], 1)
+ tmpl2()
+ self.assertEqual(custom_getter_count[0], 2)
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index b21dcaf8e8..16f2585fec 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -367,9 +367,9 @@ class TensorArrayTest(tf.test.TestCase):
# Test reading wrong datatype
r0_bad = gen_data_flow_ops._tensor_array_read_v2(
- handle=w0.handle, index=0, dtype=tf.int64, flow_in=w0.flow)
+ handle=w0.handle, index=0, dtype=tf.float64, flow_in=w0.flow)
with self.assertRaisesOpError(
- "TensorArray dtype is float but Op requested dtype int64."):
+ "TensorArray dtype is float but Op requested dtype double."):
r0_bad.eval()
# Test reading from a different index than the one we wrote to
diff --git a/tensorflow/python/layers/__init__.py b/tensorflow/python/layers/__init__.py
index e0e7658513..e69de29bb2 100644
--- a/tensorflow/python/layers/__init__.py
+++ b/tensorflow/python/layers/__init__.py
@@ -1,32 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-# pylint: disable=line-too-long
-"""This library provides a set of high-level neural networks layers.
-
-## Core layers
-
-@@fully_connected
-
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=g-bad-import-order,unused-import
-
-# Core layers.
-from tensorflow.python.layers.core import fully_connected
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index df538a5bd0..8d875477f6 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -24,6 +24,7 @@ from __future__ import division
from __future__ import print_function
import functools
+import inspect
import re
from six.moves import xrange # pylint: disable=redefined-builtin
import numpy as np
@@ -84,18 +85,33 @@ class _Layer(object):
self._reuse_weights = kwargs.get('_reuse_weights')
self._dtype = dtype
- # Determine name.
+ # Determine base name (non-unique).
+ base_name = name
if not name:
- prefix = _to_snake_case(self.__class__.__name__)
- name = ops.get_default_graph().unique_name(prefix, mark_as_used=False)
- self.name = name
+ base_name = _to_snake_case(self.__class__.__name__)
# Determine variable scope.
scope = kwargs.get('_scope')
if scope:
- self._scope = scope
+ self._scope = next(vs.variable_scope(scope).gen)
else:
- self._scope = next(vs.variable_scope(None, default_name=self.name).gen)
+ self._scope = next(vs.variable_scope(None, default_name=base_name).gen)
+
+ # Unique name is borrowed from scope to match variable names.
+ self._name = self._scope.name
+
+ def __setattr__(self, name, value):
+ if hasattr(self, name):
+ # Only allow self to update its own attributes
+ stack_0_locals = inspect.stack()[1][0].f_locals
+ called_from_layer = stack_0_locals.get('self', None) is self
+ if not called_from_layer:
+ raise AttributeError('Read-only property cannot be set: %s' % name)
+ super(_Layer, self).__setattr__(name, value)
+
+ @property
+ def name(self):
+ return self._name
@property
def trainable_weights(self):
@@ -135,16 +151,17 @@ class _Layer(object):
"""
self._built = True
- def call(self, inputs):
+ def call(self, inputs, **kwargs):
"""The logic of the layer lives here.
Arguments:
inputs: input tensor(s).
+ **kwargs: additional keyword arguments.
Returns:
Output tensor(s).
"""
- return inputs
+ raise NotImplementedError
def _add_weight(self, name, shape, dtype=None,
initializer=None, regularizer=None, trainable=True,
@@ -186,18 +203,23 @@ class _Layer(object):
regularization, ops.GraphKeys.REGULARIZATION_LOSSES)
return variable
- def __call__(self, inputs):
+ def __call__(self, inputs, **kwargs):
"""Wraps `call`, applying pre- and post-processing steps.
Arguments:
inputs: input tensor(s).
+ **kwargs: additional keyword arguments to be passed to `self.call`.
Returns:
Output tensor(s).
"""
- # Define a custom to override tf.get_variable when creating layer weights.
+ # Define a custom getter to override tf.get_variable when creating layer
+ # weights. We respect current custom getter, if one is set.
+ current_custom_getter = vs.get_variable_scope().custom_getter
def variable_getter(getter, name, shape, dtype=None, initializer=None,
regularizer=None, trainable=True, **kwargs):
+ if current_custom_getter is not None:
+ getter = functools.partial(current_custom_getter, getter)
return self._add_weight(
name, shape, initializer=initializer, regularizer=regularizer,
dtype=dtype, trainable=trainable,
@@ -215,7 +237,7 @@ class _Layer(object):
else:
self.build(input_shapes)
self._built = True
- outputs = self.call(inputs)
+ outputs = self.call(inputs, **kwargs)
# Apply activity regularization.
# Note that it should be applied every time the layer creates a new
@@ -233,23 +255,29 @@ class _Layer(object):
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
return outputs
- def apply(self, inputs):
+ def apply(self, inputs, **kwargs):
"""Apply the layer on a input.
This simply wraps `self.__call__`.
Arguments:
- inputs: input tensor(s).
+ inputs: Input tensor(s).
+ **kwargs: additional keyword arguments to be passed to `self.call`.
Returns:
Output tensor(s).
"""
- return self.__call__(inputs)
+ return self.__call__(inputs, **kwargs)
def _to_snake_case(name):
intermediate = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
- return re.sub('([a-z0-9])([A-Z])', r'\1_\2', intermediate).lower()
+ insecure = re.sub('([a-z0-9])([A-Z])', r'\1_\2', intermediate).lower()
+ # If the class is private the name starts with "_" which is not secure
+ # for creating scopes. We prefix the name with "private" in this case.
+ if insecure[0] != '_':
+ return insecure
+ return 'private' + insecure
def _to_list(x):
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index fd9ebd33d1..9262db2fc7 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -80,6 +80,9 @@ class BaseLayerTest(tf.test.TestCase):
self.w = tf.get_variable('my_var', [2, 2],
initializer=tf.zeros_initializer)
+ def call(self, inputs):
+ return inputs
+
layer = MyLayer(name='my_layer')
inputs = tf.random_uniform((5,), seed=1)
_ = layer.apply(inputs)
@@ -98,6 +101,38 @@ class BaseLayerTest(tf.test.TestCase):
self.assertEqual(layer.built, True)
self.assertEqual(outputs.op.name, 'my_layer/Square')
+ def testNaming(self):
+ default_layer = base_layers._Layer()
+ self.assertEqual(default_layer.name, 'private__layer')
+ default_layer1 = base_layers._Layer()
+ self.assertEqual(default_layer1.name, 'private__layer_1')
+ my_layer = base_layers._Layer(name='my_layer')
+ self.assertEqual(my_layer.name, 'my_layer')
+ my_layer1 = base_layers._Layer(name='my_layer')
+ self.assertEqual(my_layer1.name, 'my_layer_1')
+ # New graph has fully orthogonal names.
+ with tf.Graph().as_default():
+ my_layer_other_graph = base_layers._Layer(name='my_layer')
+ self.assertEqual(my_layer_other_graph.name, 'my_layer')
+ my_layer2 = base_layers._Layer(name='my_layer')
+ self.assertEqual(my_layer2.name, 'my_layer_2')
+ # Name scope shouldn't affect names.
+ with tf.name_scope('some_name_scope'):
+ default_layer2 = base_layers._Layer()
+ self.assertEqual(default_layer2.name, 'private__layer_2')
+ my_layer3 = base_layers._Layer(name='my_layer')
+ self.assertEqual(my_layer3.name, 'my_layer_3')
+ other_layer = base_layers._Layer(name='other_layer')
+ self.assertEqual(other_layer.name, 'other_layer')
+ # Variable scope gets added to names.
+ with tf.variable_scope('var_scope'):
+ default_layer_scoped = base_layers._Layer()
+ self.assertEqual(default_layer_scoped.name, 'var_scope/private__layer')
+ my_layer_scoped = base_layers._Layer(name='my_layer')
+ self.assertEqual(my_layer_scoped.name, 'var_scope/my_layer')
+ my_layer_scoped1 = base_layers._Layer(name='my_layer')
+ self.assertEqual(my_layer_scoped1.name, 'var_scope/my_layer_1')
+
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index b0c17a46af..f3ffbf33b9 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -27,11 +27,14 @@ from six.moves import xrange # pylint: disable=redefined-builtin
import numpy as np
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import standard_ops
from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.layers import base
@@ -51,40 +54,41 @@ class FullyConnected(base._Layer): # pylint: disable=protected-access
flattened prior to the initial matrix multiply by `w`.
Arguments:
- output_dim: Integer or Long, dimensionality of the output space.
+ units: Integer or Long, dimensionality of the output space.
activation: Activation function (callable). Set it to None to maintain a
linear activation.
use_bias: Boolean, whether the layer uses a bias.
- w_initializer: Initializer function for the weight matrix.
+ weights_initializer: Initializer function for the weight matrix.
bias_initializer: Initializer function for the bias.
- w_regularizer: Regularizer function for the weight matrix.
+ weights_regularizer: Regularizer function for the weight matrix.
bias_regularizer: Regularizer function for the bias.
activity_regularizer: Regularizer function for the output.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
- name: String, the name of the layer.
+ name: String, the name of the layer. Layers with the same name will
+ share weights, but to avoid mistakes we require reuse=True in such cases.
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
Properties:
- output_dim: Integer or Long, dimensionality of the output space.
+ units: Python integer, dimensionality of the output space.
activation: Activation function (callable).
use_bias: Boolean, whether the layer uses a bias.
- w_initializer: Initializer instance (or name) for the weight matrix.
+ weights_initializer: Initializer instance (or name) for the weight matrix.
bias_initializer: Initializer instance (or name) for the bias.
- w_regularizer: Regularizer instance for the weight matrix (callable)
+ weights_regularizer: Regularizer instance for the weight matrix (callable)
bias_regularizer: Regularizer instance for the bias (callable).
activity_regularizer: Regularizer instance for the output (callable)
- w: Weight matrix (TensorFlow variable or tensor).
+ weights: Weight matrix (TensorFlow variable or tensor).
bias: Bias vector, if applicable (TensorFlow variable or tensor).
"""
- def __init__(self, output_dim,
+ def __init__(self, units,
activation=None,
use_bias=True,
- w_initializer=None,
+ weights_initializer=None,
bias_initializer=init_ops.zeros_initializer,
- w_regularizer=None,
+ weights_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
trainable=True,
@@ -92,19 +96,22 @@ class FullyConnected(base._Layer): # pylint: disable=protected-access
**kwargs):
super(FullyConnected, self).__init__(trainable=trainable, name=name,
**kwargs)
- self.output_dim = output_dim
+ self.units = units
self.activation = activation
self.use_bias = use_bias
- self.w_initializer = w_initializer
+ self.weights_initializer = weights_initializer
self.bias_initializer = bias_initializer
- self.w_regularizer = w_regularizer
+ self.weights_regularizer = weights_regularizer
self.bias_regularizer = bias_regularizer
self.activity_regularizer = activity_regularizer
def build(self, input_shape):
+ input_shape = tensor_shape.TensorShape(input_shape)
+ if input_shape.ndims is None:
+ raise ValueError('Inputs to `FullyConnected` should have known rank.')
if len(input_shape) < 2:
raise ValueError('Inputs to `FullyConnected` should have rank >= 2.')
- if input_shape[-1] is None:
+ if input_shape[-1].value is None:
raise ValueError('The last dimension of the inputs to `FullyConnected` '
'should be defined. Found `None`.')
# Note that we set `trainable=True` because this is a trainable
@@ -112,14 +119,14 @@ class FullyConnected(base._Layer): # pylint: disable=protected-access
# (self.trainable = False), the variable will not be added to
# tf.trainable_variables(), and self.trainable_weights will be empty.
self.w = vs.get_variable('weights',
- shape=[input_shape[-1], self.output_dim],
- initializer=self.w_initializer,
- regularizer=self.w_regularizer,
+ shape=[input_shape[-1].value, self.units],
+ initializer=self.weights_initializer,
+ regularizer=self.weights_regularizer,
dtype=self._dtype,
trainable=True)
if self.use_bias:
self.bias = vs.get_variable('biases',
- shape=[self.output_dim,],
+ shape=[self.units,],
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
dtype=self._dtype,
@@ -130,11 +137,11 @@ class FullyConnected(base._Layer): # pylint: disable=protected-access
def call(self, inputs):
shape = inputs.get_shape().as_list()
input_dim = shape[-1]
- output_shape = shape[:-1] + [self.output_dim]
+ output_shape = shape[:-1] + [self.units]
if len(output_shape) > 2:
# Reshape the input to 2D.
output_shape_tensors = array_ops.unpack(array_ops.shape(inputs))
- output_shape_tensors[-1] = self.output_dim
+ output_shape_tensors[-1] = self.units
output_shape_tensor = array_ops.pack(output_shape_tensors)
inputs = array_ops.reshape(inputs, [-1, input_dim])
@@ -148,17 +155,17 @@ class FullyConnected(base._Layer): # pylint: disable=protected-access
outputs.set_shape(output_shape)
if self.activation is not None:
- return self.activation(outputs)
+ return self.activation(outputs) # pylint: disable=not-callable
return outputs
def fully_connected(
- inputs, output_dim,
+ inputs, units,
activation=None,
use_bias=True,
- w_initializer=None,
+ weights_initializer=None,
bias_initializer=init_ops.zeros_initializer,
- w_regularizer=None,
+ weights_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
trainable=True,
@@ -176,13 +183,13 @@ def fully_connected(
Arguments:
inputs: Tensor input.
- output_dim: Integer or Long, dimensionality of the output space.
+ units: Integer or Long, dimensionality of the output space.
activation: Activation function (callable). Set it to None to maintain a
linear activation.
use_bias: Boolean, whether the layer uses a bias.
- w_initializer: Initializer function for the weight matrix.
+ weights_initializer: Initializer function for the weight matrix.
bias_initializer: Initializer function for the bias.
- w_regularizer: Regularizer function for the weight matrix.
+ weights_regularizer: Regularizer function for the weight matrix.
bias_regularizer: Regularizer function for the bias.
activity_regularizer: Regularizer function for the output.
trainable: Boolean, if `True` also add variables to the graph collection
@@ -194,15 +201,105 @@ def fully_connected(
Returns:
Output tensor.
"""
- layer = FullyConnected(output_dim,
+ layer = FullyConnected(units,
activation=activation,
use_bias=use_bias,
- w_initializer=w_initializer,
+ weights_initializer=weights_initializer,
bias_initializer=bias_initializer,
- w_regularizer=w_regularizer,
+ weights_regularizer=weights_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
trainable=trainable,
name=name,
+ dtype=inputs.dtype.base_dtype,
+ _scope=name,
_reuse_weights=reuse)
return layer.apply(inputs)
+
+
+class Dropout(base._Layer): # pylint: disable=protected-access
+ """Applies Dropout to the input.
+
+ Dropout consists in randomly setting a fraction `rate` of input units to 0
+ at each update during training time, which helps prevent overfitting.
+ The units that are kept are scaled by `1 / (1 - rate)`, so that their
+ sum is unchanged at training time and inference time.
+
+ Arguments:
+ rate: The dropout rate, between 0 and 1. E.g. "rate=0.1" would drop out
+ 10% of input units.
+ noise_shape: 1D tensor of type `int32` representing the shape of the
+ binary dropout mask that will be multiplied with the input.
+ For instance, if your inputs have shape
+ `(batch_size, timesteps, features)`, and you want the dropout mask
+ to be the same for all timesteps, you can use
+ `noise_shape=[batch_size, 1, features]`.
+ seed: A Python integer. Used to create random seeds. See
+ [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
+ for behavior.
+ name: The name of the layer (string).
+ """
+
+ def __init__(self, rate=0.5,
+ noise_shape=None,
+ seed=None,
+ name=None,
+ **kwargs):
+ super(Dropout, self).__init__(name=name, **kwargs)
+ self.rate = rate
+ self.noise_shape = noise_shape
+ self.seed = seed
+
+ def call(self, inputs, training=False):
+ if isinstance(training, bool):
+ training_bool = training
+ else:
+ training_bool = tensor_util.constant_value(training)
+ if training_bool is False:
+ return array_ops.identity(inputs)
+ dropped_inputs = nn.dropout(inputs, 1 - self.rate,
+ noise_shape=self.noise_shape,
+ seed=self.seed)
+ if training_bool is True:
+ return dropped_inputs
+ return control_flow_ops.cond(training,
+ lambda: dropped_inputs,
+ lambda: inputs)
+
+
+def dropout(inputs,
+ rate=0.5,
+ noise_shape=None,
+ seed=None,
+ training=False,
+ name=None):
+ """Applies Dropout to the input.
+
+ Dropout consists in randomly setting a fraction `rate` of input units to 0
+ at each update during training time, which helps prevent overfitting.
+ The units that are kept are scaled by `1 / (1 - rate)`, so that their
+ sum is unchanged at training time and inference time.
+
+ Arguments:
+ inputs: Tensor input.
+ rate: The dropout rate, between 0 and 1. E.g. "rate=0.1" would drop out
+ 10% of input units.
+ noise_shape: 1D tensor of type `int32` representing the shape of the
+ binary dropout mask that will be multiplied with the input.
+ For instance, if your inputs have shape
+ `(batch_size, timesteps, features)`, and you want the dropout mask
+ to be the same for all timesteps, you can use
+ `noise_shape=[batch_size, 1, features]`.
+ seed: A Python integer. Used to create random seeds. See
+ [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
+ for behavior.
+ training: Either a Python boolean, or a TensorFlow boolean scalar tensor
+ (e.g. a placeholder). Whether to return the output in training mode
+ (apply dropout) or in inference mode (return the input untouched).
+ name: The name of the layer (string).
+
+ Returns:
+ Output tensor.
+ """
+ layer = Dropout(rate, noise_shape=noise_shape, seed=seed, name=name)
+ return layer.apply(inputs, training=training)
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index 588887ab16..710fd37fd0 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
import tensorflow as tf
from tensorflow.python.layers import core as core_layers
@@ -27,9 +28,9 @@ class FullyConnectedTest(tf.test.TestCase):
def testFCProperties(self):
fc = core_layers.FullyConnected(2, activation=tf.nn.relu, name='fc')
- self.assertEqual(fc.output_dim, 2)
+ self.assertEqual(fc.units, 2)
self.assertEqual(fc.activation, tf.nn.relu)
- self.assertEqual(fc.w_regularizer, None)
+ self.assertEqual(fc.weights_regularizer, None)
self.assertEqual(fc.bias_regularizer, None)
self.assertEqual(fc.activity_regularizer, None)
self.assertEqual(fc.use_bias, True)
@@ -141,7 +142,7 @@ class FullyConnectedTest(tf.test.TestCase):
def testWeightsRegularizer(self):
regularizer = lambda x: tf.reduce_sum(x) * 1e-3
fc = core_layers.FullyConnected(2, name='fc',
- w_regularizer=regularizer)
+ weights_regularizer=regularizer)
inputs = tf.random_uniform((5, 3), seed=1)
_ = fc(inputs)
loss_keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
@@ -167,6 +168,107 @@ class FullyConnectedTest(tf.test.TestCase):
self.assertEqual(outputs.op.name, 'fc/Relu')
self.assertEqual(outputs.get_shape().as_list(), [5, 2])
+ def testFunctionalFCTwice(self):
+ inputs = tf.random_uniform((5, 3), seed=1)
+ core_layers.fully_connected(inputs, 2)
+ vars1 = tf.trainable_variables()
+ core_layers.fully_connected(inputs, 2)
+ vars2 = tf.trainable_variables()
+ self.assertEqual(len(vars1), 2)
+ self.assertEqual(len(vars2), 4)
+
+ def testFunctionalFCTwiceReuse(self):
+ inputs = tf.random_uniform((5, 3), seed=1)
+ core_layers.fully_connected(inputs, 2, name='fc')
+ vars1 = tf.trainable_variables()
+ core_layers.fully_connected(inputs, 2, name='fc', reuse=True)
+ vars2 = tf.trainable_variables()
+ self.assertEqual(vars1, vars2)
+
+ def testFunctionalFCWithCustomGetter(self):
+ called = [0]
+ def custom_getter(getter, *args, **kwargs):
+ called[0] += 1
+ return getter(*args, **kwargs)
+ with tf.variable_scope('test', custom_getter=custom_getter):
+ inputs = tf.random_uniform((5, 3), seed=1)
+ core_layers.fully_connected(inputs, 2)
+ self.assertEqual(called[0], 2)
+
+ def testFunctionalFCInScope(self):
+ with tf.variable_scope('test'):
+ inputs = tf.random_uniform((5, 3), seed=1)
+ core_layers.fully_connected(inputs, 2, name='fc')
+ var = tf.trainable_variables()[0]
+ self.assertEqual(var.name, 'test/fc/weights:0')
+ with tf.variable_scope('test1') as scope:
+ inputs = tf.random_uniform((5, 3), seed=1)
+ core_layers.fully_connected(inputs, 2, name=scope)
+ var = tf.trainable_variables()[2]
+ self.assertEqual(var.name, 'test1/weights:0')
+ with tf.variable_scope('test2'):
+ inputs = tf.random_uniform((5, 3), seed=1)
+ core_layers.fully_connected(inputs, 2)
+ var = tf.trainable_variables()[4]
+ self.assertEqual(var.name, 'test2/fully_connected/weights:0')
+
+
+class DropoutTest(tf.test.TestCase):
+
+ def testDropoutProperties(self):
+ dp = core_layers.Dropout(0.5)
+ self.assertEqual(dp.rate, 0.5)
+ self.assertEqual(dp.name, 'dropout')
+ self.assertEqual(dp.noise_shape, None)
+
+ def testBooleanLearningPhase(self):
+ with self.test_session() as sess:
+ dp = core_layers.Dropout(0.5)
+ inputs = tf.ones((5, 3))
+ dropped = dp.apply(inputs, training=True)
+ sess.run(tf.global_variables_initializer())
+ np_output = sess.run(dropped)
+ self.assertAlmostEqual(0., np_output.min())
+ dropped = dp.apply(inputs, training=False)
+ np_output = sess.run(dropped)
+ self.assertAllClose(np.ones((5, 3)), np_output)
+
+ def testDynamicLearningPhase(self):
+ with self.test_session() as sess:
+ dp = core_layers.Dropout(0.5, seed=1)
+ inputs = tf.ones((5, 5))
+ training = tf.placeholder(dtype='bool')
+ dropped = dp.apply(inputs, training=training)
+ sess.run(tf.global_variables_initializer())
+ np_output = sess.run(dropped, feed_dict={training: True})
+ self.assertAlmostEqual(0., np_output.min())
+ np_output = sess.run(dropped, feed_dict={training: False})
+ self.assertAllClose(np.ones((5, 5)), np_output)
+
+ def testCustomNoiseShape(self):
+ with self.test_session() as sess:
+ inputs = tf.ones((5, 3, 2))
+ noise_shape = [5, 1, 2]
+ dp = core_layers.Dropout(0.5, noise_shape=noise_shape, seed=1)
+ dropped = dp.apply(inputs, training=True)
+ sess.run(tf.global_variables_initializer())
+ np_output = sess.run(dropped)
+ self.assertAlmostEqual(0., np_output.min())
+ self.assertAllClose(np_output[:, 0, :], np_output[:, 1, :])
+
+ def testFunctionalDropout(self):
+ with self.test_session() as sess:
+ inputs = tf.ones((5, 5))
+ training = tf.placeholder(dtype='bool')
+ dropped = core_layers.dropout(inputs, 0.5, training=training, seed=1)
+ self.assertEqual(dropped.op.name, 'dropout/cond/Merge')
+
+ sess.run(tf.global_variables_initializer())
+ np_output = sess.run(dropped, feed_dict={training: True})
+ self.assertAlmostEqual(0., np_output.min())
+ np_output = sess.run(dropped, feed_dict={training: False})
+ self.assertAllClose(np.ones((5, 5)), np_output)
+
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/python/layers/layers.py b/tensorflow/python/layers/layers.py
new file mode 100644
index 0000000000..1466487164
--- /dev/null
+++ b/tensorflow/python/layers/layers.py
@@ -0,0 +1,39 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# pylint: disable=line-too-long
+"""This library provides a set of high-level neural networks layers.
+
+## Core layers
+
+@@fully_connected
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+# pylint: disable=g-bad-import-order,unused-import
+
+# Core layers.
+from tensorflow.python.layers.core import fully_connected
+# pylint: enable=g-bad-import-order,unused-import
+
+_allowed_symbols = []
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/lib/io/file_io.py b/tensorflow/python/lib/io/file_io.py
index 2653219c5a..d3b06188ea 100644
--- a/tensorflow/python/lib/io/file_io.py
+++ b/tensorflow/python/lib/io/file_io.py
@@ -287,7 +287,9 @@ def create_dir(dirname):
def recursive_create_dir(dirname):
- """Create a directory and all parent/intermediate directories.
+ """Creates a directory and all parent/intermediate directories.
+
+ It succeeds if dirname already exists and is writable.
Args:
dirname: string, name of the directory to be created
diff --git a/tensorflow/python/lib/io/py_record_reader.cc b/tensorflow/python/lib/io/py_record_reader.cc
index d3f557506e..5fcb51b3b2 100644
--- a/tensorflow/python/lib/io/py_record_reader.cc
+++ b/tensorflow/python/lib/io/py_record_reader.cc
@@ -55,10 +55,14 @@ PyRecordReader::~PyRecordReader() {
delete file_;
}
-bool PyRecordReader::GetNext() {
- if (reader_ == nullptr) return false;
+void PyRecordReader::GetNext(TF_Status* status) {
+ if (reader_ == nullptr) {
+ Set_TF_Status_from_Status(status,
+ errors::FailedPrecondition("Reader is closed."));
+ return;
+ }
Status s = reader_->ReadRecord(&offset_, &record_);
- return s.ok();
+ Set_TF_Status_from_Status(status, s);
}
void PyRecordReader::Close() {
diff --git a/tensorflow/python/lib/io/py_record_reader.h b/tensorflow/python/lib/io/py_record_reader.h
index 0da74ee948..b7ecc928d2 100644
--- a/tensorflow/python/lib/io/py_record_reader.h
+++ b/tensorflow/python/lib/io/py_record_reader.h
@@ -42,10 +42,12 @@ class PyRecordReader {
~PyRecordReader();
- // Attempt to get the next record at "current_offset()". If
- // successful, returns true, and the record contents can be retrieved
- // with "this->record()". Otherwise, returns false.
- bool GetNext();
+ // Attempt to get the next record at "current_offset()". Populates status
+ // with OK on success, OUT_OF_RANGE for end of file, DATA_LOSS for some
+ // kinds of truncated reads, or another code for other errors
+ // (e.g., filesystem errors).
+ void GetNext(TF_Status* status);
+
// Return the current record contents. Only valid after the preceding call
// to GetNext() returned true
string record() const { return record_; }
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index 9dc3ac52c2..d02baeb6cd 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -71,7 +71,12 @@ def tf_record_iterator(path, options=None):
if reader is None:
raise IOError("Could not open %s." % path)
- while reader.GetNext():
+ while True:
+ try:
+ with errors.raise_exception_on_not_ok_status() as status:
+ reader.GetNext(status)
+ except errors.OutOfRangeError:
+ break
yield reader.record()
reader.Close()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 5ea35e8e04..1d7827bb98 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -2423,6 +2423,6 @@ def where(condition, x=None, y=None, name=None):
if x is None and y is None:
return gen_array_ops.where(input=condition, name=name)
elif x is not None and y is not None:
- return gen_math_ops.select(condition=condition, t=x, e=y, name=name)
+ return gen_math_ops._select(condition=condition, t=x, e=y, name=name)
else:
raise ValueError("x and y must both be non-None or both be None.")
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 8d29de1f89..ce22ffccba 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -2236,7 +2236,9 @@ class WhileContext(ControlFlowContext):
if self.outer_context: self.outer_context.Exit()
else:
value = op.inputs[0]
- if self.outer_context:
+ if (isinstance(self.outer_context, WhileContext) and
+ self.outer_context.grad_state is not None):
+ # We are in a nested while loop.
forward_ctxt = self.grad_state.forward_context
forward_ctxt.outer_context.Enter()
zeros_shape = array_ops.shape_internal(value, optimize=False)
@@ -2250,8 +2252,10 @@ class WhileContext(ControlFlowContext):
acc = array_ops.zeros(real_shape, grad.dtype)
self.outer_context.Exit()
else:
+ if self.outer_context: self.outer_context.Enter()
zeros_shape = array_ops.shape_internal(value, optimize=False)
acc = array_ops.zeros(zeros_shape, grad.dtype)
+ if self.outer_context: self.outer_context.Exit()
acc._shape = grad.get_shape() # pylint: disable=protected-access
self.Enter()
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index 5afe22e32e..ab30c8cf19 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -184,12 +184,16 @@ BatchIFFT2D
BatchIFFT3D
Complex
Conj
+FloorDiv
+FloorMod
Max
Mean
Min
Pow
Prod
Range
+RealDiv
+Select
SparseMatMul
Sum
MatMul
@@ -201,6 +205,8 @@ InvGrad
ReciprocalGrad
SqrtGrad
RsqrtGrad
+TruncateDiv
+TruncateMod
# nn_ops
AvgPoolGrad # "*Grad" accessible through nn_grad instead of nn_ops.
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index a901106e85..99f992ff5f 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -63,6 +63,7 @@ here](https://www.tensorflow.org/code/tensorflow/core/example/feature.proto).
@@VarLenFeature
@@FixedLenFeature
@@FixedLenSequenceFeature
+@@SparseFeature
@@parse_example
@@parse_single_example
@@parse_tensor
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 3502f11892..5a490b5395 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -211,7 +211,7 @@ def _SegmentMinOrMaxGrad(op, grad):
weighted_grads = math_ops.div(grad, num_selected)
gathered_grads = array_ops.gather(weighted_grads, op.inputs[1])
- return math_ops.select(is_selected, gathered_grads, zeros), None
+ return array_ops.where(is_selected, gathered_grads, zeros), None
@ops.RegisterGradient("SegmentMin")
@@ -674,11 +674,11 @@ def _PowGrad(op, grad):
# Avoid false singularity at x = 0
if x.dtype.is_complex:
# real(x) < 0 is fine for the complex case
- log_x = math_ops.select(
+ log_x = array_ops.where(
math_ops.not_equal(x, 0), math_ops.log(x), array_ops.zeros_like(x))
else:
# There's no sensible real value to return if x < 0, so return 0
- log_x = math_ops.select(x > 0, math_ops.log(x), array_ops.zeros_like(x))
+ log_x = array_ops.where(x > 0, math_ops.log(x), array_ops.zeros_like(x))
gy = array_ops.reshape(
math_ops.reduce_sum(grad * z * log_x, ry), sy)
return gx, gy
@@ -695,8 +695,8 @@ def _MaximumMinimumGrad(op, grad, selector_op):
zeros = array_ops.zeros(gradshape, gdtype)
xmask = selector_op(x, y)
rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
- xgrad = math_ops.select(xmask, grad, zeros)
- ygrad = math_ops.select(math_ops.logical_not(xmask), grad, zeros)
+ xgrad = array_ops.where(xmask, grad, zeros)
+ ygrad = array_ops.where(math_ops.logical_not(xmask), grad, zeros)
gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx)
gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy)
return (gx, gy)
@@ -750,8 +750,8 @@ def _SelectGrad(op, grad):
c = op.inputs[0]
x = op.inputs[1]
zeros = array_ops.zeros_like(x)
- return (None, math_ops.select(c, grad, zeros),
- math_ops.select(c, zeros, grad))
+ return (None, array_ops.where(c, grad, zeros),
+ array_ops.where(c, zeros, grad))
@ops.RegisterGradient("MatMul")
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 9cf092edd7..21bfe205ef 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -235,6 +235,7 @@ from tensorflow.python.ops import state_ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_math_ops import *
# pylint: enable=wildcard-import
+from tensorflow.python.util.deprecation import deprecated
# Aliases for some automatically-generated names.
linspace = gen_math_ops.lin_space
@@ -889,31 +890,7 @@ def _sparse_dense_truediv(sp_indices, sp_values, sp_shape, y, name=None):
sp_indices, sp_values, sp_shape, y, name=name)
-def truediv(x, y, name=None):
- """Divides x / y elementwise, always producing floating point results.
-
- The same as `tf.div` for floating point arguments, but casts integer arguments
- to floating point before dividing so that the result is always floating point.
- This op is generated by normal `x / y` division in Python 3 and in Python 2.7
- with `from __future__ import division`. If you want integer division that
- rounds down, use `x // y` or `tf.floordiv`.
-
- `x` and `y` must have the same numeric type. If the inputs are floating
- point, the output will have the same type. If the inputs are integral, the
- inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
- and `int64` (matching the behavior of Numpy).
-
- Args:
- x: `Tensor` numerator of numeric type.
- y: `Tensor` denominator of numeric type.
- name: A name for the operation (optional).
-
- Returns:
- `x / y` evaluated in floating point.
-
- Raises:
- TypeError: If `x` and `y` have different dtypes.
- """
+def _truediv_python3(x, y, name=None):
with ops.name_scope(name, "truediv", [x, y]) as name:
x = ops.convert_to_tensor(x, name="x")
y = ops.convert_to_tensor(y, name="y")
@@ -929,11 +906,21 @@ def truediv(x, y, name=None):
if dtype is not None:
x = cast(x, dtype)
y = cast(y, dtype)
- return gen_math_ops.real_div(x, y, name=name)
+ return gen_math_ops._real_div(x, y, name=name)
-def div(x, y, name=None):
- with ops.name_scope(name, "truediv", [x, y]) as name:
+def _div_python2(x, y, name=None):
+ """Divide two values using Python 2 semantics. Used for Tensor.__div__.
+
+ Args:
+ x: `Tensor` numerator of real numeric type.
+ y: `Tensor` denominator of real numeric type.
+ name: A name for the operation (optional).
+ Returns:
+ `x / y` returns the quotient of x and y.
+ """
+
+ with ops.name_scope(name, "div", [x, y]) as name:
x = ops.convert_to_tensor(x, name="x")
y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype)
x_dtype = x.dtype.base_dtype
@@ -942,20 +929,65 @@ def div(x, y, name=None):
raise TypeError("x and y must have the same dtype, got %r != %r" %
(x_dtype, y_dtype))
if x_dtype.is_floating or x_dtype.is_complex:
- return gen_math_ops.real_div(x, y, name=name)
+ return gen_math_ops._real_div(x, y, name=name)
else:
- return gen_math_ops.floor_div(x, y, name=name)
+ return gen_math_ops._floor_div(x, y, name=name)
+
+
+def truediv(x, y, name=None):
+ """Divides x / y elementwise (using Python 3 division operator semantics).
+
+ NOTE: Prefer using the Tensor operator or tf.divide which obey Python
+ division operator semantics.
+
+ This function forces Python 3 division operator semantics where all integer
+ arguments are cast to floating types first. This op is generated by normal
+ `x / y` division in Python 3 and in Python 2.7 with
+ `from __future__ import division`. If you want integer division that rounds
+ down, use `x // y` or `tf.floordiv`.
+
+ `x` and `y` must have the same numeric type. If the inputs are floating
+ point, the output will have the same type. If the inputs are integral, the
+ inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
+ and `int64` (matching the behavior of Numpy).
+
+ Args:
+ x: `Tensor` numerator of numeric type.
+ y: `Tensor` denominator of numeric type.
+ name: A name for the operation (optional).
+
+ Returns:
+ `x / y` evaluated in floating point.
+
+ Raises:
+ TypeError: If `x` and `y` have different dtypes.
+ """
+ return _truediv_python3(x, y, name)
-def div_deprecated(x, y, name=None):
- return gen_math_ops.div(x, y, name)
+def div(x, y, name=None):
+ """Divides x / y elementwise (using Python 2 division operator semantics).
+ NOTE: Prefer using the Tensor division operator or tf.divide which obey Python
+ division operator semantics.
-mod = gen_math_ops.floor_mod
+ This function divides `x` and `y`, forcing Python 2.7 semantics. That is,
+ if one of `x` or `y` is a float, then the result will be a float.
+ Otherwise, the output will be an integer type. Flooring semantics are used
+ for integer division.
+ Args:
+ x: `Tensor` numerator of real numeric type.
+ y: `Tensor` denominator of real numeric type.
+ name: A name for the operation (optional).
+ Returns:
+ `x / y` returns the quotient of x and y.
+ """
+ return _div_python2(x, y, name)
-def mod_deprecated(x, y, name=None):
- return gen_math_ops.mod(x, y, name)
+
+# TODO(aselle): This should be removed
+mod = gen_math_ops._floor_mod
# TODO(aselle): Deprecate this once all internal functionality uses
@@ -987,29 +1019,15 @@ def floordiv(x, y, name=None):
TypeError: If the inputs are complex.
"""
with ops.name_scope(name, "floordiv", [x, y]) as name:
- return gen_math_ops.floor_div(x, y, name=name)
+ return gen_math_ops._floor_div(x, y, name=name)
-def floordiv_deprecated(x, y, name=None):
- with ops.name_scope(name, "floordiv", [x, y]) as name:
- x = ops.convert_to_tensor(x, name="x")
- dtype = x.dtype
- if dtype.is_floating:
- return gen_math_ops.floor(gen_math_ops.div(x, y), name=name)
- else:
- if not dtype.is_integer:
- raise TypeError("Expected floating point or integer, got %r" % dtype)
- # TODO(aselle): Switch to math_ops.floor_div() when ready
- # return gen_math_ops.floor_div(x, y, name=name)
- return gen_math_ops.div(x, y, name=name)
-
-
-realdiv = gen_math_ops.real_div
-truncatediv = gen_math_ops.truncate_div
+realdiv = gen_math_ops._real_div
+truncatediv = gen_math_ops._truncate_div
# TODO(aselle): Rename this to floordiv when we can.
-floor_div = gen_math_ops.floor_div
-truncatemod = gen_math_ops.truncate_mod
-floormod = gen_math_ops.floor_mod
+floor_div = gen_math_ops._floor_div
+truncatemod = gen_math_ops._truncate_mod
+floormod = gen_math_ops._floor_mod
def _mul_dispatch(x, y, name=None):
@@ -1023,7 +1041,9 @@ def _mul_dispatch(x, y, name=None):
y.shape, x, name)
return sparse_tensor.SparseTensor(y.indices, new_vals, y.shape)
-
+# NOTE(aselle): When integer division is added for sparse_dense_cwise,
+# div, truediv, and floordiv should be delegated appropriately for
+# Python sematnics, analogous to dense cwise tensor operations.
_OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_div, "div",
sparse_tensor.SparseTensor)
_OverrideBinaryOperatorHelper(_sparse_dense_truediv, "truediv",
@@ -1034,12 +1054,12 @@ _OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_mul, "mul",
_OverrideBinaryOperatorHelper(gen_math_ops.add, "add")
_OverrideBinaryOperatorHelper(gen_math_ops.sub, "sub")
_OverrideBinaryOperatorHelper(_mul_dispatch, "mul")
-_OverrideBinaryOperatorHelper(div, "div")
-_OverrideBinaryOperatorHelper(truediv, "truediv")
+_OverrideBinaryOperatorHelper(_div_python2, "div")
+_OverrideBinaryOperatorHelper(_truediv_python3, "truediv")
_OverrideBinaryOperatorHelper(floordiv, "floordiv")
# TODO(aselle): Switch mod to floor_mod when ready
# _OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod")
-_OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod")
+_OverrideBinaryOperatorHelper(gen_math_ops._floor_mod, "mod")
_OverrideBinaryOperatorHelper(pow, "pow")
@@ -2146,3 +2166,13 @@ def reduced_shape(input_shape, axes):
input_shape, # [2, 3, 5, 7]
array_ops.fill(axes_shape, 1)
]) # [1, 1]
+
+
+@deprecated(
+ "2016-12-07",
+ "This op will be removed after the deprecation date. "
+ "Please switch to tf.where().")
+def select(condition, x, y, name=None):
+ return gen_math_ops._select(condition, x, y, name)
+select.__doc__ = gen_math_ops._select.__doc__
+
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 197ddb6a75..b2fc3a84d4 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -309,7 +309,7 @@ class DivAndModTest(test_util.TensorFlowTestCase):
def testComplexDiv(self):
foo = array_ops.constant([1.+3.j])
with self.test_session():
- _ = math_ops.div_deprecated(foo, 1.).eval()
+ _ = math_ops.divide(foo, 1.).eval()
_ = math_ops.div(foo, 2.).eval()
def testFloorDivGrad(self):
@@ -318,7 +318,7 @@ class DivAndModTest(test_util.TensorFlowTestCase):
b = variables.Variable(4.)
with self.test_session() as sess:
sess.run(variables.initialize_all_variables())
- c_grad = gradients.gradients(math_ops.div_deprecated(a, b), [a, b])
+ c_grad = gradients.gradients(math_ops.divide(a, b), [a, b])
self.assertAllEqual([x.eval() for x in c_grad], [.25, -.125])
c_grad = gradients.gradients(math_ops.div(a, b), [a, b])
self.assertAllEqual([x.eval() for x in c_grad], [.25, -.125])
@@ -330,7 +330,7 @@ class DivAndModTest(test_util.TensorFlowTestCase):
nums, divs = self.intTestData()
with self.test_session():
tf_result = (
- math_ops.floor_div(nums, divs) * divs + math_ops.floor_mod(nums, divs)
+ math_ops.floor_div(nums, divs) * divs + math_ops.floormod(nums, divs)
).eval()
tf_nums = array_ops.constant(nums)
tf_divs = array_ops.constant(divs)
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 149bde451a..bfa15f9401 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -25,7 +25,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import gen_nn_ops
-from tensorflow.python.ops import gen_math_ops
+
@ops.RegisterGradient("Conv2DBackpropInput")
def _Conv2DBackpropInputGrad(op, grad):
@@ -271,9 +271,10 @@ def _ReluGrad(op, grad):
@ops.RegisterGradient("EluGrad")
def _EluGradGrad(op, grad):
x = op.inputs[1]
- return (gen_nn_ops._elu_grad(grad, op.outputs[0]),
- gen_math_ops.select(x < 0., gen_nn_ops._elu_grad(grad, op.outputs[0] + 1),
- array_ops.zeros(shape = array_ops.shape(x), dtype = x.dtype)))
+ return (gen_nn_ops._elu_grad(grad, op.outputs[0]),
+ array_ops.where(
+ x < 0., gen_nn_ops._elu_grad(grad, op.outputs[0] + 1),
+ array_ops.zeros(shape = array_ops.shape(x), dtype = x.dtype)))
@ops.RegisterGradient("Relu6")
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 4ef95c1146..afacef7acd 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -92,7 +92,7 @@ def log_poisson_loss(log_input, targets, compute_full_loss=False, name=None):
zeros = array_ops.zeros_like(targets, dtype=targets.dtype)
ones = array_ops.ones_like(targets, dtype=targets.dtype)
cond = math_ops.logical_and(targets >= zeros, targets <= ones)
- result += math_ops.select(cond, zeros, stirling_approx)
+ result += array_ops.where(cond, zeros, stirling_approx)
return result
@@ -157,8 +157,8 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
# abs functions.
zeros = array_ops.zeros_like(logits, dtype=logits.dtype)
cond = (logits >= zeros)
- relu_logits = math_ops.select(cond, logits, zeros)
- neg_abs_logits = math_ops.select(cond, -logits, logits)
+ relu_logits = array_ops.where(cond, logits, zeros)
+ neg_abs_logits = array_ops.where(cond, -logits, logits)
return math_ops.add(relu_logits - logits * targets,
math_ops.log1p(math_ops.exp(neg_abs_logits)),
name=name)
@@ -292,7 +292,7 @@ def zero_fraction(value, name=None):
```python
z = tf.Relu(...)
- summ = tf.scalar_summary('sparsity', tf.nn.zero_fraction(z))
+ summ = tf.contrib.deprecated.scalar_summary('sparsity', tf.nn.zero_fraction(z))
```
Args:
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 31db4e9d56..35610cc554 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1218,8 +1218,8 @@ def crelu(features, name=None):
"""
with ops.name_scope(name, "CRelu", [features]) as name:
features = ops.convert_to_tensor(features, name="features")
- return gen_nn_ops.relu(array_ops.concat(array_ops.rank(features) - 1,
- [features, -features], name=name))
+ c = array_ops.concat(-1, [features, -features], name=name)
+ return gen_nn_ops.relu(c)
def relu6(features, name=None):
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index 21b957380a..fa99e3a49b 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -22,6 +22,7 @@ import collections
import re
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
@@ -29,6 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_parsing_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import sparse_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import,undefined-variable
from tensorflow.python.ops.gen_parsing_ops import *
@@ -49,6 +51,28 @@ class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])):
pass
+class SparseFeature(
+ collections.namedtuple(
+ "SparseFeature",
+ ["index_key", "value_key", "dtype", "size", "already_sorted"])):
+ """Configuration for parsing a sparse input feature.
+
+ Fields:
+ index_key: Name of index feature. The underlying feature's type must
+ be `int64` and its length must always match that of the `value_key`
+ feature.
+ value_key: Name of value feature. The underlying feature's type must
+ be `dtype` and its length must always match that of the `index_key`
+ feature.
+ dtype: Data type of the `value_key` feature.
+ size: Each value in the `index_key` feature must be in `[0, size)`.
+ already_sorted: A boolean to specify whether the values in `index_key` are
+ already sorted. If so skip sorting, False by default (optional).
+ """
+ pass
+SparseFeature.__new__.__defaults__ = (False,)
+
+
class FixedLenFeature(collections.namedtuple(
"FixedLenFeature", ["shape", "dtype", "default_value"])):
"""Configuration for parsing a fixed-length input feature.
@@ -91,7 +115,7 @@ def _features_to_raw_params(features, types):
Args:
features: A `dict` mapping feature keys to objects of a type in `types`.
types: Type of features to allow, among `FixedLenFeature`, `VarLenFeature`,
- and `FixedLenSequenceFeature`.
+ `SparseFeature`, and `FixedLenSequenceFeature`.
Returns:
Tuple of `sparse_keys`, `sparse_types`, `dense_keys`, `dense_types`,
@@ -118,6 +142,34 @@ def _features_to_raw_params(features, types):
raise ValueError("Missing type for feature %s." % key)
sparse_keys.append(key)
sparse_types.append(feature.dtype)
+ elif isinstance(feature, SparseFeature):
+ if SparseFeature not in types:
+ raise ValueError("Unsupported SparseFeature %s.", feature)
+ if not feature.index_key:
+ raise ValueError(
+ "Missing index_key for SparseFeature %s.", feature)
+ if not feature.value_key:
+ raise ValueError(
+ "Missing value_key for SparseFeature %s.", feature)
+ if not feature.dtype:
+ raise ValueError("Missing type for feature %s." % key)
+ if feature.index_key in sparse_keys:
+ dtype = sparse_types[sparse_keys.index(feature.index_key)]
+ if dtype != dtypes.int64:
+ raise ValueError("Conflicting type %s vs int64 for feature %s." % (
+ dtype, feature.index_key))
+ else:
+ sparse_keys.append(feature.index_key)
+ sparse_types.append(dtypes.int64)
+
+ if feature.value_key in sparse_keys:
+ dtype = sparse_types[sparse_keys.index(feature.value_key)]
+ if dtype != feature.dtype:
+ raise ValueError("Conflicting type %s vs %s for feature %s." % (
+ dtype, feature.dtype, feature.value_key))
+ else:
+ sparse_keys.append(feature.value_key)
+ sparse_types.append(feature.dtype)
elif isinstance(feature, FixedLenFeature):
if FixedLenFeature not in types:
raise ValueError("Unsupported FixedLenFeature %s.", feature)
@@ -149,6 +201,38 @@ def _features_to_raw_params(features, types):
dense_shapes)
+def _construct_sparse_tensors_for_sparse_features(features, tensor_dict):
+ """Merges SparseTensors of indices and values of SparseFeatures.
+
+ Updates `tensor_dict`. For `SparseFeatures` in the values of `features`
+ expects their `index_key`s and `index_value`s to be present in `tensor_dict`
+ mapping to `SparseTensor`s. Removes those, constructs a single `SparseTensor`
+ from them, and adds it to `tensor_dict` with the key from `features`.
+
+ Args:
+ features: A `dict` mapping feature keys to `SparseFeature` values.
+ Values of other types will be ignored.
+ tensor_dict: A `dict` mapping feature keys to `Tensor` and `SparseTensor`
+ values. Expected to contain keys of the `SparseFeature`s' `index_key`s and
+ `value_key`s and mapping them to `SparseTensor`s.
+ """
+ # Construct SparseTensors for SparseFeatures.
+ for key in sorted(features.keys()):
+ feature = features[key]
+ if isinstance(feature, SparseFeature):
+ sp_ids = tensor_dict[feature.index_key]
+ sp_values = tensor_dict[feature.value_key]
+ tensor_dict[key] = sparse_ops.sparse_merge(
+ sp_ids,
+ sp_values,
+ feature.size,
+ feature.already_sorted)
+ # Remove tensors from dictionary that were only used to construct
+ # SparseTensors for SparseFeature.
+ for key in set(tensor_dict.keys()) - set(features.keys()):
+ del tensor_dict[key]
+
+
def parse_example(serialized, features, name=None, example_names=None):
# pylint: disable=line-too-long
"""Parses `Example` protos into a `dict` of tensors.
@@ -158,18 +242,27 @@ def parse_example(serialized, features, name=None, example_names=None):
`example_names` may contain descriptive names for the corresponding serialized
protos. These may be useful for debugging purposes, but they have no effect on
- the output. If not `None`, `example_names` must be the same length as `serialized`.
+ the output. If not `None`, `example_names` must be the same length as
+ `serialized`.
This op parses serialized examples into a dictionary mapping keys to `Tensor`
- and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`
- and `FixedLenFeature` objects. Each `VarLenFeature` is mapped to a
- `SparseTensor`, and each `FixedLenFeature` is mapped to a `Tensor`.
+ and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`,
+ `SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature`
+ and `SparseFeature` is mapped to a `SparseTensor`, and each
+ `FixedLenFeature` is mapped to a `Tensor`.
Each `VarLenFeature` maps to a `SparseTensor` of the specified type
representing a ragged matrix. Its indices are `[batch, index]` where `batch`
is the batch entry the value is from in `serialized`, and `index` is the
value's index in the list of values associated with that feature and example.
+ Each `SparseFeature` maps to a `SparseTensor` of the specified type
+ representing a sparse matrix of shape
+ `(serialized.size(), SparseFeature.size)`. Its indices are `[batch, index]`
+ where `batch` is the batch entry the value is from in `serialized`, and
+ `index` is the value's index is given by the values in the
+ `SparseFeature.index_key` feature column.
+
Each `FixedLenFeature` `df` maps to a `Tensor` of the specified type (or
`tf.float32` if not specified) and shape `(serialized.size(),) + df.shape`.
@@ -281,11 +374,46 @@ def parse_example(serialized, features, name=None, example_names=None):
}
```
+ Given two `Example` input protos in `serialized`:
+
+ ```
+ [
+ features {
+ feature { key: "val" value { float_list { value: [ 0.5, -1.0 ] } } }
+ feature { key: "ix" value { int64_list { value: [ 3, 20 ] } } }
+ },
+ features {
+ feature { key: "val" value { float_list { value: [ 0.0 ] } } }
+ feature { key: "ix" value { int64_list { value: [ 42 ] } } }
+ }
+ ]
+ ```
+
+ And arguments
+
+ ```
+ example_names: ["input0", "input1"],
+ features: {
+ "sparse": SparseFeature("ix", "val", tf.float32, 100),
+ }
+ ```
+
+ Then the output is a dictionary:
+
+ ```python
+ {
+ "sparse": SparseTensor(
+ indices=[[0, 3], [0, 20], [1, 42]],
+ values=[0.5, -1.0, 0.0]
+ shape=[2, 100]),
+ }
+ ```
+
Args:
serialized: A vector (1-D Tensor) of strings, a batch of binary
serialized `Example` protos.
- features: A `dict` mapping feature keys to `FixedLenFeature` or
- `VarLenFeature` values.
+ features: A `dict` mapping feature keys to `FixedLenFeature`,
+ `VarLenFeature`, and `SparseFeature` values.
name: A name for this operation (optional).
example_names: A vector (1-D Tensor) of strings (optional), the names of
the serialized protos in the batch.
@@ -300,10 +428,12 @@ def parse_example(serialized, features, name=None, example_names=None):
raise ValueError("Missing: features was %s." % features)
(sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
dense_shapes) = _features_to_raw_params(
- features, [VarLenFeature, FixedLenFeature])
- return _parse_example_raw(
+ features, [VarLenFeature, SparseFeature, FixedLenFeature])
+ outputs = _parse_example_raw(
serialized, example_names, sparse_keys, sparse_types, dense_keys,
dense_types, dense_defaults, dense_shapes, name)
+ _construct_sparse_tensors_for_sparse_features(features, outputs)
+ return outputs
def _parse_example_raw(serialized,
@@ -410,8 +540,7 @@ def _parse_example_raw(serialized,
sparse_tensor.SparseTensor(ix, val, shape) for (ix, val, shape)
in zip(sparse_indices, sparse_values, sparse_shapes)]
- return dict(
- zip(sparse_keys + dense_keys, sparse_tensors + dense_values))
+ return dict(zip(sparse_keys + dense_keys, sparse_tensors + dense_values))
def parse_single_example(serialized, features, name=None, example_names=None):
@@ -447,10 +576,12 @@ def parse_single_example(serialized, features, name=None, example_names=None):
raise ValueError("Missing features.")
(sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
dense_shapes) = _features_to_raw_params(
- features, [VarLenFeature, FixedLenFeature])
- return _parse_single_example_raw(
+ features, [VarLenFeature, FixedLenFeature, SparseFeature])
+ outputs = _parse_single_example_raw(
serialized, example_names, sparse_keys, sparse_types, dense_keys,
dense_types, dense_defaults, dense_shapes, name)
+ _construct_sparse_tensors_for_sparse_features(features, outputs)
+ return outputs
def _parse_single_example_raw(serialized,
@@ -514,15 +645,16 @@ def _parse_single_example_raw(serialized,
name="NamesDependencies")
names = array_ops.expand_dims(names, 0)
- outputs = _parse_example_raw(serialized,
- names=names,
- sparse_keys=sparse_keys,
- sparse_types=sparse_types,
- dense_keys=dense_keys,
- dense_types=dense_types,
- dense_defaults=dense_defaults,
- dense_shapes=dense_shapes,
- name=name)
+ outputs = _parse_example_raw(
+ serialized,
+ names=names,
+ sparse_keys=sparse_keys,
+ sparse_types=sparse_types,
+ dense_keys=dense_keys,
+ dense_types=dense_types,
+ dense_defaults=dense_defaults,
+ dense_shapes=dense_shapes,
+ name=name)
if dense_keys is not None:
for d in dense_keys:
d_name = re.sub("[^A-Za-z0-9_.\\-/]", "_", d)
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index b1270a1937..61536ab4a0 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -27,13 +27,14 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn_cell
+from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import nest
# pylint: disable=protected-access
-_state_size_with_prefix = rnn_cell._state_size_with_prefix
+_state_size_with_prefix = rnn_cell_impl._state_size_with_prefix
# pylint: enable=protected-access
@@ -365,7 +366,7 @@ def _rnn_step(
def _copy_one_through(output, new_output):
copy_cond = (time >= sequence_length)
- return math_ops.select(copy_cond, output, new_output)
+ return array_ops.where(copy_cond, output, new_output)
def _copy_some_through(flat_new_output, flat_new_state):
# Use broadcasting select to determine which values should get
@@ -1298,7 +1299,7 @@ def raw_rnn(cell, loop_fn,
current_flat = nest.flatten(current)
candidate_flat = nest.flatten(candidate)
result_flat = [
- math_ops.select(elements_finished, current_i, candidate_i)
+ array_ops.where(elements_finished, current_i, candidate_i)
for (current_i, candidate_i) in zip(current_flat, candidate_flat)]
return nest.pack_sequence_as(
structure=current, flat_sequence=result_flat)
diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py
index d620177e90..b6da265ae0 100644
--- a/tensorflow/python/ops/rnn_cell.py
+++ b/tensorflow/python/ops/rnn_cell.py
@@ -42,854 +42,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
-import math
+# go/tf-wildcard-import
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.rnn_cell_impl import *
+# pylint: enable=wildcard-import
+# TODO(drpng): remove this once internal use has been eradicated.
+# pylint: disable=unused-import
+from tensorflow.python.ops.rnn_cell_impl import _linear
+# pylint: enable=unused-import
+from tensorflow.python.util.all_util import remove_undocumented
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import clip_ops
-from tensorflow.python.ops import embedding_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import partitioned_variables
-from tensorflow.python.ops import variable_scope as vs
-from tensorflow.python.ops.math_ops import sigmoid
-from tensorflow.python.ops.math_ops import tanh
+_allowed_symbols = []
-from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.util import nest
-
-
-def _state_size_with_prefix(state_size, prefix=None):
- """Helper function that enables int or TensorShape shape specification.
-
- This function takes a size specification, which can be an integer or a
- TensorShape, and converts it into a list of integers. One may specify any
- additional dimensions that precede the final state size specification.
-
- Args:
- state_size: TensorShape or int that specifies the size of a tensor.
- prefix: optional additional list of dimensions to prepend.
-
- Returns:
- result_state_size: list of dimensions the resulting tensor size.
- """
- result_state_size = tensor_shape.as_shape(state_size).as_list()
- if prefix is not None:
- if not isinstance(prefix, list):
- raise TypeError("prefix of _state_size_with_prefix should be a list.")
- result_state_size = prefix + result_state_size
- return result_state_size
-
-
-class RNNCell(object):
- """Abstract object representing an RNN cell.
-
- The definition of cell in this package differs from the definition used in the
- literature. In the literature, cell refers to an object with a single scalar
- output. The definition in this package refers to a horizontal array of such
- units.
-
- An RNN cell, in the most abstract setting, is anything that has
- a state and performs some operation that takes a matrix of inputs.
- This operation results in an output matrix with `self.output_size` columns.
- If `self.state_size` is an integer, this operation also results in a new
- state matrix with `self.state_size` columns. If `self.state_size` is a
- tuple of integers, then it results in a tuple of `len(state_size)` state
- matrices, each with a column size corresponding to values in `state_size`.
-
- This module provides a number of basic commonly used RNN cells, such as
- LSTM (Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number
- of operators that allow add dropouts, projections, or embeddings for inputs.
- Constructing multi-layer cells is supported by the class `MultiRNNCell`,
- or by calling the `rnn` ops several times. Every `RNNCell` must have the
- properties below and and implement `__call__` with the following signature.
- """
-
- def __call__(self, inputs, state, scope=None):
- """Run this RNN cell on inputs, starting from the given state.
-
- Args:
- inputs: `2-D` tensor with shape `[batch_size x input_size]`.
- state: if `self.state_size` is an integer, this should be a `2-D Tensor`
- with shape `[batch_size x self.state_size]`. Otherwise, if
- `self.state_size` is a tuple of integers, this should be a tuple
- with shapes `[batch_size x s] for s in self.state_size`.
- scope: VariableScope for the created subgraph; defaults to class name.
-
- Returns:
- A pair containing:
-
- - Output: A `2-D` tensor with shape `[batch_size x self.output_size]`.
- - New state: Either a single `2-D` tensor, or a tuple of tensors matching
- the arity and shapes of `state`.
- """
- raise NotImplementedError("Abstract method")
-
- @property
- def state_size(self):
- """size(s) of state(s) used by this cell.
-
- It can be represented by an Integer, a TensorShape or a tuple of Integers
- or TensorShapes.
- """
- raise NotImplementedError("Abstract method")
-
- @property
- def output_size(self):
- """Integer or TensorShape: size of outputs produced by this cell."""
- raise NotImplementedError("Abstract method")
-
- def zero_state(self, batch_size, dtype):
- """Return zero-filled state tensor(s).
-
- Args:
- batch_size: int, float, or unit Tensor representing the batch size.
- dtype: the data type to use for the state.
-
- Returns:
- If `state_size` is an int or TensorShape, then the return value is a
- `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.
-
- If `state_size` is a nested list or tuple, then the return value is
- a nested list or tuple (of the same structure) of `2-D` tensors with
- the shapes `[batch_size x s]` for each s in `state_size`.
- """
- state_size = self.state_size
- if nest.is_sequence(state_size):
- state_size_flat = nest.flatten(state_size)
- zeros_flat = [
- array_ops.zeros(
- array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])),
- dtype=dtype)
- for s in state_size_flat]
- for s, z in zip(state_size_flat, zeros_flat):
- z.set_shape(_state_size_with_prefix(s, prefix=[None]))
- zeros = nest.pack_sequence_as(structure=state_size,
- flat_sequence=zeros_flat)
- else:
- zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size])
- zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype)
- zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None]))
-
- return zeros
-
-
-class BasicRNNCell(RNNCell):
- """The most basic RNN cell."""
-
- def __init__(self, num_units, input_size=None, activation=tanh):
- if input_size is not None:
- logging.warn("%s: The input_size parameter is deprecated.", self)
- self._num_units = num_units
- self._activation = activation
-
- @property
- def state_size(self):
- return self._num_units
-
- @property
- def output_size(self):
- return self._num_units
-
- def __call__(self, inputs, state, scope=None):
- """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
- with vs.variable_scope(scope or "basic_rnn_cell"):
- output = self._activation(
- _linear([inputs, state], self._num_units, True, scope=scope))
- return output, output
-
-
-class GRUCell(RNNCell):
- """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
-
- def __init__(self, num_units, input_size=None, activation=tanh):
- if input_size is not None:
- logging.warn("%s: The input_size parameter is deprecated.", self)
- self._num_units = num_units
- self._activation = activation
-
- @property
- def state_size(self):
- return self._num_units
-
- @property
- def output_size(self):
- return self._num_units
-
- def __call__(self, inputs, state, scope=None):
- """Gated recurrent unit (GRU) with nunits cells."""
- with vs.variable_scope(scope or "gru_cell"):
- with vs.variable_scope("gates"): # Reset gate and update gate.
- # We start with bias of 1.0 to not reset and not update.
- r, u = array_ops.split(
- 1, 2, _linear([inputs, state], 2 * self._num_units, True, 1.0,
- scope=scope))
- r, u = sigmoid(r), sigmoid(u)
- with vs.variable_scope("candidate"):
- c = self._activation(_linear([inputs, r * state],
- self._num_units, True,
- scope=scope))
- new_h = u * state + (1 - u) * c
- return new_h, new_h
-
-
-_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h"))
-
-
-class LSTMStateTuple(_LSTMStateTuple):
- """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state.
-
- Stores two elements: `(c, h)`, in that order.
-
- Only used when `state_is_tuple=True`.
- """
- __slots__ = ()
-
- @property
- def dtype(self):
- (c, h) = self
- if not c.dtype == h.dtype:
- raise TypeError("Inconsistent internal state: %s vs %s" %
- (str(c.dtype), str(h.dtype)))
- return c.dtype
-
-
-class BasicLSTMCell(RNNCell):
- """Basic LSTM recurrent network cell.
-
- The implementation is based on: http://arxiv.org/abs/1409.2329.
-
- We add forget_bias (default: 1) to the biases of the forget gate in order to
- reduce the scale of forgetting in the beginning of the training.
-
- It does not allow cell clipping, a projection layer, and does not
- use peep-hole connections: it is the basic baseline.
-
- For advanced models, please use the full LSTMCell that follows.
- """
-
- def __init__(self, num_units, forget_bias=1.0, input_size=None,
- state_is_tuple=True, activation=tanh):
- """Initialize the basic LSTM cell.
-
- Args:
- num_units: int, The number of units in the LSTM cell.
- forget_bias: float, The bias added to forget gates (see above).
- input_size: Deprecated and unused.
- state_is_tuple: If True, accepted and returned states are 2-tuples of
- the `c_state` and `m_state`. If False, they are concatenated
- along the column axis. The latter behavior will soon be deprecated.
- activation: Activation function of the inner states.
- """
- if not state_is_tuple:
- logging.warn("%s: Using a concatenated state is slower and will soon be "
- "deprecated. Use state_is_tuple=True.", self)
- if input_size is not None:
- logging.warn("%s: The input_size parameter is deprecated.", self)
- self._num_units = num_units
- self._forget_bias = forget_bias
- self._state_is_tuple = state_is_tuple
- self._activation = activation
-
- @property
- def state_size(self):
- return (LSTMStateTuple(self._num_units, self._num_units)
- if self._state_is_tuple else 2 * self._num_units)
-
- @property
- def output_size(self):
- return self._num_units
-
- def __call__(self, inputs, state, scope=None):
- """Long short-term memory cell (LSTM)."""
- with vs.variable_scope(scope or "basic_lstm_cell"):
- # Parameters of gates are concatenated into one multiply for efficiency.
- if self._state_is_tuple:
- c, h = state
- else:
- c, h = array_ops.split(1, 2, state)
- concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope)
-
- # i = input_gate, j = new_input, f = forget_gate, o = output_gate
- i, j, f, o = array_ops.split(1, 4, concat)
-
- new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
- self._activation(j))
- new_h = self._activation(new_c) * sigmoid(o)
-
- if self._state_is_tuple:
- new_state = LSTMStateTuple(new_c, new_h)
- else:
- new_state = array_ops.concat(1, [new_c, new_h])
- return new_h, new_state
-
-
-class LSTMCell(RNNCell):
- """Long short-term memory unit (LSTM) recurrent network cell.
-
- The default non-peephole implementation is based on:
-
- http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
-
- S. Hochreiter and J. Schmidhuber.
- "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
-
- The peephole implementation is based on:
-
- https://research.google.com/pubs/archive/43905.pdf
-
- Hasim Sak, Andrew Senior, and Francoise Beaufays.
- "Long short-term memory recurrent neural network architectures for
- large scale acoustic modeling." INTERSPEECH, 2014.
-
- The class uses optional peep-hole connections, optional cell clipping, and
- an optional projection layer.
- """
-
- def __init__(self, num_units, input_size=None,
- use_peepholes=False, cell_clip=None,
- initializer=None, num_proj=None, proj_clip=None,
- num_unit_shards=None, num_proj_shards=None,
- forget_bias=1.0, state_is_tuple=True,
- activation=tanh):
- """Initialize the parameters for an LSTM cell.
-
- Args:
- num_units: int, The number of units in the LSTM cell
- input_size: Deprecated and unused.
- use_peepholes: bool, set True to enable diagonal/peephole connections.
- cell_clip: (optional) A float value, if provided the cell state is clipped
- by this value prior to the cell output activation.
- initializer: (optional) The initializer to use for the weight and
- projection matrices.
- num_proj: (optional) int, The output dimensionality for the projection
- matrices. If None, no projection is performed.
- proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
- provided, then the projected values are clipped elementwise to within
- `[-proj_clip, proj_clip]`.
- num_unit_shards: Deprecated, will be removed by Jan. 2017.
- Use a variable_scope partitioner instead.
- num_proj_shards: Deprecated, will be removed by Jan. 2017.
- Use a variable_scope partitioner instead.
- forget_bias: Biases of the forget gate are initialized by default to 1
- in order to reduce the scale of forgetting at the beginning of
- the training.
- state_is_tuple: If True, accepted and returned states are 2-tuples of
- the `c_state` and `m_state`. If False, they are concatenated
- along the column axis. This latter behavior will soon be deprecated.
- activation: Activation function of the inner states.
- """
- if not state_is_tuple:
- logging.warn("%s: Using a concatenated state is slower and will soon be "
- "deprecated. Use state_is_tuple=True.", self)
- if input_size is not None:
- logging.warn("%s: The input_size parameter is deprecated.", self)
- if num_unit_shards is not None or num_proj_shards is not None:
- logging.warn(
- "%s: The num_unit_shards and proj_unit_shards parameters are "
- "deprecated and will be removed in Jan 2017. "
- "Use a variable scope with a partitioner instead.", self)
-
- self._num_units = num_units
- self._use_peepholes = use_peepholes
- self._cell_clip = cell_clip
- self._initializer = initializer
- self._num_proj = num_proj
- self._proj_clip = proj_clip
- self._num_unit_shards = num_unit_shards
- self._num_proj_shards = num_proj_shards
- self._forget_bias = forget_bias
- self._state_is_tuple = state_is_tuple
- self._activation = activation
-
- if num_proj:
- self._state_size = (
- LSTMStateTuple(num_units, num_proj)
- if state_is_tuple else num_units + num_proj)
- self._output_size = num_proj
- else:
- self._state_size = (
- LSTMStateTuple(num_units, num_units)
- if state_is_tuple else 2 * num_units)
- self._output_size = num_units
-
- @property
- def state_size(self):
- return self._state_size
-
- @property
- def output_size(self):
- return self._output_size
-
- def __call__(self, inputs, state, scope=None):
- """Run one step of LSTM.
-
- Args:
- inputs: input Tensor, 2D, batch x num_units.
- state: if `state_is_tuple` is False, this must be a state Tensor,
- `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a
- tuple of state Tensors, both `2-D`, with column sizes `c_state` and
- `m_state`.
- scope: VariableScope for the created subgraph; defaults to "lstm_cell".
-
- Returns:
- A tuple containing:
-
- - A `2-D, [batch x output_dim]`, Tensor representing the output of the
- LSTM after reading `inputs` when previous state was `state`.
- Here output_dim is:
- num_proj if num_proj was set,
- num_units otherwise.
- - Tensor(s) representing the new state of LSTM after reading `inputs` when
- the previous state was `state`. Same type and shape(s) as `state`.
-
- Raises:
- ValueError: If input size cannot be inferred from inputs via
- static shape inference.
- """
- num_proj = self._num_units if self._num_proj is None else self._num_proj
-
- if self._state_is_tuple:
- (c_prev, m_prev) = state
- else:
- c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
- m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
-
- dtype = inputs.dtype
- input_size = inputs.get_shape().with_rank(2)[1]
- if input_size.value is None:
- raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
- with vs.variable_scope(scope or "lstm_cell",
- initializer=self._initializer) as unit_scope:
- if self._num_unit_shards is not None:
- unit_scope.set_partitioner(
- partitioned_variables.fixed_size_partitioner(
- self._num_unit_shards))
- # i = input_gate, j = new_input, f = forget_gate, o = output_gate
- lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True,
- scope=scope)
- i, j, f, o = array_ops.split(1, 4, lstm_matrix)
-
- # Diagonal connections
- if self._use_peepholes:
- with vs.variable_scope(unit_scope) as projection_scope:
- if self._num_unit_shards is not None:
- projection_scope.set_partitioner(None)
- w_f_diag = vs.get_variable(
- "w_f_diag", shape=[self._num_units], dtype=dtype)
- w_i_diag = vs.get_variable(
- "w_i_diag", shape=[self._num_units], dtype=dtype)
- w_o_diag = vs.get_variable(
- "w_o_diag", shape=[self._num_units], dtype=dtype)
-
- if self._use_peepholes:
- c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
- sigmoid(i + w_i_diag * c_prev) * self._activation(j))
- else:
- c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
- self._activation(j))
-
- if self._cell_clip is not None:
- # pylint: disable=invalid-unary-operand-type
- c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
- # pylint: enable=invalid-unary-operand-type
-
- if self._use_peepholes:
- m = sigmoid(o + w_o_diag * c) * self._activation(c)
- else:
- m = sigmoid(o) * self._activation(c)
-
- if self._num_proj is not None:
- with vs.variable_scope("projection") as proj_scope:
- if self._num_proj_shards is not None:
- proj_scope.set_partitioner(
- partitioned_variables.fixed_size_partitioner(
- self._num_proj_shards))
- m = _linear(m, self._num_proj, bias=False, scope=scope)
-
- if self._proj_clip is not None:
- # pylint: disable=invalid-unary-operand-type
- m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
- # pylint: enable=invalid-unary-operand-type
-
- new_state = (LSTMStateTuple(c, m) if self._state_is_tuple
- else array_ops.concat(1, [c, m]))
- return m, new_state
-
-
-class OutputProjectionWrapper(RNNCell):
- """Operator adding an output projection to the given cell.
-
- Note: in many cases it may be more efficient to not use this wrapper,
- but instead concatenate the whole sequence of your outputs in time,
- do the projection on this batch-concatenated sequence, then split it
- if needed or directly feed into a softmax.
- """
-
- def __init__(self, cell, output_size):
- """Create a cell with output projection.
-
- Args:
- cell: an RNNCell, a projection to output_size is added to it.
- output_size: integer, the size of the output after projection.
-
- Raises:
- TypeError: if cell is not an RNNCell.
- ValueError: if output_size is not positive.
- """
- if not isinstance(cell, RNNCell):
- raise TypeError("The parameter cell is not RNNCell.")
- if output_size < 1:
- raise ValueError("Parameter output_size must be > 0: %d." % output_size)
- self._cell = cell
- self._output_size = output_size
-
- @property
- def state_size(self):
- return self._cell.state_size
-
- @property
- def output_size(self):
- return self._output_size
-
- def __call__(self, inputs, state, scope=None):
- """Run the cell and output projection on inputs, starting from state."""
- output, res_state = self._cell(inputs, state)
- # Default scope: "OutputProjectionWrapper"
- with vs.variable_scope(scope or "output_projection_wrapper"):
- projected = _linear(output, self._output_size, True, scope=scope)
- return projected, res_state
-
-
-class InputProjectionWrapper(RNNCell):
- """Operator adding an input projection to the given cell.
-
- Note: in many cases it may be more efficient to not use this wrapper,
- but instead concatenate the whole sequence of your inputs in time,
- do the projection on this batch-concatenated sequence, then split it.
- """
-
- def __init__(self, cell, num_proj, input_size=None):
- """Create a cell with input projection.
-
- Args:
- cell: an RNNCell, a projection of inputs is added before it.
- num_proj: Python integer. The dimension to project to.
- input_size: Deprecated and unused.
-
- Raises:
- TypeError: if cell is not an RNNCell.
- """
- if input_size is not None:
- logging.warn("%s: The input_size parameter is deprecated.", self)
- if not isinstance(cell, RNNCell):
- raise TypeError("The parameter cell is not RNNCell.")
- self._cell = cell
- self._num_proj = num_proj
-
- @property
- def state_size(self):
- return self._cell.state_size
-
- @property
- def output_size(self):
- return self._cell.output_size
-
- def __call__(self, inputs, state, scope=None):
- """Run the input projection and then the cell."""
- # Default scope: "InputProjectionWrapper"
- with vs.variable_scope(scope or "input_projection_wrapper"):
- projected = _linear(inputs, self._num_proj, True, scope=scope)
- return self._cell(projected, state)
-
-
-class DropoutWrapper(RNNCell):
- """Operator adding dropout to inputs and outputs of the given cell."""
-
- def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0,
- seed=None):
- """Create a cell with added input and/or output dropout.
-
- Dropout is never used on the state.
-
- Args:
- cell: an RNNCell, a projection to output_size is added to it.
- input_keep_prob: unit Tensor or float between 0 and 1, input keep
- probability; if it is float and 1, no input dropout will be added.
- output_keep_prob: unit Tensor or float between 0 and 1, output keep
- probability; if it is float and 1, no output dropout will be added.
- seed: (optional) integer, the randomness seed.
-
- Raises:
- TypeError: if cell is not an RNNCell.
- ValueError: if keep_prob is not between 0 and 1.
- """
- if not isinstance(cell, RNNCell):
- raise TypeError("The parameter cell is not a RNNCell.")
- if (isinstance(input_keep_prob, float) and
- not (input_keep_prob >= 0.0 and input_keep_prob <= 1.0)):
- raise ValueError("Parameter input_keep_prob must be between 0 and 1: %d"
- % input_keep_prob)
- if (isinstance(output_keep_prob, float) and
- not (output_keep_prob >= 0.0 and output_keep_prob <= 1.0)):
- raise ValueError("Parameter output_keep_prob must be between 0 and 1: %d"
- % output_keep_prob)
- self._cell = cell
- self._input_keep_prob = input_keep_prob
- self._output_keep_prob = output_keep_prob
- self._seed = seed
-
- @property
- def state_size(self):
- return self._cell.state_size
-
- @property
- def output_size(self):
- return self._cell.output_size
-
- def __call__(self, inputs, state, scope=None):
- """Run the cell with the declared dropouts."""
- if (not isinstance(self._input_keep_prob, float) or
- self._input_keep_prob < 1):
- inputs = nn_ops.dropout(inputs, self._input_keep_prob, seed=self._seed)
- output, new_state = self._cell(inputs, state, scope)
- if (not isinstance(self._output_keep_prob, float) or
- self._output_keep_prob < 1):
- output = nn_ops.dropout(output, self._output_keep_prob, seed=self._seed)
- return output, new_state
-
-
-class EmbeddingWrapper(RNNCell):
- """Operator adding input embedding to the given cell.
-
- Note: in many cases it may be more efficient to not use this wrapper,
- but instead concatenate the whole sequence of your inputs in time,
- do the embedding on this batch-concatenated sequence, then split it and
- feed into your RNN.
- """
-
- def __init__(self, cell, embedding_classes, embedding_size, initializer=None):
- """Create a cell with an added input embedding.
-
- Args:
- cell: an RNNCell, an embedding will be put before its inputs.
- embedding_classes: integer, how many symbols will be embedded.
- embedding_size: integer, the size of the vectors we embed into.
- initializer: an initializer to use when creating the embedding;
- if None, the initializer from variable scope or a default one is used.
-
- Raises:
- TypeError: if cell is not an RNNCell.
- ValueError: if embedding_classes is not positive.
- """
- if not isinstance(cell, RNNCell):
- raise TypeError("The parameter cell is not RNNCell.")
- if embedding_classes <= 0 or embedding_size <= 0:
- raise ValueError("Both embedding_classes and embedding_size must be > 0: "
- "%d, %d." % (embedding_classes, embedding_size))
- self._cell = cell
- self._embedding_classes = embedding_classes
- self._embedding_size = embedding_size
- self._initializer = initializer
-
- @property
- def state_size(self):
- return self._cell.state_size
-
- @property
- def output_size(self):
- return self._cell.output_size
-
- def __call__(self, inputs, state, scope=None):
- """Run the cell on embedded inputs."""
- with vs.variable_scope(scope or "embedding_wrapper"): # "EmbeddingWrapper"
- with ops.device("/cpu:0"):
- if self._initializer:
- initializer = self._initializer
- elif vs.get_variable_scope().initializer:
- initializer = vs.get_variable_scope().initializer
- else:
- # Default initializer for embeddings should have variance=1.
- sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1.
- initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3)
-
- if type(state) is tuple:
- data_type = state[0].dtype
- else:
- data_type = state.dtype
-
- embedding = vs.get_variable(
- "embedding", [self._embedding_classes, self._embedding_size],
- initializer=initializer,
- dtype=data_type)
- embedded = embedding_ops.embedding_lookup(
- embedding, array_ops.reshape(inputs, [-1]))
- return self._cell(embedded, state)
-
-
-class MultiRNNCell(RNNCell):
- """RNN cell composed sequentially of multiple simple cells."""
-
- def __init__(self, cells, state_is_tuple=True):
- """Create a RNN cell composed sequentially of a number of RNNCells.
-
- Args:
- cells: list of RNNCells that will be composed in this order.
- state_is_tuple: If True, accepted and returned states are n-tuples, where
- `n = len(cells)`. If False, the states are all
- concatenated along the column axis. This latter behavior will soon be
- deprecated.
-
- Raises:
- ValueError: if cells is empty (not allowed), or at least one of the cells
- returns a state tuple but the flag `state_is_tuple` is `False`.
- """
- if not cells:
- raise ValueError("Must specify at least one cell for MultiRNNCell.")
- self._cells = cells
- self._state_is_tuple = state_is_tuple
- if not state_is_tuple:
- if any(nest.is_sequence(c.state_size) for c in self._cells):
- raise ValueError("Some cells return tuples of states, but the flag "
- "state_is_tuple is not set. State sizes are: %s"
- % str([c.state_size for c in self._cells]))
-
- @property
- def state_size(self):
- if self._state_is_tuple:
- return tuple(cell.state_size for cell in self._cells)
- else:
- return sum([cell.state_size for cell in self._cells])
-
- @property
- def output_size(self):
- return self._cells[-1].output_size
-
- def __call__(self, inputs, state, scope=None):
- """Run this multi-layer cell on inputs, starting from state."""
- with vs.variable_scope(scope or "multi_rnn_cell"):
- cur_state_pos = 0
- cur_inp = inputs
- new_states = []
- for i, cell in enumerate(self._cells):
- with vs.variable_scope("cell_%d" % i):
- if self._state_is_tuple:
- if not nest.is_sequence(state):
- raise ValueError(
- "Expected state to be a tuple of length %d, but received: %s"
- % (len(self.state_size), state))
- cur_state = state[i]
- else:
- cur_state = array_ops.slice(
- state, [0, cur_state_pos], [-1, cell.state_size])
- cur_state_pos += cell.state_size
- cur_inp, new_state = cell(cur_inp, cur_state)
- new_states.append(new_state)
- new_states = (tuple(new_states) if self._state_is_tuple
- else array_ops.concat(1, new_states))
- return cur_inp, new_states
-
-
-class _SlimRNNCell(RNNCell):
- """A simple wrapper for slim.rnn_cells."""
-
- def __init__(self, cell_fn):
- """Create a SlimRNNCell from a cell_fn.
-
- Args:
- cell_fn: a function which takes (inputs, state, scope) and produces the
- outputs and the new_state. Additionally when called with inputs=None and
- state=None it should return (initial_outputs, initial_state).
-
- Raises:
- TypeError: if cell_fn is not callable
- ValueError: if cell_fn cannot produce a valid initial state.
- """
- if not callable(cell_fn):
- raise TypeError("cell_fn %s needs to be callable", cell_fn)
- self._cell_fn = cell_fn
- self._cell_name = cell_fn.func.__name__
- init_output, init_state = self._cell_fn(None, None)
- output_shape = init_output.get_shape()
- state_shape = init_state.get_shape()
- self._output_size = output_shape.with_rank(2)[1].value
- self._state_size = state_shape.with_rank(2)[1].value
- if self._output_size is None:
- raise ValueError("Initial output created by %s has invalid shape %s" %
- (self._cell_name, output_shape))
- if self._state_size is None:
- raise ValueError("Initial state created by %s has invalid shape %s" %
- (self._cell_name, state_shape))
-
- @property
- def state_size(self):
- return self._state_size
-
- @property
- def output_size(self):
- return self._output_size
-
- def __call__(self, inputs, state, scope=None):
- scope = scope or self._cell_name
- output, state = self._cell_fn(inputs, state, scope=scope)
- return output, state
-
-
-def _linear(args, output_size, bias, bias_start=0.0, scope=None):
- """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
-
- Args:
- args: a 2D Tensor or a list of 2D, batch x n, Tensors.
- output_size: int, second dimension of W[i].
- bias: boolean, whether to add a bias term or not.
- bias_start: starting value to initialize the bias; 0 by default.
- scope: (optional) Variable scope to create parameters in.
-
- Returns:
- A 2D Tensor with shape [batch x output_size] equal to
- sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
-
- Raises:
- ValueError: if some of the arguments has unspecified or wrong shape.
- """
- if args is None or (nest.is_sequence(args) and not args):
- raise ValueError("`args` must be specified")
- if not nest.is_sequence(args):
- args = [args]
-
- # Calculate the total size of arguments on dimension 1.
- total_arg_size = 0
- shapes = [a.get_shape() for a in args]
- for shape in shapes:
- if shape.ndims != 2:
- raise ValueError("linear is expecting 2D arguments: %s" % shapes)
- if shape[1].value is None:
- raise ValueError("linear expects shape[1] to be provided for shape %s, "
- "but saw %d" % (shape, shape[1]))
- else:
- total_arg_size += shape[1].value
-
- dtype = [a.dtype for a in args][0]
-
- # Now the computation.
- scope = vs.get_variable_scope()
- with vs.variable_scope(scope) as outer_scope:
- weights = vs.get_variable(
- "weights", [total_arg_size, output_size], dtype=dtype)
- if len(args) == 1:
- res = math_ops.matmul(args[0], weights)
- else:
- res = math_ops.matmul(array_ops.concat(1, args), weights)
- if not bias:
- return res
- with vs.variable_scope(outer_scope) as inner_scope:
- inner_scope.set_partitioner(None)
- biases = vs.get_variable(
- "biases", [output_size],
- dtype=dtype,
- initializer=init_ops.constant_initializer(bias_start, dtype=dtype))
- return nn_ops.bias_add(res, biases)
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
new file mode 100644
index 0000000000..81d510de28
--- /dev/null
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -0,0 +1,872 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Module implementing RNN Cells."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import math
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope as vs
+
+from tensorflow.python.ops.math_ops import sigmoid
+from tensorflow.python.ops.math_ops import tanh
+
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import nest
+
+
+def _state_size_with_prefix(state_size, prefix=None):
+ """Helper function that enables int or TensorShape shape specification.
+
+ This function takes a size specification, which can be an integer or a
+ TensorShape, and converts it into a list of integers. One may specify any
+ additional dimensions that precede the final state size specification.
+
+ Args:
+ state_size: TensorShape or int that specifies the size of a tensor.
+ prefix: optional additional list of dimensions to prepend.
+
+ Returns:
+ result_state_size: list of dimensions the resulting tensor size.
+ """
+ result_state_size = tensor_shape.as_shape(state_size).as_list()
+ if prefix is not None:
+ if not isinstance(prefix, list):
+ raise TypeError("prefix of _state_size_with_prefix should be a list.")
+ result_state_size = prefix + result_state_size
+ return result_state_size
+
+
+class RNNCell(object):
+ """Abstract object representing an RNN cell.
+
+ The definition of cell in this package differs from the definition used in the
+ literature. In the literature, cell refers to an object with a single scalar
+ output. The definition in this package refers to a horizontal array of such
+ units.
+
+ An RNN cell, in the most abstract setting, is anything that has
+ a state and performs some operation that takes a matrix of inputs.
+ This operation results in an output matrix with `self.output_size` columns.
+ If `self.state_size` is an integer, this operation also results in a new
+ state matrix with `self.state_size` columns. If `self.state_size` is a
+ tuple of integers, then it results in a tuple of `len(state_size)` state
+ matrices, each with a column size corresponding to values in `state_size`.
+
+ This module provides a number of basic commonly used RNN cells, such as
+ LSTM (Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number
+ of operators that allow add dropouts, projections, or embeddings for inputs.
+ Constructing multi-layer cells is supported by the class `MultiRNNCell`,
+ or by calling the `rnn` ops several times. Every `RNNCell` must have the
+ properties below and and implement `__call__` with the following signature.
+ """
+
+ def __call__(self, inputs, state, scope=None):
+ """Run this RNN cell on inputs, starting from the given state.
+
+ Args:
+ inputs: `2-D` tensor with shape `[batch_size x input_size]`.
+ state: if `self.state_size` is an integer, this should be a `2-D Tensor`
+ with shape `[batch_size x self.state_size]`. Otherwise, if
+ `self.state_size` is a tuple of integers, this should be a tuple
+ with shapes `[batch_size x s] for s in self.state_size`.
+ scope: VariableScope for the created subgraph; defaults to class name.
+
+ Returns:
+ A pair containing:
+
+ - Output: A `2-D` tensor with shape `[batch_size x self.output_size]`.
+ - New state: Either a single `2-D` tensor, or a tuple of tensors matching
+ the arity and shapes of `state`.
+ """
+ raise NotImplementedError("Abstract method")
+
+ @property
+ def state_size(self):
+ """size(s) of state(s) used by this cell.
+
+ It can be represented by an Integer, a TensorShape or a tuple of Integers
+ or TensorShapes.
+ """
+ raise NotImplementedError("Abstract method")
+
+ @property
+ def output_size(self):
+ """Integer or TensorShape: size of outputs produced by this cell."""
+ raise NotImplementedError("Abstract method")
+
+ def zero_state(self, batch_size, dtype):
+ """Return zero-filled state tensor(s).
+
+ Args:
+ batch_size: int, float, or unit Tensor representing the batch size.
+ dtype: the data type to use for the state.
+
+ Returns:
+ If `state_size` is an int or TensorShape, then the return value is a
+ `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.
+
+ If `state_size` is a nested list or tuple, then the return value is
+ a nested list or tuple (of the same structure) of `2-D` tensors with
+ the shapes `[batch_size x s]` for each s in `state_size`.
+ """
+ state_size = self.state_size
+ if nest.is_sequence(state_size):
+ state_size_flat = nest.flatten(state_size)
+ zeros_flat = [
+ array_ops.zeros(
+ array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])),
+ dtype=dtype)
+ for s in state_size_flat]
+ for s, z in zip(state_size_flat, zeros_flat):
+ z.set_shape(_state_size_with_prefix(s, prefix=[None]))
+ zeros = nest.pack_sequence_as(structure=state_size,
+ flat_sequence=zeros_flat)
+ else:
+ zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size])
+ zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype)
+ zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None]))
+
+ return zeros
+
+
+class BasicRNNCell(RNNCell):
+ """The most basic RNN cell."""
+
+ def __init__(self, num_units, input_size=None, activation=tanh):
+ if input_size is not None:
+ logging.warn("%s: The input_size parameter is deprecated.", self)
+ self._num_units = num_units
+ self._activation = activation
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def __call__(self, inputs, state, scope=None):
+ """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
+ with vs.variable_scope(scope or "basic_rnn_cell"):
+ output = self._activation(
+ _linear([inputs, state], self._num_units, True, scope=scope))
+ return output, output
+
+
+class GRUCell(RNNCell):
+ """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
+
+ def __init__(self, num_units, input_size=None, activation=tanh):
+ if input_size is not None:
+ logging.warn("%s: The input_size parameter is deprecated.", self)
+ self._num_units = num_units
+ self._activation = activation
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def __call__(self, inputs, state, scope=None):
+ """Gated recurrent unit (GRU) with nunits cells."""
+ with vs.variable_scope(scope or "gru_cell"):
+ with vs.variable_scope("gates"): # Reset gate and update gate.
+ # We start with bias of 1.0 to not reset and not update.
+ r, u = array_ops.split(
+ 1, 2, _linear([inputs, state], 2 * self._num_units, True, 1.0,
+ scope=scope))
+ r, u = sigmoid(r), sigmoid(u)
+ with vs.variable_scope("candidate"):
+ c = self._activation(_linear([inputs, r * state],
+ self._num_units, True,
+ scope=scope))
+ new_h = u * state + (1 - u) * c
+ return new_h, new_h
+
+
+_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h"))
+
+
+class LSTMStateTuple(_LSTMStateTuple):
+ """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state.
+
+ Stores two elements: `(c, h)`, in that order.
+
+ Only used when `state_is_tuple=True`.
+ """
+ __slots__ = ()
+
+ @property
+ def dtype(self):
+ (c, h) = self
+ if not c.dtype == h.dtype:
+ raise TypeError("Inconsistent internal state: %s vs %s" %
+ (str(c.dtype), str(h.dtype)))
+ return c.dtype
+
+
+class BasicLSTMCell(RNNCell):
+ """Basic LSTM recurrent network cell.
+
+ The implementation is based on: http://arxiv.org/abs/1409.2329.
+
+ We add forget_bias (default: 1) to the biases of the forget gate in order to
+ reduce the scale of forgetting in the beginning of the training.
+
+ It does not allow cell clipping, a projection layer, and does not
+ use peep-hole connections: it is the basic baseline.
+
+ For advanced models, please use the full LSTMCell that follows.
+ """
+
+ def __init__(self, num_units, forget_bias=1.0, input_size=None,
+ state_is_tuple=True, activation=tanh):
+ """Initialize the basic LSTM cell.
+
+ Args:
+ num_units: int, The number of units in the LSTM cell.
+ forget_bias: float, The bias added to forget gates (see above).
+ input_size: Deprecated and unused.
+ state_is_tuple: If True, accepted and returned states are 2-tuples of
+ the `c_state` and `m_state`. If False, they are concatenated
+ along the column axis. The latter behavior will soon be deprecated.
+ activation: Activation function of the inner states.
+ """
+ if not state_is_tuple:
+ logging.warn("%s: Using a concatenated state is slower and will soon be "
+ "deprecated. Use state_is_tuple=True.", self)
+ if input_size is not None:
+ logging.warn("%s: The input_size parameter is deprecated.", self)
+ self._num_units = num_units
+ self._forget_bias = forget_bias
+ self._state_is_tuple = state_is_tuple
+ self._activation = activation
+
+ @property
+ def state_size(self):
+ return (LSTMStateTuple(self._num_units, self._num_units)
+ if self._state_is_tuple else 2 * self._num_units)
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def __call__(self, inputs, state, scope=None):
+ """Long short-term memory cell (LSTM)."""
+ with vs.variable_scope(scope or "basic_lstm_cell"):
+ # Parameters of gates are concatenated into one multiply for efficiency.
+ if self._state_is_tuple:
+ c, h = state
+ else:
+ c, h = array_ops.split(1, 2, state)
+ concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope)
+
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ i, j, f, o = array_ops.split(1, 4, concat)
+
+ new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
+ self._activation(j))
+ new_h = self._activation(new_c) * sigmoid(o)
+
+ if self._state_is_tuple:
+ new_state = LSTMStateTuple(new_c, new_h)
+ else:
+ new_state = array_ops.concat(1, [new_c, new_h])
+ return new_h, new_state
+
+
+class LSTMCell(RNNCell):
+ """Long short-term memory unit (LSTM) recurrent network cell.
+
+ The default non-peephole implementation is based on:
+
+ http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
+
+ S. Hochreiter and J. Schmidhuber.
+ "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+
+ The peephole implementation is based on:
+
+ https://research.google.com/pubs/archive/43905.pdf
+
+ Hasim Sak, Andrew Senior, and Francoise Beaufays.
+ "Long short-term memory recurrent neural network architectures for
+ large scale acoustic modeling." INTERSPEECH, 2014.
+
+ The class uses optional peep-hole connections, optional cell clipping, and
+ an optional projection layer.
+ """
+
+ def __init__(self, num_units, input_size=None,
+ use_peepholes=False, cell_clip=None,
+ initializer=None, num_proj=None, proj_clip=None,
+ num_unit_shards=None, num_proj_shards=None,
+ forget_bias=1.0, state_is_tuple=True,
+ activation=tanh):
+ """Initialize the parameters for an LSTM cell.
+
+ Args:
+ num_units: int, The number of units in the LSTM cell
+ input_size: Deprecated and unused.
+ use_peepholes: bool, set True to enable diagonal/peephole connections.
+ cell_clip: (optional) A float value, if provided the cell state is clipped
+ by this value prior to the cell output activation.
+ initializer: (optional) The initializer to use for the weight and
+ projection matrices.
+ num_proj: (optional) int, The output dimensionality for the projection
+ matrices. If None, no projection is performed.
+ proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
+ provided, then the projected values are clipped elementwise to within
+ `[-proj_clip, proj_clip]`.
+ num_unit_shards: Deprecated, will be removed by Jan. 2017.
+ Use a variable_scope partitioner instead.
+ num_proj_shards: Deprecated, will be removed by Jan. 2017.
+ Use a variable_scope partitioner instead.
+ forget_bias: Biases of the forget gate are initialized by default to 1
+ in order to reduce the scale of forgetting at the beginning of
+ the training.
+ state_is_tuple: If True, accepted and returned states are 2-tuples of
+ the `c_state` and `m_state`. If False, they are concatenated
+ along the column axis. This latter behavior will soon be deprecated.
+ activation: Activation function of the inner states.
+ """
+ if not state_is_tuple:
+ logging.warn("%s: Using a concatenated state is slower and will soon be "
+ "deprecated. Use state_is_tuple=True.", self)
+ if input_size is not None:
+ logging.warn("%s: The input_size parameter is deprecated.", self)
+ if num_unit_shards is not None or num_proj_shards is not None:
+ logging.warn(
+ "%s: The num_unit_shards and proj_unit_shards parameters are "
+ "deprecated and will be removed in Jan 2017. "
+ "Use a variable scope with a partitioner instead.", self)
+
+ self._num_units = num_units
+ self._use_peepholes = use_peepholes
+ self._cell_clip = cell_clip
+ self._initializer = initializer
+ self._num_proj = num_proj
+ self._proj_clip = proj_clip
+ self._num_unit_shards = num_unit_shards
+ self._num_proj_shards = num_proj_shards
+ self._forget_bias = forget_bias
+ self._state_is_tuple = state_is_tuple
+ self._activation = activation
+
+ if num_proj:
+ self._state_size = (
+ LSTMStateTuple(num_units, num_proj)
+ if state_is_tuple else num_units + num_proj)
+ self._output_size = num_proj
+ else:
+ self._state_size = (
+ LSTMStateTuple(num_units, num_units)
+ if state_is_tuple else 2 * num_units)
+ self._output_size = num_units
+
+ @property
+ def state_size(self):
+ return self._state_size
+
+ @property
+ def output_size(self):
+ return self._output_size
+
+ def __call__(self, inputs, state, scope=None):
+ """Run one step of LSTM.
+
+ Args:
+ inputs: input Tensor, 2D, batch x num_units.
+ state: if `state_is_tuple` is False, this must be a state Tensor,
+ `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a
+ tuple of state Tensors, both `2-D`, with column sizes `c_state` and
+ `m_state`.
+ scope: VariableScope for the created subgraph; defaults to "lstm_cell".
+
+ Returns:
+ A tuple containing:
+
+ - A `2-D, [batch x output_dim]`, Tensor representing the output of the
+ LSTM after reading `inputs` when previous state was `state`.
+ Here output_dim is:
+ num_proj if num_proj was set,
+ num_units otherwise.
+ - Tensor(s) representing the new state of LSTM after reading `inputs` when
+ the previous state was `state`. Same type and shape(s) as `state`.
+
+ Raises:
+ ValueError: If input size cannot be inferred from inputs via
+ static shape inference.
+ """
+ num_proj = self._num_units if self._num_proj is None else self._num_proj
+
+ if self._state_is_tuple:
+ (c_prev, m_prev) = state
+ else:
+ c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
+ m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
+
+ dtype = inputs.dtype
+ input_size = inputs.get_shape().with_rank(2)[1]
+ if input_size.value is None:
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
+ with vs.variable_scope(scope or "lstm_cell",
+ initializer=self._initializer) as unit_scope:
+ if self._num_unit_shards is not None:
+ unit_scope.set_partitioner(
+ partitioned_variables.fixed_size_partitioner(
+ self._num_unit_shards))
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True,
+ scope=scope)
+ i, j, f, o = array_ops.split(1, 4, lstm_matrix)
+
+ # Diagonal connections
+ if self._use_peepholes:
+ with vs.variable_scope(unit_scope) as projection_scope:
+ if self._num_unit_shards is not None:
+ projection_scope.set_partitioner(None)
+ w_f_diag = vs.get_variable(
+ "w_f_diag", shape=[self._num_units], dtype=dtype)
+ w_i_diag = vs.get_variable(
+ "w_i_diag", shape=[self._num_units], dtype=dtype)
+ w_o_diag = vs.get_variable(
+ "w_o_diag", shape=[self._num_units], dtype=dtype)
+
+ if self._use_peepholes:
+ c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
+ sigmoid(i + w_i_diag * c_prev) * self._activation(j))
+ else:
+ c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
+ self._activation(j))
+
+ if self._cell_clip is not None:
+ # pylint: disable=invalid-unary-operand-type
+ c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
+ # pylint: enable=invalid-unary-operand-type
+
+ if self._use_peepholes:
+ m = sigmoid(o + w_o_diag * c) * self._activation(c)
+ else:
+ m = sigmoid(o) * self._activation(c)
+
+ if self._num_proj is not None:
+ with vs.variable_scope("projection") as proj_scope:
+ if self._num_proj_shards is not None:
+ proj_scope.set_partitioner(
+ partitioned_variables.fixed_size_partitioner(
+ self._num_proj_shards))
+ m = _linear(m, self._num_proj, bias=False, scope=scope)
+
+ if self._proj_clip is not None:
+ # pylint: disable=invalid-unary-operand-type
+ m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
+ # pylint: enable=invalid-unary-operand-type
+
+ new_state = (LSTMStateTuple(c, m) if self._state_is_tuple
+ else array_ops.concat(1, [c, m]))
+ return m, new_state
+
+
+class OutputProjectionWrapper(RNNCell):
+ """Operator adding an output projection to the given cell.
+
+ Note: in many cases it may be more efficient to not use this wrapper,
+ but instead concatenate the whole sequence of your outputs in time,
+ do the projection on this batch-concatenated sequence, then split it
+ if needed or directly feed into a softmax.
+ """
+
+ def __init__(self, cell, output_size):
+ """Create a cell with output projection.
+
+ Args:
+ cell: an RNNCell, a projection to output_size is added to it.
+ output_size: integer, the size of the output after projection.
+
+ Raises:
+ TypeError: if cell is not an RNNCell.
+ ValueError: if output_size is not positive.
+ """
+ if not isinstance(cell, RNNCell):
+ raise TypeError("The parameter cell is not RNNCell.")
+ if output_size < 1:
+ raise ValueError("Parameter output_size must be > 0: %d." % output_size)
+ self._cell = cell
+ self._output_size = output_size
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ @property
+ def output_size(self):
+ return self._output_size
+
+ def __call__(self, inputs, state, scope=None):
+ """Run the cell and output projection on inputs, starting from state."""
+ output, res_state = self._cell(inputs, state)
+ # Default scope: "OutputProjectionWrapper"
+ with vs.variable_scope(scope or "output_projection_wrapper"):
+ projected = _linear(output, self._output_size, True, scope=scope)
+ return projected, res_state
+
+
+class InputProjectionWrapper(RNNCell):
+ """Operator adding an input projection to the given cell.
+
+ Note: in many cases it may be more efficient to not use this wrapper,
+ but instead concatenate the whole sequence of your inputs in time,
+ do the projection on this batch-concatenated sequence, then split it.
+ """
+
+ def __init__(self, cell, num_proj, input_size=None):
+ """Create a cell with input projection.
+
+ Args:
+ cell: an RNNCell, a projection of inputs is added before it.
+ num_proj: Python integer. The dimension to project to.
+ input_size: Deprecated and unused.
+
+ Raises:
+ TypeError: if cell is not an RNNCell.
+ """
+ if input_size is not None:
+ logging.warn("%s: The input_size parameter is deprecated.", self)
+ if not isinstance(cell, RNNCell):
+ raise TypeError("The parameter cell is not RNNCell.")
+ self._cell = cell
+ self._num_proj = num_proj
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ def __call__(self, inputs, state, scope=None):
+ """Run the input projection and then the cell."""
+ # Default scope: "InputProjectionWrapper"
+ with vs.variable_scope(scope or "input_projection_wrapper"):
+ projected = _linear(inputs, self._num_proj, True, scope=scope)
+ return self._cell(projected, state)
+
+
+class DropoutWrapper(RNNCell):
+ """Operator adding dropout to inputs and outputs of the given cell."""
+
+ def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0,
+ seed=None):
+ """Create a cell with added input and/or output dropout.
+
+ Dropout is never used on the state.
+
+ Args:
+ cell: an RNNCell, a projection to output_size is added to it.
+ input_keep_prob: unit Tensor or float between 0 and 1, input keep
+ probability; if it is float and 1, no input dropout will be added.
+ output_keep_prob: unit Tensor or float between 0 and 1, output keep
+ probability; if it is float and 1, no output dropout will be added.
+ seed: (optional) integer, the randomness seed.
+
+ Raises:
+ TypeError: if cell is not an RNNCell.
+ ValueError: if keep_prob is not between 0 and 1.
+ """
+ if not isinstance(cell, RNNCell):
+ raise TypeError("The parameter cell is not a RNNCell.")
+ if (isinstance(input_keep_prob, float) and
+ not (input_keep_prob >= 0.0 and input_keep_prob <= 1.0)):
+ raise ValueError("Parameter input_keep_prob must be between 0 and 1: %d"
+ % input_keep_prob)
+ if (isinstance(output_keep_prob, float) and
+ not (output_keep_prob >= 0.0 and output_keep_prob <= 1.0)):
+ raise ValueError("Parameter output_keep_prob must be between 0 and 1: %d"
+ % output_keep_prob)
+ self._cell = cell
+ self._input_keep_prob = input_keep_prob
+ self._output_keep_prob = output_keep_prob
+ self._seed = seed
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ def __call__(self, inputs, state, scope=None):
+ """Run the cell with the declared dropouts."""
+ if (not isinstance(self._input_keep_prob, float) or
+ self._input_keep_prob < 1):
+ inputs = nn_ops.dropout(inputs, self._input_keep_prob, seed=self._seed)
+ output, new_state = self._cell(inputs, state, scope)
+ if (not isinstance(self._output_keep_prob, float) or
+ self._output_keep_prob < 1):
+ output = nn_ops.dropout(output, self._output_keep_prob, seed=self._seed)
+ return output, new_state
+
+
+class EmbeddingWrapper(RNNCell):
+ """Operator adding input embedding to the given cell.
+
+ Note: in many cases it may be more efficient to not use this wrapper,
+ but instead concatenate the whole sequence of your inputs in time,
+ do the embedding on this batch-concatenated sequence, then split it and
+ feed into your RNN.
+ """
+
+ def __init__(self, cell, embedding_classes, embedding_size, initializer=None):
+ """Create a cell with an added input embedding.
+
+ Args:
+ cell: an RNNCell, an embedding will be put before its inputs.
+ embedding_classes: integer, how many symbols will be embedded.
+ embedding_size: integer, the size of the vectors we embed into.
+ initializer: an initializer to use when creating the embedding;
+ if None, the initializer from variable scope or a default one is used.
+
+ Raises:
+ TypeError: if cell is not an RNNCell.
+ ValueError: if embedding_classes is not positive.
+ """
+ if not isinstance(cell, RNNCell):
+ raise TypeError("The parameter cell is not RNNCell.")
+ if embedding_classes <= 0 or embedding_size <= 0:
+ raise ValueError("Both embedding_classes and embedding_size must be > 0: "
+ "%d, %d." % (embedding_classes, embedding_size))
+ self._cell = cell
+ self._embedding_classes = embedding_classes
+ self._embedding_size = embedding_size
+ self._initializer = initializer
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ def __call__(self, inputs, state, scope=None):
+ """Run the cell on embedded inputs."""
+ with vs.variable_scope(scope or "embedding_wrapper"): # "EmbeddingWrapper"
+ with ops.device("/cpu:0"):
+ if self._initializer:
+ initializer = self._initializer
+ elif vs.get_variable_scope().initializer:
+ initializer = vs.get_variable_scope().initializer
+ else:
+ # Default initializer for embeddings should have variance=1.
+ sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1.
+ initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3)
+
+ if type(state) is tuple:
+ data_type = state[0].dtype
+ else:
+ data_type = state.dtype
+
+ embedding = vs.get_variable(
+ "embedding", [self._embedding_classes, self._embedding_size],
+ initializer=initializer,
+ dtype=data_type)
+ embedded = embedding_ops.embedding_lookup(
+ embedding, array_ops.reshape(inputs, [-1]))
+ return self._cell(embedded, state)
+
+
+class MultiRNNCell(RNNCell):
+ """RNN cell composed sequentially of multiple simple cells."""
+
+ def __init__(self, cells, state_is_tuple=True):
+ """Create a RNN cell composed sequentially of a number of RNNCells.
+
+ Args:
+ cells: list of RNNCells that will be composed in this order.
+ state_is_tuple: If True, accepted and returned states are n-tuples, where
+ `n = len(cells)`. If False, the states are all
+ concatenated along the column axis. This latter behavior will soon be
+ deprecated.
+
+ Raises:
+ ValueError: if cells is empty (not allowed), or at least one of the cells
+ returns a state tuple but the flag `state_is_tuple` is `False`.
+ """
+ if not cells:
+ raise ValueError("Must specify at least one cell for MultiRNNCell.")
+ self._cells = cells
+ self._state_is_tuple = state_is_tuple
+ if not state_is_tuple:
+ if any(nest.is_sequence(c.state_size) for c in self._cells):
+ raise ValueError("Some cells return tuples of states, but the flag "
+ "state_is_tuple is not set. State sizes are: %s"
+ % str([c.state_size for c in self._cells]))
+
+ @property
+ def state_size(self):
+ if self._state_is_tuple:
+ return tuple(cell.state_size for cell in self._cells)
+ else:
+ return sum([cell.state_size for cell in self._cells])
+
+ @property
+ def output_size(self):
+ return self._cells[-1].output_size
+
+ def __call__(self, inputs, state, scope=None):
+ """Run this multi-layer cell on inputs, starting from state."""
+ with vs.variable_scope(scope or "multi_rnn_cell"):
+ cur_state_pos = 0
+ cur_inp = inputs
+ new_states = []
+ for i, cell in enumerate(self._cells):
+ with vs.variable_scope("cell_%d" % i):
+ if self._state_is_tuple:
+ if not nest.is_sequence(state):
+ raise ValueError(
+ "Expected state to be a tuple of length %d, but received: %s"
+ % (len(self.state_size), state))
+ cur_state = state[i]
+ else:
+ cur_state = array_ops.slice(
+ state, [0, cur_state_pos], [-1, cell.state_size])
+ cur_state_pos += cell.state_size
+ cur_inp, new_state = cell(cur_inp, cur_state)
+ new_states.append(new_state)
+ new_states = (tuple(new_states) if self._state_is_tuple
+ else array_ops.concat(1, new_states))
+ return cur_inp, new_states
+
+
+class _SlimRNNCell(RNNCell):
+ """A simple wrapper for slim.rnn_cells."""
+
+ def __init__(self, cell_fn):
+ """Create a SlimRNNCell from a cell_fn.
+
+ Args:
+ cell_fn: a function which takes (inputs, state, scope) and produces the
+ outputs and the new_state. Additionally when called with inputs=None and
+ state=None it should return (initial_outputs, initial_state).
+
+ Raises:
+ TypeError: if cell_fn is not callable
+ ValueError: if cell_fn cannot produce a valid initial state.
+ """
+ if not callable(cell_fn):
+ raise TypeError("cell_fn %s needs to be callable", cell_fn)
+ self._cell_fn = cell_fn
+ self._cell_name = cell_fn.func.__name__
+ init_output, init_state = self._cell_fn(None, None)
+ output_shape = init_output.get_shape()
+ state_shape = init_state.get_shape()
+ self._output_size = output_shape.with_rank(2)[1].value
+ self._state_size = state_shape.with_rank(2)[1].value
+ if self._output_size is None:
+ raise ValueError("Initial output created by %s has invalid shape %s" %
+ (self._cell_name, output_shape))
+ if self._state_size is None:
+ raise ValueError("Initial state created by %s has invalid shape %s" %
+ (self._cell_name, state_shape))
+
+ @property
+ def state_size(self):
+ return self._state_size
+
+ @property
+ def output_size(self):
+ return self._output_size
+
+ def __call__(self, inputs, state, scope=None):
+ scope = scope or self._cell_name
+ output, state = self._cell_fn(inputs, state, scope=scope)
+ return output, state
+
+
+def _linear(args, output_size, bias, bias_start=0.0, scope=None):
+ """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
+
+ Args:
+ args: a 2D Tensor or a list of 2D, batch x n, Tensors.
+ output_size: int, second dimension of W[i].
+ bias: boolean, whether to add a bias term or not.
+ bias_start: starting value to initialize the bias; 0 by default.
+ scope: (optional) Variable scope to create parameters in.
+
+ Returns:
+ A 2D Tensor with shape [batch x output_size] equal to
+ sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
+
+ Raises:
+ ValueError: if some of the arguments has unspecified or wrong shape.
+ """
+ if args is None or (nest.is_sequence(args) and not args):
+ raise ValueError("`args` must be specified")
+ if not nest.is_sequence(args):
+ args = [args]
+
+ # Calculate the total size of arguments on dimension 1.
+ total_arg_size = 0
+ shapes = [a.get_shape() for a in args]
+ for shape in shapes:
+ if shape.ndims != 2:
+ raise ValueError("linear is expecting 2D arguments: %s" % shapes)
+ if shape[1].value is None:
+ raise ValueError("linear expects shape[1] to be provided for shape %s, "
+ "but saw %d" % (shape, shape[1]))
+ else:
+ total_arg_size += shape[1].value
+
+ dtype = [a.dtype for a in args][0]
+
+ # Now the computation.
+ scope = vs.get_variable_scope()
+ with vs.variable_scope(scope) as outer_scope:
+ weights = vs.get_variable(
+ "weights", [total_arg_size, output_size], dtype=dtype)
+ if len(args) == 1:
+ res = math_ops.matmul(args[0], weights)
+ else:
+ res = math_ops.matmul(array_ops.concat(1, args), weights)
+ if not bias:
+ return res
+ with vs.variable_scope(outer_scope) as inner_scope:
+ inner_scope.set_partitioner(None)
+ biases = vs.get_variable(
+ "biases", [output_size],
+ dtype=dtype,
+ initializer=init_ops.constant_initializer(bias_start, dtype=dtype))
+ return nn_ops.bias_add(res, biases)
diff --git a/tensorflow/python/ops/seq2seq.py b/tensorflow/python/ops/seq2seq.py
index 9ec12583de..5bda634aee 100644
--- a/tensorflow/python/ops/seq2seq.py
+++ b/tensorflow/python/ops/seq2seq.py
@@ -71,11 +71,12 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
+from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
# TODO(ebrevdo): Remove once _linear is fully deprecated.
-linear = rnn_cell._linear # pylint: disable=protected-access
+linear = rnn_cell_impl._linear # pylint: disable=protected-access
def _extract_argmax_and_embed(embedding, output_projection=None,
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index c46c24af9a..57e7742355 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -46,8 +46,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import six
-
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@@ -70,7 +68,8 @@ def string_split(source, delimiter=" "): # pylint: disable=invalid-name
If `delimiter` is an empty string, each element of the `source` is split
into individual strings, each containing one byte. (This includes splitting
- multibyte sequences of UTF-8.)
+ multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is
+ treated as a set of delimiters with each considered a potential split point.
For example:
N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output
@@ -89,17 +88,14 @@ def string_split(source, delimiter=" "): # pylint: disable=invalid-name
delimiter: `0-D` string `Tensor`, the delimiter character, the string should
be length 0 or 1.
+ Raises:
+ ValueError: If delimiter is not a string.
+
Returns:
A `SparseTensor` of rank `2`, the strings split according to the delimiter.
The first column of the indices corresponds to the row in `source` and the
second column corresponds to the index of the split component in this row.
-
- Raises:
- ValueError: If delimiter is not a single-byte character.
"""
- if isinstance(delimiter, six.string_types) and len(delimiter) > 1:
- raise ValueError("delimiter must be a single byte-character, got %s" %
- delimiter)
delimiter = ops.convert_to_tensor(delimiter, dtype=dtypes.string)
source = ops.convert_to_tensor(source, dtype=dtypes.string)
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py
index fca39e0ad5..09955e690c 100644
--- a/tensorflow/python/ops/template.py
+++ b/tensorflow/python/ops/template.py
@@ -30,7 +30,7 @@ __all__ = ["make_template"]
def make_template(name_, func_, create_scope_now_=False, unique_name_=None,
- **kwargs):
+ custom_getter_=None, **kwargs):
"""Given an arbitrary function, wrap it so that it does variable sharing.
This wraps `func_` in a Template and partially evaluates it. Templates are
@@ -118,6 +118,9 @@ def make_template(name_, func_, create_scope_now_=False, unique_name_=None,
unique_name_: When used, it overrides name_ and is not made unique. If a
template of the same scope/unique_name already exists and reuse is false,
an error is raised. Defaults to None.
+ custom_getter_: Optional custom getter for variables used in `func_`. See
+ the [`get_variable`](#get_variable) `custom_getter` documentation for
+ more information.
**kwargs: Keyword arguments to apply to `func_`.
Returns:
@@ -136,7 +139,7 @@ def make_template(name_, func_, create_scope_now_=False, unique_name_=None,
func_ = functools.partial(func_, **kwargs)
return Template(
name_, func_, create_scope_now=create_scope_now_,
- unique_name=unique_name_)
+ unique_name=unique_name_, custom_getter=custom_getter_)
def _skip_common_stack_elements(stacktrace, base_case):
@@ -159,7 +162,8 @@ class Template(object):
call.
"""
- def __init__(self, name, func, create_scope_now=False, unique_name=None):
+ def __init__(self, name, func, create_scope_now=False, unique_name=None,
+ custom_getter=None):
"""Creates a template for the given function.
Args:
@@ -179,6 +183,7 @@ class Template(object):
unique_name: When used, it overrides name_ and is not made unique. If a
template of the same scope/unique_name already exists and reuse is
false, an error is raised. Defaults to None.
+ custom_getter: optional custom getter to pass to variable_scope()
Raises:
ValueError: if the name is None.
@@ -187,11 +192,13 @@ class Template(object):
self._stacktrace = traceback.format_stack()[:-2]
self._name = name
self._unique_name = unique_name
+ self._custom_getter = custom_getter
if name is None:
raise ValueError("name cannot be None.")
if create_scope_now:
with variable_scope.variable_scope(
- self._unique_name, self._name) as vs:
+ self._unique_name, self._name,
+ custom_getter=self._custom_getter) as vs:
self._var_scope = vs
else:
self._var_scope = None
@@ -262,7 +269,8 @@ class Template(object):
# Subsequent calls should reuse variables.
self._variables_created = True
with variable_scope.variable_scope(
- self._unique_name, self._name) as vs:
+ self._unique_name, self._name,
+ custom_getter=self._custom_getter) as vs:
self._var_scope = vs
return self._call_func(args, kwargs, check_for_new_variables=False)
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 40c90dfba8..9f03ae6264 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -87,7 +87,7 @@ class Variable(object):
```python
# Add an Op to initialize global variables.
- init_op = tf.global_variable_initializers()
+ init_op = tf.global_variables_initializer()
# Launch the graph in a session.
with tf.Session() as sess:
@@ -518,6 +518,10 @@ class Variable(object):
You should use this instead of the variable itself to initialize another
variable with a value that depends on the value of this variable.
+ Beware of using initialized_value except during initialization:
+ initialized_value causes the Variable's initializer op to be run, so running
+ this op resets the variable to the initial value.
+
```python
# Initialize 'v' with a random tensor.
v = tf.Variable(tf.truncated_normal([10, 40]))
diff --git a/tensorflow/python/platform/app.py b/tensorflow/python/platform/app.py
index bd58db7b45..a47d183e60 100644
--- a/tensorflow/python/platform/app.py
+++ b/tensorflow/python/platform/app.py
@@ -18,9 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import sys
+import sys as _sys
from tensorflow.python.platform import flags
+from tensorflow.python.util.all_util import remove_undocumented
def run(main=None, argv=None):
@@ -36,8 +37,17 @@ def run(main=None, argv=None):
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access
- main = main or sys.modules['__main__'].main
+ main = main or _sys.modules['__main__'].main
# Call the main function, passing through any arguments
# to the final program.
- sys.exit(main(sys.argv[:1] + flags_passthrough))
+ _sys.exit(main(_sys.argv[:1] + flags_passthrough))
+
+
+_allowed_symbols = [
+ 'run',
+ # Allowed submodule.
+ 'flags',
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index db5768acb8..1663a1f251 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -2,7 +2,10 @@
# TensorFlow SavedModel.
package(
- default_visibility = ["//tensorflow/python/saved_model:__subpackages__"],
+ default_visibility = [
+ "//tensorflow/contrib/learn:__subpackages__",
+ "//tensorflow/python/saved_model:__subpackages__",
+ ],
)
licenses(["notice"]) # Apache 2.0
@@ -33,7 +36,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":constants",
- "//tensorflow:tensorflow_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
@@ -48,7 +50,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":constants",
- "//tensorflow:tensorflow_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:platform",
"//tensorflow/python:training",
@@ -94,7 +95,9 @@ py_library(
name = "utils",
srcs = ["utils.py"],
srcs_version = "PY2AND3",
- deps = ["//tensorflow/core:protos_all_py"],
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ ],
)
py_test(
@@ -111,6 +114,31 @@ py_test(
],
)
+py_library(
+ name = "signature_def_utils",
+ srcs = ["signature_def_utils.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":signature_constants",
+ ":utils",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+py_test(
+ name = "signature_def_utils_test",
+ size = "small",
+ srcs = [
+ "signature_def_utils_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:private"],
+ deps = [
+ ":signature_def_utils",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
diff --git a/tensorflow/python/saved_model/example/BUILD b/tensorflow/python/saved_model/example/BUILD
index 8198312109..5f4785676e 100644
--- a/tensorflow/python/saved_model/example/BUILD
+++ b/tensorflow/python/saved_model/example/BUILD
@@ -41,11 +41,13 @@ py_binary(
],
)
-# TODO(b/32248363): change saved_model_half_plus_two.py to accept output
-# location so that we can avoid writing to /tmp/ and copying the files from
-# /tmp/.
+# Genrule for SavedModel half-plus-two test data. Specifically, this genrule
+# exports the test SavedModel to a versioned directory in order to be compatible
+# with TensorFlow Serving model server requirements of a versioned subdirectory.
+# Please note that SavedModel itself accepts any valid directory as the save
+# location and does not perform any versioning.
genrule(
- name = "versioned_saved_model_half_plus_two_data",
+ name = "saved_model_half_plus_two_data",
outs = [
"saved_model_half_plus_two/00000123/saved_model.pb",
"saved_model_half_plus_two/00000123/assets/foo.txt",
@@ -57,10 +59,8 @@ genrule(
"saved_model_half_plus_two_pbtxt/00000123/variables/variables.index",
],
cmd =
- "rm -rf /tmp/saved_model; " +
- "./$(locations :saved_model_half_plus_two); " +
- "cp -r /tmp/saved_model/half_plus_two/* $(@D)/saved_model_half_plus_two/00000123; " +
- "cp -r /tmp/saved_model/half_plus_two_pbtxt/* $(@D)/saved_model_half_plus_two_pbtxt/00000123",
+ "rm -rf $(@D)/saved_model_half_plus_two $(@D)/saved_model_half_plus_two_pbtxt; " +
+ "./$(locations :saved_model_half_plus_two) --output_dir=$(@D)/saved_model_half_plus_two/00000123 --output_dir_pbtxt=$(@D)/saved_model_half_plus_two_pbtxt/00000123",
tools = [
":saved_model_half_plus_two",
],
diff --git a/tensorflow/python/saved_model/example/saved_model_half_plus_two.py b/tensorflow/python/saved_model/example/saved_model_half_plus_two.py
index 65eb0f2fd7..d0b7b80674 100644
--- a/tensorflow/python/saved_model/example/saved_model_half_plus_two.py
+++ b/tensorflow/python/saved_model/example/saved_model_half_plus_two.py
@@ -36,10 +36,18 @@ from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
-from tensorflow.python.saved_model import utils
from tensorflow.python.util import compat
+tf.app.flags.DEFINE_string("output_dir", "/tmp/saved_model_half_plus_two",
+ "Directory where to ouput SavedModel.")
+tf.app.flags.DEFINE_string("output_dir_pbtxt",
+ "/tmp/saved_model_half_plus_two_pbtxt",
+ "Directory where to ouput the text format of "
+ "SavedModel.")
+FLAGS = tf.flags.FLAGS
+
def _write_assets(assets_directory, assets_filename):
"""Writes asset files to be used with SavedModel for half plus two.
@@ -113,16 +121,31 @@ def _generate_saved_model_for_half_plus_two(export_dir, as_text=False):
output_tensor = meta_graph_pb2.TensorInfo()
output_tensor.name = tf.identity(y).name
signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor}
- signature_def = utils.build_signature_def(
+ signature_def = signature_def_utils.build_signature_def(
signature_inputs, signature_outputs,
signature_constants.REGRESS_METHOD_NAME)
+ # Set up the signature for Predict with input and output tensor
+ # specification.
+ predict_input_tensor = meta_graph_pb2.TensorInfo()
+ predict_input_tensor.name = x.name
+ predict_signature_inputs = {
+ "x": predict_input_tensor
+ }
+ predict_signature_def = signature_def_utils.build_signature_def(
+ {"x": predict_input_tensor},
+ {"y": output_tensor},
+ signature_constants.PREDICT_METHOD_NAME)
+
# Initialize all variables and then save the SavedModel.
sess.run(tf.global_variables_initializer())
builder.add_meta_graph_and_variables(
sess, [tag_constants.SERVING],
signature_def_map={
- signature_constants.REGRESS_METHOD_NAME: signature_def
+ signature_constants.REGRESS_METHOD_NAME:
+ signature_def,
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ predict_signature_def
},
assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS),
legacy_init_op=tf.group(assign_filename_op))
@@ -130,13 +153,11 @@ def _generate_saved_model_for_half_plus_two(export_dir, as_text=False):
def main(_):
- export_dir_pb = "/tmp/saved_model/half_plus_two"
- _generate_saved_model_for_half_plus_two(export_dir_pb)
- print("SavedModel generated at: %s" % export_dir_pb)
+ _generate_saved_model_for_half_plus_two(FLAGS.output_dir)
+ print("SavedModel generated at: %s" % FLAGS.output_dir)
- export_dir_pbtxt = "/tmp/saved_model/half_plus_two_pbtxt"
- _generate_saved_model_for_half_plus_two(export_dir_pbtxt, as_text=True)
- print("SavedModel generated at: %s" % export_dir_pbtxt)
+ _generate_saved_model_for_half_plus_two(FLAGS.output_dir_pbtxt, as_text=True)
+ print("SavedModel generated at: %s" % FLAGS.output_dir_pbtxt)
if __name__ == "__main__":
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 0f8ddfc65b..bf5b186b80 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -27,8 +27,8 @@ from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
-from tensorflow.python.saved_model import utils
from tensorflow.python.util import compat
@@ -315,7 +315,8 @@ class SavedModelTest(tf.test.TestCase):
with self.test_session(graph=tf.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build and populate an empty SignatureDef for testing.
- foo_signature = utils.build_signature_def(dict(), dict(), "foo")
+ foo_signature = signature_def_utils.build_signature_def(
+ dict(), dict(), "foo")
builder.add_meta_graph_and_variables(
sess, ["foo"], signature_def_map={"foo_key": foo_signature})
@@ -324,10 +325,12 @@ class SavedModelTest(tf.test.TestCase):
with self.test_session(graph=tf.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 43)
# Build and populate a different SignatureDef for testing.
- bar_signature = utils.build_signature_def(dict(), dict(), "bar")
+ bar_signature = signature_def_utils.build_signature_def(
+ dict(), dict(), "bar")
# Also, build a different SignatureDef corresponding to "foo_key" defined
# in the previous graph.
- foo_new_signature = utils.build_signature_def(dict(), dict(), "foo_new")
+ foo_new_signature = signature_def_utils.build_signature_def(
+ dict(), dict(), "foo_new")
builder.add_meta_graph(
["bar"],
signature_def_map={"bar_key": bar_signature,
diff --git a/tensorflow/python/saved_model/signature_def_utils.py b/tensorflow/python/saved_model/signature_def_utils.py
new file mode 100644
index 0000000000..23e844adb2
--- /dev/null
+++ b/tensorflow/python/saved_model/signature_def_utils.py
@@ -0,0 +1,158 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""SignatureDef utility functions.
+
+Utility functions for constructing SignatureDef protos.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import utils
+
+
+def build_signature_def(inputs=None, outputs=None, method_name=None):
+ """Utility function to build a SignatureDef protocol buffer.
+
+ Args:
+ inputs: Inputs of the SignatureDef defined as a proto map of string to
+ tensor info.
+ outputs: Outputs of the SignatureDef defined as a proto map of string to
+ tensor info.
+ method_name: Method name of the SignatureDef as a string.
+
+ Returns:
+ A SignatureDef protocol buffer constructed based on the supplied arguments.
+ """
+ signature_def = meta_graph_pb2.SignatureDef()
+ if inputs is not None:
+ for item in inputs:
+ signature_def.inputs[item].CopyFrom(inputs[item])
+ if outputs is not None:
+ for item in outputs:
+ signature_def.outputs[item].CopyFrom(outputs[item])
+ if method_name is not None:
+ signature_def.method_name = method_name
+ return signature_def
+
+
+def regression_signature_def(examples, predictions):
+ """Creates regression signature from given examples and predictions.
+
+ Args:
+ examples: `Tensor`.
+ predictions: `Tensor`.
+
+ Returns:
+ A regression-flavored signature_def.
+
+ Raises:
+ ValueError: If examples is `None`.
+ """
+ if examples is None:
+ raise ValueError('examples cannot be None for regression.')
+ if predictions is None:
+ raise ValueError('predictions cannot be None for regression.')
+
+ input_tensor_info = utils.build_tensor_info(examples)
+ signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info}
+
+ output_tensor_info = utils.build_tensor_info(predictions)
+ signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info}
+ signature_def = build_signature_def(
+ signature_inputs, signature_outputs,
+ signature_constants.REGRESS_METHOD_NAME)
+
+ return signature_def
+
+
+def classification_signature_def(examples, classes, scores):
+ """Creates classification signature from given examples and predictions.
+
+ Args:
+ examples: `Tensor`.
+ classes: `Tensor`.
+ scores: `Tensor`.
+
+ Returns:
+ A classification-flavored signature_def.
+
+ Raises:
+ ValueError: If examples is `None`.
+ """
+ if examples is None:
+ raise ValueError('examples cannot be None for classification.')
+ if classes is None and scores is None:
+ raise ValueError('classes and scores cannot both be None for '
+ 'classification.')
+
+ input_tensor_info = utils.build_tensor_info(examples)
+ signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info}
+
+ signature_outputs = {}
+ if classes is not None:
+ classes_tensor_info = utils.build_tensor_info(classes)
+ signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = (
+ classes_tensor_info)
+ if scores is not None:
+ scores_tensor_info = utils.build_tensor_info(scores)
+ signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = (
+ scores_tensor_info)
+
+ signature_def = build_signature_def(
+ signature_inputs, signature_outputs,
+ signature_constants.CLASSIFY_METHOD_NAME)
+
+ return signature_def
+
+
+def predict_signature_def(inputs, outputs):
+ """Creates prediction signature from given inputs and outputs.
+
+ Args:
+ inputs: dict of string to `Tensor`.
+ outputs: dict of string to `Tensor`.
+
+ Returns:
+ A prediction-flavored signature_def.
+
+ Raises:
+ ValueError: If inputs or outputs is `None`.
+ """
+ if inputs is None or not inputs:
+ raise ValueError('inputs cannot be None or empty for prediction.')
+ if outputs is None:
+ raise ValueError('outputs cannot be None or empty for prediction.')
+
+ # If there's only one input or output, we can standardize keys
+ if len(inputs) == 1:
+ (_, value), = inputs.items()
+ inputs = {signature_constants.PREDICT_INPUTS: value}
+ if len(outputs) == 1:
+ (_, value), = outputs.items()
+ outputs = {signature_constants.PREDICT_OUTPUTS: value}
+
+ signature_inputs = {key: utils.build_tensor_info(tensor)
+ for key, tensor in inputs.items()}
+ signature_outputs = {key: utils.build_tensor_info(tensor)
+ for key, tensor in outputs.items()}
+
+ signature_def = build_signature_def(
+ signature_inputs, signature_outputs,
+ signature_constants.PREDICT_METHOD_NAME)
+
+ return signature_def
diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py
new file mode 100644
index 0000000000..6dfc4b2cd6
--- /dev/null
+++ b/tensorflow/python/saved_model/signature_def_utils_test.py
@@ -0,0 +1,156 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for SignatureDef utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import signature_def_utils
+from tensorflow.python.saved_model import utils
+
+
+class SignatureDefUtilsTest(tf.test.TestCase):
+
+ def testBuildSignatureDef(self):
+ x = tf.placeholder(tf.float32, 1, name="x")
+ x_tensor_info = utils.build_tensor_info(x)
+ inputs = dict()
+ inputs["foo-input"] = x_tensor_info
+
+ y = tf.placeholder(tf.float32, name="y")
+ y_tensor_info = utils.build_tensor_info(y)
+ outputs = dict()
+ outputs["foo-output"] = y_tensor_info
+
+ signature_def = signature_def_utils.build_signature_def(
+ inputs, outputs, "foo-method-name")
+ self.assertEqual("foo-method-name", signature_def.method_name)
+
+ # Check inputs in signature def.
+ self.assertEqual(1, len(signature_def.inputs))
+ x_tensor_info_actual = signature_def.inputs["foo-input"]
+ self.assertEqual("x:0", x_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype)
+ self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim))
+ self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size)
+
+ # Check outputs in signature def.
+ self.assertEqual(1, len(signature_def.outputs))
+ y_tensor_info_actual = signature_def.outputs["foo-output"]
+ self.assertEqual("y:0", y_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype)
+ self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim))
+
+ def testRegressionSignatureDef(self):
+ input1 = tf.constant("a", name="input-1")
+ output1 = tf.constant("b", name="output-1")
+ signature_def = signature_def_utils.regression_signature_def(
+ input1, output1)
+
+ self.assertEqual(signature_constants.REGRESS_METHOD_NAME,
+ signature_def.method_name)
+
+ # Check inputs in signature def.
+ self.assertEqual(1, len(signature_def.inputs))
+ x_tensor_info_actual = (
+ signature_def.inputs[signature_constants.REGRESS_INPUTS])
+ self.assertEqual("input-1:0", x_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype)
+ self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim))
+
+ # Check outputs in signature def.
+ self.assertEqual(1, len(signature_def.outputs))
+ y_tensor_info_actual = (
+ signature_def.outputs[signature_constants.REGRESS_OUTPUTS])
+ self.assertEqual("output-1:0", y_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, y_tensor_info_actual.dtype)
+ self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim))
+
+ def testClassificationSignatureDef(self):
+ input1 = tf.constant("a", name="input-1")
+ output1 = tf.constant("b", name="output-1")
+ output2 = tf.constant("c", name="output-2")
+ signature_def = signature_def_utils.classification_signature_def(
+ input1, output1, output2)
+
+ self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME,
+ signature_def.method_name)
+
+ # Check inputs in signature def.
+ self.assertEqual(1, len(signature_def.inputs))
+ x_tensor_info_actual = (
+ signature_def.inputs[signature_constants.CLASSIFY_INPUTS])
+ self.assertEqual("input-1:0", x_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype)
+ self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim))
+
+ # Check outputs in signature def.
+ self.assertEqual(2, len(signature_def.outputs))
+ classes_tensor_info_actual = (
+ signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES])
+ self.assertEqual("output-1:0", classes_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, classes_tensor_info_actual.dtype)
+ self.assertEqual(0, len(classes_tensor_info_actual.tensor_shape.dim))
+ scores_tensor_info_actual = (
+ signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES])
+ self.assertEqual("output-2:0", scores_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, scores_tensor_info_actual.dtype)
+ self.assertEqual(0, len(scores_tensor_info_actual.tensor_shape.dim))
+
+ def testPredictionSignatureDef(self):
+ input1 = tf.constant("a", name="input-1")
+ input2 = tf.constant("b", name="input-2")
+ output1 = tf.constant("c", name="output-1")
+ output2 = tf.constant("d", name="output-2")
+ signature_def = signature_def_utils.predict_signature_def(
+ {"input-1": input1, "input-2": input2},
+ {"output-1": output1, "output-2": output2})
+
+ self.assertEqual(signature_constants.PREDICT_METHOD_NAME,
+ signature_def.method_name)
+
+ # Check inputs in signature def.
+ self.assertEqual(2, len(signature_def.inputs))
+ input1_tensor_info_actual = (
+ signature_def.inputs["input-1"])
+ self.assertEqual("input-1:0", input1_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype)
+ self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim))
+ input2_tensor_info_actual = (
+ signature_def.inputs["input-2"])
+ self.assertEqual("input-2:0", input2_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype)
+ self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim))
+
+ # Check outputs in signature def.
+ self.assertEqual(2, len(signature_def.outputs))
+ output1_tensor_info_actual = (
+ signature_def.outputs["output-1"])
+ self.assertEqual("output-1:0", output1_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, output1_tensor_info_actual.dtype)
+ self.assertEqual(0, len(output1_tensor_info_actual.tensor_shape.dim))
+ output2_tensor_info_actual = (
+ signature_def.outputs["output-2"])
+ self.assertEqual("output-2:0", output2_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype)
+ self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim))
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/saved_model/utils.py b/tensorflow/python/saved_model/utils.py
index 550eed0fcc..ecc58fbc7a 100644
--- a/tensorflow/python/saved_model/utils.py
+++ b/tensorflow/python/saved_model/utils.py
@@ -23,6 +23,7 @@ from __future__ import print_function
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.framework import dtypes
+
# TensorInfo helpers.
@@ -40,30 +41,3 @@ def build_tensor_info(tensor):
name=tensor.name,
dtype=dtype_enum,
tensor_shape=tensor.get_shape().as_proto())
-
-# SignatureDef helpers.
-
-
-def build_signature_def(inputs=None, outputs=None, method_name=None):
- """Utility function to build a SignatureDef protocol buffer.
-
- Args:
- inputs: Inputs of the SignatureDef defined as a proto map of string to
- tensor info.
- outputs: Outputs of the SignatureDef defined as a proto map of string to
- tensor info.
- method_name: Method name of the SignatureDef as a string.
-
- Returns:
- A SignatureDef protocol buffer constructed based on the supplied arguments.
- """
- signature_def = meta_graph_pb2.SignatureDef()
- if inputs is not None:
- for item in inputs:
- signature_def.inputs[item].CopyFrom(inputs[item])
- if outputs is not None:
- for item in outputs:
- signature_def.outputs[item].CopyFrom(outputs[item])
- if method_name is not None:
- signature_def.method_name = method_name
- return signature_def
diff --git a/tensorflow/python/saved_model/utils_test.py b/tensorflow/python/saved_model/utils_test.py
index 8ce7d1dea1..74f2624773 100644
--- a/tensorflow/python/saved_model/utils_test.py
+++ b/tensorflow/python/saved_model/utils_test.py
@@ -34,36 +34,6 @@ class UtilsTest(tf.test.TestCase):
self.assertEqual(1, len(x_tensor_info.tensor_shape.dim))
self.assertEqual(1, x_tensor_info.tensor_shape.dim[0].size)
- def testBuildSignatureDef(self):
- x = tf.placeholder(tf.float32, 1, name="x")
- x_tensor_info = utils.build_tensor_info(x)
- inputs = dict()
- inputs["foo-input"] = x_tensor_info
-
- y = tf.placeholder(tf.float32, name="y")
- y_tensor_info = utils.build_tensor_info(y)
- outputs = dict()
- outputs["foo-output"] = y_tensor_info
-
- signature_def = utils.build_signature_def(inputs, outputs,
- "foo-method-name")
- self.assertEqual("foo-method-name", signature_def.method_name)
-
- # Check inputs in signature def.
- self.assertEqual(1, len(signature_def.inputs))
- x_tensor_info_actual = signature_def.inputs["foo-input"]
- self.assertEqual("x:0", x_tensor_info_actual.name)
- self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype)
- self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim))
- self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size)
-
- # Check outputs in signature def.
- self.assertEqual(1, len(signature_def.outputs))
- y_tensor_info_actual = signature_def.outputs["foo-output"]
- self.assertEqual("y:0", y_tensor_info_actual.name)
- self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype)
- self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim))
-
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/summary/event_accumulator_test.py b/tensorflow/python/summary/event_accumulator_test.py
index aa85ea56ab..6d659e27e3 100644
--- a/tensorflow/python/summary/event_accumulator_test.py
+++ b/tensorflow/python/summary/event_accumulator_test.py
@@ -645,7 +645,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
ipt = tf.placeholder(tf.float32)
tf.summary.scalar('scalar1', ipt)
tf.summary.scalar('scalar2', ipt * ipt)
- merged = tf.merge_all_summaries()
+ merged = tf.contrib.deprecated.merge_all_summaries()
writer.add_graph(sess.graph)
for i in xrange(10):
summ = sess.run(merged, feed_dict={ipt: i})
@@ -692,7 +692,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
tf.summary.image('images', ipt, max_outputs=2)
with tf.name_scope('3'):
tf.summary.image('images', ipt, max_outputs=3)
- merged = tf.merge_all_summaries()
+ merged = tf.contrib.deprecated.merge_all_summaries()
writer.add_graph(sess.graph)
for i in xrange(10):
summ = sess.run(merged)
@@ -736,7 +736,7 @@ class RealisticEventAccumulatorTest(EventAccumulatorTest):
gfile.DeleteRecursively(directory)
gfile.MkDir(directory)
- writer = tf.train.SummaryWriter(directory, max_queue=100)
+ writer = tf.summary.FileWriter(directory, max_queue=100)
with tf.Graph().as_default() as graph:
_ = tf.constant([2.0, 1.0])
@@ -814,7 +814,7 @@ class RealisticEventAccumulatorTest(EventAccumulatorTest):
gfile.DeleteRecursively(directory)
gfile.MkDir(directory)
- writer = tf.train.SummaryWriter(directory, max_queue=100)
+ writer = tf.summary.FileWriter(directory, max_queue=100)
with tf.Graph().as_default() as graph:
_ = tf.constant([2.0, 1.0])
diff --git a/tensorflow/python/summary/impl/event_file_loader.py b/tensorflow/python/summary/impl/event_file_loader.py
index dedebe5484..ccc61d4564 100644
--- a/tensorflow/python/summary/impl/event_file_loader.py
+++ b/tensorflow/python/summary/impl/event_file_loader.py
@@ -52,7 +52,15 @@ class EventFileLoader(object):
Yields:
All values that were written to disk that have not been yielded yet.
"""
- while self._reader.GetNext():
+ while True:
+ try:
+ with errors.raise_exception_on_not_ok_status() as status:
+ self._reader.GetNext(status)
+ except (errors.DataLossError, errors.OutOfRangeError):
+ # We ignore partial read exceptions, because a record may be truncated.
+ # PyRecordReader holds the offset prior to the failed read, so retrying
+ # will succeed.
+ break
event = event_pb2.Event()
event.ParseFromString(self._reader.record())
yield event
diff --git a/tensorflow/python/summary/impl/event_file_loader_test.py b/tensorflow/python/summary/impl/event_file_loader_test.py
index f4d7cf218e..0b354d553d 100644
--- a/tensorflow/python/summary/impl/event_file_loader_test.py
+++ b/tensorflow/python/summary/impl/event_file_loader_test.py
@@ -78,6 +78,15 @@ class EventFileLoaderTest(test_util.TensorFlowTestCase):
loader = self._LoaderForTestFile(filename)
self.assertEqual(len(list(loader.Load())), 2)
+ def testMultipleWritesWithBadWrite(self):
+ filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name
+ self._WriteToFile(filename, EventFileLoaderTest.RECORD)
+ self._WriteToFile(filename, EventFileLoaderTest.RECORD)
+ # Test that we ignore partial record writes at the end of the file.
+ self._WriteToFile(filename, b'123')
+ loader = self._LoaderForTestFile(filename)
+ self.assertEqual(len(list(loader.Load())), 2)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py
index 4e29bbd88d..2e653106f4 100644
--- a/tensorflow/python/summary/summary.py
+++ b/tensorflow/python/summary/summary.py
@@ -17,6 +17,7 @@
### Class for writing Summaries
@@FileWriter
+@@FileWriterCache
### Summary Ops
@@tensor_summary
@@ -56,9 +57,10 @@ from tensorflow.python.ops import gen_logging_ops as _gen_logging_ops
from tensorflow.python.ops.summary_ops import tensor_summary
# pylint: enable=unused-import
from tensorflow.python.platform import tf_logging as _logging
-# exports FileWriter
+# exports FileWriter, FileWriterCache
# pylint: disable=unused-import
from tensorflow.python.summary.writer.writer import FileWriter
+from tensorflow.python.summary.writer.writer_cache import FileWriterCache
# pylint: enable=unused-import
from tensorflow.python.util import compat as _compat
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/python/summary/summary_iterator.py b/tensorflow/python/summary/summary_iterator.py
index 9c3e8fcf4e..490ce141f1 100644
--- a/tensorflow/python/summary/summary_iterator.py
+++ b/tensorflow/python/summary/summary_iterator.py
@@ -79,7 +79,7 @@ class SummaryWriter(object):
# Launch the graph in a session.
sess = tf.Session()
# Create a summary writer, add the 'graph' to the event file.
- writer = tf.train.SummaryWriter(<some-directory>, sess.graph)
+ writer = tf.summary.FileWriter(<some-directory>, sess.graph)
```
The other arguments to the constructor control the asynchronous writes to
@@ -342,7 +342,7 @@ def summary_iterator(path):
# This example supposes that the events file contains summaries with a
# summary value tag 'loss'. These could have been added by calling
# `add_summary()`, passing the output of a scalar summary op created with
- # with: `tf.scalar_summary(['loss'], loss_tensor)`.
+ # with: `tf.summary.scalar('loss', loss_tensor)`.
for e in tf.train.summary_iterator(path to events file):
for v in e.summary.value:
if v.tag == 'loss':
diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py
index fa1715bbcb..fc90d547dc 100644
--- a/tensorflow/python/summary/writer/writer.py
+++ b/tensorflow/python/summary/writer/writer.py
@@ -66,7 +66,7 @@ class SummaryToEventTransformer(object):
# Launch the graph in a session.
sess = tf.Session()
# Create a summary writer, add the 'graph' to the event file.
- writer = tf.train.SummaryWriter(<some-directory>, sess.graph)
+ writer = tf.summary.FileWriter(<some-directory>, sess.graph)
```
@@ -286,7 +286,7 @@ class FileWriter(SummaryToEventTransformer):
# Launch the graph in a session.
sess = tf.Session()
# Create a summary writer, add the 'graph' to the event file.
- writer = tf.train.SummaryWriter(<some-directory>, sess.graph)
+ writer = tf.summary.FileWriter(<some-directory>, sess.graph)
```
The other arguments to the constructor control the asynchronous writes to
diff --git a/tensorflow/python/summary/writer/writer_cache.py b/tensorflow/python/summary/writer/writer_cache.py
index 7655fc5ba4..21870e788e 100644
--- a/tensorflow/python/summary/writer/writer_cache.py
+++ b/tensorflow/python/summary/writer/writer_cache.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Reads Summaries from and writes Summaries to event files."""
+"""A cache for FileWriters."""
from __future__ import absolute_import
from __future__ import division
@@ -21,38 +21,38 @@ from __future__ import print_function
import threading
from tensorflow.python.framework import ops
-from tensorflow.python.summary.writer.writer import FileWriter as SummaryWriter
+from tensorflow.python.summary.writer.writer import FileWriter
-class SummaryWriterCache(object):
- """Cache for summary writers.
+class FileWriterCache(object):
+ """Cache for file writers.
- This class caches summary writers, one per directory.
+ This class caches file writers, one per directory.
"""
# Cache, keyed by directory.
_cache = {}
- # Lock protecting _SUMMARY_WRITERS.
+ # Lock protecting _FILE_WRITERS.
_lock = threading.RLock()
@staticmethod
def clear():
"""Clear cached summary writers. Currently only used for unit tests."""
- with SummaryWriterCache._lock:
- SummaryWriterCache._cache = {}
+ with FileWriterCache._lock:
+ FileWriterCache._cache = {}
@staticmethod
def get(logdir):
- """Returns the SummaryWriter for the specified directory.
+ """Returns the FileWriter for the specified directory.
Args:
logdir: str, name of the directory.
Returns:
- A `SummaryWriter`.
+ A `FileWriter`.
"""
- with SummaryWriterCache._lock:
- if logdir not in SummaryWriterCache._cache:
- SummaryWriterCache._cache[logdir] = SummaryWriter(
+ with FileWriterCache._lock:
+ if logdir not in FileWriterCache._cache:
+ FileWriterCache._cache[logdir] = FileWriter(
logdir, graph=ops.get_default_graph())
- return SummaryWriterCache._cache[logdir]
+ return FileWriterCache._cache[logdir]
diff --git a/tensorflow/python/summary/writer/writer_test.py b/tensorflow/python/summary/writer/writer_test.py
index aeaebc2092..466f691dd6 100644
--- a/tensorflow/python/summary/writer/writer_test.py
+++ b/tensorflow/python/summary/writer/writer_test.py
@@ -83,7 +83,7 @@ class SummaryWriterTestCase(tf.test.TestCase):
def testAddingSummaryGraphAndRunMetadata(self):
test_dir = self._CleanTestDir("basics")
- sw = tf.train.SummaryWriter(test_dir)
+ sw = tf.summary.FileWriter(test_dir)
sw.add_session_log(tf.SessionLog(status=SessionLog.START), 1)
sw.add_summary(
@@ -154,7 +154,7 @@ class SummaryWriterTestCase(tf.test.TestCase):
test_dir = self._CleanTestDir("basics_named_graph")
with tf.Graph().as_default() as g:
tf.constant([12], name="douze")
- sw = tf.train.SummaryWriter(test_dir, graph=g)
+ sw = tf.summary.FileWriter(test_dir, graph=g)
sw.close()
self._assertEventsWithGraph(test_dir, g, True)
@@ -162,7 +162,7 @@ class SummaryWriterTestCase(tf.test.TestCase):
test_dir = self._CleanTestDir("basics_positional_graph")
with tf.Graph().as_default() as g:
tf.constant([12], name="douze")
- sw = tf.train.SummaryWriter(test_dir, g)
+ sw = tf.summary.FileWriter(test_dir, g)
sw.close()
self._assertEventsWithGraph(test_dir, g, True)
@@ -171,7 +171,7 @@ class SummaryWriterTestCase(tf.test.TestCase):
with tf.Graph().as_default() as g:
tf.constant([12], name="douze")
gd = g.as_graph_def()
- sw = tf.train.SummaryWriter(test_dir, graph_def=gd)
+ sw = tf.summary.FileWriter(test_dir, graph_def=gd)
sw.close()
self._assertEventsWithGraph(test_dir, g, False)
@@ -180,7 +180,7 @@ class SummaryWriterTestCase(tf.test.TestCase):
with tf.Graph().as_default() as g:
tf.constant([12], name="douze")
gd = g.as_graph_def()
- sw = tf.train.SummaryWriter(test_dir, gd)
+ sw = tf.summary.FileWriter(test_dir, gd)
sw.close()
self._assertEventsWithGraph(test_dir, g, False)
@@ -190,18 +190,18 @@ class SummaryWriterTestCase(tf.test.TestCase):
with tf.Graph().as_default() as g:
tf.constant([12], name="douze")
gd = g.as_graph_def()
- sw = tf.train.SummaryWriter(test_dir, graph=g, graph_def=gd)
+ sw = tf.summary.FileWriter(test_dir, graph=g, graph_def=gd)
sw.close()
def testNeitherGraphNorGraphDef(self):
with self.assertRaises(TypeError):
test_dir = self._CleanTestDir("basics_string_instead_of_graph")
- sw = tf.train.SummaryWriter(test_dir, "string instead of graph object")
+ sw = tf.summary.FileWriter(test_dir, "string instead of graph object")
sw.close()
def testCloseAndReopen(self):
test_dir = self._CleanTestDir("close_and_reopen")
- sw = tf.train.SummaryWriter(test_dir)
+ sw = tf.summary.FileWriter(test_dir)
sw.add_session_log(tf.SessionLog(status=SessionLog.START), 1)
sw.close()
# Sleep at least one second to make sure we get a new event file name.
@@ -247,7 +247,7 @@ class SummaryWriterTestCase(tf.test.TestCase):
# protocol buffers correctly.
def testAddingSummariesFromSessionRunCalls(self):
test_dir = self._CleanTestDir("global_step")
- sw = tf.train.SummaryWriter(test_dir)
+ sw = tf.summary.FileWriter(test_dir)
with self.test_session():
i = tf.constant(1, dtype=tf.int32, shape=[])
l = tf.constant(2, dtype=tf.int64, shape=[])
@@ -314,9 +314,9 @@ class SummaryWriterCacheTest(tf.test.TestCase):
with tf.Graph().as_default():
dir1 = self._test_dir("test_cache_1")
dir2 = self._test_dir("test_cache_2")
- sw1 = tf.train.SummaryWriterCache.get(dir1)
- sw2 = tf.train.SummaryWriterCache.get(dir2)
- sw3 = tf.train.SummaryWriterCache.get(dir1)
+ sw1 = tf.summary.FileWriterCache.get(dir1)
+ sw2 = tf.summary.FileWriterCache.get(dir2)
+ sw3 = tf.summary.FileWriterCache.get(dir1)
self.assertEqual(sw1, sw3)
self.assertFalse(sw1 == sw2)
sw1.close()
@@ -331,9 +331,9 @@ class SummaryWriterCacheTest(tf.test.TestCase):
def test_clear(self):
with tf.Graph().as_default():
dir1 = self._test_dir("test_clear")
- sw1 = tf.train.SummaryWriterCache.get(dir1)
- tf.train.SummaryWriterCache.clear()
- sw2 = tf.train.SummaryWriterCache.get(dir1)
+ sw1 = tf.summary.FileWriterCache.get(dir1)
+ tf.summary.FileWriterCache.clear()
+ sw2 = tf.summary.FileWriterCache.get(dir1)
self.assertFalse(sw1 == sw2)
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 97c942320a..0f7deb7827 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -29,7 +29,6 @@ limitations under the License.
%include "tensorflow/python/client/tf_session.i"
%include "tensorflow/python/client/device_lib.i"
-%include "tensorflow/python/client/net_lib.i"
%include "tensorflow/python/client/quantize_training.i"
%include "tensorflow/python/lib/io/file_io.i"
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index ca8c537d55..542396003c 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -314,7 +314,6 @@ class StepCounterHook(session_run_hook.SessionRunHook):
every_n_secs=None,
output_dir=None,
summary_writer=None):
- self._summary_tag = "global_step/sec"
if (every_n_steps is None) == (every_n_secs is None):
raise ValueError(
@@ -328,6 +327,7 @@ class StepCounterHook(session_run_hook.SessionRunHook):
def begin(self):
self._global_step_tensor = training_util.get_global_step()
+ self._summary_tag = self._global_step_tensor.op.name + "/sec"
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use StepCounterHook.")
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index f7eab4a3b5..1b8ebd11f3 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -419,6 +419,34 @@ class StepCounterHookTest(tf.test.TestCase):
self.assertEqual('global_step/sec', summary_value.tag)
self.assertGreater(summary_value.simple_value, 0)
+ def test_global_step_name(self):
+ with tf.Graph().as_default() as g, tf.Session() as sess:
+ with tf.variable_scope('bar'):
+ foo_step = tf.get_variable('foo', initializer=0, trainable=False,
+ collections=[tf.GraphKeys.GLOBAL_STEP,
+ tf.GraphKeys.GLOBAL_VARIABLES])
+ train_op = tf.assign_add(foo_step, 1)
+ summary_writer = testing.FakeSummaryWriter(self.log_dir, g)
+ hook = tf.train.StepCounterHook(
+ summary_writer=summary_writer, every_n_steps=1, every_n_secs=None)
+
+ hook.begin()
+ sess.run(tf.global_variables_initializer())
+ mon_sess = monitored_session._HookedSession(sess, [hook])
+ mon_sess.run(train_op)
+ mon_sess.run(train_op)
+ hook.end(sess)
+
+ summary_writer.assert_summaries(
+ test_case=self,
+ expected_logdir=self.log_dir,
+ expected_graph=g,
+ expected_summaries={})
+ self.assertTrue(summary_writer.summaries, 'No summaries were created.')
+ self.assertItemsEqual([2], summary_writer.summaries.keys())
+ summary_value = summary_writer.summaries[2][0].value[0]
+ self.assertEqual('bar/foo/sec', summary_value.tag)
+
class SummarySaverHookTest(tf.test.TestCase):
@@ -581,7 +609,7 @@ class GlobalStepWaiterHookTest(tf.test.TestCase):
hook = tf.train.GlobalStepWaiterHook(wait_until_step=1000)
hook.begin()
with tf.Session() as sess:
- sess.run(tf.initialize_all_variables())
+ sess.run(tf.global_variables_initializer())
waiter = threading.Thread(
target=hook.before_run,
args=(tf.train.SessionRunContext(
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index bac38ee689..45438b1342 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -288,7 +288,8 @@ class ExponentialMovingAverage(object):
@@variables_to_restore
"""
- def __init__(self, decay, num_updates=None, name="ExponentialMovingAverage"):
+ def __init__(self, decay, num_updates=None, zero_debias=False,
+ name="ExponentialMovingAverage"):
"""Creates a new ExponentialMovingAverage object.
The `apply()` method has to be called to create shadow variables and add
@@ -305,11 +306,14 @@ class ExponentialMovingAverage(object):
Args:
decay: Float. The decay to use.
num_updates: Optional count of number of updates applied to variables.
+ zero_debias: If `True`, zero debias moving-averages that are initialized
+ with tensors.
name: String. Optional prefix name to use for the name of ops added in
`apply()`.
"""
self._decay = decay
self._num_updates = num_updates
+ self._zero_debias = zero_debias
self._name = name
self._averages = {}
@@ -373,7 +377,8 @@ class ExponentialMovingAverage(object):
var,
self._name,
colocate_with_primary=(var.op.type == "Variable"))
- zero_debias_true.add(avg)
+ if self._zero_debias:
+ zero_debias_true.add(avg)
self._averages[var] = avg
with ops.name_scope(self._name) as scope:
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index a892912cc8..dae89fbefe 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -89,6 +89,11 @@ def _Repeat(value, dim):
class ExponentialMovingAverageTest(tf.test.TestCase):
def _CheckDecay(self, ema, actual_decay, dim):
+ def _Scale(dk, steps):
+ if ema._zero_debias:
+ return 1 - dk ** (steps + 1)
+ else:
+ return 1
tens = _Repeat(10.0, dim)
thirties = _Repeat(30.0, dim)
var0 = tf.Variable(tens, name="v0")
@@ -133,7 +138,7 @@ class ExponentialMovingAverageTest(tf.test.TestCase):
self.assertAllClose(expected, avg0.eval())
expected = _Repeat(30.0 * dk + 30.0 * (1 - dk), dim)
self.assertAllClose(expected, avg1.eval())
- expected = _Repeat(0.0 * dk + (10.0 + 30.0) * (1 - dk) / (1 - dk ** 2), dim)
+ expected = _Repeat(0.0 * dk + (10.0 + 30.0) * (1 - dk) / _Scale(dk, 1), dim)
self.assertAllClose(expected, avg2.eval())
# Again, update the averages and check.
@@ -145,7 +150,7 @@ class ExponentialMovingAverageTest(tf.test.TestCase):
dim)
self.assertAllClose(expected, avg1.eval())
expected = _Repeat(((0.0 * dk + (10.0 + 30.0) * (1 - dk)) * dk +
- (10.0 + 30.0) * (1 - dk)) / (1 - dk ** 3),
+ (10.0 + 30.0) * (1 - dk)) / _Scale(dk, 2),
dim)
self.assertAllClose(expected, avg2.eval())
@@ -154,23 +159,47 @@ class ExponentialMovingAverageTest(tf.test.TestCase):
ema = tf.train.ExponentialMovingAverage(0.25)
self._CheckDecay(ema, actual_decay=0.25, dim=1)
+ def testAverageVariablesNoNumUpdates_Scalar_Debias(self):
+ with self.test_session():
+ ema = tf.train.ExponentialMovingAverage(0.25, zero_debias=True)
+ self._CheckDecay(ema, actual_decay=0.25, dim=1)
+
def testAverageVariablesNoNumUpdates_Vector(self):
with self.test_session():
ema = tf.train.ExponentialMovingAverage(0.25)
self._CheckDecay(ema, actual_decay=0.25, dim=5)
+ def testAverageVariablesNoNumUpdates_Vector_Debias(self):
+ with self.test_session():
+ ema = tf.train.ExponentialMovingAverage(0.25, zero_debias=True)
+ self._CheckDecay(ema, actual_decay=0.25, dim=5)
+
def testAverageVariablesNumUpdates_Scalar(self):
with self.test_session():
# With num_updates 1, the decay applied is 0.1818
ema = tf.train.ExponentialMovingAverage(0.25, num_updates=1)
self._CheckDecay(ema, actual_decay=0.181818, dim=1)
+ def testAverageVariablesNumUpdates_Scalar_Debias(self):
+ with self.test_session():
+ # With num_updates 1, the decay applied is 0.1818
+ ema = tf.train.ExponentialMovingAverage(
+ 0.25, num_updates=1, zero_debias=True)
+ self._CheckDecay(ema, actual_decay=0.181818, dim=1)
+
def testAverageVariablesNumUpdates_Vector(self):
with self.test_session():
# With num_updates 1, the decay applied is 0.1818
ema = tf.train.ExponentialMovingAverage(0.25, num_updates=1)
self._CheckDecay(ema, actual_decay=0.181818, dim=5)
+ def testAverageVariablesNumUpdates_Vector_Debias(self):
+ with self.test_session():
+ # With num_updates 1, the decay applied is 0.1818
+ ema = tf.train.ExponentialMovingAverage(
+ 0.25, num_updates=1, zero_debias=True)
+ self._CheckDecay(ema, actual_decay=0.181818, dim=5)
+
def testAverageVariablesWithControlDeps(self):
with self.test_session() as sess:
v0 = tf.Variable(0, name="v0")
@@ -195,14 +224,15 @@ class ExponentialMovingAverageTest(tf.test.TestCase):
self.assertEqual(1, sess.run(v0))
self.assertEqual([17.5], sess.run(v1_avg))
- def testAverageVariablesNames(self):
+ def averageVariablesNamesHelper(self, zero_debias):
with self.test_session():
v0 = tf.Variable(10.0, name="v0")
v1 = tf.Variable(30.0, name="v1")
# Add a non-trainable variable.
v2 = tf.Variable(20.0, name="v2", trainable=False)
tensor2 = v0 + v1
- ema = tf.train.ExponentialMovingAverage(0.25, name="foo")
+ ema = tf.train.ExponentialMovingAverage(
+ 0.25, zero_debias=zero_debias, name="foo")
self.assertEqual("v0/foo", ema.average_name(v0))
self.assertEqual("v1/foo", ema.average_name(v1))
self.assertEqual("add/foo", ema.average_name(tensor2))
@@ -212,21 +242,30 @@ class ExponentialMovingAverageTest(tf.test.TestCase):
# {v0/foo : v0,
# v1/foo : v1,
# add/foo : add/foo,
- # add/foo/biased: add/foo/biased,
- # add/foo/local_step: add/foo/local_step,
# v2 : v2}
+ expected_names = [ema.average_name(v0),
+ ema.average_name(v1),
+ ema.average_name(tensor2),
+ v2.op.name]
+ if zero_debias:
+ # vars_to_restore should also contain the following:
+ # {add/foo/biased: add/foo/biased,
+ # add/foo/local_step: add/foo/local_step}
+ expected_names += [ema.average_name(tensor2) + "/biased",
+ ema.average_name(tensor2) + "/local_step"]
self.assertEqual(sorted(vars_to_restore.keys()),
- sorted([ema.average_name(v0),
- ema.average_name(v1),
- ema.average_name(tensor2),
- ema.average_name(tensor2) + "/biased",
- ema.average_name(tensor2) + "/local_step",
- v2.op.name]))
+ sorted(expected_names))
self.assertEqual(ema.average_name(v0), ema.average(v0).op.name)
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name)
- def testAverageVariablesNamesRespectScope(self):
+ def testAverageVariablesNames(self):
+ self.averageVariablesNamesHelper(zero_debias=True)
+
+ def testAverageVariablesNamesNoDebias(self):
+ self.averageVariablesNamesHelper(zero_debias=False)
+
+ def averageVariablesNamesRespectScopeHelper(self, zero_debias):
# See discussion on #2740.
with self.test_session():
with tf.variable_scope("scope1"):
@@ -236,7 +275,8 @@ class ExponentialMovingAverageTest(tf.test.TestCase):
v2 = tf.Variable(20.0, name="v2", trainable=False)
tensor2 = v0 + v1
with tf.variable_scope("scope2"):
- ema = tf.train.ExponentialMovingAverage(0.25, name="foo")
+ ema = tf.train.ExponentialMovingAverage(
+ 0.25, zero_debias=zero_debias, name="foo")
self.assertEqual("scope2/scope1/v0/foo", ema.average_name(v0))
self.assertEqual("scope2/scope1/v1/foo", ema.average_name(v1))
self.assertEqual("scope2/scope1/add/foo", ema.average_name(tensor2))
@@ -246,22 +286,32 @@ class ExponentialMovingAverageTest(tf.test.TestCase):
# {scope2/scope1/v0/foo : v0,
# scope2/scope1/v1/foo : v1,
# scope2/scope1/add/foo : add/foo,
- # scope2/scope2/scope1/add/foo/biased: add/foo/biased,
- # scope2/scope2/scope1/add/foo/local_step: add/foo/local_step,
# scope1/v2 : v2}
- sc = "scope2/"
+ expected_names = [ema.average_name(v0),
+ ema.average_name(v1),
+ ema.average_name(tensor2),
+ v2.op.name]
+ if zero_debias:
+ # vars_to_restore should also contain the following:
+ # {scope2/scope2/scope1/add/foo/biased: add/foo/biased,
+ # scope2/scope2/scope1/add/foo/local_step: add/foo/local_step}
+ sc = "scope2/"
+ expected_names += [sc + ema.average_name(tensor2) + "/biased",
+ sc + ema.average_name(tensor2) + "/local_step"]
+
self.assertEqual(sorted(vars_to_restore.keys()),
- sorted([ema.average_name(v0),
- ema.average_name(v1),
- ema.average_name(tensor2),
- sc + ema.average_name(tensor2) + "/biased",
- sc + ema.average_name(tensor2) + "/local_step",
- v2.op.name]))
+ sorted(expected_names))
self.assertEqual(ema.average_name(v0), ema.average(v0).op.name)
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
self.assertEqual(ema.average_name(tensor2),
ema.average(tensor2).op.name)
+ def testAverageVariablesNamesRespectScope(self):
+ self.averageVariablesNamesRespectScopeHelper(zero_debias=True)
+
+ def testAverageVariablesNamesRespectScopeNoDebias(self):
+ self.averageVariablesNamesRespectScopeHelper(zero_debias=False)
+
def testSubsetAverageVariablesNames(self):
with self.test_session():
v0 = tf.Variable(10.0, name="v0")
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 68e4bfd0f8..cb4e1de235 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -36,7 +36,6 @@ from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import errors
-from tensorflow.python.framework import graph_util
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
@@ -62,6 +61,23 @@ _VARIABLE_OPS = set(["Variable",
"ResourceGather"])
+def _set_cpu0(device_string):
+ """Creates a new device string based on `device_string` but using /CPU:0.
+
+ If the device is already on /CPU:0, this is a no-op.
+
+ Args:
+ device_string: A device string.
+
+ Returns:
+ A device string.
+ """
+ parsed_device = pydev.DeviceSpec.from_string(device_string)
+ parsed_device.device_type = "CPU"
+ parsed_device.device_index = 0
+ return parsed_device.to_string()
+
+
class BaseSaverBuilder(object):
"""Base class for Savers.
@@ -380,8 +396,7 @@ class BaseSaverBuilder(object):
# available on the GPU.
# TODO(touts): Re-enable restore on GPU when we can support annotating
# string tensors as "HostMemory" inputs.
- with ops.device(
- graph_util.set_cpu0(saveable.device) if saveable.device else None):
+ with ops.device(_set_cpu0(saveable.device) if saveable.device else None):
with ops.control_dependencies(restore_control_inputs):
tensors = self.restore_op(filename_tensor, saveable, preferred_shard)
shapes = None
diff --git a/tensorflow/python/training/summary_io.py b/tensorflow/python/training/summary_io.py
index cbbc527d85..4d1c3e7954 100644
--- a/tensorflow/python/training/summary_io.py
+++ b/tensorflow/python/training/summary_io.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Reads Summaries from and writes Summaries to event files."""
from __future__ import absolute_import
@@ -22,7 +21,7 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.python.summary.summary_iterator import summary_iterator
from tensorflow.python.summary.writer.writer import FileWriter as _FileWriter
-from tensorflow.python.summary.writer.writer_cache import SummaryWriterCache
+from tensorflow.python.summary.writer.writer_cache import FileWriterCache as SummaryWriterCache
# pylint: enable=unused-import
from tensorflow.python.util.deprecation import deprecated
diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py
index 474d417947..dda0166aa6 100644
--- a/tensorflow/python/training/supervisor_test.py
+++ b/tensorflow/python/training/supervisor_test.py
@@ -418,7 +418,7 @@ class SupervisorTest(tf.test.TestCase):
tf.summary.scalar("c2", tf.constant(2))
tf.summary.scalar("c3", tf.constant(3))
summ = tf.summary.merge_all()
- sw = tf.train.SummaryWriter(logdir)
+ sw = tf.summary.FileWriter(logdir)
sv = tf.train.Supervisor(logdir="", summary_op=None, summary_writer=sw)
meta_graph_def = meta_graph.create_meta_graph_def()
sess = sv.prepare_or_wait_for_session("")
diff --git a/tensorflow/python/training/tensorboard_logging_test.py b/tensorflow/python/training/tensorboard_logging_test.py
index dd0ee372f9..286062cab7 100644
--- a/tensorflow/python/training/tensorboard_logging_test.py
+++ b/tensorflow/python/training/tensorboard_logging_test.py
@@ -35,7 +35,7 @@ class EventLoggingTest(tf.test.TestCase):
def setUp(self):
self._work_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
- self._sw = tf.train.SummaryWriter(self._work_dir)
+ self._sw = tf.summary.FileWriter(self._work_dir)
tensorboard_logging.set_summary_writer(self._sw)
self.addCleanup(shutil.rmtree, self._work_dir)
diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py
index 561a58d594..8e0d9bbb06 100644
--- a/tensorflow/python/util/deprecation.py
+++ b/tensorflow/python/util/deprecation.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import functools
import inspect
import re
@@ -114,7 +115,11 @@ def deprecated(date, instructions):
return deprecated_wrapper
-def deprecated_args(date, instructions, *deprecated_arg_names):
+DeprecatedArgSpec = collections.namedtuple(
+ 'DeprecatedArgSpec', ['position', 'has_ok_value', 'ok_value'])
+
+
+def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples):
"""Decorator for marking specific function arguments as deprecated.
This decorator logs a deprecation warning whenever the decorated function is
@@ -135,32 +140,77 @@ def deprecated_args(date, instructions, *deprecated_arg_names):
ISO 8601 (YYYY-MM-DD).
instructions: String. Instructions on how to update code using the
deprecated function.
- *deprecated_arg_names: String. The deprecated arguments.
+ *deprecated_arg_names_or_tuples: String. or 2-Tuple(String,
+ [ok_vals]). The string is the deprecated argument name.
+ Optionally, an ok-value may be provided. If the user provided
+ argument equals this value, the warning is suppressed.
Returns:
Decorated function or method.
Raises:
- ValueError: If date is not in ISO 8601 format, instructions are empty, or
- the deprecated arguments are not present in the function signature.
+ ValueError: If date is not in ISO 8601 format, instructions are
+ empty, the deprecated arguments are not present in the function
+ signature, or the second element of a deprecated_tuple is not a
+ list.
"""
_validate_deprecation_args(date, instructions)
- if not deprecated_arg_names:
+ if not deprecated_arg_names_or_tuples:
raise ValueError('Specify which argument is deprecated.')
+ def _get_arg_names_to_ok_vals():
+ """Returns a dict mapping arg_name to DeprecatedArgSpec w/o position."""
+ d = {}
+ for name_or_tuple in deprecated_arg_names_or_tuples:
+ if isinstance(name_or_tuple, tuple):
+ d[name_or_tuple[0]] = DeprecatedArgSpec(-1, True, name_or_tuple[1])
+ else:
+ d[name_or_tuple] = DeprecatedArgSpec(-1, False, None)
+ return d
+
+ def _get_deprecated_positional_arguments(names_to_ok_vals, arg_spec):
+ """Builds a dictionary from deprecated arguments to thier spec.
+
+ Returned dict is keyed by argument name.
+ Each value is a DeprecatedArgSpec with the following fields:
+ position: The zero-based argument position of the argument
+ within the signature. None if the argument isn't found in
+ the signature.
+ ok_values: Values of this argument for which warning will be
+ suppressed.
+
+ Args:
+ names_to_ok_vals: dict from string arg_name to a list of values,
+ possibly empty, which should not elicit a warning.
+ arg_spec: Output from inspect.getargspec on the called function.
+
+ Returns:
+ Dictionary from arg_name to DeprecatedArgSpec.
+ """
+ arg_name_to_pos = dict(
+ (name, pos) for (pos, name) in enumerate(arg_spec.args))
+ deprecated_positional_args = {}
+ for arg_name, spec in iter(names_to_ok_vals.items()):
+ if arg_name in arg_name_to_pos:
+ pos = arg_name_to_pos[arg_name]
+ deprecated_positional_args[arg_name] = DeprecatedArgSpec(
+ pos, spec.has_ok_value, spec.ok_value)
+ return deprecated_positional_args
+
def deprecated_wrapper(func):
"""Deprecation decorator."""
decorator_utils.validate_callable(func, 'deprecated_args')
+ deprecated_arg_names = _get_arg_names_to_ok_vals()
arg_spec = inspect.getargspec(func)
- deprecated_positions = [
- (i, arg_name) for (i, arg_name) in enumerate(arg_spec.args)
- if arg_name in deprecated_arg_names]
+ deprecated_positions = _get_deprecated_positional_arguments(
+ deprecated_arg_names, arg_spec)
+
is_varargs_deprecated = arg_spec.varargs in deprecated_arg_names
is_kwargs_deprecated = arg_spec.keywords in deprecated_arg_names
if (len(deprecated_positions) + is_varargs_deprecated + is_kwargs_deprecated
- != len(deprecated_arg_names)):
+ != len(deprecated_arg_names_or_tuples)):
known_args = arg_spec.args + [arg_spec.varargs, arg_spec.keywords]
missing_args = [arg_name for arg_name in deprecated_arg_names
if arg_name not in known_args]
@@ -172,15 +222,21 @@ def deprecated_args(date, instructions, *deprecated_arg_names):
def new_func(*args, **kwargs):
"""Deprecation wrapper."""
invalid_args = []
- for (i, arg_name) in deprecated_positions:
- if i < len(args):
+ named_args = inspect.getcallargs(func, *args, **kwargs)
+ for arg_name, spec in iter(deprecated_positions.items()):
+ if (spec.position < len(args) and
+ not (spec.has_ok_value and
+ named_args[arg_name] == spec.ok_value)):
invalid_args.append(arg_name)
if is_varargs_deprecated and len(args) > len(arg_spec.args):
invalid_args.append(arg_spec.varargs)
if is_kwargs_deprecated and kwargs:
invalid_args.append(arg_spec.keywords)
for arg_name in deprecated_arg_names:
- if arg_name in kwargs:
+ if (arg_name in kwargs and
+ not (deprecated_positions[arg_name].has_ok_value and
+ (named_args[arg_name] ==
+ deprecated_positions[arg_name].ok_value))):
invalid_args.append(arg_name)
for arg_name in invalid_args:
logging.warning(
diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py
index 791593189f..75bd054d7f 100644
--- a/tensorflow/python/util/deprecation_test.py
+++ b/tensorflow/python/util/deprecation_test.py
@@ -538,6 +538,39 @@ class DeprecatedArgsTest(tf.test.TestCase):
self.assertRegexpMatches(args1[0], r"deprecated and will be removed after")
self._assert_subset(set([date, instructions, "d2"]), set(args2[1:]))
+ @tf.test.mock.patch.object(logging, "warning", autospec=True)
+ def test_positional_and_named_with_ok_vals(self, mock_warning):
+ date = "2016-07-04"
+ instructions = "This is how you update..."
+
+ @deprecation.deprecated_args(
+ date,
+ instructions,
+ ("d1", None),
+ ("d2", "my_ok_val"))
+ def _fn(arg0, d1=None, arg1=2, d2=None):
+ return arg0 + arg1 if d1 else arg1 + arg0 if d2 else arg0 * arg1
+
+ # Assert calls without the deprecated arguments log nothing.
+ self.assertEqual(2, _fn(1, arg1=2))
+ self.assertEqual(0, mock_warning.call_count)
+
+ # Assert calls with the deprecated arguments log warnings.
+ self.assertEqual(2, _fn(1, False, 2, d2=False))
+ self.assertEqual(2, mock_warning.call_count)
+ (args1, _) = mock_warning.call_args_list[0]
+ self.assertRegexpMatches(args1[0], r"deprecated and will be removed after")
+ self._assert_subset(set([date, instructions, "d1"]), set(args1[1:]))
+ (args2, _) = mock_warning.call_args_list[1]
+ self.assertRegexpMatches(args1[0], r"deprecated and will be removed after")
+ self._assert_subset(set([date, instructions, "d2"]), set(args2[1:]))
+
+ # Assert calls with the deprecated arguments dont log warnings if
+ # the value matches the 'ok_val'.
+ mock_warning.reset_mock()
+ self.assertEqual(3, _fn(1, None, 2, d2="my_ok_val"))
+ self.assertEqual(0, mock_warning.call_count)
+
class DeprecatedArgValuesTest(tf.test.TestCase):
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
index 322d33ae26..6b31325694 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
@@ -349,31 +349,12 @@ bool CUDAExecutor::GetKernelMetadata(CUDAKernel *cuda_kernel,
bool CUDAExecutor::Launch(Stream *stream, const ThreadDim &thread_dims,
const BlockDim &block_dims, const KernelBase &kernel,
- const std::vector<KernelArg> &args) {
- CHECK_EQ(kernel.Arity(), args.size());
+ const KernelArgsArrayBase &args) {
+ CHECK_EQ(kernel.Arity(), args.number_of_arguments());
CUstream custream = AsCUDAStreamValue(stream);
const CUDAKernel *cuda_kernel = AsCUDAKernel(&kernel);
CUfunction cufunc = cuda_kernel->AsCUDAFunctionValue();
- std::vector<void *> addrs;
- addrs.reserve(args.size());
- int shmem_bytes = 0;
- for (size_t i = 0; i < args.size(); i++) {
- switch (args[i].type) {
- case KernelArg::kNormal:
- addrs.push_back(const_cast<void *>(
- static_cast<const void *>(args[i].data.begin())));
- break;
- case KernelArg::kSharedMemory:
- shmem_bytes += args[i].bytes;
- break;
- default:
- LOG(ERROR) << "Invalid kernel arg type passed (" << args[i].type
- << ") for arg " << i;
- return false;
- }
- }
-
// Only perform/print the occupancy check 1x.
launched_kernels_mu_.lock();
if (launched_kernels_.find(cufunc) == launched_kernels_.end()) {
@@ -389,11 +370,15 @@ bool CUDAExecutor::Launch(Stream *stream, const ThreadDim &thread_dims,
CUDADriver::FuncSetCacheConfig(cufunc, cuda_kernel->GetCUDACacheConfig());
}
- if (!CUDADriver::LaunchKernel(
- GetCudaContext(stream), cufunc, block_dims.x, block_dims.y,
- block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z,
- shmem_bytes, custream, addrs.data(), nullptr /* = extra */)) {
- LOG(ERROR) << "failed to launch CUDA kernel with args: " << args.size()
+ void **kernel_params = const_cast<void **>(args.argument_addresses().data());
+
+ if (!CUDADriver::LaunchKernel(GetCudaContext(stream), cufunc, block_dims.x,
+ block_dims.y, block_dims.z, thread_dims.x,
+ thread_dims.y, thread_dims.z,
+ args.number_of_shared_bytes(), custream,
+ kernel_params, nullptr /* = extra */)) {
+ LOG(ERROR) << "failed to launch CUDA kernel with args: "
+ << args.number_of_arguments()
<< "; thread dim: " << thread_dims.ToString()
<< "; block dim: " << block_dims.ToString();
return false;
@@ -849,18 +834,6 @@ bool CUDAExecutor::FillBlockDimLimit(BlockDim *block_dim_limit) const {
return true;
}
-KernelArg CUDAExecutor::DeviceMemoryToKernelArg(
- const DeviceMemoryBase &gpu_mem) const {
- const void* arg = gpu_mem.opaque();
- const uint8 *arg_ptr = reinterpret_cast<const uint8 *>(&arg);
-
- KernelArg kernel_arg;
- kernel_arg.type = KernelArg::kNormal;
- kernel_arg.data = port::InlinedVector<uint8, 4>(arg_ptr, arg_ptr + sizeof(arg));
- kernel_arg.bytes = sizeof(arg);
- return kernel_arg;
-}
-
bool CUDAExecutor::SupportsBlas() const { return true; }
bool CUDAExecutor::SupportsFft() const { return true; }
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
index 9e01f48781..3959d04439 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
@@ -76,7 +76,7 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
bool Launch(Stream *stream, const ThreadDim &thread_dims,
const BlockDim &block_dims, const KernelBase &k,
- const std::vector<KernelArg> &args) override;
+ const KernelArgsArrayBase &args) override;
void *Allocate(uint64 size) override;
@@ -186,9 +186,6 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
// will be only partially populated as a result, and an error will be logged.
bool FillBlockDimLimit(BlockDim *block_dim_limit) const;
- KernelArg DeviceMemoryToKernelArg(
- const DeviceMemoryBase &gpu_mem) const override;
-
bool SupportsBlas() const override;
blas::BlasSupport *CreateBlas() override;
diff --git a/tensorflow/stream_executor/kernel.h b/tensorflow/stream_executor/kernel.h
index 7742e066c7..3e5453e4c9 100644
--- a/tensorflow/stream_executor/kernel.h
+++ b/tensorflow/stream_executor/kernel.h
@@ -76,9 +76,10 @@ limitations under the License.
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/kernel_cache_config.h"
+#include "tensorflow/stream_executor/lib/array_slice.h"
+#include "tensorflow/stream_executor/lib/inlined_vector.h"
#include "tensorflow/stream_executor/lib/stringpiece.h"
#include "tensorflow/stream_executor/platform/port.h"
-#include "tensorflow/stream_executor/lib/inlined_vector.h"
namespace perftools {
namespace gputools {
@@ -265,24 +266,220 @@ struct IsSharedDeviceMemory<SharedDeviceMemory<U>> {
static constexpr bool value = true;
};
-// KernelArg encapsulates the information necessary for a back-end executor to
-// configure a kernel to launch using the given argument.
+// Basic data about a kernel argument.
struct KernelArg {
- // Indicates the type of an argument: normal, to be passed to the kernel
- // in the standard manner, or shared memory, which has distinct
- // rules for specification per backend.
- enum Type {
- kNormal,
- kSharedMemory,
- } type;
-
- // The data to pass to the kernel - either a pointer to device memory, or the
- // argument value. compact_array is used to prevent smaller args (ex. u8, u64)
- // from requiring heap allocation.
- port::InlinedVector<uint8, 4> data;
-
- // The size of this argument in bytes.
- uint64 bytes;
+ bool is_shared;
+ const void *address;
+ size_t size;
+};
+
+// An iterator for traversing all the arguments of a KernelArgsArray.
+class KernelArgIterator {
+ public:
+ KernelArgIterator(int number_of_argument_addresses,
+ int number_of_shared_memory_arguments,
+ const void *const *arg_addresses_data,
+ const size_t *arg_sizes_data,
+ const size_t *shmem_bytes_data,
+ const size_t *shmem_indices_data)
+ : arg_index_(0),
+ number_of_arguments_(number_of_argument_addresses +
+ number_of_shared_memory_arguments),
+ arg_address_iter_(arg_addresses_data),
+ arg_size_iter_(arg_sizes_data),
+ shmem_bytes_iter_(shmem_bytes_data),
+ shmem_indices_iter_(shmem_indices_data),
+ shmem_indices_end_(shmem_indices_data +
+ number_of_shared_memory_arguments) {}
+
+ // Returns true if another argument is present in the iterator.
+ bool has_next() { return arg_index_ < number_of_arguments_; }
+
+ // Returns the next argument in the iterator.
+ //
+ // Returns a default-constructed KernelArg if there is no next argument.
+ KernelArg next() {
+ KernelArg result;
+ if (!has_next()) {
+ return result;
+ } else if ((shmem_indices_iter_ != shmem_indices_end_) &&
+ (arg_index_ == *shmem_indices_iter_)) {
+ result.is_shared = true;
+ result.address = nullptr;
+ result.size = *shmem_bytes_iter_;
+ ++shmem_indices_iter_;
+ ++shmem_bytes_iter_;
+ } else {
+ result.is_shared = false;
+ result.address = *arg_address_iter_;
+ result.size = *arg_size_iter_;
+ ++arg_address_iter_;
+ ++arg_size_iter_;
+ }
+ ++arg_index_;
+ return result;
+ }
+
+ private:
+ int arg_index_;
+ int number_of_arguments_;
+ const void *const *arg_address_iter_;
+ const size_t *arg_size_iter_;
+ const size_t *shmem_bytes_iter_;
+ const size_t *shmem_indices_iter_;
+ const size_t *const shmem_indices_end_;
+};
+
+// Base class for KernelArgsArray.
+//
+// Supports all the getter methods that do not depend on the compile-time number
+// of arguments template parameter.
+//
+// This class exists as a way to pass kernel arguments to
+// StreamExecutorInterface::Launch. That Launch method is virtual, so it can't
+// be templated to accept any KernelArgsArray type, therfore a reference to this
+// base type is passed instead.
+//
+// Performance is not a concern here because each of these methods will be
+// called at most once per kernel launch. Past performance concerns with
+// KernelArgsArray have been in reference to the argument packing routines which
+// are called once per kernel argument. Those packing routines are now handled
+// by the templated KernelArgsArray subclass of this class where they can take
+// advantage of compile-time knowledge of the number of arguments in order to be
+// very efficient.
+class KernelArgsArrayBase {
+ public:
+ virtual ~KernelArgsArrayBase() = default;
+
+ // Gets the number of arguments added so far, including shared memory
+ // arguments.
+ virtual size_t number_of_arguments() const = 0;
+
+ // Gets the total number of shared memory bytes added so far.
+ virtual uint64 number_of_shared_bytes() const = 0;
+
+ // Gets the list of argument addresses.
+ virtual port::ArraySlice<const void *> argument_addresses() const = 0;
+
+ // Gets an iterator to the arguments in the array.
+ virtual KernelArgIterator arg_iterator() const = 0;
+};
+
+// A list of arguments for a kernel call.
+//
+// The template parameter kNumArgs is the maximum number of arguments which can
+// be stored in the list.
+//
+// Contains a list of addresses for non-shared-memory arguments and a list of
+// sizes for shared-memory arguments. Since the shared-memory arguments may be
+// interspersed with the non-shared-memory arguments, it also stores a list of
+// the indices at which the shared-memory arguments appeared.
+//
+// For example, if the argument address list contains {a, b, c, d, e}, the
+// shared-memory arguments list contains the sizes of {A, B, C}, and the
+// shared-memory indices list contains {0, 3, 5}, then the original list of
+// arguments was {A, a, b, B, c, C, d, e}.
+//
+// This way of storing the arguments makes CUDA kernel calls efficient because
+// they only require the argument address list and the total number of shared
+// bytes, but it also makes it possible for OpenCL kernel calls because they
+// depend on the location of each shared-memory argument and its size.
+//
+// Note that the code for adding arguments has been identified as a performance
+// hotspot in some real-world applications so this structure has been optimized
+// for the performance of argument adding.
+template <size_t kNumArgs>
+class KernelArgsArray : public KernelArgsArrayBase {
+ public:
+ explicit KernelArgsArray()
+ : total_shared_memory_bytes_(0),
+ number_of_argument_addresses_(0),
+ number_of_shared_memory_arguments_(0) {}
+
+ // Adds an argument to the list.
+ //
+ // Note that the address of the argument is stored, so the input must not go
+ // out of scope before the instance of this class that calls this method does.
+ template <typename T>
+ void add_argument(const T &arg) {
+ argument_addresses_[number_of_argument_addresses_] =
+ static_cast<const void *>(&arg);
+ argument_sizes_[number_of_argument_addresses_] = sizeof(arg);
+ ++number_of_argument_addresses_;
+ }
+
+ // Adds a device memory argument to the list.
+ void add_device_memory_argument(const DeviceMemoryBase &arg) {
+ const void **copy_ptr =
+ &device_memory_opaque_pointers_[number_of_argument_addresses_];
+ *copy_ptr = arg.opaque();
+ argument_addresses_[number_of_argument_addresses_] = copy_ptr;
+ argument_sizes_[number_of_argument_addresses_] = sizeof(void *);
+ ++number_of_argument_addresses_;
+ }
+
+ // Adds a shared memory argument to the list.
+ //
+ // The only significant information about a shared argument is its size, so
+ // that is the only parameter in this function.
+ void add_shared_bytes(size_t number_of_bytes) {
+ shared_memory_indices_[number_of_shared_memory_arguments_] =
+ number_of_argument_addresses_ + number_of_shared_memory_arguments_;
+ shared_memory_bytes_[number_of_shared_memory_arguments_] = number_of_bytes;
+ ++number_of_shared_memory_arguments_;
+ total_shared_memory_bytes_ += number_of_bytes;
+ }
+
+ // Gets the number of arguments added so far, including shared memory
+ // arguments.
+ size_t number_of_arguments() const override {
+ return number_of_argument_addresses_ + number_of_shared_memory_arguments_;
+ }
+
+ // Gets the total number of shared memory bytes added so far.
+ uint64 number_of_shared_bytes() const override {
+ return total_shared_memory_bytes_;
+ }
+
+ // Gets the list of argument addresses.
+ port::ArraySlice<const void *> argument_addresses() const override {
+ return port::ArraySlice<const void *>(argument_addresses_.data(),
+ number_of_argument_addresses_);
+ }
+
+ // Gets an iterator to the arguments in the array.
+ KernelArgIterator arg_iterator() const override {
+ return KernelArgIterator(
+ number_of_argument_addresses_, number_of_shared_memory_arguments_,
+ argument_addresses_.data(), argument_sizes_.data(),
+ shared_memory_bytes_.data(), shared_memory_indices_.data());
+ }
+
+ private:
+ // A place to store copies of opaque pointers from device memory arguments.
+ std::array<const void *, kNumArgs> device_memory_opaque_pointers_;
+
+ // Addresses for non-shared-memory arguments.
+ std::array<const void *, kNumArgs> argument_addresses_;
+
+ // Sizes for non-shared-memory arguments.
+ std::array<size_t, kNumArgs> argument_sizes_;
+
+ // Size in bytes for each shared memory argument.
+ std::array<size_t, kNumArgs> shared_memory_bytes_;
+
+ // Indices in the arguments array for shared memory arguments.
+ std::array<size_t, kNumArgs> shared_memory_indices_;
+
+ // Total of all shared memory sizes.
+ size_t total_shared_memory_bytes_;
+
+ // Number of significant entries in argument_addresses_ and argument_sizes_.
+ size_t number_of_argument_addresses_;
+
+ // Number of significant entries in shared_memory_bytes_ and
+ // shared_memory_indices_.
+ size_t number_of_shared_memory_arguments_;
};
// Typed variant of KernelBase, like a typed device function pointer. See the
@@ -298,6 +495,8 @@ struct KernelArg {
template <typename... Params>
class TypedKernel : public KernelBase {
public:
+ static constexpr size_t kNumberOfParameters = sizeof...(Params);
+
// Delegates to KernelBase::KernelBase(), see that constructor.
explicit TypedKernel(StreamExecutor *parent) : KernelBase(parent) {}
@@ -318,13 +517,19 @@ class TypedKernel : public KernelBase {
//
// Const refs are taken as parameters on all of the handlers to avoid
// implicit type promotion of integers.
- void PackParams(std::vector<KernelArg> *args, Params... params) const {
+ //
+ // WARNING: as a performance optimization this method may store pointers to
+ // some of the input parameters in the kernel args structure, so any params
+ // passed into this method must live at least as long as the kernel args
+ // structure.
+ void PackParams(KernelArgsArray<kNumberOfParameters> *args,
+ Params &... params) const {
PackOneParam(args, params...);
}
template <typename T, typename... RestOfParams>
- void PackOneParam(std::vector<KernelArg> *args, const T &arg,
- const RestOfParams... rest) const {
+ void PackOneParam(KernelArgsArray<kNumberOfParameters> *args, const T &arg,
+ const RestOfParams &... rest) const {
PackOneParam(args, arg);
PackOneParam(args, rest...);
}
@@ -334,7 +539,7 @@ class TypedKernel : public KernelBase {
// separate implementation below.
template <typename T>
void PackOneParam(
- std::vector<KernelArg> *args, const T &arg,
+ KernelArgsArray<kNumberOfParameters> *args, const T &arg,
typename std::enable_if<!IsDeviceMemoryValueLike<T>::value &&
!IsDeviceMemoryPointer<T>::value &&
!IsSharedDeviceMemory<T>::value>::type * =
@@ -343,44 +548,40 @@ class TypedKernel : public KernelBase {
"cannot pass raw pointer to the device");
static_assert(!std::is_convertible<T, DeviceMemoryBase>::value,
"cannot pass device memory as a normal value");
- const uint8 *arg_ptr = reinterpret_cast<const uint8 *>(&arg);
- args->emplace_back(KernelArg{
- KernelArg::kNormal,
- port::InlinedVector<uint8, 4>{arg_ptr, arg_ptr + sizeof(arg)}, sizeof(arg)});
+ args->add_argument(arg);
}
// DeviceMemoryBase family reference override.
template <typename T>
void PackOneParam(
- std::vector<KernelArg> *args, const T &arg,
+ KernelArgsArray<kNumberOfParameters> *args, const T &arg,
typename std::enable_if<IsDeviceMemoryValueLike<T>::value>::type * =
nullptr) const {
- args->emplace_back(parent()->DeviceMemoryToKernelArg(arg));
+ args->add_device_memory_argument(arg);
}
// DeviceMemoryBase family pointer override.
template <typename T>
void PackOneParam(
- std::vector<KernelArg> *args, T arg,
+ KernelArgsArray<kNumberOfParameters> *args, T arg,
typename std::enable_if<IsDeviceMemoryPointer<T>::value>::type * =
nullptr) const {
DeviceMemoryBase *ptr = static_cast<DeviceMemoryBase *>(arg);
- args->emplace_back(parent()->DeviceMemoryToKernelArg(*ptr));
+ args->add_device_memory_argument(*ptr);
}
// Dynamic shared device memory has a size, but no associated allocation on
// the host; internally, the device will allocate storage.
template <typename T>
void PackOneParam(
- std::vector<KernelArg> *args, T arg,
+ KernelArgsArray<kNumberOfParameters> *args, T arg,
typename std::enable_if<IsSharedDeviceMemory<T>::value>::type * =
nullptr) const {
- args->emplace_back(KernelArg{KernelArg::kSharedMemory,
- port::InlinedVector<uint8, 4>(), arg.size()});
+ args->add_shared_bytes(arg.size());
}
// Base case for variadic template expansion - nothing to do!
- void PackOneParam(std::vector<KernelArg> *args) const {}
+ void PackOneParam(KernelArgsArray<kNumberOfParameters> *args) const {}
SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel);
};
diff --git a/tensorflow/stream_executor/stream_executor.h b/tensorflow/stream_executor/stream_executor.h
index dd4664849d..2995dccf46 100644
--- a/tensorflow/stream_executor/stream_executor.h
+++ b/tensorflow/stream_executor/stream_executor.h
@@ -18,34 +18,6 @@ limitations under the License.
// * Loading/launching data-parallel-kernels
// * Invoking pre-canned high-performance library routines (like matrix
// multiply)
-//
-// The appropriately-typed kernel and "loader spec" are automatically generated
-// for the user within a namespace by the gcudacc compiler output, so typical
-// use looks like so:
-//
-// namespace gpu = ::perftools::gputools;
-// namespace gcudacc = ::platforms::gpus::gcudacc;
-//
-// gpu::StreamExecutor stream_exec{PlatformKind::kCuda};
-// gcudacc::kernel::MyKernel my_kernel{&stream_exec};
-// bool ok = stream_exec.GetKernel(gcudacc::spec::MyKernelSpec(),
-// &my_kernel);
-// if (!ok) { ... }
-// gpu::DeviceMemory<int> result = stream_exec.AllocateZeroed<int>();
-// if (result == nullptr) { ... }
-// int host_result;
-// gpu::Stream my_stream{&stream_exec};
-// my_stream
-// .Init()
-// .ThenLaunch(ThreadDim{1024}, BlockDim{1}, my_kernel, result)
-// .ThenMemcpy(&host_result, result, sizeof(host_result))
-// .BlockHostUntilDone()
-// if (!my_stream.ok()) { ... }
-// printf("%d\n", host_result);
-//
-// Since the device may operate asynchronously to the host, the
-// Stream::BlockHostUntilDone() call forces the calling host thread to wait for
-// the chain of commands specified for the Stream to complete execution.
#ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
#define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index acdbb07cb7..57db7775a6 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -184,7 +184,7 @@ class StreamExecutorInterface {
}
virtual bool Launch(Stream *stream, const ThreadDim &thread_dims,
const BlockDim &block_dims, const KernelBase &k,
- const std::vector<KernelArg> &args) {
+ const KernelArgsArrayBase &args) {
return false;
}
virtual void *Allocate(uint64 size) = 0;
@@ -258,9 +258,6 @@ class StreamExecutorInterface {
// caller.
virtual DeviceDescription *PopulateDeviceDescription() const = 0;
- virtual KernelArg DeviceMemoryToKernelArg(
- const DeviceMemoryBase &gpu_mem) const = 0;
-
// Attempts to register the provided TraceListener with the device-specific
// Executor implementation. When this is called, the PIMPL interface has
// already taken ownership of the object and is managing the generic tracing
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 2fdd1e4b49..7739d31662 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -394,7 +394,7 @@ rng::RngSupport *StreamExecutor::AsRng() {
bool StreamExecutor::Launch(Stream *stream, const ThreadDim &thread_dims,
const BlockDim &block_dims,
const KernelBase &kernel,
- const std::vector<KernelArg> &args) {
+ const KernelArgsArrayBase &args) {
SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
kernel, args);
@@ -659,11 +659,6 @@ bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
return implementation_->DeviceMemoryUsage(free, total);
}
-KernelArg StreamExecutor::DeviceMemoryToKernelArg(
- const DeviceMemoryBase &gpu_mem) const {
- return implementation_->DeviceMemoryToKernelArg(gpu_mem);
-}
-
void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
background_threads_->Schedule(task);
}
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 2b5a70f807..83fd27599e 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -392,7 +392,7 @@ class StreamExecutor {
// implementation in StreamExecutorInterface::Launch().
bool Launch(Stream *stream, const ThreadDim &thread_dims,
const BlockDim &block_dims, const KernelBase &kernel,
- const std::vector<KernelArg> &args);
+ const KernelArgsArrayBase &args);
// Gets-or-creates (creates with memoization) a FftSupport datatype that can
// be used to execute FFT routines on the current platform.
@@ -427,10 +427,6 @@ class StreamExecutor {
// previously registered.
bool UnregisterTraceListener(TraceListener* listener);
- // Converts a DeviceMemory object into a KernelArg object for passing to the
- // device driver for kernel launch.
- KernelArg DeviceMemoryToKernelArg(const DeviceMemoryBase &gpu_mem) const;
-
private:
template <typename BeginCallT, typename CompleteCallT,
typename ReturnT, typename... BeginArgsT>
@@ -758,9 +754,9 @@ inline Stream &Stream::ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
// we pack the variadic parameters passed as ...args into the desired
// tuple form and pass that packed form to the StreamExecutor::Launch()
// implementation.
- std::vector<KernelArg> kernel_args;
- kernel_args.reserve(kernel.Arity());
+ KernelArgsArray<sizeof...(args)> kernel_args;
kernel.PackParams(&kernel_args, args...);
+ DCHECK(parent_ != nullptr);
bool ok =
parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args);
if (!ok) {
diff --git a/tensorflow/stream_executor/trace_listener.h b/tensorflow/stream_executor/trace_listener.h
index 804c6ee8fa..88c54f982b 100644
--- a/tensorflow/stream_executor/trace_listener.h
+++ b/tensorflow/stream_executor/trace_listener.h
@@ -50,7 +50,7 @@ class TraceListener {
virtual void LaunchSubmit(Stream* stream, const ThreadDim& thread_dims,
const BlockDim& block_dims,
const KernelBase& kernel,
- const std::vector<KernelArg>& args) {}
+ const KernelArgsArrayBase& args) {}
virtual void SynchronousMemcpyH2DBegin(int64 correlation_id,
const void* host_src, int64 size,
diff --git a/tensorflow/tensorboard/backend/server_test.py b/tensorflow/tensorboard/backend/server_test.py
index 3dd7843e66..596fff2864 100644
--- a/tensorflow/tensorboard/backend/server_test.py
+++ b/tensorflow/tensorboard/backend/server_test.py
@@ -333,7 +333,7 @@ class TensorboardServerTest(tf.test.TestCase):
self.addCleanup(shutil.rmtree, temp_dir)
run1_path = os.path.join(temp_dir, 'run1')
os.makedirs(run1_path)
- writer = tf.train.SummaryWriter(run1_path)
+ writer = tf.summary.FileWriter(run1_path)
histogram_value = tf.HistogramProto(min=0,
max=2,
diff --git a/tensorflow/tensorboard/components/vz_projector/data.ts b/tensorflow/tensorboard/components/vz_projector/data.ts
index 34f275f546..4b5cb7a687 100644
--- a/tensorflow/tensorboard/components/vz_projector/data.ts
+++ b/tensorflow/tensorboard/components/vz_projector/data.ts
@@ -22,8 +22,7 @@ import {getSearchPredicate, runAsyncTask, shuffle} from './util';
import * as vector from './vector';
export type DistanceFunction = (a: number[], b: number[]) => number;
-export type PointAccessor = (index: number) => number;
-export type PointAccessors3D = [PointAccessor, PointAccessor, PointAccessor];
+export type ProjectionComponents3D = [string, string, string];
export interface PointMetadata { [key: string]: number|string; }
@@ -187,25 +186,6 @@ export class DataSet {
return traces;
}
- getPointAccessors(projection: ProjectionType, components: (number|string)[]):
- [PointAccessor, PointAccessor, PointAccessor] {
- if (components.length > 3) {
- throw new RangeError('components length must be <= 3');
- }
- const accessors: [PointAccessor, PointAccessor, PointAccessor] =
- [null, null, null];
- const prefix = (projection === 'custom') ? 'linear' : projection;
- for (let i = 0; i < components.length; ++i) {
- if (components[i] == null) {
- continue;
- }
- accessors[i] =
- (index =>
- this.points[index].projections[prefix + '-' + components[i]]);
- }
- return accessors;
- }
-
projectionCanBeRendered(projection: ProjectionType): boolean {
if (projection !== 'tsne') {
return true;
@@ -222,8 +202,9 @@ export class DataSet {
* @return A subset of the original dataset.
*/
getSubset(subset?: number[]): DataSet {
- let pointsSubset =
- subset && subset.length ? subset.map(i => this.points[i]) : this.points;
+ const pointsSubset = ((subset != null) && (subset.length > 0)) ?
+ subset.map(i => this.points[i]) :
+ this.points;
let points = pointsSubset.map(dp => {
return {
metadata: dp.metadata,
@@ -302,12 +283,13 @@ export class DataSet {
}
return newV;
});
- for (let j = 0; j < NUM_PCA_COMPONENTS; j++) {
- let label = 'pca-' + j;
+ for (let d = 0; d < NUM_PCA_COMPONENTS; d++) {
+ let label = 'pca-' + d;
this.projections.add(label);
- this.points.forEach((d, i) => {
- d.projections[label] = pcaVectors[i][j];
- });
+ for (let i = 0; i < pcaVectors.length; i++) {
+ let pointIndex = this.shuffledDataIndices[i];
+ this.points[pointIndex].projections[label] = pcaVectors[i][d];
+ }
}
});
}
@@ -418,8 +400,8 @@ export type ProjectionType = 'tsne' | 'pca' | 'custom';
export class Projection {
constructor(
public projectionType: ProjectionType,
- public pointAccessors: PointAccessors3D, public dimensionality: number,
- public dataSet: DataSet) {}
+ public projectionComponents: ProjectionComponents3D,
+ public dimensionality: number, public dataSet: DataSet) {}
}
export interface ColorOption {
@@ -489,6 +471,23 @@ export class State {
selectedLabelOption: string;
}
+export function getProjectionComponents(
+ projection: ProjectionType,
+ components: (number|string)[]): ProjectionComponents3D {
+ if (components.length > 3) {
+ throw new RangeError('components length must be <= 3');
+ }
+ const projectionComponents: [string, string, string] = [null, null, null];
+ const prefix = (projection === 'custom') ? 'linear' : projection;
+ for (let i = 0; i < components.length; ++i) {
+ if (components[i] == null) {
+ continue;
+ }
+ projectionComponents[i] = prefix + '-' + components[i];
+ }
+ return projectionComponents;
+}
+
export function stateGetAccessorDimensions(state: State): Array<number|string> {
let dimensions: Array<number|string>;
switch (state.selectedProjection) {
diff --git a/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts b/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts
index 7fa9924813..d00973935c 100644
--- a/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts
+++ b/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-import {DataSet, DistanceFunction, PointAccessors3D, Projection, State} from './data';
+import {DataSet, DistanceFunction, Projection, ProjectionComponents3D, State} from './data';
import {NearestEntry} from './knn';
import {ProjectorEventContext} from './projectorEventContext';
import {LabelRenderParams} from './renderContext';
@@ -82,11 +82,14 @@ export class ProjectorScatterPlotAdapter {
private selectedPointIndices: number[];
private neighborsOfFirstSelectedPoint: NearestEntry[];
private renderLabelsIn3D: boolean = false;
- private labelPointAccessor: (index: number) => string;
- private legendPointColorer: (index: number) => string;
+ private labelPointAccessor: (ds: DataSet, index: number) => string;
+ private legendPointColorer: (ds: DataSet, index: number) => string;
private distanceMetric: DistanceFunction;
+ private spriteVisualizer: ScatterPlotVisualizerSprites;
private labels3DVisualizer: ScatterPlotVisualizer3DLabels;
+ private canvasLabelsVisualizer: ScatterPlotVisualizerCanvasLabels;
+ private traceVisualizer: ScatterPlotVisualizerTraces;
constructor(
scatterPlotContainer: d3.Selection<any>,
@@ -102,6 +105,7 @@ export class ProjectorScatterPlotAdapter {
(selectedPointIndices, neighbors) => {
this.selectedPointIndices = selectedPointIndices;
this.neighborsOfFirstSelectedPoint = neighbors;
+ this.updateScatterPlotPositions();
this.updateScatterPlotAttributes();
this.scatterPlot.render();
});
@@ -124,6 +128,42 @@ export class ProjectorScatterPlotAdapter {
this.scatterPlot.render();
}
+ setDataSet(dataSet: DataSet) {
+ if (this.projection != null) {
+ // TODO(nicholsonc): setDataSet needs to go away, the projection is the
+ // atomic unit of update.
+ this.projection.dataSet = dataSet;
+ }
+ if (this.traceVisualizer != null) {
+ this.traceVisualizer.setDataSet(dataSet);
+ }
+ if (this.canvasLabelsVisualizer != null) {
+ this.canvasLabelsVisualizer.setDataSet(dataSet);
+ }
+ if (this.labels3DVisualizer != null) {
+ this.labels3DVisualizer.setDataSet(dataSet);
+ }
+ if (this.spriteVisualizer == null) {
+ return;
+ }
+ this.spriteVisualizer.clearSpriteAtlas();
+ if ((dataSet == null) || (dataSet.spriteAndMetadataInfo == null)) {
+ return;
+ }
+ const metadata = dataSet.spriteAndMetadataInfo;
+ if ((metadata.spriteImage == null) || (metadata.spriteMetadata == null)) {
+ return;
+ }
+ const n = dataSet.points.length;
+ const spriteIndices = new Float32Array(n);
+ for (let i = 0; i < n; ++i) {
+ spriteIndices[i] = dataSet.points[i].index;
+ }
+ this.spriteVisualizer.setSpriteAtlas(
+ metadata.spriteImage, metadata.spriteMetadata.singleImageDim,
+ spriteIndices);
+ }
+
set3DLabelMode(renderLabelsIn3D: boolean) {
this.renderLabelsIn3D = renderLabelsIn3D;
this.createVisualizers(renderLabelsIn3D);
@@ -131,14 +171,17 @@ export class ProjectorScatterPlotAdapter {
this.scatterPlot.render();
}
- setLegendPointColorer(legendPointColorer: (index: number) => string) {
+ setLegendPointColorer(
+ legendPointColorer: (ds: DataSet, index: number) => string) {
this.legendPointColorer = legendPointColorer;
}
- setLabelPointAccessor(labelPointAccessor: (index: number) => string) {
+ setLabelPointAccessor(
+ labelPointAccessor: (ds: DataSet, index: number) => string) {
this.labelPointAccessor = labelPointAccessor;
if (this.labels3DVisualizer != null) {
- this.labels3DVisualizer.setLabelAccessor(labelPointAccessor);
+ this.labels3DVisualizer.setLabelStrings(this.generate3DLabelsArray(
+ this.projection.dataSet, labelPointAccessor));
}
}
@@ -157,10 +200,11 @@ export class ProjectorScatterPlotAdapter {
updateScatterPlotPositions() {
const ds = (this.projection == null) ? null : this.projection.dataSet;
- const accessors =
- (this.projection == null) ? null : this.projection.pointAccessors;
- const newPositions = this.generatePointPositionArray(ds, accessors);
- this.scatterPlot.setPointPositions(ds, newPositions);
+ const projectionComponents =
+ (this.projection == null) ? null : this.projection.projectionComponents;
+ const newPositions =
+ this.generatePointPositionArray(ds, projectionComponents);
+ this.scatterPlot.setPointPositions(newPositions);
}
updateScatterPlotAttributes() {
@@ -198,10 +242,10 @@ export class ProjectorScatterPlotAdapter {
this.scatterPlot.render();
}
- generatePointPositionArray(ds: DataSet, pointAccessors: PointAccessors3D):
- Float32Array {
+ generatePointPositionArray(
+ ds: DataSet, projectionComponents: ProjectionComponents3D): Float32Array {
if (ds == null) {
- return new Float32Array(0);
+ return null;
}
const xScaler: d3.scale.Linear<number, number> = d3.scale.linear();
@@ -209,8 +253,12 @@ export class ProjectorScatterPlotAdapter {
let zScaler: d3.scale.Linear<number, number> = null;
{
// Determine max and min of each axis of our data.
- const xExtent = d3.extent(ds.points, (p, i) => pointAccessors[0](i));
- const yExtent = d3.extent(ds.points, (p, i) => pointAccessors[1](i));
+ const xExtent = d3.extent(
+ ds.points,
+ (p, i) => ds.points[i].projections[projectionComponents[0]]);
+ const yExtent = d3.extent(
+ ds.points,
+ (p, i) => ds.points[i].projections[projectionComponents[1]]);
const range =
[-SCATTER_PLOT_CUBE_LENGTH / 2, SCATTER_PLOT_CUBE_LENGTH / 2];
@@ -218,8 +266,10 @@ export class ProjectorScatterPlotAdapter {
xScaler.domain(xExtent).range(range);
yScaler.domain(yExtent).range(range);
- if (pointAccessors[2] != null) {
- const zExtent = d3.extent(ds.points, (p, i) => pointAccessors[2](i));
+ if (projectionComponents[2] != null) {
+ const zExtent = d3.extent(
+ ds.points,
+ (p, i) => ds.points[i].projections[projectionComponents[2]]);
zScaler = d3.scale.linear();
zScaler.domain(zExtent).range(range);
}
@@ -229,15 +279,18 @@ export class ProjectorScatterPlotAdapter {
let dst = 0;
ds.points.forEach((d, i) => {
- positions[dst++] = xScaler(pointAccessors[0](i));
- positions[dst++] = yScaler(pointAccessors[1](i));
+ positions[dst++] =
+ xScaler(ds.points[i].projections[projectionComponents[0]]);
+ positions[dst++] =
+ yScaler(ds.points[i].projections[projectionComponents[1]]);
positions[dst++] = 0.0;
});
if (zScaler) {
dst = 2;
ds.points.forEach((d, i) => {
- positions[dst] = zScaler(pointAccessors[2](i));
+ positions[dst] =
+ zScaler(ds.points[i].projections[projectionComponents[2]]);
dst += 3;
});
}
@@ -253,7 +306,11 @@ export class ProjectorScatterPlotAdapter {
return null;
}
- const n = selectedPointIndices.length + neighborsOfFirstPoint.length +
+ const selectedPointCount =
+ (selectedPointIndices == null) ? 0 : selectedPointIndices.length;
+ const neighborCount =
+ (neighborsOfFirstPoint == null) ? 0 : neighborsOfFirstPoint.length;
+ const n = selectedPointCount + neighborCount +
((hoverPointIndex != null) ? 1 : 0);
const visibleLabels = new Uint32Array(n);
@@ -261,6 +318,7 @@ export class ProjectorScatterPlotAdapter {
const opacityFlags = new Int8Array(n);
const fillColors = new Uint8Array(n * 3);
const strokeColors = new Uint8Array(n * 3);
+ const labelStrings: string[] = [];
scale.fill(LABEL_SCALE_DEFAULT);
opacityFlags.fill(1);
@@ -268,6 +326,7 @@ export class ProjectorScatterPlotAdapter {
let dst = 0;
if (hoverPointIndex != null) {
+ labelStrings.push(this.labelPointAccessor(ds, hoverPointIndex));
visibleLabels[dst] = hoverPointIndex;
scale[dst] = LABEL_SCALE_LARGE;
opacityFlags[dst] = 0;
@@ -282,11 +341,13 @@ export class ProjectorScatterPlotAdapter {
// Selected points
{
- const n = selectedPointIndices.length;
+ const n = selectedPointCount;
const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_SELECTED);
const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_SELECTED);
for (let i = 0; i < n; ++i) {
- visibleLabels[dst] = selectedPointIndices[i];
+ const labelIndex = selectedPointIndices[i];
+ labelStrings.push(this.labelPointAccessor(ds, labelIndex));
+ visibleLabels[dst] = labelIndex;
scale[dst] = LABEL_SCALE_LARGE;
opacityFlags[dst] = (n === 1) ? 0 : 1;
packRgbIntoUint8Array(
@@ -299,11 +360,13 @@ export class ProjectorScatterPlotAdapter {
// Neighbors
{
- const n = neighborsOfFirstPoint.length;
+ const n = neighborCount;
const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_NEIGHBOR);
const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_NEIGHBOR);
for (let i = 0; i < n; ++i) {
- visibleLabels[dst] = neighborsOfFirstPoint[i].index;
+ const labelIndex = neighborsOfFirstPoint[i].index;
+ labelStrings.push(this.labelPointAccessor(ds, labelIndex));
+ visibleLabels[dst] = labelIndex;
packRgbIntoUint8Array(
fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]);
packRgbIntoUint8Array(
@@ -313,8 +376,8 @@ export class ProjectorScatterPlotAdapter {
}
return new LabelRenderParams(
- this.labelPointAccessor, visibleLabels, scale, opacityFlags,
- LABEL_FONT_SIZE, fillColors, strokeColors);
+ visibleLabels, labelStrings, scale, opacityFlags, LABEL_FONT_SIZE,
+ fillColors, strokeColors);
}
generatePointScaleFactorArray(
@@ -328,9 +391,14 @@ export class ProjectorScatterPlotAdapter {
const scale = new Float32Array(ds.points.length);
scale.fill(POINT_SCALE_DEFAULT);
+ const selectedPointCount =
+ (selectedPointIndices == null) ? 0 : selectedPointIndices.length;
+ const neighborCount =
+ (neighborsOfFirstPoint == null) ? 0 : neighborsOfFirstPoint.length;
+
// Scale up all selected points.
{
- const n = selectedPointIndices.length;
+ const n = selectedPointCount;
for (let i = 0; i < n; ++i) {
const p = selectedPointIndices[i];
scale[p] = POINT_SCALE_SELECTED;
@@ -339,7 +407,7 @@ export class ProjectorScatterPlotAdapter {
// Scale up the neighbor points.
{
- const n = neighborsOfFirstPoint.length;
+ const n = neighborCount;
for (let i = 0; i < n; ++i) {
const p = neighborsOfFirstPoint[i].index;
scale[p] = POINT_SCALE_NEIGHBOR;
@@ -355,7 +423,7 @@ export class ProjectorScatterPlotAdapter {
}
generateLineSegmentColorMap(
- ds: DataSet, legendPointColorer: (index: number) => string):
+ ds: DataSet, legendPointColorer: (ds: DataSet, index: number) => string):
{[trace: number]: Float32Array} {
let traceColorArrayMap: {[trace: number]: Float32Array} = {};
if (ds == null) {
@@ -370,10 +438,10 @@ export class ProjectorScatterPlotAdapter {
if (legendPointColorer) {
for (let j = 0; j < dataTrace.pointIndices.length - 1; j++) {
- const c1 =
- new THREE.Color(legendPointColorer(dataTrace.pointIndices[j]));
+ const c1 = new THREE.Color(
+ legendPointColorer(ds, dataTrace.pointIndices[j]));
const c2 = new THREE.Color(
- legendPointColorer(dataTrace.pointIndices[j + 1]));
+ legendPointColorer(ds, dataTrace.pointIndices[j + 1]));
colors[colorIndex++] = c1.r;
colors[colorIndex++] = c1.g;
colors[colorIndex++] = c1.b;
@@ -408,7 +476,9 @@ export class ProjectorScatterPlotAdapter {
return new Float32Array(0);
}
const opacities = new Float32Array(ds.traces.length);
- if (selectedPoints.length > 0) {
+ const selectedPointCount =
+ (selectedPoints == null) ? 0 : selectedPoints.length;
+ if (selectedPointCount > 0) {
opacities.fill(TRACE_DESELECTED_OPACITY);
const i = ds.points[selectedPoints[0]].traceIndex;
opacities[i] = TRACE_SELECTED_OPACITY;
@@ -425,7 +495,9 @@ export class ProjectorScatterPlotAdapter {
}
const widths = new Float32Array(ds.traces.length);
widths.fill(TRACE_DEFAULT_LINEWIDTH);
- if (selectedPoints.length > 0) {
+ const selectedPointCount =
+ (selectedPoints == null) ? 0 : selectedPoints.length;
+ if (selectedPointCount > 0) {
const i = ds.points[selectedPoints[0]].traceIndex;
widths[i] = TRACE_SELECTED_LINEWIDTH;
}
@@ -433,7 +505,7 @@ export class ProjectorScatterPlotAdapter {
}
generatePointColorArray(
- ds: DataSet, legendPointColorer: (index: number) => string,
+ ds: DataSet, legendPointColorer: (ds: DataSet, index: number) => string,
distFunc: DistanceFunction, selectedPointIndices: number[],
neighborsOfFirstPoint: NearestEntry[], hoverPointIndex: number,
label3dMode: boolean, spriteImageMode: boolean): Float32Array {
@@ -441,6 +513,10 @@ export class ProjectorScatterPlotAdapter {
return new Float32Array(0);
}
+ const selectedPointCount =
+ (selectedPointIndices == null) ? 0 : selectedPointIndices.length;
+ const neighborCount =
+ (neighborsOfFirstPoint == null) ? 0 : neighborsOfFirstPoint.length;
const colors = new Float32Array(ds.points.length * 3);
let unselectedColor = POINT_COLOR_UNSELECTED;
@@ -460,7 +536,7 @@ export class ProjectorScatterPlotAdapter {
{
const n = ds.points.length;
let dst = 0;
- if (selectedPointIndices.length > 0) {
+ if (selectedPointCount > 0) {
const c = new THREE.Color(unselectedColor);
for (let i = 0; i < n; ++i) {
colors[dst++] = c.r;
@@ -470,7 +546,7 @@ export class ProjectorScatterPlotAdapter {
} else {
if (legendPointColorer != null) {
for (let i = 0; i < n; ++i) {
- const c = new THREE.Color(legendPointColorer(i));
+ const c = new THREE.Color(legendPointColorer(ds, i));
colors[dst++] = c.r;
colors[dst++] = c.g;
colors[dst++] = c.b;
@@ -488,7 +564,7 @@ export class ProjectorScatterPlotAdapter {
// Color the selected points.
{
- const n = selectedPointIndices.length;
+ const n = selectedPointCount;
const c = new THREE.Color(POINT_COLOR_SELECTED);
for (let i = 0; i < n; ++i) {
let dst = selectedPointIndices[i] * 3;
@@ -500,7 +576,7 @@ export class ProjectorScatterPlotAdapter {
// Color the neighbors.
{
- const n = neighborsOfFirstPoint.length;
+ const n = neighborCount;
let minDist = n > 0 ? neighborsOfFirstPoint[0].dist : 0;
for (let i = 0; i < n; ++i) {
const c = new THREE.Color(
@@ -524,6 +600,19 @@ export class ProjectorScatterPlotAdapter {
return colors;
}
+ generate3DLabelsArray(
+ ds: DataSet, accessor: (ds: DataSet, i: number) => string) {
+ if ((ds == null) || (accessor == null)) {
+ return null;
+ }
+ let labels: string[] = [];
+ const n = ds.points.length;
+ for (let i = 0; i < n; ++i) {
+ labels.push(accessor(ds, i).toString());
+ }
+ return labels;
+ }
+
private updateScatterPlotWithNewProjection(projection: Projection) {
if (projection != null) {
this.scatterPlot.setDimensions(projection.dimensionality);
@@ -543,16 +632,32 @@ export class ProjectorScatterPlotAdapter {
const scatterPlot = this.scatterPlot;
scatterPlot.removeAllVisualizers();
this.labels3DVisualizer = null;
+ this.canvasLabelsVisualizer = null;
+ this.spriteVisualizer = null;
+ this.traceVisualizer = null;
if (inLabels3DMode) {
this.labels3DVisualizer = new ScatterPlotVisualizer3DLabels();
- this.labels3DVisualizer.setLabelAccessor(this.labelPointAccessor);
- scatterPlot.addVisualizer(this.labels3DVisualizer);
+ this.labels3DVisualizer.setLabelStrings(this.generate3DLabelsArray(
+ this.projection.dataSet, this.labelPointAccessor));
} else {
- scatterPlot.addVisualizer(new ScatterPlotVisualizerSprites());
- scatterPlot.addVisualizer(
- new ScatterPlotVisualizerCanvasLabels(this.scatterPlotContainer));
+ this.spriteVisualizer = new ScatterPlotVisualizerSprites();
+ scatterPlot.addVisualizer(this.spriteVisualizer);
+ this.canvasLabelsVisualizer =
+ new ScatterPlotVisualizerCanvasLabels(this.scatterPlotContainer);
+ }
+ this.traceVisualizer = new ScatterPlotVisualizerTraces();
+ const dataSet = (this.projection == null) ? null : this.projection.dataSet;
+ this.setDataSet(dataSet);
+ if (this.spriteVisualizer) {
+ scatterPlot.addVisualizer(this.spriteVisualizer);
+ }
+ if (this.labels3DVisualizer) {
+ scatterPlot.addVisualizer(this.labels3DVisualizer);
+ }
+ if (this.canvasLabelsVisualizer) {
+ scatterPlot.addVisualizer(this.canvasLabelsVisualizer);
}
- scatterPlot.addVisualizer(new ScatterPlotVisualizerTraces());
+ scatterPlot.addVisualizer(this.traceVisualizer);
}
private getSpriteImageMode(): boolean {
diff --git a/tensorflow/tensorboard/components/vz_projector/renderContext.ts b/tensorflow/tensorboard/components/vz_projector/renderContext.ts
index 27c1310992..2e7e254596 100644
--- a/tensorflow/tensorboard/components/vz_projector/renderContext.ts
+++ b/tensorflow/tensorboard/components/vz_projector/renderContext.ts
@@ -19,10 +19,10 @@ limitations under the License.
*/
export class LabelRenderParams {
constructor(
- public labelAccessor: (index: number) => string,
- public pointIndices: Float32Array, public scaleFactors: Float32Array,
- public useSceneOpacityFlags: Int8Array, public defaultFontSize: number,
- public fillColors: Uint8Array, public strokeColors: Uint8Array) {}
+ public pointIndices: Float32Array, public labelStrings: string[],
+ public scaleFactors: Float32Array, public useSceneOpacityFlags: Int8Array,
+ public defaultFontSize: number, public fillColors: Uint8Array,
+ public strokeColors: Uint8Array) {}
}
/** Details about the camera projection being used to render the scene. */
diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts
index 3a30a74503..9d9b0b5aff 100644
--- a/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts
+++ b/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-import {DataSet} from './data';
import {ProjectorEventContext} from './projectorEventContext';
import {CameraType, LabelRenderParams, RenderContext} from './renderContext';
import {BoundingBox, ScatterPlotRectangleSelector} from './scatterPlotRectangleSelector';
@@ -73,7 +72,6 @@ export class CameraDef {
* array of visualizers and dispatches application events to them.
*/
export class ScatterPlot {
- private dataSet: DataSet;
private projectorEventContext: ProjectorEventContext;
private containerNode: HTMLElement;
@@ -104,7 +102,6 @@ export class ScatterPlot {
private pointColors: Float32Array;
private pointScaleFactors: Float32Array;
private labels: LabelRenderParams;
-
private traceColors: {[trace: number]: Float32Array};
private traceOpacities: Float32Array;
private traceWidths: Float32Array;
@@ -337,9 +334,6 @@ export class ScatterPlot {
* hoverlisteners (usually called from embedding.ts)
*/
private onMouseMove(e: MouseEvent) {
- if (!this.dataSet) {
- return;
- }
this.isDragSequence = this.mouseIsDown;
// Depending if we're selecting or just navigating, handle accordingly.
if (this.selecting && this.mouseIsDown) {
@@ -390,6 +384,10 @@ export class ScatterPlot {
*/
private getPointIndicesFromPickingTexture(boundingBox: BoundingBox):
number[] {
+ if (this.worldSpacePointPositions == null) {
+ return null;
+ }
+ const pointCount = this.worldSpacePointPositions.length / 3;
const dpr = window.devicePixelRatio || 1;
const x = Math.floor(boundingBox.x * dpr);
const y = Math.floor(boundingBox.y * dpr);
@@ -411,7 +409,7 @@ export class ScatterPlot {
for (let i = 0; i < width * height; i++) {
const id = (pixelBuffer[i * 4] << 16) | (pixelBuffer[i * 4 + 1] << 8) |
pixelBuffer[i * 4 + 2];
- if (id !== 0xffffff && (id < this.dataSet.points.length)) {
+ if (id !== 0xffffff && (id < pointCount)) {
pointIndicesSelection[id] = 1;
}
}
@@ -436,12 +434,10 @@ export class ScatterPlot {
this.nearestPoint = null;
return;
}
-
- let boundingBox:
+ const boundingBox:
BoundingBox = {x: e.offsetX, y: e.offsetY, width: 1, height: 1};
-
- let pointIndices = this.getPointIndicesFromPickingTexture(boundingBox);
- this.nearestPoint = pointIndices[0];
+ const pointIndices = this.getPointIndicesFromPickingTexture(boundingBox);
+ this.nearestPoint = (pointIndices != null) ? pointIndices[0] : null;
}
private getLayoutValues(): Point2D {
@@ -560,10 +556,7 @@ export class ScatterPlot {
visualizer.setScene(this.scene);
}
visualizer.onResize(this.width, this.height);
- if (this.dataSet) {
- visualizer.onPointPositionsChanged(
- this.worldSpacePointPositions, this.dataSet);
- }
+ visualizer.onPointPositionsChanged(this.worldSpacePointPositions);
this.visualizers.push(visualizer);
}
@@ -574,19 +567,13 @@ export class ScatterPlot {
}
/** Update scatter plot with a new array of packed xyz point positions. */
- setPointPositions(dataSet: DataSet, worldSpacePointPositions: Float32Array) {
- this.dataSet = dataSet;
+ setPointPositions(worldSpacePointPositions: Float32Array) {
this.worldSpacePointPositions = worldSpacePointPositions;
- this.visualizers.forEach(v => {
- v.onPointPositionsChanged(worldSpacePointPositions, this.dataSet);
- });
+ this.visualizers.forEach(
+ v => v.onPointPositionsChanged(worldSpacePointPositions));
}
render() {
- if (this.dataSet == null) {
- return;
- }
-
{
const lightPos = this.camera.position.clone();
lightPos.x += 1;
@@ -598,9 +585,12 @@ export class ScatterPlot {
CameraType.Perspective :
CameraType.Orthographic;
- const cameraSpacePointExtents: [number, number] = util.getNearFarPoints(
- this.worldSpacePointPositions, this.camera.position,
- this.orbitCameraControls.target);
+ let cameraSpacePointExtents: [number, number] = [0, 0];
+ if (this.worldSpacePointPositions != null) {
+ cameraSpacePointExtents = util.getNearFarPoints(
+ this.worldSpacePointPositions, this.camera.position,
+ this.orbitCameraControls.target);
+ }
const rc = new RenderContext(
this.camera, cameraType, this.orbitCameraControls.target, this.width,
@@ -612,9 +602,7 @@ export class ScatterPlot {
// with colors that are actually point ids, so that sampling the texture at
// the mouse's current x,y coordinates will reveal the data point that the
// mouse is over.
- this.visualizers.forEach(v => {
- v.onPickingRender(rc);
- });
+ this.visualizers.forEach(v => v.onPickingRender(rc));
{
const axes = this.remove3dAxisFromScene();
@@ -625,9 +613,7 @@ export class ScatterPlot {
}
// Render second pass to color buffer, to be displayed on the canvas.
- this.visualizers.forEach(v => {
- v.onRender(rc);
- });
+ this.visualizers.forEach(v => v.onRender(rc));
this.renderer.render(this.scene, this.camera);
}
@@ -723,9 +709,7 @@ export class ScatterPlot {
this.pickingTexture.texture.minFilter = THREE.LinearFilter;
}
- this.visualizers.forEach(v => {
- v.onResize(newW, newH);
- });
+ this.visualizers.forEach(v => v.onResize(newW, newH));
if (render) {
this.render();
diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer.ts
index 2d7f5cd640..b0974a2053 100644
--- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer.ts
+++ b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer.ts
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-import {DataSet} from './data';
import {RenderContext} from './renderContext';
/**
@@ -33,8 +32,7 @@ export interface ScatterPlotVisualizer {
/**
* Called when the positions of the scatter plot points have changed.
*/
- onPointPositionsChanged(
- newWorldSpacePointPositions: Float32Array, dataSet: DataSet);
+ onPointPositionsChanged(newWorldSpacePointPositions: Float32Array);
/**
* Called immediately before the main scatter plot performs a picking
* (selection) render. Set up render state for any geometry to use picking IDs
diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer3DLabels.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer3DLabels.ts
index 3811e10c57..ecd2e21403 100644
--- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer3DLabels.ts
+++ b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer3DLabels.ts
@@ -98,7 +98,7 @@ type GlyphTexture = {
export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer {
private dataSet: DataSet;
private scene: THREE.Scene;
- private labelAccessor: (index: number) => string;
+ private labelStrings: string[];
private geometry: THREE.BufferGeometry;
private worldSpacePointPositions: Float32Array;
private pickingColors: Float32Array;
@@ -111,6 +111,10 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer {
private labelVertexMap: number[][];
private glyphTexture: GlyphTexture;
+ setDataSet(ds: DataSet) {
+ this.dataSet = ds;
+ }
+
private createGlyphTexture(): GlyphTexture {
let canvas = document.createElement('canvas');
canvas.width = MAX_CANVAS_DIMENSION;
@@ -139,11 +143,11 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer {
return {texture: tex, lengths: glyphLengths, offsets: glyphOffset};
}
- private processLabelVerts() {
+ private processLabelVerts(pointCount: number) {
let numTotalLetters = 0;
this.labelVertexMap = [];
- for (let i = 0; i < this.dataSet.points.length; i++) {
- let label: string = this.labelAccessor(i).toString();
+ for (let i = 0; i < pointCount; i++) {
+ const label = this.labelStrings[i];
let vertsArray: number[] = [];
for (let j = 0; j < label.length; j++) {
for (let k = 0; k < VERTICES_PER_GLYPH; k++) {
@@ -156,13 +160,12 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer {
this.totalVertexCount = numTotalLetters * VERTICES_PER_GLYPH;
}
- private createColorBuffers() {
- let numPoints = this.dataSet.points.length;
+ private createColorBuffers(pointCount: number) {
this.pickingColors =
new Float32Array(this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY);
this.renderColors =
new Float32Array(this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY);
- for (let i = 0; i < numPoints; i++) {
+ for (let i = 0; i < pointCount; i++) {
let color = new THREE.Color(i);
this.labelVertexMap[i].forEach((j) => {
this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j] = color.r;
@@ -175,7 +178,16 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer {
}
}
- private createLabels(dataSet: DataSet) {
+ private createLabels() {
+ if ((this.labelStrings == null) ||
+ (this.worldSpacePointPositions == null)) {
+ return;
+ }
+ const pointCount =
+ this.worldSpacePointPositions.length / XYZ_ELEMENTS_PER_ENTRY;
+ if (pointCount !== this.labelStrings.length) {
+ return;
+ }
this.glyphTexture = this.createGlyphTexture();
this.uniforms = {
@@ -190,8 +202,8 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer {
fragmentShader: FRAGMENT_SHADER,
});
- this.processLabelVerts();
- this.createColorBuffers();
+ this.processLabelVerts(pointCount);
+ this.createColorBuffers(pointCount);
let positionArray =
new Float32Array(this.totalVertexCount * XYZ_ELEMENTS_PER_ENTRY);
@@ -215,8 +227,8 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer {
this.geometry.addAttribute('color', colors);
let lettersSoFar = 0;
- for (let i = 0; i < dataSet.points.length; i++) {
- let label: string = this.labelAccessor(i).toString();
+ for (let i = 0; i < pointCount; i++) {
+ const label = this.labelStrings[i];
let leftOffset = 0;
// Determine length of word in pixels.
for (let j = 0; j < label.length; j++) {
@@ -262,8 +274,7 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer {
}
}
- const n = dataSet.points.length;
- for (let i = 0; i < n; i++) {
+ for (let i = 0; i < pointCount; i++) {
const p = util.vector3FromPackedArray(this.worldSpacePointPositions, i);
this.labelVertexMap[i].forEach((j) => {
this.positions.setXYZ(j, p.x, p.y, p.z);
@@ -276,7 +287,7 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer {
}
private colorLabels(pointColors: Float32Array) {
- if (this.labelAccessor == null || this.geometry == null ||
+ if (this.labelStrings == null || this.geometry == null ||
this.dataSet == null || pointColors == null) {
return;
}
@@ -319,40 +330,43 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer {
}
}
- setLabelAccessor(labelAccessor: (index: number) => string) {
- this.labelAccessor = labelAccessor;
- this.dispose();
- this.onPointPositionsChanged(this.worldSpacePointPositions, this.dataSet);
- }
-
onPickingRender(rc: RenderContext) {
+ if (this.geometry == null) {
+ this.createLabels();
+ }
+ if (this.geometry == null) {
+ return;
+ }
this.material.uniforms.texture.value = this.glyphTexture.texture;
this.material.uniforms.picking.value = true;
-
- let colors = this.geometry.getAttribute('color') as THREE.BufferAttribute;
+ const colors = this.geometry.getAttribute('color') as THREE.BufferAttribute;
colors.array = this.pickingColors;
colors.needsUpdate = true;
}
onRender(rc: RenderContext) {
+ if (this.geometry == null) {
+ this.createLabels();
+ }
+ if (this.geometry == null) {
+ return;
+ }
this.colorLabels(rc.pointColors);
-
this.material.uniforms.texture.value = this.glyphTexture.texture;
this.material.uniforms.picking.value = false;
-
const colors = this.geometry.getAttribute('color') as THREE.BufferAttribute;
colors.array = this.renderColors;
colors.needsUpdate = true;
}
- onPointPositionsChanged(newPositions: Float32Array, dataSet: DataSet) {
+ onPointPositionsChanged(newPositions: Float32Array) {
this.worldSpacePointPositions = newPositions;
- this.dataSet = dataSet;
this.dispose();
- if ((this.dataSet != null) && (this.labelAccessor != null) &&
- (this.worldSpacePointPositions != null)) {
- this.createLabels(this.dataSet);
- }
+ }
+
+ setLabelStrings(labelStrings: string[]) {
+ this.labelStrings = labelStrings;
+ this.dispose();
}
onResize(newWidth: number, newHeight: number) {}
diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts
index 959b077a4d..ef473eda6c 100644
--- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts
+++ b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts
@@ -29,6 +29,7 @@ const LABEL_FILL_WIDTH = 6;
*/
export class ScatterPlotVisualizerCanvasLabels implements
ScatterPlotVisualizer {
+ private dataSet: DataSet;
private worldSpacePointPositions: Float32Array;
private gc: CanvasRenderingContext2D;
private canvas: HTMLCanvasElement;
@@ -41,6 +42,10 @@ export class ScatterPlotVisualizerCanvasLabels implements
this.canvas.style.pointerEvents = 'none';
}
+ setDataSet(ds: DataSet) {
+ this.dataSet = ds;
+ }
+
private removeAllLabels() {
const pixelWidth = this.canvas.width * window.devicePixelRatio;
const pixelHeight = this.canvas.height * window.devicePixelRatio;
@@ -49,13 +54,18 @@ export class ScatterPlotVisualizerCanvasLabels implements
/** Render all of the non-overlapping visible labels to the canvas. */
private makeLabels(rc: RenderContext) {
+ if (this.dataSet == null) {
+ return;
+ }
if ((rc.labels == null) || (rc.labels.pointIndices.length === 0)) {
return;
}
+ if (this.worldSpacePointPositions == null) {
+ return;
+ }
const lrc = rc.labels;
const sceneIs3D: boolean = (rc.cameraType === CameraType.Perspective);
-
const labelHeight = parseInt(this.gc.font, 10);
const dpr = window.devicePixelRatio;
@@ -87,9 +97,11 @@ export class ScatterPlotVisualizerCanvasLabels implements
const n = Math.min(MAX_LABELS_ON_SCREEN, lrc.pointIndices.length);
for (let i = 0; i < n; ++i) {
- const index = lrc.pointIndices[i];
- const point =
- util.vector3FromPackedArray(this.worldSpacePointPositions, index);
+ let point: THREE.Vector3;
+ {
+ const pi = lrc.pointIndices[i];
+ point = util.vector3FromPackedArray(this.worldSpacePointPositions, pi);
+ }
// discard points that are behind the camera
camToPoint.copy(camPos).sub(point);
@@ -112,7 +124,7 @@ export class ScatterPlotVisualizerCanvasLabels implements
};
if (grid.insert(textBoundingBox, true)) {
- const text = lrc.labelAccessor(index);
+ const text = lrc.labelStrings[i];
const fontSize = lrc.defaultFontSize * lrc.scaleFactors[i] * dpr;
this.gc.font = fontSize + 'px roboto';
@@ -160,7 +172,7 @@ export class ScatterPlotVisualizerCanvasLabels implements
this.gc = null;
}
- onPointPositionsChanged(newPositions: Float32Array, dataSet: DataSet) {
+ onPointPositionsChanged(newPositions: Float32Array) {
this.worldSpacePointPositions = newPositions;
this.removeAllLabels();
}
diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts
index db1fb691fa..1facddba1a 100644
--- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts
+++ b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-import {DataSet} from './data';
import {CameraType, RenderContext} from './renderContext';
import {ScatterPlotVisualizer} from './scatterPlotVisualizer';
import * as util from './util';
@@ -30,7 +29,7 @@ const XYZ_NUM_ELEMENTS = 3;
const VERTEX_SHADER = `
// Index of the specific vertex (passed in as bufferAttribute), and the
// variable that will be used to pass it to the fragment shader.
- attribute float vertexIndex;
+ attribute float spriteIndex;
attribute vec3 color;
attribute float scaleFactor;
@@ -39,14 +38,14 @@ const VERTEX_SHADER = `
uniform bool sizeAttenuation;
uniform float pointSize;
- uniform float imageWidth;
- uniform float imageHeight;
+ uniform float spritesPerRow;
+ uniform float spritesPerColumn;
void main() {
// Pass index and color values to fragment shader.
vColor = color;
- xyIndex = vec2(mod(vertexIndex, imageWidth),
- floor(vertexIndex / imageWidth));
+ xyIndex = vec2(mod(spriteIndex, spritesPerRow),
+ floor(spriteIndex / spritesPerColumn));
// Transform current vertex by modelViewMatrix (model world position and
// camera world position matrix).
@@ -93,8 +92,8 @@ const FRAGMENT_SHADER = `
varying vec3 vColor;
uniform sampler2D texture;
- uniform float imageWidth;
- uniform float imageHeight;
+ uniform float spritesPerRow;
+ uniform float spritesPerColumn;
uniform bool isImage;
${THREE.ShaderChunk['common']}
@@ -104,7 +103,8 @@ const FRAGMENT_SHADER = `
void main() {
if (isImage) {
// Coordinates of the vertex within the entire sprite image.
- vec2 coords = (gl_PointCoord + xyIndex) / vec2(imageWidth, imageHeight);
+ vec2 coords =
+ (gl_PointCoord + xyIndex) / vec2(spritesPerRow, spritesPerColumn);
gl_FragColor = vec4(vColor, 1.0) * texture2D(texture, coords);
} else {
bool inside = point_in_unit_circle(gl_PointCoord);
@@ -140,11 +140,13 @@ const FRAGMENT_SHADER_PICKING = `
* Uses GL point sprites to render the dataset.
*/
export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer {
- private image: HTMLImageElement;
-
private scene: THREE.Scene;
private fog: THREE.Fog;
private texture: THREE.Texture = null;
+ private standinTextureForPoints: THREE.Texture;
+ private spritesPerRow: number;
+ private spritesPerColumn: number;
+ private spriteIndexBufferAttribute: THREE.BufferAttribute;
private renderMaterial: THREE.ShaderMaterial;
private pickingMaterial: THREE.ShaderMaterial;
@@ -153,46 +155,47 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer {
private pickingColors: Float32Array;
private renderColors: Float32Array;
- /**
- * Create points, set their locations and actually instantiate the
- * geometry.
- */
- private createPointSprites(
- scene: THREE.Scene, positions: Float32Array, dataSet: DataSet) {
- const geometry =
- this.createGeometry(positions.length / XYZ_NUM_ELEMENTS, dataSet);
+ constructor() {
+ this.standinTextureForPoints =
+ util.createTexture(document.createElement('canvas'));
+ this.renderMaterial = this.createRenderMaterial(false);
+ this.pickingMaterial = this.createPickingMaterial(false);
+ }
- const haveImage = (this.image != null);
- this.fog = new THREE.Fog(0xFFFFFF); // unused value, gets overwritten.
+ private createTextureFromSpriteAtlas(
+ spriteAtlas: HTMLImageElement, spriteDimensions: [number, number],
+ spriteIndices: Float32Array) {
+ this.texture = util.createTexture(spriteAtlas);
+ this.spritesPerRow = spriteAtlas.width / spriteDimensions[0];
+ this.spritesPerColumn = spriteAtlas.height / spriteDimensions[1];
- {
- const image = this.image || document.createElement('canvas');
- this.texture = util.createTexture(image);
- }
+ this.spriteIndexBufferAttribute =
+ new THREE.BufferAttribute(spriteIndices, INDEX_NUM_ELEMENTS);
- let imageDim = [1, 1];
- {
- const spriteMetadata = dataSet.spriteAndMetadataInfo.spriteMetadata;
- if (haveImage && spriteMetadata) {
- imageDim[0] = this.image.width / spriteMetadata.singleImageDim[0];
- imageDim[1] = this.image.height / spriteMetadata.singleImageDim[1];
- }
+ if (this.points != null) {
+ (this.points.geometry as THREE.BufferGeometry)
+ .addAttribute('spriteIndex', this.spriteIndexBufferAttribute);
}
+ }
- const uniforms = {
+ private createUniforms(): any {
+ return {
texture: {type: 't'},
- imageWidth: {type: 'f', value: imageDim[0]},
- imageHeight: {type: 'f', value: imageDim[1]},
+ spritesPerRow: {type: 'f'},
+ spritesPerColumn: {type: 'f'},
fogColor: {type: 'c'},
fogNear: {type: 'f'},
fogFar: {type: 'f'},
- isImage: {type: 'bool', value: haveImage},
+ isImage: {type: 'bool'},
sizeAttenuation: {type: 'bool'},
pointSize: {type: 'f'}
};
+ }
- this.renderMaterial = new THREE.ShaderMaterial({
- uniforms: THREE.UniformsUtils.clone(uniforms),
+ private createRenderMaterial(haveImage: boolean): THREE.ShaderMaterial {
+ const uniforms = this.createUniforms();
+ return new THREE.ShaderMaterial({
+ uniforms: uniforms,
vertexShader: VERTEX_SHADER,
fragmentShader: FRAGMENT_SHADER,
transparent: !haveImage,
@@ -201,9 +204,12 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer {
fog: true,
blending: THREE.MultiplyBlending,
});
+ }
- this.pickingMaterial = new THREE.ShaderMaterial({
- uniforms: THREE.UniformsUtils.clone(uniforms),
+ private createPickingMaterial(haveImage: boolean): THREE.ShaderMaterial {
+ const uniforms = this.createUniforms();
+ return new THREE.ShaderMaterial({
+ uniforms: uniforms,
vertexShader: VERTEX_SHADER,
fragmentShader: FRAGMENT_SHADER_PICKING,
transparent: true,
@@ -212,17 +218,35 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer {
fog: false,
blending: THREE.NormalBlending,
});
+ }
+
+ /**
+ * Create points, set their locations and actually instantiate the
+ * geometry.
+ */
+ private createPointSprites(scene: THREE.Scene, positions: Float32Array) {
+ const pointCount =
+ (positions != null) ? (positions.length / XYZ_NUM_ELEMENTS) : 0;
+ const geometry = this.createGeometry(pointCount);
+
+ this.fog = new THREE.Fog(0xFFFFFF); // unused value, gets overwritten.
this.points = new THREE.Points(geometry, this.renderMaterial);
this.points.frustumCulled = false;
+ if (this.spriteIndexBufferAttribute != null) {
+ (this.points.geometry as THREE.BufferGeometry)
+ .addAttribute('spriteIndex', this.spriteIndexBufferAttribute);
+ }
scene.add(this.points);
}
private calculatePointSize(sceneIs3D: boolean): number {
- if (this.image != null) {
+ if (this.texture != null) {
return IMAGE_SIZE;
}
- const n = this.worldSpacePointPositions.length / XYZ_NUM_ELEMENTS;
+ const n = (this.worldSpacePointPositions != null) ?
+ (this.worldSpacePointPositions.length / XYZ_NUM_ELEMENTS) :
+ 1;
const SCALE = 200;
const LOG_BASE = 8;
const DIVISOR = 1.5;
@@ -234,8 +258,7 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer {
/**
* Set up buffer attributes to be used for the points/images.
*/
- private createGeometry(pointCount: number, dataSet: DataSet):
- THREE.BufferGeometry {
+ private createGeometry(pointCount: number): THREE.BufferGeometry {
const n = pointCount;
// Fill pickingColors with each point's unique id as its color.
@@ -250,13 +273,6 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer {
}
}
- const spriteIndexes =
- new THREE.BufferAttribute(new Float32Array(n), INDEX_NUM_ELEMENTS);
-
- for (let i = 0; i < n; i++) {
- spriteIndexes.setX(i, dataSet.points[i].index);
- }
-
const geometry = new THREE.BufferGeometry();
geometry.addAttribute(
'position', new THREE.BufferAttribute(null, XYZ_NUM_ELEMENTS));
@@ -264,7 +280,6 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer {
'color', new THREE.BufferAttribute(null, RGB_NUM_ELEMENTS));
geometry.addAttribute(
'scaleFactor', new THREE.BufferAttribute(null, INDEX_NUM_ELEMENTS));
- geometry.addAttribute('vertexIndex', spriteIndexes);
return geometry;
}
@@ -286,55 +301,83 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer {
}
dispose() {
- this.scene.remove(this.points);
- this.points.geometry.dispose();
- if (this.renderMaterial.uniforms.texture.value) {
- this.renderMaterial.uniforms.texture.value.dispose();
+ this.disposeGeometry();
+ this.disposeTextureAtlas();
+ this.worldSpacePointPositions = null;
+ }
+
+ private disposeGeometry() {
+ if (this.points != null) {
+ this.scene.remove(this.points);
+ this.points.geometry.dispose();
+ this.points = null;
+ }
+ }
+
+ private disposeTextureAtlas() {
+ if (this.texture != null) {
+ this.texture.dispose();
}
- this.points = null;
+ this.texture = null;
this.renderMaterial = null;
this.pickingMaterial = null;
- this.worldSpacePointPositions = null;
- this.image = null;
}
setScene(scene: THREE.Scene) {
this.scene = scene;
}
- onPointPositionsChanged(newPositions: Float32Array, dataSet: DataSet) {
+ setSpriteAtlas(
+ spriteImage: HTMLImageElement, spriteDimensions: [number, number],
+ spriteIndices: Uint8Array) {
+ this.disposeTextureAtlas();
+ this.createTextureFromSpriteAtlas(
+ spriteImage, spriteDimensions, spriteIndices);
+ this.renderMaterial = this.createRenderMaterial(true);
+ this.pickingMaterial = this.createPickingMaterial(true);
+ }
+
+ clearSpriteAtlas() {
+ this.disposeTextureAtlas();
+ this.renderMaterial = this.createRenderMaterial(false);
+ this.pickingMaterial = this.createPickingMaterial(false);
+ }
+
+ onPointPositionsChanged(newPositions: Float32Array) {
+ if ((newPositions == null) || (newPositions.length === 0)) {
+ this.disposeGeometry();
+ return;
+ }
+
if (this.points != null) {
- const notEnoughSpace = (this.pickingColors.length < newPositions.length);
- const newImage = (dataSet != null) &&
- (this.image !== dataSet.spriteAndMetadataInfo.spriteImage);
- if (notEnoughSpace || newImage) {
- this.dispose();
+ const notEnoughSpace =
+ (this.worldSpacePointPositions.length < newPositions.length);
+ if (notEnoughSpace) {
+ this.disposeGeometry();
}
}
- this.image =
- (dataSet != null) ? dataSet.spriteAndMetadataInfo.spriteImage : null;
this.worldSpacePointPositions = newPositions;
if (this.points == null) {
- this.createPointSprites(this.scene, newPositions, dataSet);
+ this.createPointSprites(this.scene, newPositions);
}
- if (newPositions) {
- const positions = (this.points.geometry as THREE.BufferGeometry)
- .getAttribute('position') as THREE.BufferAttribute;
- positions.array = newPositions;
- positions.needsUpdate = true;
- }
+ const positions = (this.points.geometry as THREE.BufferGeometry)
+ .getAttribute('position') as THREE.BufferAttribute;
+ positions.array = newPositions;
+ positions.needsUpdate = true;
}
onPickingRender(rc: RenderContext) {
- if (!this.points) {
+ if (this.points == null) {
return;
}
const sceneIs3D: boolean = (rc.cameraType === CameraType.Perspective);
+ this.pickingMaterial.uniforms.spritesPerRow.value = this.spritesPerRow;
+ this.pickingMaterial.uniforms.spritesPerRow.value = this.spritesPerColumn;
this.pickingMaterial.uniforms.sizeAttenuation.value = sceneIs3D;
this.pickingMaterial.uniforms.pointSize.value =
this.calculatePointSize(sceneIs3D);
@@ -367,7 +410,11 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer {
this.renderMaterial.uniforms.fogColor.value = this.scene.fog.color;
this.renderMaterial.uniforms.fogNear.value = this.fog.near;
this.renderMaterial.uniforms.fogFar.value = this.fog.far;
- this.renderMaterial.uniforms.texture.value = this.texture;
+ this.renderMaterial.uniforms.spritesPerRow.value = this.spritesPerRow;
+ this.renderMaterial.uniforms.spritesPerColumn.value = this.spritesPerColumn;
+ this.renderMaterial.uniforms.isImage.value = (this.texture != null);
+ this.renderMaterial.uniforms.texture.value =
+ (this.texture != null) ? this.texture : this.standinTextureForPoints;
this.renderMaterial.uniforms.sizeAttenuation.value = sceneIs3D;
this.renderMaterial.uniforms.pointSize.value =
this.calculatePointSize(sceneIs3D);
diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerTraces.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerTraces.ts
index ec71a93414..a1ff747ff3 100644
--- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerTraces.ts
+++ b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerTraces.ts
@@ -69,7 +69,7 @@ export class ScatterPlotVisualizerTraces implements ScatterPlotVisualizer {
}
dispose() {
- if (!this.traces) {
+ if (this.traces == null) {
return;
}
for (let i = 0; i < this.traces.length; i++) {
@@ -85,32 +85,30 @@ export class ScatterPlotVisualizerTraces implements ScatterPlotVisualizer {
this.scene = scene;
}
- onPointPositionsChanged(newPositions: Float32Array, dataSet: DataSet) {
+ setDataSet(dataSet: DataSet) {
this.dataSet = dataSet;
- if (dataSet == null) {
+ }
+
+ onPointPositionsChanged(newPositions: Float32Array) {
+ if ((newPositions == null) || (this.traces != null)) {
+ this.dispose();
+ }
+ if ((newPositions == null) || (this.dataSet == null)) {
return;
}
+ // Set up the position buffer arrays for each trace.
+ for (let i = 0; i < this.dataSet.traces.length; i++) {
+ let dataTrace = this.dataSet.traces[i];
+ const vertexCount = 2 * (dataTrace.pointIndices.length - 1);
- if ((this.traces == null) ||
- (this.traces.length !== dataSet.traces.length)) {
- if (this.traces != null) {
- this.dispose();
- }
- // Set up the position buffer arrays for each trace.
- for (let i = 0; i < this.dataSet.traces.length; i++) {
- let dataTrace = this.dataSet.traces[i];
- const vertexCount = 2 * (dataTrace.pointIndices.length - 1);
-
- let traces = new Float32Array(vertexCount * XYZ_NUM_ELEMENTS);
- this.tracePositionBuffer[i] =
- new THREE.BufferAttribute(traces, XYZ_NUM_ELEMENTS);
-
- let colors = new Float32Array(vertexCount * RGB_NUM_ELEMENTS);
- this.traceColorBuffer[i] =
- new THREE.BufferAttribute(colors, RGB_NUM_ELEMENTS);
- }
- }
+ let traces = new Float32Array(vertexCount * XYZ_NUM_ELEMENTS);
+ this.tracePositionBuffer[i] =
+ new THREE.BufferAttribute(traces, XYZ_NUM_ELEMENTS);
+ let colors = new Float32Array(vertexCount * RGB_NUM_ELEMENTS);
+ this.traceColorBuffer[i] =
+ new THREE.BufferAttribute(colors, RGB_NUM_ELEMENTS);
+ }
for (let i = 0; i < this.dataSet.traces.length; i++) {
const dataTrace = this.dataSet.traces[i];
let src = 0;
diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.ts
index 0de423cc9e..308e1685d2 100644
--- a/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.ts
+++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.ts
@@ -64,11 +64,12 @@ export class BookmarkPanel extends BookmarkPanelPolymer {
setSelectedTensor(
run: string, tensorInfo: EmbeddingInfo, dataProvider: DataProvider) {
+ // Clear any existing bookmarks.
+ this.addStates(null);
if (tensorInfo && tensorInfo.bookmarksPath) {
- this.loadAllStates([]);
// Get any bookmarks that may come when the projector starts up.
dataProvider.getBookmarks(run, tensorInfo.tensorName, bookmarks => {
- this.loadAllStates(bookmarks);
+ this.addStates(bookmarks);
});
}
}
@@ -145,7 +146,7 @@ export class BookmarkPanel extends BookmarkPanelPolymer {
// Verify the bookmarks match.
if (this.savedStatesValid(savedStates)) {
- this.loadAllStates(savedStates);
+ this.addStates(savedStates);
this.loadSavedState(0);
} else {
logging.setWarningMessage(
@@ -157,10 +158,14 @@ export class BookmarkPanel extends BookmarkPanelPolymer {
});
}
- loadAllStates(savedStates: State[]) {
- for (let i = 0; i < savedStates.length; i++) {
- savedStates[i].isSelected = false;
- this.push('savedStates', savedStates[i] as any);
+ addStates(savedStates?: State[]) {
+ if (savedStates == null) {
+ this.savedStates = [];
+ } else {
+ for (let i = 0; i < savedStates.length; i++) {
+ savedStates[i].isSelected = false;
+ this.push('savedStates', savedStates[i] as any);
+ }
}
this.updateHasStates();
}
@@ -168,34 +173,35 @@ export class BookmarkPanel extends BookmarkPanelPolymer {
/** Deselects any selected state selection. */
clearStateSelection() {
for (let i = 0; i < this.savedStates.length; i++) {
- if (this.savedStates[i].isSelected) {
- this.savedStates[i].isSelected = false;
- this.notifyPath('savedStates.' + i + '.isSelected', false, false);
- return;
- }
+ this.setSelectionState(i, false);
}
}
/** Handles a radio button click on a saved state. */
_radioButtonHandler(evt: Event) {
- this.loadSavedState(this.getParentDataIndex(evt));
+ const index = this.getParentDataIndex(evt);
+ this.loadSavedState(index);
+ this.setSelectionState(index, true);
}
loadSavedState(index: number) {
for (let i = 0; i < this.savedStates.length; i++) {
if (this.savedStates[i].isSelected) {
- this.savedStates[i].isSelected = false;
- this.notifyPath('savedStates.' + i + '.isSelected', false, false);
+ this.setSelectionState(i, false);
} else if (index === i) {
- this.savedStates[i].isSelected = true;
- this.notifyPath('savedStates.' + i + '.isSelected', true, false);
-
+ this.setSelectionState(i, true);
this.ignoreNextProjectionEvent = true;
this.projector.loadState(this.savedStates[i]);
}
}
}
+ private setSelectionState(stateIndex: number, selected: boolean) {
+ this.savedStates[stateIndex].isSelected = selected;
+ const path = 'savedStates.' + stateIndex + '.isSelected';
+ this.notifyPath(path, selected, false);
+ }
+
/**
* Crawls up the DOM to find an ancestor with a data-index attribute. This is
* used to match events to their bookmark index.
diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts
index 3b40ec27ce..32e9b0a724 100644
--- a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts
+++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts
@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-import {DataSet, PCA_SAMPLE_DIM, PCA_SAMPLE_SIZE, Projection, ProjectionType, SpriteAndMetadataInfo, State, TSNE_SAMPLE_SIZE} from './data';
+import * as data from './data';
+import {DataSet, Projection, ProjectionType, SpriteAndMetadataInfo, State} from './data';
import * as vector from './vector';
import {Vector} from './vector';
import {Projector} from './vz-projector';
@@ -289,18 +290,17 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer {
this.dataSet = dataSet;
this.originalDataSet = originalDataSet;
this.dim = dim;
- let perplexity =
- Math.max(5, Math.ceil(Math.sqrt(dataSet.points.length) / 4));
+ const pointCount = (dataSet == null) ? 0 : dataSet.points.length;
+ const perplexity = Math.max(5, Math.ceil(Math.sqrt(pointCount) / 4));
this.perplexitySlider.value = perplexity.toString();
this.updateTSNEPerplexityFromSliderChange();
this.clearCentroids();
this.dom.select('#tsne-sampling')
- .style(
- 'display',
- dataSet.points.length > TSNE_SAMPLE_SIZE ? null : 'none');
- let wasSampled =
- dataSet.dim[0] > PCA_SAMPLE_SIZE || dataSet.dim[1] > PCA_SAMPLE_DIM;
+ .style('display', pointCount > data.TSNE_SAMPLE_SIZE ? null : 'none');
+ const wasSampled =
+ (dataSet == null) ? false : (dataSet.dim[0] > data.PCA_SAMPLE_DIM ||
+ dataSet.dim[1] > data.PCA_SAMPLE_DIM);
this.dom.select('#pca-sampling')
.style('display', wasSampled ? null : 'none');
this.showTab('pca');
@@ -374,7 +374,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer {
return;
}
const accessors =
- dataSet.getPointAccessors('tsne', [0, 1, this.tSNEis3d ? 2 : null]);
+ data.getProjectionComponents('tsne', [0, 1, this.tSNEis3d ? 2 : null]);
const dimensionality = this.tSNEis3d ? 3 : 2;
const projection =
new Projection('tsne', accessors, dimensionality, dataSet);
@@ -427,7 +427,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer {
}
this.dataSet.projectPCA().then(() => {
// Polymer properties are 1-based.
- const accessors = this.dataSet.getPointAccessors(
+ const accessors = data.getProjectionComponents(
'pca', [this.pcaX, this.pcaY, this.pcaZ]);
const dimensionality = this.pcaIs3d ? 3 : 2;
@@ -459,7 +459,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer {
const yDir = vector.sub(this.centroids.yUp, this.centroids.yDown);
this.dataSet.projectLinear(yDir, 'linear-y');
- const accessors = this.dataSet.getPointAccessors('custom', ['x', 'y']);
+ const accessors = data.getProjectionComponents('custom', ['x', 'y']);
const projection = new Projection('custom', accessors, 2, this.dataSet);
this.projector.setProjection(projection);
}
@@ -543,15 +543,15 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer {
}
getPcaSampledDimText() {
- return PCA_SAMPLE_DIM.toLocaleString();
+ return data.PCA_SAMPLE_DIM.toLocaleString();
}
getPcaSampleSizeText() {
- return PCA_SAMPLE_SIZE.toLocaleString();
+ return data.PCA_SAMPLE_SIZE.toLocaleString();
}
getTsneSampleSizeText() {
- return TSNE_SAMPLE_SIZE.toLocaleString();
+ return data.TSNE_SAMPLE_SIZE.toLocaleString();
}
}
diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector.ts
index 655a75b7f4..14ea58b24a 100644
--- a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts
+++ b/tensorflow/tensorboard/components/vz_projector/vz-projector.ts
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
import {AnalyticsLogger} from './analyticsLogger';
+import * as data from './data';
import {ColorOption, ColumnStats, DataPoint, DataProto, DataSet, DistanceFunction, PointMetadata, Projection, SpriteAndMetadataInfo, State, stateGetAccessorDimensions} from './data';
import {DataProvider, EmbeddingInfo, ServingMode} from './data-provider';
import {DemoDataProvider} from './data-provider-demo';
@@ -68,6 +69,7 @@ export class Projector extends ProjectorPolymer implements
private distanceMetricChangedListeners: DistanceMetricChangedListener[];
private originalDataSet: DataSet;
+ private dataSetBeforeFilter: DataSet;
private dom: d3.Selection<any>;
private projectorScatterPlotAdapter: ProjectorScatterPlotAdapter;
private dim: number;
@@ -125,10 +127,8 @@ export class Projector extends ProjectorPolymer implements
setSelectedLabelOption(labelOption: string) {
this.selectedLabelOption = labelOption;
- let labelAccessor = (i: number): string => {
- return this.dataSet.points[i]
- .metadata[this.selectedLabelOption] as string;
- };
+ const labelAccessor = (ds: DataSet, i: number): string =>
+ ds.points[i].metadata[this.selectedLabelOption] as string;
this.metadataCard.setLabelOption(this.selectedLabelOption);
this.projectorScatterPlotAdapter.setLabelPointAccessor(labelAccessor);
this.projectorScatterPlotAdapter.render();
@@ -152,30 +152,41 @@ export class Projector extends ProjectorPolymer implements
metadataFile?: string) {
this.dataSetFilterIndices = null;
this.originalDataSet = ds;
- if (this.projectorScatterPlotAdapter == null || ds == null) {
- return;
+ if (ds != null) {
+ this.normalizeData =
+ this.originalDataSet.dim[1] >= THRESHOLD_DIM_NORMALIZE;
+ spriteAndMetadata = spriteAndMetadata || {};
+ if (spriteAndMetadata.pointsInfo == null) {
+ let [pointsInfo, stats] = this.makeDefaultPointsInfoAndStats(ds.points);
+ spriteAndMetadata.pointsInfo = pointsInfo;
+ spriteAndMetadata.stats = stats;
+ }
+ ds.mergeMetadata(spriteAndMetadata);
}
- this.normalizeData = this.originalDataSet.dim[1] >= THRESHOLD_DIM_NORMALIZE;
- spriteAndMetadata = spriteAndMetadata || {};
- if (spriteAndMetadata.pointsInfo == null) {
- let [pointsInfo, stats] = this.makeDefaultPointsInfoAndStats(ds.points);
- spriteAndMetadata.pointsInfo = pointsInfo;
- spriteAndMetadata.stats = stats;
+ if (this.projectorScatterPlotAdapter != null) {
+ if (ds == null) {
+ this.setProjection(null);
+ }
+ this.projectorScatterPlotAdapter.updateScatterPlotPositions();
+ this.projectorScatterPlotAdapter.updateScatterPlotAttributes();
+ this.projectorScatterPlotAdapter.resize();
+ this.projectorScatterPlotAdapter.render();
+ }
+ if (ds != null) {
+ this.dataPanel.setNormalizeData(this.normalizeData);
+ this.setCurrentDataSet(ds.getSubset());
+ this.inspectorPanel.datasetChanged();
+
+ this.inspectorPanel.metadataChanged(spriteAndMetadata);
+ this.projectionsPanel.metadataChanged(spriteAndMetadata);
+ this.dataPanel.metadataChanged(spriteAndMetadata, metadataFile);
+ // Set the container to a fixed height, otherwise in Colab the
+ // height can grow indefinitely.
+ let container = this.dom.select('#container');
+ container.style('height', container.property('clientHeight') + 'px');
+ } else {
+ this.setCurrentDataSet(null);
}
- ds.mergeMetadata(spriteAndMetadata);
- this.dataPanel.setNormalizeData(this.normalizeData);
- this.setCurrentDataSet(this.originalDataSet.getSubset());
- this.inspectorPanel.datasetChanged();
-
- this.inspectorPanel.metadataChanged(spriteAndMetadata);
- this.projectionsPanel.metadataChanged(spriteAndMetadata);
- this.dataPanel.metadataChanged(spriteAndMetadata, metadataFile);
- // Set the container to a fixed height, otherwise in Colab the
- // height can grow indefinitely.
- let container = this.dom.select('#container');
- container.style('height', container.property('clientHeight') + 'px');
- this.projectorScatterPlotAdapter.resize();
- this.projectorScatterPlotAdapter.render();
}
setSelectedTensor(run: string, tensorInfo: EmbeddingInfo) {
@@ -191,17 +202,26 @@ export class Projector extends ProjectorPolymer implements
filterDataset(pointIndices: number[]) {
const selectionSize = this.selectedPointIndices.length;
+ if (this.dataSetBeforeFilter == null) {
+ this.dataSetBeforeFilter = this.dataSet;
+ }
this.setCurrentDataSet(this.dataSet.getSubset(pointIndices));
this.dataSetFilterIndices = pointIndices;
+ this.projectorScatterPlotAdapter.updateScatterPlotPositions();
+ this.projectorScatterPlotAdapter.updateScatterPlotAttributes();
this.adjustSelectionAndHover(d3.range(selectionSize));
}
resetFilterDataset() {
- let originalPointIndices = this.selectedPointIndices.map(localIndex => {
- return this.dataSet.points[localIndex].index;
- });
- this.setCurrentDataSet(this.originalDataSet.getSubset());
+ const originalPointIndices = this.selectedPointIndices.map(
+ filteredIndex => this.dataSet.points[filteredIndex].index);
+ this.setCurrentDataSet(this.dataSetBeforeFilter);
+ if (this.projection != null) {
+ this.projection.dataSet = this.dataSetBeforeFilter;
+ }
+ this.dataSetBeforeFilter = null;
this.projectorScatterPlotAdapter.updateScatterPlotPositions();
+ this.projectorScatterPlotAdapter.updateScatterPlotAttributes();
this.dataSetFilterIndices = [];
this.adjustSelectionAndHover(originalPointIndices);
}
@@ -306,13 +326,12 @@ export class Projector extends ProjectorPolymer implements
}
private getLegendPointColorer(colorOption: ColorOption):
- (index: number) => string {
+ (ds: DataSet, index: number) => string {
if ((colorOption == null) || (colorOption.map == null)) {
return null;
}
- const colorer = (i: number) => {
- let value =
- this.dataSet.points[i].metadata[this.selectedColorOption.name];
+ const colorer = (ds: DataSet, i: number) => {
+ let value = ds.points[i].metadata[this.selectedColorOption.name];
if (value == null) {
return POINT_COLOR_MISSING;
}
@@ -347,19 +366,19 @@ export class Projector extends ProjectorPolymer implements
if (this.dataSet != null) {
this.dataSet.stopTSNE();
}
- this.dataSet = ds;
- if (this.normalizeData) {
- this.dataSet.normalize();
+ if ((ds != null) && this.normalizeData) {
+ ds.normalize();
}
- this.dim = this.dataSet.dim[1];
- this.dom.select('span.numDataPoints').text(this.dataSet.dim[0]);
- this.dom.select('span.dim').text(this.dataSet.dim[1]);
+ this.dim = (ds == null) ? 0 : ds.dim[1];
+ this.dom.select('span.numDataPoints').text((ds == null) ? '0' : ds.dim[0]);
+ this.dom.select('span.dim').text((ds == null) ? '0' : ds.dim[1]);
- this.projection = null;
+ this.dataSet = ds;
this.projectionsPanel.dataSetUpdated(
this.dataSet, this.originalDataSet, this.dim);
+ this.projectorScatterPlotAdapter.setDataSet(this.dataSet);
this.projectorScatterPlotAdapter.scatterPlot
.setCameraParametersForNextCameraCreation(null, true);
}
@@ -494,7 +513,9 @@ export class Projector extends ProjectorPolymer implements
this.setProjection(null);
{
this.projectionsPanel.disablePolymerChangesTriggerReprojection();
- this.resetFilterDataset();
+ if (this.dataSetBeforeFilter != null) {
+ this.resetFilterDataset();
+ }
if (state.filteredPoints != null) {
this.filterDataset(state.filteredPoints);
}
@@ -517,10 +538,11 @@ export class Projector extends ProjectorPolymer implements
this.projectorScatterPlotAdapter.restoreUIFromBookmark(state);
{
const dimensions = stateGetAccessorDimensions(state);
- const accessors =
- this.dataSet.getPointAccessors(state.selectedProjection, dimensions);
+ const components =
+ data.getProjectionComponents(state.selectedProjection, dimensions);
const projection = new Projection(
- state.selectedProjection, accessors, dimensions.length, this.dataSet);
+ state.selectedProjection, components, dimensions.length,
+ this.dataSet);
this.setProjection(projection);
}
this.notifySelectionChanged(state.selectedPoints);
diff --git a/tensorflow/tensorboard/scripts/generate_testdata.py b/tensorflow/tensorboard/scripts/generate_testdata.py
index 78739aa7cf..2d305a3b84 100644
--- a/tensorflow/tensorboard/scripts/generate_testdata.py
+++ b/tensorflow/tensorboard/scripts/generate_testdata.py
@@ -110,7 +110,7 @@ def WriteImageSeries(writer, tag, n_images=1):
step = 0
session = tf.Session()
p = tf.placeholder("uint8", (1, 4, 4, 3))
- s = tf.image_summary(tag, p)
+ s = tf.contrib.deprecated.image_summary(tag, p)
for _ in xrange(n_images):
im = np.random.random_integers(0, 255, (1, 4, 4, 3))
summ = session.run(s, feed_dict={p: im})
@@ -133,7 +133,7 @@ def WriteAudioSeries(writer, tag, n_audio=1):
p = tf.placeholder("float32", (frequencies_per_run, duration_frames,
num_channels))
- s = tf.audio_summary(tag, p, sample_rate)
+ s = tf.contrib.deprecated.audio_summary(tag, p, sample_rate)
for _ in xrange(n_audio):
# Generate a different frequency for each channel to show stereo works.
@@ -158,7 +158,7 @@ def GenerateTestData(path):
"""Generates the test data directory."""
run1_path = os.path.join(path, "run1")
os.makedirs(run1_path)
- writer1 = tf.train.SummaryWriter(run1_path)
+ writer1 = tf.summary.FileWriter(run1_path)
WriteScalarSeries(writer1, "foo/square", lambda x: x * x)
WriteScalarSeries(writer1, "bar/square", lambda x: x * x)
WriteScalarSeries(writer1, "foo/sin", math.sin)
@@ -171,7 +171,7 @@ def GenerateTestData(path):
run2_path = os.path.join(path, "run2")
os.makedirs(run2_path)
- writer2 = tf.train.SummaryWriter(run2_path)
+ writer2 = tf.summary.FileWriter(run2_path)
WriteScalarSeries(writer2, "foo/square", lambda x: x * x * 2)
WriteScalarSeries(writer2, "bar/square", lambda x: x * x * 3)
WriteScalarSeries(writer2, "foo/cos", lambda x: math.cos(x) * 2)
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 4951d9da81..502d698468 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -147,8 +147,9 @@ def if_not_windows(a):
return select({
"//tensorflow:windows": [],
"//conditions:default": a,
- })
+ })
+# LINT.IfChange
def tf_copts():
return (["-DEIGEN_AVOID_STL_ARRAY",
"-Iexternal/gemmlowp",
@@ -179,6 +180,7 @@ def tf_opts_nortti_if_android():
"-DGOOGLE_PROTOBUF_NO_RTTI",
"-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER",
])
+# LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt)
# Given a list of "op_lib_names" (a list of files in the ops directory
# without their .cc extensions), generate a library for that file.
@@ -552,29 +554,6 @@ def tf_kernel_library(name, prefix=None, srcs=None, gpu_srcs=None, hdrs=None,
deps = deps,
**kwargs)
-def tf_kernel_libraries(name, prefixes, deps=None, libs=None, **kwargs):
- """Makes one target per prefix, and one target that includes them all.
-
- Args:
- name: The name of the omnibus cc_library target that depends on each
- generated tf_kernel_library target.
- prefixes: A list of source file name prefixes used to generate individual
- libraries. See the definition of tf_kernel_library for details.
- deps: The dependencies list associated with each generated target.
- libs: Additional tf_kernel_library targets that should be included in the
- omnibus cc_library target but not as deps of individual libraries.
- This can be used, for example, if a library that was previously
- generated by this rule is refactored into a separate definition
- in order to specify more or fewer deps for it.
-
- Other attributes are forwarded to each individual target but not to the
- omnibus cc_library target.
- """
- for p in prefixes:
- tf_kernel_library(name=p, prefix=p, deps=deps, **kwargs)
- native.cc_library(name=name,
- deps=[":" + p for p in prefixes] + (libs or []))
-
# Bazel rules for building swig files.
def _py_wrap_cc_impl(ctx):
srcs = ctx.files.srcs
diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc
index c3d9e865b1..c544829e5f 100644
--- a/tensorflow/tools/benchmark/benchmark_model.cc
+++ b/tensorflow/tools/benchmark/benchmark_model.cc
@@ -100,6 +100,11 @@ Status RunBenchmark(const std::vector<InputLayerInfo>& inputs,
int_tensor = int_tensor.constant(0.0);
break;
}
+ case DT_UINT8: {
+ auto int_tensor = input_tensor.flat<uint8>();
+ int_tensor = int_tensor.constant(0.0);
+ break;
+ }
default:
LOG(FATAL) << "Unsupported input type: " << input.data_type;
}
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index e945df2c61..0d890f5684 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -291,6 +291,71 @@ do_buildifier(){
fi
}
+do_external_licenses_check(){
+ echo "Running do_external_licenses_check"
+ echo ""
+
+ EXTERNAL_LICENSES_CHECK_START_TIME=$(date +'%s')
+
+ EXTERNAL_DEPENDENCIES_FILE="$(mktemp)_external_dependencies.log"
+ LICENSES_FILE="$(mktemp)_licenses.log"
+ MISSING_LICENSES_FILE="$(mktemp)_missing_licenses.log"
+ EXTRA_LICENSES_FILE="$(mktemp)_extra_licenses.log"
+
+ echo "Getting external dependencies for //tensorflow/tools/pip_package:build_pip_package."
+ bazel query 'attr("licenses", "notice", deps(//tensorflow/tools/pip_package:build_pip_package))' --no_implicit_deps --no_host_deps --keep_going \
+ | egrep -v "^//tensorflow" \
+ | sed -e 's|:.*||' \
+ | sort \
+ | uniq 2>&1 \
+ | tee ${EXTERNAL_DEPENDENCIES_FILE}
+
+ echo
+ echo "Getting list of external licenses."
+ bazel query 'deps(//tensorflow/tools/pip_package:licenses)' --no_implicit_deps --no_host_deps --keep_going \
+ | egrep -v "^//tensorflow" \
+ | sed -e 's|:.*||' \
+ | sort \
+ | uniq 2>&1 \
+ | tee ${LICENSES_FILE}
+
+ echo
+ comm -1 -3 ${EXTERNAL_DEPENDENCIES_FILE} ${LICENSES_FILE} 2>&1 | tee ${EXTRA_LICENSES_FILE}
+ echo
+ comm -2 -3 ${EXTERNAL_DEPENDENCIES_FILE} ${LICENSES_FILE} 2>&1 | tee ${MISSING_LICENSES_FILE}
+
+ EXTERNAL_LICENSES_CHECK_END_TIME=$(date +'%s')
+
+ echo
+ echo "do_external_licenses_check took $((${EXTERNAL_LICENSES_CHECK_END_TIME} - ${EXTERNAL_LICENSES_CHECK_START_TIME})) s"
+ echo
+
+ if [[ -s ${MISSING_LICENSES_FILE} ]] || [[ -s ${EXTRA_LICENSES_FILE} ]] ; then
+ echo "FAIL: pip package external dependencies vs licenses mismatch."
+ if [[ -s ${MISSING_LICENSES_FILE} ]] ; then
+ echo "Missing the licenses for the following external dependencies:"
+ cat ${MISSING_LICENSES_FILE}
+ fi
+ if [[ -s ${EXTRA_LICENSES_FILE} ]] ; then
+ echo "Please remove the licenses for the following external dependencies:"
+ cat ${EXTRA_LICENSES_FILE}
+ fi
+ rm -rf ${EXTERNAL_DEPENDENCIES_FILE}
+ rm -rf ${LICENSES_FILE}
+ rm -rf ${MISSING_LICENSES_FILE}
+ rm -rf ${EXTRA_LICENSES_FILE}
+ return 1
+ else
+ echo "PASS: all external licenses included."
+ rm -rf ${EXTERNAL_DEPENDENCIES_FILE}
+ rm -rf ${LICENSES_FILE}
+ rm -rf ${MISSING_LICENSES_FILE}
+ rm -rf ${EXTRA_LICENSES_FILE}
+ return 0
+ fi
+}
+
+
# Run bazel build --nobuild to test the validity of the BUILD files
do_bazel_nobuild() {
BUILD_TARGET="//tensorflow/..."
@@ -311,8 +376,8 @@ do_bazel_nobuild() {
}
# Supply all sanity step commands and descriptions
-SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_buildifier" "do_bazel_nobuild")
-SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "buildifier check" "bazel nobuild")
+SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_buildifier" "do_bazel_nobuild" "do_external_licenses_check")
+SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "buildifier check" "bazel nobuild" "external dependencies licenses check")
INCREMENTAL_FLAG=""
diff --git a/tensorflow/tools/dist_test/python/census_widendeep.py b/tensorflow/tools/dist_test/python/census_widendeep.py
index f5510c374a..309366e467 100644
--- a/tensorflow/tools/dist_test/python/census_widendeep.py
+++ b/tensorflow/tools/dist_test/python/census_widendeep.py
@@ -53,8 +53,8 @@ FLAGS = flags.FLAGS
# Constants: Data download URLs
-TRAIN_DATA_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
-TEST_DATA_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test"
+TRAIN_DATA_URL = "http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data"
+TEST_DATA_URL = "http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test"
# Define features for the model
diff --git a/tensorflow/tools/dist_test/server/Dockerfile.test b/tensorflow/tools/dist_test/server/Dockerfile.test
index 22438f3984..e2feb2227b 100644
--- a/tensorflow/tools/dist_test/server/Dockerfile.test
+++ b/tensorflow/tools/dist_test/server/Dockerfile.test
@@ -63,9 +63,9 @@ RUN curl -o /tmp/mnist-data/t10k-images-idx3-ubyte.gz \
# Download Census data for Wide & Deep test
RUN mkdir -p /tmp/census-data
RUN curl -o /tmp/census-data/adult.data \
- https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data
+ http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data
RUN curl -o /tmp/census-data/adult.test \
- https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test
+ http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test
# Container entry point
ENTRYPOINT ["/var/tf-k8s/server/grpc_tensorflow_server_wrapper.sh"]
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index d6a6d83f9b..d0c813e84f 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -78,12 +78,36 @@ py_binary(
deps = ["//tensorflow:tensorflow_py"],
)
+filegroup(
+ name = "licenses",
+ data = [
+ "//third_party/eigen3:LICENSE",
+ "//third_party/hadoop:LICENSE.txt",
+ "@boringssl//:LICENSE",
+ "@com_googlesource_code_re2//:LICENSE",
+ "@eigen_archive//:COPYING.MPL2",
+ "@farmhash_archive//:COPYING",
+ "@gemmlowp//:LICENSE",
+ "@gif_archive//:COPYING",
+ "@grpc//:LICENSE",
+ "@highwayhash//:LICENSE",
+ "@jpeg//:LICENSE.md",
+ "@local_config_sycl//sycl:LICENSE.text",
+ "@nanopb_git//:LICENSE.txt",
+ "@png_archive//:LICENSE",
+ "@protobuf//:LICENSE",
+ "@six_archive//:LICENSE",
+ "@zlib_archive//:zlib.h",
+ ],
+)
+
sh_binary(
name = "build_pip_package",
srcs = ["build_pip_package.sh"],
data = select({
"//tensorflow:windows": [":simple_console_for_windows"],
"//conditions:default": [
+ ":licenses",
"MANIFEST.in",
"README",
"setup.py",
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 62e37f6ad4..9ca2ffc509 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -46,7 +46,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
name = "farmhash_archive",
url = "http://github.com/google/farmhash/archive/92e897b282426729f4724d91a637596c7e2fe28f.zip",
sha256 = "4c626d1f306bda2c6804ab955892f803f5245f4dcaecb4979dc08b091256da54",
- strip_prefix = "farmhash-92e897b282426729f4724d91a637596c7e2fe28f/src",
+ strip_prefix = "farmhash-92e897b282426729f4724d91a637596c7e2fe28f",
build_file = str(Label("//:farmhash.BUILD")),
)
@@ -90,7 +90,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
name = "gif_archive",
url = "http://cdimage.debian.org/mirror/xbmc.org/build-deps/sources/giflib-5.1.4.tar.gz",
sha256 = "34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1",
- strip_prefix = "giflib-5.1.4/lib",
+ strip_prefix = "giflib-5.1.4",
build_file = str(Label("//:gif.BUILD")),
)
@@ -248,3 +248,15 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
name = "zlib",
actual = "@zlib_archive//:zlib",
)
+
+ # Make junit-4.12 available as //external:junit
+ native.http_jar(
+ name = "junit_jar",
+ url = "https://github.com/junit-team/junit4/releases/download/r4.12/junit-4.12.jar",
+ sha256 = "59721f0805e223d84b90677887d9ff567dc534d7c502ca903c0c2b17f05c116a",
+ )
+
+ native.bind(
+ name = "junit",
+ actual = "@junit_jar//jar",
+ )
diff --git a/third_party/eigen3/BUILD b/third_party/eigen3/BUILD
index f697866bde..c2abf78e95 100644
--- a/third_party/eigen3/BUILD
+++ b/third_party/eigen3/BUILD
@@ -9,6 +9,8 @@ licenses([
"notice", # Portions BSD
])
+exports_files(["LICENSE"])
+
cc_library(
name = "eigen3",
hdrs = glob(["unsupported/Eigen/CXX11/src/FixedPoint/*.h"]) + [
diff --git a/third_party/hadoop/BUILD b/third_party/hadoop/BUILD
index f25208c416..9e98154400 100644
--- a/third_party/hadoop/BUILD
+++ b/third_party/hadoop/BUILD
@@ -2,6 +2,8 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
+exports_files(["LICENSE.txt"])
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/third_party/hadoop/LICENSE.txt b/third_party/hadoop/LICENSE.txt
new file mode 100644
index 0000000000..6ccfd09277
--- /dev/null
+++ b/third_party/hadoop/LICENSE.txt
@@ -0,0 +1,284 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+
+APACHE HADOOP SUBCOMPONENTS:
+
+The Apache Hadoop project contains subcomponents with separate copyright
+notices and license terms. Your use of the source code for the these
+subcomponents is subject to the terms and conditions of the following
+licenses.
+
+For the org.apache.hadoop.util.bloom.* classes:
+
+/**
+ *
+ * Copyright (c) 2005, European Commission project OneLab under contract
+ * 034819 (http://www.one-lab.org)
+ * All rights reserved.
+ * Redistribution and use in source and binary forms, with or
+ * without modification, are permitted provided that the following
+ * conditions are met:
+ * - Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * - Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in
+ * the documentation and/or other materials provided with the distribution.
+ * - Neither the name of the University Catholique de Louvain - UCL
+ * nor the names of its contributors may be used to endorse or
+ * promote products derived from this software without specific prior
+ * written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+ * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+ * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+ * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+ * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+ * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ */
+
+For portions of the native implementation of slicing-by-8 CRC calculation
+in src/main/native/src/org/apache/hadoop/util:
+
+/**
+ * Copyright 2008,2009,2010 Massachusetts Institute of Technology.
+ * All rights reserved. Use of this source code is governed by a
+ * BSD-style license that can be found in the LICENSE file.
+ */
+
+ For src/main/native/src/org/apache/hadoop/io/compress/lz4/lz4.c:
+
+/*
+ LZ4 - Fast LZ compression algorithm
+ Copyright (C) 2011, Yann Collet.
+ BSD License
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions are
+ met:
+
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following disclaimer
+ in the documentation and/or other materials provided with the
+ distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD
index 4924a49b57..e1c20e82a7 100644
--- a/third_party/llvm/llvm.BUILD
+++ b/third_party/llvm/llvm.BUILD
@@ -13,6 +13,8 @@ load(
"cmake_var_string",
)
+package(default_visibility = ["@//tensorflow/compiler/xla:internal"])
+
llvm_host_triple = "x86_64-unknown-linux_gnu"
llvm_targets = [
@@ -26,6 +28,7 @@ llvm_targets = [
llvm_target_asm_parsers = [
"AArch64",
"ARM",
+ "NVPTX",
"PowerPC",
"X86",
]
@@ -1334,6 +1337,28 @@ cc_library(
)
cc_library(
+ name = "objc_arc",
+ srcs = glob([
+ "lib/Transforms/ObjCARC/*.c",
+ "lib/Transforms/ObjCARC/*.cpp",
+ "lib/Transforms/ObjCARC/*.inc",
+ "lib/Transforms/ObjCARC/*.h",
+ ]),
+ hdrs = glob([
+ "include/llvm/Transforms/ObjCARC/*.h",
+ "include/llvm/Transforms/ObjCARC/*.def",
+ "include/llvm/Transforms/ObjCARC/*.inc",
+ ]),
+ deps = [
+ ":analysis",
+ ":config",
+ ":core",
+ ":support",
+ ":transform_utils",
+ ],
+)
+
+cc_library(
name = "orc_jit",
srcs = glob([
"lib/ExecutionEngine/Orc/*.c",
diff --git a/third_party/pcre.BUILD b/third_party/pcre.BUILD
index d9ef246672..68aadd1d40 100644
--- a/third_party/pcre.BUILD
+++ b/third_party/pcre.BUILD
@@ -1,6 +1,6 @@
licenses(["notice"]) # BSD
-exports_files(["LICENSE"])
+exports_files(["COPYING"])
cc_library(
name = "pcre",
diff --git a/third_party/sycl/sycl/BUILD.tpl b/third_party/sycl/sycl/BUILD.tpl
index 9e83b1994c..c66a9f007d 100755
--- a/third_party/sycl/sycl/BUILD.tpl
+++ b/third_party/sycl/sycl/BUILD.tpl
@@ -7,6 +7,8 @@ load("platform", "readlink_command")
package(default_visibility = ["//visibility:public"])
+exports_files(["LICENSE.text"])
+
config_setting(
name = "using_sycl",
values = {
diff --git a/third_party/sycl/sycl/LICENSE.text.tpl b/third_party/sycl/sycl/LICENSE.text.tpl
new file mode 100644
index 0000000000..0c2955c4d7
--- /dev/null
+++ b/third_party/sycl/sycl/LICENSE.text.tpl
@@ -0,0 +1,268 @@
+
+---------------------------------------------------------------------
+
+SOFTWARE LICENSE AGREEMENT
+
+---------------------------------------------------------------------
+---------------------------------------------------------------------
+
+By downloading, installing, copying, or otherwise using the
+ComputeCpp Community Edition software, including any associated
+components, media, printed materials, and electronic documentation
+("Software"), the user agrees to the following terms and conditions
+of this Software License Agreement ("Agreement"). Please read the
+terms of this Agreement carefully before beginning your download, as
+pressing the "I AGREE" button at the end of this Agreement will
+confirm your assent. If you do not agree to these terms, then
+Codeplay Software Limited is unwilling to license the Software to
+you; so please press the "CANCEL" button to cancel your download.
+
+ 1. License. Codeplay Software Ltd., a company incorporated in
+ England and Wales with registered number 04567874 and having its
+ registered office at Regent House, 316 Beulah Hill, London,
+ United Kingdom, SE19 3HF ("Codeplay") hereby grants the user,
+ free of charge, a non-exclusive worldwide license to use and
+ replicate (but not modify) the Software for any use, whether
+ commercial or non-commercial, in accordance with this Agreement.
+ Codeplay reserves all rights to the Software that are not
+ expressly granted by this Agreement.
+ 2. Redistribution. The user may copy and redistribute unmodified
+ copies of only those components of the Software which are
+ specified below ("Redistributable Components"), in object code
+ form, as part of the user’s software applications or libraries
+ ("Applications"). The user acknowledges and agrees that it has no
+ right to modify the Redistributable Components in any way. Any
+ use of the Redistributable Components within the user’s
+ Applications will continue to be subject to the terms and
+ conditions of this Agreement, and the user must also distribute a
+ copy of this Agreement and reproduce and include all notices of
+ copyrights or other proprietary rights in the Software. The
+ user’s redistribution of the Redistributable Components will not
+ entitle it to any payment from Codeplay. The user may not
+ transfer any of its rights or obligations under this Agreement.
+
++-------------------------------------------+
+|Redistributable Component|File Name |
+|-------------------------+-----------------|
+|Runtime (for Linux) |libComputeCpp.so |
+|-------------------------+-----------------|
+|Runtime (for Windows) |libComputeCpp.dll|
++-------------------------------------------+
+
+ 3. Restrictions. The user shall not:
+
+ a. circumvent or bypass any technological protection measures in
+ or relating to the Software;
+ b. use the Software to perform any unauthorized transfer of
+ information or for any illegal purpose;
+ c. de-compile, decrypt, disassemble, hack, emulate, exploit or
+ reverse-engineer the Software (other than to the limited
+ extent permitted by law);
+ d. copy or redistribute any components of the Software that are
+ not listed in the table of Redistributable Components;
+ e. publish, rent, lease, sell, export, import, or lend the
+ Software;
+ f. represent in any way that it is selling the Software itself
+ or any license to use the Software, nor refer to Codeplay or
+ ComputeCpp within its marketing materials, without the
+ express prior written permission of Codeplay.
+ 4. Support. Codeplay does not provide any guarantees of support for
+ the Software to the user. Codeplay will use reasonable endeavours
+ to respond to users' support requests, for the most recent
+ release only, via the community support website at https://
+ computecpp.codeplay.com.
+ 5. Intellectual Property. The Software is owned by Codeplay or its
+ licensors, and is protected by the copyright laws of the United
+ Kingdom and other countries and international treaty provisions.
+ Codeplay (and/or its licensors, as the case may be) retains all
+ copyrights, trade secrets and other proprietary rights in the
+ Software, including the rights to make and license the use of all
+ copies. To the extent that any patents owned by Codeplay or its
+ licensors relate to any component of the Software, the licence
+ granted to the user in accordance with this Agreement allows for
+ the lawful use of such patents but only for the purposes of this
+ Agreement and not further or otherwise. Therefore, the user may
+ make no copies of the Software, or the written materials that
+ accompany the Software, or reproduce it in any way, except as set
+ forth above.
+ 6. Terms. This Agreement is effective until terminated. Codeplay or
+ the user may terminate it immediately at any time. Any violation
+ of the terms of this Agreement by the user will result in
+ immediate termination by Codeplay. Upon termination, the user
+ must return or destroy the Software and accompanying materials
+ and notify Codeplay of its actions by email to info@codeplay.com.
+ 7. NO WARRANTIES. Codeplay expressly disclaims any warranty for the
+ Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
+ ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
+ WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
+ AND NON-INFRINGEMENT. IN NO EVENT SHALL CODEPLAY BE LIABLE FOR
+ ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
+ CONTRACT, DELICT OR TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE. In particular, Codeplay provides no guarantees of
+ application performance on the target hardware.
+ 8. General. The invalidity of any portion or provision of this
+ Agreement shall not affect any other portions or provisions. This
+ Agreement shall be governed by the laws of Scotland. This
+ Agreement is the complete and exclusive agreement between the
+ user and Codeplay regarding the Software, and it supersedes any
+ prior agreement, oral or written, and any other communication
+ between the user and Codeplay relating to the subject matter of
+ the Agreement. Any amendment or modification of this Agreement
+ must be in writing and signed by both parties. If the user does
+ not agree to the terms of this Agreement, the user must not
+ install or use the Software.
+ 9. Third Party Licenses. The following licenses are for third-party
+ components included in the software.
+
+ a. License for Clang/LLVM compiler technology components:
+
+==============================================================================
+
+LLVM Release License
+
+==============================================================================
+
+University of Illinois/NCSA
+
+Open Source License
+
+Copyright (c) 2007-2014 University of Illinois at Urbana-Champaign.
+
+All rights reserved.
+
+Developed by:
+
+ LLVM Team
+
+ University of Illinois at Urbana-Champaign
+
+ http://llvm.org
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal with
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+of the Software, and to permit persons to whom the Software is furnished to do
+so, subject to the following conditions:
+
+ * Redistributions of source code must retain the above copyright notice,
+ this list of conditions and the following disclaimers.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimers in the
+ documentation and/or other materials provided with the distribution.
+
+ * Neither the names of the LLVM Team, University of Illinois at
+ Urbana-Champaign, nor the names of its contributors may be used to
+ endorse or promote products derived from this Software without specific
+ prior written permission.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE
+SOFTWARE.
+
+==============================================================================
+
+ b. License for OpenBSD regex components:
+
+$OpenBSD: COPYRIGHT,v 1.3 2003/06/02 20:18:36 millert Exp $
+Copyright 1992, 1993, 1994 Henry Spencer. All rights reserved.
+This software is not subject to any license of the American Telephone
+and Telegraph Company or of the Regents of the University of California.
+Permission is granted to anyone to use this software for any purpose on
+any computer system, and to alter it and redistribute it, subject
+to the following restrictions:
+
+1. The author is not responsible for the consequences of use of this
+ software, no matter how awful, even if they arise from flaws in it.
+
+2. The origin of this software must not be misrepresented, either by
+ explicit claim or by omission. Since few users ever read sources,
+ credits must appear in the documentation.
+
+3. Altered versions must be plainly marked as such, and must not be
+ misrepresented as being the original software. Since few users
+ ever read sources, credits must appear in the documentation.
+
+4. This notice may not be removed or altered.
+
+=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
+
+/*-
+ * Copyright (c) 1994
+ * The Regents of the University of California. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ * 1. Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ * 3. Neither the name of the University nor the names of its contributors
+ * may be used to endorse or promote products derived from this software
+ * without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
+ * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+ * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
+ * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
+ * SUCH DAMAGE.
+ *
+ * @(#)COPYRIGHT8.1 (Berkeley) 3/16/94
+ */
+
+ c. License for MD5 components:
+
+/*
+ * This code is derived from (original license follows):
+ *
+ * This is an OpenSSL-compatible implementation of the RSA Data Security, Inc.
+ * MD5 Message-Digest Algorithm (RFC 1321).
+ *
+ * Homepage:
+ * http://openwall.info/wiki/people/solar/software/public-domain-source-code/md5
+ *
+ * Author:
+ * Alexander Peslyak, better known as Solar Designer <solar at openwall.com>
+ *
+ * This software was written by Alexander Peslyak in 2001. No copyright is
+ * claimed, and the software is hereby placed in the public domain.
+ * In case this attempt to disclaim copyright and place the software in the
+ * public domain is deemed null and void, then the software is
+ * Copyright (c) 2001 Alexander Peslyak and it is hereby released to the
+ * general public under the following terms:
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted.
+ *
+ * There's ABSOLUTELY NO WARRANTY, express or implied.
+ *
+ * (This is a heavily cut-down "BSD license".)
+ *
+ * This differs from Colin Plumb's older public domain implementation in that
+ * no exactly 32-bit integer data type is required (any 32-bit or wider
+ * unsigned integer data type will do), there's no compile-time endianness
+ * configuration, and the function prototypes match OpenSSL's. No code from
+ * Colin Plumb's implementation has been reused; this comment merely compares
+ * the properties of the two independent implementations.
+ *
+ * The primary goals of this implementation are portability and ease of use.
+ * It is meant to be fast, but not as fast as possible. Some known
+ * optimizations are not included to reduce source code size and avoid
+ * compile-time configuration.
+ */
+
+
diff --git a/third_party/sycl/sycl_configure.bzl b/third_party/sycl/sycl_configure.bzl
index 6102ed49c2..38bd7759de 100644
--- a/third_party/sycl/sycl_configure.bzl
+++ b/third_party/sycl/sycl_configure.bzl
@@ -135,6 +135,7 @@ def _create_dummy_repository(repository_ctx):
# Set up BUILD file for sycl/.
_file(repository_ctx, "sycl:build_defs.bzl")
_tpl(repository_ctx, "sycl:BUILD")
+ _tpl(repository_ctx, "sycl:LICENSE.text")
_tpl(repository_ctx, "sycl:platform.bzl")
# Create dummy files for the SYCL toolkit since they are still required by