diff options
424 files changed, 22987 insertions, 7028 deletions
diff --git a/RELEASE.md b/RELEASE.md index 0bb2a92ba6..ead29f0c54 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -59,6 +59,13 @@ * Removes RegisterShape from public API. Use C++ shape function registration instead. indexing now starts from 1 instead of 0, and `bus_id==0` is used where previously `BUS_ANY` was used. +* Most RNN cells and RNN functions now use different variable scopes to be + consistent with layers (`tf.contrib.layers`). This means old checkpoints + written using this code will not load after this change without providing + `Saver` a list of variable renames. Examples of variable scope changes + include `RNN` -> `rnn` in `tf.nn.rnn`, `tf.nn.dynamic_rnn` and moving from + `Linear/Matrix` -> `weights` and `Linear/Bias` -> `biases` in most RNN cells. +* Deprecated tf.select op. tf.where should be used instead. * `Env::FileExists` and `FileSystem::FileExists` now return a `tensorflow::Status` intead of a bool. Any callers to this function can be converted to a bool by adding `.ok()` to the call. @@ -118,7 +125,6 @@ Yuming Wang, Zafar Takhirov, @zhongyuk, Ziming Dong, @guotong1988 We are also grateful to all who filed issues or helped resolve them, asked and answered questions, and were part of inspiring discussions. ->>>>>>> r0.12 # Release 0.11.0 @@ -29,6 +29,13 @@ new_http_archive( sha256 = "d13569f6a98159de37e92e9c8ec4dae8f674fbf475f69fe6199b514f756d4364" ) +new_http_archive( + name = "mobile_multibox", + build_file = "models.BUILD", + url = "https://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1.zip", + sha256 = "b4c178fd6236dcf0a20d25d07c45eebe85281263978c6a6f1dfc49d75befc45f" +) + # TENSORBOARD_BOWER_AUTOGENERATED_BELOW_THIS_LINE_DO_NOT_EDIT new_http_archive( diff --git a/eigen.BUILD b/eigen.BUILD index 8ce28ac076..3fd710dfd4 100644 --- a/eigen.BUILD +++ b/eigen.BUILD @@ -9,6 +9,8 @@ licenses([ "notice", # Portions BSD ]) +exports_files(["COPYING.MPL2",]) + # License-restricted (i.e. not reciprocal or notice) files inside Eigen/... EIGEN_RESTRICTED_FILES = [ "Eigen/src/OrderingMethods/Amd.h", diff --git a/farmhash.BUILD b/farmhash.BUILD index b41c799f8f..d054797a56 100644 --- a/farmhash.BUILD +++ b/farmhash.BUILD @@ -1,5 +1,7 @@ licenses(["notice"]) # MIT +exports_files(["COPYING"]) + config_setting( name = "windows", values = { @@ -10,13 +12,13 @@ config_setting( cc_library( name = "farmhash", - srcs = ["farmhash.cc"], - hdrs = ["farmhash.h"], + srcs = ["src/farmhash.cc"], + hdrs = ["src/farmhash.h"], # Disable __builtin_expect support on Windows copts = select({ ":windows" : ["/DFARMHASH_OPTIONAL_BUILTIN_EXPECT"], "//conditions:default" : [], }), - includes = ["."], + includes = ["src/."], visibility = ["//visibility:public"], ) @@ -3,22 +3,24 @@ licenses(["notice"]) # MIT +exports_files(["COPYING"]) + cc_library( name = "gif", srcs = [ - "dgif_lib.c", - "egif_lib.c", - "gif_err.c", - "gif_font.c", - "gif_hash.c", - "gif_hash.h", - "gif_lib_private.h", - "gifalloc.c", - "openbsd-reallocarray.c", - "quantize.c", + "lib/dgif_lib.c", + "lib/egif_lib.c", + "lib/gif_err.c", + "lib/gif_font.c", + "lib/gif_hash.c", + "lib/gif_hash.h", + "lib/gif_lib_private.h", + "lib/gifalloc.c", + "lib/openbsd-reallocarray.c", + "lib/quantize.c", ], - hdrs = ["gif_lib.h"], - includes = ["."], + hdrs = ["lib/gif_lib.h"], + includes = ["lib/."], visibility = ["//visibility:public"], deps = select({ ":windows": [":windows_polyfill"], diff --git a/gmock.BUILD b/gmock.BUILD index 66ed60750d..501e322529 100644 --- a/gmock.BUILD +++ b/gmock.BUILD @@ -4,6 +4,8 @@ licenses(["notice"]) # 3-clause BSD +exports_files(["LICENSE"]) + cc_library( name = "gtest", srcs = [ diff --git a/grpc.BUILD b/grpc.BUILD index e74da683e3..e501db57e5 100644 --- a/grpc.BUILD +++ b/grpc.BUILD @@ -45,6 +45,8 @@ licenses(["notice"]) # 3-clause BSD package(default_visibility = ["//visibility:public"]) +exports_files(["LICENSE"]) + genrule( name = "pb_h", outs = ["third_party/nanopb/pb.h"], diff --git a/jsoncpp.BUILD b/jsoncpp.BUILD index 765bf15129..ce672a72ec 100644 --- a/jsoncpp.BUILD +++ b/jsoncpp.BUILD @@ -1,5 +1,7 @@ licenses(["unencumbered"]) # Public Domain or MIT +exports_files(["LICENSE"]) + cc_library( name = "jsoncpp", srcs = [ diff --git a/nanopb.BUILD b/nanopb.BUILD index 8b428689e1..d21866911b 100644 --- a/nanopb.BUILD +++ b/nanopb.BUILD @@ -3,6 +3,8 @@ licenses(["notice"]) # zlib license +exports_files(["LICENSE.txt"]) + cc_library( name = "nanopb", srcs = [ @@ -3,6 +3,8 @@ licenses(["notice"]) # BSD/MIT-like license +exports_files(["LICENSE"]) + cc_library( name = "png", srcs = [ @@ -4,6 +4,8 @@ licenses(["notice"]) # MIT +exports_files(["LICENSE"]) + py_library( name = "six", srcs = ["six.py"], diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 1e6c5d5947..2d7c28feb7 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -160,6 +160,9 @@ filegroup( "//tensorflow/g3doc/how_tos/adding_an_op:all_files", "//tensorflow/g3doc/tutorials:all_files", "//tensorflow/go:all_files", + "//tensorflow/java:all_files", + "//tensorflow/java/src/main/java/org/tensorflow/examples:all_files", + "//tensorflow/java/src/main/native:all_files", "//tensorflow/models/embedding:all_files", "//tensorflow/models/image/alexnet:all_files", "//tensorflow/models/image/cifar10:all_files", diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 90c87210b1..bfa1386d96 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -45,7 +45,7 @@ tf_cc_test( name = "loader_test", srcs = ["loader_test.cc"], data = [ - ":saved_model_half_plus_two", + "//tensorflow/python/saved_model/example:saved_model_half_plus_two_data", ], linkstatic = 1, deps = [ @@ -61,14 +61,6 @@ tf_cc_test( ], ) -filegroup( - name = "saved_model_half_plus_two", - srcs = glob([ - "testdata/half_plus_two_pbtxt/**", - "testdata/half_plus_two_sharded/**", - ]), -) - # ----------------------------------------------------------------------------- # Google-internal targets. diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 82f30c23f6..dbbcc79802 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -29,9 +29,10 @@ limitations under the License. namespace tensorflow { namespace { -constexpr char kTestDataPbTxt[] = "cc/saved_model/testdata/half_plus_two_pbtxt"; +constexpr char kTestDataPbTxt[] = + "python/saved_model/example/saved_model_half_plus_two_pbtxt/00000123"; constexpr char kTestDataSharded[] = - "cc/saved_model/testdata/half_plus_two_sharded"; + "python/saved_model/example/saved_model_half_plus_two/00000123"; class LoaderTest : public ::testing::Test { protected: diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/assets/foo.txt b/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/assets/foo.txt deleted file mode 100644 index f9ff036688..0000000000 --- a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/assets/foo.txt +++ /dev/null @@ -1 +0,0 @@ -asset-file-contents
\ No newline at end of file diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/saved_model.pbtxt b/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/saved_model.pbtxt deleted file mode 100644 index 693262eb4d..0000000000 --- a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/saved_model.pbtxt +++ /dev/null @@ -1,1951 +0,0 @@ -saved_model_schema_version: 1 -meta_graphs { - meta_info_def { - stripped_op_list { - op { - name: "Add" - input_arg { - name: "x" - type_attr: "T" - } - input_arg { - name: "y" - type_attr: "T" - } - output_arg { - name: "z" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_UINT8 - type: DT_INT8 - type: DT_INT16 - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - type: DT_STRING - } - } - } - } - op { - name: "Assign" - input_arg { - name: "ref" - type_attr: "T" - is_ref: true - } - input_arg { - name: "value" - type_attr: "T" - } - output_arg { - name: "output_ref" - type_attr: "T" - is_ref: true - } - attr { - name: "T" - type: "type" - } - attr { - name: "validate_shape" - type: "bool" - default_value { - b: true - } - } - attr { - name: "use_locking" - type: "bool" - default_value { - b: true - } - } - allows_uninitialized_input: true - } - op { - name: "Const" - output_arg { - name: "output" - type_attr: "dtype" - } - attr { - name: "value" - type: "tensor" - } - attr { - name: "dtype" - type: "type" - } - } - op { - name: "Identity" - input_arg { - name: "input" - type_attr: "T" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - } - op { - name: "MergeV2Checkpoints" - input_arg { - name: "checkpoint_prefixes" - type: DT_STRING - } - input_arg { - name: "destination_prefix" - type: DT_STRING - } - attr { - name: "delete_old_dirs" - type: "bool" - default_value { - b: true - } - } - } - op { - name: "Mul" - input_arg { - name: "x" - type_attr: "T" - } - input_arg { - name: "y" - type_attr: "T" - } - output_arg { - name: "z" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_HALF - type: DT_FLOAT - type: DT_DOUBLE - type: DT_UINT8 - type: DT_INT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - } - } - } - is_commutative: true - } - op { - name: "NoOp" - } - op { - name: "Pack" - input_arg { - name: "values" - type_attr: "T" - number_attr: "N" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "N" - type: "int" - has_minimum: true - minimum: 1 - } - attr { - name: "T" - type: "type" - } - attr { - name: "axis" - type: "int" - default_value { - i: 0 - } - } - } - op { - name: "ParseExample" - input_arg { - name: "serialized" - type: DT_STRING - } - input_arg { - name: "names" - type: DT_STRING - } - input_arg { - name: "sparse_keys" - type: DT_STRING - number_attr: "Nsparse" - } - input_arg { - name: "dense_keys" - type: DT_STRING - number_attr: "Ndense" - } - input_arg { - name: "dense_defaults" - type_list_attr: "Tdense" - } - output_arg { - name: "sparse_indices" - type: DT_INT64 - number_attr: "Nsparse" - } - output_arg { - name: "sparse_values" - type_list_attr: "sparse_types" - } - output_arg { - name: "sparse_shapes" - type: DT_INT64 - number_attr: "Nsparse" - } - output_arg { - name: "dense_values" - type_list_attr: "Tdense" - } - attr { - name: "Nsparse" - type: "int" - has_minimum: true - } - attr { - name: "Ndense" - type: "int" - has_minimum: true - } - attr { - name: "sparse_types" - type: "list(type)" - has_minimum: true - allowed_values { - list { - type: DT_FLOAT - type: DT_INT64 - type: DT_STRING - } - } - } - attr { - name: "Tdense" - type: "list(type)" - has_minimum: true - allowed_values { - list { - type: DT_FLOAT - type: DT_INT64 - type: DT_STRING - } - } - } - attr { - name: "dense_shapes" - type: "list(shape)" - has_minimum: true - } - } - op { - name: "Placeholder" - output_arg { - name: "output" - type_attr: "dtype" - } - attr { - name: "dtype" - type: "type" - } - attr { - name: "shape" - type: "shape" - default_value { - shape { - } - } - } - } - op { - name: "RestoreV2" - input_arg { - name: "prefix" - type: DT_STRING - } - input_arg { - name: "tensor_names" - type: DT_STRING - } - input_arg { - name: "shape_and_slices" - type: DT_STRING - } - output_arg { - name: "tensors" - type_list_attr: "dtypes" - } - attr { - name: "dtypes" - type: "list(type)" - has_minimum: true - minimum: 1 - } - } - op { - name: "SaveV2" - input_arg { - name: "prefix" - type: DT_STRING - } - input_arg { - name: "tensor_names" - type: DT_STRING - } - input_arg { - name: "shape_and_slices" - type: DT_STRING - } - input_arg { - name: "tensors" - type_list_attr: "dtypes" - } - attr { - name: "dtypes" - type: "list(type)" - has_minimum: true - minimum: 1 - } - } - op { - name: "ShardedFilename" - input_arg { - name: "basename" - type: DT_STRING - } - input_arg { - name: "shard" - type: DT_INT32 - } - input_arg { - name: "num_shards" - type: DT_INT32 - } - output_arg { - name: "filename" - type: DT_STRING - } - } - op { - name: "StringJoin" - input_arg { - name: "inputs" - type: DT_STRING - number_attr: "N" - } - output_arg { - name: "output" - type: DT_STRING - } - attr { - name: "N" - type: "int" - has_minimum: true - minimum: 1 - } - attr { - name: "separator" - type: "string" - default_value { - s: "" - } - } - } - op { - name: "Variable" - output_arg { - name: "ref" - type_attr: "dtype" - is_ref: true - } - attr { - name: "shape" - type: "shape" - } - attr { - name: "dtype" - type: "type" - } - attr { - name: "container" - type: "string" - default_value { - s: "" - } - } - attr { - name: "shared_name" - type: "string" - default_value { - s: "" - } - } - is_stateful: true - } - } - tags: "serve" - } - graph_def { - node { - name: "a/initial_value" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.5 - } - } - } - } - node { - name: "a" - op: "Variable" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } - } - node { - name: "a/Assign" - op: "Assign" - input: "a" - input: "a/initial_value" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@a" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } - } - node { - name: "a/read" - op: "Identity" - input: "a" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@a" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "b/initial_value" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 2.0 - } - } - } - } - node { - name: "b" - op: "Variable" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } - } - node { - name: "b/Assign" - op: "Assign" - input: "b" - input: "b/initial_value" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@b" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } - } - node { - name: "b/read" - op: "Identity" - input: "b" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@b" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "tf_example" - op: "Placeholder" - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "shape" - value { - shape { - } - } - } - } - node { - name: "ParseExample/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - } - } - } - } - } - } - node { - name: "ParseExample/ParseExample/names" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - } - } - } - } - } - } - node { - name: "ParseExample/ParseExample/dense_keys_0" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "x" - } - } - } - } - node { - name: "ParseExample/ParseExample" - op: "ParseExample" - input: "tf_example" - input: "ParseExample/ParseExample/names" - input: "ParseExample/ParseExample/dense_keys_0" - input: "ParseExample/Const" - attr { - key: "Ndense" - value { - i: 1 - } - } - attr { - key: "Nsparse" - value { - i: 0 - } - } - attr { - key: "Tdense" - value { - list { - type: DT_FLOAT - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 1 - } - } - } - } - } - attr { - key: "dense_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "sparse_types" - value { - list { - } - } - } - } - node { - name: "x" - op: "Identity" - input: "ParseExample/ParseExample" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 1 - } - } - } - } - } - } - node { - name: "Mul" - op: "Mul" - input: "a/read" - input: "x" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 1 - } - } - } - } - } - } - node { - name: "y" - op: "Add" - input: "Mul" - input: "b/read" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 1 - } - } - } - } - } - } - node { - name: "Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "/tmp/original/export/assets/foo.txt" - } - } - } - } - node { - name: "filename_tensor/initial_value" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "foo.txt" - } - } - } - } - node { - name: "filename_tensor" - op: "Variable" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "shape" - value { - shape { - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } - } - node { - name: "filename_tensor/Assign" - op: "Assign" - input: "filename_tensor" - input: "filename_tensor/initial_value" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_class" - value { - list { - s: "loc:@filename_tensor" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } - } - node { - name: "filename_tensor/read" - op: "Identity" - input: "filename_tensor" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_class" - value { - list { - s: "loc:@filename_tensor" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "Assign/value" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "foo.txt" - } - } - } - } - node { - name: "Assign" - op: "Assign" - input: "filename_tensor" - input: "Assign/value" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_class" - value { - list { - s: "loc:@filename_tensor" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } - attr { - key: "validate_shape" - value { - b: true - } - } - } - node { - name: "Identity" - op: "Identity" - input: "y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 1 - } - } - } - } - } - } - node { - name: "init" - op: "NoOp" - input: "^a/Assign" - input: "^b/Assign" - } - node { - name: "group_deps" - op: "NoOp" - input: "^Assign" - } - node { - name: "save/Const" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "model" - } - } - } - } - node { - name: "save/StringJoin/inputs_1" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "_temp_ff2bd25218b646ea9ed224eecdce5e79/part" - } - } - } - } - node { - name: "save/StringJoin" - op: "StringJoin" - input: "save/Const" - input: "save/StringJoin/inputs_1" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "separator" - value { - s: "" - } - } - } - node { - name: "save/num_shards" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node { - name: "save/ShardedFilename/shard" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - } - node { - name: "save/ShardedFilename" - op: "ShardedFilename" - input: "save/StringJoin" - input: "save/ShardedFilename/shard" - input: "save/num_shards" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "save/SaveV2/tensor_names" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 2 - } - } - string_val: "a" - string_val: "b" - } - } - } - } - node { - name: "save/SaveV2/shape_and_slices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 2 - } - } - string_val: "" - string_val: "" - } - } - } - } - node { - name: "save/SaveV2" - op: "SaveV2" - input: "save/ShardedFilename" - input: "save/SaveV2/tensor_names" - input: "save/SaveV2/shape_and_slices" - input: "a" - input: "b" - attr { - key: "dtypes" - value { - list { - type: DT_FLOAT - type: DT_FLOAT - } - } - } - } - node { - name: "save/control_dependency" - op: "Identity" - input: "save/ShardedFilename" - input: "^save/SaveV2" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_class" - value { - list { - s: "loc:@save/ShardedFilename" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "save/MergeV2Checkpoints/checkpoint_prefixes" - op: "Pack" - input: "save/ShardedFilename" - input: "^save/control_dependency" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } - } - node { - name: "save/MergeV2Checkpoints" - op: "MergeV2Checkpoints" - input: "save/MergeV2Checkpoints/checkpoint_prefixes" - input: "save/Const" - attr { - key: "delete_old_dirs" - value { - b: true - } - } - } - node { - name: "save/Identity" - op: "Identity" - input: "save/Const" - input: "^save/control_dependency" - input: "^save/MergeV2Checkpoints" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node { - name: "save/RestoreV2/tensor_names" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "a" - } - } - } - } - node { - name: "save/RestoreV2/shape_and_slices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "" - } - } - } - } - node { - name: "save/RestoreV2" - op: "RestoreV2" - input: "save/Const" - input: "save/RestoreV2/tensor_names" - input: "save/RestoreV2/shape_and_slices" - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "dtypes" - value { - list { - type: DT_FLOAT - } - } - } - } - node { - name: "save/Assign" - op: "Assign" - input: "a" - input: "save/RestoreV2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@a" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } - } - node { - name: "save/RestoreV2_1/tensor_names" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "b" - } - } - } - } - node { - name: "save/RestoreV2_1/shape_and_slices" - op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - dim { - size: 1 - } - } - string_val: "" - } - } - } - } - node { - name: "save/RestoreV2_1" - op: "RestoreV2" - input: "save/Const" - input: "save/RestoreV2_1/tensor_names" - input: "save/RestoreV2_1/shape_and_slices" - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "dtypes" - value { - list { - type: DT_FLOAT - } - } - } - } - node { - name: "save/Assign_1" - op: "Assign" - input: "b" - input: "save/RestoreV2_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@b" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "use_locking" - value { - b: true - } - } - attr { - key: "validate_shape" - value { - b: true - } - } - } - node { - name: "save/restore_shard" - op: "NoOp" - input: "^save/Assign" - input: "^save/Assign_1" - } - node { - name: "save/restore_all" - op: "NoOp" - input: "^save/restore_shard" - } - versions { - producer: 15 - } - } - saver_def { - filename_tensor_name: "save/Const:0" - save_tensor_name: "save/Identity:0" - restore_op_name: "save/restore_all" - max_to_keep: 5 - sharded: true - keep_checkpoint_every_n_hours: 10000.0 - version: V2 - } - collection_def { - key: "asset_filepaths" - value { - node_list { - value: "Const:0" - } - } - } - collection_def { - key: "legacy_init_op" - value { - node_list { - value: "group_deps" - } - } - } - collection_def { - key: "saved_model_assets" - value { - any_list { - value { - type_url: "type.googleapis.com/tensorflow.AssetFileDef" - value: "\n\t\n\007Const:0\022\007foo.txt" - } - } - } - } - collection_def { - key: "trainable_variables" - value { - bytes_list { - value: "\n\003a:0\022\010a/Assign\032\010a/read:0" - value: "\n\003b:0\022\010b/Assign\032\010b/read:0" - } - } - } - collection_def { - key: "variables" - value { - bytes_list { - value: "\n\003a:0\022\010a/Assign\032\010a/read:0" - value: "\n\003b:0\022\010b/Assign\032\010b/read:0" - } - } - } - signature_def { - key: "tensorflow/serving/regress" - value { - inputs { - key: "inputs" - value { - name: "tf_example:0" - } - } - outputs { - key: "outputs" - value { - name: "Identity:0" - } - } - method_name: "tensorflow/serving/regress" - } - } -} diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.data-00000-of-00001 Binary files differdeleted file mode 100644 index 20bc7d454d..0000000000 --- a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.data-00000-of-00001 +++ /dev/null diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.index b/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.index Binary files differdeleted file mode 100644 index e7df518f5b..0000000000 --- a/tensorflow/cc/saved_model/testdata/half_plus_two_pbtxt/variables/variables.index +++ /dev/null diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/assets/foo.txt b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/assets/foo.txt deleted file mode 100644 index f9ff036688..0000000000 --- a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/assets/foo.txt +++ /dev/null @@ -1 +0,0 @@ -asset-file-contents
\ No newline at end of file diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/saved_model.pb b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/saved_model.pb Binary files differdeleted file mode 100644 index 0df49f2168..0000000000 --- a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/saved_model.pb +++ /dev/null diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.data-00000-of-00001 Binary files differdeleted file mode 100644 index 20bc7d454d..0000000000 --- a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.data-00000-of-00001 +++ /dev/null diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.index b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.index Binary files differdeleted file mode 100644 index e7df518f5b..0000000000 --- a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/variables.index +++ /dev/null diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index 10ff80c9cd..5d6710ea5c 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -33,6 +33,16 @@ Status QueueRunner::New(const QueueRunnerDef& queue_runner_def, return (*result)->Init(queue_runner_def); } +void QueueRunner::AddErrorCallback(const std::function<void(Status)>& cb) { + mutex_lock l(cb_mu_); + callbacks_.push_back(cb); +} + +void QueueRunner::ClearErrorCallbacks() { + mutex_lock l(cb_mu_); + callbacks_.clear(); +} + Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) { queue_name_ = queue_runner_def.queue_name(); enqueue_op_names_.clear(); @@ -100,7 +110,6 @@ Status QueueRunner::Start(Session* sess, int wait_for) { } void QueueRunner::Stop(Session* sess) { - DCHECK(coord_ != nullptr); if (cancel_op_name_.empty()) { return; } @@ -127,6 +136,10 @@ void QueueRunner::UpdateStatus(const Status& status) { if (coord_) { coord_->ReportStatus(status); } + mutex_lock l(cb_mu_); + for (auto& cb : callbacks_) { + cb(status); + } } void QueueRunner::Run(Session* sess, const string& enqueue_op) { diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index 273eb39671..fd9f97a958 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -46,6 +46,12 @@ class QueueRunner : public RunnerInterface { static Status New(const QueueRunnerDef& queue_runner_def, Coordinator* coord, std::unique_ptr<QueueRunner>* result); + // Adds a callback that the queue runner will call when it detects an error. + void AddErrorCallback(const std::function<void(Status)>& cb); + + // Delete the previously registered callbacks. + void ClearErrorCallbacks(); + // The destructor would join all the threads. ~QueueRunner(); @@ -56,6 +62,11 @@ class QueueRunner : public RunnerInterface { // specified time (in milliseconds) for the queues to start to fill up. Status Start(Session* sess, int wait_for_ms); + // Requests to stop and runs the cancel op. It would be called in a separate + // thread when coordinator is set. If there is no coordinator it should be + // called before calling Join. + void Stop(Session* sess); + // Joins all the threads. Returns okay if all threads run successfully; // otherwise returns the first captured failure status. Status Join() final; @@ -72,10 +83,6 @@ class QueueRunner : public RunnerInterface { // The Run function for each thread. void Run(Session* sess, const string& enqueue_op); - // Requests to stop and runs the cancel op. It would be called in a separate - // thread when coordinator is set. - void Stop(Session* sess); - // Updates the internal status; it only keeps OK or the first unexpected error // status. void UpdateStatus(const Status& status); @@ -100,6 +107,9 @@ class QueueRunner : public RunnerInterface { std::unique_ptr<BlockingCounter> counter_; Coordinator* coord_; + + mutex cb_mu_; + std::vector<std::function<void(Status)>> callbacks_; }; } // namespace tensorflow diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc index 0e7e94b40f..1661c5c91b 100644 --- a/tensorflow/cc/training/queue_runner_test.cc +++ b/tensorflow/cc/training/queue_runner_test.cc @@ -328,5 +328,21 @@ TEST(QueueRunnerTest, TestCoordinatorStop) { TF_EXPECT_OK(coord.Join()); } +TEST(QueueRunnerTest, CallbackCalledOnError) { + GraphDef graph_def = BuildSimpleGraph(); + auto session = BuildSessionAndInitVariable(graph_def); + + QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( + kQueueName, {kIllegalOpName1, kIllegalOpName2}, kCountUpToOpName, "", {}); + + std::unique_ptr<QueueRunner> qr; + TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); + bool error_caught = false; + qr->AddErrorCallback([&error_caught](const Status&) { error_caught = true; }); + TF_EXPECT_OK(qr->Start(session.get())); + qr->Join(); + EXPECT_TRUE(error_caught); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/contrib/android/cmake/CMakeLists.txt b/tensorflow/contrib/android/cmake/CMakeLists.txt new file mode 100644 index 0000000000..a7e0d581aa --- /dev/null +++ b/tensorflow/contrib/android/cmake/CMakeLists.txt @@ -0,0 +1,61 @@ +# +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +cmake_minimum_required(VERSION 3.4.1) +include(ExternalProject) + +# TENSORFLOW_ROOT_DIR: +# root directory of tensorflow repo +# used for shared source files and pre-built libs +get_filename_component(TENSORFLOW_ROOT_DIR ../../../.. ABSOLUTE) +set(PREBUILT_DIR ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen) + +add_library(lib_proto STATIC IMPORTED ) +set_target_properties(lib_proto PROPERTIES IMPORTED_LOCATION + ${PREBUILT_DIR}/protobuf/lib/libprotobuf.a) + +add_library(lib_tf STATIC IMPORTED ) +set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION + ${PREBUILT_DIR}/lib/libtensorflow-core.a) +# Change to compile flags should be replicated into bazel build file +# LINT.IfChange +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -fno-rtti -fno-exceptions \ + -fpic -O2 -mfpu=neon -DTF_LEAN_BINARY \ + -DGOOGLE_PROTOBUF_NO_RTTI \ + -DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER") +# LINT.ThenChange(//tensorflow/tensorflow.bzl) + +set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} \ + -Wl,--allow-multiple-definition \ + -Wl,--whole-archive") + +file(GLOB tensorflow_inference_sources + ${CMAKE_CURRENT_SOURCE_DIR}/../jni/*.cc) +add_library(tensorflow_inference SHARED ${tensorflow_inference_sources}) + +# Include libraries needed for hello-jni lib +target_link_libraries(tensorflow_inference + android + log + m + z + lib_tf + lib_proto) +include_directories( + ${PREBUILT_DIR}/proto + ${PREBUILT_DIR}/protobuf/include + ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/downloads/eigen + ${TENSORFLOW_ROOT_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/..) diff --git a/tensorflow/contrib/android/cmake/README.md b/tensorflow/contrib/android/cmake/README.md new file mode 100644 index 0000000000..ad9e1720c7 --- /dev/null +++ b/tensorflow/contrib/android/cmake/README.md @@ -0,0 +1,44 @@ +TensorFlow-Android-Inference +============================ +Android Java interface to the TensorFlow native APIs + +Usage +----- +Add TensorFlow-Android-Inference as a dependency of your Android application + +* settings.gradle + +``` +include ':TensorFlow-Android-Inference' +findProject(":TensorFlow-Android-Inference").projectDir = + new File("${/path/to/tensorflow_repo}/contrib/android/cmake") +``` + +* application's build.gradle (adding dependency): + +``` +debugCompile project(path: ':tensorflow_inference', configuration: 'debug') +releaseCompile project(path: ':tensorflow_inference', configuration: 'release') +``` +Note: this makes native code in the lib traceable from your app. + +Dependencies +------------ +TensorFlow-Android-Inference depends on the TensorFlow static libs already built in your +local TensorFlow repo directory. For Linux/Mac OS, build_all_android.sh is used +in build.gradle to build it. It DOES take time to build the core libs; +so, by default, it is commented out to avoid confusion (otherwise +Android Studio would appear to hang during opening the project). +To enable it, refer to the comment in + +* build.gradle + +Output +------ +- TensorFlow-Inference-debug.aar +- TensorFlow-Inference-release.aar + +File libtensorflow_inference.so should be packed under jni/${ANDROID_ABI}/ +in the above aar, and it is transparent to the app as it will acccess them via +equivalent java APIs. + diff --git a/tensorflow/contrib/android/cmake/build.gradle b/tensorflow/contrib/android/cmake/build.gradle new file mode 100644 index 0000000000..8791fac18a --- /dev/null +++ b/tensorflow/contrib/android/cmake/build.gradle @@ -0,0 +1,97 @@ +apply plugin: 'com.android.library' + +android { + compileSdkVersion 24 + buildToolsVersion "24.0.2" + + // for debugging native code purpose + publishNonDefault true + + defaultConfig { + archivesBaseName = "Tensorflow-Android-Inference" + minSdkVersion 21 + targetSdkVersion 21 + versionCode 1 + versionName "1.0" + ndk { + abiFilters 'armeabi-v7a' + } + externalNativeBuild { + cmake { + arguments '-DANDROID_TOOLCHAIN=gcc', + '-DANDROID_STL=gnustl_static' + } + } + } + sourceSets { + main { + java.srcDirs = ["../java"] + } + } + + externalNativeBuild { + cmake { + path 'CMakeLists.txt' + } + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), + 'proguard-rules.pro' + } + } +} + +// Build libtensorflow-core.a if necessary +// Note: the environment needs to be set up already +// [ such as installing autoconfig, make, etc ] +// How to use: +// 1) install all of the necessary tools to build libtensorflow-core.a +// 2) inside Android Studio IDE, uncomment buildTensorFlow in +// whenTaskAdded{...} +// 3) re-sync and re-build. It could take a long time if NOT building +// with multiple processes. +import org.apache.tools.ant.taskdefs.condition.Os + +Properties properties = new Properties() +properties.load(project.rootProject.file('local.properties') + .newDataInputStream()) +def ndkDir = properties.getProperty('ndk.dir') +if (ndkDir == null || ndkDir == "") { + ndkDir = System.getenv('ANDROID_NDK_HOME') +} + +if(! Os.isFamily(Os.FAMILY_WINDOWS)) { + // This script is for non-Windows OS. For Windows OS, MANUALLY build + // (or copy the built) libs/headers to the + // ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen + // refer to CMakeLists.txt about lib and header directories for details + task buildTensorflow(type: Exec) { + group 'buildTensorflowLib' + workingDir getProjectDir().toString() + '/../../../../' + environment PATH: '/opt/local/bin:/opt/local/sbin:' + + System.getenv('PATH') + environment NDK_ROOT: ndkDir + commandLine 'tensorflow/contrib/makefile/build_all_android.sh' + } + + tasks.whenTaskAdded { task -> + group 'buildTensorflowLib' + if (task.name.toLowerCase().contains('sources')) { + def tensorflowTarget = new File(getProjectDir().toString() + + '/../../makefile/gen/lib/libtensorflow-core.a') + if (!tensorflowTarget.exists()) { + // Note: + // just uncomment this line to use it: + // it can take long time to build by default + // it is disabled to avoid false first impression + // task.dependsOn buildTensorflow + } + } + } +} + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) +} diff --git a/tensorflow/contrib/android/cmake/src/main/AndroidManifest.xml b/tensorflow/contrib/android/cmake/src/main/AndroidManifest.xml new file mode 100644 index 0000000000..bced47e046 --- /dev/null +++ b/tensorflow/contrib/android/cmake/src/main/AndroidManifest.xml @@ -0,0 +1,9 @@ +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="org.tensorflow.contrib.android"> + + <application android:allowBackup="true" android:label="@string/app_name" + android:supportsRtl="true"> + + </application> + +</manifest> diff --git a/tensorflow/contrib/android/cmake/src/main/res/values/strings.xml b/tensorflow/contrib/android/cmake/src/main/res/values/strings.xml new file mode 100644 index 0000000000..92dc3a1baf --- /dev/null +++ b/tensorflow/contrib/android/cmake/src/main/res/values/strings.xml @@ -0,0 +1,3 @@ +<resources> + <string name="app_name">TensorFlowInference</string> +</resources> diff --git a/tensorflow/contrib/bayesflow/python/ops/special_math.py b/tensorflow/contrib/bayesflow/python/ops/special_math.py index 77e7c0e093..5e5cde5c1f 100644 --- a/tensorflow/contrib/bayesflow/python/ops/special_math.py +++ b/tensorflow/contrib/bayesflow/python/ops/special_math.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops __all__ = [ @@ -90,9 +91,9 @@ def _ndtr(x): 0.5 * math.sqrt(2.), dtype=x.dtype, name="half_sqrt_2") w = x * half_sqrt_2 z = math_ops.abs(w) - y = math_ops.select(math_ops.less(z, half_sqrt_2), + y = array_ops.where(math_ops.less(z, half_sqrt_2), 1. + math_ops.erf(w), - math_ops.select(math_ops.greater(w, 0.), + array_ops.where(math_ops.greater(w, 0.), 2. - math_ops.erfc(z), math_ops.erfc(z))) return 0.5 * y @@ -180,10 +181,10 @@ def log_ndtr(x, series_order=3, name="log_ndtr"): # the gradient of a select involves the calculation 1*dy+0*(-inf)=nan # regardless of whether dy is finite. Note that the minimum is a NOP if # the branch is chosen. - return math_ops.select( + return array_ops.where( math_ops.greater(x, upper_segment), -_ndtr(-x), # log(1-x) ~= -x, x << 1 - math_ops.select(math_ops.greater(x, lower_segment), + array_ops.where(math_ops.greater(x, lower_segment), math_ops.log(_ndtr(math_ops.maximum(x, lower_segment))), _log_ndtr_lower(math_ops.minimum(x, lower_segment), series_order))) diff --git a/tensorflow/contrib/distributions/python/ops/beta.py b/tensorflow/contrib/distributions/python/ops/beta.py index 1abe77e295..84839f40c4 100644 --- a/tensorflow/contrib/distributions/python/ops/beta.py +++ b/tensorflow/contrib/distributions/python/ops/beta.py @@ -252,7 +252,7 @@ class Beta(distribution.Distribution): mode = (self.a - 1.)/ (self.a_b_sum - 2.) if self.allow_nan_stats: nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype()) - return math_ops.select( + return array_ops.where( math_ops.logical_and( math_ops.greater(self.a, 1.), math_ops.greater(self.b, 1.)), diff --git a/tensorflow/contrib/distributions/python/ops/bijector.py b/tensorflow/contrib/distributions/python/ops/bijector.py index c89108b3f6..b29c272405 100644 --- a/tensorflow/contrib/distributions/python/ops/bijector.py +++ b/tensorflow/contrib/distributions/python/ops/bijector.py @@ -1468,7 +1468,7 @@ class ScaleAndShift(Bijector): batch_ndims: `Tensor` (0D, `int32`). The ndims of the `batch` portion. """ ndims = array_ops.rank(scale) - left = math_ops.select( + left = array_ops.where( math_ops.reduce_any([ math_ops.reduce_all([ math_ops.equal(ndims, 0), @@ -1478,7 +1478,7 @@ class ScaleAndShift(Bijector): math_ops.equal(ndims, 2), math_ops.equal(event_ndims, 1) ])]), 1, 0) - right = math_ops.select(math_ops.equal(event_ndims, 0), 2, 0) + right = array_ops.where(math_ops.equal(event_ndims, 0), 2, 0) pad = array_ops.concat(0, ( array_ops.ones([left], dtype=dtypes.int32), array_ops.shape(scale), diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet.py b/tensorflow/contrib/distributions/python/ops/dirichlet.py index 2a2ea4ec26..c485038fb2 100644 --- a/tensorflow/contrib/distributions/python/ops/dirichlet.py +++ b/tensorflow/contrib/distributions/python/ops/dirichlet.py @@ -239,7 +239,7 @@ class Dirichlet(distribution.Distribution): if self.allow_nan_stats: nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype()) shape = array_ops.concat(0, (self.batch_shape(), self.event_shape())) - return math_ops.select( + return array_ops.where( math_ops.greater(self.alpha, 1.), mode, array_ops.fill(shape, nan, name="nan")) diff --git a/tensorflow/contrib/distributions/python/ops/distribution.py b/tensorflow/contrib/distributions/python/ops/distribution.py index cd193f4d6d..9f52e1f0dd 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution.py +++ b/tensorflow/contrib/distributions/python/ops/distribution.py @@ -130,7 +130,10 @@ class _DistributionMeta(abc.ABCMeta): if not baseclasses: # Nothing to be done for Distribution raise TypeError("Expected non-empty baseclass. Does Distribution " "not subclass _BaseDistribution?") - base = baseclasses[0] + which_base = [ + base for base in baseclasses + if base == _BaseDistribution or issubclass(base, Distribution)] + base = which_base[0] if base == _BaseDistribution: # Nothing to be done for Distribution return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs) if not issubclass(base, Distribution): diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 094854b672..1da931c08e 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -335,7 +335,7 @@ def rotate_transpose(x, shift, name="rotate_transpose"): # Finally, we transform shift by modulo length so it can be specified # independently from the array upon which it operates (like python). ndims = array_ops.rank(x) - shift = math_ops.select(math_ops.less(shift, 0), + shift = array_ops.where(math_ops.less(shift, 0), math_ops.mod(-shift, ndims), ndims - math_ops.mod(shift, ndims)) first = math_ops.range(0, shift) @@ -396,8 +396,8 @@ def pick_vector(cond, false_vector.name, false_vector.dtype)) n = array_ops.shape(true_vector)[0] return array_ops.slice(array_ops.concat(0, (true_vector, false_vector)), - [math_ops.select(cond, 0, n)], - [math_ops.select(cond, n, -1)]) + [array_ops.where(cond, 0, n)], + [array_ops.where(cond, n, -1)]) def gen_new_seed(seed, salt): @@ -578,8 +578,8 @@ class AppendDocstring(object): if "\n" in value: raise ValueError( "Parameter description for \"%s\" contains newlines." % key) - bullets.append("* <b>`%s`</b>: %s" % (key, value)) - self._additional_note += ("\n\n##### <b>`condition_kwargs`</b>:\n\n" + + bullets.append("* `%s`: %s" % (key, value)) + self._additional_note += ("\n\n##### `condition_kwargs`:\n\n" + "\n".join(bullets)) def __call__(self, fn): diff --git a/tensorflow/contrib/distributions/python/ops/gamma.py b/tensorflow/contrib/distributions/python/ops/gamma.py index 6d2a1ee953..fcc5281c55 100644 --- a/tensorflow/contrib/distributions/python/ops/gamma.py +++ b/tensorflow/contrib/distributions/python/ops/gamma.py @@ -208,7 +208,7 @@ class Gamma(distribution.Distribution): mode = (self.alpha - 1.) / self.beta if self.allow_nan_stats: nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype()) - return math_ops.select( + return array_ops.where( self.alpha >= 1., mode, array_ops.fill(self.batch_shape(), nan, name="nan")) diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py new file mode 100644 index 0000000000..d0f3ce4933 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -0,0 +1,205 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Gumbel distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np +from tensorflow.contrib.distributions.python.ops import distribution +from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util +from tensorflow.python.framework import common_shapes +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops + + +class _Gumbel(distribution.Distribution): + """The scalar Gumbel distribution with location and scale parameters. + + #### Mathematical details + + The PDF of this distribution is: + + ```pdf(x) = exp(-(x - loc)/scale - exp(-(x - loc)/scale))``` + + with support on (-inf, inf). The CDF of this distribution is: + + ```cdf(x) = exp(-exp(-(x - loc)/scale))``` + + #### Examples + + Examples of initialization of one or a batch of distributions. + + ```python + # Define a single scalar Gumbel distribution. + dist = tf.contrib.distributions.Gumbel(loc=0., scale=3.) + + # Evaluate the cdf at 1, returning a scalar. + dist.cdf(1.) + + # Define a batch of two scalar valued Gumbels. + # The first has mean 1 and scale 11, the second 2 and 22. + dist = tf.contrib.distributions.Gumbel(loc=[1, 2.], scale=[11, 22.]) + + # Evaluate the pdf of the first distribution on 0, and the second on 1.5, + # returning a length two tensor. + dist.pdf([0, 1.5]) + + # Get 3 samples, returning a 3 x 2 tensor. + dist.sample([3]) + ``` + + Arguments are broadcast when possible. + + ```python + # Define a batch of two scalar valued Logistics. + # Both have mean 1, but different scales. + dist = tf.contrib.distributions.Gumbel(loc=1., scale=[11, 22.]) + + # Evaluate the pdf of both distributions on the same point, 3.0, + # returning a length 2 tensor. + dist.pdf(3.0) + ``` + + """ + + def __init__(self, + loc, + scale, + validate_args=False, + allow_nan_stats=True, + name="Gumbel"): + """Construct Gumbel distributions with location and scale `loc` and `scale`. + + The parameters `loc` and `scale` must be shaped in a way that supports + broadcasting (e.g. `loc + scale` is a valid operation). + + Args: + loc: Floating point tensor, the means of the distribution(s). + scale: Floating point tensor, the scales of the distribution(s). + scale must contain only positive values. + validate_args: `Boolean`, default `False`. Whether to assert that + `scale > 0`. If `validate_args` is `False`, correct output is not + guaranteed when input is invalid. + allow_nan_stats: `Boolean`, default `True`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. + name: The name to give Ops created by the initializer. + + Raises: + TypeError: if loc and scale are different dtypes. + """ + parameters = locals() + parameters.pop("self") + with ops.name_scope(name, values=[loc, scale]) as ns: + with ops.control_dependencies([check_ops.assert_positive(scale)] if + validate_args else []): + self._loc = array_ops.identity(loc, name="loc") + self._scale = array_ops.identity(scale, name="scale") + contrib_tensor_util.assert_same_float_dtype((self._loc, self._scale)) + super(_Gumbel, self).__init__( + dtype=self._scale.dtype, + is_continuous=True, + is_reparameterized=True, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[self._loc, self._scale], + name=ns) + + @staticmethod + def _param_shapes(sample_shape): + return dict( + zip(("loc", "scale"), ([ops.convert_to_tensor( + sample_shape, dtype=dtypes.int32)] * 2))) + + @property + def loc(self): + """Distribution parameter for the location.""" + return self._loc + + @property + def scale(self): + """Distribution parameter for scale.""" + return self._scale + + def _batch_shape(self): + return array_ops.shape(self.loc + self.scale) + + def _get_batch_shape(self): + return common_shapes.broadcast_shape(self.loc.get_shape(), + self.scale.get_shape()) + + def _event_shape(self): + return constant_op.constant([], dtype=dtypes.int32) + + def _get_event_shape(self): + return tensor_shape.scalar() + + def _sample_n(self, n, seed=None): + shape = array_ops.concat(0, ([n], array_ops.shape(self.mean()))) + np_dtype = self.dtype.as_numpy_dtype() + minval = np.nextafter(np_dtype(0), np_dtype(1)) + uniform = random_ops.random_uniform(shape=shape, + minval=minval, + maxval=1, + dtype=self.dtype, + seed=seed) + sampled = -math_ops.log(-math_ops.log(uniform)) + return sampled * self.scale + self.loc + + def _log_prob(self, x): + z = self._z(x) + return - z - math_ops.log(self.scale) - math_ops.exp(-z) + + def _prob(self, x): + return math_ops.exp(self._log_prob(x)) + + def _log_cdf(self, x): + return -math_ops.exp(-self._z(x)) + + def _cdf(self, x): + return math_ops.exp(-math_ops.exp(-self._z(x))) + + def _entropy(self): + # Use broadcasting rules to calculate the full broadcast sigma. + scale = self.scale * array_ops.ones_like(self.loc) + return 1 + math_ops.log(scale) + np.euler_gamma + + def _mean(self): + return self.loc + self.scale * np.euler_gamma + + def _variance(self): + return math_ops.square(self.std()) + + def _std(self): + return self.scale * array_ops.ones_like(self.loc) * math.pi / math.sqrt(6) + + def _mode(self): + return self.loc * array_ops.ones_like(self.scale) + + def _z(self, x): + """Standardize input `x` to a unit logistic.""" + with ops.name_scope("standardize", values=[x]): + return (x - self.loc) / self.scale diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index ffe72f09a9..feb0bf2f90 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -185,7 +185,7 @@ class InverseGamma(distribution.Distribution): mean = self.beta / (self.alpha - 1.) if self.allow_nan_stats: nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype()) - return math_ops.select( + return array_ops.where( self.alpha > 1., mean, array_ops.fill(self.batch_shape(), nan, name="nan")) else: @@ -204,7 +204,7 @@ class InverseGamma(distribution.Distribution): (math_ops.square(self.alpha - 1.) * (self.alpha - 2.))) if self.allow_nan_stats: nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype()) - return math_ops.select( + return array_ops.where( self.alpha > 2., var, array_ops.fill(self.batch_shape(), nan, name="nan")) else: diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py new file mode 100644 index 0000000000..9a20c653ae --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -0,0 +1,210 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Logistic distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np +from tensorflow.contrib.distributions.python.ops import distribution +from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util +from tensorflow.python.framework import common_shapes +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops + + +class _Logistic(distribution.Distribution): + """The scalar Logistic distribution with location and scale parameters. + + #### Mathematical details + + The CDF of this distribution is: + + ```cdf(x) = 1/(1+exp(-(x - loc) / scale))``` + + with support on (-inf, inf). + + #### Examples + + Examples of initialization of one or a batch of distributions. + + ```python + # Define a single scalar Logistic distribution. + dist = tf.contrib.distributions.Logistic(loc=0., scale=3.) + + # Evaluate the cdf at 1, returning a scalar. + dist.cdf(1.) + + # Define a batch of two scalar valued Logistics. + # The first has mean 1 and scale 11, the second 2 and 22. + dist = tf.contrib.distributions.Logistic(loc=[1, 2.], scale=[11, 22.]) + + # Evaluate the pdf of the first distribution on 0, and the second on 1.5, + # returning a length two tensor. + dist.pdf([0, 1.5]) + + # Get 3 samples, returning a 3 x 2 tensor. + dist.sample([3]) + ``` + + Arguments are broadcast when possible. + + ```python + # Define a batch of two scalar valued Logistics. + # Both have mean 1, but different scales. + dist = tf.contrib.distributions.Logistic(loc=1., scale=[11, 22.]) + + # Evaluate the pdf of both distributions on the same point, 3.0, + # returning a length 2 tensor. + dist.pdf(3.0) + ``` + + """ + + def __init__(self, + loc, + scale, + validate_args=False, + allow_nan_stats=True, + name="Logistic"): + """Construct Logistic distributions with mean and scale `loc` and `scale`. + + The parameters `loc` and `scale` must be shaped in a way that supports + broadcasting (e.g. `loc + scale` is a valid operation). + + Args: + loc: Floating point tensor, the means of the distribution(s). + scale: Floating point tensor, the scales of the distribution(s). + scale must contain only positive values. + validate_args: `Boolean`, default `False`. Whether to assert that + `scale > 0`. If `validate_args` is `False`, correct output is not + guaranteed when input is invalid. + allow_nan_stats: `Boolean`, default `True`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. + name: The name to give Ops created by the initializer. + + Raises: + TypeError: if loc and scale are different dtypes. + """ + parameters = locals() + parameters.pop("self") + with ops.name_scope(name, values=[loc, scale]) as ns: + with ops.control_dependencies([check_ops.assert_positive(scale)] if + validate_args else []): + self._loc = array_ops.identity(loc, name="loc") + self._scale = array_ops.identity(scale, name="scale") + contrib_tensor_util.assert_same_float_dtype((self._loc, self._scale)) + super(_Logistic, self).__init__( + dtype=self._scale.dtype, + is_continuous=True, + is_reparameterized=True, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[self._loc, self._scale], + name=ns) + + @staticmethod + def _param_shapes(sample_shape): + return dict( + zip(("loc", "scale"), ([ops.convert_to_tensor( + sample_shape, dtype=dtypes.int32)] * 2))) + + @property + def loc(self): + """Distribution parameter for the location.""" + return self._loc + + @property + def scale(self): + """Distribution parameter for scale.""" + return self._scale + + def _batch_shape(self): + return array_ops.shape(self.loc + self.scale) + + def _get_batch_shape(self): + return common_shapes.broadcast_shape(self.loc.get_shape(), + self.scale.get_shape()) + + def _event_shape(self): + return constant_op.constant([], dtype=dtypes.int32) + + def _get_event_shape(self): + return tensor_shape.scalar() + + def _sample_n(self, n, seed=None): + shape = array_ops.concat(0, ([n], array_ops.shape(self.mean()))) + np_dtype = self.dtype.as_numpy_dtype() + minval = np.nextafter(np_dtype(0), np_dtype(1)) + uniform = random_ops.random_uniform(shape=shape, + minval=minval, + maxval=1, + dtype=self.dtype, + seed=seed) + sampled = math_ops.log(uniform) - math_ops.log(1-uniform) + return sampled * self.scale + self.loc + + def _log_prob(self, x): + z = self._z(x) + return - z - math_ops.log(self.scale) - 2*nn_ops.softplus(-z) + + def _prob(self, x): + return math_ops.exp(self._log_prob(x)) + + def _log_cdf(self, x): + return nn_ops.softplus(-self._z(x)) + + def _cdf(self, x): + return math_ops.sigmoid(self._z(x)) + + def _log_survival_function(self, x): + return nn_ops.softplus(self._z(x)) + + def _survival_function(self, x): + return math_ops.sigmoid(-self._z(x)) + + def _entropy(self): + # Use broadcasting rules to calculate the full broadcast sigma. + scale = self.scale * array_ops.ones_like(self.loc) + return 2 + math_ops.log(scale) + + def _mean(self): + return self.loc * array_ops.ones_like(self.scale) + + def _variance(self): + return math_ops.square(self.std()) + + def _std(self): + return self.scale * array_ops.ones_like(self.loc) * math.pi / math.sqrt(3) + + def _mode(self): + return self._mean() + + def _z(self, x): + """Standardize input `x` to a unit logistic.""" + with ops.name_scope("standardize", values=[x]): + return (x - self.loc) / self.scale diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py new file mode 100644 index 0000000000..bb05c90a12 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py @@ -0,0 +1,262 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The OneHotCategorical distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import distribution +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops import kullback_leibler +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops + + +class _OneHotCategorical(distribution.Distribution): + """OneHotCategorical distribution. + + The categorical distribution is parameterized by the log-probabilities + of a set of classes. The difference between OneHotCategorical and Categorical + distributions is that OneHotCategorical is a discrete distribution over + one-hot bit vectors whereas Categorical is a discrete distribution over + positive integers. + + This class provides methods to create indexed batches of OneHotCategorical + distributions. If the provided `logits` or `p` is rank 2 or higher, for + every fixed set of leading dimensions, the last dimension represents one + single OneHotCategorical distribution. When calling distribution + functions (e.g. `dist.prob(x)`), `logits` and `x` are broadcast to the + same shape (if possible). In all cases, the last dimension of `logits/x` + represents single OneHotCategorical distributions. + + #### Examples + + Creates a 3-class distiribution, with the 2nd class, the most likely to be + drawn from. + + ```python + p = [0.1, 0.5, 0.4] + dist = OneHotCategorical(p=p) + ``` + + Creates a 3-class distiribution, with the 2nd class the most likely to be + drawn from, using logits. + + ```python + logits = [-2, 2, 0] + dist = OneHotCategorical(logits=logits) + ``` + + Creates a 3-class distribution, with the 3rd class is most likely to be drawn. + + ```python + # counts is a scalar. + p = [0.1, 0.4, 0.5] + dist = OneHotCategorical(p=p) + dist.pmf([0,1,0]) # Shape [] + + # p will be broadcast to [[0.1, 0.4, 0.5], [0.1, 0.4, 0.5]] to match. + samples = [[0,1,0], [1,0,0]] + dist.pmf(samples) # Shape [2] + ``` + + """ + + def __init__( + self, + logits=None, + p=None, + dtype=dtypes.int32, + validate_args=False, + allow_nan_stats=True, + name="OneHotCategorical"): + """Initialize OneHotCategorical distributions using class log-probabilities. + + Args: + logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities + of a set of Categorical distributions. The first `N - 1` dimensions + index into a batch of independent distributions and the last dimension + represents a vector of logits for each class. Only one of `logits` or + `p` should be passed in. + p: An N-D `Tensor`, `N >= 1`, representing the probabilities + of a set of Categorical distributions. The first `N - 1` dimensions + index into a batch of independent distributions and the last dimension + represents a vector of probabilities for each class. Only one of + `logits` or `p` should be passed in. + dtype: The type of the event samples (default: int32). + validate_args: Unused in this distribution. + allow_nan_stats: `Boolean`, default `True`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. + name: A name for this distribution (optional). + """ + parameters = locals() + parameters.pop("self") + with ops.name_scope(name, values=[logits]) as ns: + self._logits, self._p = distribution_util.get_logits_and_prob( + name=name, logits=logits, p=p, validate_args=validate_args, + multidimensional=True) + + logits_shape_static = self._logits.get_shape().with_rank_at_least(1) + if logits_shape_static.ndims is not None: + self._batch_rank = ops.convert_to_tensor( + logits_shape_static.ndims - 1, + dtype=dtypes.int32, + name="batch_rank") + else: + with ops.name_scope(name="batch_rank"): + self._batch_rank = array_ops.rank(self._logits) - 1 + + logits_shape = array_ops.shape(self._logits, name="logits_shape") + if logits_shape_static[-1].value is not None: + self._num_classes = ops.convert_to_tensor( + logits_shape_static[-1].value, + dtype=dtypes.int32, + name="num_classes") + else: + self._num_classes = array_ops.gather(logits_shape, + self._batch_rank, + name="num_classes") + + if logits_shape_static[:-1].is_fully_defined(): + self._batch_shape_val = constant_op.constant( + logits_shape_static[:-1].as_list(), + dtype=dtypes.int32, + name="batch_shape") + else: + with ops.name_scope(name="batch_shape"): + self._batch_shape_val = logits_shape[:-1] + super(_OneHotCategorical, self).__init__( + dtype=dtype, + is_continuous=False, + is_reparameterized=False, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[self._logits, self._num_classes], + name=ns) + + @property + def num_classes(self): + """Scalar `int32` tensor: the number of classes.""" + return self._num_classes + + @property + def logits(self): + """Vector of coordinatewise logits.""" + return self._logits + + @property + def p(self): + """Vector of probabilities summing to one. + + Each element is the probability of drawing that coordinate.""" + return self._p + + def _batch_shape(self): + # Use identity to inherit callers "name". + return array_ops.identity(self._batch_shape_val) + + def _get_batch_shape(self): + return self.logits.get_shape()[:-1] + + def _event_shape(self): + return array_ops.shape(self.logits)[-1] + + def _get_event_shape(self): + return self.logits.get_shape().with_rank_at_least(1)[-1:] + + def _sample_n(self, n, seed=None): + sample_shape = array_ops.concat(0, ([n], array_ops.shape(self.logits))) + logits = self.logits + if logits.get_shape().ndims == 2: + logits_2d = logits + else: + logits_2d = array_ops.reshape(logits, [-1, self.num_classes]) + samples = random_ops.multinomial(logits_2d, n, seed=seed) + samples = array_ops.transpose(samples) + samples = array_ops.one_hot(samples, self.num_classes, dtype=self.dtype) + ret = array_ops.reshape(samples, sample_shape) + return ret + + def _log_prob(self, x): + x = ops.convert_to_tensor(x, name="x") + # broadcast logits or x if need be. + logits = self.logits + if (not x.get_shape().is_fully_defined() or + not logits.get_shape().is_fully_defined() or + x.get_shape() != logits.get_shape()): + logits = array_ops.ones_like(x, dtype=logits.dtype) * logits + x = array_ops.ones_like(logits, dtype=x.dtype) * x + + logits_shape = array_ops.shape(logits) + if logits.get_shape().ndims == 2: + logits_2d = logits + x_2d = x + else: + logits_2d = array_ops.reshape(logits, [-1, self.num_classes]) + x_2d = array_ops.reshape(x, [-1, self.num_classes]) + ret = -nn_ops.softmax_cross_entropy_with_logits(logits_2d, x_2d) + ret = array_ops.reshape(ret, logits_shape) + return ret + + def _prob(self, x): + return math_ops.exp(self._log_prob(x)) + + def _entropy(self): + if self.logits.get_shape().ndims == 2: + logits_2d = self.logits + else: + logits_2d = array_ops.reshape(self.logits, [-1, self.num_classes]) + histogram_2d = nn_ops.softmax(logits_2d) + ret = array_ops.reshape( + nn_ops.softmax_cross_entropy_with_logits(logits_2d, histogram_2d), + self.batch_shape()) + ret.set_shape(self.get_batch_shape()) + return ret + + def _mode(self): + ret = math_ops.argmax(self.logits, axis=self._batch_rank) + ret = array_ops.one_hot(ret, self.num_classes, dtype=self.dtype) + ret.set_shape(self.logits.get_shape()) + return ret + + +@kullback_leibler.RegisterKL(_OneHotCategorical, _OneHotCategorical) +def _kl_categorical_categorical(a, b, name=None): + """Calculate the batched KL divergence KL(a || b) with a, b OneHotCategorical. + + Args: + a: instance of a OneHotCategorical distribution object. + b: instance of a OneHotCategorical distribution object. + name: (optional) Name to use for created operations. + default is "kl_categorical_categorical". + + Returns: + Batchwise KL(a || b) + """ + with ops.name_scope( + name, "kl_categorical_categorical", [a.logits, b.logits]): + # sum(p*ln(p/q)) + return math_ops.reduce_sum( + nn_ops.softmax(a.logits)*(nn_ops.log_softmax(a.logits) + - nn_ops.log_softmax(b.logits)), reduction_indices=[-1]) diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index bca3f99604..fd3ec553c0 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -284,11 +284,11 @@ class QuantizedDistribution(distributions.Distribution): result_so_far = math_ops.ceil(x_samps) if lower_cutoff is not None: - result_so_far = math_ops.select(result_so_far < lower_cutoff, + result_so_far = array_ops.where(result_so_far < lower_cutoff, lower_cutoff * ones, result_so_far) if upper_cutoff is not None: - result_so_far = math_ops.select(result_so_far > upper_cutoff, + result_so_far = array_ops.where(result_so_far > upper_cutoff, upper_cutoff * ones, result_so_far) return result_so_far @@ -327,8 +327,8 @@ class QuantizedDistribution(distributions.Distribution): # In either case, we are doing Log[ exp{big} - exp{small} ] # We want to use the sf items precisely when we are on the right side of the # median, which occurs when logsf_y < logcdf_y. - big = math_ops.select(logsf_y < logcdf_y, logsf_y_minus_1, logcdf_y) - small = math_ops.select(logsf_y < logcdf_y, logsf_y, logcdf_y_minus_1) + big = array_ops.where(logsf_y < logcdf_y, logsf_y_minus_1, logcdf_y) + small = array_ops.where(logsf_y < logcdf_y, logsf_y, logcdf_y_minus_1) return _logsum_expbig_minus_expsmall(big, small) @@ -357,7 +357,7 @@ class QuantizedDistribution(distributions.Distribution): cdf_y_minus_1 = self.cdf(y - 1) # sf_prob has greater precision iff we're on the right side of the median. - return math_ops.select( + return array_ops.where( sf_y < cdf_y, # True iff we're on the right side of the median. sf_y_minus_1 - sf_y, cdf_y - cdf_y_minus_1) @@ -386,9 +386,9 @@ class QuantizedDistribution(distributions.Distribution): # Re-define values at the cutoffs. if lower_cutoff is not None: neg_inf = -np.inf * array_ops.ones_like(result_so_far) - result_so_far = math_ops.select(j < lower_cutoff, neg_inf, result_so_far) + result_so_far = array_ops.where(j < lower_cutoff, neg_inf, result_so_far) if upper_cutoff is not None: - result_so_far = math_ops.select(j >= upper_cutoff, + result_so_far = array_ops.where(j >= upper_cutoff, array_ops.zeros_like(result_so_far), result_so_far) @@ -418,11 +418,11 @@ class QuantizedDistribution(distributions.Distribution): # Re-define values at the cutoffs. if lower_cutoff is not None: - result_so_far = math_ops.select(j < lower_cutoff, + result_so_far = array_ops.where(j < lower_cutoff, array_ops.zeros_like(result_so_far), result_so_far) if upper_cutoff is not None: - result_so_far = math_ops.select(j >= upper_cutoff, + result_so_far = array_ops.where(j >= upper_cutoff, array_ops.ones_like(result_so_far), result_so_far) @@ -452,12 +452,12 @@ class QuantizedDistribution(distributions.Distribution): # Re-define values at the cutoffs. if lower_cutoff is not None: - result_so_far = math_ops.select(j < lower_cutoff, + result_so_far = array_ops.where(j < lower_cutoff, array_ops.zeros_like(result_so_far), result_so_far) if upper_cutoff is not None: neg_inf = -np.inf * array_ops.ones_like(result_so_far) - result_so_far = math_ops.select(j >= upper_cutoff, neg_inf, result_so_far) + result_so_far = array_ops.where(j >= upper_cutoff, neg_inf, result_so_far) return result_so_far @@ -485,11 +485,11 @@ class QuantizedDistribution(distributions.Distribution): # Re-define values at the cutoffs. if lower_cutoff is not None: - result_so_far = math_ops.select(j < lower_cutoff, + result_so_far = array_ops.where(j < lower_cutoff, array_ops.ones_like(result_so_far), result_so_far) if upper_cutoff is not None: - result_so_far = math_ops.select(j >= upper_cutoff, + result_so_far = array_ops.where(j >= upper_cutoff, array_ops.zeros_like(result_so_far), result_so_far) diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py new file mode 100644 index 0000000000..7994c5d433 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py @@ -0,0 +1,213 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The RelaxedBernoulli distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import bijector +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops import logistic +from tensorflow.contrib.distributions.python.ops import transformed_distribution +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops + + +class _RelaxedBernoulli(transformed_distribution.TransformedDistribution): + """RelaxedBernoulli distribution with temperature and logits parameters. + + The RelaxedBernoulli is a distribution over the unit interval (0,1), which + continuously approximates a Bernoulli. The degree of approximation is + controlled by a temperature: as the temperaturegoes to 0 the RelaxedBernoulli + becomes discrete with a distribution described by the `logits` or `p` + parameters, as the temperature goes to infinity the RelaxedBernoulli + becomes the constant distribution that is identically 0.5. + + The RelaxedBernoulli distribution is a reparameterized continuous + distribution that is the binary special case of the RelaxedOneHotCategorical + distribution (Maddison et al., 2016; Jang et al., 2016). For details on the + binary special case see the appendix of Maddison et al. (2016) where it is + referred to as BinConcrete. If you use this distribution, please cite both + papers. + + Some care needs to be taken for loss functions that depend on the + log-probability of RelaxedBernoullis, because computing log-probabilities of + the RelaxedBernoulli can suffer from underflow issues. In many case loss + functions such as these are invariant under invertible transformations of + the random variables. The KL divergence, found in the variational autoencoder + loss, is an example. Because RelaxedBernoullis are sampled by by a Logistic + random variable followed by a `tf.sigmoid` op, one solution is to treat + the Logistic as the random variable and `tf.sigmoid` as downstream. The + KL divergences of two Logistics, which are always followed by a `tf.sigmoid` + op, is equivalent to evaluating KL divergences of RelaxedBernoulli samples. + See Maddison et al., 2016 for more details where this distribution is called + the BinConcrete. + + #### Examples + + Creates three continuous distributions, which approximate 3 Bernoullis with + probabilities (0.1, 0.5, 0.4). Samples from these distributions will be in + the unit interval (0,1). + + ```python + temperature = 0.5 + p = [0.1, 0.5, 0.4] + dist = RelaxedBernoulli(temperature, p=p) + ``` + + Creates three continuous distributions, which approximate 3 Bernoullis with + logits (-2, 2, 0). Samples from these distributions will be in + the unit interval (0,1). + + ```python + temperature = 0.5 + logits = [-2, 2, 0] + dist = RelaxedBernoulli(temperature, logits=logits) + ``` + + Creates three continuous distributions, whose sigmoid approximate 3 Bernoullis + with logits (-2, 2, 0). + + ```python + temperature = 0.5 + logits = [-2, 2, 0] + dist = Logistic(logits/temperature, 1./temperature) + samples = dist.sample() + sigmoid_samples = tf.sigmoid(samples) + # sigmoid_samples has the same distribution as samples from + # RelaxedBernoulli(temperature, logits=logits) + ``` + + Creates three continuous distributions, which approximate 3 Bernoullis with + logits (-2, 2, 0). Samples from these distributions will be in + the unit interval (0,1). Because the temperature is very low, samples from + these distributions are almost discrete, usually taking values very close to 0 + or 1. + + ```python + temperature = 1e-5 + logits = [-2, 2, 0] + dist = RelaxedBernoulli(temperature, logits=logits) + ``` + + Creates three continuous distributions, which approximate 3 Bernoullis with + logits (-2, 2, 0). Samples from these distributions will be in + the unit interval (0,1). Because the temperature is very high, samples from + these distributions are usually close to the (0.5, 0.5, 0.5) vector. + + ```python + temperature = 100 + logits = [-2, 2, 0] + dist = RelaxedBernoulli(temperature, logits=logits) + ``` + + Chris J. Maddison, Andriy Mnih, and Yee Whye Teh. The Concrete Distribution: + A Continuous Relaxation of Discrete Random Variables. 2016. + + Eric Jang, Shixiang Gu, and Ben Poole. Categorical Reparameterization with + Gumbel-Softmax. 2016. + """ + + def __init__(self, + temperature, + logits=None, + p=None, + validate_args=False, + allow_nan_stats=True, + name="RelaxedBernoulli"): + """Construct RelaxedBernoulli distributions. + + Args: + temperature: An 0-D `Tensor`, representing the temperature + of a set of RelaxedBernoulli distributions. The temperature should be + positive. + logits: An N-D `Tensor` representing the log-odds + of a positive event. Each entry in the `Tensor` parametrizes + an independent RelaxedBernoulli distribution where the probability of an + event is sigmoid(logits). Only one of `logits` or `p` should be passed + in. + p: An N-D `Tensor` representing the probability of a positive + event. Each entry in the `Tensor` parameterizes an independent + Bernoulli distribution. Only one of `logits` or `p` should be passed + in. + validate_args: `Boolean`, default `False`. Whether to validate that + `0 <= p <= 1`. If `validate_args` is `False`, and the inputs are + invalid, methods like `log_pmf` may return `NaN` values. + allow_nan_stats: `Boolean`, default `True`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. + name: A name for this distribution. + + Raises: + ValueError: If p and logits are passed, or if neither are passed. + """ + parameters = locals() + parameters.pop("self") + with ops.name_scope(name, values=[logits, p, temperature]) as ns: + with ops.control_dependencies([check_ops.assert_positive(temperature)] + if validate_args else []): + self._temperature = array_ops.identity(temperature, name="temperature") + + self._logits, self._p = distribution_util.get_logits_and_prob( + logits=logits, p=p, validate_args=validate_args) + with ops.name_scope("q"): + self._q = 1. - self._p + dist = logistic._Logistic(self._logits / self._temperature, + 1./self._temperature, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + name=ns) + + def inverse_log_det_jacobian_fn(y): + return -math_ops.reduce_sum(math_ops.log(y) + math_ops.log(1-y), + reduction_indices=-1) + + sigmoidbijector = bijector.Inline( + forward_fn=math_ops.sigmoid, + inverse_fn=(lambda y: math_ops.log(y) - math_ops.log(1-y)), + inverse_log_det_jacobian_fn=inverse_log_det_jacobian_fn, + name="sigmoid") + super(_RelaxedBernoulli, self).__init__(dist, + sigmoidbijector, + name=name) + + @staticmethod + def _param_shapes(sample_shape): + return {"logits": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} + + @property + def temperature(self): + """Distribution parameter for the location.""" + return self._temperature + + @property + def logits(self): + """Log-odds of success.""" + return self._logits + + @property + def p(self): + """Probability of success.""" + return self._p + + @property + def q(self): + """Probability of failure.""" + return self._q diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py new file mode 100644 index 0000000000..1b60b32ff6 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -0,0 +1,420 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Relaxed OneHotCategorical distribution classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib.distributions.python.ops import bijector +from tensorflow.contrib.distributions.python.ops import distribution +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops import transformed_distribution +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops + + +class _ExpRelaxedOneHotCategorical(distribution.Distribution): + """ExpRelaxedOneHotCategorical distribution with temperature and logits. + + An ExpRelaxedOneHotCategorical distribution is a log-transformed + RelaxedOneHotCategorical distribution. The RelaxedOneHotCategorical is a + distribution over random probability vectors, vectors of positive real + values that sum to one, which continuously approximates a OneHotCategorical. + The degree of approximation is controlled by a temperature: as the temperature + goes to 0 the RelaxedOneHotCategorical becomes discrete with a distribution + described by the logits, as the temperature goes to infinity the + RelaxedOneHotCategorical becomes the constant distribution that is identically + the constant vector of (1/num_classes, ..., 1/num_classes). + + Because computing log-probabilities of the RelaxedOneHotCategorical can + suffer from underflow issues, this class is one solution for loss + functions that depend on log-probabilities, such as the KL Divergence found + in the variational autoencoder loss. The KL divergence between two + distributions is invariant under invertible transformations, so evaluating + KL divergences of ExpRelaxedOneHotCategorical samples, which are always + followed by a `tf.exp` op, is equivalent to evaluating KL divergences of + RelaxedOneHotCategorical samples. See the appendix of Maddison et al., 2016 + for more mathematical details, where this distribution is called the + ExpConcrete. + + #### Examples + + Creates a continuous distribution, whoe exp approximates a 3-class one-hot + categorical distiribution. The 2nd class is the most likely to be the + largest component in samples drawn from this distribution. If those samples + are followed by a `tf.exp` op, then they are distributed as a relaxed onehot + categorical. + + ```python + temperature = 0.5 + p = [0.1, 0.5, 0.4] + dist = ExpRelaxedOneHotCategorical(temperature, p=p) + samples = dist.sample() + exp_samples = tf.exp(samples) + # exp_samples has the same distribution as samples from + # RelaxedOneHotCategorical(temperature, p=p) + ``` + + Creates a continuous distribution, whose exp approximates a 3-class one-hot + categorical distiribution. The 2nd class is the most likely to be the + largest component in samples drawn from this distribution. + + ```python + temperature = 0.5 + logits = [-2, 2, 0] + dist = ExpRelaxedOneHotCategorical(temperature, logits=logits) + samples = dist.sample() + exp_samples = tf.exp(samples) + # exp_samples has the same distribution as samples from + # RelaxedOneHotCategorical(temperature, p=p) + ``` + + Creates a continuous distribution, whose exp approximates a 3-class one-hot + categorical distiribution. Because the temperature is very low, samples from + this distribution are almost discrete, with one component almost 0 and the + others very negative. The 2nd class is the most likely to be the largest + component in samples drawn from this distribution. + + ```python + temperature = 1e-5 + logits = [-2, 2, 0] + dist = ExpRelaxedOneHotCategorical(temperature, logits=logits) + samples = dist.sample() + exp_samples = tf.exp(samples) + # exp_samples has the same distribution as samples from + # RelaxedOneHotCategorical(temperature, p=p) + ``` + + Creates a continuous distribution, whose exp approximates a 3-class one-hot + categorical distiribution. Because the temperature is very high, samples from + this distribution are usually close to the (-log(3), -log(3), -log(3)) vector. + The 2nd class is still the most likely to be the largest component + in samples drawn from this distribution. + + ```python + temperature = 10 + logits = [-2, 2, 0] + dist = ExpRelaxedOneHotCategorical(temperature, logits=logits) + samples = dist.sample() + exp_samples = tf.exp(samples) + # exp_samples has the same distribution as samples from + # RelaxedOneHotCategorical(temperature, p=p) + ``` + + Chris J. Maddison, Andriy Mnih, and Yee Whye Teh. The Concrete Distribution: + A Continuous Relaxation of Discrete Random Variables. 2016. + """ + + def __init__( + self, + temperature, + logits=None, + p=None, + dtype=dtypes.float32, + validate_args=False, + allow_nan_stats=True, + name="ExpRelaxedOneHotCategorical"): + """Initialize ExpRelaxedOneHotCategorical using class log-probabilities. + + Args: + temperature: An 0-D `Tensor`, representing the temperature + of a set of ExpRelaxedCategorical distributions. The temperature should + be positive. + logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities + of a set of ExpRelaxedCategorical distributions. The first + `N - 1` dimensions index into a batch of independent distributions and + the last dimension represents a vector of logits for each class. Only + one of `logits` or `p` should be passed in. + p: An N-D `Tensor`, `N >= 1`, representing the probabilities + of a set of ExpRelaxedCategorical distributions. The first + `N - 1` dimensions index into a batch of independent distributions and + the last dimension represents a vector of probabilities for each + class. Only one of `logits` or `p` should be passed in. + dtype: The type of the event samples (default: int32). + validate_args: Unused in this distribution. + allow_nan_stats: `Boolean`, default `True`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. + name: A name for this distribution (optional). + """ + parameters = locals() + parameters.pop("self") + with ops.name_scope(name, values=[logits, p, temperature]) as ns: + with ops.control_dependencies([check_ops.assert_positive(temperature)] + if validate_args else []): + self._temperature = array_ops.identity(temperature, name="temperature") + self._logits, self._p = distribution_util.get_logits_and_prob( + name=name, logits=logits, p=p, validate_args=validate_args, + multidimensional=True) + + logits_shape_static = self._logits.get_shape().with_rank_at_least(1) + if logits_shape_static.ndims is not None: + self._batch_rank = ops.convert_to_tensor( + logits_shape_static.ndims - 1, + dtype=dtypes.int32, + name="batch_rank") + else: + with ops.name_scope(name="batch_rank"): + self._batch_rank = array_ops.rank(self._logits) - 1 + + logits_shape = array_ops.shape(self._logits, name="logits_shape") + if logits_shape_static[-1].value is not None: + self._num_classes = ops.convert_to_tensor( + logits_shape_static[-1].value, + dtype=dtypes.int32, + name="num_classes") + else: + self._num_classes = array_ops.gather(logits_shape, + self._batch_rank, + name="num_classes") + + if logits_shape_static[:-1].is_fully_defined(): + self._batch_shape_val = constant_op.constant( + logits_shape_static[:-1].as_list(), + dtype=dtypes.int32, + name="batch_shape") + else: + with ops.name_scope(name="batch_shape"): + self._batch_shape_val = logits_shape[:-1] + super(_ExpRelaxedOneHotCategorical, self).__init__( + dtype=dtype, + is_continuous=True, + is_reparameterized=True, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[self._logits, self._temperature, self._num_classes], + name=ns) + + @property + def num_classes(self): + """Scalar `int32` tensor: the number of classes.""" + return self._num_classes + + @property + def temperature(self): + """A scalar representing the temperature.""" + return self._temperature + + @property + def logits(self): + """Vector of coordinatewise logits.""" + return self._logits + + @property + def p(self): + """Vector of probabilities summing to one.""" + return self._p + + def _batch_shape(self): + # Use identity to inherit callers "name". + return array_ops.identity(self._batch_shape_val) + + def _get_batch_shape(self): + return self.logits.get_shape()[:-1] + + def _event_shape(self): + return array_ops.shape(self.logits)[-1] + + def _get_event_shape(self): + return self.logits.get_shape().with_rank_at_least(1)[-1:] + + def _sample_n(self, n, seed=None): + sample_shape = array_ops.concat(0, ([n], array_ops.shape(self.logits))) + logits = self.logits * array_ops.ones(sample_shape) + if logits.get_shape().ndims == 2: + logits_2d = logits + else: + logits_2d = array_ops.reshape(logits, [-1, self.num_classes]) + np_dtype = self.dtype.as_numpy_dtype() + minval = np.nextafter(np_dtype(0), np_dtype(1)) + uniform = random_ops.random_uniform(shape=array_ops.shape(logits_2d), + minval=minval, + maxval=1, + dtype=self.dtype, + seed=seed) + gumbel = - math_ops.log(- math_ops.log(uniform)) + noisy_logits = math_ops.div(gumbel + logits_2d, self.temperature) + samples = nn_ops.log_softmax(noisy_logits) + ret = array_ops.reshape(samples, sample_shape) + return ret + + def _log_prob(self, x): + x = ops.convert_to_tensor(x, name="x") + x = self._assert_valid_sample(x) + # broadcast logits or x if need be. + logits = self.logits + if (not x.get_shape().is_fully_defined() or + not logits.get_shape().is_fully_defined() or + x.get_shape() != logits.get_shape()): + logits = array_ops.ones_like(x, dtype=logits.dtype) * logits + x = array_ops.ones_like(logits, dtype=x.dtype) * x + + logits_shape = array_ops.shape(logits) + if logits.get_shape().ndims == 2: + logits_2d = logits + x_2d = x + else: + logits_2d = array_ops.reshape(logits, [-1, self.num_classes]) + x_2d = array_ops.reshape(x, [-1, self.num_classes]) + # compute the normalization constant + log_norm_const = (math_ops.lgamma(self.num_classes) + + (self.num_classes - 1) + * math_ops.log(self.temperature)) + # compute the unnormalized density + log_softmax = nn_ops.log_softmax(logits_2d - x_2d * self.temperature) + log_unnorm_prob = math_ops.reduce_sum(log_softmax, [-1], keep_dims=False) + # combine unnormalized density with normalization constant + log_prob = log_norm_const + log_unnorm_prob + ret = array_ops.reshape(log_prob, logits_shape) + return ret + + def _prob(self, x): + return math_ops.exp(self._log_prob(x)) + + def _assert_valid_sample(self, x): + if not self.validate_args: return x + return control_flow_ops.with_dependencies([ + check_ops.assert_non_positive(x), + distribution_util.assert_close( + array_ops.zeros((), dtype=self.dtype), + math_ops.reduce_logsumexp(x, reduction_indices=[-1])), + ], x) + + +class _RelaxedOneHotCategorical( + transformed_distribution.TransformedDistribution): + """RelaxedOneHotCategorical distribution with temperature and logits. + + The RelaxedOneHotCategorical is a distribution over random probability + vectors, vectors of positive real values that sum to one, which continuously + approximates a OneHotCategorical. The degree of approximation is controlled by + a temperature: as the temperaturegoes to 0 the RelaxedOneHotCategorical + becomes discrete with a distribution described by the `logits` or `p` + parameters, as the temperature goes to infinity the RelaxedOneHotCategorical + becomes the constant distribution that is identically the constant vector of + (1/num_classes, ..., 1/num_classes). + + The RelaxedOneHotCategorical distribution was concurrently introduced as the + Gumbel-Softmax (Jang et al., 2016) and Concrete (Maddison et al., 2016) + distributions for use as a reparameterized continuous approximation to the + `Categorical` one-hot distribution. If you use this distribution, please cite + both papers. + + #### Examples + + Creates a continuous distribution, which approximates a 3-class one-hot + categorical distiribution. The 2nd class is the most likely to be the + largest component in samples drawn from this distribution. + + ```python + temperature = 0.5 + p = [0.1, 0.5, 0.4] + dist = RelaxedOneHotCategorical(temperature, p=p) + ``` + + Creates a continuous distribution, which approximates a 3-class one-hot + categorical distiribution. The 2nd class is the most likely to be the + largest component in samples drawn from this distribution. + + ```python + temperature = 0.5 + logits = [-2, 2, 0] + dist = RelaxedOneHotCategorical(temperature, logits=logits) + ``` + + Creates a continuous distribution, which approximates a 3-class one-hot + categorical distiribution. Because the temperature is very low, samples from + this distribution are almost discrete, with one component almost 1 and the + others nearly 0. The 2nd class is the most likely to be the largest component + in samples drawn from this distribution. + + ```python + temperature = 1e-5 + logits = [-2, 2, 0] + dist = RelaxedOneHotCategorical(temperature, logits=logits) + ``` + + Creates a continuous distribution, which approximates a 3-class one-hot + categorical distiribution. Because the temperature is very high, samples from + this distribution are usually close to the (1/3, 1/3, 1/3) vector. The 2nd + class is still the most likely to be the largest component + in samples drawn from this distribution. + + ```python + temperature = 10 + logits = [-2, 2, 0] + dist = RelaxedOneHotCategorical(temperature, logits=logits) + ``` + + Eric Jang, Shixiang Gu, and Ben Poole. Categorical Reparameterization with + Gumbel-Softmax. 2016. + + Chris J. Maddison, Andriy Mnih, and Yee Whye Teh. The Concrete Distribution: + A Continuous Relaxation of Discrete Random Variables. 2016. + """ + + def __init__( + self, + temperature, + logits=None, + p=None, + dtype=dtypes.float32, + validate_args=False, + allow_nan_stats=True, + name="RelaxedOneHotCategorical"): + """Initialize RelaxedOneHotCategorical using class log-probabilities. + + Args: + temperature: An 0-D `Tensor`, representing the temperature + of a set of RelaxedOneHotCategorical distributions. The temperature + should be positive. + logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities + of a set of RelaxedOneHotCategorical distributions. The first + `N - 1` dimensions index into a batch of independent distributions and + the last dimension represents a vector of logits for each class. Only + one of `logits` or `p` should be passed in. + p: An N-D `Tensor`, `N >= 1`, representing the probabilities + of a set of RelaxedOneHotCategorical distributions. The first + `N - 1` dimensions index into a batch of independent distributions and + the last dimension represents a vector of probabilities for each + class. Only one of `logits` or `p` should be passed in. + dtype: The type of the event samples (default: int32). + validate_args: Unused in this distribution. + allow_nan_stats: `Boolean`, default `True`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member. If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. + name: A name for this distribution (optional). + """ + dist = _ExpRelaxedOneHotCategorical(temperature, + logits=logits, + p=p, + dtype=dtype, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats) + super(_RelaxedOneHotCategorical, self).__init__(dist, + bijector.Exp(), + name=name) diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py index f5b2d94a8a..fa1596d555 100644 --- a/tensorflow/contrib/distributions/python/ops/shape.py +++ b/tensorflow/contrib/distributions/python/ops/shape.py @@ -422,7 +422,7 @@ class _DistributionShape(object): batch_shape = array_ops.slice(s, (1,), (self.batch_ndims,)) # Since sample_dims=1 and is left-most, we add 1 to the number of # batch_ndims to get the event start dim. - event_start = math_ops.select( + event_start = array_ops.where( self._batch_ndims_is_0, 2, 1 + self.batch_ndims) event_shape = array_ops.slice(s, (event_start,), (self.event_ndims,)) new_shape = array_ops.concat(0, (sample_shape, batch_shape, event_shape)) diff --git a/tensorflow/contrib/distributions/python/ops/student_t.py b/tensorflow/contrib/distributions/python/ops/student_t.py index cf21aedc37..9c086d126c 100644 --- a/tensorflow/contrib/distributions/python/ops/student_t.py +++ b/tensorflow/contrib/distributions/python/ops/student_t.py @@ -230,7 +230,7 @@ class StudentT(distribution.Distribution): mean = self.mu * self._ones() if self.allow_nan_stats: nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype()) - return math_ops.select( + return array_ops.where( math_ops.greater(self.df, self._ones()), mean, array_ops.fill(self.batch_shape(), nan, name="nan")) else: @@ -255,14 +255,14 @@ class StudentT(distribution.Distribution): math_ops.square(self.sigma) * self.df / (self.df - 2)) # When 1 < df <= 2, variance is infinite. inf = np.array(np.inf, dtype=self.dtype.as_numpy_dtype()) - result_where_defined = math_ops.select( + result_where_defined = array_ops.where( math_ops.greater(self.df, array_ops.fill(self.batch_shape(), 2.)), var, array_ops.fill(self.batch_shape(), inf, name="inf")) if self.allow_nan_stats: nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype()) - return math_ops.select( + return array_ops.where( math_ops.greater(self.df, self._ones()), result_where_defined, array_ops.fill(self.batch_shape(), nan, name="nan")) diff --git a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py index 074c51b8d5..74115fe542 100644 --- a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py @@ -120,7 +120,7 @@ class TransformedDistribution(distributions.Distribution): forward_fn=tf.exp, inverse_fn=tf.log, inverse_log_det_jacobian_fn=( - lambda y: -tf.reduce_sum(tf.log(x), reduction_indices=-1)), + lambda y: -tf.reduce_sum(tf.log(y), reduction_indices=-1)), name="LogNormalTransformedDistribution") ``` @@ -144,7 +144,7 @@ class TransformedDistribution(distributions.Distribution): """Construct a Transformed Distribution. Args: - distribution: The base distribution class to transform. Typically an + distribution: The base distribution instance to transform. Typically an instance of `Distribution`. bijector: The object responsible for calculating the transformation. Typically an instance of `Bijector`. @@ -244,7 +244,7 @@ class TransformedDistribution(distributions.Distribution): bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - return self.distribution.log_cdf(x, distribution_kwargs) + return self.distribution.log_cdf(x, **distribution_kwargs) @distribution_util.AppendDocstring( condition_kwargs_dict=_condition_kwargs_dict) diff --git a/tensorflow/contrib/distributions/python/ops/uniform.py b/tensorflow/contrib/distributions/python/ops/uniform.py index 6b64ca4669..9c9fdf42bd 100644 --- a/tensorflow/contrib/distributions/python/ops/uniform.py +++ b/tensorflow/contrib/distributions/python/ops/uniform.py @@ -148,10 +148,10 @@ class Uniform(distribution.Distribution): def _prob(self, x): broadcasted_x = x * array_ops.ones(self.batch_shape()) - return math_ops.select( + return array_ops.where( math_ops.is_nan(broadcasted_x), broadcasted_x, - math_ops.select( + array_ops.where( math_ops.logical_or(broadcasted_x < self.a, broadcasted_x > self.b), array_ops.zeros_like(broadcasted_x), @@ -164,9 +164,9 @@ class Uniform(distribution.Distribution): broadcasted_x = x * array_ops.ones(self.batch_shape()) zeros = array_ops.zeros_like(x + self.a + self.b, dtype=self.dtype) ones = array_ops.ones_like(x + self.a + self.b, dtype=self.dtype) - result_if_not_big = math_ops.select( + result_if_not_big = array_ops.where( x < self.a, zeros, (broadcasted_x - self.a) / self.range()) - return math_ops.select(x >= self.b, ones, result_if_not_big) + return array_ops.where(x >= self.b, ones, result_if_not_big) def _entropy(self): return math_ops.log(self.range()) diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index ae106628cc..b478a12d36 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -367,7 +367,7 @@ class _WishartOperatorPD(distribution.Distribution): def _mode(self): s = self.df - self.dimension - 1. - s = math_ops.select( + s = array_ops.where( math_ops.less(s, 0.), constant_op.constant(float("NaN"), dtype=self.dtype, name="nan"), s) diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index 3228c1f3df..7784c6dbda 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -243,7 +243,7 @@ class KMeansClustering(estimator.Estimator, ).training_graph() incr_step = tf.assign_add(tf.contrib.framework.get_global_step(), 1) self._loss = tf.reduce_sum(losses) - tf.scalar_summary('loss/raw', self._loss) + tf.contrib.deprecated.scalar_summary('loss/raw', self._loss) training_op = with_dependencies([training_op, incr_step], self._loss) return training_op, self._loss diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py index 34420fc87f..c149d14849 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util.py @@ -98,29 +98,32 @@ def assert_same_float_dtype(tensors=None, dtype=None): return dtype -def assert_scalar_int(tensor): +def assert_scalar_int(tensor, name=None): """Assert `tensor` is 0-D, of type `tf.int32` or `tf.int64`. Args: - tensor: Tensor to test. + tensor: `Tensor` to test. + name: Name of the op and of the new `Tensor` if one is created. Returns: `tensor`, for chaining. Raises: ValueError: if `tensor` is not 0-D, of type `tf.int32` or `tf.int64`. """ - tensor = ops.convert_to_tensor(tensor) - data_type = tensor.dtype - if data_type.base_dtype not in [dtypes.int32, dtypes.int64]: - raise ValueError('Unexpected type %s for %s.' % (data_type, tensor.name)) - assert_scalar(tensor) - - -def assert_scalar(tensor): - tensor = ops.convert_to_tensor(tensor) - shape = tensor.get_shape() - if shape.ndims != 0: - raise ValueError('Unexpected shape %s for %s.' % (shape, tensor.name)) - return tensor + with ops.name_scope(name, 'assert_scalar_int', [tensor]) as name_scope: + tensor = ops.convert_to_tensor(tensor) + data_type = tensor.dtype + if data_type.base_dtype not in [dtypes.int32, dtypes.int64]: + raise ValueError('Unexpected type %s for %s.' % (data_type, tensor.name)) + return assert_scalar(tensor, name=name_scope) + + +def assert_scalar(tensor, name=None): + with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope: + tensor = ops.convert_to_tensor(tensor, name=name_scope) + shape = tensor.get_shape() + if shape.ndims != 0: + raise ValueError('Unexpected shape %s for %s.' % (shape, tensor.name)) + return tensor def reduce_sum_n(tensors, name=None): @@ -141,14 +144,15 @@ def reduce_sum_n(tensors, name=None): """ if not tensors: raise ValueError('No tensors provided.') - tensors = [math_ops.reduce_sum(t, name='%s/sum' % t.op.name) for t in tensors] - if len(tensors) == 1: - return tensors[0] - with ops.name_scope(name, 'reduce_sum_n', tensors) as scope: - return math_ops.add_n(tensors, name=scope) + with ops.name_scope(name, 'reduce_sum_n', tensors) as name_scope: + tensors = [ + math_ops.reduce_sum(t, name='%s/sum' % t.op.name) for t in tensors] + if len(tensors) == 1: + return tensors[0] + return math_ops.add_n(tensors, name=name_scope) -def remove_squeezable_dimensions(predictions, labels): +def remove_squeezable_dimensions(predictions, labels, name=None): """Squeeze last dim if ranks of `predictions` and `labels` differ by 1. This will use static shape if available. Otherwise, it will add graph @@ -157,41 +161,44 @@ def remove_squeezable_dimensions(predictions, labels): Args: predictions: Predicted values, a `Tensor` of arbitrary dimensions. labels: Label values, a `Tensor` whose dimensions match `predictions`. + name: Name of the op. Returns: Tuple of `predictions` and `labels`, possibly with last dim squeezed. """ - predictions = ops.convert_to_tensor(predictions) - labels = ops.convert_to_tensor(labels) - predictions_shape = predictions.get_shape() - predictions_rank = predictions_shape.ndims - labels_shape = labels.get_shape() - labels_rank = labels_shape.ndims - if (labels_rank is not None) and (predictions_rank is not None): - # Use static rank. - rank_diff = predictions_rank - labels_rank - if rank_diff == -1: - labels = array_ops.squeeze(labels, [-1]) - elif rank_diff == 1: - predictions = array_ops.squeeze(predictions, [-1]) + with ops.name_scope(name, 'remove_squeezable_dimensions', + [predictions, labels]): + predictions = ops.convert_to_tensor(predictions) + labels = ops.convert_to_tensor(labels) + predictions_shape = predictions.get_shape() + predictions_rank = predictions_shape.ndims + labels_shape = labels.get_shape() + labels_rank = labels_shape.ndims + if (labels_rank is not None) and (predictions_rank is not None): + # Use static rank. + rank_diff = predictions_rank - labels_rank + if rank_diff == -1: + labels = array_ops.squeeze(labels, [-1]) + elif rank_diff == 1: + predictions = array_ops.squeeze(predictions, [-1]) + return predictions, labels + + # Use dynamic rank. + rank_diff = array_ops.rank(predictions) - array_ops.rank(labels) + if (predictions_rank is None) or ( + predictions_shape.dims[-1].is_compatible_with(1)): + predictions = control_flow_ops.cond( + math_ops.equal(1, rank_diff), + lambda: array_ops.squeeze(predictions, [-1]), + lambda: predictions) + if (labels_rank is None) or ( + labels_shape.dims[-1].is_compatible_with(1)): + labels = control_flow_ops.cond( + math_ops.equal(-1, rank_diff), + lambda: array_ops.squeeze(labels, [-1]), + lambda: labels) return predictions, labels - # Use dynamic rank. - rank_diff = array_ops.rank(predictions) - array_ops.rank(labels) - if (predictions_rank is None) or ( - predictions_shape.dims[-1].is_compatible_with(1)): - predictions = control_flow_ops.cond( - math_ops.equal(1, rank_diff), - lambda: array_ops.squeeze(predictions, [-1]), - lambda: predictions) - if (labels_rank is None) or ( - labels_shape.dims[-1].is_compatible_with(1)): - labels = control_flow_ops.cond( - math_ops.equal(-1, rank_diff), - lambda: array_ops.squeeze(labels, [-1]), - lambda: labels) - return predictions, labels - def _all_equal(tensor0, tensor1): with ops.name_scope('all_equal', values=[tensor0, tensor1]) as scope: diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 3ec9de0af2..2db91cd889 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -171,7 +171,8 @@ def local_variable(initial_value, validate_shape=True, name=None): @contrib_add_arg_scope def variable(name, shape=None, dtype=None, initializer=None, regularizer=None, trainable=True, collections=None, - caching_device=None, device=None): + caching_device=None, device=None, + partitioner=None, custom_getter=None): """Gets an existing variable with these parameters or creates a new one. Args: @@ -191,6 +192,11 @@ def variable(name, shape=None, dtype=None, initializer=None, device. device: Optional device to place the variable. It can be an string or a function that is called to get the device for the variable. + partitioner: Optional callable that accepts a fully defined `TensorShape` + and dtype of the `Variable` to be created, and returns a list of + partitions for each axis (currently only one axis can be partitioned). + custom_getter: Callable that allows overwriting the internal + get_variable method and has to have the same signature. Returns: The created or existing variable. @@ -199,19 +205,24 @@ def variable(name, shape=None, dtype=None, initializer=None, # Remove duplicates collections = set(collections) + getter = variable_scope.get_variable + if custom_getter is not None: + getter = custom_getter with ops.device(device or ''): - return variable_scope.get_variable(name, shape=shape, dtype=dtype, - initializer=initializer, - regularizer=regularizer, - trainable=trainable, - collections=collections, - caching_device=caching_device) + return getter(name, shape=shape, dtype=dtype, + initializer=initializer, + regularizer=regularizer, + trainable=trainable, + collections=collections, + caching_device=caching_device, + partitioner=partitioner) @contrib_add_arg_scope def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None, regularizer=None, trainable=True, collections=None, - caching_device=None, device=None): + caching_device=None, device=None, partitioner=None, + custom_getter=None): """Gets an existing model variable with these parameters or creates a new one. Args: @@ -232,16 +243,23 @@ def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None, device. device: Optional device to place the variable. It can be an string or a function that is called to get the device for the variable. + partitioner: Optional callable that accepts a fully defined `TensorShape` + and dtype of the `Variable` to be created, and returns a list of + partitions for each axis (currently only one axis can be partitioned). + custom_getter: Callable that allows overwriting the internal + get_variable method and has to have the same signature. Returns: The created or existing variable. """ collections = list(collections or []) collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES] - return variable(name, shape=shape, dtype=dtype, - initializer=initializer, regularizer=regularizer, - trainable=trainable, collections=collections, - caching_device=caching_device, device=device) + var = variable(name, shape=shape, dtype=dtype, + initializer=initializer, regularizer=regularizer, + trainable=trainable, collections=collections, + caching_device=caching_device, device=device, + partitioner=partitioner, custom_getter=custom_getter) + return var def add_model_variable(var): diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py index 590590bf7b..d846b013fe 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py @@ -727,10 +727,10 @@ def matmul(a, b, name=None): b = core.convert_to_labeled_tensor(b) if len(a.axes) > 2 or len(b.axes) > 2: - # We could use tf.batch_matmul to make this work, but we would also need - # to use tf.tile and/or tf.transpose. These are more expensive than doing - # reshapes, so it's not clear if it's a good idea to do this - # automatically. + # We could pass batched inputs to tf.matmul to make this work, but we + # would also need to use tf.tile and/or tf.transpose. These are more + # expensive than doing reshapes, so it's not clear if it's a good idea to + # do this automatically. raise NotImplementedError( 'matmul currently requires inputs with rank 2 or less, but ' 'inputs have ranks %r and %r' % (len(a.axes), len(b.axes))) diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index 57debbc148..b7832be73f 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -33,12 +33,14 @@ common machine learning algorithms. @@repeat @@safe_embedding_lookup_sparse @@separable_convolution2d -@@stack @@unit_norm Aliases for fully_connected which set a default activation function are available: `relu`, `relu6` and `linear`. +`stack` operation is also available. It builds a stack of layers by applying +a layer repeatedly. + ## Regularizers Regularization can help prevent overfitting. These have the signature @@ -118,4 +120,8 @@ from tensorflow.contrib.layers.python.ops import sparse_ops from tensorflow.python.util.all_util import make_all # pylint: enable=unused-import,wildcard-import + +# Note: `stack` operation is available, just excluded from the document above +# due to collision with tf.stack. + __all__ = make_all(__name__) diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index 4713d0b5c7..25a871cd15 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -150,7 +150,7 @@ def safe_embedding_lookup_sparse(embedding_weights, array_ops.reshape(is_row_empty, [-1, 1]), array_ops.pack([1, array_ops.shape(result)[1]])) - result = math_ops.select(is_row_empty, + result = array_ops.where(is_row_empty, array_ops.zeros_like(result), result, name=scope) diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index d8e5485373..e3ef7328a4 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -850,6 +850,8 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple( shared_embedding_name: (Optional). The common name for shared embedding. shared_vocab_size: (Optional). The common vocab_size used for shared embedding space. + max_norm: (Optional). If not None, embedding values are l2-normalized to + the value of max_norm. Raises: ValueError: if `initializer` is specified and is not callable. Also, @@ -959,7 +961,8 @@ def embedding_column(sparse_id_column, combiner=None, initializer=None, ckpt_to_load_from=None, - tensor_name_in_ckpt=None): + tensor_name_in_ckpt=None, + max_norm=None): """Creates an `_EmbeddingColumn` for feeding sparse data into a DNN. Args: @@ -984,6 +987,8 @@ def embedding_column(sparse_id_column, tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided checkpoint from which to restore the column weights. Required if `ckpt_to_load_from` is not None. + max_norm: (Optional). If not None, embedding values are l2-normalized to + the value of max_norm. Returns: An `_EmbeddingColumn`. @@ -993,7 +998,8 @@ def embedding_column(sparse_id_column, "to \"sqrtn\" after 2016/11/01.") combiner = "mean" return _EmbeddingColumn(sparse_id_column, dimension, combiner, initializer, - ckpt_to_load_from, tensor_name_in_ckpt) + ckpt_to_load_from, tensor_name_in_ckpt, + max_norm=max_norm) def shared_embedding_columns(sparse_id_columns, @@ -1002,7 +1008,8 @@ def shared_embedding_columns(sparse_id_columns, shared_embedding_name=None, initializer=None, ckpt_to_load_from=None, - tensor_name_in_ckpt=None): + tensor_name_in_ckpt=None, + max_norm=None): """Creates a list of `_EmbeddingColumn` sharing the same embedding. Args: @@ -1030,6 +1037,8 @@ def shared_embedding_columns(sparse_id_columns, tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided checkpoint from which to restore the column weights. Required if `ckpt_to_load_from` is not None. + max_norm: (Optional). If not None, embedding values are l2-normalized to + the value of max_norm. Returns: A tuple of `_EmbeddingColumn` with shared embedding space. @@ -1061,7 +1070,7 @@ def shared_embedding_columns(sparse_id_columns, return [ _EmbeddingColumn(sparse_id_columns[0], dimension, combiner, initializer, ckpt_to_load_from, tensor_name_in_ckpt, - shared_embedding_name)] + shared_embedding_name, max_norm=max_norm)] else: # check compatibility of sparse_id_columns compatible = True @@ -1090,7 +1099,8 @@ def shared_embedding_columns(sparse_id_columns, embedded_columns.append( _EmbeddingColumn(column, dimension, combiner, initializer, ckpt_to_load_from, tensor_name_in_ckpt, - shared_embedding_name, shared_vocab_size)) + shared_embedding_name, shared_vocab_size, + max_norm=max_norm)) return tuple(embedded_columns) diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index af6cfa9418..8a49e14c08 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -778,6 +778,25 @@ class CreateInputLayersForDNNsTest(tf.test.TestCase): # score: (number of values) self.assertAllEqual(output.eval(), [[1.], [2.], [0.]]) + def testEmbeddingColumnWithMaxNormForDNN(self): + hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10) + wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo"], + indices=[[0, 0], [1, 0], [1, 1]], + shape=[3, 2]) + features = {"wire": wire_tensor} + embedded_sparse = tf.contrib.layers.embedding_column( + hashed_sparse, + 1, + combiner="sum", + initializer=init_ops.ones_initializer(), + max_norm=0.5) + output = tf.contrib.layers.input_from_feature_columns(features, + [embedded_sparse]) + with self.test_session(): + tf.global_variables_initializer().run() + # score: (number of values * 0.5) + self.assertAllClose(output.eval(), [[0.5], [1.], [0.]]) + def testEmbeddingColumnWithWeightedSparseColumnForDNN(self): ids = tf.contrib.layers.sparse_column_with_keys( "ids", ["marlo", "omar", "stringer"]) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 30f6690c68..5c6559b826 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -39,6 +39,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import standard_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as tf_variables from tensorflow.python.training import moving_averages # TODO(b/28426988): Replace legacy_* fns migrated from slim. @@ -1153,12 +1154,16 @@ def dropout(inputs, Returns: a tensor representing the output of the operation. """ - with ops.name_scope(scope, 'Dropout', [inputs]) as sc: + with variable_scope.variable_scope( + scope, 'Dropout', [inputs], custom_getter=_model_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) - dropout_fn = lambda: nn.dropout(inputs, keep_prob, noise_shape) - id_fn = lambda: array_ops.identity(inputs) - outputs = utils.smart_cond(is_training, dropout_fn, id_fn) - return utils.collect_named_outputs(outputs_collections, sc, outputs) + layer = core_layers.Dropout(rate=1 - keep_prob, + noise_shape=noise_shape, + name=sc.name, + _scope=sc) + outputs = layer.apply(inputs, training=is_training) + return utils.collect_named_outputs( + outputs_collections, sc.original_name_scope, outputs) @add_arg_scope @@ -1264,6 +1269,31 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None): return utils.collect_named_outputs(output_collections, sc, flattened) +def _model_variable_getter(getter, name, shape=None, dtype=None, + initializer=None, regularizer=None, trainable=True, + collections=None, caching_device=None, + partitioner=None, **_): + """Getter that uses model_variable for compatibility with core layers.""" + return variables.model_variable( + name, shape=shape, dtype=dtype, initializer=initializer, + regularizer=regularizer, collections=collections, trainable=trainable, + caching_device=caching_device, partitioner=partitioner, + custom_getter=getter) + + +def _add_variable_to_collections(variable, collections_set, collections_name): + """Adds variable (or all its parts) to all collections with that name.""" + collections = utils.get_variable_collections( + collections_set, collections_name) or [] + variables_list = [variable] + if isinstance(variable, tf_variables.PartitionedVariable): + variables_list = [v for v in variable] + for collection in collections: + for var in variables_list: + if var not in ops.get_collection(collection): + ops.add_to_collection(collection, var) + + @add_arg_scope def fully_connected(inputs, num_outputs, @@ -1325,41 +1355,17 @@ def fully_connected(inputs, if not (isinstance(num_outputs, six.integer_types)): raise ValueError('num_outputs should be int or long, got %s.', num_outputs) - # Currently, the layers in this module do not create variables via - # `tf.get_variable`, rather they use their own variable management system - # which wraps `tf.get_variable` (the `model_variable()` interface from Slim). - # This interface is globally-configured via an argscope. This global - # configuration mechanism is used for instance by Slim-deploy to globally - # configure the target device of the variables of a model. - # - # We have the following the constraints: - # - Argscopes are not currently moving into core, thus core layers cannot - # rely on the Slim variable wrapper, and should instead - # use `tf.get_variable`. - # - Contrib layers require to use the argscope-enabled Slim variable wrapper - # rather than raw TF variables. - # - We want to be able to reuse at least the logic across core layers - # and contrib layers. - # - # We use the following strategy: - # - We instantiate variables in the contrib layer via the Slim interface. - # - We instantiate a core layer and set its variables to be the Slim ones. - # - We call the core layer. - # - # This enables us to reuse the `call` method across both implementations. - - with variable_scope.variable_scope(scope, 'fully_connected', [inputs], - reuse=reuse) as sc: + with variable_scope.variable_scope( + scope, 'fully_connected', [inputs], + reuse=reuse, custom_getter=_model_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) - - # Instantiate the FullyConnected layer. layer = core_layers.FullyConnected( - num_outputs, + units=num_outputs, activation=None, use_bias=not normalizer_fn and biases_initializer, - w_initializer=weights_initializer, + weights_initializer=weights_initializer, bias_initializer=biases_initializer, - w_regularizer=weights_regularizer, + weights_regularizer=weights_regularizer, bias_regularizer=biases_regularizer, activity_regularizer=None, trainable=trainable, @@ -1367,39 +1373,12 @@ def fully_connected(inputs, dtype=inputs.dtype.base_dtype, _scope=sc, _reuse_weights=reuse) + outputs = layer.apply(inputs) - dtype = inputs.dtype.base_dtype - inputs_shape = inputs.get_shape() - num_input_units = utils.last_dimension(inputs_shape, min_rank=2) - - static_shape = inputs_shape.as_list() - static_shape[-1] = num_outputs - - weights_shape = [num_input_units, num_outputs] - weights_collections = utils.get_variable_collections( - variables_collections, 'weights') - weights = variables.model_variable('weights', - shape=weights_shape, - dtype=dtype, - initializer=weights_initializer, - regularizer=weights_regularizer, - collections=weights_collections, - trainable=trainable) - layer.w = weights - - if layer.use_bias: - biases_collections = utils.get_variable_collections( - variables_collections, 'biases') - biases = variables.model_variable('biases', - shape=[num_outputs,], - dtype=dtype, - initializer=biases_initializer, - regularizer=biases_regularizer, - collections=biases_collections, - trainable=trainable) - layer.bias = biases - - outputs = layer.call(inputs) + # Add variables to collections. + _add_variable_to_collections(layer.w, variables_collections, 'weights') + if layer.bias: + _add_variable_to_collections(layer.bias, variables_collections, 'biases') # Apply normalizer function / layer. if normalizer_fn is not None: @@ -2099,4 +2078,3 @@ conv2d = convolution2d conv2d_transpose = convolution2d_transpose conv2d_in_plane = convolution2d_in_plane separable_conv2d = separable_convolution2d - diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index a9c8997885..b28e3363dc 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -628,7 +628,7 @@ class Convolution2dTransposeTests(tf.test.TestCase): self.assertListEqual(list(output.eval().shape), expected_size) def testOutputSizeWithStrideOneValidPaddingNCHW(self): - if tf.test.is_gpu_available(): + if tf.test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True) as sess: num_filters = 32 input_size = [5, 3, 10, 12] @@ -644,7 +644,7 @@ class Convolution2dTransposeTests(tf.test.TestCase): self.assertListEqual(list(output.eval().shape), expected_size) def testOutputSizeWithStrideTwoValidPaddingNCHW(self): - if tf.test.is_gpu_available(): + if tf.test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True) as sess: num_filters = 32 input_size = [5, 3, 9, 11] @@ -661,7 +661,7 @@ class Convolution2dTransposeTests(tf.test.TestCase): self.assertListEqual(list(output.eval().shape), expected_size) def testOutputSizeWith1x1StrideTwoSamePaddingNCHW(self): - if tf.test.is_gpu_available(): + if tf.test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True) as sess: num_filters = 1 input_size = [1, 1, 1, 1] @@ -678,7 +678,7 @@ class Convolution2dTransposeTests(tf.test.TestCase): self.assertListEqual(list(output.eval().shape), expected_size) def testOutputSizeWith1x1StrideTwoValidPaddingNCHW(self): - if tf.test.is_gpu_available(): + if tf.test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True) as sess: num_filters = 1 input_size = [1, 1, 1, 1] @@ -693,7 +693,7 @@ class Convolution2dTransposeTests(tf.test.TestCase): self.assertListEqual(list(output.eval().shape), expected_size) def testOutputSizeWith2x2StrideTwoSamePaddingNCHW(self): - if tf.test.is_gpu_available(): + if tf.test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True) as sess: num_filters = 1 input_size = [1, 1, 2, 2] @@ -708,7 +708,7 @@ class Convolution2dTransposeTests(tf.test.TestCase): self.assertListEqual(list(output.eval().shape), expected_size) def testOutputSizeWith2x2StrideTwoValidPaddingNCHW(self): - if tf.test.is_gpu_available(): + if tf.test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True) as sess: num_filters = 1 input_size = [1, 1, 2, 2] @@ -723,7 +723,7 @@ class Convolution2dTransposeTests(tf.test.TestCase): self.assertListEqual(list(output.eval().shape), expected_size) def testOutputSizeWithStride2x1NCHW(self): - if tf.test.is_gpu_available(): + if tf.test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True) as sess: num_filters = 1 input_size = [1, 1, 3, 2] @@ -738,7 +738,7 @@ class Convolution2dTransposeTests(tf.test.TestCase): self.assertListEqual(list(output.eval().shape), expected_size) def testOutputSizeWithStride2x4NCHW(self): - if tf.test.is_gpu_available(): + if tf.test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True) as sess: num_filters = 1 input_size = [1, 1, 3, 2] @@ -753,7 +753,7 @@ class Convolution2dTransposeTests(tf.test.TestCase): self.assertListEqual(list(output.eval().shape), expected_size) def testOutputSizeWithStride2x5NCHW(self): - if tf.test.is_gpu_available(): + if tf.test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True) as sess: num_filters = 1 input_size = [1, 1, 3, 2] @@ -1181,7 +1181,6 @@ class DropoutTest(tf.test.TestCase): is_training = tf.constant(True) images = tf.random_uniform((5, height, width, 3), seed=1) output = tf.contrib.layers.dropout(images, is_training=is_training) - self.assertEqual(output.op.name, 'Dropout/dropout/mul') output.get_shape().assert_is_compatible_with(images.get_shape()) def testCreateDropoutWithConstantFalse(self): @@ -1190,7 +1189,6 @@ class DropoutTest(tf.test.TestCase): is_training = tf.constant(False) images = tf.random_uniform((5, height, width, 3), seed=1) output = tf.contrib.layers.dropout(images, is_training=is_training) - self.assertEqual(output.op.name, 'Dropout/Identity') output.get_shape().assert_is_compatible_with(images.get_shape()) def testCreateDropoutWithPlaceholder(self): @@ -1220,8 +1218,8 @@ class DropoutTest(tf.test.TestCase): num_elem = tf.reduce_mean(tf.to_float(output > 0)) sess.run(tf.global_variables_initializer()) num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial]) - self.assertLess(num_elem, num_elem_initial/2 + 0.1) - self.assertGreater(num_elem, num_elem_initial/2 - 0.1) + self.assertLess(num_elem, num_elem_initial / 2 + 0.1) + self.assertGreater(num_elem, num_elem_initial / 2 - 0.1) def testCreateDropoutNoTraining(self): height, width = 3, 3 @@ -1246,8 +1244,8 @@ class DropoutTest(tf.test.TestCase): num_elem = tf.reduce_mean(tf.to_float(output > 0)) sess.run(tf.global_variables_initializer()) num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial]) - self.assertLess(num_elem, num_elem_initial/2 + 0.1) - self.assertGreater(num_elem, num_elem_initial/2 - 0.1) + self.assertLess(num_elem, num_elem_initial / 2 + 0.1) + self.assertGreater(num_elem, num_elem_initial / 2 - 0.1) def testCreateFCWithDropout(self): height, width = 3, 3 diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py index 2908096c6c..ada679d83b 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers.py +++ b/tensorflow/contrib/layers/python/layers/optimizers.py @@ -360,7 +360,7 @@ def adaptive_clipping_fn(std_factor=2., summary.scalar("global_norm/adaptive_max_gradient_norm", max_norm) # factor will be 1. if norm is smaller than max_norm - factor = math_ops.select(norm < max_norm, + factor = array_ops.where(norm < max_norm, array_ops.ones_like(norm), math_ops.exp(log_mean) / norm) diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 5b7a2d76d8..764971935f 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -21,6 +21,10 @@ py_library( "//tensorflow/contrib/session_bundle:exporter", "//tensorflow/contrib/tensor_forest:client_lib", "//tensorflow/python:framework", + "//tensorflow/python/saved_model:builder", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:signature_def_utils", + "//tensorflow/python/saved_model:tag_constants", ], ) @@ -663,6 +667,32 @@ py_test( ) py_test( + name = "gc_test", + size = "small", + srcs = ["python/learn/utils/gc_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":learn", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + ], +) + +py_test( + name = "saved_model_export_utils_test", + size = "small", + srcs = ["python/learn/utils/saved_model_export_utils_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":learn", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + ], +) + +py_test( name = "stability_test", size = "small", srcs = ["python/learn/estimators/stability_test.py"], diff --git a/tensorflow/python/util/net_lib.py b/tensorflow/contrib/learn/python/learn/estimators/constants.py index d8566eb7c7..aee4541627 100644 --- a/tensorflow/python/util/net_lib.py +++ b/tensorflow/contrib/learn/python/learn/estimators/constants.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A Python interface for creating TensorFlow tests.""" +"""Constants regarding Estimators.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python import pywrap_tensorflow - -def pick_unused_port_or_die(): - """Find an unused port on localhost.""" - return pywrap_tensorflow.PickUnusedPortOrDie() +class ProblemType(object): + UNSPECIFIED = 0 + CLASSIFICATION = 1 + LINEAR_REGRESSION = 2 + LOGISTIC_REGRESSION = 3 diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index db0a24e508..98947cc6d4 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -19,9 +19,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import six + from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values +from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import evaluable @@ -89,6 +92,9 @@ def _dnn_model_fn(features, labels, mode, params): * gradient_clip_norm: A float > 0. If provided, gradients are clipped to their global norm with this clipping ratio. * num_ps_replicas: The number of parameter server replicas. + * embedding_lr_multipliers: Optional. A dictionary from + `EmbeddingColumn` to a `float` multiplier. Multiplier will be used to + multiply with learning rate for the embedding variables. Returns: predictions: A dict of `Tensor` objects. @@ -103,6 +109,7 @@ def _dnn_model_fn(features, labels, mode, params): dropout = params.get("dropout") gradient_clip_norm = params.get("gradient_clip_norm") num_ps_replicas = params.get("num_ps_replicas", 0) + embedding_lr_multipliers = params.get("embedding_lr_multipliers", {}) features = _get_feature_dict(features) parent_scope = "dnn" @@ -111,9 +118,10 @@ def _dnn_model_fn(features, labels, mode, params): partitioned_variables.min_max_variable_partitioner( max_partitions=num_ps_replicas, min_slice_size=64 << 20)) + input_layer_scope = parent_scope + "/input_from_feature_columns" with variable_scope.variable_scope( - parent_scope + "/input_from_feature_columns", - values=features.values(), + input_layer_scope, + values=list(six.itervalues(features)), partitioner=input_layer_partitioner) as scope: net = layers.input_from_feature_columns( columns_to_tensors=features, @@ -160,6 +168,9 @@ def _dnn_model_fn(features, labels, mode, params): global_step=contrib_variables.get_global_step(), learning_rate=_LEARNING_RATE, optimizer=_get_optimizer(optimizer), + gradient_multipliers=( + dnn_linear_combined._extract_embedding_lr_multipliers( # pylint: disable=protected-access + embedding_lr_multipliers, parent_scope, input_layer_scope)), clip_gradients=gradient_clip_norm, name=parent_scope, # Empty summaries to prevent optimizers from logging the training_loss. @@ -234,7 +245,8 @@ class DNNClassifier(evaluable.Evaluable, trainable.Trainable): gradient_clip_norm=None, enable_centered_bias=False, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + embedding_lr_multipliers=None): """Initializes a DNNClassifier instance. Args: @@ -271,6 +283,9 @@ class DNNClassifier(evaluable.Evaluable, trainable.Trainable): labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + embedding_lr_multipliers: Optional. A dictionary from `EmbeddingColumn` to + a `float` multiplier. Multiplier will be used to multiply with + learning rate for the embedding variables. Returns: A `DNNClassifier` estimator. @@ -287,17 +302,27 @@ class DNNClassifier(evaluable.Evaluable, trainable.Trainable): model_dir=model_dir, config=config, params={ - "head": head_lib._multi_class_head( # pylint: disable=protected-access - n_classes, - weight_column_name=weight_column_name, - enable_centered_bias=enable_centered_bias), - "hidden_units": hidden_units, - "feature_columns": feature_columns, - "optimizer": optimizer, - "activation_fn": activation_fn, - "dropout": dropout, - "gradient_clip_norm": gradient_clip_norm, - "num_ps_replicas": config.num_ps_replicas if config else 0, + "head": + head_lib._multi_class_head( # pylint: disable=protected-access + n_classes, + weight_column_name=weight_column_name, + enable_centered_bias=enable_centered_bias), + "hidden_units": + hidden_units, + "feature_columns": + feature_columns, + "optimizer": + optimizer, + "activation_fn": + activation_fn, + "dropout": + dropout, + "gradient_clip_norm": + gradient_clip_norm, + "num_ps_replicas": + config.num_ps_replicas if config else 0, + "embedding_lr_multipliers": + embedding_lr_multipliers, }, feature_engineering_fn=feature_engineering_fn) @@ -428,6 +453,22 @@ class DNNClassifier(evaluable.Evaluable, trainable.Trainable): default_batch_size=default_batch_size, exports_to_keep=exports_to_keep) + @experimental + def export_savedmodel(self, + export_dir_base, + input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + exports_to_keep=None): + return self._estimator.export_savedmodel( + export_dir_base, + input_fn, + default_output_alternative_key=default_output_alternative_key, + assets_extra=assets_extra, + as_text=as_text, + exports_to_keep=exports_to_keep) + @property def model_dir(self): return self._estimator.model_dir diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index 10b288f1fb..256e074079 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -26,7 +26,9 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values +from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.layers.python.layers import feature_column_ops from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import evaluable @@ -365,6 +367,31 @@ def _add_hidden_layer_summary(value, tag): logging_ops.histogram_summary("%s:activation" % tag, value) +def _get_embedding_variable(column, collection_key, input_layer_scope): + return ops.get_collection(collection_key, + input_layer_scope + "/" + column.name) + + +def _extract_embedding_lr_multipliers(embedding_lr_multipliers, collection_key, + input_layer_scope): + """Converts embedding lr multipliers to variable based gradient multiplier.""" + if not embedding_lr_multipliers: + return None + gradient_multipliers = {} + for column, lr_mult in embedding_lr_multipliers.items(): + if not isinstance(column, feature_column_lib._EmbeddingColumn): # pylint: disable=protected-access + raise ValueError( + "learning rate multipler can only be defined for embedding columns. " + "It is defined for {}".format(column)) + embedding = _get_embedding_variable( + column, collection_key, input_layer_scope) + if not embedding: + raise ValueError("Couldn't find a variable for column {}".format(column)) + for v in embedding: + gradient_multipliers[v] = lr_mult + return gradient_multipliers + + def _dnn_linear_combined_model_fn(features, labels, mode, params): """Deep Neural Net and Linear combined model_fn. @@ -396,6 +423,9 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params): * gradient_clip_norm: A float > 0. If provided, gradients are clipped to their global norm with this clipping ratio. * num_ps_replicas: The number of parameter server replicas. + * embedding_lr_multipliers: Optional. A dictionary from + `EmbeddingColumn` to a `float` multiplier. Multiplier will be used to + multiply with learning rate for the embedding variables. Returns: `ModelFnOps` @@ -414,7 +444,8 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params): dnn_activation_fn = params.get("dnn_activation_fn") dnn_dropout = params.get("dnn_dropout") gradient_clip_norm = params.get("gradient_clip_norm") - num_ps_replicas = params["num_ps_replicas"] + num_ps_replicas = params.get("num_ps_replicas", 0) + embedding_lr_multipliers = params.get("embedding_lr_multipliers", {}) if not linear_feature_columns and not dnn_feature_columns: raise ValueError( @@ -432,8 +463,9 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params): partitioned_variables.min_max_variable_partitioner( max_partitions=num_ps_replicas, min_slice_size=64 << 20)) + input_layer_scope = dnn_parent_scope + "/input_from_feature_columns" with variable_scope.variable_scope( - dnn_parent_scope + "/input_from_feature_columns", + input_layer_scope, values=features.values(), partitioner=input_layer_partitioner) as scope: net = layers.input_from_feature_columns( @@ -521,6 +553,9 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params): global_step=contrib_variables.get_global_step(), learning_rate=_DNN_LEARNING_RATE, optimizer=_get_optimizer(dnn_optimizer), + gradient_multipliers=_extract_embedding_lr_multipliers( # pylint: disable=protected-access + embedding_lr_multipliers, dnn_parent_scope, + input_layer_scope), clip_gradients=gradient_clip_norm, variables=ops.get_collection(dnn_parent_scope), name=dnn_parent_scope, @@ -612,7 +647,8 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable): gradient_clip_norm=None, enable_centered_bias=False, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + embedding_lr_multipliers=None): """Constructs a DNNLinearCombinedClassifier instance. Args: @@ -656,6 +692,9 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable): labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + embedding_lr_multipliers: Optional. A dictionary from `EmbeddingColumn` to + a `float` multiplier. Multiplier will be used to multiply with + learning rate for the embedding variables. Raises: ValueError: If `n_classes` < 2. @@ -695,6 +734,7 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable): "dnn_dropout": dnn_dropout, "gradient_clip_norm": gradient_clip_norm, "num_ps_replicas": config.num_ps_replicas if config else 0, + "embedding_lr_multipliers": embedding_lr_multipliers, }, feature_engineering_fn=feature_engineering_fn) @@ -829,6 +869,22 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable): default_batch_size=default_batch_size, exports_to_keep=exports_to_keep) + @experimental + def export_savedmodel(self, + export_dir_base, + input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + exports_to_keep=None): + return self._estimator.export_savedmodel( + export_dir_base, + input_fn, + default_output_alternative_key=default_output_alternative_key, + assets_extra=assets_extra, + as_text=as_text, + exports_to_keep=exports_to_keep) + @property def model_dir(self): return self._estimator.model_dir diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py index 26c38d0789..33d0d2eb4f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py @@ -27,7 +27,9 @@ import numpy as np import tensorflow as tf from tensorflow.contrib.learn.python.learn.estimators import _sklearn +from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils +from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import test_data from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec @@ -39,6 +41,82 @@ def _assert_metrics_in_range(keys, metrics): 0.0 - epsilon, 1.0 + epsilon, key, metrics) +class EmbeddingMultiplierTest(tf.test.TestCase): + """dnn_model_fn tests.""" + + def testRaisesNonEmbeddingColumn(self): + one_hot_language = tf.contrib.layers.one_hot_column( + tf.contrib.layers.sparse_column_with_hash_bucket('language', 10)) + + params = { + 'dnn_feature_columns': [one_hot_language], + 'head': head_lib._multi_class_head(2), + 'dnn_hidden_units': [1], + # Set lr mult to 0. to keep embeddings constant. + 'embedding_lr_multipliers': { + one_hot_language: 0.0 + }, + 'dnn_optimizer': 'Adagrad', + } + features = { + 'language': + tf.SparseTensor( + values=['en', 'fr', 'zh'], + indices=[[0, 0], [1, 0], [2, 0]], + shape=[3, 1]), + } + labels = tf.constant([[0], [0], [0]], dtype=tf.int32) + with self.assertRaisesRegexp( + ValueError, 'can only be defined for embedding columns'): + dnn_linear_combined._dnn_linear_combined_model_fn( + features, labels, tf.contrib.learn.ModeKeys.TRAIN, params) + + def testMultipliesGradient(self): + embedding_language = tf.contrib.layers.embedding_column( + tf.contrib.layers.sparse_column_with_hash_bucket('language', 10), + dimension=1, initializer=tf.constant_initializer(0.1)) + embedding_wire = tf.contrib.layers.embedding_column( + tf.contrib.layers.sparse_column_with_hash_bucket('wire', 10), + dimension=1, initializer=tf.constant_initializer(0.1)) + + params = { + 'dnn_feature_columns': [embedding_language, embedding_wire], + 'head': head_lib._multi_class_head(2), + 'dnn_hidden_units': [1], + # Set lr mult to 0. to keep embeddings constant. + 'embedding_lr_multipliers': { + embedding_language: 0.0 + }, + 'dnn_optimizer': 'Adagrad', + } + features = { + 'language': + tf.SparseTensor( + values=['en', 'fr', 'zh'], + indices=[[0, 0], [1, 0], [2, 0]], + shape=[3, 1]), + 'wire': + tf.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [2, 0]], + shape=[3, 1]), + } + labels = tf.constant([[0], [0], [0]], dtype=tf.int32) + model_ops = dnn_linear_combined._dnn_linear_combined_model_fn( + features, labels, tf.contrib.learn.ModeKeys.TRAIN, params) + with tf.train.MonitoredSession() as sess: + language_var = dnn_linear_combined._get_embedding_variable( + embedding_language, 'dnn', 'dnn/input_from_feature_columns') + wire_var = dnn_linear_combined._get_embedding_variable( + embedding_wire, 'dnn', 'dnn/input_from_feature_columns') + for _ in range(2): + _, language_value, wire_value = sess.run( + [model_ops.train_op, language_var, wire_var]) + initial_value = np.full_like(language_value, 0.1) + self.assertTrue(np.all(np.isclose(language_value, initial_value))) + self.assertFalse(np.all(np.isclose(wire_value, initial_value))) + + class DNNLinearCombinedClassifierTest(tf.test.TestCase): def testEstimatorContract(self): @@ -54,6 +132,18 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase): dnn_feature_columns=None, dnn_hidden_units=[3, 3]) + def testEmbeddingMultiplier(self): + embedding_language = tf.contrib.layers.embedding_column( + tf.contrib.layers.sparse_column_with_hash_bucket('language', 10), + dimension=1, initializer=tf.constant_initializer(0.1)) + classifier = tf.contrib.learn.DNNLinearCombinedClassifier( + dnn_feature_columns=[embedding_language], + dnn_hidden_units=[3, 3], + embedding_lr_multipliers={embedding_language: 0.8}) + self.assertEqual( + {embedding_language: 0.8}, + classifier._estimator.params['embedding_lr_multipliers']) + def testLogisticRegression_MatrixData(self): """Tests binary classification using matrix data as input.""" iris = test_data.prepare_iris_data_for_logistic_regression() diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py index a55ca08b1b..9196d78d22 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py @@ -27,12 +27,89 @@ import numpy as np import tensorflow as tf from tensorflow.contrib.learn.python.learn.estimators import _sklearn +from tensorflow.contrib.learn.python.learn.estimators import dnn +from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils +from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import test_data from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec from tensorflow.python.ops import math_ops +class EmbeddingMultiplierTest(tf.test.TestCase): + """dnn_model_fn tests.""" + + def testRaisesNonEmbeddingColumn(self): + one_hot_language = tf.contrib.layers.one_hot_column( + tf.contrib.layers.sparse_column_with_hash_bucket('language', 10)) + + params = { + 'feature_columns': [one_hot_language], + 'head': head_lib._multi_class_head(2), + 'hidden_units': [1], + # Set lr mult to 0. to keep embeddings constant. + 'embedding_lr_multipliers': { + one_hot_language: 0.0 + }, + } + features = { + 'language': + tf.SparseTensor( + values=['en', 'fr', 'zh'], + indices=[[0, 0], [1, 0], [2, 0]], + shape=[3, 1]), + } + labels = tf.constant([[0], [0], [0]], dtype=tf.int32) + with self.assertRaisesRegexp( + ValueError, 'can only be defined for embedding columns'): + dnn._dnn_model_fn(features, labels, + tf.contrib.learn.ModeKeys.TRAIN, params) + + def testMultipliesGradient(self): + embedding_language = tf.contrib.layers.embedding_column( + tf.contrib.layers.sparse_column_with_hash_bucket('language', 10), + dimension=1, initializer=tf.constant_initializer(0.1)) + embedding_wire = tf.contrib.layers.embedding_column( + tf.contrib.layers.sparse_column_with_hash_bucket('wire', 10), + dimension=1, initializer=tf.constant_initializer(0.1)) + + params = { + 'feature_columns': [embedding_language, embedding_wire], + 'head': head_lib._multi_class_head(2), + 'hidden_units': [1], + # Set lr mult to 0. to keep embeddings constant. + 'embedding_lr_multipliers': { + embedding_language: 0.0 + }, + } + features = { + 'language': + tf.SparseTensor( + values=['en', 'fr', 'zh'], + indices=[[0, 0], [1, 0], [2, 0]], + shape=[3, 1]), + 'wire': + tf.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [2, 0]], + shape=[3, 1]), + } + labels = tf.constant([[0], [0], [0]], dtype=tf.int32) + model_ops = dnn._dnn_model_fn(features, labels, + tf.contrib.learn.ModeKeys.TRAIN, params) + with tf.train.MonitoredSession() as sess: + language_var = dnn_linear_combined._get_embedding_variable( + embedding_language, 'dnn', 'dnn/input_from_feature_columns') + wire_var = dnn_linear_combined._get_embedding_variable( + embedding_wire, 'dnn', 'dnn/input_from_feature_columns') + for _ in range(2): + _, language_value, wire_value = sess.run( + [model_ops.train_op, language_var, wire_var]) + initial_value = np.full_like(language_value, 0.1) + self.assertTrue(np.all(np.isclose(language_value, initial_value))) + self.assertFalse(np.all(np.isclose(wire_value, initial_value))) + + class DNNClassifierTest(tf.test.TestCase): def _assertInRange(self, expected_min, expected_max, actual): @@ -43,6 +120,18 @@ class DNNClassifierTest(tf.test.TestCase): estimator_test_utils.assert_estimator_contract( self, tf.contrib.learn.DNNClassifier) + def testEmbeddingMultiplier(self): + embedding_language = tf.contrib.layers.embedding_column( + tf.contrib.layers.sparse_column_with_hash_bucket('language', 10), + dimension=1, initializer=tf.constant_initializer(0.1)) + classifier = tf.contrib.learn.DNNClassifier( + feature_columns=[embedding_language], + hidden_units=[3, 3], + embedding_lr_multipliers={embedding_language: 0.8}) + self.assertEqual( + {embedding_language: 0.8}, + classifier._estimator.params['embedding_lr_multipliers']) + def testLogisticRegression_MatrixData(self): """Tests binary classification using matrix data as input.""" cont_features = [ @@ -118,10 +207,10 @@ class DNNClassifierTest(tf.test.TestCase): classifier = tf.contrib.learn.DNNClassifier( n_classes=2, feature_columns=feature_columns, - hidden_units=[3, 3], + hidden_units=[10, 10], config=tf.contrib.learn.RunConfig(tf_random_seed=1)) - classifier.fit(input_fn=_input_fn, steps=5) + classifier.fit(input_fn=_input_fn, steps=50) scores = classifier.evaluate(input_fn=_input_fn, steps=1) self._assertInRange(0.0, 1.0, scores['accuracy']) @@ -222,7 +311,7 @@ class DNNClassifierTest(tf.test.TestCase): n_classes=3, feature_columns=feature_columns, hidden_units=[3, 3], - config=tf.contrib.learn.RunConfig(tf_random_seed=3)) + config=tf.contrib.learn.RunConfig(tf_random_seed=1)) classifier.fit(x=train_x, y=train_y, steps=200) scores = classifier.evaluate(x=train_x, y=train_y, steps=1) @@ -310,7 +399,7 @@ class DNNClassifierTest(tf.test.TestCase): weight_column_name='w', feature_columns=[tf.contrib.layers.real_valued_column('x')], hidden_units=[3, 3], - config=tf.contrib.learn.RunConfig(tf_random_seed=3)) + config=tf.contrib.learn.RunConfig(tf_random_seed=1)) classifier.fit(input_fn=_input_fn_train, steps=5) scores = classifier.evaluate(input_fn=_input_fn_eval, steps=1) @@ -339,8 +428,8 @@ class DNNClassifierTest(tf.test.TestCase): classifier = tf.contrib.learn.DNNClassifier( n_classes=3, feature_columns=feature_columns, - hidden_units=[3, 3], - config=tf.contrib.learn.RunConfig(tf_random_seed=3)) + hidden_units=[10, 10], + config=tf.contrib.learn.RunConfig(tf_random_seed=1)) classifier.fit(input_fn=_input_fn, steps=100) @@ -524,7 +613,7 @@ class DNNClassifierTest(tf.test.TestCase): } with tf.test.mock.patch.dict('os.environ', {'TF_CONFIG': json.dumps(tf_config)}): - config = tf.contrib.learn.RunConfig(tf_random_seed=5) + config = tf.contrib.learn.RunConfig(tf_random_seed=1) # Because we did not start a distributed cluster, we need to pass an # empty ClusterSpec, otherwise the device_setter will look for # distributed jobs, such as "/job:ps" which are not present. @@ -707,7 +796,7 @@ class DNNRegressorTest(tf.test.TestCase): regressor = tf.contrib.learn.DNNRegressor( feature_columns=[tf.contrib.layers.real_valued_column('x')], hidden_units=[3, 3], - config=tf.contrib.learn.RunConfig(tf_random_seed=3)) + config=tf.contrib.learn.RunConfig(tf_random_seed=1)) regressor.fit(input_fn=_input_fn_train, steps=5) scores = regressor.evaluate(input_fn=_input_fn_train, steps=1) @@ -772,7 +861,7 @@ class DNNRegressorTest(tf.test.TestCase): weight_column_name='w', feature_columns=[tf.contrib.layers.real_valued_column('x')], hidden_units=[3, 3], - config=tf.contrib.learn.RunConfig(tf_random_seed=3)) + config=tf.contrib.learn.RunConfig(tf_random_seed=1)) regressor.fit(input_fn=_input_fn_train, steps=5) scores = regressor.evaluate(input_fn=_input_fn_eval, steps=1) @@ -803,7 +892,7 @@ class DNNRegressorTest(tf.test.TestCase): regressor = tf.contrib.learn.DNNRegressor( feature_columns=feature_columns, hidden_units=[3, 3], - config=tf.contrib.learn.RunConfig(tf_random_seed=3)) + config=tf.contrib.learn.RunConfig(tf_random_seed=1)) regressor.fit(input_fn=_input_fn, steps=200) @@ -837,7 +926,7 @@ class DNNRegressorTest(tf.test.TestCase): regressor = tf.contrib.learn.DNNRegressor( feature_columns=feature_columns, hidden_units=[3, 3], - config=tf.contrib.learn.RunConfig(tf_random_seed=3)) + config=tf.contrib.learn.RunConfig(tf_random_seed=1)) regressor.fit(input_fn=_input_fn, steps=200) @@ -918,7 +1007,7 @@ class DNNRegressorTest(tf.test.TestCase): model_dir=model_dir, feature_columns=feature_columns, hidden_units=[3, 3], - config=tf.contrib.learn.RunConfig(tf_random_seed=3)) + config=tf.contrib.learn.RunConfig(tf_random_seed=1)) regressor.fit(input_fn=_input_fn, steps=5) predict_input_fn = functools.partial(_input_fn, num_epochs=1) @@ -929,7 +1018,7 @@ class DNNRegressorTest(tf.test.TestCase): model_dir=model_dir, feature_columns=feature_columns, hidden_units=[3, 3], - config=tf.contrib.learn.RunConfig(tf_random_seed=3)) + config=tf.contrib.learn.RunConfig(tf_random_seed=1)) predictions2 = list(regressor2.predict(input_fn=predict_input_fn)) self.assertAllClose(predictions, predictions2) @@ -1004,7 +1093,7 @@ class DNNRegressorTest(tf.test.TestCase): feature_columns=feature_columns, hidden_units=[3, 3], enable_centered_bias=True, - config=tf.contrib.learn.RunConfig(tf_random_seed=3)) + config=tf.contrib.learn.RunConfig(tf_random_seed=1)) regressor.fit(input_fn=_input_fn, steps=5) self.assertIn('centered_bias_weight', regressor.get_variable_names()) @@ -1037,7 +1126,7 @@ class DNNRegressorTest(tf.test.TestCase): feature_columns=feature_columns, hidden_units=[3, 3], enable_centered_bias=False, - config=tf.contrib.learn.RunConfig(tf_random_seed=3)) + config=tf.contrib.learn.RunConfig(tf_random_seed=1)) regressor.fit(input_fn=_input_fn, steps=5) self.assertNotIn('centered_bias_weight', regressor.get_variable_names()) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index cdd8300e2e..be28f07a7f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -59,25 +59,6 @@ _CELL_TYPES = {'basic_rnn': rnn_cell.BasicRNNCell, 'gru': rnn_cell.GRUCell,} -# TODO(jamieas): move `padding_mask` to array_ops. -def padding_mask(sequence_lengths, padded_length): - """Creates a mask used for calculating losses with padded input. - - Args: - sequence_lengths: A `Tensor` of shape `[batch_size]` containing the unpadded - length of each sequence. - padded_length: A scalar `Tensor` indicating the length of the sequences - after padding - Returns: - A boolean `Tensor` M of shape `[batch_size, padded_length]` where - `M[i, j] == True` when `lengths[i] > j`. - - """ - range_tensor = math_ops.range(padded_length) - return math_ops.less(array_ops.expand_dims(range_tensor, 0), - array_ops.expand_dims(sequence_lengths, 1)) - - def mask_activations_and_labels(activations, labels, sequence_lengths): """Remove entries outside `sequence_lengths` and returned flattened results. @@ -89,7 +70,7 @@ def mask_activations_and_labels(activations, labels, sequence_lengths): Returns: activations_masked: `logit` values with those beyond `sequence_lengths` - removed for each batch. Batches are then concatenated. Shape + removed for each batch. Batches are then concatenated. Shape `[tf.sum(sequence_lengths), k]` if `sequence_lengths` is not `None` and shape `[batch_size * padded_length, k]` otherwise. labels_masked: Label values after removing unneeded entries. Shape @@ -107,7 +88,7 @@ def mask_activations_and_labels(activations, labels, sequence_lengths): [flattened_dimension, -1]) labels_masked = array_ops.reshape(labels, [flattened_dimension]) else: - mask = padding_mask(sequence_lengths, padded_length) + mask = array_ops.sequence_mask(sequence_lengths, padded_length) activations_masked = array_ops.boolean_mask(activations, mask) labels_masked = array_ops.boolean_mask(labels, mask) return activations_masked, labels_masked @@ -236,7 +217,7 @@ def construct_rnn(initial_state, num_label_columns, dtype=dtypes.float32, parallel_iterations=32, - swap_memory=False): + swap_memory=True): """Build an RNN and apply a fully connected layer to get the desired output. Args: @@ -273,6 +254,9 @@ def construct_rnn(initial_state, num_outputs=num_label_columns, activation_fn=None, trainable=True) + # Use `identitiy` to rename `final_state`. + final_state = array_ops.identity( + final_state, name=RNNKeys.FINAL_STATE_KEY) return activations, final_state @@ -371,13 +355,15 @@ def _multi_value_predictions( probability_shape = array_ops.concat(0, [activations_shape[:2], [2]]) else: probability_shape = activations_shape - probabilities = array_ops.reshape(flat_probabilities, probability_shape) + probabilities = array_ops.reshape( + flat_probabilities, probability_shape, name=RNNKeys.PROBABILITIES_KEY) prediction_dict[RNNKeys.PROBABILITIES_KEY] = probabilities else: flat_predictions = target_column.logits_to_predictions( flattened_activations, proba=False) predictions = array_ops.reshape( - flat_predictions, [activations_shape[0], activations_shape[1]]) + flat_predictions, [activations_shape[0], activations_shape[1]], + name=RNNKeys.PREDICTIONS_KEY) prediction_dict[RNNKeys.PREDICTIONS_KEY] = predictions return prediction_dict @@ -474,7 +460,7 @@ def apply_dropout( cell: An `RNNCell`. input_keep_probability: Probability to keep inputs to `cell`. If `None`, no dropout is applied. - output_keep_probability: Probability to keep outputs to `cell`. If `None`, + output_keep_probability: Probability to keep outputs of `cell`. If `None`, no dropout is applied. random_seed: Seed for random dropout. @@ -509,13 +495,12 @@ def _get_dynamic_rnn_model_fn(cell, initial_state_key=RNNKeys.INITIAL_STATE_KEY, dtype=dtypes.float32, parallel_iterations=None, - swap_memory=False, + swap_memory=True, name='DynamicRNNModel'): """Creates an RNN model function for an `Estimator`. Args: cell: An initialized `RNNCell` to be used in the RNN. - 'basic_rnn,' 'lstm' or 'gru'. target_column: An initialized `TargetColumn`, used to calculate prediction and loss. problem_type: `ProblemType.CLASSIFICATION` or`ProblemType.REGRESSION`. @@ -527,23 +512,23 @@ def _get_dynamic_rnn_model_fn(cell, describing sequence features. All items in the set should be instances of classes derived from `FeatureColumn`. context_feature_columns: An iterable containing all the feature columns - describing context features i.e. features that apply accross all time + describing context features, i.e., features that apply accross all time steps. All items in the set should be instances of classes derived from `FeatureColumn`. predict_probabilities: A boolean indicating whether to predict probabilities - for all classes. Should only be used with `ProblemType.CLASSIFICATION`. + for all classes. Must only be used with `ProblemType.CLASSIFICATION`. learning_rate: Learning rate used for optimization. This argument has no effect if `optimizer` is an instance of an `Optimizer`. gradient_clipping_norm: A float. Gradients will be clipped to this value. input_keep_probability: Probability to keep inputs to `cell`. If `None`, no dropout is applied. - output_keep_probability: Probability to keep outputs to `cell`. If `None`, + output_keep_probability: Probability to keep outputs of `cell`. If `None`, no dropout is applied. sequence_length_key: The key that will be used to look up sequence length in the `features` dict. initial_state_key: The key that will be used to look up initial_state in the `features` dict. - dtype: The dtype of the state and output for the given `cell_num` + dtype: The dtype of the state and output of the given `cell`. parallel_iterations: Number of iterations to run in parallel. Values >> 1 use more memory but take less time, while smaller values use less memory but computations take longer. @@ -601,30 +586,41 @@ def _get_dynamic_rnn_model_fn(cell, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory) + + loss = None # Created below for modes TRAIN and EVAL. if prediction_type == PredictionType.MULTIPLE_VALUE: prediction_dict = _multi_value_predictions( rnn_activations, target_column, predict_probabilities) - loss = _multi_value_loss( - rnn_activations, labels, sequence_length, target_column, features) + if mode != model_fn.ModeKeys.INFER: + loss = _multi_value_loss( + rnn_activations, labels, sequence_length, target_column, features) elif prediction_type == PredictionType.SINGLE_VALUE: prediction_dict = _single_value_predictions( rnn_activations, sequence_length, target_column, predict_probabilities) - loss = _single_value_loss( - rnn_activations, labels, sequence_length, target_column, features) - # TODO(roumposg): Return eval_metric_ops here, instead of default_metrics. - default_metrics = _get_default_metrics( - problem_type, prediction_type, sequence_length) + if mode != model_fn.ModeKeys.INFER: + loss = _single_value_loss( + rnn_activations, labels, sequence_length, target_column, features) prediction_dict[RNNKeys.FINAL_STATE_KEY] = final_state - eval_metric_ops = estimator._make_metrics_ops( # pylint: disable=protected-access - default_metrics, features, labels, prediction_dict) - train_op = optimizers.optimize_loss( - loss=loss, - global_step=None, - learning_rate=learning_rate, - optimizer=optimizer, - clip_gradients=gradient_clipping_norm, - summaries=optimizers.OPTIMIZER_SUMMARIES) + + eval_metric_ops = None + if mode != model_fn.ModeKeys.INFER: + # TODO(roumposg): Return eval_metric_ops instead of default_metrics. + default_metrics = _get_default_metrics( + problem_type, prediction_type, sequence_length) + eval_metric_ops = estimator._make_metrics_ops( # pylint: disable=protected-access + default_metrics, features, labels, prediction_dict) + + train_op = None + if mode == model_fn.ModeKeys.TRAIN: + train_op = optimizers.optimize_loss( + loss=loss, + global_step=None, # Get it internally. + learning_rate=learning_rate, + optimizer=optimizer, + clip_gradients=gradient_clipping_norm, + summaries=optimizers.OPTIMIZER_SUMMARIES) + return model_fn.ModelFnOps(mode=mode, predictions=prediction_dict, loss=loss, @@ -674,43 +670,43 @@ def multi_value_rnn_regressor(num_units, optimizer_type='SGD', learning_rate=0.1, momentum=None, - gradient_clipping_norm=10.0, + gradient_clipping_norm=5.0, input_keep_probability=None, output_keep_probability=None, model_dir=None, config=None, - params=None, feature_engineering_fn=None): - """Creates a RNN `Estimator` that predicts sequences of values. Args: - num_units: The size of the RNN cells. + num_units: The size of the RNN cells. This argument has no effect + if `cell_type` is an instance of `RNNCell`. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the set should be instances of classes derived from `FeatureColumn`. context_feature_columns: An iterable containing all the feature columns - describing context features i.e. features that apply accross all time + describing context features, i.e., features that apply accross all time steps. All items in the set should be instances of classes derived from `FeatureColumn`. - cell_type: A subclass of `RNNCell`, an instance of an `RNNCell or one of + cell_type: A subclass of `RNNCell`, an instance of an `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'. - num_rnn_layers: Number of RNN layers. + num_rnn_layers: Number of RNN layers. Leave this at its default value 1 + if passing a `cell_type` that is already a MultiRNNCell. optimizer_type: The type of optimizer to use. Either a subclass of `Optimizer`, an instance of an `Optimizer` or a string. Strings must be one of 'Adagrad', 'Momentum' or 'SGD'. - learning_rate: Learning rate. + learning_rate: Learning rate. This argument has no effect if `optimizer` + is an instance of an `Optimizer`. momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'. gradient_clipping_norm: Parameter used for gradient clipping. If `None`, then no clipping is performed. input_keep_probability: Probability to keep inputs to `cell`. If `None`, no dropout is applied. - output_keep_probability: Probability to keep outputs to `cell`. If `None`, + output_keep_probability: Probability to keep outputs of `cell`. If `None`, no dropout is applied. - model_dir: Directory to use for The directory in which to save and restore - the model graph, parameters, etc. + model_dir: The directory in which to save and restore the model graph, + parameters, etc. config: A `RunConfig` instance. - params: `dict` of hyperparameters. Passed through to `Estimator`. feature_engineering_fn: Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into `model_fn`. Please check `model_fn` for a definition of features and @@ -739,7 +735,6 @@ def multi_value_rnn_regressor(num_units, return estimator.Estimator(model_fn=dynamic_rnn_model_fn, model_dir=model_dir, config=config, - params=params, feature_engineering_fn=feature_engineering_fn) @@ -754,32 +749,34 @@ def multi_value_rnn_classifier(num_classes, learning_rate=0.1, predict_probabilities=False, momentum=None, - gradient_clipping_norm=10.0, + gradient_clipping_norm=5.0, input_keep_probability=None, output_keep_probability=None, model_dir=None, config=None, - params=None, feature_engineering_fn=None): """Creates a RNN `Estimator` that predicts sequences of labels. Args: num_classes: The number of classes for categorization. - num_units: The size of the RNN cells. + num_units: The size of the RNN cells. This argument has no effect + if `cell_type` is an instance of `RNNCell`. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the set should be instances of classes derived from `FeatureColumn`. context_feature_columns: An iterable containing all the feature columns - describing context features i.e. features that apply accross all time + describing context features, i.e., features that apply accross all time steps. All items in the set should be instances of classes derived from `FeatureColumn`. cell_type: A subclass of `RNNCell`, an instance of an `RNNCell or one of 'basic_rnn,' 'lstm' or 'gru'. - num_rnn_layers: Number of RNN layers. + num_rnn_layers: Number of RNN layers. Leave this at its default value 1 + if passing a `cell_type` that is already a MultiRNNCell. optimizer_type: The type of optimizer to use. Either a subclass of `Optimizer`, an instance of an `Optimizer` or a string. Strings must be one of 'Adagrad', 'Momentum' or 'SGD'. - learning_rate: Learning rate. + learning_rate: Learning rate. This argument has no effect if `optimizer` + is an instance of an `Optimizer`. predict_probabilities: A boolean indicating whether to predict probabilities for all classes. momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'. @@ -787,12 +784,11 @@ def multi_value_rnn_classifier(num_classes, then no clipping is performed. input_keep_probability: Probability to keep inputs to `cell`. If `None`, no dropout is applied. - output_keep_probability: Probability to keep outputs to `cell`. If `None`, + output_keep_probability: Probability to keep outputs of `cell`. If `None`, no dropout is applied. - model_dir: Directory to use for The directory in which to save and restore - the model graph, parameters, etc. + model_dir: The directory in which to save and restore the model graph, + parameters, etc. config: A `RunConfig` instance. - params: `dict` of hyperparameters. Passed through to `Estimator`. feature_engineering_fn: Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into `model_fn`. Please check `model_fn` for a definition of features and @@ -822,7 +818,6 @@ def multi_value_rnn_classifier(num_classes, return estimator.Estimator(model_fn=dynamic_rnn_model_fn, model_dir=model_dir, config=config, - params=params, feature_engineering_fn=feature_engineering_fn) @@ -835,42 +830,43 @@ def single_value_rnn_regressor(num_units, optimizer_type='SGD', learning_rate=0.1, momentum=None, - gradient_clipping_norm=10.0, + gradient_clipping_norm=5.0, input_keep_probability=None, output_keep_probability=None, model_dir=None, config=None, - params=None, feature_engineering_fn=None): """Create a RNN `Estimator` that predicts single values. Args: - num_units: The size of the RNN cells. + num_units: The size of the RNN cells. This argument has no effect + if `cell_type` is an instance of `RNNCell`. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the set should be instances of classes derived from `FeatureColumn`. context_feature_columns: An iterable containing all the feature columns - describing context features i.e. features that apply accross all time + describing context features, i.e., features that apply accross all time steps. All items in the set should be instances of classes derived from `FeatureColumn`. cell_type: A subclass of `RNNCell`, an instance of an `RNNCell or one of 'basic_rnn,' 'lstm' or 'gru'. - num_rnn_layers: Number of RNN layers. + num_rnn_layers: Number of RNN layers. Leave this at its default value 1 + if passing a `cell_type` that is already a MultiRNNCell. optimizer_type: The type of optimizer to use. Either a subclass of `Optimizer`, an instance of an `Optimizer` or a string. Strings must be one of 'Adagrad', 'Momentum' or 'SGD'. - learning_rate: Learning rate. + learning_rate: Learning rate. This argument has no effect if `optimizer` + is an instance of an `Optimizer`. momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'. gradient_clipping_norm: Parameter used for gradient clipping. If `None`, then no clipping is performed. input_keep_probability: Probability to keep inputs to `cell`. If `None`, no dropout is applied. - output_keep_probability: Probability to keep outputs to `cell`. If `None`, + output_keep_probability: Probability to keep outputs of `cell`. If `None`, no dropout is applied. - model_dir: Directory to use for The directory in which to save and restore - the model graph, parameters, etc. + model_dir: The directory in which to save and restore the model graph, + parameters, etc. config: A `RunConfig` instance. - params: `dict` of hyperparameters. Passed through to `Estimator`. feature_engineering_fn: Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into `model_fn`. Please check `model_fn` for a definition of features and @@ -899,7 +895,6 @@ def single_value_rnn_regressor(num_units, return estimator.Estimator(model_fn=dynamic_rnn_model_fn, model_dir=model_dir, config=config, - params=params, feature_engineering_fn=feature_engineering_fn) @@ -914,32 +909,34 @@ def single_value_rnn_classifier(num_classes, learning_rate=0.1, predict_probabilities=False, momentum=None, - gradient_clipping_norm=10.0, + gradient_clipping_norm=5.0, input_keep_probability=None, output_keep_probability=None, model_dir=None, config=None, - params=None, feature_engineering_fn=None): """Creates a RNN `Estimator` that predicts single labels. Args: num_classes: The number of classes for categorization. - num_units: The size of the RNN cells. + num_units: The size of the RNN cells. This argument has no effect + if `cell_type` is an instance of `RNNCell`. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the set should be instances of classes derived from `FeatureColumn`. context_feature_columns: An iterable containing all the feature columns - describing context features i.e. features that apply accross all time + describing context features, i.e., features that apply accross all time steps. All items in the set should be instances of classes derived from `FeatureColumn`. cell_type: A subclass of `RNNCell`, an instance of an `RNNCell or one of 'basic_rnn,' 'lstm' or 'gru'. - num_rnn_layers: Number of RNN layers. + num_rnn_layers: Number of RNN layers. Leave this at its default value 1 + if passing a `cell_type` that is already a MultiRNNCell. optimizer_type: The type of optimizer to use. Either a subclass of `Optimizer`, an instance of an `Optimizer` or a string. Strings must be one of 'Adagrad', 'Momentum' or 'SGD'. - learning_rate: Learning rate. + learning_rate: Learning rate. This argument has no effect if `optimizer` + is an instance of an `Optimizer`. predict_probabilities: A boolean indicating whether to predict probabilities for all classes. momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'. @@ -947,12 +944,11 @@ def single_value_rnn_classifier(num_classes, then no clipping is performed. input_keep_probability: Probability to keep inputs to `cell`. If `None`, no dropout is applied. - output_keep_probability: Probability to keep outputs to `cell`. If `None`, + output_keep_probability: Probability to keep outputs of `cell`. If `None`, no dropout is applied. - model_dir: Directory to use for The directory in which to save and restore - the model graph, parameters, etc. + model_dir: The directory in which to save and restore the model graph, + parameters, etc. config: A `RunConfig` instance. - params: `dict` of hyperparameters. Passed through to `Estimator`. feature_engineering_fn: Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into `model_fn`. Please check `model_fn` for a definition of features and @@ -982,5 +978,4 @@ def single_value_rnn_classifier(num_classes, return estimator.Estimator(model_fn=dynamic_rnn_model_fn, model_dir=model_dir, config=config, - params=params, feature_engineering_fn=feature_engineering_fn) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index a2df6de6fd..f534789270 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import tempfile + import numpy as np import tensorflow as tf @@ -69,17 +71,6 @@ class MockTargetColumn(object): self._num_label_columns = n -class MockOptimizer(object): - - def compute_gradients(self, loss, var_list): - raise NotImplementedError( - 'MockOptimizer.compute_gradients called unexpectedly.') - - def apply_gradients(self, processed_gradients, global_step): - raise NotImplementedError( - 'MockOptimizer.apply_gradients called unexpectedly.') - - def sequence_length_mask(values, lengths): masked = values for i, length in enumerate(lengths): @@ -95,6 +86,7 @@ class DynamicRnnEstimatorTest(tf.test.TestCase): 'inputs', dimension=NUM_LABEL_COLUMNS) def setUp(self): + super(DynamicRnnEstimatorTest, self).setUp() self.rnn_cell = rnn_cell.BasicRNNCell(self.NUM_RNN_CELL_UNITS) self.mock_target_column = MockTargetColumn( num_label_columns=self.NUM_LABEL_COLUMNS) @@ -112,7 +104,9 @@ class DynamicRnnEstimatorTest(tf.test.TestCase): 'measurements', dimension=2) self.sequence_feature_columns = [measurements, wire_cast_embedded] - self.columns_to_tensors = { + def GetColumnsToTensors(self): + """Get columns_to_tensors matching setUp(), in the current default graph.""" + return { 'location': tf.SparseTensor( indices=[[0, 0], [1, 0], [2, 0]], values=['west_side', 'west_side', 'nyc'], @@ -125,11 +119,16 @@ class DynamicRnnEstimatorTest(tf.test.TestCase): b'omar', b'stringer', b'marlo', b'marlo'], shape=[3, 2, 2]), - 'measurements': tf.random_uniform([3, 2, 2])} + 'measurements': tf.random_uniform([3, 2, 2], seed=4711)} + + def GetClassificationTargetsOrNone(self, mode): + """Get targets matching setUp() and mode, in the current default graph.""" + return (tf.random_uniform([3, 2, 1], 0, 2, dtype=tf.int64, seed=1412) + if mode != tf.contrib.learn.ModeKeys.INFER else None) def testBuildSequenceInputInput(self): sequence_input = dynamic_rnn_estimator.build_sequence_input( - self.columns_to_tensors, + self.GetColumnsToTensors(), self.sequence_feature_columns, self.context_feature_columns) with self.test_session() as sess: @@ -146,7 +145,7 @@ class DynamicRnnEstimatorTest(tf.test.TestCase): def testConstructRNN(self): initial_state = None sequence_input = dynamic_rnn_estimator.build_sequence_input( - self.columns_to_tensors, + self.GetColumnsToTensors(), self.sequence_feature_columns, self.context_feature_columns) activations_t, final_state_t = dynamic_rnn_estimator.construct_rnn( @@ -166,30 +165,6 @@ class DynamicRnnEstimatorTest(tf.test.TestCase): expected_state_shape = np.array([3, self.NUM_RNN_CELL_UNITS]) self.assertAllEqual(expected_state_shape, final_state.shape) - def testPaddingMask(self): - """Test `padding_mask`.""" - batch_size = 16 - padded_length = 32 - np.random.seed(1234) - sequence_lengths = np.random.randint(0, padded_length + 1, batch_size) - - padding_mask_t = dynamic_rnn_estimator.padding_mask( - tf.constant(sequence_lengths, dtype=tf.int32), - tf.constant(padded_length, dtype=tf.int32)) - - with tf.Session() as sess: - padding_mask = sess.run(padding_mask_t) - - for i in range(batch_size): - actual_mask = padding_mask[i] - expected_mask = np.concatenate( - [np.ones(sequence_lengths[i]), - np.zeros(padded_length - sequence_lengths[i])], - axis=0) - np.testing.assert_equal(actual_mask, expected_mask, - 'Mismatch on row {}. Got {}; expected {}.'.format( - i, actual_mask, expected_mask)) - def testMaskActivationsAndLabels(self): """Test `mask_activations_and_labels`.""" batch_size = 4 @@ -275,9 +250,90 @@ class DynamicRnnEstimatorTest(tf.test.TestCase): ' Expected {}; got {}.'.format(i, expected_activations, actual_activations)) + # testGetDynamicRnnModelFn{Train,Eval,Infer}() test which fields + # of ModelFnOps are set depending on mode. + def testGetDynamicRnnModelFnTrain(self): + model_fn_ops = self._GetModelFnOpsForMode(tf.contrib.learn.ModeKeys.TRAIN) + self.assertIsNotNone(model_fn_ops.predictions) + self.assertIsNotNone(model_fn_ops.loss) + self.assertIsNotNone(model_fn_ops.train_op) + # None may get normalized to {}; we accept neither. + self.assertNotEqual(len(model_fn_ops.eval_metric_ops), 0) + + def testGetDynamicRnnModelFnEval(self): + model_fn_ops = self._GetModelFnOpsForMode(tf.contrib.learn.ModeKeys.EVAL) + self.assertIsNotNone(model_fn_ops.predictions) + self.assertIsNotNone(model_fn_ops.loss) + self.assertIsNone(model_fn_ops.train_op) + # None may get normalized to {}; we accept neither. + self.assertNotEqual(len(model_fn_ops.eval_metric_ops), 0) + + def testGetDynamicRnnModelFnInfer(self): + model_fn_ops = self._GetModelFnOpsForMode(tf.contrib.learn.ModeKeys.INFER) + self.assertIsNotNone(model_fn_ops.predictions) + self.assertIsNone(model_fn_ops.loss) + self.assertIsNone(model_fn_ops.train_op) + # None may get normalized to {}; we accept both. + self.assertFalse(model_fn_ops.eval_metric_ops) + + def _GetModelFnOpsForMode(self, mode): + """Helper for testGetDynamicRnnModelFn{Train,Eval,Infer}().""" + model_fn = dynamic_rnn_estimator._get_dynamic_rnn_model_fn( + self.rnn_cell, + target_column=tf.contrib.layers.multi_class_target(n_classes=2), + # Only CLASSIFICATION yields eval metrics to test for. + problem_type=dynamic_rnn_estimator.ProblemType.CLASSIFICATION, + prediction_type=dynamic_rnn_estimator.PredictionType.MULTIPLE_VALUE, + optimizer='SGD', + sequence_feature_columns=self.sequence_feature_columns, + context_feature_columns=self.context_feature_columns, + learning_rate=0.1) + labels = self.GetClassificationTargetsOrNone(mode) + model_fn_ops = model_fn(features=self.GetColumnsToTensors(), + labels=labels, mode=mode) + return model_fn_ops + + def testExport(self): + input_feature_key = 'magic_input_feature_key' + def get_input_fn(mode): + def input_fn(): + features = self.GetColumnsToTensors() + if mode == tf.contrib.learn.ModeKeys.INFER: + input_examples = tf.placeholder(tf.string) + features[input_feature_key] = input_examples + # Real code would now parse features out of input_examples, + # but this test can just stick to the constants above. + return features, self.GetClassificationTargetsOrNone(mode) + return input_fn + + model_dir = tempfile.mkdtemp() + def estimator_fn(): + return dynamic_rnn_estimator.multi_value_rnn_classifier( + num_classes=2, + num_units=self.NUM_RNN_CELL_UNITS, + sequence_feature_columns=self.sequence_feature_columns, + context_feature_columns=self.context_feature_columns, + predict_probabilities=True, + model_dir=model_dir) + + # Train a bit to create an exportable checkpoint. + estimator_fn().fit( + input_fn=get_input_fn(tf.contrib.learn.ModeKeys.TRAIN), steps=100) + # Now export, but from a fresh estimator instance, like you would + # in an export binary. That means .export() has to work without + # .fit() being called on the same object. + export_dir = tempfile.mkdtemp() + print('Exporting to', export_dir) + estimator_fn().export( + export_dir, + input_fn=get_input_fn(tf.contrib.learn.ModeKeys.INFER), + use_deprecated_input_fn=False, + input_feature_key=input_feature_key) + + # TODO(jamieas): move all tests below to a benchmark test. class DynamicRNNEstimatorLearningTest(tf.test.TestCase): - """Learning tests for dymanic RNN Estimators.""" + """Learning tests for dynamic RNN Estimators.""" def testLearnSineFunction(self): """Tests learning a sine function.""" diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 5402c20297..ade126d6d1 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -38,6 +38,8 @@ from tensorflow.contrib.framework import deprecated_arg_values from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework import list_variables from tensorflow.contrib.framework import load_variable +from tensorflow.contrib.framework.python.framework import experimental +from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import graph_actions from tensorflow.contrib.learn.python.learn import metric_spec @@ -51,14 +53,21 @@ from tensorflow.contrib.learn.python.learn.estimators import tensor_signature from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError from tensorflow.contrib.learn.python.learn.learn_io import data_feeder from tensorflow.contrib.learn.python.learn.utils import export - +from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils +from tensorflow.python.client import session as tf_session from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import builder as saved_model_builder +from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import device_setter from tensorflow.python.training import saver +from tensorflow.python.util import compat AS_ITERABLE_DATE = '2016-09-15' @@ -553,13 +562,12 @@ class BaseEstimator( use_deprecated_input_fn=use_deprecated_input_fn, default_batch_size=default_batch_size, exports_to_keep=exports_to_keep) - # pylint: enable=protected-access @abc.abstractproperty def _get_train_ops(self, features, labels): """Method that builds model graph and returns trainer ops. - Expected to be overriden by sub-classes that require custom support. + Expected to be overridden by sub-classes that require custom support. Args: features: `Tensor` or `dict` of `Tensor` objects. @@ -1126,6 +1134,106 @@ class Estimator(BaseEstimator): self._labels_info[model_fn_lib.ModeKeys.INFER]) return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.INFER) + @experimental + def export_savedmodel( + self, export_dir_base, input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + exports_to_keep=None): + """Exports inference graph as a SavedModel into given dir. + + Args: + export_dir_base: A string containing a directory to write the exported + graph and checkpoints. + input_fn: A function that takes no argument and + returns an `InputFnOps`. + default_output_alternative_key: the name of the head to serve when none is + specified. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel. Each key should give the destination + path (including the filename) relative to the assets.extra directory. + The corresponding value gives the full path of the source file to be + copied. For example, the simple case of copying a single file without + renaming it is specified as + `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + as_text: whether to write the SavedModel proto in text format. + exports_to_keep: Number of exports to keep. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if an unrecognized export_type is requested. + """ + if input_fn is None: + raise ValueError('input_fn must be defined.') + + with ops.Graph().as_default() as g: + contrib_variables.create_global_step(g) + + # Call the input_fn and collect the input alternatives. + input_ops = input_fn() + input_alternatives, features = ( + saved_model_export_utils.get_input_alternatives(input_ops)) + + # Call the model_fn and collect the output alternatives. + model_fn_ops = self._call_model_fn(features, None, + model_fn_lib.ModeKeys.INFER) + output_alternatives, actual_default_output_alternative_key = ( + saved_model_export_utils.get_output_alternatives( + model_fn_ops, default_output_alternative_key)) + + # Build the SignatureDefs from all pairs of input and output signatures + signature_def_map = saved_model_export_utils.build_all_signature_defs( + input_alternatives, output_alternatives, + actual_default_output_alternative_key) + + # Locate the latest checkpoint + # TODO(soergel): does it help that we know we have one from this step? + checkpoint_path = saver.latest_checkpoint(self._model_dir) + if not checkpoint_path: + raise NotFittedError("Couldn't find trained model at %s." + % self._model_dir) + + export_dir = saved_model_export_utils.get_timestamped_export_dir( + export_dir_base) + + with tf_session.Session('') as session: + variables.initialize_local_variables() + data_flow_ops.initialize_all_tables() + saver_for_restore = saver.Saver( + variables.global_variables(), + sharded=True) + saver_for_restore.restore(session, checkpoint_path) + + init_op = control_flow_ops.group( + variables.local_variables_initializer(), + data_flow_ops.initialize_all_tables()) + + # Perform the export + builder = saved_model_builder.SavedModelBuilder(export_dir) + builder.add_meta_graph_and_variables( + session, [tag_constants.SERVING], + signature_def_map=signature_def_map, + assets_collection=ops.get_collection( + ops.GraphKeys.ASSET_FILEPATHS), + legacy_init_op=init_op) + builder.save(as_text) + + # Add the extra assets + if assets_extra: + assets_extra_path = os.path.join(compat.as_bytes(export_dir), + compat.as_bytes('assets.extra')) + for dest_relative, source in assets_extra.items(): + dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), + compat.as_bytes(dest_relative)) + dest_path = os.path.dirname(dest_absolute) + gfile.MakeDirs(dest_path) + gfile.Copy(source, dest_absolute) + + return export_dir + # For time of deprecation x,y from Estimator allow direct access # pylint: disable=protected-access diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index ae38e7a79e..a43b960a96 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -22,6 +22,7 @@ from __future__ import print_function import functools import itertools import json +import os import tempfile import numpy as np @@ -33,6 +34,11 @@ from tensorflow.contrib.learn.python.learn import metric_spec from tensorflow.contrib.learn.python.learn.estimators import _sklearn from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn +from tensorflow.contrib.learn.python.learn.utils import input_fn_utils +from tensorflow.python.framework import ops +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import tag_constants +from tensorflow.python.util import compat _BOSTON_INPUT_DIM = 13 @@ -105,6 +111,8 @@ def linear_model_fn(features, labels, mode): tf.contrib.learn.ModeKeys.TRAIN, tf.contrib.learn.ModeKeys.EVAL, tf.contrib.learn.ModeKeys.INFER) + if isinstance(features, dict): + (_, features), = features.items() prediction, loss = ( tf.contrib.learn.models.linear_regression_zero_init(features, labels) ) @@ -144,6 +152,45 @@ def logistic_model_no_mode_fn(features, labels): learning_rate=0.1) return {'class': tf.argmax(prediction, 1), 'prob': prediction}, loss, train_op +VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n' +EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n' + + +def _build_estimator_for_export_tests(tmpdir): + def _input_fn(): + iris = tf.contrib.learn.datasets.load_iris() + return { + 'feature': tf.constant(iris.data, dtype=tf.float32) + }, tf.constant(iris.target, shape=[150], dtype=tf.int32) + + feature_columns = [tf.contrib.layers.real_valued_column('feature', + dimension=4)] + + est = tf.contrib.learn.LinearRegressor(feature_columns) + est.fit(input_fn=_input_fn, steps=20) + + feature_spec = tf.contrib.layers.create_feature_spec_for_parsing( + feature_columns) + export_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec) + + # hack in an op that uses an asset, in order to test asset export. + # this is not actually valid, of course. + def export_input_fn_with_asset(): + features, labels, inputs = export_input_fn() + + vocab_file_name = os.path.join(tmpdir, 'my_vocab_file') + vocab_file = tf.gfile.GFile(vocab_file_name, mode='w') + vocab_file.write(VOCAB_FILE_CONTENT) + vocab_file.close() + hashtable = tf.contrib.lookup.HashTable( + tf.contrib.lookup.TextFileStringTableInitializer(vocab_file_name), 'x') + features['bogus_lookup'] = hashtable.lookup( + tf.to_int64(features['feature'])) + + return input_fn_utils.InputFnOps(features, labels, inputs) + + return est, export_input_fn_with_asset + class CheckCallsMonitor(tf.contrib.learn.monitors.BaseMonitor): @@ -585,6 +632,76 @@ class EstimatorTest(tf.test.TestCase): self.assertEquals(expected, actual) + def test_export_savedmodel(self): + tmpdir = tempfile.mkdtemp() + est, export_input_fn = _build_estimator_for_export_tests(tmpdir) + + extra_file_name = os.path.join(compat.as_bytes(tmpdir), + compat.as_bytes('my_extra_file')) + extra_file = tf.gfile.GFile(extra_file_name, mode='w') + extra_file.write(EXTRA_FILE_CONTENT) + extra_file.close() + assets_extra = {'some/sub/directory/my_extra_file': extra_file_name} + + export_dir_base = os.path.join(compat.as_bytes(tmpdir), + compat.as_bytes('export')) + export_dir = est.export_savedmodel(export_dir_base, export_input_fn, + assets_extra=assets_extra) + + self.assertTrue(tf.gfile.Exists(export_dir_base)) + self.assertTrue(tf.gfile.Exists(export_dir)) + self.assertTrue(tf.gfile.Exists( + os.path.join(compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) + self.assertTrue(tf.gfile.Exists( + os.path.join(compat.as_bytes(export_dir), + compat.as_bytes('variables')))) + self.assertTrue(tf.gfile.Exists( + os.path.join(compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.index')))) + self.assertTrue(tf.gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.data-00000-of-00001')))) + + self.assertTrue(tf.gfile.Exists( + os.path.join(compat.as_bytes(export_dir), compat.as_bytes('assets')))) + self.assertTrue(tf.gfile.Exists( + os.path.join(compat.as_bytes(export_dir), + compat.as_bytes('assets/my_vocab_file')))) + self.assertEqual( + compat.as_bytes(VOCAB_FILE_CONTENT), + compat.as_bytes(tf.gfile.GFile( + os.path.join(compat.as_bytes(export_dir), + compat.as_bytes('assets/my_vocab_file'))).read())) + + expected_extra_path = os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('assets.extra/some/sub/directory/my_extra_file')) + self.assertTrue(tf.gfile.Exists( + os.path.join(compat.as_bytes(export_dir), + compat.as_bytes('assets.extra')))) + self.assertTrue(tf.gfile.Exists(expected_extra_path)) + self.assertEqual( + compat.as_bytes(EXTRA_FILE_CONTENT), + compat.as_bytes(tf.gfile.GFile(expected_extra_path).read())) + + expected_vocab_file = os.path.join(compat.as_bytes(tmpdir), + compat.as_bytes('my_vocab_file')) + # Restore, to validate that the export was well-formed. + with tf.Graph().as_default() as graph: + with tf.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + assets = [x.eval() + for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)] + self.assertItemsEqual([expected_vocab_file], assets) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertTrue('linear/linear/feature/matmul' in graph_ops) + + # cleanup + tf.gfile.DeleteRecursively(tmpdir) + class InferRealValuedColumnsTest(tf.test.TestCase): diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 389b5b2b62..ad7a3f5a46 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -19,10 +19,12 @@ from __future__ import division from __future__ import print_function import abc +import six from tensorflow.contrib import losses from tensorflow.contrib import metrics as metrics_lib from tensorflow.contrib.learn.python.learn import metric_spec +from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import metric_key from tensorflow.contrib.learn.python.learn.estimators import model_fn @@ -33,9 +35,10 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn -from tensorflow.python.ops import variables +from tensorflow.python.ops import variable_scope from tensorflow.python.training import training @@ -64,8 +67,7 @@ def _regression_head(label_name=None, Returns: An instance of _Head """ - return _RegressionHead(loss_fn=_mean_squared_loss, - label_name=label_name, + return _RegressionHead(label_name=label_name, weight_column_name=weight_column_name, label_dimension=label_dimension, enable_centered_bias=enable_centered_bias, @@ -198,6 +200,9 @@ class _Head(object): """ __metaclass__ = abc.ABCMeta + def __init__(self, head_name): + self._head_name = head_name + @abc.abstractproperty def logits_dimension(self): raise NotImplementedError("Calling an abstract method.") @@ -215,8 +220,7 @@ class _Head(object): optimize with the loss. logits: logits to be used for the head. logits_input: tensor to build logits from. - scope: Optional scope for variable_scope. Only used by heads which create - variables. + scope: Optional scope for variable_scope. Returns: `ModelFnOps`. @@ -226,16 +230,48 @@ class _Head(object): """ raise NotImplementedError("Calling an abstract method.") + def _create_output_alternatives(self, predictions): + """Creates output alternative for the Head. + + Args: + predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a + symbolic name for an output Tensor possibly but not necessarily taken + from `PredictionKey`, and 'Tensor' is the corresponding output Tensor + itself. + + Returns: + `dict` of {submodel_name: (problem_type, {tensor_name: Tensor})}, where + 'submodel_name' is a submodel identifier that should be consistent across + the pipeline (here likely taken from the head_name), + 'problem_type' is a `ProblemType`, + 'tensor_name' is a symbolic name for an output Tensor possibly but not + necessarily taken from `PredictionKey`, and + 'Tensor' is the corresponding output Tensor itself. + """ + return {self._head_name: (self._problem_type, predictions)} + + +# TODO(zakaria): use contrib losses. +def _mean_squared_loss(logits, labels): + with ops.name_scope(None, "mean_squared_loss", (logits, labels)) as name: + # To prevent broadcasting inside "-". + if len(labels.get_shape()) == 1: + labels = array_ops.expand_dims(labels, dim=(1,)) + # TODO(zakaria): make sure it does not recreate the broadcast bug. + if len(logits.get_shape()) == 1: + logits = array_ops.expand_dims(logits, dim=(1,)) + logits.get_shape().assert_is_compatible_with(labels.get_shape()) + return math_ops.square(logits - math_ops.to_float(labels), name=name) + class _RegressionHead(_Head): """_Head for regression.""" - def __init__(self, loss_fn, label_name, weight_column_name, label_dimension, - enable_centered_bias, head_name): + def __init__(self, label_name, weight_column_name, label_dimension, + enable_centered_bias, head_name, loss_fn=_mean_squared_loss): """Base type for all single heads. Args: - loss_fn: Loss function. label_name: String, name of the key in label dict. Can be null if label is a tensor (single headed models). weight_column_name: A string defining feature column name representing @@ -247,15 +283,16 @@ class _RegressionHead(_Head): residual after centered bias. head_name: name of the head. If provided, predictions, summary and metrics keys will be prefixed by the head_name and an underscore. + loss_fn: Loss function. """ + super(_RegressionHead, self).__init__(head_name=head_name) + self._loss_fn = loss_fn self._logits_dimension = label_dimension self._label_name = label_name self._weight_column_name = weight_column_name - self._head_name = head_name self._enable_centered_bias = enable_centered_bias - self._centered_bias_weight_collection = _head_prefixed(head_name, - "centered_bias") + self._problem_type = constants.ProblemType.LINEAR_REGRESSION @property def logits_dimension(self): @@ -266,16 +303,29 @@ class _RegressionHead(_Head): """See `_Head`.""" _check_mode_valid(mode) _check_logits_input_not_supported(logits, logits_input) - predictions = self._predictions(logits) - if (mode == model_fn.ModeKeys.INFER) or (labels is None): - loss = None - train_op = None - eval_metric_ops = None - else: - loss = self._training_loss(features, labels, logits) - train_op = (None if train_op_fn is None or mode == model_fn.ModeKeys.EVAL - else self._train_op(loss, labels, train_op_fn)) - eval_metric_ops = self._eval_metric_ops(features, labels, predictions) + + centered_bias = None + if self._enable_centered_bias: + centered_bias = _centered_bias(self._logits_dimension) + logits = nn.bias_add(logits, centered_bias) + + predictions = self._logits_to_predictions(logits) + loss = None + train_op = None + eval_metric_ops = None + if (mode != model_fn.ModeKeys.INFER) and (labels is not None): + labels = _check_labels(labels, self._label_name) + loss = _training_loss( + features, labels, logits, + loss_fn=self._loss_fn, + weight_column_name=self._weight_column_name, + head_name=self._head_name) + if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None): + train_op = _train_op( + loss, labels, train_op_fn, centered_bias, self.logits_dimension, + self._loss_fn) + eval_metric_ops = _eval_metric_ops( + self._default_metrics(), features, labels, predictions) return model_fn.ModelFnOps( mode=mode, @@ -283,79 +333,8 @@ class _RegressionHead(_Head): loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, - signature_fn=self._signature_fn()) - - def _training_loss(self, features, labels, logits, name="training_loss"): - """Returns training loss tensor for this head. - - Training loss is different from the loss reported on the tensorboard as we - should respect the example weights when computing the gradient. - - L = sum_{i} w_{i} * l_{i} / B - - where B is the number of examples in the batch, l_{i}, w_{i} are individual - losses, and example weight. - - Args: - features: features dict. - labels: either a tensor for labels or in multihead case, a dict of string - to labels tensor. - logits: logits, a float tensor. - name: Op name. - - Returns: - A loss `Tensor`. - """ - labels = _check_labels(labels, self._label_name) - - if self._enable_centered_bias: - logits = nn.bias_add(logits, _centered_bias( - self.logits_dimension, - self._centered_bias_weight_collection)) - - loss_unweighted = self._loss_fn(logits, labels) - loss, weighted_average_loss = _loss( - loss_unweighted, - _weight_tensor(features, self._weight_column_name), - name=name) - summary.scalar( - _head_prefixed(self._head_name, "loss"), weighted_average_loss) - return loss - - def _train_op(self, loss, labels, train_op_fn): - """Returns op for the training step.""" - train_op = train_op_fn(loss) - - if self._enable_centered_bias: - centered_bias_step = [_centered_bias_step( - self.logits_dimension, - self._centered_bias_weight_collection, - labels, - self._loss_fn)] - train_op = control_flow_ops.group(train_op, *centered_bias_step) - - return train_op - - def _eval_metric_ops(self, features, labels, predictions): - """Returns a dict of metric ops keyed by name.""" - labels = _check_labels(labels, self._label_name) - return estimator._make_metrics_ops( # pylint: disable=protected-access - self._default_metrics(), features, labels, predictions) - - def _predictions(self, logits): - """Returns a dict of predictions. - - Args: - logits: logits `Tensor` before applying possible centered bias. - - Returns: - Dict of prediction `Tensor` keyed by `PredictionKey`. - """ - if self._enable_centered_bias: - logits = nn.bias_add(logits, _centered_bias( - self.logits_dimension, - self._centered_bias_weight_collection)) - return self._logits_to_predictions(logits) + signature_fn=self._signature_fn(), + output_alternatives=self._create_output_alternatives(predictions)) def _logits_to_predictions(self, logits): """Returns a dict of predictions. @@ -366,13 +345,11 @@ class _RegressionHead(_Head): Returns: Dict of prediction `Tensor` keyed by `PredictionKey`. """ - predictions = {} - if self.logits_dimension == 1: - predictions[prediction_key.PredictionKey.SCORES] = array_ops.squeeze( - logits, squeeze_dims=[1], name=prediction_key.PredictionKey.SCORES) - else: - predictions[prediction_key.PredictionKey.SCORES] = logits - return predictions + key = prediction_key.PredictionKey.SCORES + with ops.name_scope(None, "predictions", (logits,)): + if self.logits_dimension == 1: + logits = array_ops.squeeze(logits, squeeze_dims=(1,), name=key) + return {key: logits} def _signature_fn(self): """Returns the signature_fn to be used in exporting.""" @@ -399,11 +376,17 @@ class _RegressionHead(_Head): def _log_loss_with_two_classes(logits, labels): - # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels. - if len(labels.get_shape()) == 1: - labels = array_ops.expand_dims(labels, dim=[1]) - return nn.sigmoid_cross_entropy_with_logits( - logits, math_ops.to_float(labels)) + with ops.name_scope( + None, "log_loss_with_two_classes", (logits, labels)) as name: + # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels. + if len(labels.get_shape()) == 1: + labels = array_ops.expand_dims(labels, dim=(1,)) + return nn.sigmoid_cross_entropy_with_logits( + logits, math_ops.to_float(labels), name=name) + + +def _one_class_to_two_class_logits(logits): + return array_ops.concat(1, (array_ops.zeros_like(logits), logits)) class _BinaryLogisticHead(_Head): @@ -430,34 +413,45 @@ class _BinaryLogisticHead(_Head): Raises: ValueError: if n_classes is invalid. """ - self._thresholds = thresholds if thresholds else [.5] + super(_BinaryLogisticHead, self).__init__(head_name=head_name) + self._thresholds = thresholds if thresholds else (.5,) self._label_name = label_name self._weight_column_name = weight_column_name - self._head_name = head_name self._loss_fn = loss_fn self._enable_centered_bias = enable_centered_bias - self._centered_bias_weight_collection = _head_prefixed(head_name, - "centered_bias") @property def logits_dimension(self): return 1 def head_ops(self, features, labels, mode, train_op_fn, logits=None, - logits_input=None): + logits_input=None, scope=None): """See `_Head`.""" _check_mode_valid(mode) _check_logits_input_not_supported(logits, logits_input) - predictions = self._predictions(logits) - if (mode == model_fn.ModeKeys.INFER) or (labels is None): - loss = None - train_op = None - eval_metric_ops = None - else: - loss = self._training_loss(features, labels, logits) - train_op = (None if train_op_fn is None - else self._train_op(loss, labels, train_op_fn)) - eval_metric_ops = self._eval_metric_ops(features, labels, predictions) + + centered_bias = None + if self._enable_centered_bias: + centered_bias = _centered_bias(1) + logits = nn.bias_add(logits, centered_bias) + + predictions = self._logits_to_predictions(logits) + loss = None + train_op = None + eval_metric_ops = None + if (mode != model_fn.ModeKeys.INFER) and (labels is not None): + labels = _check_labels(labels, self._label_name) + loss = _training_loss( + features, labels, logits, + loss_fn=self._loss_fn, + weight_column_name=self._weight_column_name, + head_name=self._head_name) + if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None): + train_op = _train_op( + loss, labels, train_op_fn, centered_bias, self.logits_dimension, + self._loss_fn) + eval_metric_ops = _eval_metric_ops( + self._default_metrics(), features, labels, predictions) return model_fn.ModelFnOps( mode=mode, @@ -467,78 +461,6 @@ class _BinaryLogisticHead(_Head): eval_metric_ops=eval_metric_ops, signature_fn=self._signature_fn()) - def _training_loss(self, features, labels, logits=None, name="training_loss"): - """Returns training loss tensor for this head. - - Training loss is different from the loss reported on the tensorboard as we - should respect the example weights when computing the gradient. - - L = sum_{i} w_{i} * l_{i} / B - - where B is the number of examples in the batch, l_{i}, w_{i} are individual - losses, and example weight. - - Args: - features: features dict. - labels: either a tensor for labels or in multihead case, a dict of string - to labels tensor. - logits: logits, a float tensor. - name: Op name. - - Returns: - A loss `Output`. - """ - labels = _check_labels(labels, self._label_name) - - if self._enable_centered_bias: - logits = nn.bias_add(logits, _centered_bias( - self.logits_dimension, - self._centered_bias_weight_collection)) - - loss_unweighted = self._loss_fn(logits, labels) - loss, weighted_average_loss = _loss( - loss_unweighted, - _weight_tensor(features, self._weight_column_name), - name=name) - summary.scalar( - _head_prefixed(self._head_name, "loss"), weighted_average_loss) - return loss - - def _train_op(self, loss, labels, train_op_fn): - """Returns op for the training step.""" - train_op = train_op_fn(loss) - - if self._enable_centered_bias: - centered_bias_step = [_centered_bias_step( - self.logits_dimension, - self._centered_bias_weight_collection, - labels, - self._loss_fn)] - train_op = control_flow_ops.group(train_op, *centered_bias_step) - - return train_op - - def _eval_metric_ops(self, features, labels, predictions): - """Returns a dict of metric ops keyed by name.""" - labels = _check_labels(labels, self._label_name) - return estimator._make_metrics_ops( # pylint: disable=protected-access - self._default_metrics(), features, labels, predictions) - - def _predictions(self, logits): - """Returns a dict of predictions. - - Args: - logits: logits `Output` before applying possible centered bias. - - Returns: - Dict of prediction `Output` keyed by `PredictionKey`. - """ - if self._enable_centered_bias: - logits = nn.bias_add(logits, _centered_bias( - self.logits_dimension, - self._centered_bias_weight_collection)) - return self._logits_to_predictions(logits) - def _logits_to_predictions(self, logits): """Returns a dict of predictions. @@ -548,15 +470,18 @@ class _BinaryLogisticHead(_Head): Returns: Dict of prediction `Output` keyed by `PredictionKey`. """ - predictions = {prediction_key.PredictionKey.LOGITS: logits} - predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid( - logits) - logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits]) - predictions[prediction_key.PredictionKey.PROBABILITIES] = nn.softmax( - logits) - predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax( - logits, 1) - return predictions + with ops.name_scope(None, "predictions", (logits,)): + two_class_logits = _one_class_to_two_class_logits(logits) + return { + prediction_key.PredictionKey.LOGITS: logits, + prediction_key.PredictionKey.LOGISTIC: math_ops.sigmoid( + logits, name=prediction_key.PredictionKey.LOGISTIC), + prediction_key.PredictionKey.PROBABILITIES: nn.softmax( + two_class_logits, + name=prediction_key.PredictionKey.PROBABILITIES), + prediction_key.PredictionKey.CLASSES: math_ops.argmax( + two_class_logits, 1, name=prediction_key.PredictionKey.CLASSES) + } def _signature_fn(self): """Returns the signature_fn to be used in exporting.""" @@ -628,14 +553,17 @@ class _BinaryLogisticHead(_Head): def _softmax_cross_entropy_loss(logits, labels): - # Check that we got integer for classification. - if not labels.dtype.is_integer: - raise ValueError("Labels dtype should be integer " - "Instead got %s." % labels.dtype) - # sparse_softmax_cross_entropy_with_logits requires [batch_size] labels. - if len(labels.get_shape()) == 2: - labels = array_ops.squeeze(labels, squeeze_dims=[1]) - return nn.sparse_softmax_cross_entropy_with_logits(logits, labels) + with ops.name_scope( + None, "softmax_cross_entropy_loss", (logits, labels,)) as name: + # Check that we got integer for classification. + if not labels.dtype.is_integer: + raise ValueError("Labels dtype should be integer " + "Instead got %s." % labels.dtype) + # sparse_softmax_cross_entropy_with_logits requires [batch_size] labels. + if len(labels.get_shape()) == 2: + labels = array_ops.squeeze(labels, squeeze_dims=(1,)) + return nn.sparse_softmax_cross_entropy_with_logits( + logits, labels, name=name) class _MultiClassHead(_Head): @@ -665,18 +593,17 @@ class _MultiClassHead(_Head): Raises: ValueError: if n_classes is invalid. """ + super(_MultiClassHead, self).__init__(head_name=head_name) + if (n_classes is None) or (n_classes <= 2): raise ValueError("n_classes must be > 2: %s." % n_classes) - self._thresholds = thresholds if thresholds else [.5] - + self._thresholds = thresholds if thresholds else (.5,) self._logits_dimension = n_classes self._label_name = label_name self._weight_column_name = weight_column_name - self._head_name = head_name self._loss_fn = loss_fn self._enable_centered_bias = enable_centered_bias - self._centered_bias_weight_collection = _head_prefixed(head_name, - "centered_bias") + self._problem_type = constants.ProblemType.CLASSIFICATION @property def logits_dimension(self): @@ -687,16 +614,29 @@ class _MultiClassHead(_Head): """See `_Head`.""" _check_mode_valid(mode) _check_logits_input_not_supported(logits, logits_input) - predictions = self._predictions(logits) - if (mode == model_fn.ModeKeys.INFER) or (labels is None): - loss = None - train_op = None - eval_metric_ops = None - else: - loss = self._training_loss(features, labels, logits) - train_op = (None if train_op_fn is None or mode == model_fn.ModeKeys.EVAL - else self._train_op(loss, labels, train_op_fn)) - eval_metric_ops = self._eval_metric_ops(features, labels, predictions) + + centered_bias = None + if self._enable_centered_bias: + centered_bias = _centered_bias(self._logits_dimension) + logits = nn.bias_add(logits, centered_bias) + + predictions = self._logits_to_predictions(logits) + loss = None + train_op = None + eval_metric_ops = None + if (mode != model_fn.ModeKeys.INFER) and (labels is not None): + labels = _check_labels(labels, self._label_name) + loss = _training_loss( + features, labels, logits, + loss_fn=self._loss_fn, + weight_column_name=self._weight_column_name, + head_name=self._head_name) + if (mode == model_fn.ModeKeys.TRAIN) and (train_op_fn is not None): + train_op = _train_op( + loss, labels, train_op_fn, centered_bias, self._logits_dimension, + self._loss_fn) + eval_metric_ops = _eval_metric_ops( + self._default_metrics(), features, labels, predictions) return model_fn.ModelFnOps( mode=mode, @@ -704,79 +644,8 @@ class _MultiClassHead(_Head): loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, - signature_fn=self._signature_fn()) - - def _training_loss(self, features, labels, logits=None, name="training_loss"): - """Returns training loss tensor for this head. - - Training loss is different from the loss reported on the tensorboard as we - should respect the example weights when computing the gradient. - - L = sum_{i} w_{i} * l_{i} / B - - where B is the number of examples in the batch, l_{i}, w_{i} are individual - losses, and example weight. - - Args: - features: features dict. - labels: either a tensor for labels or in multihead case, a dict of string - to labels tensor. - logits: logits, a float tensor. - name: Op name. - - Returns: - A loss `Tensor`. - """ - labels = _check_labels(labels, self._label_name) - - if self._enable_centered_bias: - logits = nn.bias_add(logits, _centered_bias( - self.logits_dimension, - self._centered_bias_weight_collection)) - - loss_unweighted = self._loss_fn(logits, labels) - loss, weighted_average_loss = _loss( - loss_unweighted, - _weight_tensor(features, self._weight_column_name), - name=name) - summary.scalar( - _head_prefixed(self._head_name, "loss"), weighted_average_loss) - return loss - - def _train_op(self, loss, labels, train_op_fn): - """Returns op for the training step.""" - train_op = train_op_fn(loss) - - if self._enable_centered_bias: - centered_bias_step = [_centered_bias_step( - self.logits_dimension, - self._centered_bias_weight_collection, - labels, - self._loss_fn)] - train_op = control_flow_ops.group(train_op, *centered_bias_step) - - return train_op - - def _eval_metric_ops(self, features, labels, predictions): - """Returns a dict of metric ops keyed by name.""" - labels = _check_labels(labels, self._label_name) - return estimator._make_metrics_ops( # pylint: disable=protected-access - self._default_metrics(), features, labels, predictions) - - def _predictions(self, logits): - """Returns a dict of predictions. - - Args: - logits: logits `Tensor` before applying possible centered bias. - - Returns: - Dict of prediction `Tensor` keyed by `PredictionKey`. - """ - if self._enable_centered_bias: - logits = nn.bias_add(logits, _centered_bias( - self.logits_dimension, - self._centered_bias_weight_collection)) - return self._logits_to_predictions(logits) + signature_fn=self._signature_fn(), + output_alternatives=self._create_output_alternatives(predictions)) def _logits_to_predictions(self, logits): """Returns a dict of predictions. @@ -787,13 +656,14 @@ class _MultiClassHead(_Head): Returns: Dict of prediction `Tensor` keyed by `PredictionKey`. """ - predictions = {prediction_key.PredictionKey.LOGITS: logits} - predictions[prediction_key.PredictionKey.PROBABILITIES] = nn.softmax( - logits) - predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax( - logits, 1) - - return predictions + with ops.name_scope(None, "predictions", (logits,)): + return { + prediction_key.PredictionKey.LOGITS: logits, + prediction_key.PredictionKey.PROBABILITIES: nn.softmax( + logits, name=prediction_key.PredictionKey.PROBABILITIES), + prediction_key.PredictionKey.CLASSES: math_ops.argmax( + logits, 1, name=prediction_key.PredictionKey.CLASSES) + } def _signature_fn(self): """Returns the signature_fn to be used in exporting.""" @@ -849,31 +719,32 @@ class _BinarySvmHead(_BinaryLogisticHead): def __init__(self, label_name, weight_column_name, enable_centered_bias, head_name, thresholds): def _loss_fn(logits, labels): - check_shape_op = control_flow_ops.Assert( - math_ops.less_equal(array_ops.rank(labels), 2), - ["labels shape should be either [batch_size, 1] or [batch_size]"]) - with ops.control_dependencies([check_shape_op]): - labels = array_ops.reshape( - labels, shape=[array_ops.shape(labels)[0], 1]) - return losses.hinge_loss(logits, labels) + with ops.name_scope(None, "hinge_loss", (logits, labels)) as name: + check_shape_op = control_flow_ops.Assert( + math_ops.less_equal(array_ops.rank(labels), 2), + ("labels shape should be either [batch_size, 1] or [batch_size]",)) + with ops.control_dependencies((check_shape_op,)): + labels = array_ops.reshape( + labels, shape=(array_ops.shape(labels)[0], 1)) + return losses.hinge_loss(logits, labels, scope=name) super(_BinarySvmHead, self).__init__( - loss_fn=_loss_fn, label_name=label_name, weight_column_name=weight_column_name, enable_centered_bias=enable_centered_bias, head_name=head_name, + loss_fn=_loss_fn, thresholds=thresholds) def _logits_to_predictions(self, logits): """See `_MultiClassHead`.""" - predictions = {} - predictions[prediction_key.PredictionKey.LOGITS] = logits - logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits]) - predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax( - logits, 1, name=prediction_key.PredictionKey.CLASSES) - - return predictions + with ops.name_scope(None, "predictions", (logits,)): + return { + prediction_key.PredictionKey.LOGITS: logits, + prediction_key.PredictionKey.CLASSES: math_ops.argmax( + _one_class_to_two_class_logits(logits), 1, + name=prediction_key.PredictionKey.CLASSES) + } def _default_metrics(self): """See `_MultiClassHead`.""" @@ -901,60 +772,62 @@ class _MultiLabelHead(_MultiClassHead): thresholds): super(_MultiLabelHead, self).__init__( - loss_fn=_sigmoid_cross_entropy_loss, n_classes=n_classes, label_name=label_name, weight_column_name=weight_column_name, enable_centered_bias=enable_centered_bias, head_name=head_name, + loss_fn=_sigmoid_cross_entropy_loss, thresholds=thresholds) def _logits_to_predictions(self, logits): """See `_MultiClassHead`.""" - predictions = {prediction_key.PredictionKey.LOGITS: logits} - if self.logits_dimension == 1: - predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid( - logits, name=prediction_key.PredictionKey.LOGISTIC) - logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits]) - predictions[ - prediction_key.PredictionKey.PROBABILITIES] = math_ops.sigmoid( - logits, name=prediction_key.PredictionKey.PROBABILITIES) - predictions[prediction_key.PredictionKey.CLASSES] = math_ops.to_int64( - math_ops.greater(logits, 0), - name=prediction_key.PredictionKey.CLASSES) - return predictions + with ops.name_scope(None, "predictions", (logits,)): + return { + prediction_key.PredictionKey.LOGITS: logits, + prediction_key.PredictionKey.PROBABILITIES: math_ops.sigmoid( + logits, name=prediction_key.PredictionKey.PROBABILITIES), + prediction_key.PredictionKey.CLASSES: math_ops.to_int64( + math_ops.greater(logits, 0), + name=prediction_key.PredictionKey.CLASSES) + } def _weighted_loss(loss, weight): """Returns cumulative weighted loss.""" - unweighted_loss = array_ops.reshape(loss, shape=(-1,)) - weighted_loss = math_ops.mul(unweighted_loss, - array_ops.reshape( - weight, shape=(-1,))) - return weighted_loss + with ops.name_scope(None, "weighted_loss", (loss, weight)) as name: + unweighted_loss = array_ops.reshape(loss, shape=(-1,)) + weighted_loss = math_ops.mul(unweighted_loss, + array_ops.reshape( + weight, shape=(-1,)), + name=name) + return weighted_loss def _weight_tensor(features, weight_column_name): if not weight_column_name: return None - else: + with ops.name_scope( + None, "weight_tensor", tuple(six.itervalues(features))) as name: return array_ops.reshape( math_ops.to_float(features[weight_column_name]), - shape=(-1,)) + shape=(-1,), + name=name) def _loss(loss_unweighted, weight, name): - """Returns loss.""" - if weight is None: - loss = math_ops.reduce_mean(loss_unweighted, name=name) - return loss, loss - loss_weighted = _weighted_loss(loss_unweighted, weight) - weighted_average_loss = math_ops.div( - math_ops.reduce_sum(loss_weighted), - math_ops.to_float(math_ops.reduce_sum(weight)), - name="weighted_average_loss") - loss = math_ops.reduce_mean(loss_weighted, name=name) - return loss, weighted_average_loss + """Returns a tuple of (loss, weighted_average_loss).""" + with ops.name_scope(name, values=(loss_unweighted, weight)) as name_scope: + if weight is None: + loss = math_ops.reduce_mean(loss_unweighted, name=name_scope) + return loss, loss + loss_weighted = _weighted_loss(loss_unweighted, weight) + weighted_average_loss = math_ops.div( + math_ops.reduce_sum(loss_weighted), + math_ops.to_float(math_ops.reduce_sum(weight)), + name="weighted_average_loss") + loss = math_ops.reduce_mean(loss_weighted, name=name_scope) + return loss, weighted_average_loss def _check_logits_input_not_supported(logits, logits_input): @@ -971,63 +844,128 @@ def _check_mode_valid(mode): raise ValueError("mode=%s unrecognized." % str(mode)) -def _centered_bias(logits_dimension, weight_collection): - """Creates and returns centered bias.""" - centered_bias = variables.Variable( - array_ops.zeros([logits_dimension]), - collections=[weight_collection, ops.GraphKeys.GLOBAL_VARIABLES], - name="centered_bias_weight") +def _centered_bias(logits_dimension): + """Returns `logits`, optionally with centered bias applied. + + Args: + logits_dimension: Last dimension of `logits`. Must be >= 1. + + Returns: + Centered bias `Variable`. - biases = array_ops.reshape(centered_bias, [-1]) - for cb in range(logits_dimension): - summary.scalar("centered_bias_%d" % cb, biases[cb]) + Raises: + ValueError: if `logits_dimension` is invalid. + """ + if (logits_dimension is None) or (logits_dimension < 1): + raise ValueError("Invalid logits_dimension %s." % logits_dimension) + centered_bias = variable_scope.get_variable( + name="centered_bias_weight", + shape=(logits_dimension,), + initializer=init_ops.zeros_initializer, + trainable=True) + for dim in range(logits_dimension): + summary.scalar("centered_bias_%d" % dim, centered_bias[dim]) return centered_bias -def _centered_bias_step(logits_dimension, weight_collection, labels, loss_fn): +def _centered_bias_step(centered_bias, logits_dimension, labels, loss_fn): """Creates and returns training op for centered bias.""" - centered_bias = ops.get_collection(weight_collection) - batch_size = array_ops.shape(labels)[0] - logits = array_ops.reshape( - array_ops.tile(centered_bias[0], [batch_size]), - [batch_size, logits_dimension]) - with ops.name_scope(None, "centered_bias", (labels, logits)): - centered_bias_loss = math_ops.reduce_mean( - loss_fn(logits, labels), name="training_loss") - # Learn central bias by an optimizer. 0.1 is a convervative lr for a - # single variable. - return training.AdagradOptimizer(0.1).minimize( - centered_bias_loss, var_list=centered_bias) + if (logits_dimension is None) or (logits_dimension < 1): + raise ValueError("Invalid logits_dimension %s." % logits_dimension) + with ops.name_scope(None, "centered_bias_step", (labels,)) as name: + batch_size = array_ops.shape(labels)[0] + logits = array_ops.reshape( + array_ops.tile(centered_bias, (batch_size,)), + (batch_size, logits_dimension)) + with ops.name_scope(None, "centered_bias", (labels, logits)): + centered_bias_loss = math_ops.reduce_mean( + loss_fn(logits, labels), name="training_loss") + # Learn central bias by an optimizer. 0.1 is a convervative lr for a + # single variable. + return training.AdagradOptimizer(0.1).minimize( + centered_bias_loss, var_list=(centered_bias,), name=name) def _head_prefixed(head_name, val): return "%s_%s" % (head_name, val) if head_name else val -# TODO(zakaria): use contrib losses. -def _mean_squared_loss(logits, labels): - # To prevent broadcasting inside "-". - if len(labels.get_shape()) == 1: - labels = array_ops.expand_dims(labels, dim=[1]) - # TODO(zakaria): make sure it does not recreate the broadcast bug. - if len(logits.get_shape()) == 1: - logits = array_ops.expand_dims(logits, dim=[1]) - logits.get_shape().assert_is_compatible_with(labels.get_shape()) - return math_ops.square(logits - math_ops.to_float(labels)) +def _training_loss( + features, labels, logits, loss_fn, weight_column_name=None, head_name=None): + """Returns training loss tensor. + + Training loss is different from the loss reported on the tensorboard as we + should respect the example weights when computing the gradient. + + L = sum_{i} w_{i} * l_{i} / B + + where B is the number of examples in the batch, l_{i}, w_{i} are individual + losses, and example weight. + + Args: + features: Features `dict`. + labels: Either a `Tensor` for labels or in multihead case, a `dict` of + string to `Tensor`. + logits: logits, a float `Tensor`. Shape is `(batch_size, logits_dimension)`. + loss_fn: Function taking `logits` and `labels`, and returning the raw + unweighted loss. + weight_column_name: Key for weights `Tensor` in `features`, if applicable. + head_name: Head name, used for summary. + + Returns: + A loss `Output`. + """ + with ops.name_scope( + None, "training_loss", + tuple(six.itervalues(features)) + (labels, logits)) as name: + loss, weighted_average_loss = _loss( + loss_fn(logits, labels), + _weight_tensor(features, weight_column_name), + name=name) + summary.scalar(_head_prefixed(head_name, "loss"), weighted_average_loss) + return loss + + +def _train_op( + loss, labels, train_op_fn, centered_bias=None, logits_dimension=None, + loss_fn=None): + """Returns op for the training step.""" + with ops.name_scope(None, "train_op", (loss, labels)): + train_op = train_op_fn(loss) + if centered_bias is not None: + centered_bias_step = _centered_bias_step( + centered_bias, logits_dimension, labels, loss_fn) + train_op = control_flow_ops.group(train_op, centered_bias_step) + return train_op + + +def _eval_metric_ops(metrics, features, labels, predictions): + with ops.name_scope( + None, "metrics", + (tuple(six.itervalues(features)) + + (labels,) + + tuple(six.itervalues(predictions)))): + # pylint: disable=protected-access + return estimator._make_metrics_ops(metrics, features, labels, predictions) + # pylint: enable=protected-access def _sigmoid_cross_entropy_loss(logits, labels): - # sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels. - return nn.sigmoid_cross_entropy_with_logits(logits, math_ops.to_float(labels)) + with ops.name_scope( + None, "sigmoid_cross_entropy_loss", (logits, labels)) as name: + # sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels. + return nn.sigmoid_cross_entropy_with_logits( + logits, math_ops.to_float(labels), name=name) def _float_weights_or_none(weights): if weights is None: return None - return math_ops.to_float(weights) + with ops.name_scope(None, "float_weights", (weights,)) as name: + return math_ops.to_float(weights, name=name) -def _weighted_average_loss_metric_spec(loss_fn, predictoin_key, +def _weighted_average_loss_metric_spec(loss_fn, pred_key, label_key, weight_key): def _streaming_weighted_average_loss(predictions, labels, weights=None): loss_unweighted = loss_fn(predictions, labels) @@ -1038,7 +976,7 @@ def _weighted_average_loss_metric_spec(loss_fn, predictoin_key, name="eval_loss") return metrics_lib.streaming_mean(weighted_average_loss) return metric_spec.MetricSpec(_streaming_weighted_average_loss, - predictoin_key, label_key, weight_key) + pred_key, label_key, weight_key) def _labels_streaming_mean(unused_predictions, labels, weights=None): @@ -1070,7 +1008,7 @@ def _streaming_at_threshold(streaming_metrics_fn, threshold): def _streaming_metrics(predictions, labels, weights=None): precision_tensor, update_op = streaming_metrics_fn( - predictions, labels=labels, thresholds=[threshold], + predictions, labels=labels, thresholds=(threshold,), weights=_float_weights_or_none(weights)) return array_ops.squeeze(precision_tensor), update_op diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 440615371d..673fbaefbb 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -18,13 +18,38 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np +import six import tensorflow as tf from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +def _assert_variables( + test_case, expected_global=None, expected_model=None, + expected_trainable=None): + test_case.assertItemsEqual( + [] if expected_global is None else expected_global, + [k.name for k in tf.global_variables()]) + test_case.assertItemsEqual( + [] if expected_model is None else expected_model, + [k.name for k in tf.model_variables()]) + test_case.assertItemsEqual( + [] if expected_trainable is None else expected_trainable, + [k.name for k in tf.trainable_variables()]) + + +def _assert_no_variables(test_case): + _assert_variables(test_case, set([]), set([]), set([])) + + class RegressionModelHeadTest(tf.test.TestCase): + def _assert_metrics(self, model_fn_ops): + self.assertItemsEqual(( + "loss", + ), six.iterkeys(model_fn_ops.eval_metric_ops)) + # TODO(zakaria): test multilabel regresssion. def testRegression(self): head = head_lib._regression_head() @@ -34,8 +59,15 @@ class RegressionModelHeadTest(tf.test.TestCase): model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=prediction) + self._assert_metrics(model_fn_ops) + _assert_no_variables(self) self.assertAlmostEqual(5. / 3, sess.run(model_fn_ops.loss)) + model_fn_ops = head.head_ops({}, labels, + tf.contrib.learn.ModeKeys.EVAL, + _noop_train_op, logits=prediction) + self.assertIsNone(model_fn_ops.train_op) + def testRegressionWithWeights(self): head = head_lib._regression_head( weight_column_name="label_weight") @@ -46,6 +78,28 @@ class RegressionModelHeadTest(tf.test.TestCase): model_fn_ops = head.head_ops(features, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=prediction) + self._assert_metrics(model_fn_ops) + _assert_no_variables(self) + self.assertAlmostEqual(2. / 3, sess.run(model_fn_ops.loss), places=3) + + def testRegressionWithCenteredBias(self): + head = head_lib._regression_head( + weight_column_name="label_weight", enable_centered_bias=True) + with tf.Graph().as_default(), tf.Session() as sess: + features = {"label_weight": tf.constant([[2.], [5.], [0.]])} + prediction = tf.constant([[1.], [1.], [3.]]) + labels = tf.constant([[0.], [1.], [1.]]) + model_fn_ops = head.head_ops(features, labels, + tf.contrib.learn.ModeKeys.TRAIN, + _noop_train_op, logits=prediction) + self._assert_metrics(model_fn_ops) + _assert_variables(self, expected_global=( + "centered_bias_weight:0", + "train_op/centered_bias_step/centered_bias_weight/Adagrad:0", + ), expected_trainable=( + "centered_bias_weight:0", + )) + tf.global_variables_initializer().run() self.assertAlmostEqual(2. / 3, sess.run(model_fn_ops.loss), places=3) def testErrorInSparseTensorLabels(self): @@ -64,6 +118,12 @@ class RegressionModelHeadTest(tf.test.TestCase): class MultiLabelModelHeadTest(tf.test.TestCase): + def _assert_metrics(self, model_fn_ops): + self.assertItemsEqual(( + "accuracy", + "loss", + ), six.iterkeys(model_fn_ops.eval_metric_ops)) + def testMultiLabel(self): head = head_lib._multi_label_head(n_classes=3) with tf.Graph().as_default(), tf.Session() as sess: @@ -72,8 +132,15 @@ class MultiLabelModelHeadTest(tf.test.TestCase): model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=logits) + self._assert_metrics(model_fn_ops) + _assert_no_variables(self) self.assertAlmostEqual(0.89985204, sess.run(model_fn_ops.loss)) + model_fn_ops = head.head_ops({}, labels, + tf.contrib.learn.ModeKeys.EVAL, + _noop_train_op, logits=logits) + self.assertIsNone(model_fn_ops.train_op) + def testMultiLabelWithWeight(self): head = head_lib._multi_label_head( n_classes=3, weight_column_name="label_weight") @@ -84,11 +151,44 @@ class MultiLabelModelHeadTest(tf.test.TestCase): model_fn_ops = head.head_ops(features, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=logits) + self._assert_metrics(model_fn_ops) + _assert_no_variables(self) self.assertAlmostEqual(0.089985214, sess.run(model_fn_ops.loss)) + def testMultiLabelWithCenteredBias(self): + head = head_lib._multi_label_head(n_classes=3, enable_centered_bias=True) + with tf.Graph().as_default(), tf.Session() as sess: + logits = tf.constant([[1., 0., 0.]]) + labels = tf.constant([[0, 0, 1]]) + model_fn_ops = head.head_ops({}, labels, + tf.contrib.learn.ModeKeys.TRAIN, + _noop_train_op, logits=logits) + self._assert_metrics(model_fn_ops) + _assert_variables(self, expected_global=( + "centered_bias_weight:0", + "train_op/centered_bias_step/centered_bias_weight/Adagrad:0", + ), expected_trainable=( + "centered_bias_weight:0", + )) + tf.global_variables_initializer().run() + self.assertAlmostEqual(0.89985204, sess.run(model_fn_ops.loss)) + class MultiClassModelHeadTest(tf.test.TestCase): + def _assert_binary_metrics(self, model_fn_ops): + self.assertItemsEqual(( + "accuracy", + "accuracy/baseline_label_mean", + "accuracy/threshold_0.500000_mean", + "auc", + "labels/actual_label_mean", + "labels/prediction_mean", + "loss", + "precision/positive_threshold_0.500000_mean", + "recall/positive_threshold_0.500000_mean", + ), six.iterkeys(model_fn_ops.eval_metric_ops)) + def testBinaryClassification(self): head = head_lib._multi_class_head(n_classes=2) with tf.Graph().as_default(), tf.Session() as sess: @@ -99,8 +199,14 @@ class MultiClassModelHeadTest(tf.test.TestCase): model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=logits) + self._assert_binary_metrics(model_fn_ops) + _assert_no_variables(self) self.assertAlmostEqual(0.81326175, sess.run(model_fn_ops.loss), delta=1e-6) + model_fn_ops = head.head_ops({}, labels, + tf.contrib.learn.ModeKeys.EVAL, + _noop_train_op, logits=logits) + self.assertIsNone(model_fn_ops.train_op) def testErrorInSparseTensorLabels(self): head = head_lib._multi_class_head(n_classes=2) @@ -127,11 +233,41 @@ class MultiClassModelHeadTest(tf.test.TestCase): model_fn_ops = head.head_ops(features, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=logits) + self._assert_binary_metrics(model_fn_ops) + _assert_no_variables(self) self.assertAlmostEqual(.31326166 / 2, sess.run(model_fn_ops.loss), delta=1e-6) + def testBinaryClassificationWithCenteredBias(self): + head = head_lib._multi_class_head(n_classes=2, enable_centered_bias=True) + with tf.Graph().as_default(), tf.Session() as sess: + logits = tf.constant([[1.], [1.]]) + labels = tf.constant([[1.], [0.]]) + # logloss: z:label, x:logit + # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) + model_fn_ops = head.head_ops({}, labels, + tf.contrib.learn.ModeKeys.TRAIN, + _noop_train_op, logits=logits) + self._assert_binary_metrics(model_fn_ops) + _assert_variables(self, expected_global=( + "centered_bias_weight:0", + "train_op/centered_bias_step/centered_bias_weight/Adagrad:0", + ), expected_trainable=( + "centered_bias_weight:0", + )) + tf.global_variables_initializer().run() + self.assertAlmostEqual(0.81326175, sess.run(model_fn_ops.loss), + delta=1e-6) + + def _assert_multi_class_metrics(self, model_fn_ops): + self.assertItemsEqual(( + "accuracy", + "loss", + ), six.iterkeys(model_fn_ops.eval_metric_ops)) + def testMultiClass(self): - head = head_lib._multi_class_head(n_classes=3) + n_classes = 3 + head = head_lib._multi_class_head(n_classes=n_classes) with tf.Graph().as_default(), tf.Session() as sess: logits = tf.constant([[1., 0., 0.]]) labels = tf.constant([2]) @@ -140,11 +276,18 @@ class MultiClassModelHeadTest(tf.test.TestCase): model_fn_ops = head.head_ops({}, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=logits) + self._assert_multi_class_metrics(model_fn_ops) + _assert_no_variables(self) self.assertAlmostEqual(1.5514446, sess.run(model_fn_ops.loss)) + model_fn_ops = head.head_ops({}, labels, + tf.contrib.learn.ModeKeys.EVAL, + _noop_train_op, logits=logits) + self.assertIsNone(model_fn_ops.train_op) def testMultiClassWithWeight(self): + n_classes = 3 head = head_lib._multi_class_head( - n_classes=3, weight_column_name="label_weight") + n_classes=n_classes, weight_column_name="label_weight") with tf.Graph().as_default(), tf.Session() as sess: features = {"label_weight": tf.constant([0.1])} logits = tf.constant([[1., 0., 0.]]) @@ -154,6 +297,8 @@ class MultiClassModelHeadTest(tf.test.TestCase): model_fn_ops = head.head_ops(features, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_train_op, logits=logits) + self._assert_multi_class_metrics(model_fn_ops) + _assert_no_variables(self) self.assertAlmostEqual(.15514446, sess.run(model_fn_ops.loss)) def testInvalidNClasses(self): @@ -164,34 +309,73 @@ class MultiClassModelHeadTest(tf.test.TestCase): class BinarySvmModelHeadTest(tf.test.TestCase): + def setUp(self): + # Prediction for first example is in the right side of the hyperplane + # (i.e., < 0) but it is within the [-1,1] margin. There is a 0.5 loss + # incurred by this example. The 2nd prediction is outside the margin so it + # incurs no loss at all. + self._predictions = ((-0.5,), (1.2,)) + self._labels = (0, 1) + self._expected_losses = (0.5, 0.0) + + def _assert_metrics(self, model_fn_ops): + self.assertItemsEqual(( + "accuracy", + "loss", + ), six.iterkeys(model_fn_ops.eval_metric_ops)) + def testBinarySVMDefaultWeights(self): head = head_lib._binary_svm_head() - predictions = tf.constant([[-0.5], [1.2]]) - labels = tf.constant([0, 1]) + with tf.Graph().as_default(), tf.Session(): + predictions = tf.constant(self._predictions) + labels = tf.constant(self._labels) + model_fn_ops = head.head_ops({}, labels, + tf.contrib.learn.ModeKeys.TRAIN, + _noop_train_op, logits=predictions) + self._assert_metrics(model_fn_ops) + _assert_no_variables(self) + self.assertAlmostEqual( + np.average(self._expected_losses), model_fn_ops.loss.eval()) + model_fn_ops = head.head_ops({}, labels, - tf.contrib.learn.ModeKeys.TRAIN, + tf.contrib.learn.ModeKeys.EVAL, _noop_train_op, logits=predictions) - # Prediction for first example is in the right side of the hyperplane (i.e., - # < 0) but it is within the [-1,1] margin. There is a 0.5 loss incurred by - # this example. The 2nd prediction is outside the margin so it incurs no - # loss at all. The overall (normalized) loss is therefore 0.5/(1+1) = 0.25. - with tf.Session() as sess: - self.assertAlmostEqual(0.25, sess.run(model_fn_ops.loss)) + self.assertIsNone(model_fn_ops.train_op) def testBinarySVMWithWeights(self): - head = head_lib._binary_svm_head( - weight_column_name="weights") - predictions = tf.constant([[-0.7], [0.2]]) - labels = tf.constant([0, 1]) - features = {"weights": tf.constant([2.0, 10.0])} - model_fn_ops = head.head_ops(features, labels, - tf.contrib.learn.ModeKeys.TRAIN, - _noop_train_op, logits=predictions) - # Prediction for both examples are in the right side of the hyperplane but - # within the margin. The (weighted) loss incurred is 2*0.3=0.6 and 10*0.8=8 - # respectively. The overall (normalized) loss is therefore 8.6/12. - with tf.Session() as sess: - self.assertAlmostEqual(8.6 / 2, sess.run(model_fn_ops.loss), places=3) + head = head_lib._binary_svm_head(weight_column_name="weights") + with tf.Graph().as_default(), tf.Session(): + predictions = tf.constant(self._predictions) + labels = tf.constant(self._labels) + weights = (7.0, 11.0) + features = {"weights": tf.constant(weights)} + model_fn_ops = head.head_ops(features, labels, + tf.contrib.learn.ModeKeys.TRAIN, + _noop_train_op, logits=predictions) + self._assert_metrics(model_fn_ops) + _assert_no_variables(self) + self.assertAlmostEqual( + np.sum(np.multiply(weights, self._expected_losses)) / 2.0, + model_fn_ops.loss.eval()) + + def testBinarySVMWithCenteredBias(self): + head = head_lib._binary_svm_head(enable_centered_bias=True) + with tf.Graph().as_default(), tf.Session(): + predictions = tf.constant(self._predictions) + labels = tf.constant(self._labels) + model_fn_ops = head.head_ops({}, labels, + tf.contrib.learn.ModeKeys.TRAIN, + _noop_train_op, logits=predictions) + self._assert_metrics(model_fn_ops) + _assert_variables(self, expected_global=( + "centered_bias_weight:0", + "train_op/centered_bias_step/centered_bias_weight/Adagrad:0", + ), expected_trainable=( + "centered_bias_weight:0", + )) + tf.global_variables_initializer().run() + self.assertAlmostEqual( + np.average(self._expected_losses), model_fn_ops.loss.eval()) def _noop_train_op(unused_loss): diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index d043468654..0405eb0476 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -27,6 +27,7 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values +from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import monitors as monitor_lib @@ -519,6 +520,22 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable): default_batch_size=default_batch_size, exports_to_keep=exports_to_keep) + @experimental + def export_savedmodel(self, + export_dir_base, + input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + exports_to_keep=None): + return self._estimator.export_savedmodel( + export_dir_base, + input_fn, + default_output_alternative_key=default_output_alternative_key, + assets_extra=assets_extra, + as_text=as_text, + exports_to_keep=exports_to_keep) + @property @deprecated("2016-10-30", "This method will be removed after the deprecation date. " @@ -761,6 +778,22 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable): default_batch_size=default_batch_size, exports_to_keep=exports_to_keep) + @experimental + def export_savedmodel(self, + export_dir_base, + input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + exports_to_keep=None): + return self._estimator.export_savedmodel( + export_dir_base, + input_fn, + default_output_alternative_key=default_output_alternative_key, + assets_extra=assets_extra, + as_text=as_text, + exports_to_keep=exports_to_keep) + @property @deprecated("2016-10-30", "This method will be removed after the deprecation date. " diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py index 0c99de5dd1..50f0d2d75d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py @@ -533,7 +533,7 @@ class LinearClassifierTest(tf.test.TestCase): classifier = tf.contrib.learn.LinearClassifier( feature_columns=[age, language], enable_centered_bias=False) classifier.fit(input_fn=input_fn, steps=100) - self.assertFalse('centered_bias_weight' in classifier.get_variable_names()) + self.assertNotIn('centered_bias_weight', classifier.get_variable_names()) def testEnableCenteredBias(self): """Tests that we can disable centered bias.""" @@ -552,7 +552,7 @@ class LinearClassifierTest(tf.test.TestCase): classifier = tf.contrib.learn.LinearClassifier( feature_columns=[age, language], enable_centered_bias=True) classifier.fit(input_fn=input_fn, steps=100) - self.assertTrue('centered_bias_weight' in classifier.get_variable_names()) + self.assertIn('centered_bias_weight', classifier.get_variable_names()) def testTrainOptimizerWithL1Reg(self): """Tests l1 regularized model has higher loss.""" diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py index 3f9351ce22..42f21bd196 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py @@ -49,13 +49,33 @@ class ModeKeys(object): # TODO(roumposg): Pass output_signature_fn instead of signature_fn. class ModelFnOps(collections.namedtuple( 'ModelFnOps', - ['predictions', 'loss', 'train_op', 'eval_metric_ops', 'signature_fn'])): + ['predictions', 'loss', 'train_op', 'eval_metric_ops', 'signature_fn', + 'output_alternatives'])): """Ops returned from a model_fn.""" + # TODO(soergel): remove signature_fn once sessionbundle export is deprecated. + def __new__(cls, mode, predictions=None, loss=None, train_op=None, - eval_metric_ops=None, signature_fn=None): + eval_metric_ops=None, signature_fn=None, + output_alternatives=None): """Creates a validated `ModelFnOps` instance. + For a multi-headed model, the predictions dict here will contain the outputs + of all of the heads. However: at serving time, requests will be made + specifically for one or more heads, and the RPCs used for these requests may + differ by problem type (i.e., regression, classification, other). The + purpose of the output_alternatives dict is to aid in exporting a SavedModel + from which such head-specific queries can be served. These + output_alternatives will be combined with input_alternatives (see + `saved_model_export_utils`) to produce a set of `SignatureDef`s specifying + the valid requests that can be served from this model. + + For a single-headed model, it is still adviseable to provide + output_alternatives with a single entry, because this is how the problem + type is communicated for export and serving. If output_alternatives is not + given, the resulting SavedModel will support only one head of unspecified + type. + Args: mode: One of `ModeKeys`. Specifies if this training, evaluation or prediction. @@ -65,6 +85,14 @@ class ModelFnOps(collections.namedtuple( eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, such as `Tensor`. signature_fn: The signature_fn used for exporting. + output_alternatives: a dict of + `{submodel_name: (problem_type, {tensor_name: Tensor})}`, where + `submodel_name` is a submodel identifier that should be consistent + across the pipeline (here likely taken from the name of each `Head`, + for models that use them), `problem_type` is a `ProblemType`, + `tensor_name` is a symbolic name for an output Tensor possibly but not + necessarily taken from `PredictionKey`, and `Tensor` is the + corresponding output Tensor itself. Returns: A validated `ModelFnOps` object. @@ -122,4 +150,5 @@ class ModelFnOps(collections.namedtuple( raise ValueError('signature_fn is not callable.') return super(ModelFnOps, cls).__new__(cls, predictions, loss, train_op, - eval_metric_ops, signature_fn) + eval_metric_ops, signature_fn, + output_alternatives) diff --git a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py index c2c41255c9..deb55efc9f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py +++ b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.contrib import framework as contrib_framework from tensorflow.contrib.framework import deprecated_arg_values +from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import trainable @@ -352,3 +353,19 @@ class TensorForestEstimator(evaluable.Evaluable, trainable.Trainable): self._estimator._model_fn = orig_model_fn # pylint: enable=protected-access return result + + @experimental + def export_savedmodel(self, + export_dir_base, + input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + exports_to_keep=None): + return self._estimator.export_savedmodel( + export_dir_base, + input_fn, + default_output_alternative_key=default_output_alternative_key, + assets_extra=assets_extra, + as_text=as_text, + exports_to_keep=exports_to_keep) diff --git a/tensorflow/contrib/learn/python/learn/estimators/svm.py b/tensorflow/contrib/learn/python/learn/estimators/svm.py index eeee673c5a..a6e4e7b6a3 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/svm.py +++ b/tensorflow/contrib/learn/python/learn/estimators/svm.py @@ -19,12 +19,14 @@ from __future__ import division from __future__ import print_function import inspect +import re import tempfile from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated_arg_values from tensorflow.contrib.framework import list_variables from tensorflow.contrib.framework import load_variable +from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import estimator @@ -235,6 +237,22 @@ class SVM(trainable.Trainable, evaluable.Evaluable): default_batch_size=default_batch_size, exports_to_keep=exports_to_keep) + @experimental + def export_savedmodel(self, + export_dir_base, + input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + exports_to_keep=None): + return self._estimator.export_savedmodel( + export_dir_base, + input_fn, + default_output_alternative_key=default_output_alternative_key, + assets_extra=assets_extra, + as_text=as_text, + exports_to_keep=exports_to_keep) + @property def weights_(self): values = {} diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index 9c70cc8dea..edd363b728 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -924,9 +924,9 @@ class ExportMonitor(EveryN): `None`). input_feature_key: String key into the features dict returned by `input_fn` that corresponds to the raw `Example` strings `Tensor` that - the exported model will take as input. Can only be `None` if you're - using a custom `signature_fn` that does not use the first arg - (examples). + the exported model will take as input. Should be `None` if and only if + you're passing in a `signature_fn` that does not use the first arg + (`Tensor` of `Example` strings). exports_to_keep: int, number of exports to keep. signature_fn: Function that returns a default signature and a named signature map, given `Tensor` of `Example` strings, `dict` of `Tensor`s diff --git a/tensorflow/contrib/learn/python/learn/ops/array_ops.py b/tensorflow/contrib/learn/python/learn/ops/array_ops.py index a04e91b830..9196a9b9ad 100644 --- a/tensorflow/contrib/learn/python/learn/ops/array_ops.py +++ b/tensorflow/contrib/learn/python/learn/ops/array_ops.py @@ -21,25 +21,32 @@ from __future__ import print_function from tensorflow.contrib.framework import deprecated from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops as array_ops_ from tensorflow.python.ops import math_ops @deprecated('2016-12-01', 'Use `tf.one_hot` instead.') -def one_hot_matrix(tensor_in, num_classes, on_value=1.0, off_value=0.0): +def one_hot_matrix(tensor_in, num_classes, on_value=1.0, off_value=0.0, + name=None): """Encodes indices from given tensor as one-hot tensor. TODO(ilblackdragon): Ideally implementation should be part of TensorFlow with Eigen-native operation. Args: - tensor_in: Input tensor of shape [N1, N2]. + tensor_in: Input `Tensor` of shape [N1, N2]. num_classes: Number of classes to expand index into. - on_value: Tensor or float, value to fill-in given index. - off_value: Tensor or float, value to fill-in everything else. + on_value: `Tensor` or float, value to fill-in given index. + off_value: `Tensor` or float, value to fill-in everything else. + name: Name of the op. Returns: - Tensor of shape [N1, N2, num_classes] with 1.0 for each id in original + `Tensor` of shape `[N1, N2, num_classes]` with 1.0 for each id in original tensor. """ - return array_ops_.one_hot( - math_ops.cast(tensor_in, dtypes.int64), num_classes, on_value, off_value) + with ops.name_scope( + name, 'one_hot_matrix', + [tensor_in, num_classes, on_value, off_value]) as name_scope: + return array_ops_.one_hot( + math_ops.cast(tensor_in, dtypes.int64), num_classes, on_value, + off_value, name=name_scope) diff --git a/tensorflow/contrib/learn/python/learn/utils/gc.py b/tensorflow/contrib/learn/python/learn/utils/gc.py new file mode 100644 index 0000000000..dd4376f051 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/utils/gc.py @@ -0,0 +1,205 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +r"""System for specifying garbage collection (GC) of path based data. + +This framework allows for GC of data specified by path names, for example files +on disk. gc.Path objects each represent a single item stored at a path and may +be a base directory, + /tmp/exports/0/... + /tmp/exports/1/... + ... +or a fully qualified file, + /tmp/train-1.ckpt + /tmp/train-2.ckpt + ... + +A gc filter function takes and returns a list of gc.Path items. Filter +functions are responsible for selecting Path items for preservation or deletion. +Note that functions should always return a sorted list. + +For example, + base_dir = "/tmp" + # create the directories + for e in xrange(10): + os.mkdir("%s/%d" % (base_dir, e), 0o755) + + # create a simple parser that pulls the export_version from the directory + def parser(path): + match = re.match("^" + base_dir + "/(\\d+)$", path.path) + if not match: + return None + return path._replace(export_version=int(match.group(1))) + + path_list = gc.get_paths("/tmp", parser) # contains all ten Paths + + every_fifth = gc.mod_export_version(5) + print every_fifth(path_list) # shows ["/tmp/0", "/tmp/5"] + + largest_three = gc.largest_export_versions(3) + print largest_three(all_paths) # shows ["/tmp/7", "/tmp/8", "/tmp/9"] + + both = gc.union(every_fifth, largest_three) + print both(all_paths) # shows ["/tmp/0", "/tmp/5", + # "/tmp/7", "/tmp/8", "/tmp/9"] + # delete everything not in 'both' + to_delete = gc.negation(both) + for p in to_delete(all_paths): + gfile.DeleteRecursively(p.path) # deletes: "/tmp/1", "/tmp/2", + # "/tmp/3", "/tmp/4", "/tmp/6", +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import heapq +import math +import os + +from tensorflow.python.platform import gfile + +Path = collections.namedtuple('Path', 'path export_version') + + +def largest_export_versions(n): + """Creates a filter that keeps the largest n export versions. + + Args: + n: number of versions to keep. + + Returns: + A filter function that keeps the n largest paths. + """ + def keep(paths): + heap = [] + for idx, path in enumerate(paths): + if path.export_version is not None: + heapq.heappush(heap, (path.export_version, idx)) + keepers = [paths[i] for _, i in heapq.nlargest(n, heap)] + return sorted(keepers) + + return keep + + +def one_of_every_n_export_versions(n): + """Creates a filter that keeps one of every n export versions. + + Args: + n: interval size. + + Returns: + A filter function that keeps exactly one path from each interval + [0, n], (n, 2n], (2n, 3n], etc... If more than one path exists in an + interval the largest is kept. + """ + def keep(paths): + """A filter function that keeps exactly one out of every n paths.""" + + keeper_map = {} # map from interval to largest path seen in that interval + for p in paths: + if p.export_version is None: + # Skip missing export_versions. + continue + # Find the interval (with a special case to map export_version = 0 to + # interval 0. + interval = math.floor( + (p.export_version - 1) / n) if p.export_version else 0 + existing = keeper_map.get(interval, None) + if (not existing) or (existing.export_version < p.export_version): + keeper_map[interval] = p + return sorted(keeper_map.values()) + + return keep + + +def mod_export_version(n): + """Creates a filter that keeps every export that is a multiple of n. + + Args: + n: step size. + + Returns: + A filter function that keeps paths where export_version % n == 0. + """ + def keep(paths): + keepers = [] + for p in paths: + if p.export_version % n == 0: + keepers.append(p) + return sorted(keepers) + return keep + + +def union(lf, rf): + """Creates a filter that keeps the union of two filters. + + Args: + lf: first filter + rf: second filter + + Returns: + A filter function that keeps the n largest paths. + """ + def keep(paths): + l = set(lf(paths)) + r = set(rf(paths)) + return sorted(list(l|r)) + return keep + + +def negation(f): + """Negate a filter. + + Args: + f: filter function to invert + + Returns: + A filter function that returns the negation of f. + """ + def keep(paths): + l = set(paths) + r = set(f(paths)) + return sorted(list(l-r)) + return keep + + +def get_paths(base_dir, parser): + """Gets a list of Paths in a given directory. + + Args: + base_dir: directory. + parser: a function which gets the raw Path and can augment it with + information such as the export_version, or ignore the path by returning + None. An example parser may extract the export version from a path + such as "/tmp/exports/100" an another may extract from a full file + name such as "/tmp/checkpoint-99.out". + + Returns: + A list of Paths contained in the base directory with the parsing function + applied. + By default the following fields are populated, + - Path.path + The parsing function is responsible for populating, + - Path.export_version + """ + raw_paths = gfile.ListDirectory(base_dir) + paths = [] + for r in raw_paths: + p = parser(Path(os.path.join(base_dir, r), None)) + if p: + paths.append(p) + return sorted(paths) diff --git a/tensorflow/contrib/learn/python/learn/utils/gc_test.py b/tensorflow/contrib/learn/python/learn/utils/gc_test.py new file mode 100644 index 0000000000..dbe3304f21 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/utils/gc_test.py @@ -0,0 +1,120 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for learn.utils.gc.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import re + +from six.moves import xrange # pylint: disable=redefined-builtin + +import tensorflow as tf + +from tensorflow.contrib.learn.python.learn.utils import gc +from tensorflow.python.framework import test_util +from tensorflow.python.platform import gfile + + +def tearDownModule(): + gfile.DeleteRecursively(tf.test.get_temp_dir()) + + +class GcTest(test_util.TensorFlowTestCase): + + def testLargestExportVersions(self): + paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)] + newest = gc.largest_export_versions(2) + n = newest(paths) + self.assertEquals(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)]) + + def testLargestExportVersionsDoesNotDeleteZeroFolder(self): + paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)] + newest = gc.largest_export_versions(2) + n = newest(paths) + self.assertEquals(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)]) + + def testModExportVersion(self): + paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6), + gc.Path("/foo", 9)] + mod = gc.mod_export_version(2) + self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)]) + mod = gc.mod_export_version(3) + self.assertEquals(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)]) + + def testOneOfEveryNExportVersions(self): + paths = [gc.Path("/foo", 0), gc.Path("/foo", 1), gc.Path("/foo", 3), + gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 7), + gc.Path("/foo", 8), gc.Path("/foo", 33)] + one_of = gc.one_of_every_n_export_versions(3) + self.assertEquals(one_of(paths), + [gc.Path("/foo", 3), gc.Path("/foo", 6), + gc.Path("/foo", 8), gc.Path("/foo", 33)]) + + def testOneOfEveryNExportVersionsZero(self): + # Zero is a special case since it gets rolled into the first interval. + # Test that here. + paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)] + one_of = gc.one_of_every_n_export_versions(3) + self.assertEquals(one_of(paths), + [gc.Path("/foo", 0), gc.Path("/foo", 5)]) + + def testUnion(self): + paths = [] + for i in xrange(10): + paths.append(gc.Path("/foo", i)) + f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3)) + self.assertEquals( + f(paths), [gc.Path("/foo", 0), gc.Path("/foo", 3), + gc.Path("/foo", 6), gc.Path("/foo", 7), + gc.Path("/foo", 8), gc.Path("/foo", 9)]) + + def testNegation(self): + paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6), + gc.Path("/foo", 9)] + mod = gc.negation(gc.mod_export_version(2)) + self.assertEquals( + mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)]) + mod = gc.negation(gc.mod_export_version(3)) + self.assertEquals( + mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)]) + + def testPathsWithParse(self): + base_dir = os.path.join(tf.test.get_temp_dir(), "paths_parse") + self.assertFalse(gfile.Exists(base_dir)) + for p in xrange(3): + gfile.MakeDirs(os.path.join(base_dir, "%d" % p)) + # add a base_directory to ignore + gfile.MakeDirs(os.path.join(base_dir, "ignore")) + + # create a simple parser that pulls the export_version from the directory. + def parser(path): + match = re.match("^" + base_dir + "/(\\d+)$", path.path) + if not match: + return None + return path._replace(export_version=int(match.group(1))) + + self.assertEquals( + gc.get_paths(base_dir, parser=parser), + [gc.Path(os.path.join(base_dir, "0"), 0), + gc.Path(os.path.join(base_dir, "1"), 1), + gc.Path(os.path.join(base_dir, "2"), 2)]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py new file mode 100644 index 0000000000..2cb7173d5a --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py @@ -0,0 +1,97 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utilities for creating input_fns.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import parsing_ops + + +# A return type allowing input_fns to return multiple values in a well- +# defined way (analogous to ModelFnOps). +# The expected return values are: +# features: a dict of string to Tensor, giving the features to be passed to +# the model. +# labels: a dict of string to Tensor, giving labels (aka targets) for training. +# default_inputs: a dict of string to Tensor, giving the input Tensors (if +# any) that this input_fn expects to be fed. +InputFnOps = collections.namedtuple('InputFnOps', + ['features', + 'labels', + 'default_inputs']) + + +def build_parsing_serving_input_fn(feature_spec, default_batch_size=1): + """Build an input_fn appropriate for serving, expecting fed tf.Examples. + + Creates an input_fn that expects a serialized tf.Example fed into a string + placeholder. The function parses the tf.Example according to the provided + feature_spec, and returns all parsed Tensors as features. This input_fn is + for use at serving time, so the labels return value is always None. + + Args: + feature_spec: a dict of string to `VarLenFeature`/`FixedLenFeature`. + default_batch_size: the number of query examples expected per batch. + + Returns: + An input_fn suitable for use in serving. + """ + def input_fn(): + """An input_fn that expects a serialized tf.Example.""" + serialized_tf_example = array_ops.placeholder(dtype=dtypes.string, + shape=[default_batch_size], + name='input_example_tensor') + inputs = {'examples': serialized_tf_example} + features = parsing_ops.parse_example(serialized_tf_example, feature_spec) + labels = None # these are not known in serving! + return InputFnOps(features, labels, inputs) + return input_fn + + +def build_default_serving_input_fn(features, default_batch_size=1): + """Build an input_fn appropriate for serving, expecting feature Tensors. + + Creates an input_fn that expects all features to be fed directly. + This input_fn is for use at serving time, so the labels return value is always + None. + + Args: + features: a dict of string to `Tensor`. + default_batch_size: the number of query examples expected per batch. + + Returns: + An input_fn suitable for use in serving. + """ + def input_fn(): + """an input_fn that expects all features to be fed directly.""" + features_placeholders = {} + for name, t in features.items(): + shape_list = t.get_shape().as_list() + shape_list[0] = default_batch_size + shape = tensor_shape.TensorShape(shape_list) + + features_placeholders[name] = array_ops.placeholder(dtype=t.dtype, + shape=shape, + name=t.name) + labels = None # these are not known in serving! + return InputFnOps(features_placeholders, labels, features_placeholders) + return input_fn diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py new file mode 100644 index 0000000000..54bb0fb3d7 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -0,0 +1,248 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utilities supporting export to SavedModel.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import os +import re +import time + +from tensorflow.contrib.learn.python.learn.estimators import constants +from tensorflow.contrib.learn.python.learn.estimators import prediction_key +from tensorflow.contrib.learn.python.learn.utils import gc +from tensorflow.contrib.learn.python.learn.utils import input_fn_utils +from tensorflow.python.platform import gfile +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import signature_def_utils + +from tensorflow.python.util import compat + +# A key for use in the input_alternatives dict indicating the default input. +# This is the input that will be expected when a serving request does not +# specify a specific signature. +# The default input alternative specifies placeholders that the input_fn +# requires to be fed (in the typical case, a single placeholder for a +# serialized tf.Example). +DEFAULT_INPUT_ALTERNATIVE_KEY = 'default_input_alternative' + +# A key for use in the input_alternatives dict indicating the features input. +# The features inputs alternative specifies the feature Tensors provided as +# input to the model_fn, i.e. the outputs of the input_fn. +FEATURES_INPUT_ALTERNATIVE_KEY = 'features_input_alternative' + +# A key for use in the output_alternatives dict indicating the default output. +# This is the output that will be provided when a serving request does not +# specify a specific signature. +# In a single-headed model, the single output is automatically the default. +# In a multi-headed model, the name of the desired default head should be +# provided to get_output_alternatives. +DEFAULT_OUTPUT_ALTERNATIVE_KEY = 'default_output_alternative' + + +def build_standardized_signature_def( + input_tensors, output_tensors, problem_type): + """Build a SignatureDef using problem type and input and output Tensors. + + Note that this delegates the actual creation of the signatures to methods in + //third_party/tensorflow/python/saved_model/signature_def_utils.py, which may + assign names to the input and output tensors (depending on the problem type) + that are standardized in the context of SavedModel. + + Args: + input_tensors: a dict of string key to `Tensor` + output_tensors: a dict of string key to `Tensor` + problem_type: an instance of constants.ProblemType, specifying + classification, regression, etc. + + Returns: + A SignatureDef using SavedModel standard keys where possible. + + Raises: + ValueError: if input_tensors or output_tensors is None or empty. + """ + + if not input_tensors: + raise ValueError('input_tensors must be provided.') + if not output_tensors: + raise ValueError('output_tensors must be provided.') + + # Per-method signature_def functions will standardize the keys if possible + if _is_classification_problem(problem_type, input_tensors, output_tensors): + (_, examples), = input_tensors.items() + classes = output_tensors.get(prediction_key.PredictionKey.CLASSES) + scores = output_tensors.get(prediction_key.PredictionKey.SCORES) + if not (classes or scores): + (_, classes), = output_tensors.items() + return signature_def_utils.classification_signature_def( + examples, classes, scores) + elif _is_regression_problem(problem_type, input_tensors, output_tensors): + (_, examples), = input_tensors.items() + (_, predictions), = output_tensors.items() + return signature_def_utils.regression_signature_def(examples, predictions) + else: + return signature_def_utils.predict_signature_def( + input_tensors, output_tensors) + + +def _is_classification_problem(problem_type, input_tensors, output_tensors): + classes = output_tensors.get(prediction_key.PredictionKey.CLASSES) + scores = output_tensors.get(prediction_key.PredictionKey.SCORES) + return ((problem_type == constants.ProblemType.CLASSIFICATION or + problem_type == constants.ProblemType.LOGISTIC_REGRESSION) + and len(input_tensors) == 1 + and (classes or scores or len(output_tensors) == 1)) + + +def _is_regression_problem(problem_type, input_tensors, output_tensors): + return (problem_type == constants.ProblemType.LINEAR_REGRESSION + and len(input_tensors) == 1 + and len(output_tensors) == 1) + + +def get_input_alternatives(input_ops): + """Obtain all input alternatives using the input_fn output and heuristics.""" + input_alternatives = {} + if isinstance(input_ops, input_fn_utils.InputFnOps): + features, unused_labels, default_inputs = input_ops + input_alternatives[DEFAULT_INPUT_ALTERNATIVE_KEY] = default_inputs + else: + features, unused_labels = input_ops + + if not features: + raise ValueError('Features must be defined.') + + # Add the "features" input_signature in any case. + # Note defensive copy because model_fns alter the features dict. + input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY] = ( + copy.copy(features)) + + return input_alternatives, features + + +def get_output_alternatives( + model_fn_ops, + default_output_alternative_key=DEFAULT_OUTPUT_ALTERNATIVE_KEY): + """Obtain all output alternatives using the model_fn output and heuristics.""" + output_alternatives = model_fn_ops.output_alternatives + + # Identify the default outputs, creating them if needed. + if (output_alternatives + and default_output_alternative_key not in output_alternatives): + raise ValueError('default_output_alternative_key not in ' + 'output_alternatives: %s' % default_output_alternative_key) + + if (output_alternatives + and default_output_alternative_key in output_alternatives): + # If a default head is provided, use it. + actual_default_output_alternative_key = default_output_alternative_key + return output_alternatives, actual_default_output_alternative_key + + if output_alternatives and len(output_alternatives) == 1: + # If there is only one head, use it as the default. + (actual_default_output_alternative_key, _), = output_alternatives.items() + return output_alternatives, actual_default_output_alternative_key + + # Lacking provided output alternatives, the best we can do is to + # interpret the model as single-headed of unknown type. + default_problem_type = constants.ProblemType.UNSPECIFIED + default_outputs = model_fn_ops.predictions + actual_default_output_alternative_key = DEFAULT_OUTPUT_ALTERNATIVE_KEY + output_alternatives = {actual_default_output_alternative_key: + (default_problem_type, default_outputs)} + return output_alternatives, actual_default_output_alternative_key + + +def build_all_signature_defs(input_alternatives, output_alternatives, + actual_default_output_alternative_key): + """Build `SignatureDef`s from all pairs of input and output alternatives.""" + + signature_def_map = { + ('%s:%s' % (input_key, output_key or 'None')): + build_standardized_signature_def( + inputs, outputs, problem_type) + for input_key, inputs in input_alternatives.items() + for output_key, (problem_type, outputs) + in output_alternatives.items()} + + # Add the default SignatureDef + default_inputs = input_alternatives[DEFAULT_INPUT_ALTERNATIVE_KEY] + if not default_inputs: + default_inputs = input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY] + # default outputs are guaranteed to exist above + (default_problem_type, default_outputs) = ( + output_alternatives[actual_default_output_alternative_key]) + signature_def_map[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( + build_standardized_signature_def( + default_inputs, default_outputs, default_problem_type)) + + return signature_def_map + + +def get_timestamped_export_dir(export_dir_base): + """Builds a path to a new subdirectory within the base directory. + + Each export is written into a new subdirectory named using the + current time. This guarantees monotonically increasing version + numbers even across multiple runs of the pipeline. + The timestamp used is the number of milliseconds since epoch UTC. + + Args: + export_dir_base: A string containing a directory to write the exported + graph and checkpoints. + Returns: + The full path of the new subdirectory (which is not actually created yet). + """ + export_timestamp = int(time.time() * 1e3) + + export_dir = os.path.join( + compat.as_bytes(export_dir_base), + compat.as_bytes(str(export_timestamp))) + return export_dir + + +def garbage_collect_exports(export_dir_base, exports_to_keep): + """Deletes older exports, retaining only a given number of the most recent. + + Export subdirectories are assumed to be named with monotonically increasing + integers; the most recent are taken to be those with the largest values. + + Args: + export_dir_base: the base directory under which each export is in a + versioned subdirectory. + exports_to_keep: the number of recent exports to retain. + """ + if exports_to_keep is None: + return + + keep_filter = gc.largest_export_versions(exports_to_keep) + delete_filter = gc.negation(keep_filter) + + # Export dir must not end with / or it will break the re match below. + if export_dir_base.endswith('/'): + export_dir_base = export_dir_base[:-1] + + # create a simple parser that pulls the export_version from the directory. + def parser(path): + match = re.match('^' + export_dir_base + '/(\\d{13})$', path.path) + if not match: + return None + return path._replace(export_version=int(match.group(1))) + + for p in delete_filter(gc.get_paths(export_dir_base, parser=parser)): + gfile.DeleteRecursively(p.path) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py new file mode 100644 index 0000000000..538e0ab104 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py @@ -0,0 +1,228 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests of utilities supporting export to SavedModel.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile +import time + +import tensorflow as tf + +from tensorflow.contrib.learn.python.learn.estimators import constants +from tensorflow.contrib.learn.python.learn.estimators import model_fn +from tensorflow.contrib.learn.python.learn.utils import input_fn_utils +from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils +from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import signature_def_utils + + +class SavedModelExportUtilsTest(tf.test.TestCase): + + def test_build_standardized_signature_def(self): + input_tensors = { + "input-1": tf.placeholder(tf.float32, 1, name="input-tensor-1")} + output_tensors = { + "output-1": tf.placeholder(tf.float32, 1, name="output-tensor-1")} + problem_type = constants.ProblemType.LINEAR_REGRESSION + regression_signature_def = ( + saved_model_export_utils.build_standardized_signature_def( + input_tensors, output_tensors, problem_type)) + expected_regression_signature_def = meta_graph_pb2.SignatureDef() + shape = tensor_shape_pb2.TensorShapeProto( + dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) + dtype = types_pb2.DataType.Value("DT_FLOAT") + expected_regression_signature_def.inputs[ + signature_constants.REGRESS_INPUTS].CopyFrom( + meta_graph_pb2.TensorInfo(name="input-tensor-1:0", + dtype=dtype, + tensor_shape=shape)) + expected_regression_signature_def.outputs[ + signature_constants.REGRESS_OUTPUTS].CopyFrom( + meta_graph_pb2.TensorInfo(name="output-tensor-1:0", + dtype=dtype, + tensor_shape=shape)) + + expected_regression_signature_def.method_name = ( + signature_constants.REGRESS_METHOD_NAME) + self.assertEqual(regression_signature_def, + expected_regression_signature_def) + + def test_get_input_alternatives(self): + input_ops = input_fn_utils.InputFnOps("bogus features dict", None, + "bogus default input dict") + + input_alternatives, _ = saved_model_export_utils.get_input_alternatives( + input_ops) + self.assertEqual( + input_alternatives[ + saved_model_export_utils.DEFAULT_INPUT_ALTERNATIVE_KEY], + "bogus default input dict") + self.assertEqual( + input_alternatives[ + saved_model_export_utils.FEATURES_INPUT_ALTERNATIVE_KEY], + "bogus features dict") + + def test_get_output_alternatives_explicit(self): + provided_output_alternatives = { + "head-1": (constants.ProblemType.LINEAR_REGRESSION, + "bogus output dict"), + "head-2": (constants.ProblemType.CLASSIFICATION, + "bogus output dict 2"), + "head-3": (constants.ProblemType.UNSPECIFIED, + "bogus output dict 3"), + } + model_fn_ops = model_fn.ModelFnOps( + model_fn.ModeKeys.INFER, + predictions={"some_output": "bogus_tensor"}, + output_alternatives=provided_output_alternatives) + output_alternatives, _ = saved_model_export_utils.get_output_alternatives( + model_fn_ops, "head-1") + + self.assertEqual(provided_output_alternatives, output_alternatives) + + def test_get_output_alternatives_implicit(self): + prediction_tensor = tf.constant(["bogus"]) + model_fn_ops = model_fn.ModelFnOps( + model_fn.ModeKeys.INFER, + predictions={"some_output": prediction_tensor}, + output_alternatives=None) + + output_alternatives, _ = saved_model_export_utils.get_output_alternatives( + model_fn_ops, "some_output") + self.assertEqual( + {"default_output_alternative": (constants.ProblemType.UNSPECIFIED, + {"some_output": prediction_tensor})}, + output_alternatives) + + def test_build_all_signature_defs(self): + input_features = tf.constant(["10"]) + input_example = tf.constant(["11"]) + input_ops = input_fn_utils.InputFnOps( + {"features": input_features}, + None, + {"default input": input_example}) + input_alternatives, _ = ( + saved_model_export_utils.get_input_alternatives(input_ops)) + output_1 = tf.constant(["1"]) + output_2 = tf.constant(["2"]) + output_3 = tf.constant(["3"]) + provided_output_alternatives = { + "head-1": (constants.ProblemType.LINEAR_REGRESSION, + {"some_output_1": output_1}), + "head-2": (constants.ProblemType.CLASSIFICATION, + {"some_output_2": output_2}), + "head-3": (constants.ProblemType.UNSPECIFIED, + {"some_output_3": output_3}), + } + model_fn_ops = model_fn.ModelFnOps( + model_fn.ModeKeys.INFER, + predictions={"some_output": tf.constant(["4"])}, + output_alternatives=provided_output_alternatives) + output_alternatives, _ = ( + saved_model_export_utils.get_output_alternatives(model_fn_ops, + "head-1")) + + signature_defs = saved_model_export_utils.build_all_signature_defs( + input_alternatives, output_alternatives, "head-1") + + expected_signature_defs = { + "serving_default": + signature_def_utils.regression_signature_def( + input_example, output_1), + "default_input_alternative:head-1": + signature_def_utils.regression_signature_def( + input_example, output_1), + "default_input_alternative:head-2": + signature_def_utils.classification_signature_def( + input_example, output_2, None), + "default_input_alternative:head-3": + signature_def_utils.predict_signature_def( + {"input": input_example}, {"output": output_3}), + "features_input_alternative:head-1": + signature_def_utils.regression_signature_def( + input_features, output_1), + "features_input_alternative:head-2": + signature_def_utils.classification_signature_def( + input_features, output_2, None), + "features_input_alternative:head-3": + signature_def_utils.predict_signature_def( + {"input": input_features}, {"output": output_3}), + } + + self.assertDictEqual(expected_signature_defs, signature_defs) + + def test_get_timestamped_export_dir(self): + export_dir_base = tempfile.mkdtemp() + "export/" + export_dir_1 = saved_model_export_utils.get_timestamped_export_dir( + export_dir_base) + time.sleep(0.001) + export_dir_2 = saved_model_export_utils.get_timestamped_export_dir( + export_dir_base) + time.sleep(0.001) + export_dir_3 = saved_model_export_utils.get_timestamped_export_dir( + export_dir_base) + + # Export directories should be named using a timestamp that is milliseconds + # since epoch. Such a timestamp is 13 digits long. + time_1 = os.path.basename(export_dir_1) + self.assertEqual(13, len(time_1)) + time_2 = os.path.basename(export_dir_2) + self.assertEqual(13, len(time_2)) + time_3 = os.path.basename(export_dir_3) + self.assertEqual(13, len(time_3)) + + self.assertTrue(int(time_1) < int(time_2)) + self.assertTrue(int(time_2) < int(time_3)) + + def test_garbage_collect_exports(self): + export_dir_base = tempfile.mkdtemp() + "export/" + tf.gfile.MkDir(export_dir_base) + export_dir_1 = _create_test_export_dir(export_dir_base) + export_dir_2 = _create_test_export_dir(export_dir_base) + export_dir_3 = _create_test_export_dir(export_dir_base) + export_dir_4 = _create_test_export_dir(export_dir_base) + + self.assertTrue(tf.gfile.Exists(export_dir_1)) + self.assertTrue(tf.gfile.Exists(export_dir_2)) + self.assertTrue(tf.gfile.Exists(export_dir_3)) + self.assertTrue(tf.gfile.Exists(export_dir_4)) + + # Garbage collect all but the most recent 2 exports, + # where recency is determined based on the timestamp directory names. + saved_model_export_utils.garbage_collect_exports(export_dir_base, 2) + + self.assertFalse(tf.gfile.Exists(export_dir_1)) + self.assertFalse(tf.gfile.Exists(export_dir_2)) + self.assertTrue(tf.gfile.Exists(export_dir_3)) + self.assertTrue(tf.gfile.Exists(export_dir_4)) + + +def _create_test_export_dir(export_dir_base): + export_dir = saved_model_export_utils.get_timestamped_export_dir( + export_dir_base) + tf.gfile.MkDir(export_dir) + time.sleep(0.001) + return export_dir + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD index a2f012d349..e3ed248dd5 100644 --- a/tensorflow/contrib/linalg/BUILD +++ b/tensorflow/contrib/linalg/BUILD @@ -24,7 +24,7 @@ cuda_py_tests( cuda_py_tests( name = "linear_operator_diag_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/linear_operator_diag_test.py"], additional_deps = [ ":linalg_py", diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py index 98eac39683..d03fb1d66f 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py @@ -26,50 +26,18 @@ linalg = tf.contrib.linalg tf.set_random_seed(23) -class LinearOperatorDiagtest( - linear_operator_test_util.LinearOperatorDerivedClassTest): +class LinearOperatorDiagTest( + linear_operator_test_util.SquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" - @property - def _dtypes_to_test(self): - return [tf.float32, tf.float64] - - @property - def _shapes_to_test(self): - # non-batch operators (n, n) and batch operators. - return [(0, 0), (1, 1), (1, 3, 3), (3, 2, 2), (2, 1, 3, 3)] - - def _make_rhs(self, operator): - # This operator is square, so rhs and x will have same shape. - return self._make_x(operator) - - def _make_x(self, operator): - # Return the number of systems to solve, R, equal to 1 or 2. - r = self._get_num_systems(operator) - # If operator.shape = [B1,...,Bb, N, N] this returns a random matrix of - # shape [B1,...,Bb, N, R], R = 1 or 2. - if operator.shape.is_fully_defined(): - batch_shape = operator.batch_shape.as_list() - n = operator.domain_dimension.value - rhs_shape = batch_shape + [n, r] - else: - batch_shape = operator.batch_shape_dynamic() - n = operator.domain_dimension_dynamic() - rhs_shape = tf.concat(0, (batch_shape, [n, r])) - return tf.random_normal(shape=rhs_shape, dtype=operator.dtype) - - def _get_num_systems(self, operator): - """Get some number, either 1 or 2, depending on operator.""" - if operator.tensor_rank is None or operator.tensor_rank % 2: - return 1 - else: - return 2 - def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): shape = list(shape) diag_shape = shape[:-1] - diag = tf.random_normal(diag_shape, dtype=dtype) + diag = tf.random_normal(diag_shape, dtype=dtype.real_dtype) + if dtype.is_complex: + diag = tf.complex( + diag, tf.random_normal(diag_shape, dtype=dtype.real_dtype)) diag_ph = tf.placeholder(dtype=dtype) if use_placeholder: @@ -87,15 +55,32 @@ class LinearOperatorDiagtest( return operator, mat, feed_dict - def test_assert_positive_definite(self): - # Singlular matrix with one positive eigenvalue and one zero eigenvalue. + def test_assert_positive_definite_raises_for_zero_eigenvalue(self): + # Matrix with one positive eigenvalue and one zero eigenvalue. + with self.test_session(): + diag = [1.0, 0.0] + operator = linalg.LinearOperatorDiag(diag) + with self.assertRaisesOpError("non-positive.*not positive definite"): + operator.assert_positive_definite().run() + + def test_assert_positive_definite_raises_for_negative_real_eigvalues(self): with self.test_session(): - diag = [1.0, -1.0] + diag_x = [1.0, -2.0] + diag_y = [0., 0.] # Imaginary eigenvalues should not matter. + diag = tf.complex(diag_x, diag_y) operator = linalg.LinearOperatorDiag(diag) - with self.assertRaisesOpError("was not positive definite"): + with self.assertRaisesOpError("non-positive real.*not positive definite"): operator.assert_positive_definite().run() - def test_assert_non_singular(self): + def test_assert_positive_definite_does_not_raise_if_pd_and_complex(self): + with self.test_session(): + x = [1., 2.] + y = [1., 0.] + diag = tf.complex(x, y) # Re[diag] > 0. + # Should not fail + linalg.LinearOperatorDiag(diag).assert_positive_definite().run() + + def test_assert_non_singular_raises_if_zero_eigenvalue(self): # Singlular matrix with one positive eigenvalue and one zero eigenvalue. with self.test_session(): diag = [1.0, 0.0] @@ -103,10 +88,36 @@ class LinearOperatorDiagtest( with self.assertRaisesOpError("Singular operator"): operator.assert_non_singular().run() + def test_assert_non_singular_does_not_raise_for_complex_nonsingular(self): + with self.test_session(): + x = [1., 0.] + y = [0., 1.] + diag = tf.complex(x, y) + # Should not raise. + linalg.LinearOperatorDiag(diag).assert_non_singular().run() + + def test_assert_self_adjoint_raises_if_diag_has_complex_part(self): + with self.test_session(): + x = [1., 0.] + y = [0., 1.] + diag = tf.complex(x, y) + operator = linalg.LinearOperatorDiag(diag) + with self.assertRaisesOpError("imaginary.*not self-adjoint"): + operator.assert_self_adjoint().run() + + def test_assert_self_adjoint_does_not_raise_for_diag_with_zero_imag(self): + with self.test_session(): + x = [1., 0.] + y = [0., 0.] + diag = tf.complex(x, y) + operator = linalg.LinearOperatorDiag(diag) + # Should not raise + operator.assert_self_adjoint().run() + def test_broadcast_apply_and_solve(self): # These cannot be done in the automated (base test class) tests since they - # test shapes that tf.batch_matmul cannot handle. - # In particular, tf.batch_matmul does not broadcast. + # test shapes that tf.matmul cannot handle. + # In particular, tf.matmul does not broadcast. with self.test_session() as sess: x = tf.random_normal(shape=(2, 2, 3, 4)) @@ -122,7 +133,7 @@ class LinearOperatorDiagtest( self.assertAllEqual((2, 2, 3, 3), mat.get_shape()) # being pedantic. operator_apply = operator.apply(x) - mat_apply = tf.batch_matmul(mat, x) + mat_apply = tf.matmul(mat, x) self.assertAllEqual(operator_apply.get_shape(), mat_apply.get_shape()) self.assertAllClose(*sess.run([operator_apply, mat_apply])) diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py index eb279177ab..4228903388 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py @@ -16,9 +16,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np import tensorflow as tf linalg = tf.contrib.linalg +rng = np.random.RandomState(123) class LinearOperatorShape(linalg.LinearOperator): @@ -44,6 +46,31 @@ class LinearOperatorShape(linalg.LinearOperator): return tf.constant(self._stored_shape, dtype=tf.int32) +class LinearOperatorApplyOnly(linalg.LinearOperator): + """LinearOperator that simply wraps a [batch] matrix and implements apply.""" + + def __init__(self, + matrix, + is_non_singular=None, + is_self_adjoint=None, + is_positive_definite=None): + self._matrix = tf.convert_to_tensor(matrix, name="matrix") + super(LinearOperatorApplyOnly, self).__init__( + dtype=matrix.dtype, + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite,) + + def _shape(self): + return self._matrix.get_shape() + + def _shape_dynamic(self): + return tf.shape(self._matrix) + + def _apply(self, x, adjoint=False): + return tf.matmul(self._matrix, x, adjoint_a=adjoint) + + class LinearOperatorTest(tf.test.TestCase): def test_all_shape_properties_defined_by_the_one_property_shape(self): @@ -78,6 +105,23 @@ class LinearOperatorTest(tf.test.TestCase): self.assertTrue(operator.is_self_adjoint) self.assertFalse(operator.is_positive_definite) + def test_generic_to_dense_method_non_square_matrix_static(self): + matrix = rng.randn(2, 3, 4) + operator = LinearOperatorApplyOnly(matrix) + with self.test_session(): + operator_dense = operator.to_dense() + self.assertAllEqual((2, 3, 4), operator_dense.get_shape()) + self.assertAllClose(matrix, operator_dense.eval()) + + def test_generic_to_dense_method_non_square_matrix_dynamic(self): + matrix = rng.randn(2, 3, 4) + matrix_ph = tf.placeholder(tf.float64) + operator = LinearOperatorApplyOnly(matrix_ph) + with self.test_session(): + operator_dense = operator.to_dense() + self.assertAllClose( + matrix, operator_dense.eval(feed_dict={matrix_ph: matrix})) + -if __name__ == '__main__': +if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator.py b/tensorflow/contrib/linalg/python/ops/linear_operator.py index d5aa3fdf25..6199518af0 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator.py @@ -23,6 +23,7 @@ import contextlib from tensorflow.contrib import framework as contrib_framework from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops __all__ = ["LinearOperator"] @@ -114,6 +115,19 @@ class LinearOperator(object): ### Performance FILL THIS IN + + ### Matrix property hints + + This `LinearOperator` is initialized with boolean flags of the form `is_X`, + for `X = non_singular, self_adjoint` etc... + These have the following meaning + * If `is_X == True`, callers should expect the operator to have the + property `X`. This is a promise that should be fulfilled, but is *not* a + runtime assert. For example, finite floating point precision may result + in these promises being violated. + * If `is_X == False`, callers should expect the operator to not have `X`. + * If `is_X == None` (the default), callers should have no expectation either + way. """ def __init__(self, @@ -123,21 +137,11 @@ class LinearOperator(object): is_self_adjoint=None, is_positive_definite=None, name=None): - """Initialize the `LinearOperator`. + r"""Initialize the `LinearOperator`. **This is a private method for subclass use.** **Subclasses should copy-paste this `__init__` documentation.** - For `X = non_singular, self_adjoint` etc... - `is_X` is a Python `bool` initialization argument with the following meaning - * If `is_X == True`, callers should expect the operator to have the - attribute `X`. This is a promise that should be fulfilled, but is *not* a - runtime assert. Issues, such as floating point error, could mean the - operator violates this promise. - * If `is_X == False`, callers should expect the operator to not have `X`. - * If `is_X == None` (the default), callers should have no expectation either - way. - Args: dtype: The type of the this `LinearOperator`. Arguments to `apply` and `solve` will have to be this type. @@ -146,18 +150,21 @@ class LinearOperator(object): is_non_singular: Expect that this operator is non-singular. is_self_adjoint: Expect that this operator is equal to its hermitian transpose. If `dtype` is real, this is equivalent to being symmetric. - is_positive_definite: Expect that this operator is positive definite. - name: A name for this `LinearOperator`. Default: subclass name. + is_positive_definite: Expect that this operator is positive definite, + meaning the real part of all eigenvalues is positive. We do not require + the operator to be self-adjoint to be positive-definite. See: + https://en.wikipedia.org/wiki/Positive-definite_matrix\ + #Extension_for_non_symmetric_matrices + name: A name for this `LinearOperator`. Raises: ValueError: if any member of graph_parents is `None` or not a `Tensor`. """ - if is_positive_definite and not is_self_adjoint: - raise ValueError( - "A positive definite matrix is by definition self adjoint") - if is_positive_definite and not is_non_singular: - raise ValueError( - "A positive definite matrix is by definition non-singular") + # Check and auto-set flags. + if is_positive_definite: + if is_non_singular is False: + raise ValueError("A positive definite matrix is always non-singular.") + is_non_singular = True graph_parents = [] if graph_parents is None else graph_parents for i, t in enumerate(graph_parents): @@ -384,10 +391,28 @@ class LinearOperator(object): raise NotImplementedError("assert_positive_definite is not implemented.") def assert_positive_definite(self, name="assert_positive_definite"): - """Returns an `Op` that asserts this operator is positive definite.""" + """Returns an `Op` that asserts this operator is positive definite. + + Here, positive definite means the real part of all eigenvalues is positive. + We do not require the operator to be self-adjoint. + + Args: + name: A name to give this `Op`. + + Returns: + An `Op` that asserts this operator is positive definite. + """ with self._name_scope(name): return self._assert_positive_definite() + def _assert_self_adjoint(self): + raise NotImplementedError("assert_self_adjoint is not implemented.") + + def assert_self_adjoint(self, name="assert_self_adjoint"): + """Returns an `Op` that asserts this operator is self-adjoint.""" + with self._name_scope(name): + return self._assert_self_adjoint() + def _apply(self, x, adjoint=False): raise NotImplementedError("_apply is not implemented.") @@ -485,9 +510,38 @@ class LinearOperator(object): return self._solve(rhs, adjoint=adjoint) def _to_dense(self): - raise NotImplementedError("_to_dense is not implemented.") + """Generic and often inefficient implementation. Override often.""" + if self.batch_shape.is_fully_defined(): + batch_shape = self.batch_shape + else: + batch_shape = self.batch_shape_dynamic() + + if self.domain_dimension.value is not None: + n = self.domain_dimension.value + else: + n = self.domain_dimension_dynamic() + + eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype) + return self.apply(eye) def to_dense(self, name="to_dense"): """Return a dense (batch) matrix representing this operator.""" with self._name_scope(name): return self._to_dense() + + def _add_to_tensor(self, x): + raise NotImplementedError("_add_to_tensor is not implemented.") + + def add_to_tensor(self, x, name="add_to_tensor"): + """Add matrix represented by this operator to `x`. Equivalent to `A + x`. + + Args: + x: `Tensor` with same `dtype` and shape broadcastable to `self.shape`. + name: A name to give this `Op`. + + Returns: + A `Tensor` with broadcast shape and same `dtype` as `self`. + """ + with self._name_scope(name, values=[x]): + x = ops.convert_to_tensor(x, name="x") + return self._add_to_tensor(x) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py index 6f1c769758..f65ed9a6c8 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.linalg.python.ops import linear_operator +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -34,7 +35,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator): This operator acts like a [batch] matrix `A` with shape `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is - an `m x n` matrix. Again, this matrix `A` may not be materialized, but for + an `N x N` matrix. This matrix `A` is not materialized, but for purposes of broadcasting this shape will be relevant. `LinearOperatorDiag` is initialized with a (batch) vector. @@ -48,7 +49,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator): ==> [[1., 0.] [0., -1.]] - operator.shape() + operator.shape ==> [2, 2] operator.log_determinant() @@ -83,7 +84,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator): ### Performance - Suppose `operator` is a `LinearOperatorDiag` is of shape `[N, N]`, + Suppose `operator` is a `LinearOperatorDiag` of shape `[N, N]`, and `x.shape = [N, R]`. Then * `operator.apply(x)` involves `N*R` multiplications. @@ -92,6 +93,19 @@ class LinearOperatorDiag(linear_operator.LinearOperator): If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. + + ### Matrix property hints + + This `LinearOperator` is initialized with boolean flags of the form `is_X`, + for `X = non_singular, self_adjoint` etc... + These have the following meaning + * If `is_X == True`, callers should expect the operator to have the + property `X`. This is a promise that should be fulfilled, but is *not* a + runtime assert. For example, finite floating point precision may result + in these promises being violated. + * If `is_X == False`, callers should expect the operator to not have `X`. + * If `is_X == None` (the default), callers should have no expectation either + way. """ def __init__(self, @@ -102,44 +116,45 @@ class LinearOperatorDiag(linear_operator.LinearOperator): name="LinearOperatorDiag"): """Initialize a `LinearOperatorDiag`. - For `X = non_singular, self_adjoint` etc... - `is_X` is a Python `bool` initialization argument with the following meaning - * If `is_X == True`, callers should expect the operator to have the - attribute `X`. This is a promise that should be fulfilled, but is *not* a - runtime assert. Issues, such as floating point error, could mean the - operator violates this promise. - * If `is_X == False`, callers should expect the operator to not have `X`. - * If `is_X == None` (the default), callers should have no expectation either - way. - Args: - diag: Shape `[B1,...,Bb, N]` real float type `Tensor` with `b >= 0`, - `N >= 0`. The diagonal of the operator. + diag: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. + The diagonal of the operator. Allowed dtypes: `float32`, `float64`, + `complex64`, `complex128`. is_non_singular: Expect that this operator is non-singular. is_self_adjoint: Expect that this operator is equal to its hermitian transpose. Since this is a real (not complex) diagonal operator, it is always self adjoint. - is_positive_definite: Expect that this operator is positive definite. - name: A name for this `LinearOperator`. Default: subclass name. + is_positive_definite: Expect that this operator is positive definite, + meaning the real part of all eigenvalues is positive. We do not require + the operator to be self-adjoint to be positive-definite. See: + https://en.wikipedia.org/wiki/Positive-definite_matrix + #Extension_for_non_symmetric_matrices + name: A name for this `LinearOperator`. Raises: - ValueError: If `diag.dtype` is not floating point. + TypeError: If `diag.dtype` is not an allowed type. ValueError: If `is_self_adjoint` is not `True`. """ + allowed_dtypes = [ + dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] + with ops.name_scope(name, values=[diag]): self._diag = ops.convert_to_tensor(diag, name="diag") - if not self._diag.dtype.is_floating: - raise ValueError("Only real floating point matrices are supported.") - if not is_self_adjoint: - raise ValueError("A real diagonal matrix is always self adjoint.") + dtype = self._diag.dtype + if dtype not in allowed_dtypes: + raise TypeError( + "Argument diag must have dtype in %s. Found: %s" + % (allowed_dtypes, dtype)) + if dtype.is_floating and not is_self_adjoint: + raise ValueError("A real diagonal operator is always self adjoint.") super(LinearOperatorDiag, self).__init__( - dtype=self._diag.dtype, + dtype=dtype, graph_parents=[self._diag], is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, - is_positive_definite=is_non_singular, + is_positive_definite=is_positive_definite, name=name) def _shape(self): @@ -153,20 +168,42 @@ class LinearOperatorDiag(linear_operator.LinearOperator): return array_ops.concat(0, (d_shape, [k])) def _assert_non_singular(self): + if self.dtype.is_complex: + should_be_nonzero = math_ops.complex_abs(self._diag) + else: + should_be_nonzero = self._diag + nonzero_diag = math_ops.reduce_all( - math_ops.logical_not(math_ops.equal(self._diag, 0))) + math_ops.logical_not(math_ops.equal(should_be_nonzero, 0))) + return control_flow_ops.Assert( nonzero_diag, data=["Singular operator: diag contained zero values.", self._diag]) def _assert_positive_definite(self): + if self.dtype.is_complex: + message = ( + "Diagonal operator had diagonal entries with non-positive real part, " + "thus was not positive definite.") + else: + message = ( + "Real diagonal operator had non-positive diagonal entries, " + "thus was not positive definite.") + return check_ops.assert_positive( + math_ops.real(self._diag), + message=message) + + def _assert_self_adjoint(self): + return _assert_imag_part_zero( self._diag, - message="Operator was not positive definite: diag was not all positive") + message=( + "This diagonal operator contained non-zero imaginary values. " + " Thus it was not self-adjoint.")) def _apply(self, x, adjoint=False): - # adjoint has no effect since this matrix is self-adjoint. - diag_mat = array_ops.expand_dims(self._diag, -1) + diag_term = math_ops.conj(self._diag) if adjoint else self._diag + diag_mat = array_ops.expand_dims(diag_term, -1) return diag_mat * x def _determinant(self): @@ -177,9 +214,29 @@ class LinearOperatorDiag(linear_operator.LinearOperator): math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1]) def _solve(self, rhs, adjoint=False): - # adjoint has no effect since this matrix is self-adjoint. - inv_diag_mat = array_ops.expand_dims(1. / self._diag, -1) + diag_term = math_ops.conj(self._diag) if adjoint else self._diag + inv_diag_mat = array_ops.expand_dims(1. / diag_term, -1) return rhs * inv_diag_mat def _to_dense(self): return array_ops.matrix_diag(self._diag) + + def _add_to_tensor(self, x): + x_diag = array_ops.matrix_diag_part(x) + new_diag = self._diag + x_diag + return array_ops.matrix_set_diag(x, new_diag) + + +def _assert_imag_part_zero(x, message=None): + """Assert that floating or complex 'x' is real.""" + dtype = x.dtype.base_dtype + if dtype.is_floating: + return control_flow_ops.no_op() + + if not dtype.is_complex: + raise TypeError( + "imag_part_zero only handles float or complex types. Found: %s" + % dtype) + + zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype) + return check_ops.assert_equal(zero, math_ops.imag(x), message=message) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py index adbdb9b3d2..20136bfbd0 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py @@ -31,10 +31,25 @@ class LinearOperatorDerivedClassTest(tf.test.TestCase): test methods to work. """ - @abc.abstractproperty + # Absolute/relative tolerance for tests. + _atol = { + tf.float16: 1e-3, tf.float32: 1e-6, tf.float64: 1e-12, tf.complex64: 1e-6, + tf.complex128: 1e-12} + _rtol = { + tf.float16: 1e-3, tf.float32: 1e-6, tf.float64: 1e-12, tf.complex64: 1e-6, + tf.complex128: 1e-12} + + def assertAC(self, x, y): + """Derived classes can set _atol, _rtol to get different tolerance.""" + dtype = tf.as_dtype(x.dtype) + atol = self._atol[dtype] + rtol = self._rtol[dtype] + self.assertAllClose(x, y, atol=atol, rtol=rtol) + + @property def _dtypes_to_test(self): - """Returns list of numpy or tensorflow dtypes. Each will be tested.""" - raise NotImplementedError("dtypes_to_test has not been implemented.") + # TODO(langmore) Test tf.float16 once tf.matrix_diag works in 16bit. + return [tf.float32, tf.float64, tf.complex64, tf.complex128] @abc.abstractproperty def _shapes_to_test(self): @@ -57,8 +72,9 @@ class LinearOperatorDerivedClassTest(tf.test.TestCase): Returns: operator: `LinearOperator` subclass instance. mat: `Tensor` representing operator. - feed_dict: Dictionary. If placholder is True, this must be fed to - sess.run calls at runtime to make the operator work. + feed_dict: Dictionary. + If placholder is True, this must contains everything needed to be fed + to sess.run calls at runtime to make the operator work. """ # Create a matrix as a numpy array with desired shape/dtype. # Create a LinearOperator that should have the same behavior as the matrix. @@ -74,107 +90,145 @@ class LinearOperatorDerivedClassTest(tf.test.TestCase): """Make a rhs appropriate for calling operator.apply(rhs).""" raise NotImplementedError("_make_x is not defined.") - def _maybe_adjoint(self, x, adjoint): - if adjoint: - return tf.matrix_transpose(x) - else: - return x + @property + def _tests_to_skip(self): + """List of test names to skip.""" + # Subclasses should over-ride if they want to skip some tests. + # To skip "test_foo", add "foo" to this list. + return [] + + def _maybe_skip(self, test_name): + if test_name in self._tests_to_skip: + self.skipTest("%s skipped because it was added to self._tests_to_skip.") def test_to_dense(self): + self._maybe_skip("to_dense") with self.test_session() as sess: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: - operator, mat, _ = self._operator_and_mat_and_feed_dict( - shape, dtype, use_placeholder=False) - op_dense = operator.to_dense() - self.assertAllEqual(shape, op_dense.get_shape()) - op_dense_v, mat_v = sess.run([op_dense, mat]) - self.assertAllClose(op_dense_v, mat_v) - - def test_to_dense_dynamic(self): - with self.test_session() as sess: - for shape in self._shapes_to_test: - for dtype in self._dtypes_to_test: - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( - shape, dtype, use_placeholder=True) - op_dense_v, mat_v = sess.run( - [operator.to_dense(), mat], feed_dict=feed_dict) - self.assertAllClose(op_dense_v, mat_v) + for use_placeholder in False, True: + operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + shape, dtype, use_placeholder=use_placeholder) + op_dense = operator.to_dense() + if not use_placeholder: + self.assertAllEqual(shape, op_dense.get_shape()) + op_dense_v, mat_v = sess.run([op_dense, mat], feed_dict=feed_dict) + self.assertAC(op_dense_v, mat_v) def test_det(self): + self._maybe_skip("det") with self.test_session() as sess: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: - operator, mat, _ = self._operator_and_mat_and_feed_dict( - shape, dtype, use_placeholder=False) - op_det = operator.determinant() - self.assertAllEqual(shape[:-2], op_det.get_shape()) - op_det_v, mat_det_v = sess.run([op_det, tf.matrix_determinant(mat)]) - self.assertAllClose(op_det_v, mat_det_v) - - def test_det_dynamic(self): - with self.test_session() as sess: - for shape in self._shapes_to_test: - for dtype in self._dtypes_to_test: - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( - shape, dtype, use_placeholder=True) - op_det_v, mat_det_v = sess.run( - [operator.determinant(), tf.matrix_determinant(mat)], - feed_dict=feed_dict) - self.assertAllClose(op_det_v, mat_det_v) + if dtype.is_complex: + self.skipTest( + "tf.matrix_determinant does not work with complex, so this test" + " is being skipped.") + for use_placeholder in False, True: + operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + shape, dtype, use_placeholder=use_placeholder) + op_det = operator.determinant() + if not use_placeholder: + self.assertAllEqual(shape[:-2], op_det.get_shape()) + op_det_v, mat_det_v = sess.run( + [op_det, tf.matrix_determinant(mat)], feed_dict=feed_dict) + self.assertAC(op_det_v, mat_det_v) def test_apply(self): + self._maybe_skip("apply") with self.test_session() as sess: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: - operator, mat, _ = self._operator_and_mat_and_feed_dict( - shape, dtype, use_placeholder=False) - for adjoint in [False, True]: - if adjoint and operator.is_self_adjoint: - continue - x = self._make_x(operator) - op_apply = operator.apply(x, adjoint=adjoint) - mat_apply = tf.batch_matmul(self._maybe_adjoint(mat, adjoint), x) - self.assertAllEqual(op_apply.get_shape(), mat_apply.get_shape()) - op_apply_v, mat_apply_v = sess.run([op_apply, mat_apply]) - self.assertAllClose(op_apply_v, mat_apply_v) - - def test_apply_dynamic(self): - with self.test_session() as sess: - for shape in self._shapes_to_test: - for dtype in self._dtypes_to_test: - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( - shape, dtype, use_placeholder=True) - x = self._make_x(operator) - op_apply_v, mat_apply_v = sess.run( - [operator.apply(x), tf.batch_matmul(mat, x)], - feed_dict=feed_dict) - self.assertAllClose(op_apply_v, mat_apply_v) + for use_placeholder in False, True: + for adjoint in [False, True]: + operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + shape, dtype, use_placeholder=use_placeholder) + x = self._make_x(operator) + op_apply = operator.apply(x, adjoint=adjoint) + mat_apply = tf.matmul(mat, x, adjoint_a=adjoint) + if not use_placeholder: + self.assertAllEqual(op_apply.get_shape(), mat_apply.get_shape()) + op_apply_v, mat_apply_v = sess.run( + [op_apply, mat_apply], feed_dict=feed_dict) + self.assertAC(op_apply_v, mat_apply_v) def test_solve(self): + self._maybe_skip("solve") with self.test_session() as sess: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: - operator, mat, _ = self._operator_and_mat_and_feed_dict( - shape, dtype, use_placeholder=False) - for adjoint in [False, True]: - if adjoint and operator.is_self_adjoint: - continue - rhs = self._make_rhs(operator) - op_solve = operator.solve(rhs, adjoint=adjoint) - mat_solve = tf.matrix_solve(self._maybe_adjoint(mat, adjoint), rhs) - self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape()) - op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve]) - self.assertAllClose(op_solve_v, mat_solve_v) - - def test_solve_dynamic(self): + for use_placeholder in False, True: + for adjoint in [False, True]: + operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + shape, dtype, use_placeholder=use_placeholder) + rhs = self._make_rhs(operator) + op_solve = operator.solve(rhs, adjoint=adjoint) + mat_solve = tf.matrix_solve(mat, rhs, adjoint=adjoint) + if not use_placeholder: + self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape()) + op_solve_v, mat_solve_v = sess.run( + [op_solve, mat_solve], feed_dict=feed_dict) + self.assertAC(op_solve_v, mat_solve_v) + + def test_add_to_tensor(self): + self._maybe_skip("add_to_tensor") with self.test_session() as sess: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( - shape, dtype, use_placeholder=True) - rhs = self._make_rhs(operator) - op_solve_v, mat_solve_v = sess.run( - [operator.solve(rhs), tf.matrix_solve(mat, rhs)], - feed_dict=feed_dict) - self.assertAllClose(op_solve_v, mat_solve_v) + for use_placeholder in False, True: + operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + shape, dtype, use_placeholder=use_placeholder) + op_plus_2mat = operator.add_to_tensor(2 * mat) + + if not use_placeholder: + self.assertAllEqual(shape, op_plus_2mat.get_shape()) + + op_plus_2mat_v, mat_v = sess.run( + [op_plus_2mat, mat], feed_dict=feed_dict) + + self.assertAC(op_plus_2mat_v, 3 * mat_v) + + +@six.add_metaclass(abc.ABCMeta) +class SquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest): + """Base test class appropriate for square operators. + + Sub-classes must still define all abstractmethods from + LinearOperatorDerivedClassTest that are not defined here. + """ + + @property + def _shapes_to_test(self): + # non-batch operators (n, n) and batch operators. + return [(0, 0), (1, 1), (1, 3, 3), (3, 4, 4), (2, 1, 4, 4)] + + def _make_rhs(self, operator): + # This operator is square, so rhs and x will have same shape. + return self._make_x(operator) + + def _make_x(self, operator): + # Return the number of systems to solve, R, equal to 1 or 2. + r = self._get_num_systems(operator) + # If operator.shape = [B1,...,Bb, N, N] this returns a random matrix of + # shape [B1,...,Bb, N, R], R = 1 or 2. + if operator.shape.is_fully_defined(): + batch_shape = operator.batch_shape.as_list() + n = operator.domain_dimension.value + rhs_shape = batch_shape + [n, r] + else: + batch_shape = operator.batch_shape_dynamic() + n = operator.domain_dimension_dynamic() + rhs_shape = tf.concat(0, (batch_shape, [n, r])) + + x = tf.random_normal(shape=rhs_shape, dtype=operator.dtype.real_dtype) + if operator.dtype.is_complex: + x = tf.complex( + x, tf.random_normal(shape=rhs_shape, dtype=operator.dtype.real_dtype)) + return x + + def _get_num_systems(self, operator): + """Get some number, either 1 or 2, depending on operator.""" + if operator.tensor_rank is None or operator.tensor_rank % 2: + return 1 + else: + return 2 diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD index c607707bbb..3b68c51413 100644 --- a/tensorflow/contrib/linear_optimizer/BUILD +++ b/tensorflow/contrib/linear_optimizer/BUILD @@ -17,7 +17,8 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/lookup:lookup_py", + ":sharded_mutable_dense_hashtable_py", + ":sparse_feature_column_py", ], ) @@ -34,6 +35,47 @@ py_test( ], ) +py_library( + name = "sharded_mutable_dense_hashtable_py", + srcs = ["python/ops/sharded_mutable_dense_hashtable.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/lookup:lookup_py", + ], +) + +py_test( + name = "sharded_mutable_dense_hashtable_test", + size = "small", + srcs = ["python/ops/sharded_mutable_dense_hashtable_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":sharded_mutable_dense_hashtable_py", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +py_library( + name = "sparse_feature_column_py", + srcs = ["python/ops/sparse_feature_column.py"], + srcs_version = "PY2AND3", +) + +py_test( + name = "sparse_feature_column_test", + size = "small", + srcs = ["python/ops/sparse_feature_column_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":sparse_feature_column_py", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/linear_optimizer/__init__.py b/tensorflow/contrib/linear_optimizer/__init__.py index 40445c456f..83bd8b5fcf 100644 --- a/tensorflow/contrib/linear_optimizer/__init__.py +++ b/tensorflow/contrib/linear_optimizer/__init__.py @@ -23,5 +23,5 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel -from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SparseFeatureColumn +from tensorflow.contrib.linear_optimizer.python.ops.sparse_feature_column import SparseFeatureColumn from tensorflow.contrib.linear_optimizer.python.sdca_optimizer import SDCAOptimizer diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index 40a6404881..8e918e8880 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -22,9 +22,8 @@ from threading import Thread import tensorflow as tf -from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import _ShardedMutableDenseHashTable from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel -from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SparseFeatureColumn +from tensorflow.contrib.linear_optimizer.python.ops.sparse_feature_column import SparseFeatureColumn from tensorflow.python.framework.test_util import TensorFlowTestCase from tensorflow.python.ops import gen_sdca_ops from tensorflow.python.platform import googletest @@ -980,27 +979,6 @@ class SdcaWithSmoothHingeLossTest(SdcaModelTest): self.assertAllClose(0.44, regularized_loss.eval(), atol=0.02) -class SparseFeatureColumnTest(SdcaModelTest): - """Tests for SparseFeatureColumn. - """ - - def testBasic(self): - expected_example_indices = [1, 1, 1, 2] - expected_feature_indices = [0, 1, 2, 0] - sfc = SparseFeatureColumn(expected_example_indices, - expected_feature_indices, None) - self.assertTrue(isinstance(sfc.example_indices, tf.Tensor)) - self.assertTrue(isinstance(sfc.feature_indices, tf.Tensor)) - self.assertEqual(sfc.feature_values, None) - with self._single_threaded_test_session(): - self.assertAllEqual(expected_example_indices, sfc.example_indices.eval()) - self.assertAllEqual(expected_feature_indices, sfc.feature_indices.eval()) - expected_feature_values = [1.0, 2.0, 3.0, 4.0] - sfc = SparseFeatureColumn([1, 1, 1, 2], [0, 1, 2, 0], - expected_feature_values) - with self._single_threaded_test_session(): - self.assertAllEqual(expected_feature_values, sfc.feature_values.eval()) - class SdcaFprintTest(SdcaModelTest): """Tests for the SdcaFprint op. @@ -1020,74 +998,5 @@ class SdcaFprintTest(SdcaModelTest): [603227410218889250, 8762207001949257490]], out_data.eval()) - -class ShardedMutableDenseHashTableTest(SdcaModelTest): - """Tests for the _ShardedMutableHashTable class.""" - - def testShardedMutableHashTable(self): - for num_shards in [1, 3, 10]: - with self._single_threaded_test_session(): - default_val = -1 - empty_key = 0 - keys = tf.constant([11, 12, 13], tf.int64) - values = tf.constant([0, 1, 2], tf.int64) - table = _ShardedMutableDenseHashTable( - tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = tf.constant([11, 12, 14], tf.int64) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) - self.assertAllEqual([0, 1, -1], output.eval()) - - def testShardedMutableHashTableVectors(self): - for num_shards in [1, 3, 10]: - with self._single_threaded_test_session(): - default_val = [-0.1, 0.2] - empty_key = [0, 1] - keys = tf.constant([[11, 12], [13, 14], [15, 16]], tf.int64) - values = tf.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]], tf.float32) - table = _ShardedMutableDenseHashTable( - tf.int64, tf.float32, default_val, empty_key, num_shards=num_shards) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = tf.constant([[11, 12], [13, 14], [11, 14]], tf.int64) - output = table.lookup(input_string) - self.assertAllEqual([3, 2], output.get_shape()) - self.assertAllClose([[0.5, 0.6], [1.5, 1.6], [-0.1, 0.2]], - output.eval()) - - def testExportSharded(self): - with self._single_threaded_test_session(): - empty_key = -2 - default_val = -1 - num_shards = 2 - keys = tf.constant([10, 11, 12], tf.int64) - values = tf.constant([2, 3, 4], tf.int64) - table = _ShardedMutableDenseHashTable( - tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - keys_list, values_list = table.export_sharded() - self.assertAllEqual(num_shards, len(keys_list)) - self.assertAllEqual(num_shards, len(values_list)) - - # Exported keys include empty key buckets set to the empty_key - self.assertAllEqual(set([-2, 10, 12]), set(keys_list[0].eval().flatten())) - self.assertAllEqual(set([-2, 11]), set(keys_list[1].eval().flatten())) - # Exported values include empty value buckets set to 0 - self.assertAllEqual(set([0, 2, 4]), set(values_list[0].eval().flatten())) - self.assertAllEqual(set([0, 3]), set(values_list[1].eval().flatten())) - - if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index 7143520e3f..415aa752ac 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -22,16 +22,14 @@ import collections from six.moves import range -from tensorflow.contrib.lookup import lookup_ops +from tensorflow.contrib.linear_optimizer.python.ops.sharded_mutable_dense_hashtable import ShardedMutableDenseHashTable from tensorflow.python import summary from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.framework.ops import internal_convert_to_tensor from tensorflow.python.framework.ops import name_scope from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_sdca_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -41,233 +39,6 @@ from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits __all__ = ['SdcaModel'] -class _ShardedMutableDenseHashTable(lookup_ops.LookupInterface): - """A sharded version of MutableDenseHashTable. - - It is designed to be interface compatible with LookupInterface and - MutableDenseHashTable, with the exception of the export method, which is - replaced by an export_sharded method. - - The _ShardedMutableDenseHashTable keeps `num_shards` MutableDenseHashTable - internally. The shard is computed via the modulo operation on the key. - """ - - # TODO(andreasst): consider moving this to lookup_ops - - def __init__(self, - key_dtype, - value_dtype, - default_value, - empty_key, - num_shards=1, - name='ShardedMutableHashTable'): - with ops.name_scope(name, 'sharded_mutable_hash_table') as scope: - super(_ShardedMutableDenseHashTable, self).__init__(key_dtype, - value_dtype, scope) - table_shards = [] - for i in range(num_shards): - table_shards.append( - lookup_ops.MutableDenseHashTable( - key_dtype=key_dtype, - value_dtype=value_dtype, - default_value=default_value, - empty_key=empty_key, - name='%s-%d-of-%d' % (name, i + 1, num_shards))) - self._table_shards = table_shards - # TODO(andreasst): add a value_shape() method to LookupInterface - # pylint: disable=protected-access - self._value_shape = self._table_shards[0]._value_shape - # pylint: enable=protected-access - - @property - def _num_shards(self): - return len(self._table_shards) - - @property - def table_shards(self): - return self._table_shards - - def size(self, name=None): - with ops.name_scope(name, 'sharded_mutable_hash_table_size'): - sizes = [ - self._table_shards[i].size() for i in range(self._num_shards) - ] - return math_ops.add_n(sizes) - - def _shard_indices(self, keys): - key_shape = keys.get_shape() - if key_shape.ndims > 1: - # If keys are a matrix (i.e. a single key is a vector), we use the first - # element of each key vector to determine the shard. - keys = array_ops.slice(keys, [0, 0], [key_shape[0].value, 1]) - keys = array_ops.reshape(keys, [-1]) - indices = math_ops.mod(math_ops.abs(keys), self._num_shards) - return math_ops.cast(indices, dtypes.int32) - - def _check_keys(self, keys): - if not keys.get_shape().is_fully_defined(): - raise ValueError('Key shape must be fully defined, got %s.' % - keys.get_shape()) - if keys.get_shape().ndims != 1 and keys.get_shape().ndims != 2: - raise ValueError('Expected a vector or matrix for keys, got %s.' % - keys.get_shape()) - - def lookup(self, keys, name=None): - if keys.dtype != self._key_dtype: - raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' % - (self._key_dtype, keys.dtype)) - self._check_keys(keys) - num_shards = self._num_shards - if num_shards == 1: - return self._table_shards[0].lookup(keys, name=name) - - shard_indices = self._shard_indices(keys) - # TODO(andreasst): support 'keys' that are not vectors - key_shards = data_flow_ops.dynamic_partition(keys, shard_indices, - num_shards) - value_shards = [ - self._table_shards[i].lookup(key_shards[i], name=name) - for i in range(num_shards) - ] - - num_keys = keys.get_shape().dims[0] - original_indices = math_ops.range(num_keys) - partitioned_indices = data_flow_ops.dynamic_partition(original_indices, - shard_indices, - num_shards) - result = data_flow_ops.dynamic_stitch(partitioned_indices, value_shards) - result.set_shape( - tensor_shape.TensorShape([num_keys]).concatenate(self._value_shape)) - return result - - def insert(self, keys, values, name=None): - self._check_keys(keys) - num_shards = self._num_shards - if num_shards == 1: - return self._table_shards[0].insert(keys, values, name=name) - - shard_indices = self._shard_indices(keys) - # TODO(andreasst): support 'keys' that are not vectors - key_shards = data_flow_ops.dynamic_partition(keys, shard_indices, - num_shards) - value_shards = data_flow_ops.dynamic_partition(values, shard_indices, - num_shards) - return_values = [ - self._table_shards[i].insert(key_shards[i], value_shards[i], name=name) - for i in range(num_shards) - ] - - return control_flow_ops.group(*return_values) - - def export_sharded(self, name=None): - """Returns lists of the keys and values tensors in the sharded table. - - Returns: - A pair of lists with the first list containing the key tensors and the - second list containing the value tensors from each shard. - """ - keys_list = [] - values_list = [] - for table_shard in self._table_shards: - exported_keys, exported_values = table_shard.export(name=name) - keys_list.append(exported_keys) - values_list.append(exported_values) - return keys_list, values_list - - -class SparseFeatureColumn(object): - """Represents a sparse feature column. - - Contains three tensors representing a sparse feature column, they are - example indices (int64), feature indices (int64), and feature values (float). - Feature weights are optional, and are treated as 1.0f if missing. - - For example, consider a batch of 4 examples, which contains the following - features in a particular SparseFeatureColumn: - Example 0: feature 5, value 1 - Example 1: feature 6, value 1 and feature 10, value 0.5 - Example 2: no features - Example 3: two copies of feature 2, value 1 - - This SparseFeatureColumn will be represented as follows: - <0, 5, 1> - <1, 6, 1> - <1, 10, 0.5> - <3, 2, 1> - <3, 2, 1> - - For a batch of 2 examples below: - Example 0: feature 5 - Example 1: feature 6 - - is represented by SparseFeatureColumn as: - <0, 5, 1> - <1, 6, 1> - - ``` - - @@__init__ - @@example_indices - @@feature_indices - @@feature_values - """ - - def __init__(self, example_indices, feature_indices, feature_values): - """Creates a `SparseFeatureColumn` representation. - - Args: - example_indices: A 1-D int64 tensor of shape `[N]`. Also, accepts - python lists, or numpy arrays. - feature_indices: A 1-D int64 tensor of shape `[N]`. Also, accepts - python lists, or numpy arrays. - feature_values: An optional 1-D tensor float tensor of shape `[N]`. Also, - accepts python lists, or numpy arrays. - - Returns: - A `SparseFeatureColumn` - """ - with name_scope(None, 'SparseFeatureColumn', - [example_indices, feature_indices]): - self._example_indices = internal_convert_to_tensor(example_indices, - name='example_indices', - dtype=dtypes.int64) - self._feature_indices = internal_convert_to_tensor(feature_indices, - name='feature_indices', - dtype=dtypes.int64) - self._feature_values = None - if feature_values is not None: - with name_scope(None, 'SparseFeatureColumn', [feature_values]): - self._feature_values = internal_convert_to_tensor(feature_values, - name='feature_values', - dtype=dtypes.float32) - - @property - def example_indices(self): - """The example indices represented as a dense tensor. - - Returns: - A 1-D Tensor of int64 with shape `[N]`. - """ - return self._example_indices - - @property - def feature_indices(self): - """The feature indices represented as a dense tensor. - - Returns: - A 1-D Tensor of int64 with shape `[N]`. - """ - return self._feature_indices - - @property - def feature_values(self): - """The feature values represented as a dense tensor. - - Returns: - May return None, or a 1-D Tensor of float32 with shape `[N]`. - """ - return self._feature_values - # TODO(sibyl-Aix6ihai): add name_scope to appropriate methods. class SdcaModel(object): @@ -372,7 +143,7 @@ class SdcaModel(object): self._variables = variables self._options = options self._create_slots() - self._hashtable = _ShardedMutableDenseHashTable( + self._hashtable = ShardedMutableDenseHashTable( key_dtype=dtypes.int64, value_dtype=dtypes.float32, num_shards=self._num_table_shards(), diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py new file mode 100644 index 0000000000..494dfb6c99 --- /dev/null +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -0,0 +1,167 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Sharded mutable dense hash table.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six.moves import range + +from tensorflow.contrib.lookup import lookup_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import math_ops + + +class ShardedMutableDenseHashTable(lookup_ops.LookupInterface): + """A sharded version of MutableDenseHashTable. + + It is designed to be interface compatible with LookupInterface and + MutableDenseHashTable, with the exception of the export method, which is + replaced by an export_sharded method. + + The _ShardedMutableDenseHashTable keeps `num_shards` MutableDenseHashTable + internally. The shard is computed via the modulo operation on the key. + """ + + # TODO(andreasst): consider moving this to lookup_ops + + def __init__(self, + key_dtype, + value_dtype, + default_value, + empty_key, + num_shards=1, + name='ShardedMutableHashTable'): + with ops.name_scope(name, 'sharded_mutable_hash_table') as scope: + super(ShardedMutableDenseHashTable, self).__init__(key_dtype, + value_dtype, scope) + table_shards = [] + for i in range(num_shards): + table_shards.append( + lookup_ops.MutableDenseHashTable( + key_dtype=key_dtype, + value_dtype=value_dtype, + default_value=default_value, + empty_key=empty_key, + name='%s-%d-of-%d' % (name, i + 1, num_shards))) + self._table_shards = table_shards + # TODO(andreasst): add a value_shape() method to LookupInterface + # pylint: disable=protected-access + self._value_shape = self._table_shards[0]._value_shape + # pylint: enable=protected-access + + @property + def _num_shards(self): + return len(self._table_shards) + + @property + def table_shards(self): + return self._table_shards + + def size(self, name=None): + with ops.name_scope(name, 'sharded_mutable_hash_table_size'): + sizes = [ + self._table_shards[i].size() for i in range(self._num_shards) + ] + return math_ops.add_n(sizes) + + def _shard_indices(self, keys): + key_shape = keys.get_shape() + if key_shape.ndims > 1: + # If keys are a matrix (i.e. a single key is a vector), we use the first + # element of each key vector to determine the shard. + keys = array_ops.slice(keys, [0, 0], [key_shape[0].value, 1]) + keys = array_ops.reshape(keys, [-1]) + indices = math_ops.mod(math_ops.abs(keys), self._num_shards) + return math_ops.cast(indices, dtypes.int32) + + def _check_keys(self, keys): + if not keys.get_shape().is_fully_defined(): + raise ValueError('Key shape must be fully defined, got %s.' % + keys.get_shape()) + if keys.get_shape().ndims != 1 and keys.get_shape().ndims != 2: + raise ValueError('Expected a vector or matrix for keys, got %s.' % + keys.get_shape()) + + def lookup(self, keys, name=None): + if keys.dtype != self._key_dtype: + raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' % + (self._key_dtype, keys.dtype)) + self._check_keys(keys) + num_shards = self._num_shards + if num_shards == 1: + return self._table_shards[0].lookup(keys, name=name) + + shard_indices = self._shard_indices(keys) + # TODO(andreasst): support 'keys' that are not vectors + key_shards = data_flow_ops.dynamic_partition(keys, shard_indices, + num_shards) + value_shards = [ + self._table_shards[i].lookup(key_shards[i], name=name) + for i in range(num_shards) + ] + + num_keys = keys.get_shape().dims[0] + original_indices = math_ops.range(num_keys) + partitioned_indices = data_flow_ops.dynamic_partition(original_indices, + shard_indices, + num_shards) + result = data_flow_ops.dynamic_stitch(partitioned_indices, value_shards) + result.set_shape( + tensor_shape.TensorShape([num_keys]).concatenate(self._value_shape)) + return result + + def insert(self, keys, values, name=None): + self._check_keys(keys) + num_shards = self._num_shards + if num_shards == 1: + return self._table_shards[0].insert(keys, values, name=name) + + shard_indices = self._shard_indices(keys) + # TODO(andreasst): support 'keys' that are not vectors + key_shards = data_flow_ops.dynamic_partition(keys, shard_indices, + num_shards) + value_shards = data_flow_ops.dynamic_partition(values, shard_indices, + num_shards) + return_values = [ + self._table_shards[i].insert(key_shards[i], value_shards[i], name=name) + for i in range(num_shards) + ] + + return control_flow_ops.group(*return_values) + + def export_sharded(self, name=None): + """Returns lists of the keys and values tensors in the sharded table. + + Args: + name: name of the table. + + Returns: + A pair of lists with the first list containing the key tensors and the + second list containing the value tensors from each shard. + """ + keys_list = [] + values_list = [] + for table_shard in self._table_shards: + exported_keys, exported_values = table_shard.export(name=name) + keys_list.append(exported_keys) + values_list.append(exported_values) + return keys_list, values_list diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py new file mode 100644 index 0000000000..8c83700d51 --- /dev/null +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py @@ -0,0 +1,97 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sharded_mutable_dense_hashtable.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.linear_optimizer.python.ops.sharded_mutable_dense_hashtable import ShardedMutableDenseHashTable +from tensorflow.python.framework.test_util import TensorFlowTestCase +from tensorflow.python.platform import googletest + + +class ShardedMutableDenseHashTableTest(TensorFlowTestCase): + """Tests for the ShardedMutableHashTable class.""" + + def testShardedMutableHashTable(self): + for num_shards in [1, 3, 10]: + with self.test_session(): + default_val = -1 + empty_key = 0 + keys = tf.constant([11, 12, 13], tf.int64) + values = tf.constant([0, 1, 2], tf.int64) + table = ShardedMutableDenseHashTable( + tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards) + self.assertAllEqual(0, table.size().eval()) + + table.insert(keys, values).run() + self.assertAllEqual(3, table.size().eval()) + + input_string = tf.constant([11, 12, 14], tf.int64) + output = table.lookup(input_string) + self.assertAllEqual([3], output.get_shape()) + self.assertAllEqual([0, 1, -1], output.eval()) + + def testShardedMutableHashTableVectors(self): + for num_shards in [1, 3, 10]: + with self.test_session(): + default_val = [-0.1, 0.2] + empty_key = [0, 1] + keys = tf.constant([[11, 12], [13, 14], [15, 16]], tf.int64) + values = tf.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]], tf.float32) + table = ShardedMutableDenseHashTable( + tf.int64, tf.float32, default_val, empty_key, num_shards=num_shards) + self.assertAllEqual(0, table.size().eval()) + + table.insert(keys, values).run() + self.assertAllEqual(3, table.size().eval()) + + input_string = tf.constant([[11, 12], [13, 14], [11, 14]], tf.int64) + output = table.lookup(input_string) + self.assertAllEqual([3, 2], output.get_shape()) + self.assertAllClose([[0.5, 0.6], [1.5, 1.6], [-0.1, 0.2]], + output.eval()) + + def testExportSharded(self): + with self.test_session(): + empty_key = -2 + default_val = -1 + num_shards = 2 + keys = tf.constant([10, 11, 12], tf.int64) + values = tf.constant([2, 3, 4], tf.int64) + table = ShardedMutableDenseHashTable( + tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards) + self.assertAllEqual(0, table.size().eval()) + + table.insert(keys, values).run() + self.assertAllEqual(3, table.size().eval()) + + keys_list, values_list = table.export_sharded() + self.assertAllEqual(num_shards, len(keys_list)) + self.assertAllEqual(num_shards, len(values_list)) + + # Exported keys include empty key buckets set to the empty_key + self.assertAllEqual(set([-2, 10, 12]), set(keys_list[0].eval().flatten())) + self.assertAllEqual(set([-2, 11]), set(keys_list[1].eval().flatten())) + # Exported values include empty value buckets set to 0 + self.assertAllEqual(set([0, 2, 4]), set(values_list[0].eval().flatten())) + self.assertAllEqual(set([0, 3]), set(values_list[1].eval().flatten())) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py new file mode 100644 index 0000000000..ed7105b5c9 --- /dev/null +++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py @@ -0,0 +1,114 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Sparse feature column.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework.ops import internal_convert_to_tensor +from tensorflow.python.framework.ops import name_scope + + +class SparseFeatureColumn(object): + """Represents a sparse feature column. + + Contains three tensors representing a sparse feature column, they are + example indices (int64), feature indices (int64), and feature values (float). + Feature weights are optional, and are treated as 1.0f if missing. + + For example, consider a batch of 4 examples, which contains the following + features in a particular SparseFeatureColumn: + Example 0: feature 5, value 1 + Example 1: feature 6, value 1 and feature 10, value 0.5 + Example 2: no features + Example 3: two copies of feature 2, value 1 + + This SparseFeatureColumn will be represented as follows: + <0, 5, 1> + <1, 6, 1> + <1, 10, 0.5> + <3, 2, 1> + <3, 2, 1> + + For a batch of 2 examples below: + Example 0: feature 5 + Example 1: feature 6 + + is represented by SparseFeatureColumn as: + <0, 5, 1> + <1, 6, 1> + + ``` + + @@__init__ + @@example_indices + @@feature_indices + @@feature_values + """ + + def __init__(self, example_indices, feature_indices, feature_values): + """Creates a `SparseFeatureColumn` representation. + + Args: + example_indices: A 1-D int64 tensor of shape `[N]`. Also, accepts + python lists, or numpy arrays. + feature_indices: A 1-D int64 tensor of shape `[N]`. Also, accepts + python lists, or numpy arrays. + feature_values: An optional 1-D tensor float tensor of shape `[N]`. Also, + accepts python lists, or numpy arrays. + + Returns: + A `SparseFeatureColumn` + """ + with name_scope(None, 'SparseFeatureColumn', + [example_indices, feature_indices]): + self._example_indices = internal_convert_to_tensor( + example_indices, name='example_indices', dtype=dtypes.int64) + self._feature_indices = internal_convert_to_tensor( + feature_indices, name='feature_indices', dtype=dtypes.int64) + self._feature_values = None + if feature_values is not None: + with name_scope(None, 'SparseFeatureColumn', [feature_values]): + self._feature_values = internal_convert_to_tensor( + feature_values, name='feature_values', dtype=dtypes.float32) + + @property + def example_indices(self): + """The example indices represented as a dense tensor. + + Returns: + A 1-D Tensor of int64 with shape `[N]`. + """ + return self._example_indices + + @property + def feature_indices(self): + """The feature indices represented as a dense tensor. + + Returns: + A 1-D Tensor of int64 with shape `[N]`. + """ + return self._feature_indices + + @property + def feature_values(self): + """The feature values represented as a dense tensor. + + Returns: + May return None, or a 1-D Tensor of float32 with shape `[N]`. + """ + return self._feature_values diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py new file mode 100644 index 0000000000..f2e4ca0c88 --- /dev/null +++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py @@ -0,0 +1,51 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sparse_feature_column.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.linear_optimizer.python.ops.sparse_feature_column import SparseFeatureColumn +from tensorflow.python.framework.test_util import TensorFlowTestCase +from tensorflow.python.platform import googletest + + +class SparseFeatureColumnTest(TensorFlowTestCase): + """Tests for SparseFeatureColumn. + """ + + def testBasic(self): + expected_example_indices = [1, 1, 1, 2] + expected_feature_indices = [0, 1, 2, 0] + sfc = SparseFeatureColumn(expected_example_indices, + expected_feature_indices, None) + self.assertTrue(isinstance(sfc.example_indices, tf.Tensor)) + self.assertTrue(isinstance(sfc.feature_indices, tf.Tensor)) + self.assertEqual(sfc.feature_values, None) + with self.test_session(): + self.assertAllEqual(expected_example_indices, sfc.example_indices.eval()) + self.assertAllEqual(expected_feature_indices, sfc.feature_indices.eval()) + expected_feature_values = [1.0, 2.0, 3.0, 4.0] + sfc = SparseFeatureColumn([1, 1, 1, 2], [0, 1, 2, 0], + expected_feature_values) + with self.test_session(): + self.assertAllEqual(expected_feature_values, sfc.feature_values.eval()) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py index 6ff4bf3175..644347f0b5 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py @@ -18,6 +18,7 @@ from __future__ import print_function from tensorflow.contrib import layers from tensorflow.contrib.linear_optimizer.python.ops import sdca_ops +from tensorflow.contrib.linear_optimizer.python.ops.sparse_feature_column import SparseFeatureColumn from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -86,7 +87,7 @@ class SDCAOptimizer(object): sparse_values = array_ops.gather_nd(dense_tensor, sparse_indices) # TODO(sibyl-Aix6ihai, sibyl-vie3Poto): Makes this efficient, as now SDCA supports # very sparse features with weights and not weights. - return sdca_ops.SparseFeatureColumn( + return SparseFeatureColumn( array_ops.reshape( array_ops.split(1, 2, sparse_indices)[0], [-1]), array_ops.reshape( @@ -134,7 +135,7 @@ class SDCAOptimizer(object): columns_to_variables[column][0]) elif isinstance(column, (layers.feature_column._CrossedColumn, layers.feature_column._SparseColumn)): - sparse_features.append(sdca_ops.SparseFeatureColumn( + sparse_features.append(SparseFeatureColumn( array_ops.reshape( array_ops.split(1, 2, transformed_tensor.indices)[0], [-1]), array_ops.reshape(transformed_tensor.values, [-1]), None)) @@ -142,7 +143,7 @@ class SDCAOptimizer(object): elif isinstance(column, layers.feature_column._WeightedSparseColumn): id_tensor = column.id_tensor(transformed_tensor) weight_tensor = column.weight_tensor(transformed_tensor) - sparse_feature_with_values.append(sdca_ops.SparseFeatureColumn( + sparse_feature_with_values.append(SparseFeatureColumn( array_ops.reshape( array_ops.split(1, 2, id_tensor.indices)[0], [-1]), array_ops.reshape(id_tensor.values, [-1]), array_ops.reshape( diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 7610f9275f..c17b251d3e 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -117,9 +117,9 @@ def _safe_div(numerator, denominator, name="value"): Returns: The element-wise value of the numerator divided by the denominator. """ - return math_ops.select( + return array_ops.where( math_ops.greater(denominator, 0), - math_ops.div(numerator, math_ops.select( + math_ops.div(numerator, array_ops.where( math_ops.equal(denominator, 0), array_ops.ones_like(denominator), denominator)), array_ops.zeros_like(numerator), @@ -144,12 +144,13 @@ def _safe_mean(losses, num_present): @deprecated_args( "2016-11-25", "`weight` is being deprecated, use `weights`.", "weight") def compute_weighted_loss( - losses, weights=_WEIGHT_SENTINEL, weight=_WEIGHT_SENTINEL): + losses, weights=_WEIGHT_SENTINEL, scope=None, weight=_WEIGHT_SENTINEL): """Computes the weighted loss. Args: losses: A tensor of size [batch_size, d1, ... dN]. weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N. + scope: the scope for the operations performed in computing the loss. weight: Deprecated alias for `weights`. Returns: @@ -161,27 +162,28 @@ def compute_weighted_loss( `weights` is missing. """ weights = _weights(weights, weight) - losses = ops.convert_to_tensor(losses) - input_dtype = losses.dtype - losses = math_ops.to_float(losses) - weights = math_ops.to_float(ops.convert_to_tensor(weights)) + with ops.name_scope(scope, "weighted_loss", [losses, weights]): + losses = ops.convert_to_tensor(losses) + input_dtype = losses.dtype + losses = math_ops.to_float(losses) + weights = math_ops.to_float(ops.convert_to_tensor(weights)) - if losses.get_shape().ndims is None: - raise ValueError("losses.get_shape().ndims cannot be None") - weights_shape = weights.get_shape() - if weights_shape.ndims is None: - raise ValueError("weight.get_shape().ndims cannot be None") + if losses.get_shape().ndims is None: + raise ValueError("losses.get_shape().ndims cannot be None") + weights_shape = weights.get_shape() + if weights_shape.ndims is None: + raise ValueError("weight.get_shape().ndims cannot be None") - if weights_shape.ndims > 1 and weights_shape.dims[-1].is_compatible_with(1): - weights = array_ops.squeeze(weights, [-1]) + if weights_shape.ndims > 1 and weights_shape.dims[-1].is_compatible_with(1): + weights = array_ops.squeeze(weights, [-1]) - total_loss = _scale_losses(losses, weights) - num_present = _num_present(losses, weights) - mean_loss = _safe_mean(total_loss, num_present) - # convert the result back to the input type - mean_loss = math_ops.cast(mean_loss, input_dtype) - add_loss(mean_loss) - return mean_loss + total_loss = _scale_losses(losses, weights) + num_present = _num_present(losses, weights) + mean_loss = _safe_mean(total_loss, num_present) + # convert the result back to the input type + mean_loss = math_ops.cast(mean_loss, input_dtype) + add_loss(mean_loss) + return mean_loss def _num_present(losses, weights, per_batch=False): @@ -211,7 +213,7 @@ def _num_present(losses, weights, per_batch=False): [0], [1]), []) num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)), math_ops.to_float(batch_size)) - num_per_batch = math_ops.select(math_ops.equal(weights, 0), + num_per_batch = array_ops.where(math_ops.equal(weights, 0), 0.0, num_per_batch) num_per_batch = math_ops.mul(array_ops.ones( array_ops.reshape(batch_size, [1])), num_per_batch) @@ -334,7 +336,7 @@ def absolute_difference( predictions = math_ops.to_float(predictions) labels = math_ops.to_float(labels) losses = math_ops.abs(math_ops.sub(predictions, labels)) - return compute_weighted_loss(losses, weights) + return compute_weighted_loss(losses, weights, scope=scope) @deprecated_args( @@ -373,7 +375,7 @@ def sigmoid_cross_entropy( """ weights = _weights(weights, weight) with ops.name_scope(scope, "sigmoid_cross_entropy_loss", - [logits, multi_class_labels, weights]): + [logits, multi_class_labels, weights]) as scope: logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape()) multi_class_labels = math_ops.cast(multi_class_labels, logits.dtype) @@ -384,7 +386,7 @@ def sigmoid_cross_entropy( losses = nn.sigmoid_cross_entropy_with_logits(logits, multi_class_labels, name="xentropy") - return compute_weighted_loss(losses, weights) + return compute_weighted_loss(losses, weights, scope=scope) @deprecated_args( @@ -421,7 +423,7 @@ def softmax_cross_entropy( """ weights = _weights(weights, weight) with ops.name_scope(scope, "softmax_cross_entropy_loss", - [logits, onehot_labels, weights]): + [logits, onehot_labels, weights]) as scope: logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape()) onehot_labels = math_ops.cast(onehot_labels, logits.dtype) @@ -435,7 +437,7 @@ def softmax_cross_entropy( losses = nn.softmax_cross_entropy_with_logits(logits, onehot_labels, name="xentropy") - return compute_weighted_loss(losses, weights) + return compute_weighted_loss(losses, weights, scope=scope) @deprecated_args( @@ -468,13 +470,13 @@ def sparse_softmax_cross_entropy( """ weights = _weights(weights, weight) with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss", - [logits, labels, weights]): + [logits, labels, weights]) as scope: labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]]) weights = array_ops.squeeze(weights) losses = nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name="xentropy") - return compute_weighted_loss(losses, weights) + return compute_weighted_loss(losses, weights, scope=scope) @deprecated_args( @@ -523,7 +525,7 @@ def log_loss( labels, math_ops.log(predictions + epsilon)) - math_ops.mul( (1 - labels), math_ops.log(1 - predictions + epsilon)) - return compute_weighted_loss(losses, weights) + return compute_weighted_loss(losses, weights, scope=scope) @deprecated_args( @@ -597,7 +599,7 @@ def mean_squared_error( predictions = math_ops.to_float(predictions) labels = math_ops.to_float(labels) losses = math_ops.square(math_ops.sub(predictions, labels)) - return compute_weighted_loss(losses, weights) + return compute_weighted_loss(losses, weights, scope=scope) @deprecated_args( @@ -681,7 +683,7 @@ def mean_pairwise_squared_error( loss = _scale_losses(term1 - term2, weights) - mean_loss = math_ops.select(math_ops.reduce_sum(num_present_per_batch) > 0, + mean_loss = array_ops.where(math_ops.reduce_sum(num_present_per_batch) > 0, loss, array_ops.zeros_like(loss), name="value") @@ -732,4 +734,4 @@ def cosine_distance( radial_diffs = math_ops.mul(predictions, labels) losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[dim,]) - return compute_weighted_loss(losses, weights) + return compute_weighted_loss(losses, weights, scope=scope) diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index 380f2b440a..4c0d32f115 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -204,6 +204,7 @@ ifeq ($(TARGET),PI) endif # Set up Android building +# LINT.IfChange ifeq ($(TARGET),ANDROID) # Override NDK_ROOT on the command line with your own NDK location, e.g. # make -f tensorflow/contrib/makefile/Makefile TARGET=ANDROID \ @@ -276,6 +277,7 @@ ifeq ($(TARGET),ANDROID) endif endif # ANDROID +# LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt) # Settings for iOS. ifeq ($(TARGET),IOS) diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index d07d1508a3..d39dc1d430 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -75,6 +75,8 @@ tensorflow/core/kernels/padding_fifo_queue.cc tensorflow/core/kernels/pad_op.cc tensorflow/core/kernels/pack_op.cc tensorflow/core/kernels/ops_util.cc +tensorflow/core/kernels/one_hot_op.cc +tensorflow/core/kernels/non_max_suppression_op.cc tensorflow/core/kernels/no_op.cc tensorflow/core/kernels/mirror_pad_op.cc tensorflow/core/kernels/mirror_pad_op_cpu_impl_1.cc diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 1199f8dd95..c6d6b50e90 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -52,7 +52,7 @@ def _safe_div(numerator, denominator, name): Returns: 0 if `denominator` <= 0, else `numerator` / `denominator` """ - return math_ops.select( + return array_ops.where( math_ops.greater(denominator, 0), math_ops.truediv(numerator, denominator), 0, @@ -587,7 +587,7 @@ def streaming_precision(predictions, labels, weights=None, updates_collections=None, name=None) def compute_precision(name): - return math_ops.select( + return array_ops.where( math_ops.greater(true_positives + false_positives, 0), math_ops.div(true_positives, true_positives + false_positives), 0, @@ -661,7 +661,7 @@ def streaming_recall(predictions, labels, weights=None, updates_collections=None, name=None) def compute_recall(true_positives, false_negatives, name): - return math_ops.select( + return array_ops.where( math_ops.greater(true_positives + false_negatives, 0), math_ops.div(true_positives, true_positives + false_negatives), 0, @@ -2388,7 +2388,7 @@ def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, predictions, normalizer = tensor_util.remove_squeezable_dimensions( predictions, normalizer) predictions.get_shape().assert_is_compatible_with(normalizer.get_shape()) - relative_errors = math_ops.select( + relative_errors = array_ops.where( math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels), math_ops.div(math_ops.abs(labels - predictions), normalizer)) @@ -2923,7 +2923,7 @@ def streaming_mean_iou(predictions, # If the value of the denominator is 0, set it to 1 to avoid # zero division. - denominator = math_ops.select( + denominator = array_ops.where( math_ops.greater(denominator, 0), denominator, array_ops.ones_like(denominator)) diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 9890e712c1..80ca709b3a 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -998,7 +999,7 @@ class BidirectionalGridLSTMCell(GridLSTMCell): # pylint: disable=protected-access -_linear = rnn_cell._linear +_linear = rnn_cell_impl._linear # pylint: enable=protected-access diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD index 7750e54569..b6bfcc748a 100644 --- a/tensorflow/contrib/session_bundle/BUILD +++ b/tensorflow/contrib/session_bundle/BUILD @@ -254,13 +254,6 @@ cc_library( load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") -filegroup( - name = "saved_model_half_plus_two", - srcs = glob([ - "testdata/saved_model_half_plus_two/**", - ]), -) - cc_library( name = "bundle_shim", srcs = ["bundle_shim.cc"], @@ -287,8 +280,8 @@ cc_test( size = "small", srcs = ["bundle_shim_test.cc"], data = [ - ":saved_model_half_plus_two", "//tensorflow/contrib/session_bundle/example:half_plus_two", + "//tensorflow/python/saved_model/example:saved_model_half_plus_two_data", ], # Link in all registered kernels. linkstatic = 1, diff --git a/tensorflow/contrib/session_bundle/bundle_shim.cc b/tensorflow/contrib/session_bundle/bundle_shim.cc index 47a0935472..1ce2753c57 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim.cc +++ b/tensorflow/contrib/session_bundle/bundle_shim.cc @@ -127,10 +127,16 @@ Status ConvertNamedSignaturesToSignatureDef(const Signatures& signatures, AddOutputToSignatureDef(map_entry.second.tensor_name(), map_entry.first, &signature_def); } - // Add the `default` key to the signature def map of the meta graph def and - // map it to the constructed signature def. - (*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] = - signature_def; + // Add the constructed signature def to the signature def map of the meta + // graph def. Use the default key if it isn't already in use. + const bool already_has_default_signature = + meta_graph_def->signature_def().find(kDefaultServingSignatureDefKey) != + meta_graph_def->signature_def().end(); + const string signature_def_key = + already_has_default_signature + ? strings::StrCat(kDefaultServingSignatureDefKey, "_from_named") + : kDefaultServingSignatureDefKey; + (*meta_graph_def->mutable_signature_def())[signature_def_key] = signature_def; return Status::OK(); } @@ -138,9 +144,12 @@ Status ConvertSignaturesToSignatureDef(MetaGraphDef* meta_graph_def) { Signatures signatures; GetSignatures(*meta_graph_def, &signatures); if (signatures.has_default_signature()) { - return ConvertDefaultSignatureToSignatureDef(signatures, meta_graph_def); - } else if (!signatures.named_signatures().empty()) { - return ConvertNamedSignaturesToSignatureDef(signatures, meta_graph_def); + TF_RETURN_IF_ERROR( + ConvertDefaultSignatureToSignatureDef(signatures, meta_graph_def)); + } + if (!signatures.named_signatures().empty()) { + TF_RETURN_IF_ERROR( + ConvertNamedSignaturesToSignatureDef(signatures, meta_graph_def)); } return Status::OK(); } diff --git a/tensorflow/contrib/session_bundle/bundle_shim_test.cc b/tensorflow/contrib/session_bundle/bundle_shim_test.cc index cfdd05e608..a8dca12195 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim_test.cc +++ b/tensorflow/contrib/session_bundle/bundle_shim_test.cc @@ -35,7 +35,7 @@ constexpr char kSessionBundlePath[] = constexpr char kSessionBundleMetaGraphFilename[] = "export.meta"; constexpr char kSessionBundleVariablesFilename[] = "export-00000-of-00001"; constexpr char kSavedModelBundlePath[] = - "contrib/session_bundle/testdata/saved_model_half_plus_two"; + "python/saved_model/example/saved_model_half_plus_two/00000123"; string MakeSerializedExample(float x) { tensorflow::Example example; @@ -72,16 +72,20 @@ void LoadAndValidateSavedModelBundle(const string& export_dir, session_options, run_options, export_dir, tags, &saved_model_bundle)); const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def; const auto& signature_def_map = meta_graph_def.signature_def(); - EXPECT_EQ(1, signature_def_map.size()); const auto& regression_entry = signature_def_map.find(signature_def_key); + ASSERT_FALSE(regression_entry == signature_def_map.end()); SignatureDef regression_signature_def = regression_entry->second; EXPECT_EQ(1, regression_signature_def.inputs_size()); + ASSERT_FALSE(regression_signature_def.inputs().find(kRegressInputs) == + regression_signature_def.inputs().end()); TensorInfo input_tensor_info = regression_signature_def.inputs().find(kRegressInputs)->second; EXPECT_EQ(1, regression_signature_def.outputs_size()); + ASSERT_FALSE(regression_signature_def.outputs().find(kRegressOutputs) == + regression_signature_def.outputs().end()); TensorInfo output_tensor_info = regression_signature_def.outputs().find(kRegressOutputs)->second; ValidateHalfPlusTwo(saved_model_bundle, input_tensor_info.name(), @@ -261,9 +265,14 @@ TEST(BundleShimTest, NamedSignatureGenericInputsAndOutputs) { EXPECT_EQ(1, meta_graph_def.signature_def_size()); const auto actual_signature_def = meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey); + ASSERT_FALSE(actual_signature_def == meta_graph_def.signature_def().end()); + ASSERT_FALSE(actual_signature_def->second.inputs().find("foo-input") == + actual_signature_def->second.inputs().end()); EXPECT_EQ( "foo-input", actual_signature_def->second.inputs().find("foo-input")->second.name()); + ASSERT_FALSE(actual_signature_def->second.outputs().find("foo-output") == + actual_signature_def->second.outputs().end()); EXPECT_EQ( "foo-output", actual_signature_def->second.outputs().find("foo-output")->second.name()); @@ -318,10 +327,40 @@ TEST(BundleShimTest, NamedSignatureGenericOnlyInput) { // Checks a basic up conversion for half plus two for SessionBundle. TEST(BundleShimTest, BasicExportSessionBundle) { + const std::unordered_set<string> tags = {"tag"}; const string session_bundle_export_dir = test_util::TestSrcDirPath(kSessionBundlePath); - LoadAndValidateSavedModelBundle(session_bundle_export_dir, {"tag"}, + LoadAndValidateSavedModelBundle(session_bundle_export_dir, tags, kDefaultServingSignatureDefKey); + + // Verify that the named signature is also present. + SessionOptions session_options; + RunOptions run_options; + SavedModelBundle saved_model_bundle; + TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(session_options, run_options, + session_bundle_export_dir, + tags, &saved_model_bundle)); + const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def; + const auto& signature_def_map = meta_graph_def.signature_def(); + bool found_named_signature = false; + for (const auto& entry : signature_def_map) { + const string& key = entry.first; + const SignatureDef& signature_def = entry.second; + + // We're looking for the key that is *not* kDefaultServingSignatureDefKey. + if (key == kDefaultServingSignatureDefKey) { + continue; + } + found_named_signature = true; + + EXPECT_EQ(1, signature_def.inputs_size()); + EXPECT_FALSE(signature_def.inputs().find("x") == + signature_def.inputs().end()); + EXPECT_EQ(1, signature_def.outputs_size()); + EXPECT_FALSE(signature_def.outputs().find("y") == + signature_def.outputs().end()); + } + EXPECT_TRUE(found_named_signature); } // Checks a basic load for half plus two for SavedModelBundle. diff --git a/tensorflow/contrib/session_bundle/session_bundle.cc b/tensorflow/contrib/session_bundle/session_bundle.cc index 2b608a1348..bc6fdcd4de 100644 --- a/tensorflow/contrib/session_bundle/session_bundle.cc +++ b/tensorflow/contrib/session_bundle/session_bundle.cc @@ -43,12 +43,13 @@ namespace serving { namespace { auto* load_attempt_count = monitoring::Counter<2>::New( - "/tensorflow/contrib/session_bundle/load_attempt_count", "model_path", - "status", - "The number of times a SessionBundle was requested to be loaded."); + "/tensorflow/contrib/session_bundle/load_attempt_count", + "The number of times a SessionBundle was requested to be loaded.", + "model_path", "status"); auto* load_latency = monitoring::Counter<1>::New( - "/tensorflow/contrib/session_bundle/load_latency", "model_path", - "Latency in microseconds for SessionBundles that were successfully loaded."); + "/tensorflow/contrib/session_bundle/load_latency", + "Latency in microseconds for SessionBundles that were successfully loaded.", + "model_path"); constexpr char kLoadAttemptFail[] = "fail"; constexpr char kLoadAttemptSuccess[] = "success"; diff --git a/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/saved_model.pb b/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/saved_model.pb Binary files differdeleted file mode 100644 index e894f9b101..0000000000 --- a/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/saved_model.pb +++ /dev/null diff --git a/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/variables/variables.data-00000-of-00001 b/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/variables/variables.data-00000-of-00001 Binary files differdeleted file mode 100644 index 20bc7d454d..0000000000 --- a/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/variables/variables.data-00000-of-00001 +++ /dev/null diff --git a/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/variables/variables.index b/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/variables/variables.index Binary files differdeleted file mode 100644 index e7df518f5b..0000000000 --- a/tensorflow/contrib/session_bundle/testdata/saved_model_half_plus_two/variables/variables.index +++ /dev/null diff --git a/tensorflow/contrib/slim/python/slim/evaluation.py b/tensorflow/contrib/slim/python/slim/evaluation.py index b122799a32..b89eca46ea 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation.py +++ b/tensorflow/contrib/slim/python/slim/evaluation.py @@ -86,7 +86,7 @@ more summaries and call the evaluation_loop method: logdir, num_evals=num_evals, eval_op=names_to_updates.values(), - summary_op=tf.merge_summary(summary_ops), + summary_op=tf.contrib.deprecated.merge_summary(summary_ops), eval_interval_secs=600) ************************************************** @@ -113,7 +113,7 @@ with only summaries. The user need only leave out the 'eval_op' argument: checkpoint_dir, logdir, num_evals=1, - summary_op=tf.merge_summary(summary_ops), + summary_op=tf.contrib.deprecated.merge_summary(summary_ops), eval_interval_secs=600) """ @@ -122,160 +122,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import time - -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.contrib.training.python.training import evaluation from tensorflow.python import summary -from tensorflow.python.framework import ops -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import saver as tf_saver -from tensorflow.python.training import summary_io -from tensorflow.python.training import supervisor -from tensorflow.python.training import training_util +from tensorflow.python.training import monitored_session __all__ = [ 'evaluate_once', - 'evaluation', 'evaluation_loop', 'wait_for_new_checkpoint', 'checkpoints_iterator', ] - -def wait_for_new_checkpoint(checkpoint_dir, - last_checkpoint, - seconds_to_sleep=1, - timeout=None): - """Waits until a new checkpoint file is found. - - Args: - checkpoint_dir: The directory in which checkpoints are saved. - last_checkpoint: The last checkpoint path used. - seconds_to_sleep: The number of seconds to sleep for before looking for a - new checkpoint. - timeout: The maximum amount of time to wait. If left as `None`, then the - process will wait indefinitely. - - Returns: - a new checkpoint path, or None if the timeout was reached. - """ - logging.info('Waiting for new checkpoint at %s', checkpoint_dir) - stop_time = time.time() + timeout if timeout is not None else None - while True: - checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir) - if checkpoint_path is None or checkpoint_path == last_checkpoint: - if stop_time is not None and time.time() + seconds_to_sleep > stop_time: - return None - time.sleep(seconds_to_sleep) - else: - logging.info('Found new checkpoint at %s', checkpoint_path) - return checkpoint_path - - -def checkpoints_iterator(checkpoint_dir, - min_interval_secs=0, - timeout=None): - """Continuously yield new checkpoint files as they appear. - - The iterator only checks for new checkpoints when control flow has been - reverted to it. This means it can miss checkpoints if your code takes longer - to run between iterations than `min_interval_secs` or the interval at which - new checkpoints are written. - - Args: - checkpoint_dir: The directory in which checkpoints are saved. - min_interval_secs: The minimum number of seconds between yielding - checkpoints. - timeout: The maximum amount of time to wait between checkpoints. If left as - `None`, then the process will wait indefinitely. - - Yields: - String paths to latest checkpoint files as they arrive. Stops yielding only - if/when waiting for a checkpoint times out. - """ - checkpoint_path = None - while True: - checkpoint_path = wait_for_new_checkpoint( - checkpoint_dir, checkpoint_path, timeout=timeout) - if checkpoint_path is None: - # timed out - return - start = time.time() - yield checkpoint_path - time_to_next_eval = start + min_interval_secs - time.time() - if time_to_next_eval > 0: - time.sleep(time_to_next_eval) - - -def evaluation(sess, - num_evals=1, - initial_op=None, - initial_op_feed_dict=None, - eval_op=None, - eval_op_feed_dict=None, - final_op=None, - final_op_feed_dict=None, - summary_op=None, - summary_op_feed_dict=None, - summary_writer=None, - global_step=None): - """Performs a single evaluation run. - - A single evaluation consists of several steps run in the following order: - (1) an initialization op, (2) an evaluation op which is executed `num_evals` - times (3) a finalization op and (4) the execution of a summary op which is - written out using a summary writer. - - Args: - sess: The current TensorFlow `Session`. - num_evals: The number of times to execute `eval_op`. - initial_op: An operation run at the beginning of evaluation. - initial_op_feed_dict: A feed dictionary to use when executing `initial_op`. - eval_op: A operation run `num_evals` times. - eval_op_feed_dict: The feed dictionary to use when executing the `eval_op`. - final_op: An operation to execute after all of the `eval_op` executions. The - value of `final_op` is returned. - final_op_feed_dict: A feed dictionary to use when executing `final_op`. - summary_op: A summary op executed after `eval_op` and `finalize_op`. - summary_op_feed_dict: An optional feed dictionary to use when executing the - `summary_op`. - summary_writer: The summery writer used if `summary_op` is provided. - global_step: the global step variable. If left as `None`, then - slim.variables.global_step() is used. - - Returns: - The value of `final_op` or `None` if `final_op` is `None`. - - Raises: - ValueError: if `summary_op` is provided but `global_step` is `None`. - """ - if initial_op is not None: - logging.info('Executing initial eval op') - sess.run(initial_op, initial_op_feed_dict) - - if eval_op is not None: - logging.info('Executing eval ops') - for i in range(int(num_evals)): - logging.info('Executing eval_op %d/%d', i + 1, num_evals) - sess.run(eval_op, eval_op_feed_dict) - - if final_op is not None: - logging.info('Executing final op') - final_op_value = sess.run(final_op, final_op_feed_dict) - else: - final_op_value = None - - if summary_op is not None: - logging.info('Executing summary op') - if global_step is None: - global_step = variables.get_or_create_global_step() - - global_step = training_util.global_step(sess, global_step) - summary_str = sess.run(summary_op, summary_op_feed_dict) - summary_writer.add_summary(summary_str, global_step) - summary_writer.flush() - - return final_op_value +wait_for_new_checkpoint = evaluation.wait_for_new_checkpoint +checkpoints_iterator = evaluation.checkpoints_iterator _USE_DEFAULT = 0 @@ -325,43 +184,27 @@ def evaluate_once(master, if summary_op == _USE_DEFAULT: summary_op = summary.merge_all() - global_step = variables.get_or_create_global_step() - - saver = tf_saver.Saver(variables_to_restore or - variables.get_variables_to_restore()) - - summary_writer = summary_io.SummaryWriter(logdir) - - sv = supervisor.Supervisor(graph=ops.get_default_graph(), - logdir=logdir, - summary_op=None, - summary_writer=None, - global_step=None, - saver=None) - - logging.info('Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S', - time.gmtime())) - with sv.managed_session( - master, start_standard_services=False, config=session_config) as sess: - saver.restore(sess, checkpoint_path) - sv.start_queue_runners(sess) - final_op_value = evaluation(sess, - num_evals=num_evals, - initial_op=initial_op, - initial_op_feed_dict=initial_op_feed_dict, - eval_op=eval_op, - eval_op_feed_dict=eval_op_feed_dict, - final_op=final_op, - final_op_feed_dict=final_op_feed_dict, - summary_op=summary_op, - summary_op_feed_dict=summary_op_feed_dict, - summary_writer=summary_writer, - global_step=global_step) - - logging.info('Finished evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S', - time.gmtime())) - - return final_op_value + hooks = [ + evaluation.StopAfterNEvalsHook(num_evals), + ] + + if summary_op is not None: + hooks.append( + evaluation.SummaryAtEndHook(logdir, summary_op, summary_op_feed_dict)) + + return evaluation.evaluate_once( + checkpoint_path, + master=master, + scaffold=monitored_session.Scaffold( + init_op=initial_op, + init_feed_dict=initial_op_feed_dict), + eval_ops=eval_op, + feed_dict=eval_op_feed_dict, + final_ops=final_op, + final_ops_feed_dict=final_op_feed_dict, + variables_to_restore=variables_to_restore, + hooks=hooks, + config=session_config) def evaluation_loop(master, @@ -416,53 +259,27 @@ def evaluation_loop(master, if summary_op == _USE_DEFAULT: summary_op = summary.merge_all() - global_step = variables.get_or_create_global_step() - - saver = tf_saver.Saver(variables_to_restore or - variables.get_variables_to_restore()) - - summary_writer = summary_io.SummaryWriter(logdir) - - sv = supervisor.Supervisor(graph=ops.get_default_graph(), - logdir=logdir, - summary_op=None, - summary_writer=None, - global_step=None, - saver=saver) - - number_of_evaluations = 0 - for checkpoint_path in checkpoints_iterator(checkpoint_dir, - eval_interval_secs, - timeout): - logging.info('Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S', - time.gmtime())) - - with sv.managed_session( - master, start_standard_services=False, config=session_config) as sess: - sv.saver.restore(sess, checkpoint_path) - sv.start_queue_runners(sess) - final_op_value = evaluation(sess, - num_evals=num_evals, - initial_op=initial_op, - initial_op_feed_dict=initial_op_feed_dict, - eval_op=eval_op, - eval_op_feed_dict=eval_op_feed_dict, - final_op=final_op, - final_op_feed_dict=final_op_feed_dict, - summary_op=summary_op, - summary_op_feed_dict=summary_op_feed_dict, - summary_writer=summary_writer, - global_step=global_step) - - logging.info('Finished evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S', - time.gmtime())) - number_of_evaluations += 1 - if (max_number_of_evaluations and - number_of_evaluations >= max_number_of_evaluations): - logging.info('Reached max_number_of_evaluations=%s. Exit', - max_number_of_evaluations) - return final_op_value - - logging.info( - 'Timed-out waiting for new checkpoint file. Exiting evaluation loop.') - return final_op_value + hooks = [ + evaluation.StopAfterNEvalsHook(num_evals), + ] + + if summary_op is not None: + hooks.append( + evaluation.SummaryAtEndHook(logdir, summary_op, summary_op_feed_dict)) + + return evaluation.evaluate_repeatedly( + checkpoint_dir, + master=master, + scaffold=monitored_session.Scaffold( + init_op=initial_op, + init_feed_dict=initial_op_feed_dict), + eval_ops=eval_op, + feed_dict=eval_op_feed_dict, + final_ops=final_op, + final_ops_feed_dict=final_op_feed_dict, + variables_to_restore=variables_to_restore, + eval_interval_secs=eval_interval_secs, + hooks=hooks, + config=session_config, + max_number_of_evaluations=max_number_of_evaluations, + timeout=timeout) diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py index b12c82e985..a308a515bd 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation_test.py +++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py @@ -73,28 +73,6 @@ class EvaluationTest(tf.test.TestCase): self._labels = tf.constant(labels, dtype=tf.int64) self._predictions, self._scale = TestModel(self._inputs) - def testUpdateOpsAreEvaluated(self): - accuracy, update_op = slim.metrics.streaming_accuracy( - self._predictions, self._labels) - initial_op = tf.group(tf.global_variables_initializer(), - tf.local_variables_initializer()) - - with self.test_session() as sess: - slim.evaluation.evaluation( - sess, initial_op=initial_op, eval_op=update_op) - self.assertAlmostEqual(accuracy.eval(), self._expected_accuracy) - - def testFinalOpsIsEvaluated(self): - _, update_op = slim.metrics.streaming_accuracy( - self._predictions, self._labels) - initial_op = tf.group(tf.global_variables_initializer(), - tf.local_variables_initializer()) - - with self.test_session() as sess: - accuracy_value = slim.evaluation.evaluation( - sess, initial_op=initial_op, final_op=update_op) - self.assertAlmostEqual(accuracy_value, self._expected_accuracy) - def testFinalOpsOnEvaluationLoop(self): value_op, update_op = slim.metrics.streaming_accuracy( self._predictions, self._labels) @@ -153,96 +131,6 @@ class EvaluationTest(tf.test.TestCase): for name in names_to_values: self.assertAlmostEqual(names_to_values[name], saved_results[name]) - def testSummariesAreFlushedToDisk(self): - output_dir = os.path.join(self.get_temp_dir(), 'flush_test') - if tf.gfile.Exists(output_dir): # For running on jenkins. - tf.gfile.DeleteRecursively(output_dir) - - names_to_metrics, names_to_updates = self._create_names_to_metrics( - self._predictions, self._labels) - - for k in names_to_metrics: - v = names_to_metrics[k] - tf.summary.scalar(k, v) - - summary_writer = tf.train.SummaryWriter(output_dir) - - initial_op = tf.group(tf.global_variables_initializer(), - tf.local_variables_initializer()) - eval_op = tf.group(*names_to_updates.values()) - - with self.test_session() as sess: - slim.evaluation.evaluation( - sess, - initial_op=initial_op, - eval_op=eval_op, - summary_op=tf.summary.merge_all(), - summary_writer=summary_writer, - global_step=self._global_step) - - names_to_values = {name: names_to_metrics[name].eval() - for name in names_to_metrics} - self._verify_summaries(output_dir, names_to_values) - - def testSummariesAreFlushedToDiskWithoutGlobalStep(self): - output_dir = os.path.join(self.get_temp_dir(), 'flush_test_no_global_step') - if tf.gfile.Exists(output_dir): # For running on jenkins. - tf.gfile.DeleteRecursively(output_dir) - - names_to_metrics, names_to_updates = self._create_names_to_metrics( - self._predictions, self._labels) - - for k in names_to_metrics: - v = names_to_metrics[k] - tf.summary.scalar(k, v) - - summary_writer = tf.train.SummaryWriter(output_dir) - - initial_op = tf.group(tf.global_variables_initializer(), - tf.local_variables_initializer()) - eval_op = tf.group(*names_to_updates.values()) - - with self.test_session() as sess: - slim.evaluation.evaluation( - sess, - initial_op=initial_op, - eval_op=eval_op, - summary_op=tf.summary.merge_all(), - summary_writer=summary_writer) - - names_to_values = {name: names_to_metrics[name].eval() - for name in names_to_metrics} - self._verify_summaries(output_dir, names_to_values) - - def testWithFeedDict(self): - accuracy, update_op = slim.metrics.streaming_accuracy( - self._predictions, self._labels) - initial_op = tf.group(tf.global_variables_initializer(), - tf.local_variables_initializer()) - - with self.test_session() as sess: - slim.evaluation.evaluation( - sess, - initial_op=initial_op, - eval_op=update_op, - eval_op_feed_dict={self._scale: np.ones([], dtype=np.float32)}) - self.assertAlmostEqual(accuracy.eval(), self._expected_accuracy) - - def testWithQueueRunning(self): - strings = ['the', 'cat', 'in', 'the', 'hat'] - _ = tf.train.string_input_producer(strings, capacity=5) - - accuracy, update_op = slim.metrics.streaming_accuracy( - self._predictions, self._labels) - - initial_op = tf.group(tf.global_variables_initializer(), - tf.local_variables_initializer()) - - with self.test_session() as sess: - slim.evaluation.evaluation( - sess, initial_op=initial_op, eval_op=update_op) - self.assertAlmostEqual(accuracy.eval(), self._expected_accuracy) - def testLatestCheckpointReturnsNoneAfterTimeout(self): start = time.time() ret = slim.evaluation.wait_for_new_checkpoint( @@ -259,38 +147,6 @@ class EvaluationTest(tf.test.TestCase): '/non-existent-dir', timeout=0)) self.assertEqual(ret, []) - def testEvaluationLoopTimeout(self): - _, update_op = slim.metrics.streaming_accuracy( - self._predictions, self._labels) - init_op = tf.group(tf.global_variables_initializer(), - tf.local_variables_initializer()) - - # Create checkpoint and log directories. - chkpt_dir = os.path.join(self.get_temp_dir(), 'tmp_logs/') - gfile.MakeDirs(chkpt_dir) - logdir = os.path.join(self.get_temp_dir(), 'tmp_logs2/') - gfile.MakeDirs(logdir) - - # Save initialized variables to checkpoint directory. - saver = tf.train.Saver() - with self.test_session() as sess: - init_op.run() - saver.save(sess, os.path.join(chkpt_dir, 'chkpt')) - - # Run the evaluation loop with a timeout. - with self.test_session() as sess: - start = time.time() - slim.evaluation.evaluation_loop( - '', chkpt_dir, logdir, eval_op=update_op, - eval_interval_secs=2.0, timeout=6.0) - end = time.time() - - # Check we've waited for the timeout. - self.assertGreater(end - start, 6.0) - - # Then the timeout kicked in and stops the loop. - self.assertLess(end - start, 8.0) - class SingleEvaluationTest(tf.test.TestCase): diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc index 77d7f4290d..626b6a1daf 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc @@ -214,26 +214,15 @@ class TreePredictions : public OpKernel { errors::InvalidArgument("node_index not in valid range.")) const int32 left_child = tree(node_index, CHILDREN_INDEX); if (left_child == LEAF_NODE) { - float sum = node_pcw(node_index, 0); - float parent_weight = 0.0; - if (sum < valid_leaf_threshold_ && parent >= 0) { - VLOG(1) << "not enough samples at leaf, including parent counts." - << "child sum = " << sum; - float parent_sum = node_pcw(parent, 0); - // Weight the parent's counts just enough so that the new sum is - // valid_leaf_threshold_, but never give any counts a weight of - // more than 1. - parent_weight = std::min(1.0f, - (valid_leaf_threshold_ - sum) / parent_sum); - sum += parent_weight * parent_sum; - VLOG(1) << "Sum w/ parent included = " << sum; - } + const int32 flat_leaf_index = node_index * num_classes + 1; + const int32 flat_parent_index = parent * num_classes + 1; + std::vector<float> means(num_classes - 1); + tensorforest::GetParentWeightedMean( + node_pcw(node_index, 0), node_pcw.data() + flat_leaf_index, + node_pcw(parent, 0), node_pcw.data() + flat_parent_index, + valid_leaf_threshold_, num_classes - 1, &means); for (int c = 1; c < num_classes; c++) { - float w = node_pcw(node_index, c); - if (parent_weight > 0.0) { - w += parent_weight * node_pcw(parent, c); - } - out(i, c - 1) = w / sum; + out(i, c - 1) = means[c - 1]; } break; } else if (left_child == FREE_NODE) { diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc index 544336b1ba..5f538c9e41 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc @@ -555,6 +555,31 @@ bool IsAllInitialized(const Tensor& features) { return feature_vec(feature_vec.size() - 1) >= 0; } +void GetParentWeightedMean(float leaf_sum, const float* leaf_data, + float parent_sum, const float* parent_data, + float valid_leaf_threshold, int num_outputs, + std::vector<float>* mean) { + float parent_weight = 0.0; + if (leaf_sum < valid_leaf_threshold && parent_sum >= 0) { + VLOG(1) << "not enough samples at leaf, including parent counts." + << "child sum = " << leaf_sum; + // Weight the parent's counts just enough so that the new sum is + // valid_leaf_threshold_, but never give any counts a weight of + // more than 1. + parent_weight = + std::min(1.0f, (valid_leaf_threshold - leaf_sum) / parent_sum); + leaf_sum += parent_weight * parent_sum; + VLOG(1) << "Sum w/ parent included = " << leaf_sum; + } + + for (int c = 0; c < num_outputs; c++) { + float w = leaf_data[c]; + if (parent_weight > 0.0) { + w += parent_weight * parent_data[c]; + } + (*mean)[c] = w / leaf_sum; + } +} } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h index 7c7193f0f4..a17622d8f5 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h +++ b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h @@ -228,6 +228,11 @@ inline bool CheckTensorBounds(OpKernelContext* context, const Tensor& tensor) { return true; } +void GetParentWeightedMean(float leaf_sum, const float* leaf_data, + float parent_sum, const float* parent_data, + float valid_leaf_threshold, int num_outputs, + std::vector<float>* mean); + } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensorboard/BUILD b/tensorflow/contrib/tensorboard/BUILD index 94b7222737..3321abb7e9 100644 --- a/tensorflow/contrib/tensorboard/BUILD +++ b/tensorflow/contrib/tensorboard/BUILD @@ -30,7 +30,10 @@ py_library( name = "plugins", srcs = ["plugins/__init__.py"], srcs_version = "PY2AND3", - deps = [":projector"], + deps = [ + ":projector", + ":trace", + ], ) # API methods and protos in `tf.contrib.tensorboard.plugins.projector` package. @@ -55,6 +58,31 @@ py_test( ], ) +# API methods and protos in `tf.contrib.tensorboard.plugins.trace` package. +py_library( + name = "trace", + srcs = glob( + ["plugins/trace/**/*.py"], + exclude = ["**/*test*"], + ), + srcs_version = "PY2AND3", + deps = [ + ":protos_all_py", + "//tensorflow/python:lib", + ], +) + +py_test( + name = "trace_test", + size = "small", + srcs = ["plugins/trace/trace_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":trace", + "//tensorflow:tensorflow_py", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/tensorboard/plugins/__init__.py b/tensorflow/contrib/tensorboard/plugins/__init__.py index 88336714a7..41aa77910c 100644 --- a/tensorflow/contrib/tensorboard/plugins/__init__.py +++ b/tensorflow/contrib/tensorboard/plugins/__init__.py @@ -20,3 +20,4 @@ from __future__ import print_function # Add projects here, they will show up under tf.contrib.tensorboard.plugins from tensorflow.contrib.tensorboard.plugins import projector +from tensorflow.contrib.tensorboard.plugins import trace diff --git a/tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py b/tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py index 284f3ba24e..6bb310db3e 100644 --- a/tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py +++ b/tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py @@ -38,7 +38,7 @@ class ProjectorApiTest(tf.test.TestCase): # Call the API method to save the configuration to a temporary dir. temp_dir = self.get_temp_dir() self.addCleanup(shutil.rmtree, temp_dir) - writer = tf.train.SummaryWriter(temp_dir) + writer = tf.summary.FileWriter(temp_dir) tf.contrib.tensorboard.plugins.projector.visualize_embeddings(writer, config) diff --git a/tensorflow/python/util/net_lib_test.py b/tensorflow/contrib/tensorboard/plugins/trace/__init__.py index 1e2ad53cda..2c99f4077e 100644 --- a/tensorflow/python/util/net_lib_test.py +++ b/tensorflow/contrib/tensorboard/plugins/trace/__init__.py @@ -12,28 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Public API for the Trace plugin.""" -"""Tests for the SWIG-wrapped test lib.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import tensorflow as tf - -from tensorflow.python.util import net_lib - - -class TestLibTest(tf.test.TestCase): - - def testPickUnusedPortOrDie(self): - port0 = net_lib.pick_unused_port_or_die() - port1 = net_lib.pick_unused_port_or_die() - self.assertGreater(port0, 0) - self.assertLess(port0, 65536) - self.assertGreater(port1, 0) - self.assertLess(port1, 65536) - self.assertNotEqual(port0, port1) - - -if __name__ == "__main__": - tf.test.main() +# pylint: disable=wildcard-import +from tensorflow.contrib.tensorboard.plugins.trace.trace import * +from tensorflow.contrib.tensorboard.plugins.trace.trace_info_pb2 import * +# pylint: enable=wildcard-import diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace.py b/tensorflow/contrib/tensorboard/plugins/trace/trace.py new file mode 100644 index 0000000000..0c645889af --- /dev/null +++ b/tensorflow/contrib/tensorboard/plugins/trace/trace.py @@ -0,0 +1,162 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Stores debugging information regarding TensorFlow model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import parser +import re +import token +import tensorflow as tf + +from google.protobuf import json_format +from tensorflow.contrib.tensorboard.plugins.trace.trace_info_pb2 import TraceInfo + +# List of regex patterns that match files in the core tensorflow library. +TF_LIB_REGEX_FPATHS = [os.sep + os.path.join('tensorflow', 'python')] + +LEFT_TOKENS = [token.LPAR, token.LSQB, token.LBRACE] +RIGHT_TOKENS = [token.RPAR, token.RSQB, token.RBRACE] +TOKENS = LEFT_TOKENS + RIGHT_TOKENS + + +def store_trace_info(output_file_path, graph=tf.get_default_graph(), + ignore_regex_fpaths=None): + """Collects and stores trace information for a TensorFlow model. + + The output proto is stored in json format. + + Args: + output_file_path: The path where to store the output proto. + graph: Optional. The data flow graph. Defaults to `tf.get_default_graph()`. + ignore_regex_fpaths: Optional. Files whose path matches any of the regexes + in this list will be ignored. Defaults to patterns that match the core + tensorflow python library. + """ + if not ignore_regex_fpaths: + ignore_regex_fpaths = TF_LIB_REGEX_FPATHS + + trace_info = TraceInfo() + # Extract trace information for every op in the graph. + source_fpaths = set() + for op in graph.get_operations(): + op_info = trace_info.ops.add() + op_info.name = op.name + op_info.op_type = op.type + op_info.device = op.device + for trace in op.traceback: + fname, lineno, _, _ = trace + # Ignore traces in specified file paths. + if os.path.isabs(fname) and not _ignore_file_path(fname, + ignore_regex_fpaths): + line_trace = op_info.traceback.add() + line_trace.file_path = fname + line_trace.line_number = lineno + source_fpaths.add(fname) + _add_data_from_tensors(op.inputs, op_info.inputs) + _add_data_from_tensors(op.outputs, op_info.outputs) + + # Read the source files involved in the graph construction. + for fpath in source_fpaths: + file_info = trace_info.files.add() + + with tf.gfile.Open(fpath, 'r') as f: + source = f.read().decode('utf-8') + + file_info.file_path = fpath + file_info.source_code = source + + line2start = find_multiline_statements(source) + + for key, value in line2start.items(): + file_info.multiline_statements[key] = value + + # Make sure the directory for the output file exists. + output_file_path = os.path.expanduser(output_file_path) + output_dir = os.path.dirname(output_file_path) + if not tf.gfile.Exists(output_dir): + tf.gfile.MakeDirs(output_dir) + + # Store the debug information. + with tf.gfile.Open(output_file_path, 'w') as f: + f.write(json_format.MessageToJson(trace_info)) + + +def find_multiline_statements(source): + """Parses the python source and finds multiline statements. + + Based on counting the number of open and closed parenthesis on each line. + + Args: + source: The source code string. + + Returns: + A dict that maps a line index A to a line index B, where A is the end of a + multiline statement and B is the start. Line indexing is 0-based. + """ + # Get the AST. + tree = parser.suite(source) + line2paren_count = [0] * (source.count('\n') + 1) + _count_brackets_braces_parenthesis(tree.totuple(True), line2paren_count) + + line2start = {} + for end in range(len(line2paren_count)): + if line2paren_count[end] >= 0: + # This is not the end of a multiline statement. + continue + cumulative_paren_count = 0 + for start in range(end, -1, -1): + cumulative_paren_count += line2paren_count[start] + if cumulative_paren_count == 0: + line2start[end] = start + break + return line2start + + +def _add_data_from_tensors(tensors, info): + for t in tensors: + tensor_info = info.add() + + shape = t.get_shape() + if shape.ndims: + shape = [(-1 if s is None else s) for s in shape.as_list()] + tensor_info.shape.extend(shape) + tensor_info.dtype = t.dtype.name + tensor_info.num_bytes_per_elem = t.dtype.size + + for c in t.consumers(): + tensor_info.consumers.append(c.name) + + +def _ignore_file_path(fname, ignore_regex_fpaths): + for regex_pattern in ignore_regex_fpaths: + if re.search(regex_pattern, fname): + return True + return False + + +def _count_brackets_braces_parenthesis(node, line2par): + if isinstance(node[1], tuple): + for child in node[1:]: + _count_brackets_braces_parenthesis(child, line2par) + else: + tok = node[0] + if tok in TOKENS: + lineno = node[2] + line2par[lineno - 1] += (1 if tok in LEFT_TOKENS else -1) + return line2par diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto b/tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto new file mode 100644 index 0000000000..09013c6387 --- /dev/null +++ b/tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto @@ -0,0 +1,60 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow; + +message TraceInfo { + repeated OpInfo ops = 1; + repeated FileInfo files = 2; +} + +message OpInfo { + string name = 1; + string op_type = 2; + string device = 3; + repeated LineTrace traceback = 4; + repeated TensorInfo inputs = 5; + repeated TensorInfo outputs = 6; +} + +message LineTrace { + // Absolute file path. + string file_path = 1; + // 1-based line number. + uint32 line_number = 2; +} + +message TensorInfo { + // Size of the tensor for each dimension. Value of -1 denotes "unknown" + // size for that dimension. + repeated int32 shape = 1; + // The data type of the tensor. + string dtype = 2; + // Number of bytes per element in the tensor. + uint32 num_bytes_per_elem = 3; + // List of operation names that consume this tensor. + repeated string consumers = 4; +} + +message FileInfo { + // Absolute file path to the source code. + string file_path = 1; + string source_code = 2; + // Map from end of statement to start of statement. End and start are 0-based + // line indexes. + map<uint32, uint32> multiline_statements = 3; +} diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace_test.py b/tensorflow/contrib/tensorboard/plugins/trace/trace_test.py new file mode 100644 index 0000000000..e67bde9d59 --- /dev/null +++ b/tensorflow/contrib/tensorboard/plugins/trace/trace_test.py @@ -0,0 +1,91 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.contrib.tensorboard.plugins.trace package.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +import tensorflow as tf + +from google.protobuf import json_format +from tensorflow.contrib.tensorboard.plugins import trace + + +class TraceTest(tf.test.TestCase): + + def setUp(self): + self._temp_dir = tempfile.mkdtemp() + self._temp_trace_json = self._temp_dir + 'trace.json' + + def tearDown(self): + tf.gfile.DeleteRecursively(self._temp_dir) + + def testEmptyGraph(self): + trace_info = self._store_and_read_trace_info() + self.assertEqual(len(trace_info.ops), 0) + + def testHasSourceCodeOfThisFile(self): + tf.constant(0) + trace_info = self._store_and_read_trace_info() + + self.assertTrue(trace_info.files) + for file_info in trace_info.files: + if file_info.file_path.endswith('trace_test.py'): + return + self.fail('trace_test file not found in the trace info json') + + def testHasTheConstantOp(self): + tf.constant(0) + trace_info = self._store_and_read_trace_info() + + self.assertTrue(trace_info.ops) + + for op in trace_info.ops: + if op.op_type == 'Const': + return + self.fail('Could not find operation of type `Const` in the graph') + + def testMultilineStatements(self): + source = """def test(): + a(4, + 3, + 1) + + b(3, 4, 5) + + c((4, 3), + (), + ) + """ + line2start = trace.find_multiline_statements(source) + + self.assertEqual(line2start[3], 1) + self.assertEqual(line2start[9], 7) + self.assertEqual(len(line2start), 2) + + def _store_and_read_trace_info(self): + trace.store_trace_info(self._temp_trace_json) + trace_info = trace.TraceInfo() + + with tf.gfile.Open(self._temp_trace_json) as f: + text = f.read().decode('utf-8') + json_format.Parse(text, trace_info) + + return trace_info + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/testing/python/framework/fake_summary_writer.py b/tensorflow/contrib/testing/python/framework/fake_summary_writer.py index 7966b1a2c0..3b06bcc0e0 100644 --- a/tensorflow/contrib/testing/python/framework/fake_summary_writer.py +++ b/tensorflow/contrib/testing/python/framework/fake_summary_writer.py @@ -19,8 +19,8 @@ from __future__ import division from __future__ import print_function from tensorflow.core.framework import summary_pb2 +from tensorflow.python import summary from tensorflow.python.summary.writer import writer_cache -from tensorflow.python.training import summary_io # TODO(ptucker): Replace with mock framework. @@ -33,16 +33,16 @@ class FakeSummaryWriter(object): def install(cls): if cls._replaced_summary_writer: raise ValueError('FakeSummaryWriter already installed.') - cls._replaced_summary_writer = summary_io.SummaryWriter - summary_io.SummaryWriter = FakeSummaryWriter - writer_cache.SummaryWriter = FakeSummaryWriter + cls._replaced_summary_writer = summary.FileWriter + summary.FileWriter = FakeSummaryWriter + writer_cache.FileWriter = FakeSummaryWriter @classmethod def uninstall(cls): if not cls._replaced_summary_writer: raise ValueError('FakeSummaryWriter not installed.') - summary_io.SummaryWriter = cls._replaced_summary_writer - writer_cache.SummaryWriter = cls._replaced_summary_writer + summary.FileWriter = cls._replaced_summary_writer + writer_cache.FileWriter = cls._replaced_summary_writer cls._replaced_summary_writer = None def __init__(self, logdir, graph=None): @@ -86,18 +86,18 @@ class FakeSummaryWriter(object): if expected_session_logs is not None: test_case.assertEqual(expected_session_logs, self._added_session_logs) - def add_summary(self, summary, current_global_step): + def add_summary(self, summ, current_global_step): """Add summary.""" - if isinstance(summary, bytes): + if isinstance(summ, bytes): summary_proto = summary_pb2.Summary() - summary_proto.ParseFromString(summary) - summary = summary_proto + summary_proto.ParseFromString(summ) + summ = summary_proto if current_global_step in self._summaries: step_summaries = self._summaries[current_global_step] else: step_summaries = [] self._summaries[current_global_step] = step_summaries - step_summaries.append(summary) + step_summaries.append(summ) # NOTE: Ignore global_step since its value is non-deterministic. def add_graph(self, graph, global_step=None, graph_def=None): diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py index ee4208d312..9070cd6e8d 100644 --- a/tensorflow/contrib/training/python/training/evaluation.py +++ b/tensorflow/contrib/training/python/training/evaluation.py @@ -144,7 +144,6 @@ from tensorflow.contrib.framework.python.ops import variables from tensorflow.core.protobuf import saver_pb2 from tensorflow.python import summary from tensorflow.python.framework import ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import monitored_session @@ -254,7 +253,7 @@ def get_or_create_eval_step(): class StopAfterNEvalsHook(session_run_hook.SessionRunHook): - """A run hook used by the evaluation routines to run the `eval_ops` N times.""" + """Run hook used by the evaluation routines to run the `eval_ops` N times.""" def __init__(self, num_evals): """Constructs the run hook. @@ -274,6 +273,7 @@ class StopAfterNEvalsHook(session_run_hook.SessionRunHook): def after_run(self, run_context, run_values): evals_completed = run_values.results['evals_completed'] + logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals) if evals_completed >= self._num_evals: run_context.request_stop() @@ -299,7 +299,7 @@ class _FinalOpsHook(session_run_hook.SessionRunHook): return self._final_ops_values def end(self, session): - if self._final_ops: + if self._final_ops is not None: self._final_ops_values = session.run(self._final_ops, feed_dict=self._final_ops_feed_dict) @@ -379,14 +379,14 @@ def evaluate_once( the requested number of times. Optionally, a user can pass in `final_ops`, a single `Tensor`, a list of - `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is evaluated - a single time after `eval_ops` has finished running and the fetched values of - `final_ops` are returned. If `final_ops` is left as `None`, then `None` is - returned. + `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is + evaluated a single time after `eval_ops` has finished running and the fetched + values of `final_ops` are returned. If `final_ops` is left as `None`, then + `None` is returned. One may also consider using a `tf.contrib.training.SummaryAtEndHook` to record - summaries after the `eval_ops` have run. If `eval_ops` is `None`, the summaries - run immedietly after the model checkpoint has been restored. + summaries after the `eval_ops` have run. If `eval_ops` is `None`, the + summaries run immedietly after the model checkpoint has been restored. Note that `evaluate_once` creates a local variable used to track the number of evaluations run via `tf.contrib.training.get_or_create_eval_step`. @@ -403,8 +403,8 @@ def evaluate_once( eval_ops: A operation which is run until the session is requested to stop, commonly done by a `tf.contrib.training.StopAfterNEvalsHook`. feed_dict: The feed dictionary to use when executing the `eval_ops`. - final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names to - `Tensors`. + final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names + to `Tensors`. final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`. variables_to_restore: A list of TensorFlow variables to restore during evaluation. If the argument is left as `None` then @@ -420,9 +420,14 @@ def evaluate_once( eval_step = get_or_create_eval_step() if eval_ops is not None: - eval_ops = control_flow_ops.with_dependencies( - [eval_ops], - state_ops.assign_add(eval_step, 1)) + update_eval_step = state_ops.assign_add(eval_step, 1) + + if isinstance(eval_ops, dict): + eval_ops['update_eval_step'] = update_eval_step + elif isinstance(eval_ops, (tuple, list)): + eval_ops = list(eval_ops) + [update_eval_step] + else: + eval_ops = [eval_ops, update_eval_step] # Must come before the scaffold check. if scaffold and scaffold.saver: @@ -484,14 +489,14 @@ def evaluate_repeatedly( the requested number of times. Optionally, a user can pass in `final_ops`, a single `Tensor`, a list of - `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is evaluated - a single time after `eval_ops` has finished running and the fetched values of - `final_ops` are returned. If `final_ops` is left as `None`, then `None` is - returned. + `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is + evaluated a single time after `eval_ops` has finished running and the fetched + values of `final_ops` are returned. If `final_ops` is left as `None`, then + `None` is returned. One may also consider using a `tf.contrib.training.SummaryAtEndHook` to record - summaries after the `eval_ops` have run. If `eval_ops` is `None`, the summaries - run immedietly after the model checkpoint has been restored. + summaries after the `eval_ops` have run. If `eval_ops` is `None`, the + summaries run immedietly after the model checkpoint has been restored. Note that `evaluate_once` creates a local variable used to track the number of evaluations run via `tf.contrib.training.get_or_create_eval_step`. @@ -508,8 +513,8 @@ def evaluate_repeatedly( eval_ops: A operation which is run until the session is requested to stop, commonly done by a `tf.contrib.training.StopAfterNEvalsHook`. feed_dict: The feed dictionary to use when executing the `eval_ops`. - final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names to - `Tensors`. + final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names + to `Tensors`. final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`. variables_to_restore: A list of TensorFlow variables to restore during evaluation. If the argument is left as `None` then @@ -530,9 +535,14 @@ def evaluate_repeatedly( eval_step = get_or_create_eval_step() if eval_ops is not None: - eval_ops = control_flow_ops.with_dependencies( - [eval_ops], - state_ops.assign_add(eval_step, 1)) + update_eval_step = state_ops.assign_add(eval_step, 1) + + if isinstance(eval_ops, dict): + eval_ops['update_eval_step'] = update_eval_step + elif isinstance(eval_ops, (tuple, list)): + eval_ops = list(eval_ops) + [update_eval_step] + else: + eval_ops = [eval_ops, update_eval_step] # Must come before the scaffold check. if scaffold and scaffold.saver: @@ -572,7 +582,9 @@ def evaluate_repeatedly( 'Finished evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime())) num_evaluations += 1 - if num_evaluations >= max_number_of_evaluations: + + reached_max = num_evaluations >= max_number_of_evaluations + if max_number_of_evaluations and reached_max: return final_ops_hook.final_ops_values logging.info('Timed-out waiting for a checkpoint.') diff --git a/tensorflow/contrib/training/python/training/resample.py b/tensorflow/contrib/training/python/training/resample.py index 513b7f59a0..ffaf6b8a03 100644 --- a/tensorflow/contrib/training/python/training/resample.py +++ b/tensorflow/contrib/training/python/training/resample.py @@ -128,7 +128,7 @@ def resample_at_rate(inputs, rates, scope=None, seed=None, back_prop=False): def weighted_resample(inputs, weights, overall_rate, scope=None, - mean_decay=0.999, warmup=10, seed=None): + mean_decay=0.999, seed=None): """Performs an approximate weighted resampling of `inputs`. This method chooses elements from `inputs` where each item's rate of @@ -142,9 +142,6 @@ def weighted_resample(inputs, weights, overall_rate, scope=None, overall_rate: Desired overall rate of resampling. scope: Scope to use for the op. mean_decay: How quickly to decay the running estimate of the mean weight. - warmup: Until the resulting tensor has been evaluated `warmup` - times, the resampling menthod uses the true mean over all calls - as its weight estimate, rather than a decayed mean. seed: Random seed. Returns: @@ -158,26 +155,16 @@ def weighted_resample(inputs, weights, overall_rate, scope=None, # overall rate, and a weight twice the average has twice the rate, # etc. with ops.name_scope(scope, 'weighted_resample', inputs) as opscope: - # First: Maintain a running estimated mean weight, with decay - # adjusted (by also maintaining an invocation count) during the - # warmup period so that at the beginning, there aren't too many - # zeros mixed in, throwing the average off. + # First: Maintain a running estimated mean weight, with zero debiasing + # enabled (by default) to avoid throwing the average off. with variable_scope.variable_scope(scope, 'estimate_mean', inputs): - count_so_far = variable_scope.get_local_variable( - 'resample_count', initializer=0) - estimated_mean = variable_scope.get_local_variable( 'estimated_mean', initializer=0.0) - count = count_so_far.assign_add(1) - real_decay = math_ops.minimum( - math_ops.truediv((count - 1), math_ops.minimum(count, warmup)), - mean_decay) - batch_mean = math_ops.reduce_mean(weights) mean = moving_averages.assign_moving_average( - estimated_mean, batch_mean, real_decay, zero_debias=False) + estimated_mean, batch_mean, mean_decay) # Then, normalize the weights into rates using the mean weight and # overall target rate: diff --git a/tensorflow/contrib/training/python/training/resample_test.py b/tensorflow/contrib/training/python/training/resample_test.py index 9324a940d3..1f8d332125 100644 --- a/tensorflow/contrib/training/python/training/resample_test.py +++ b/tensorflow/contrib/training/python/training/resample_test.py @@ -40,7 +40,8 @@ class ResampleTest(tf.test.TestCase): resampled_back_out = tf.contrib.training.resample_at_rate( resampled_in, 1.0/rates, seed=456) - init = tf.local_variables_initializer() + init = tf.group(tf.local_variables_initializer(), + tf.global_variables_initializer()) with self.test_session() as s: s.run(init) # initialize @@ -81,7 +82,8 @@ class ResampleTest(tf.test.TestCase): invrates = 1.0/rates - init = tf.local_variables_initializer() + init = tf.group(tf.local_variables_initializer(), + tf.global_variables_initializer()) expected_sum_op = tf.reduce_sum(vals) with self.test_session() as s: s.run(init) diff --git a/tensorflow/contrib/training/python/training/sampling_ops.py b/tensorflow/contrib/training/python/training/sampling_ops.py index 2efc50cb4e..d5e6878e75 100644 --- a/tensorflow/contrib/training/python/training/sampling_ops.py +++ b/tensorflow/contrib/training/python/training/sampling_ops.py @@ -387,7 +387,7 @@ def _calculate_acceptance_probabilities(init_probs, target_probs): ratio_l = target_probs / init_probs # Replace NaNs with 0s. - ratio_l = math_ops.select(math_ops.is_nan(ratio_l), + ratio_l = array_ops.where(math_ops.is_nan(ratio_l), array_ops.zeros_like(ratio_l), ratio_l) diff --git a/tensorflow/contrib/training/python/training/sampling_ops_test.py b/tensorflow/contrib/training/python/training/sampling_ops_test.py index f0501b3f3e..788e01efd7 100644 --- a/tensorflow/contrib/training/python/training/sampling_ops_test.py +++ b/tensorflow/contrib/training/python/training/sampling_ops_test.py @@ -133,7 +133,8 @@ class StratifiedSampleTest(tf.test.TestCase): val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs) batches += tf.contrib.training.stratified_sample( val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs) - summary_op = tf.merge_summary(tf.get_collection(tf.GraphKeys.SUMMARIES)) + summary_op = tf.contrib.deprecated.merge_summary( + tf.get_collection(tf.GraphKeys.SUMMARIES)) with self.test_session() as sess: coord = tf.train.Coordinator() diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 3b81fa859a..3f9b94128a 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1292,7 +1292,10 @@ cc_library( hdrs = [ "platform/regexp.h", ], - visibility = ["//tensorflow/tools/tfprof:__subpackages__"], + visibility = [ + "//tensorflow/compiler:__subpackages__", + "//tensorflow/tools/tfprof:__subpackages__", + ], deps = [":lib_internal"], ) diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 8a49c7d3ab..e1f2c55230 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -407,7 +407,9 @@ Status DirectSession::Run(const RunOptions& run_options, &executors_and_keys, &run_state_args)); // Create a run state and start execution. - RunState run_state(input_tensor_names, output_names); + Executor::Args args; + args.step_id = step_id_counter_.fetch_add(1); + RunState run_state(input_tensor_names, output_names, args.step_id, &devices_); run_state.rendez = new IntraProcessRendezvous(device_mgr_.get()); CancellationManager step_cancellation_manager; @@ -425,8 +427,6 @@ Status DirectSession::Run(const RunOptions& run_options, run_state.executors_done.Notify(); }); - Executor::Args args; - args.step_id = step_id_counter_.fetch_add(1); args.rendezvous = run_state.rendez; args.cancellation_manager = &step_cancellation_manager; args.runner = [this, pool](Executor::Args::Closure c) { @@ -434,7 +434,7 @@ Status DirectSession::Run(const RunOptions& run_options, }; args.session_state = &session_state_; args.tensor_store = &run_state.tensor_store; - args.step_resource_manager = &run_state.step_resource_manager; + args.step_container = &run_state.step_container; if (LogMemory::IsEnabled()) { LogMemory::RecordStep(args.step_id, run_state_args.handle); } @@ -582,7 +582,10 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names, &run_state_args)); // Create the run state and save it for future PRun calls. - RunState* run_state = new RunState(input_names, output_names); + Executor::Args args; + args.step_id = step_id_counter_.fetch_add(1); + RunState* run_state = + new RunState(input_names, output_names, args.step_id, &devices_); run_state->rendez = new IntraProcessRendezvous(device_mgr_.get()); { mutex_lock l(executor_lock_); @@ -606,8 +609,6 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names, run_state->executors_done.Notify(); }); - Executor::Args args; - args.step_id = step_id_counter_.fetch_add(1); args.rendezvous = run_state->rendez; args.cancellation_manager = cancellation_manager_; args.runner = [this, pool](Executor::Args::Closure c) { @@ -615,7 +616,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names, }; args.session_state = &session_state_; args.tensor_store = &run_state->tensor_store; - args.step_resource_manager = &run_state->step_resource_manager; + args.step_container = &run_state->step_container; if (LogMemory::IsEnabled()) { LogMemory::RecordStep(args.step_id, run_state_args.handle); } @@ -1173,7 +1174,16 @@ Status DirectSession::CreateGraphs( } DirectSession::RunState::RunState(const std::vector<string>& input_names, - const std::vector<string>& output_names) { + const std::vector<string>& output_names, + int64 step_id, + const std::vector<Device*>* devices) + : step_container(step_id, [devices](const string& name) { + for (auto d : *devices) { + if (!d->resource_manager()->Cleanup(name).ok()) { + // Do nothing... + } + } + }) { // Initially all the feeds and fetches are pending. for (auto& name : input_names) { pending_inputs.emplace(name); diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index 37f0277a40..127c08d0a4 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -151,10 +151,11 @@ class DirectSession : public Session { std::unordered_set<string> pending_inputs; std::unordered_set<string> pending_outputs; TensorStore tensor_store; - ResourceMgr step_resource_manager; + ScopedStepContainer step_container; RunState(const std::vector<string>& input_names, - const std::vector<string>& output_names); + const std::vector<string>& output_names, int64 step_id, + const std::vector<Device*>* devices); ~RunState(); }; diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index ef531dc6c5..542ed70b4c 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -873,8 +873,8 @@ class ExecutorState { Rendezvous* rendezvous_; SessionState* session_state_; TensorStore* tensor_store_; - // Step-local resource manager. - ResourceMgr* step_resource_manager_; + // Step-local container. + ScopedStepContainer* step_container_; StepStatsCollector* stats_collector_; // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper // instead of a pointer? (avoids having to delete). @@ -992,7 +992,7 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl) rendezvous_(args.rendezvous), session_state_(args.session_state), tensor_store_(args.tensor_store), - step_resource_manager_(args.step_resource_manager), + step_container_(args.step_container), stats_collector_(args.stats_collector), slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper), call_frame_(args.call_frame), @@ -1220,7 +1220,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { params.call_frame = call_frame_; params.function_library = impl_->params_.function_library; params.resource_manager = device->resource_manager(); - params.step_resource_manager = step_resource_manager_; + params.step_container = step_container_; params.slice_reader_cache = slice_reader_cache_; params.inputs = &inputs; params.input_device_contexts = &input_device_contexts; diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index 2e9990d951..8cca22fb6f 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -88,7 +88,7 @@ class Executor { CancellationManager* cancellation_manager = nullptr; SessionState* session_state = nullptr; TensorStore* tensor_store = nullptr; - ResourceMgr* step_resource_manager = nullptr; + ScopedStepContainer* step_container = nullptr; // If true, calls Sync() on the device. bool sync_on_finish = false; diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 2d29f5176d..695c7244ae 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -261,7 +261,7 @@ class CallOp : public AsyncOpKernel { done); FunctionLibraryRuntime::Options opts; opts.step_id = ctx->step_id(); - opts.step_resource_manager = ctx->step_resource_manager(); + opts.step_container = ctx->step_container(); opts.runner = ctx->runner(); std::vector<Tensor> args; args.reserve(ctx->num_inputs()); @@ -558,7 +558,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, Executor::Args exec_args; // Inherit the step_id from the caller. exec_args.step_id = opts.step_id; - exec_args.step_resource_manager = opts.step_resource_manager; + exec_args.step_container = opts.step_container; exec_args.call_frame = frame; exec_args.cancellation_manager = opts.cancellation_manager; exec_args.runner = *opts.runner; diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 731edd6ac3..4e89226752 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -359,8 +359,8 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, return; } - StartParallelExecutors(handle, item, rendezvous, collector, cost_graph, - cancellation_manager, + StartParallelExecutors(handle, step_id, item, rendezvous, collector, + cost_graph, cancellation_manager, [this, item, rendezvous, done](const Status& s) { done(s); rendezvous->Unref(); @@ -368,22 +368,25 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, }); } -void GraphMgr::StartParallelExecutors(const string& handle, Item* item, - Rendezvous* rendezvous, +void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id, + Item* item, Rendezvous* rendezvous, StepStatsCollector* collector, CostGraphDef* cost_graph, CancellationManager* cancellation_manager, StatusCallback done) { const int num_units = item->units.size(); CHECK_GE(num_units, 1); - ResourceMgr* step_resource_manager = new ResourceMgr; + ScopedStepContainer* step_container = + new ScopedStepContainer(step_id, [this](const string& name) { + worker_env_->device_mgr->ClearContainers({name}); + }); // NOTE: Transfer one ref of rendezvous and item. ExecutorBarrier* barrier = new ExecutorBarrier( - num_units, rendezvous, [this, item, collector, cost_graph, - step_resource_manager, done](const Status& s) { + num_units, rendezvous, [this, item, collector, cost_graph, step_container, + done](const Status& s) { BuildCostModel(item, collector, cost_graph); done(s); - delete step_resource_manager; + delete step_container; }); Executor::Args args; { @@ -393,7 +396,7 @@ void GraphMgr::StartParallelExecutors(const string& handle, Item* item, args.rendezvous = rendezvous; args.cancellation_manager = cancellation_manager; args.stats_collector = collector; - args.step_resource_manager = step_resource_manager; + args.step_container = step_container; args.sync_on_finish = true; if (LogMemory::IsEnabled()) { LogMemory::RecordStep(args.step_id, handle); diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index a3771e6747..e9b8d415ed 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -140,7 +140,7 @@ class GraphMgr { // mechanism to gc these graphs. std::unordered_map<string, Item*> table_; - void StartParallelExecutors(const string& handle, Item* item, + void StartParallelExecutors(const string& handle, int64 step_id, Item* item, Rendezvous* rendezvous, StepStatsCollector* collector, CostGraphDef* cost_graph, diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h index 4f8eb04c95..06859c5290 100644 --- a/tensorflow/core/framework/allocator.h +++ b/tensorflow/core/framework/allocator.h @@ -66,8 +66,13 @@ struct AllocatorStats { // device memory. class Allocator { public: +#ifdef EIGEN_VECTORIZE_AVX512 + // Align to 64 byte boundary. + static constexpr size_t kAllocatorAlignment = 64; +#else // Align to 32 byte boundary. static constexpr size_t kAllocatorAlignment = 32; +#endif virtual ~Allocator(); diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index dcf0ae40d5..bc1441ac6e 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -739,6 +739,7 @@ Status ConcatShapeHelper(InferenceContext* c, int start_value_index, // Merge all the non-concat dims, and sum the concat dim to make an output // shape. const int32 concat_dim = concat_dim_t->scalar<int32>()(); + // Minimum required number of dimensions. const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1; @@ -749,7 +750,11 @@ Status ConcatShapeHelper(InferenceContext* c, int start_value_index, TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input)); TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before)); DimensionHandle output_middle = c->Dim(input, concat_dim); - TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after)); + if (concat_dim == -1) { + output_after = c->Scalar(); // no dimensions. + } else { + TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after)); + } for (int i = end_value_index - 2; i >= start_value_index; --i) { ShapeHandle before; @@ -758,7 +763,11 @@ Status ConcatShapeHelper(InferenceContext* c, int start_value_index, TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input)); TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before)); DimensionHandle middle = c->Dim(input, concat_dim); - TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after)); + if (concat_dim == -1) { + after = c->Scalar(); + } else { + TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after)); + } TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before)); TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle)); diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 5cb4e28faf..1fa3aee517 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -35,6 +35,7 @@ namespace tensorflow { class CancellationManager; class OpKernel; class ResourceMgr; +class ScopedStepContainer; // FunctionDefHelper::Create is a convenient helper to construct a // FunctionDef proto. @@ -381,8 +382,8 @@ class FunctionLibraryRuntime { // The id of the step that is calling this function. int64 step_id = 0; - // Per-step resource manager. Does not take ownership. - ResourceMgr* step_resource_manager = nullptr; + // Per-step container. + ScopedStepContainer* step_container; std::function<void(std::function<void()>)>* runner = nullptr; }; diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 50520bb3fd..c4023d2ced 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -222,7 +222,7 @@ OpKernelContext::~OpKernelContext() { Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) { Allocator* allocator = - params_->device->GetStepAllocator(attr, step_resource_manager()); + params_->device->GetStepAllocator(attr, resource_manager()); if (params_->track_allocations) { mutex_lock lock(mu_); for (const auto& wrapped : wrapped_allocators_) { diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 4a66d43e50..7318a2dc7d 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -511,8 +511,9 @@ class OpKernelContext { // Shared resources accessible by this op kernel invocation. ResourceMgr* resource_manager = nullptr; - // Per-step resources accessible by this op kernel invocation. - ResourceMgr* step_resource_manager = nullptr; + // Per-step resources accessible by this op kernel invocation should be + // stored in this container.. + ScopedStepContainer* step_container = nullptr; // Mechanism used by this op kernel invocation to communicate with // computations running on other devices. @@ -938,9 +939,9 @@ class OpKernelContext { // not be called from Op kernels. void retrieve_accessed_tensors(TensorReferenceVector* out_vector); - // Per-step resource manager for use by white-listed internal ops. - ResourceMgr* step_resource_manager() const { - return params_->step_resource_manager; + // Per-step container for use by white-listed internal ops. + ScopedStepContainer* step_container() const { + return params_->step_container; } // Helper routines for the OP_REQUIRES macros diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index ae4186ee71..a1053669b7 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -79,6 +79,24 @@ class ResourceBase : public core::RefCounted { virtual string DebugString() = 0; }; +// Container used for per-step resources. +class ScopedStepContainer { + public: + // step_id: the unique ID of this step. Doesn't have to be sequential, just + // has to be unique. + // cleanup: callback to delete a container of this name. + ScopedStepContainer(const int64 step_id, + std::function<void(const string&)> cleanup) + : name_(strings::StrCat("__per_step_", step_id)), cleanup_(cleanup) {} + ~ScopedStepContainer() { cleanup_(name_); } + + const string& name() const { return name_; } + + private: + const string name_; + const std::function<void(const string&)> cleanup_; +}; + class ResourceMgr { public: ResourceMgr(); @@ -165,6 +183,9 @@ class ResourceMgr { template <typename T> ResourceHandle MakeResourceHandle(OpKernelContext* ctx, const string& container, const string& name); +template <typename T> +ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx, + const string& name); // Returns a resource handle from a numbered op input. ResourceHandle HandleFromInput(OpKernelContext* ctx, int input); @@ -385,6 +406,12 @@ ResourceHandle MakeResourceHandle(OpKernelContext* ctx, const string& container, return result; } +template <typename T> +ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx, + const string& name) { + return MakeResourceHandle<T>(ctx, ctx->step_container()->name(), name); +} + namespace internal { template <typename T> diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 8e9eceb699..a35f3ff15c 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -74,6 +74,28 @@ struct RecvInfo { typedef std::unordered_map<DupRecvKey, RecvInfo, DupRecvKeyHash, DupRecvKeyEq> DupRecvTable; +struct DupControlKey { + int dst_node_id; // Edge's dst node id + GraphDef* src_graph; // Edge's src node is in this subgraph +}; + +struct DupControlKeyHash { + size_t operator()(const DupControlKey& k) const { + return Hash64(reinterpret_cast<const char*>(&k.src_graph), + sizeof(k.src_graph), k.dst_node_id); + } +}; + +struct DupControlKeyEq { + bool operator()(const DupControlKey& x, const DupControlKey& y) const { + return (x.dst_node_id == y.dst_node_id) && (x.src_graph == y.src_graph); + } +}; + +typedef std::unordered_map<DupControlKey, NodeDef*, DupControlKeyHash, + DupControlKeyEq> + DupControlTable; + struct PairIntHash { public: std::size_t operator()(const std::pair<int, int>& x) const { @@ -825,6 +847,7 @@ Status Partition(const PartitionOptions& opts, Graph* g, string dstp; std::vector<const Edge*> inputs; DupRecvTable dup_recv(3); + DupControlTable dup_control(3); // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref // edge to dst. We will add a control edge for every pair in @@ -918,7 +941,9 @@ Status Partition(const PartitionOptions& opts, Graph* g, } // Check whether there is already a send/recv pair transferring - // the same tensor/control from the src to dst partition. + // the same tensor/control from src to the dst partition. This + // handles the dedup case when a single source in one partition + // going to multiple destinations in another partition. const bool on_host = IsDstInputOnHost(edge, g_info); DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host}; auto iter = dup_recv.find(key); @@ -943,6 +968,16 @@ Status Partition(const PartitionOptions& opts, Graph* g, NodeDefBuilder::NodeOut send_from; if (edge->IsControlEdge()) { + // This handles the dedup case when multiple control edges going from + // one partition to a single destination in another partition. + DupControlKey key{dst->id(), src_graph}; + auto iter = dup_control.find(key); + if (iter != dup_control.end()) { + // This could cause start_time(src) > start_time(iter->second). + AddInput(iter->second, src->name(), Graph::kControlSlot); + continue; + } + // Insert a dummy const node that will generate a tiny // data element to be sent from send to recv. VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "[" @@ -956,6 +991,7 @@ Status Partition(const PartitionOptions& opts, Graph* g, } AddInput(dummy, src->name(), Graph::kControlSlot); send_from.Reset(dummy->name(), 0, DT_FLOAT); + dup_control[key] = dummy; } else { send_from.Reset(src->name(), edge->src_output(), EdgeType(edge)); } diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc index d8322e6077..fd259f0b40 100644 --- a/tensorflow/core/graph/graph_partition_test.cc +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -398,5 +398,37 @@ TEST_F(GraphPartitionTest, PartitionIncompleteGraph) { EXPECT_EQ(error::INVALID_ARGUMENT, status.code()) << status; } +TEST_F(GraphPartitionTest, CrossDevice_MultiControl) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + auto a1 = Input(in_.WithOpName("A1")); + auto a2 = Input(in_.WithOpName("A2")); + auto b1 = Input(in_.WithOpName("B1")); + Combine( + in_.WithOpName("B2").WithControlDependencies(a1).WithControlDependencies( + a2), + b1, b1); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + string a = "/job:a/replica:0/task:0/cpu:0"; + string b = "/job:a/replica:0/task:0/cpu:1"; + a1 = Input(scope_a_.WithOpName("A1")); + a2 = Input(scope_a_.WithOpName("A2")); + auto c = Const(scope_a_.WithOpName("A1/_0") + .WithControlDependencies(a1) + .WithControlDependencies(a2), + {}); + _Send(scope_a_.WithOpName("A1/_1"), c, "edge_3_A1", a, 82, b); + ExpectMatchA(); + + auto recv = + _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_3_A1", a, 82, b); + auto id = Identity(scope_b_.WithOpName("A1/_3"), recv); + b1 = Input(scope_b_.WithOpName("B1")); + Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1); + ExpectMatchB(); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index cd76a40a47..3e8da9884e 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -30,7 +30,6 @@ load( "tf_cc_tests", "tf_copts", "tf_opts_nortti_if_android", - "tf_kernel_libraries", "tf_kernel_library", "cc_header_only_library", ) @@ -396,7 +395,6 @@ ARRAY_DEPS = [ ":fill_functor", ":gather_functor", ":ops_util", - ":strided_slice_op", ":transpose_functor", "//tensorflow/core:array_grad", "//tensorflow/core:array_ops_op_lib", @@ -432,49 +430,185 @@ tf_kernel_library( ], ) -tf_kernel_libraries( +cc_library( name = "array", - libs = [ + deps = [ ":batch_space_ops", + ":bcast_ops", + ":bitcast_op", + ":concat_op", + ":constant_op", ":depth_space_ops", + ":diag_op", + ":edit_distance_op", ":extract_image_patches_op", + ":gather_nd_op", + ":gather_op", + ":identity_op", + ":listdiff_op", + ":matrix_band_part_op", + ":matrix_diag_op", + ":matrix_set_diag_op", + ":mirror_pad_op", + ":one_hot_op", + ":pack_op", + ":pad_op", + ":quantize_and_dequantize_op", + ":reshape_op", + ":reverse_op", + ":reverse_sequence_op", + ":shape_ops", + ":slice_op", ":split_op", ":split_v_op", + ":strided_slice_op", + ":tile_ops", + ":transpose_op", + ":unique_op", ":unpack_op", + ":where_op", ], - prefixes = [ - "bcast_ops", - "bitcast_op", - "concat_op", - "constant_op", - "diag_op", - "matrix_band_part_op", - "matrix_diag_op", - "matrix_set_diag_op", - "edit_distance_op", - "gather_op", - "gather_nd_op", - "identity_op", - "listdiff_op", - "mirror_pad_op", - "one_hot_op", - "pack_op", - "pad_op", - "quantize_and_dequantize_op", - "reshape_op", - "reverse_op", - "reverse_sequence_op", - "shape_ops", - "slice_op", - "tile_ops", - "transpose_op", - "unique_op", - "where_op", - ], +) + +tf_kernel_library( + name = "bcast_ops", + prefix = "bcast_ops", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "bitcast_op", + prefix = "bitcast_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "concat_op", + prefix = "concat_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "constant_op", + prefix = "constant_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "diag_op", + prefix = "diag_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "edit_distance_op", + prefix = "edit_distance_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "gather_nd_op", + prefix = "gather_nd_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "gather_op", + prefix = "gather_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "identity_op", + prefix = "identity_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "listdiff_op", + prefix = "listdiff_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "matrix_band_part_op", + prefix = "matrix_band_part_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "matrix_diag_op", + prefix = "matrix_diag_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "matrix_set_diag_op", + prefix = "matrix_set_diag_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "mirror_pad_op", + prefix = "mirror_pad_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "one_hot_op", + prefix = "one_hot_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "pack_op", + prefix = "pack_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "pad_op", + prefix = "pad_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "quantize_and_dequantize_op", + prefix = "quantize_and_dequantize_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "reshape_op", + prefix = "reshape_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "reverse_op", + prefix = "reverse_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "reverse_sequence_op", + prefix = "reverse_sequence_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "shape_ops", + prefix = "shape_ops", deps = ARRAY_DEPS, ) tf_kernel_library( + name = "slice_op", + prefix = "slice_op", + deps = ARRAY_DEPS + [":strided_slice_op"], +) + +tf_kernel_library( name = "split_op", gpu_srcs = ["cuda_device_array.h"], prefix = "split_op", @@ -489,11 +623,35 @@ tf_kernel_library( ) tf_kernel_library( + name = "tile_ops", + prefix = "tile_ops", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "transpose_op", + prefix = "transpose_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( + name = "unique_op", + prefix = "unique_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( name = "unpack_op", prefix = "unpack_op", deps = ARRAY_DEPS + [":split_lib"], ) +tf_kernel_library( + name = "where_op", + prefix = "where_op", + deps = ARRAY_DEPS, +) + tf_cc_test( name = "batch_norm_op_test", size = "small", @@ -918,84 +1076,167 @@ tf_cc_test( ], ) -tf_kernel_libraries( +cc_library( name = "data_flow", - libs = [ - ":dynamic", - ":lookup", - ], - prefixes = [ - "conditional_accumulator_base_op", - "conditional_accumulator_op", - "barrier_ops", - "fifo_queue_op", - "priority_queue_op", - "padding_fifo_queue_op", - "queue_ops", - "random_shuffle_queue_op", - "session_ops", - "sparse_conditional_accumulator_op", - "stack_ops", - "tensor_array_ops", - ], deps = [ - ":bounds_check", - ":concat_lib", - ":conditional_accumulator", - ":conditional_accumulator_base", - ":fifo_queue", - ":initializable_lookup_table", - ":lookup_util", - ":padding_fifo_queue", - ":priority_queue", - ":queue_base", - ":queue_op", - ":sparse_conditional_accumulator", - ":split_lib", - ":tensor_array", - ":typed_conditional_accumulator_base", - ":typed_queue", - "//tensorflow/core:core_cpu", - "//tensorflow/core:data_flow_ops_op_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//third_party/eigen3", - ], + ":barrier_ops", + ":conditional_accumulator_base_op", + ":conditional_accumulator_op", + ":dynamic_partition_op", + ":dynamic_stitch_op", + ":fifo_queue_op", + ":lookup_table_init_op", + ":lookup_table_op", + ":padding_fifo_queue_op", + ":priority_queue_op", + ":queue_ops", + ":random_shuffle_queue_op", + ":session_ops", + ":sparse_conditional_accumulator_op", + ":stack_ops", + ":tensor_array_ops", + ], +) + +DATA_FLOW_DEPS = [ + ":bounds_check", + ":concat_lib", + ":conditional_accumulator", + ":conditional_accumulator_base", + ":fifo_queue", + ":initializable_lookup_table", + ":lookup_util", + ":padding_fifo_queue", + ":priority_queue", + ":queue_base", + ":queue_op", + ":sparse_conditional_accumulator", + ":split_lib", + ":tensor_array", + ":typed_conditional_accumulator_base", + ":typed_queue", + "//third_party/eigen3", + "//tensorflow/core:core_cpu", + "//tensorflow/core:data_flow_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", +] + +tf_kernel_library( + name = "conditional_accumulator_base_op", + prefix = "conditional_accumulator_base_op", + deps = DATA_FLOW_DEPS, ) -tf_kernel_libraries( - name = "dynamic", - prefixes = [ - "dynamic_partition_op", - "dynamic_stitch_op", - ], - deps = [ - ":bounds_check", - "//tensorflow/core:core_cpu", - "//tensorflow/core:data_flow_ops_op_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], +tf_kernel_library( + name = "conditional_accumulator_op", + prefix = "conditional_accumulator_op", + deps = DATA_FLOW_DEPS, ) -tf_kernel_libraries( - name = "lookup", - prefixes = [ - "lookup_table_init_op", - "lookup_table_op", - ], - deps = [ - ":bounds_check", - ":initializable_lookup_table", - ":lookup_util", - "//tensorflow/core:core_cpu", - "//tensorflow/core:data_flow_ops_op_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], +tf_kernel_library( + name = "barrier_ops", + prefix = "barrier_ops", + deps = DATA_FLOW_DEPS, +) + +tf_kernel_library( + name = "fifo_queue_op", + prefix = "fifo_queue_op", + deps = DATA_FLOW_DEPS, +) + +tf_kernel_library( + name = "padding_fifo_queue_op", + prefix = "padding_fifo_queue_op", + deps = DATA_FLOW_DEPS, +) + +tf_kernel_library( + name = "priority_queue_op", + prefix = "priority_queue_op", + deps = DATA_FLOW_DEPS, +) + +tf_kernel_library( + name = "queue_ops", + prefix = "queue_ops", + deps = DATA_FLOW_DEPS, +) + +tf_kernel_library( + name = "random_shuffle_queue_op", + prefix = "random_shuffle_queue_op", + deps = DATA_FLOW_DEPS, +) + +tf_kernel_library( + name = "session_ops", + prefix = "session_ops", + deps = DATA_FLOW_DEPS, +) + +tf_kernel_library( + name = "sparse_conditional_accumulator_op", + prefix = "sparse_conditional_accumulator_op", + deps = DATA_FLOW_DEPS, +) + +tf_kernel_library( + name = "stack_ops", + prefix = "stack_ops", + deps = DATA_FLOW_DEPS, +) + +tf_kernel_library( + name = "tensor_array_ops", + prefix = "tensor_array_ops", + deps = DATA_FLOW_DEPS, +) + +DYNAMIC_DEPS = [ + ":bounds_check", + "//tensorflow/core:core_cpu", + "//tensorflow/core:data_flow_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", +] + +tf_kernel_library( + name = "dynamic_partition_op", + prefix = "dynamic_partition_op", + deps = DYNAMIC_DEPS, +) + +tf_kernel_library( + name = "dynamic_stitch_op", + prefix = "dynamic_stitch_op", + deps = DYNAMIC_DEPS, +) + +LOOKUP_DEPS = [ + ":bounds_check", + ":initializable_lookup_table", + ":lookup_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:data_flow_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", +] + +tf_kernel_library( + name = "lookup_table_init_op", + prefix = "lookup_table_init_op", + deps = LOOKUP_DEPS, +) + +tf_kernel_library( + name = "lookup_table_op", + prefix = "lookup_table_op", + deps = LOOKUP_DEPS, ) tf_cc_tests( @@ -1136,41 +1377,150 @@ tf_kernel_library( ], ) -tf_kernel_libraries( +cc_library( name = "image", - prefixes = [ - "adjust_contrast_op", - "adjust_hue_op", - "colorspace_op", - "crop_and_resize_op", - "decode_jpeg_op", - "decode_png_op", - "decode_gif_op", - "draw_bounding_box_op", - "encode_jpeg_op", - "attention_ops", - "encode_png_op", - "non_max_suppression_op", - "random_crop_op", - "resize_area_op", - "resize_bicubic_op", - "resize_bilinear_op", - "resize_nearest_neighbor_op", - "sample_distorted_bounding_box_op", - ], deps = [ - ":bounds_check", - ":eigen_helpers", - ":image_resizer_state", - "//tensorflow/core:framework", - "//tensorflow/core:gif_internal", - "//tensorflow/core:image_ops_op_lib", - "//tensorflow/core:jpeg_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//third_party/eigen3", - ], + ":adjust_contrast_op", + ":adjust_hue_op", + ":attention_ops", + ":colorspace_op", + ":crop_and_resize_op", + ":decode_gif_op", + ":decode_jpeg_op", + ":decode_png_op", + ":draw_bounding_box_op", + ":encode_jpeg_op", + ":encode_png_op", + ":non_max_suppression_op", + ":random_crop_op", + ":resize_area_op", + ":resize_bicubic_op", + ":resize_bilinear_op", + ":resize_nearest_neighbor_op", + ":sample_distorted_bounding_box_op", + ], +) + +IMAGE_DEPS = [ + ":bounds_check", + ":eigen_helpers", + ":image_resizer_state", + "//third_party/eigen3", + "//tensorflow/core:framework", + "//tensorflow/core:gif_internal", + "//tensorflow/core:image_ops_op_lib", + "//tensorflow/core:jpeg_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", +] + +tf_kernel_library( + name = "adjust_contrast_op", + prefix = "adjust_contrast_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "adjust_hue_op", + prefix = "adjust_hue_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "attention_ops", + prefix = "attention_ops", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "colorspace_op", + prefix = "colorspace_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "crop_and_resize_op", + prefix = "crop_and_resize_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "decode_jpeg_op", + prefix = "decode_jpeg_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "decode_png_op", + prefix = "decode_png_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "decode_gif_op", + prefix = "decode_gif_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "draw_bounding_box_op", + prefix = "draw_bounding_box_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "encode_jpeg_op", + prefix = "encode_jpeg_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "encode_png_op", + prefix = "encode_png_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "non_max_suppression_op", + prefix = "non_max_suppression_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "random_crop_op", + prefix = "random_crop_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "resize_area_op", + prefix = "resize_area_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "resize_bicubic_op", + prefix = "resize_bicubic_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "resize_bilinear_op", + prefix = "resize_bilinear_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "resize_nearest_neighbor_op", + prefix = "resize_nearest_neighbor_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( + name = "sample_distorted_bounding_box_op", + prefix = "sample_distorted_bounding_box_op", + deps = IMAGE_DEPS, ) tf_cc_tests( @@ -1254,47 +1604,102 @@ tf_cuda_cc_test( ], ) -tf_kernel_libraries( +cc_library( name = "io", - libs = [":save_restore"], - prefixes = [ - "fixed_length_record_reader_op", - "identity_reader_op", - "matching_files_op", - "reader_ops", - "text_line_reader_op", - "tf_record_reader_op", - "whole_file_read_ops", - ], deps = [ - ":ops_util", - ":reader_base", - "//tensorflow/core:framework", - "//tensorflow/core:io_ops_op_lib", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/util/tensor_bundle", + ":fixed_length_record_reader_op", + ":identity_reader_op", + ":matching_files_op", + ":reader_ops", + ":restore_op", + ":save_op", + ":save_restore_v2_ops", + ":text_line_reader_op", + ":tf_record_reader_op", + ":whole_file_read_ops", ], ) -tf_kernel_libraries( - name = "save_restore", - prefixes = [ - "restore_op", - "save_op", - "save_restore_v2_ops", - ], - deps = [ - ":bounds_check_lib", - ":save_restore_tensor", - "//tensorflow/core:framework", - "//tensorflow/core:io_ops_op_lib", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/util/tensor_bundle", - ], +IO_DEPS = [ + ":ops_util", + ":reader_base", + "//tensorflow/core:framework", + "//tensorflow/core:io_ops_op_lib", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/util/tensor_bundle", +] + +tf_kernel_library( + name = "fixed_length_record_reader_op", + prefix = "fixed_length_record_reader_op", + deps = IO_DEPS, +) + +tf_kernel_library( + name = "identity_reader_op", + prefix = "identity_reader_op", + deps = IO_DEPS, +) + +tf_kernel_library( + name = "matching_files_op", + prefix = "matching_files_op", + deps = IO_DEPS, +) + +tf_kernel_library( + name = "reader_ops", + prefix = "reader_ops", + deps = IO_DEPS, +) + +SAVE_RESTORE_DEPS = [ + ":bounds_check_lib", + ":save_restore_tensor", + "//tensorflow/core:framework", + "//tensorflow/core:io_ops_op_lib", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/util/tensor_bundle", +] + +tf_kernel_library( + name = "restore_op", + prefix = "restore_op", + deps = SAVE_RESTORE_DEPS, +) + +tf_kernel_library( + name = "save_op", + prefix = "save_op", + deps = SAVE_RESTORE_DEPS, +) + +tf_kernel_library( + name = "save_restore_v2_ops", + prefix = "save_restore_v2_ops", + deps = SAVE_RESTORE_DEPS, +) + +tf_kernel_library( + name = "text_line_reader_op", + prefix = "text_line_reader_op", + deps = IO_DEPS, +) + +tf_kernel_library( + name = "tf_record_reader_op", + prefix = "tf_record_reader_op", + deps = IO_DEPS, +) + +tf_kernel_library( + name = "whole_file_read_ops", + prefix = "whole_file_read_ops", + deps = IO_DEPS, ) tf_cc_tests( @@ -1323,30 +1728,97 @@ tf_cc_tests( ], ) -tf_kernel_libraries( +cc_library( name = "linalg", - prefixes = [ - "cholesky_op", - "cholesky_grad", - "determinant_op", - "self_adjoint_eig_op", - "self_adjoint_eig_v2_op", - "matrix_inverse_op", - "matrix_solve_ls_op", - "matrix_solve_op", - "matrix_triangular_solve_op", - "qr_op", - "svd_op", - ], deps = [ - ":linalg_ops_common", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:linalg_ops_op_lib", - "//third_party/eigen3", + ":cholesky_grad", + ":cholesky_op", + ":determinant_op", + ":matrix_inverse_op", + ":matrix_solve_ls_op", + ":matrix_solve_op", + ":matrix_triangular_solve_op", + ":qr_op", + ":self_adjoint_eig_op", + ":self_adjoint_eig_v2_op", + ":svd_op", ], ) +LINALG_DEPS = [ + ":linalg_ops_common", + "//third_party/eigen3", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:linalg_ops_op_lib", +] + +tf_kernel_library( + name = "cholesky_op", + prefix = "cholesky_op", + deps = LINALG_DEPS, +) + +tf_kernel_library( + name = "cholesky_grad", + prefix = "cholesky_grad", + deps = LINALG_DEPS, +) + +tf_kernel_library( + name = "determinant_op", + prefix = "determinant_op", + deps = LINALG_DEPS, +) + +tf_kernel_library( + name = "self_adjoint_eig_op", + prefix = "self_adjoint_eig_op", + deps = LINALG_DEPS, +) + +tf_kernel_library( + name = "self_adjoint_eig_v2_op", + prefix = "self_adjoint_eig_v2_op", + deps = LINALG_DEPS, +) + +tf_kernel_library( + name = "matrix_inverse_op", + prefix = "matrix_inverse_op", + deps = LINALG_DEPS, +) + +tf_kernel_library( + name = "matrix_solve_ls_op", + prefix = "matrix_solve_ls_op", + deps = LINALG_DEPS, +) + +tf_kernel_library( + name = "matrix_solve_op", + prefix = "matrix_solve_op", + deps = LINALG_DEPS, +) + +tf_kernel_library( + name = "matrix_triangular_solve_op", + prefix = "matrix_triangular_solve_op", + deps = LINALG_DEPS, +) + +tf_kernel_library( + name = "qr_op", + prefix = "qr_op", + deps = LINALG_DEPS, +) + +tf_kernel_library( + name = "svd_op", + prefix = "svd_op", + deps = LINALG_DEPS, +) + cc_library( name = "linalg_ops_common", srcs = ["linalg_ops_common.cc"], @@ -1359,24 +1831,55 @@ cc_library( ], ) -tf_kernel_libraries( +cc_library( name = "logging", - prefixes = [ - "logging_ops", - "summary_audio_op", - "summary_image_op", - "summary_op", - "summary_tensor_op", - ], deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:logging_ops_op_lib", - "//tensorflow/core:protos_all_cc", + ":logging_ops", + ":summary_audio_op", + ":summary_image_op", + ":summary_op", + ":summary_tensor_op", ], ) +LOGGING_DEPS = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:logging_ops_op_lib", + "//tensorflow/core:protos_all_cc", +] + +tf_kernel_library( + name = "logging_ops", + prefix = "logging_ops", + deps = LOGGING_DEPS, +) + +tf_kernel_library( + name = "summary_audio_op", + prefix = "summary_audio_op", + deps = LOGGING_DEPS, +) + +tf_kernel_library( + name = "summary_image_op", + prefix = "summary_image_op", + deps = LOGGING_DEPS, +) + +tf_kernel_library( + name = "summary_op", + prefix = "summary_op", + deps = LOGGING_DEPS, +) + +tf_kernel_library( + name = "summary_tensor_op", + prefix = "summary_tensor_op", + deps = LOGGING_DEPS, +) + tf_cc_tests( size = "small", srcs = [ @@ -1411,32 +1914,120 @@ MATH_DEPS = [ "//third_party/eigen3", ] -tf_kernel_libraries( +cc_library( name = "math_not_windows", - prefixes = [ - "sparse_matmul_op", + deps = [ + ":sparse_matmul_op", ], +) + +tf_kernel_library( + name = "sparse_matmul_op", + prefix = "sparse_matmul_op", deps = MATH_DEPS, ) -tf_kernel_libraries( +cc_library( name = "math", - prefixes = [ - "aggregate_ops", - "argmax_op", - "batch_matmul_op", - "betainc_op", - "cast_op", - "check_numerics_op", - "cross_op", - "cwise_op", - "fft_ops", - "matmul_op", - "reduction_ops", - "segment_reduction_ops", - "scan_ops", - "sequence_ops", + deps = [ + ":aggregate_ops", + ":argmax_op", + ":batch_matmul_op", + ":betainc_op", + ":cast_op", + ":check_numerics_op", + ":cross_op", + ":cwise_op", + ":fft_ops", + ":matmul_op", + ":reduction_ops", + ":scan_ops", + ":segment_reduction_ops", + ":sequence_ops", ], +) + +tf_kernel_library( + name = "aggregate_ops", + prefix = "aggregate_ops", + deps = MATH_DEPS, +) + +tf_kernel_library( + name = "argmax_op", + prefix = "argmax_op", + deps = MATH_DEPS, +) + +tf_kernel_library( + name = "batch_matmul_op", + prefix = "batch_matmul_op", + deps = MATH_DEPS, +) + +tf_kernel_library( + name = "betainc_op", + prefix = "betainc_op", + deps = MATH_DEPS, +) + +tf_kernel_library( + name = "cast_op", + prefix = "cast_op", + deps = MATH_DEPS, +) + +tf_kernel_library( + name = "check_numerics_op", + prefix = "check_numerics_op", + deps = MATH_DEPS, +) + +tf_kernel_library( + name = "cross_op", + prefix = "cross_op", + deps = MATH_DEPS, +) + +tf_kernel_library( + name = "cwise_op", + prefix = "cwise_op", + deps = MATH_DEPS, +) + +tf_kernel_library( + name = "fft_ops", + prefix = "fft_ops", + deps = MATH_DEPS, +) + +tf_kernel_library( + name = "matmul_op", + prefix = "matmul_op", + deps = MATH_DEPS, +) + +tf_kernel_library( + name = "reduction_ops", + prefix = "reduction_ops", + deps = MATH_DEPS, +) + +tf_kernel_library( + name = "segment_reduction_ops", + prefix = "segment_reduction_ops", + deps = MATH_DEPS, +) + +tf_kernel_library( + name = "scan_ops", + prefix = "scan_ops", + deps = MATH_DEPS, +) + +tf_kernel_library( + name = "sequence_ops", + prefix = "sequence_ops", deps = MATH_DEPS, ) @@ -1696,42 +2287,104 @@ tf_kernel_library( ], ) -tf_kernel_libraries( +cc_library( name = "nn", - libs = [ - ":l2loss_op", - ], - prefixes = [ - "batch_norm_op", - "bias_op", - "fused_batch_norm_op", - "in_topk_op", - "lrn_op", - "relu_op", - "softmax_op", - "softplus_op", - "softsign_op", - "topk_op", - "xent_op", - ], deps = [ - ":bounds_check", - ":conv_2d", + ":batch_norm_op", + ":bias_op", ":conv_ops", ":dilation_ops", - ":fused_batch_norm_util_gpu", - ":ops_util", - ":pooling_ops", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:nn_grad", - "//tensorflow/core:nn_ops_op_lib", - "//third_party/eigen3", - ] + if_not_windows([ - ":depthwise_conv_grad_op", - ":depthwise_conv_op", - ]), + ":fused_batch_norm_op", + ":in_topk_op", + ":l2loss_op", + ":lrn_op", + ":relu_op", + ":softmax_op", + ":softplus_op", + ":softsign_op", + ":topk_op", + ":xent_op", + ] + if_not_windows([":depthwise_conv_op"]), +) + +NN_DEPS = [ + ":bounds_check", + ":conv_2d", + ":fused_batch_norm_util_gpu", + ":ops_util", + ":pooling_ops", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:nn_grad", + "//tensorflow/core:nn_ops_op_lib", + "//third_party/eigen3", +] + +tf_kernel_library( + name = "batch_norm_op", + prefix = "batch_norm_op", + deps = NN_DEPS, +) + +tf_kernel_library( + name = "bias_op", + prefix = "bias_op", + deps = NN_DEPS, +) + +tf_kernel_library( + name = "fused_batch_norm_op", + prefix = "fused_batch_norm_op", + deps = NN_DEPS, +) + +tf_kernel_library( + name = "in_topk_op", + prefix = "in_topk_op", + deps = NN_DEPS, +) + +tf_kernel_library( + name = "lrn_op", + prefix = "lrn_op", + deps = NN_DEPS, +) + +tf_kernel_library( + name = "relu_op", + prefix = "relu_op", + deps = NN_DEPS, +) + +tf_kernel_library( + name = "softmax_op", + prefix = "softmax_op", + deps = NN_DEPS, +) + +tf_kernel_library( + name = "softplus_op", + prefix = "softplus_op", + deps = NN_DEPS, +) + +tf_kernel_library( + name = "softsign_op", + prefix = "softsign_op", + deps = NN_DEPS, +) + +tf_kernel_library( + name = "topk_op", + prefix = "topk_op", + deps = NN_DEPS, +) + +tf_kernel_library( + name = "xent_op", + prefix = "xent_op", + deps = NN_DEPS, ) tf_kernel_library( @@ -1965,39 +2618,83 @@ tf_kernel_library( ], ) -tf_kernel_libraries( +cc_library( name = "parsing", - prefixes = [ - "decode_csv_op", - "decode_raw_op", - "example_parsing_ops", - "parse_tensor_op", - "string_to_number_op", - ], deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:parsing_ops_op_lib", - "//tensorflow/core:proto_text", - "//tensorflow/core:protos_all_cc", + ":decode_csv_op", + ":decode_raw_op", + ":example_parsing_ops", + ":parse_tensor_op", + ":string_to_number_op", ], ) -tf_kernel_libraries( +PARSING_DEPS = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:parsing_ops_op_lib", + "//tensorflow/core:proto_text", + "//tensorflow/core:protos_all_cc", +] + +tf_kernel_library( + name = "decode_csv_op", + prefix = "decode_csv_op", + deps = PARSING_DEPS, +) + +tf_kernel_library( + name = "decode_raw_op", + prefix = "decode_raw_op", + deps = PARSING_DEPS, +) + +tf_kernel_library( + name = "example_parsing_ops", + prefix = "example_parsing_ops", + deps = PARSING_DEPS, +) + +tf_kernel_library( + name = "parse_tensor_op", + prefix = "parse_tensor_op", + deps = PARSING_DEPS, +) + +tf_kernel_library( + name = "string_to_number_op", + prefix = "string_to_number_op", + deps = PARSING_DEPS, +) + +cc_library( name = "random_ops", - prefixes = [ - "random_op", - "random_shuffle_op", - ], deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:random_ops_op_lib", + ":random_op", + ":random_shuffle_op", ], ) +RANDOM_OPS_DEPS = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:random_ops_op_lib", +] + +tf_kernel_library( + name = "random_op", + prefix = "random_op", + deps = RANDOM_OPS_DEPS, +) + +tf_kernel_library( + name = "random_shuffle_op", + prefix = "random_shuffle_op", + deps = RANDOM_OPS_DEPS, +) + tf_cuda_cc_test( name = "random_op_test", size = "small", @@ -2013,52 +2710,162 @@ tf_cuda_cc_test( ], ) -tf_kernel_libraries( +cc_library( name = "required", - prefixes = [ - "no_op", - "sendrecv_ops", - ], deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:no_op_op_lib", - "//tensorflow/core:sendrecv_ops_op_lib", + ":no_op", + ":sendrecv_ops", ], ) -tf_kernel_libraries( +REQUIRED_DEPS = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:no_op_op_lib", + "//tensorflow/core:sendrecv_ops_op_lib", +] + +tf_kernel_library( + name = "no_op", + prefix = "no_op", + deps = REQUIRED_DEPS, +) + +tf_kernel_library( + name = "sendrecv_ops", + prefix = "sendrecv_ops", + deps = REQUIRED_DEPS, +) + +cc_library( name = "sparse", - prefixes = [ - "sparse_add_grad_op", - "sparse_add_op", - "sparse_concat_op", - "sparse_reduce_sum_op", - "sparse_dense_binary_op_shared", - "sparse_sparse_binary_op_shared", - "sparse_reorder_op", - "sparse_reshape_op", - "sparse_softmax", - "sparse_split_op", - "sparse_tensor_dense_add_op", - "sparse_tensor_dense_matmul_op", - "sparse_to_dense_op", - "sparse_xent_op", - "serialize_sparse_op", - "sparse_tensors_map_ops", - ], deps = [ - ":bounds_check", - ":cwise_op", - ":fill_functor", - ":scatter_functor", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:sparse_ops_op_lib", - "//third_party/eigen3", + ":serialize_sparse_op", + ":sparse_add_grad_op", + ":sparse_add_op", + ":sparse_concat_op", + ":sparse_dense_binary_op_shared", + ":sparse_reduce_sum_op", + ":sparse_reorder_op", + ":sparse_reshape_op", + ":sparse_softmax", + ":sparse_sparse_binary_op_shared", + ":sparse_split_op", + ":sparse_tensor_dense_add_op", + ":sparse_tensor_dense_matmul_op", + ":sparse_tensors_map_ops", + ":sparse_to_dense_op", + ":sparse_xent_op", ], ) +SPARSE_DEPS = [ + ":bounds_check", + ":cwise_op", + ":fill_functor", + ":scatter_functor", + "//third_party/eigen3", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:sparse_ops_op_lib", +] + +tf_kernel_library( + name = "sparse_add_grad_op", + prefix = "sparse_add_grad_op", + deps = SPARSE_DEPS, +) + +tf_kernel_library( + name = "sparse_add_op", + prefix = "sparse_add_op", + deps = SPARSE_DEPS, +) + +tf_kernel_library( + name = "sparse_concat_op", + prefix = "sparse_concat_op", + deps = SPARSE_DEPS, +) + +tf_kernel_library( + name = "sparse_reduce_sum_op", + prefix = "sparse_reduce_sum_op", + deps = SPARSE_DEPS, +) + +tf_kernel_library( + name = "sparse_dense_binary_op_shared", + prefix = "sparse_dense_binary_op_shared", + deps = SPARSE_DEPS, +) + +tf_kernel_library( + name = "sparse_sparse_binary_op_shared", + prefix = "sparse_sparse_binary_op_shared", + deps = SPARSE_DEPS, +) + +tf_kernel_library( + name = "sparse_reorder_op", + prefix = "sparse_reorder_op", + deps = SPARSE_DEPS, +) + +tf_kernel_library( + name = "sparse_reshape_op", + prefix = "sparse_reshape_op", + deps = SPARSE_DEPS, +) + +tf_kernel_library( + name = "sparse_softmax", + prefix = "sparse_softmax", + deps = SPARSE_DEPS, +) + +tf_kernel_library( + name = "sparse_split_op", + prefix = "sparse_split_op", + deps = SPARSE_DEPS, +) + +tf_kernel_library( + name = "sparse_tensor_dense_add_op", + prefix = "sparse_tensor_dense_add_op", + deps = SPARSE_DEPS, +) + +tf_kernel_library( + name = "sparse_tensor_dense_matmul_op", + prefix = "sparse_tensor_dense_matmul_op", + deps = SPARSE_DEPS, +) + +tf_kernel_library( + name = "sparse_to_dense_op", + prefix = "sparse_to_dense_op", + deps = SPARSE_DEPS, +) + +tf_kernel_library( + name = "sparse_xent_op", + prefix = "sparse_xent_op", + deps = SPARSE_DEPS, +) + +tf_kernel_library( + name = "serialize_sparse_op", + prefix = "serialize_sparse_op", + deps = SPARSE_DEPS, +) + +tf_kernel_library( + name = "sparse_tensors_map_ops", + prefix = "sparse_tensors_map_ops", + deps = SPARSE_DEPS, +) + tf_cuda_cc_tests( size = "small", srcs = [ @@ -2151,27 +2958,58 @@ cc_library( ], ) -tf_kernel_libraries( +cc_library( name = "state", - prefixes = [ - "count_up_to_op", - "dense_update_ops", - "scatter_op", - "scatter_nd_op", - "variable_ops", - ], deps = [ - ":assign_op", - ":bounds_check", - ":fill_functor", - ":scatter_functor", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:state_ops_op_lib", - "//third_party/eigen3", + ":count_up_to_op", + ":dense_update_ops", + ":scatter_nd_op", + ":scatter_op", + ":variable_ops", ], ) +STATE_DEPS = [ + ":assign_op", + ":bounds_check", + ":fill_functor", + ":scatter_functor", + "//third_party/eigen3", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:state_ops_op_lib", +] + +tf_kernel_library( + name = "count_up_to_op", + prefix = "count_up_to_op", + deps = STATE_DEPS, +) + +tf_kernel_library( + name = "dense_update_ops", + prefix = "dense_update_ops", + deps = STATE_DEPS, +) + +tf_kernel_library( + name = "scatter_op", + prefix = "scatter_op", + deps = STATE_DEPS, +) + +tf_kernel_library( + name = "scatter_nd_op", + prefix = "scatter_nd_op", + deps = STATE_DEPS, +) + +tf_kernel_library( + name = "variable_ops", + prefix = "variable_ops", + deps = STATE_DEPS, +) + tf_cc_test( name = "scatter_op_test", size = "small", @@ -2208,27 +3046,70 @@ tf_cc_test( ], ) -tf_kernel_libraries( +cc_library( name = "string", - prefixes = [ - "string_to_hash_bucket_op", - "reduce_join_op", - "string_join_op", - "string_split_op", - "substr_op", - "as_string_op", - "base64_ops", - ], deps = [ - ":bounds_check", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:string_ops_op_lib", - "//third_party/eigen3", + ":as_string_op", + ":base64_ops", + ":reduce_join_op", + ":string_join_op", + ":string_split_op", + ":string_to_hash_bucket_op", + ":substr_op", ], ) +STRING_DEPS = [ + ":bounds_check", + "//third_party/eigen3", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:string_ops_op_lib", +] + +tf_kernel_library( + name = "string_to_hash_bucket_op", + prefix = "string_to_hash_bucket_op", + deps = STRING_DEPS, +) + +tf_kernel_library( + name = "reduce_join_op", + prefix = "reduce_join_op", + deps = STRING_DEPS, +) + +tf_kernel_library( + name = "string_join_op", + prefix = "string_join_op", + deps = STRING_DEPS, +) + +tf_kernel_library( + name = "string_split_op", + prefix = "string_split_op", + deps = STRING_DEPS, +) + +tf_kernel_library( + name = "substr_op", + prefix = "substr_op", + deps = STRING_DEPS, +) + +tf_kernel_library( + name = "as_string_op", + prefix = "as_string_op", + deps = STRING_DEPS, +) + +tf_kernel_library( + name = "base64_ops", + prefix = "base64_ops", + deps = STRING_DEPS, +) + tf_kernel_library( name = "training_ops", prefix = "training_ops", @@ -2398,6 +3279,10 @@ filegroup( "matmul_op.h", "no_op.cc", "no_op.h", + "non_max_suppression_op.cc", + "non_max_suppression_op.h", + "one_hot_op.cc", + "one_hot_op.h", "ops_util.h", "pack_op.cc", "pooling_ops_common.h", diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 7ab37a8abd..2d1b21d9e4 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -507,7 +507,8 @@ void LaunchConv2DOp<GPUDevice, T>::launch( transformed_output.template flat<T>().size()); static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( - "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default + // default value is in bytes despite the name of the environment variable + "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB ); int device_id = stream->parent()->device_ordinal(); diff --git a/tensorflow/core/kernels/eigen_pooling.h b/tensorflow/core/kernels/eigen_pooling.h index 8eea1b0f9d..e13c8b9835 100644 --- a/tensorflow/core/kernels/eigen_pooling.h +++ b/tensorflow/core/kernels/eigen_pooling.h @@ -329,7 +329,11 @@ struct AvgPoolMeanReducer { } #if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__) -#ifdef EIGEN_VECTORIZE_AVX +#ifdef EIGEN_VECTORIZE_AVX512 +#define pequal(a, b) \ + _mm512_maskz_set1_epi32(_mm512_cmp_ps_mask(a, b, _CMP_EQ_UQ), -1) +#define psel(a, b, false_mask) _mm512_ternarylogic_epi64(false_mask, a, b, 0xca) +#elif defined EIGEN_VECTORIZE_AVX #define pequal(a, b) _mm256_cmp_ps(a, b, _CMP_EQ_UQ) #define psel(a, b, false_mask) _mm256_blendv_ps(a, b, false_mask) #else diff --git a/tensorflow/core/kernels/hexagon/BUILD b/tensorflow/core/kernels/hexagon/BUILD index f6af111493..444180f986 100644 --- a/tensorflow/core/kernels/hexagon/BUILD +++ b/tensorflow/core/kernels/hexagon/BUILD @@ -59,6 +59,10 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:quantized_ops", + "//tensorflow/core/kernels:reduction_ops", + "//tensorflow/core/kernels:reshape_op", + "//tensorflow/core/kernels:softmax_op", ], ) diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc index 422be39d54..da13d64052 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/tensor_slice_writer.h" @@ -31,6 +32,9 @@ namespace tensorflow { constexpr bool DBG_DUMP_VERIFICATION_STRING = false; constexpr bool DBG_DUMP_PARAMS = false; +const string RESHAPE_NODE_TYPE_STRING = "Reshape"; +const string SOURCE_NODE_NAME = "_SOURCE"; +const string SINK_NODE_NAME = "_SINK"; const string INPUTS_NODE_PREFIX = "inputs_for_"; const string OUTPUTS_NODE_PREFIX = "outputs_for_"; const string DATA_NODE_PREFIX = "data_for_op_"; @@ -83,9 +87,13 @@ Status GraphTransferer::LoadGraphFromProto( } for (const Node* const node : graph.nodes()) { - RegisterNodeIfAllInputsAreCached(ops_definitions, shape_refiner, *node, - false, input_node_info_list, - output_node_names, output_tensor_map); + status = RegisterNodeIfAllInputsAreCached( + ops_definitions, shape_refiner, *node, false, input_node_info_list, + output_node_names, output_tensor_map); + if (!status.ok()) { + LOG(ERROR) << "Failed to transfer graph " << status; + return status; + } } ClearCache(); if (DBG_DUMP_PARAMS) { @@ -101,11 +109,13 @@ Status GraphTransferer::LoadGraphFromProtoFile( const IGraphTransferOpsDefinitions& ops_definitions, const string& graph_def_path, const std::vector<InputNodeInfo>& input_node_info_list, - const std::vector<string>& output_node_names, - const OutputTensorMap& output_tensor_map, const bool is_text_proto) { + const std::vector<string>& output_node_names, const bool is_text_proto, + const bool dry_run_for_unknown_shape, + OutputTensorInfo* output_tensor_info) { GraphDef graph_def; string output; Status status; + VLOG(1) << "Parse file " << graph_def_path; if (is_text_proto) { status = ReadFileToString(Env::Default(), graph_def_path, &output); if (!protobuf::TextFormat::ParseFromString(output, &graph_def)) { @@ -115,30 +125,21 @@ Status GraphTransferer::LoadGraphFromProtoFile( status = ReadBinaryProto(Env::Default(), graph_def_path, &graph_def); } if (!status.ok()) { + VLOG(1) << "Failed to load graph " << status; return status; } - return LoadGraphFromProto(ops_definitions, graph_def, input_node_info_list, - output_node_names, output_tensor_map); -} - -Status GraphTransferer::LoadGraphFromProtoFile( - const IGraphTransferOpsDefinitions& ops_definitions, - const string& graph_def_path, - const std::vector<InputNodeInfo>& input_node_info_list, - const std::vector<string>& output_node_names, - const OutputTensorMap& output_tensor_map) { - GraphDef graph_def; - string output; - Status status = ReadFileToString(Env::Default(), graph_def_path, &output); - if (!status.ok()) { - return status; - } - if (!protobuf::TextFormat::ParseFromString(output, &graph_def)) { - return errors::InvalidArgument("Cannot parse proto string."); + if (dry_run_for_unknown_shape) { + VLOG(1) << "Dry run graph to obtain shape of nodes"; + status = DryRunInferenceForAllNode(graph_def, input_node_info_list, true, + output_tensor_info); + if (!status.ok()) { + return status; + } } - LoadGraphFromProto(ops_definitions, graph_def, input_node_info_list, - output_node_names, output_tensor_map); - return Status(); + VLOG(1) << "Load graph with output tensors"; + return LoadGraphFromProto(ops_definitions, graph_def, input_node_info_list, + output_node_names, + output_tensor_info->output_tensor_map); } /** @@ -172,17 +173,17 @@ Status GraphTransferer::LoadGraphFromProtoFile( switch (data_type) { case DT_INT32: { auto int_tensor = input_tensor.flat<int32>(); - int_tensor = int_tensor.constant(0.0); + int_tensor = int_tensor.constant(0); break; } case DT_FLOAT: { auto float_tensor = input_tensor.flat<float>(); - float_tensor = float_tensor.constant(0.0); + float_tensor = float_tensor.constant(0.0f); break; } case DT_QUINT8: { auto int_tensor = input_tensor.flat<quint8>(); - int_tensor = int_tensor.constant(0.0); + int_tensor = int_tensor.constant(0); break; } default: @@ -234,7 +235,12 @@ Status GraphTransferer::LoadGraphFromProtoFile( const Status status = DryRunInference(graph_def, input_node_info_list, output_node_names, initialize_by_zero, &output_tensors); - CHECK(output_node_names.size() == output_tensors.size()); + if (!status.ok()) { + VLOG(1) << "Failed to dryrun " << status; + return status; + } + CHECK(output_node_names.size() == output_tensors.size()) + << output_node_names.size() << ", " << output_tensors.size(); // Append output tensor of input node in advance to create a map // to avoid memory reallocation inside vector @@ -257,6 +263,10 @@ Status GraphTransferer::LoadGraphFromProtoFile( return status; } +void GraphTransferer::EnableStrictCheckMode(const bool enable) { + strict_check_mode_ = enable; +} + const std::vector<GraphTransferer::ConstNodeTransferParams>& GraphTransferer::GetConstNodeParams() const { return const_node_transfer_params_list_; @@ -314,13 +324,16 @@ bool GraphTransferer::AreAllInputsCached(const Node& node) const { return true; } -void GraphTransferer::RegisterNode( +Status GraphTransferer::RegisterNode( const IGraphTransferOpsDefinitions& ops_definitions, const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map, const Node& node, const std::vector<InputNodeInfo>& input_node_info_list, const std::vector<string>& output_node_names) { VLOG(1) << "Register node: " << node.name(); - if (IsInputNode(input_node_info_list, node.name())) { + if (node.name() == SOURCE_NODE_NAME || node.name() == SINK_NODE_NAME) { + // Just ignore sink and source + return Status(); + } else if (IsInputNode(input_node_info_list, node.name())) { RegisterInputNode(ops_definitions, shape_refiner, output_tensor_map, node); } else if (std::find(output_node_names.begin(), output_node_names.end(), node.name()) != output_node_names.end()) { @@ -330,10 +343,18 @@ void GraphTransferer::RegisterNode( } else if (HasPaddingAndStrides(node)) { RegisterNodeWithPaddingAndStrides(ops_definitions, shape_refiner, output_tensor_map, node); + } else if (IsNodeFlattenReshape(node, output_tensor_map, shape_refiner)) { + RegisterFlattenNode(ops_definitions, shape_refiner, output_tensor_map, + node); + } else if (ops_definitions.GetOpIdFor(node.type_string()) != + IGraphTransferOpsDefinitions::INVALID_OP_ID) { + RegisterGenericNode(ops_definitions, shape_refiner, output_tensor_map, + node); } else { - // TODO(satok): register params for nodes which are supported by SOC - VLOG(1) << "Not implemented for " << node.type_string(); + return errors::InvalidArgument(node.type_string() + + " has not implemented yet."); } + return Status(); } void GraphTransferer::RegisterConstantNode( @@ -348,8 +369,8 @@ void GraphTransferer::RegisterConstantNode( // TODO(satok): support multiple outputs? const int output_index = 0; const DataType dt = node.output_type(output_index); - const size_t max_bytes_per_data = - checkpoint::TensorSliceWriter::MaxBytesPerElement(dt); + const size_t max_bytes_per_data = DataTypeSize(dt); + CHECK(max_bytes_per_data > 0); shape_inference::InferenceContext* context = shape_refiner.GetContext(&node); shape_inference::ShapeHandle shape_handle = context->output(output_index); const shape_inference::DimensionHandle num_elements_dim = @@ -402,6 +423,45 @@ bool GraphTransferer::HasPaddingAndStrides(const Node& node) { node.def().attr().count(STRIDES_ATTR_NAME) > 0; } +bool GraphTransferer::IsNodeFlattenReshape( + const Node& node, const OutputTensorMap& output_tensor_map, + const ShapeRefiner& shape_refiner) { + // Check if node is reshape op + if (node.type_string() != RESHAPE_NODE_TYPE_STRING) { + return false; + } + + shape_inference::InferenceContext* context = shape_refiner.GetContext(&node); + // Check if output count is valid + if (context->num_outputs() != 1) { + return false; + } + + shape_inference::ShapeHandle shape_handle = context->output(0); + std::array<int64, SHAPE_ARRAY_SIZE> shape; + const shape_inference::DimensionHandle dim_handle = + context->NumElements(shape_handle); + + // Obtain shape of output of node + if (context->ValueKnown(dim_handle)) { + shape = BuildShapeArray(shape_handle, context); + } else { + // Use output tensor for unknown shape + // TODO(stok): Remove this fallback + CHECK(!output_tensor_map.empty()); + const TensorShape& tensor_shape = + output_tensor_map.at(node.name())->shape(); + shape = ToTensorShapeArray(tensor_shape); + } + + // check if reshape op just does flatten + if (shape[0] == 1 && shape[1] == 1 && shape[2] == 1) { + return true; + } else { + return false; + } +} + void GraphTransferer::RegisterNodeWithPaddingAndStrides( const IGraphTransferOpsDefinitions& ops_definitions, const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map, @@ -428,7 +488,8 @@ void GraphTransferer::RegisterNodeWithPaddingAndStrides( padding == VALID ? PADDING_VALID_STR : PADDING_SAME_STR; const int op_type_id = ops_definitions.GetOpIdFor(node.type_string()); CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()) - << node.type_string(); + << "Op " << node.type_string() << " not found in map(id = " << op_type_id + << ")"; AppendNodeParamsWithIoParams(shape_refiner, output_tensor_map, node, node.name(), id, node.type_string(), op_type_id, padding_str, node.num_inputs(), extra_inputs, @@ -448,9 +509,8 @@ void GraphTransferer::RegisterInputNode( CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()); AppendNodeParamsWithIoParams( shape_refiner, output_tensor_map, node, node.name(), id, - IGraphTransferOpsDefinitions::INPUT_OP_NAME, op_type_id, PADDING_NA, - node.num_inputs(), {}, node.num_outputs(), true /* append_input */, - true /* append_output */); + node.type_string(), op_type_id, PADDING_NA, node.num_inputs(), {}, + node.num_outputs(), true /* append_input */, true /* append_output */); } void GraphTransferer::RegisterOutputNode( @@ -465,14 +525,47 @@ void GraphTransferer::RegisterOutputNode( CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()); // TODO(satok): Set output for output node? AppendNodeParamsWithIoParams( - shape_refiner, output_tensor_map, node, node.name(), id, op_type, - op_type_id, PADDING_NA, node.num_inputs(), {}, 0 /* outputs_size */, - true /* append_input */, false /* append_output */); + shape_refiner, output_tensor_map, node, node.name(), id, + node.type_string(), op_type_id, PADDING_NA, node.num_inputs(), {}, + 0 /* outputs_size */, true /* append_input */, false /* append_output */); +} + +void GraphTransferer::RegisterFlattenNode( + const IGraphTransferOpsDefinitions& ops_definitions, + const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map, + const Node& node) { + VLOG(1) << "Register flatten node: " << node.name(); + CHECK(node_name_to_id_cache_map_.count(node.name()) == 1); + const int id = node_name_to_id_cache_map_[node.name()]; + const string op_type = IGraphTransferOpsDefinitions::FLATTEN_OP_NAME; + const int op_type_id = ops_definitions.GetOpIdFor(op_type); + CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()); + + AppendNodeParamsWithIoParams( + shape_refiner, output_tensor_map, node, node.name(), id, + node.type_string(), op_type_id, PADDING_NA, node.num_inputs(), {}, + node.num_outputs(), true /* append_input */, true /* append_output */); +} + +void GraphTransferer::RegisterGenericNode( + const IGraphTransferOpsDefinitions& ops_definitions, + const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map, + const Node& node) { + VLOG(1) << "Register generic node: " << node.name(); + CHECK(node_name_to_id_cache_map_.count(node.name()) == 1); + const int id = node_name_to_id_cache_map_[node.name()]; + const int op_type_id = ops_definitions.GetOpIdFor(node.type_string()); + CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()); + + AppendNodeParamsWithIoParams( + shape_refiner, output_tensor_map, node, node.name(), id, + node.type_string(), op_type_id, PADDING_NA, node.num_inputs(), {}, + node.num_outputs(), true /* append_input */, true /* append_output */); } // TODO(satok): Remove this function. // TODO(satok): Remove only_register_const_node. -bool GraphTransferer::RegisterNodeIfAllInputsAreCached( +Status GraphTransferer::RegisterNodeIfAllInputsAreCached( const IGraphTransferOpsDefinitions& ops_definitions, const ShapeRefiner& shape_refiner, const Node& node, const bool only_register_const_node, @@ -480,12 +573,11 @@ bool GraphTransferer::RegisterNodeIfAllInputsAreCached( const std::vector<string>& output_node_names, const OutputTensorMap& output_tensor_map) { if (only_register_const_node && !node.IsConstant()) { - return false; + return Status(); } CHECK(AreAllInputsCached(node)); - RegisterNode(ops_definitions, shape_refiner, output_tensor_map, node, - input_node_info_list, output_node_names); - return true; + return RegisterNode(ops_definitions, shape_refiner, output_tensor_map, node, + input_node_info_list, output_node_names); } // CAVEAT: Append inputs and outputs params accordingly @@ -542,7 +634,7 @@ void GraphTransferer::AppendNodeOutputParams( output_node = output_edge->src(); } } - CHECK(output_node != nullptr); + CHECK(output_node != nullptr) << node.name() << ", " << node.type_string(); const int output_index = i; const DataType dt = node.output_type(output_index); const size_t max_bytes_per_data = @@ -556,11 +648,14 @@ void GraphTransferer::AppendNodeOutputParams( if (context->ValueKnown(num_elements_dim)) { const int64 num_output_elements = context->Value(num_elements_dim); data_size = max_bytes_per_data * num_output_elements; - if (!output_tensor_map.empty()) { + if (!output_tensor_map.empty() && strict_check_mode_) { CHECK(output_tensor_map.count(node.name()) == 1) << node.name(); const TensorShape& tensor_shape = output_tensor_map.at(node.name())->shape(); - CHECK(num_output_elements == tensor_shape.num_elements()); + CHECK(num_output_elements == tensor_shape.num_elements()) + << "num elements of node " << node.name() << " doesn't match " + << num_output_elements << " vs " << tensor_shape.num_elements() + << ", " << node.type_string(); } } else { // Use dryrun result to get the output data size @@ -718,7 +813,7 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const { for (const ConstNodeTransferParams& params : const_node_transfer_params_list_) { std::stringstream sstream; - sstream << "---(CONST) [" << std::hex << params.node_id << "," + sstream << "---(CONST) [" << std::hex << params.node_id << std::dec << "," << params.shape[0] << "," << params.shape[1] << "," << params.shape[2] << "," << params.shape[3] << "," << params.data_name << "," << params.data_size << "," << params.name @@ -729,7 +824,7 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const { for (const NodeTransferParams& params : node_transfer_params_list_) { std::stringstream sstream; sstream << "---(OP) [" << params.name.c_str() << "," << std::hex - << params.node_id << "," << params.soc_op_id << "," + << params.node_id << std::dec << "," << params.soc_op_id << "," << params.padding << "," << params.inputs_name << "," << params.inputs_size << "," << params.outputs_name << "," << params.outputs_size << "," << params.type << "]"; @@ -738,7 +833,7 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const { LOG(INFO) << "Op node count = " << node_transfer_params_list_.size(); for (const NodeInputParams& params : node_input_params_list_) { std::stringstream sstream; - sstream << "---(INPUT) [" << std::hex << params.node_id; + sstream << "---(INPUT) [" << std::hex << params.node_id << std::dec; for (const std::tuple<int, int>& pair : params.input_node_id_and_output_port_list) { sstream << "," << std::get<0>(pair) << "," << std::get<1>(pair); @@ -749,7 +844,7 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const { LOG(INFO) << "Input params count = " << node_input_params_list_.size(); for (const NodeOutputParams& params : node_output_params_list_) { std::stringstream sstream; - sstream << "---(OUTPUT) [" << std::hex << params.node_id; + sstream << "---(OUTPUT) [" << std::hex << params.node_id << std::dec; for (const int max_size : params.max_sizes) { sstream << "," << max_size; } diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h index 666c5889ad..71bd1d3375 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.h +++ b/tensorflow/core/kernels/hexagon/graph_transferer.h @@ -100,16 +100,9 @@ class GraphTransferer { const IGraphTransferOpsDefinitions& ops_definitions, const string& graph_def_path, const std::vector<InputNodeInfo>& input_node_info_list, - const std::vector<string>& output_node_names, - const OutputTensorMap& output_tensor_map, const bool is_text_proto); - - // Load graph structure into GraphTransferer from protobuf file - Status LoadGraphFromProtoFile( - const IGraphTransferOpsDefinitions& ops_definitions, - const string& graph_def_path, - const std::vector<InputNodeInfo>& input_node_info_list, - const std::vector<string>& output_node_names, - const OutputTensorMap& output_tensor_map); + const std::vector<string>& output_node_names, const bool is_text_proto, + const bool dry_run_for_unknown_shape, + OutputTensorInfo* output_tensor_info); // Dry run inference and cache the result to get memory mapping static Status DryRunInference( @@ -128,6 +121,8 @@ class GraphTransferer { const std::vector<InputNodeInfo>& input_node_info_list, const bool initialize_by_zero, OutputTensorInfo* output_tensor_info); + void EnableStrictCheckMode(bool enable); + // Return const node parameters for transfer const std::vector<ConstNodeTransferParams>& GetConstNodeParams() const; @@ -142,51 +137,84 @@ class GraphTransferer { private: int CacheNode(const Node& node); + static bool IsInputNode( const std::vector<InputNodeInfo>& input_node_info_list, const string& node_name); + bool AreAllInputsCached(const Node& node) const; - void RegisterNode(const IGraphTransferOpsDefinitions& ops_definitions, - const ShapeRefiner& shape_refiner, - const OutputTensorMap& output_tensor_map, const Node& node, - const std::vector<InputNodeInfo>& input_node_info_list, - const std::vector<string>& output_node_names); + + Status RegisterNode(const IGraphTransferOpsDefinitions& ops_definitions, + const ShapeRefiner& shape_refiner, + const OutputTensorMap& output_tensor_map, + const Node& node, + const std::vector<InputNodeInfo>& input_node_info_list, + const std::vector<string>& output_node_names); + void RegisterConstantNode(const ShapeRefiner& shape_refiner, const Node& node, const OutputTensorMap& output_tensor_map); + int RegisterConstantShape(const std::vector<int>& shape); + bool HasPaddingAndStrides(const Node& node); + + // Return true if the node is a reshape op which just flattens input + // TODO(satok): Remove this method once generic reshape op is implemented in + // SOC + bool IsNodeFlattenReshape(const Node& node, + const OutputTensorMap& output_tensor_map, + const ShapeRefiner& shape_refiner); + void RegisterNodeWithPaddingAndStrides( const IGraphTransferOpsDefinitions& ops_definitions, const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map, const Node& node); + void RegisterInputNode(const IGraphTransferOpsDefinitions& ops_definitions, const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map, const Node& node); + void RegisterOutputNode(const IGraphTransferOpsDefinitions& ops_definitions, const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map, const Node& node); - bool RegisterNodeIfAllInputsAreCached( + + void RegisterFlattenNode(const IGraphTransferOpsDefinitions& ops_definitions, + const ShapeRefiner& shape_refiner, + const OutputTensorMap& output_tensor_map, + const Node& node); + + void RegisterGenericNode(const IGraphTransferOpsDefinitions& ops_definitions, + const ShapeRefiner& shape_refiner, + const OutputTensorMap& output_tensor_map, + const Node& node); + + Status RegisterNodeIfAllInputsAreCached( const IGraphTransferOpsDefinitions& ops_definitions, const ShapeRefiner& shape_refiner, const Node& node, const bool only_register_const_node, const std::vector<InputNodeInfo>& input_node_info_list, const std::vector<string>& output_node_names, const OutputTensorMap& output_tensor_map); + void AppendNodeParams(const string& name, const int id, const string& type, const int type_id, const string& padding_str, const int inputs_size, const std::vector<int>& extra_inputs, const int outputs_size); + void AppendNodeInputParams(const int id, const Node& node, const std::vector<int>& extra_inputs); + void AppendNodeOutputParams(const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map, const int id, const Node& node); + static std::array<int64, SHAPE_ARRAY_SIZE> BuildShapeArray( const shape_inference::ShapeHandle& shape_handle, shape_inference::InferenceContext* context); + void AppendNodeParamsWithIoParams( const ShapeRefiner& shape_refiner, const OutputTensorMap& output_tensor_map, const Node& node, @@ -194,14 +222,19 @@ class GraphTransferer { const string& padding_str, const int inputs_size, const std::vector<int>& extra_inputs, const int outputs_size, const bool append_input_params, const bool append_output_params); + static std::array<int64, SHAPE_ARRAY_SIZE> ToTensorShapeArray( const TensorShape& shape); + static void CheckShape(const OutputTensorMap& output_tensor_map, const string& node_name, const std::array<int64, SHAPE_ARRAY_SIZE>& actual); + void ClearCache(); + // Dump pretty print of parameters void DumpNodeTransferParams() const; + // Dump verification string of parameters to verify with offline tools void DumpVerificationStringOfNodeTransferParams() const; @@ -213,6 +246,10 @@ class GraphTransferer { std::vector<const Node*> node_name_cache_list_; std::unordered_map<string, int> node_name_to_id_cache_map_; + // strict check mode is true by default. Disable this if the ops' shape + // inferences are not implemented correctly. + bool strict_check_mode_{true}; + TF_DISALLOW_COPY_AND_ASSIGN(GraphTransferer); }; diff --git a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc index 71f19e1ea2..23d57ff3e9 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc @@ -46,8 +46,8 @@ class GraphTransfererTest : public ::testing::Test { GraphTransferer gt_; }; -static const std::vector<string> OP_TYPES{"INPUT", "OUTPUT", "Conv2D", - "MaxPool"}; +static const std::vector<string> OP_TYPES{"INPUT", "OUTPUT", "Conv2D", + "MaxPool", "NoOp", "Add"}; const GraphTransferer::OutputTensorMap EMPTY_OUTPUT_TENSOR_MAP; class TestGraphTransferOpsDefinitions : public IGraphTransferOpsDefinitions { @@ -193,7 +193,8 @@ static void SanityCheckNodes(const GraphTransferer& gt) { TEST_F(GraphTransfererTest, LoadAddGraph) { GraphDef def = CreateAddGraphDef(); ASSERT_TRUE(gt_.LoadGraphFromProto(TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, def, - {}, {}, EMPTY_OUTPUT_TENSOR_MAP) + {}, std::vector<string>{NAME_A_PLUS_B}, + EMPTY_OUTPUT_TENSOR_MAP) .ok()); SanityCheckNodes(gt_); @@ -399,21 +400,29 @@ TEST(HexagonOpsDefinitions, CheckOpsDefinitions) { } TEST(GraphTransferer, LoadGraphFromProtoFile) { + const IGraphTransferOpsDefinitions* ops_definitions = + &TEST_GRAPH_TRANSFER_OPS_DEFINITIONS; string filename = io::JoinPath(testing::TensorFlowSrcRoot(), "core/example/testdata/parse_example_graph_def.pbtxt"); std::vector<GraphTransferer::InputNodeInfo> input_node_info_list = {}; std::vector<string> output_node_names = {}; bool is_text_proto = true; + // Keep following comments for debugging purpose for now - // filename = ""; - // input_node_names = { "Mul" }; - // output_node_names = { "softmax" }; + // filename = "v3_stripped_quantized_graph_opt.pb"; + // input_node_info_list.emplace_back( + // GraphTransferer::InputNodeInfo{"Mul", Tensor{DT_FLOAT, {1,299,299,3}}}); + // output_node_names.emplace_back("softmax"); // is_text_proto = false; + // ops_definitions = &HexagonOpsDefinitions::getInstance(); + + GraphTransferer::OutputTensorInfo output_tensor_info; GraphTransferer gt; + gt.EnableStrictCheckMode(false); Status status = gt.LoadGraphFromProtoFile( - TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, filename, input_node_info_list, - output_node_names, EMPTY_OUTPUT_TENSOR_MAP, is_text_proto); + *ops_definitions, filename, input_node_info_list, output_node_names, + is_text_proto, true, &output_tensor_info); // TODO(satok): Uncomment following assert once we fix the loader problem // ASSERT_TRUE(status.ok()) << status; } diff --git a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc index 8db1ee4b04..f170a4d556 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc @@ -21,6 +21,7 @@ limitations under the License. namespace tensorflow { +// HVX internal supported ops names enum class SupportedOpType { INPUT, OUTPUT, @@ -69,26 +70,28 @@ enum class SupportedOpType { SUPPORTED_OP_TYPE_COUNT, }; -static const std::unordered_map<string, SupportedOpType> - OP_NAME_TO_SOC_OP_TYPE_MAP{ - // Custom Op name - {IGraphTransferOpsDefinitions::INPUT_OP_NAME, SupportedOpType::INPUT}, - {IGraphTransferOpsDefinitions::OUTPUT_OP_NAME, SupportedOpType::OUTPUT}, - // Tensorflow op name - {"QuantizedConv2D", SupportedOpType::QUANTIZEDCONV2D_8X8TO32}, - {"QuantizedMatMul", SupportedOpType::QUANTIZEDMATMUL_8X8TO32}, - {"QuantizeDownAndShrinkRange", - SupportedOpType::QUANTIZEDOWNANDSHRINKRANGE_32TO8}, - {"QuantizedRelu", SupportedOpType::QUANTIZEDRELU_8}, - {"QuantizedReluX", SupportedOpType::QUANTIZEDRELUX_8}, - {"QuantizedMaxPool", SupportedOpType::QUANTIZEDMAXPOOL_8}, - {"QuantizedAvgPool", SupportedOpType::QUANTIZEDAVGPOOL_8}, - {"QuantizedConcat", SupportedOpType::QUANTIZEDCONCAT_8}, - {"QuantizedBiasAdd", SupportedOpType::QUANTIZEDBIASADD_8P8TO32}, - {"Min", SupportedOpType::MIN_F}, - {"Max", SupportedOpType::MAX_F}, - {"QuantizeV2", SupportedOpType::QUANTIZE}, - }; +const std::unordered_map<string, SupportedOpType> OP_NAME_TO_SOC_OP_TYPE_MAP{ + // Custom Op name + {IGraphTransferOpsDefinitions::INPUT_OP_NAME, SupportedOpType::INPUT}, + {IGraphTransferOpsDefinitions::OUTPUT_OP_NAME, SupportedOpType::OUTPUT}, + {"NoOp", SupportedOpType::NOP}, + {IGraphTransferOpsDefinitions::FLATTEN_OP_NAME, SupportedOpType::FLATTEN}, + // Tensorflow op name + {"QuantizedConv2D", SupportedOpType::QUANTIZEDCONV2D_8X8TO32}, + {"QuantizedMatMul", SupportedOpType::QUANTIZEDMATMUL_8X8TO32}, + {"QuantizeDownAndShrinkRange", + SupportedOpType::QUANTIZEDOWNANDSHRINKRANGE_32TO8}, + {"QuantizedRelu", SupportedOpType::QUANTIZEDRELU_8}, + {"QuantizedReluX", SupportedOpType::QUANTIZEDRELUX_8}, + {"QuantizedMaxPool", SupportedOpType::QUANTIZEDMAXPOOL_8}, + {"QuantizedAvgPool", SupportedOpType::QUANTIZEDAVGPOOL_8}, + {"QuantizedConcat", SupportedOpType::QUANTIZEDCONCAT_8}, + {"QuantizedBiasAdd", SupportedOpType::QUANTIZEDBIASADD_8P8TO32}, + {"Min", SupportedOpType::MIN_F}, + {"Max", SupportedOpType::MAX_F}, + {"QuantizeV2", SupportedOpType::QUANTIZE}, + {"Dequantize", SupportedOpType::DEQUANTIZE}, +}; /* static */ const IGraphTransferOpsDefinitions& HexagonOpsDefinitions::getInstance() { diff --git a/tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.cc b/tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.cc index 8e44c680f6..a4f6ec402e 100644 --- a/tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.cc +++ b/tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.cc @@ -21,4 +21,6 @@ namespace tensorflow { IGraphTransferOpsDefinitions::INPUT_OP_NAME; /* static */ constexpr const char* const IGraphTransferOpsDefinitions::OUTPUT_OP_NAME; +/* static */ constexpr const char* const + IGraphTransferOpsDefinitions::FLATTEN_OP_NAME; } diff --git a/tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h b/tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h index 039c4376e4..7e733e1f63 100644 --- a/tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h +++ b/tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h @@ -32,6 +32,8 @@ class IGraphTransferOpsDefinitions { static constexpr const char* const INPUT_OP_NAME = "INPUT"; // Custom op name for output node static constexpr const char* const OUTPUT_OP_NAME = "OUTPUT"; + // Custom op name for flatten node + static constexpr const char* const FLATTEN_OP_NAME = "FLATTEN"; IGraphTransferOpsDefinitions() = default; virtual ~IGraphTransferOpsDefinitions() = default; diff --git a/tensorflow/core/kernels/quantization_utils.h b/tensorflow/core/kernels/quantization_utils.h index a098179034..683a2f5b41 100644 --- a/tensorflow/core/kernels/quantization_utils.h +++ b/tensorflow/core/kernels/quantization_utils.h @@ -112,10 +112,9 @@ void QuantizationRangeForMultiplication(float min_a, float max_a, float min_b, // input_array is an eigen Tensor. q2f is a QuantizedToFloatStruct. // This evaluates to an eigen tensor expression, to be used like: // auto tensor = DEQUANTIZE_WITH_EIGEN(input_tensor, q2f); -#define DEQUANTIZE_WITH_EIGEN(input_array, q2f) \ - (q2f.range_min + \ - (((input_array.template cast<float>() - q2f.lowest_quantized())) * \ - q2f.range_scale)); +#define DEQUANTIZE_WITH_EIGEN(input_array, q2f) \ + ((q2f.range_min - q2f.lowest_quantized() * q2f.range_scale) + \ + input_array.template cast<float>() * q2f.range_scale) // input_array is an eigen Tensor. f2q is a FloatToQuantizedStruct. // OutputType is the type of output (e.g. quint8). diff --git a/tensorflow/core/kernels/quantization_utils_test.cc b/tensorflow/core/kernels/quantization_utils_test.cc index 55b5193ce1..8456604740 100644 --- a/tensorflow/core/kernels/quantization_utils_test.cc +++ b/tensorflow/core/kernels/quantization_utils_test.cc @@ -252,12 +252,11 @@ class QuantizationUtilsTest : public ::testing::Test { Eigen::ThreadPoolDevice* eigen_device) { // These are the float values we're going to test the conversions on. typedef std::pair<float, float> FPair; - for (FPair min_and_max : std::vector<FPair>{FPair(-255.0f, 255.0f), // - FPair(-1.0f, 1.0f), // - FPair(-1.0f, 255.0f), // - FPair(0.0f, 1e6), // - FPair(0.0f, 1.0f), // - FPair(-31.0f, 13.0f)}) { + for (FPair min_and_max : std::vector<FPair>{ + FPair(-255.0f, 255.0f), FPair(-1.0f, 1.0f), FPair(-1.0f, 255.0f), + FPair(0.0f, 1e6), FPair(0.0f, 1.0f), FPair(-31.0f, 13.0f), + FPair(-5.89505e+08, 5.89505e+08), + }) { const float f_min = min_and_max.first; const float f_max = min_and_max.second; const int values_count = sizeof(T) == 1 ? 256 : 50000; @@ -272,8 +271,8 @@ class QuantizationUtilsTest : public ::testing::Test { } else { int64 offset = static_cast<int64>(q_range / values_count * i); input_array(i) = static_cast<int32>( - Eigen::NumTraits<T>::lowest() + - std::min<int64>(Eigen::NumTraits<T>::highest(), offset)); + std::min<int64>(Eigen::NumTraits<T>::lowest() + offset, + Eigen::NumTraits<T>::highest())); } } @@ -285,7 +284,7 @@ class QuantizationUtilsTest : public ::testing::Test { for (int i = 0; i < values_count; ++i) { float expected = QuantizedToFloat<T>(input_array(i), f_min, f_max); float actual = output_array(i); - ASSERT_NEAR(expected, actual, range * 1e-6) + ASSERT_NEAR(expected, actual, range * 1.1e-7) << "expected=" << expected << " actual=" << actual << " v=" << input_array(i) << " i=" << i << " f_min=" << f_min << " f_max=" << f_max @@ -340,14 +339,14 @@ TEST_F(QuantizationUtilsTest, QuantizedToFloat) { const int int32_min = std::numeric_limits<int>::min(); const int int32_max = std::numeric_limits<int>::max(); - EXPECT_LT( - fabsf(-1.0f - QuantizedToFloat<qint32>(qint32(int32_min), -1.0f, 1.0f)), - 1e-5f); - EXPECT_LT(fabsf(0.0f - QuantizedToFloat<qint32>(qint32(0), -1.0f, 1.0f)), - 1e-5f); - EXPECT_LT( - fabsf(1.0f - QuantizedToFloat<qint32>(qint32(int32_max), -1.0f, 1.0f)), - 1e-5f); + EXPECT_NEAR(-1.0f, QuantizedToFloat<qint32>(qint32(int32_min), -1.0f, 1.0f), + 1e-5f); + EXPECT_NEAR(0.0f, QuantizedToFloat<qint32>(qint32(0), -1.0f, 1.0f), 1e-5f); + EXPECT_NEAR(1.0f, QuantizedToFloat<qint32>(qint32(int32_max), -1.0f, 1.0f), + 1e-5f); + + EXPECT_NEAR(32.0f, QuantizedToFloat<qint32>(qint32(32), int32_min, int32_max), + 1.0); } TEST_F(QuantizationUtilsTest, AvoidBias) { @@ -531,6 +530,32 @@ TEST_F(QuantizationUtilsTest, QuantizedTensorToFloat) { -103.0f, 115.0f, 116.0f, 117.0f}); Tensor output = QuantizedTensorToFloat<quint8>(input, input_min, input_max); test::ExpectTensorEqual<float>(expected, output); + + // Test for signed 32 bit. + // Note that we cannot use input mins and maxes that match the range because + // there are 7 too few bits of mantissa accuracy in floats to represent + // 2**31-1 accurately. Also there is no good fraction to use because 2**31-1 + // is a mersenne prime. + Tensor input32(DT_QINT32, TensorShape({input_height, input_width})); + + // Use a quantizer centered at 0. + float input_range = 1LL << 25; + int64 num_levels = (1LL << 32) - 1; + float step_size = + static_cast<float>(static_cast<double>(input_range) / num_levels); + float q_compatible_min_value = + roundf(-(input_range / 2.0) / step_size) * step_size; + float q_compatible_max_value = q_compatible_min_value + input_range; + test::FillValues<qint32>(&input32, {-16384, 0, 16256, -13440, -13312, -13184, + 14720, 14848, 14976}); + + Tensor output32 = QuantizedTensorToFloat<qint32>( + input32, q_compatible_min_value, q_compatible_max_value); + test::FillValues<float>(&expected, {-128.0f, 0.0f, 127.0f, -105.0f, -104.0f, + -103.0f, 115.0f, 116.0f, 117.0f}); + // The quantization error in going between 1<<25 and 1<<32 levels. + const double kTolerance = .5 / 128.0; + test::ExpectTensorNear<float>(expected, output32, kTolerance); } // Verify that QuantizedToFloatInPlaceUsingEigen is same result as diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index 7704c5f65a..55d3ee36da 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -37,7 +37,7 @@ static bool ValidUpdateShape(const TensorShape& params_shape, const Tensor& indices, const Tensor& updates) { int64 indices_nd = 1; if (indices.dims() > 1) { - indices_nd = indices.dim_size(1); + indices_nd = indices.dim_size(indices.dims() - 1); } for (int d = indices_nd; d < params_shape.dims(); d++) { if (updates.dim_size(d - indices_nd + 1) != params_shape.dim_size(d)) { @@ -71,13 +71,13 @@ static void PrepareAndValidateInputs(OpKernelContext* c, "The outermost dimension of updates and indices ", "must match. Got indices.shape ", indices_shape.DebugString(), ", updates.shape ", updates_shape.DebugString())); - OP_REQUIRES( - c, ValidUpdateShape(params_shape, indices, updates), - errors::InvalidArgument( - "Must have updates.shape = indices.shape[0] + params_shape[IXDIM:], ", - "got updates.shape ", updates_shape.DebugString(), ", indices.shape ", - indices_shape.DebugString(), ", params_shape ", - params_shape.DebugString())); + OP_REQUIRES(c, ValidUpdateShape(params_shape, indices, updates), + errors::InvalidArgument( + "Must have updates.shape = indices.shape[:IXDIM] + ", + "params_shape[IXDIM:], got updates.shape ", + updates_shape.DebugString(), ", indices.shape ", + indices_shape.DebugString(), ", params_shape ", + params_shape.DebugString())); // Check that we have enough index space const int64 N_big = indices.NumElements(); OP_REQUIRES(c, N_big <= std::numeric_limits<Index>::max(), diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc index d69909b6de..cc4772c001 100644 --- a/tensorflow/core/kernels/scatter_nd_op_test.cc +++ b/tensorflow/core/kernels/scatter_nd_op_test.cc @@ -217,8 +217,7 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) { {100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004}); Status s = RunOpKernel(); EXPECT_TRUE(StringPiece(s.ToString()) - .contains("Must have updates.shape = indices.shape[0] + " - "params_shape[IXDIM:], got")) + .contains("Must have updates.shape = indices.shape[:IXDIM]")) << s; } @@ -233,9 +232,10 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) { AddInputFromArray<float>(TensorShape({2, 3}), {100, 101, 102, 10000, 10001, 10002}); Status s = RunOpKernel(); - EXPECT_TRUE(StringPiece(s.ToString()) - .contains("The outermost dimension of updates and indices " - "must match. Got ")) + EXPECT_TRUE( + StringPiece(s.ToString()) + .contains( + "The outermost dimension of updates and indices must match.")) << s; } diff --git a/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc b/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc index 2d1539fb9d..cc0f86ce05 100644 --- a/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc +++ b/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc @@ -163,6 +163,10 @@ class SparseDenseBinaryOpShared : public OpKernel { } }; +// NOTE(aselle): If Div is extended to non-reals, make sure to use the same +// separation of operator semantics as done for dense cwise ops. I.e. you +// should make SparseDenseCwiseRealDiv, SparseDenseCwiseTruncateDiv, +// SparseDenseCwiseFloorDiv, and then deprecate, SparseDenseCwiseDiv. // TODO(zongheng): extend to other eligible cwise operations as requested. #define REGISTER_KERNELS(T) \ REGISTER_KERNEL_BUILDER( \ diff --git a/tensorflow/core/kernels/sparse_matmul_op.h b/tensorflow/core/kernels/sparse_matmul_op.h index 4e14f0099a..170d4ec18b 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.h +++ b/tensorflow/core/kernels/sparse_matmul_op.h @@ -209,6 +209,77 @@ EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) { #endif +#ifdef EIGEN_VECTORIZE_AVX512 +template <> +EIGEN_STRONG_INLINE Packet16f +pbroadcast_first<Packet16f>(const Packet16f& a_in) { + Packet4f a = _mm512_castps512_ps128(a_in); + return _mm512_broadcastss_ps(a); +} +template <> +EIGEN_STRONG_INLINE Packet16f +pbroadcast_second<Packet16f>(const Packet16f& a_in) { + Packet4f a = _mm512_castps512_ps128(a_in); + return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(1, 1, 1, 1))); +} +template <> +EIGEN_STRONG_INLINE Packet16f +pbroadcast_third<Packet16f>(const Packet16f& a_in) { + Packet4f a = _mm512_castps512_ps128(a_in); + return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(2, 2, 2, 2))); +} +template <> +EIGEN_STRONG_INLINE Packet16f +pbroadcast_fourth<Packet16f>(const Packet16f& a_in) { + Packet4f a = _mm512_castps512_ps128(a_in); + return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(3, 3, 3, 3))); +} +template <> +EIGEN_STRONG_INLINE Packet8d pbroadcast_first<Packet8d>(const Packet8d& a_in) { + Packet2d a = _mm512_castpd512_pd128(a_in); + return _mm512_broadcastsd_pd(a); +} +template <> +EIGEN_STRONG_INLINE Packet8d pbroadcast_second<Packet8d>(const Packet8d& a_in) { + Packet2d a = _mm_permute_pd(_mm512_castpd512_pd128(a_in), 3); + return _mm512_broadcastsd_pd(a); +} +template <> +EIGEN_STRONG_INLINE Packet8d pbroadcast_third<Packet8d>(const Packet8d& a_in) { + Packet2d a = _mm512_extractf32x4_ps(a_in, 1); + return _mm512_broadcastsd_pd(a); +} +template <> +EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth<Packet8d>(const Packet8d& a_in) { + Packet2d a = _mm_permute_pd(_mm512_extractf32x4_ps(a_in, 1), 3); + return _mm512_broadcastsd_pd(a); +} +template <> +EIGEN_STRONG_INLINE Packet16i +pbroadcast_first<Packet16i>(const Packet16i& a_in) { + Packet4i a = _mm512_castsi512_si128(a_in); + return _mm512_broadcastd_epi32(a); +} +template <> +EIGEN_STRONG_INLINE Packet16i +pbroadcast_second<Packet16i>(const Packet16i& a_in) { + Packet4i a = _mm512_castsi512_si128(a_in); + return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(1, 1, 1, 1))); +} +template <> +EIGEN_STRONG_INLINE Packet16i +pbroadcast_third<Packet16i>(const Packet16i& a_in) { + Packet4i a = _mm512_castsi512_si128(a_in); + return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(2, 2, 2, 2))); +} +template <> +EIGEN_STRONG_INLINE Packet16i +pbroadcast_fourth<Packet16i>(const Packet16i& a_in) { + Packet4i a = _mm512_castsi512_si128(a_in); + return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(3, 3, 3, 3))); +} +#endif + #ifdef EIGEN_VECTORIZE_AVX // For a Packet of Size 8 floats(256-bits), swap the 2nd and 3rd quadwords template <> @@ -245,6 +316,25 @@ EIGEN_STRONG_INLINE Packet8f pload2bf16<Packet8f>(const float* from) { _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp))); } +#ifdef EIGEN_VECTORIZE_AVX512 +// Return a Packet with 4 floats loaded from 4 bfloat16 values +template <> +EIGEN_STRONG_INLINE Packet16f pload4bf16<Packet16f>(const float* from) { + __m128i zero = _mm_setzero_si128(); + __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from)); + return _mm512_castps128_ps512( + _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp))); +} +// Return a Packet with 2 floats loaded from 2 bfloat16 values +template <> +EIGEN_STRONG_INLINE Packet16f pload2bf16<Packet16f>(const float* from) { + __m128i zero = _mm_setzero_si128(); + __m128i tmp = _mm_castps_si128(_mm_load_ps1(from)); + return _mm512_castps128_ps512( + _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp))); +} +#endif + // For each 128-bit lane convert 4 bfloat to 4 float values from the lower half // of the 128-bit lane template <typename Packet> @@ -313,6 +403,22 @@ EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth<Packet8f>(const Packet8f& a) { } #endif + +#ifdef EIGEN_VECTORIZE_AVX512 + +template <typename Packet> +EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) { + return _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_castsi512_si256(from)), + 16); +} + +template <typename Packet> +EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) { + return _mm512_slli_epi32( + _mm512_cvtepu16_epi32(_mm512_extractf64x4_pd(from, 1)), 16); +} + +#endif } // namespace internal } // namespace Eigen #endif diff --git a/tensorflow/core/kernels/sparse_matmul_op_test.cc b/tensorflow/core/kernels/sparse_matmul_op_test.cc index 45cad2e23b..b155e45187 100644 --- a/tensorflow/core/kernels/sparse_matmul_op_test.cc +++ b/tensorflow/core/kernels/sparse_matmul_op_test.cc @@ -200,7 +200,7 @@ class SparseMatmulOpTest : public ::testing::Test { // zero out lower 16-bits of mantissa of data3 values // copy bfloat representation to data3_bfloat16 - for (int i = 0; i < kMaxPacketSize; ++i) { + for (int i = 0; i < kMaxPacketSize * 2; ++i) { uint16_t* data3_p = reinterpret_cast<uint16_t*>(&data3[i]); uint16_t* data3_bfloat16_p = reinterpret_cast<uint16_t*>(data3_bfloat16) + i; @@ -222,7 +222,13 @@ class SparseMatmulOpTest : public ::testing::Test { return true; } +#ifdef EIGEN_VECTORIZE_AVX512 static const int kMaxPacketSize = 16; +#elif defined EIGEN_VECTORIZE_AVX || defined EIGEN_VECTORIZE_AVX2 + static const int kMaxPacketSize = 8; +#else + static const int kMaxPacketSize = 4; +#endif typedef typename Eigen::internal::packet_traits<float>::type Packet; const int PacketSize; // float values @@ -230,9 +236,9 @@ class SparseMatmulOpTest : public ::testing::Test { // output of intrinsics EIGEN_ALIGN_MAX float data2[kMaxPacketSize]; // float values with only 7 mantissa bits (bfloat representable) - EIGEN_ALIGN_MAX float data3[kMaxPacketSize]; + EIGEN_ALIGN_MAX float data3[kMaxPacketSize * 2]; // bfloat16 representation of data3 - EIGEN_ALIGN_MAX float data3_bfloat16[kMaxPacketSize / 2]; + EIGEN_ALIGN_MAX float data3_bfloat16[kMaxPacketSize]; EIGEN_ALIGN_MAX float ref[kMaxPacketSize]; }; diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc index 33705aac6a..e275b63de4 100644 --- a/tensorflow/core/kernels/stack_ops.cc +++ b/tensorflow/core/kernels/stack_ops.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -130,11 +129,12 @@ Status GetStack(OpKernelContext* ctx, Stack** stack) { } const string& container = Tstack_handle.flat<string>()(0); const string& stack_name = Tstack_handle.flat<string>()(1); - ResourceMgr* rm = ctx->step_resource_manager(); + ResourceMgr* rm = ctx->resource_manager(); if (rm == nullptr) { - return errors::Internal("No per-step resource manager."); + return errors::Internal("No resource manager."); } - TF_RETURN_IF_ERROR(rm->Lookup(container, stack_name, stack)); + TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(), + strings::StrCat(container, stack_name), stack)); return Status::OK(); } @@ -162,12 +162,13 @@ class StackOp : public OpKernel { auto handle = stack_handle.flat<string>(); handle(0) = "_stacks"; handle(1) = strings::StrCat(stack_name_, "_", stack_id); - // Store the handle in a container of the per-step RM. - ResourceMgr* rm = ctx->step_resource_manager(); - OP_REQUIRES(ctx, rm != nullptr, - errors::Internal("No per-step resource manager.")); + // Store the handle in a per-step container. + ResourceMgr* rm = ctx->resource_manager(); + OP_REQUIRES(ctx, rm != nullptr, errors::Internal("No resource manager.")); Stack* stack = new Stack(elem_type_, stack_handle); - OP_REQUIRES_OK(ctx, rm->Create(handle(0), handle(1), stack)); + OP_REQUIRES_OK(ctx, + rm->Create(ctx->step_container()->name(), + strings::StrCat(handle(0), handle(1)), stack)); ctx->set_output_ref(0, stack->mu(), stack->handle()); } diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc index c4584993fa..3226e5e0f8 100644 --- a/tensorflow/core/kernels/string_split_op.cc +++ b/tensorflow/core/kernels/string_split_op.cc @@ -30,7 +30,7 @@ namespace { std::vector<string> Split(const string& str, const string& delimiter) { if (delimiter.size()) { - return str_util::Split(str, delimiter[0], str_util::SkipEmpty()); + return str_util::Split(str, delimiter, str_util::SkipEmpty()); } std::vector<string> char_vector(str.size()); for (size_t i = 0; i < str.size(); ++i) { @@ -64,10 +64,6 @@ class StringSplitOp : public OpKernel { const auto delimiter_vec = delimiter_tensor->flat<string>(); const string& delimiter = delimiter_vec(0); // Empty delimiter means split the input character by character. - OP_REQUIRES(ctx, delimiter.size() < 2, - errors::InvalidArgument("Delimiter must be a character, got", - delimiter)); - std::vector<string> tokens; // Guess that we'll be unpacking a handful of tokens per example. static constexpr int kReserveSize = 4; diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 318e8ba160..fa26232468 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -45,6 +45,8 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #endif // GOOGLE_CUDA +// clang-format on + namespace tensorflow { Status GetHandle(OpKernelContext* ctx, string* container, string* ta_handle) { @@ -72,9 +74,10 @@ Status GetTensorArray(OpKernelContext* ctx, TensorArray** tensor_array) { string container; string ta_handle; TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &ta_handle)); - ResourceMgr* rm = ctx->step_resource_manager(); - if (rm == nullptr) return errors::Internal("No per-step resource manager."); - TF_RETURN_IF_ERROR(rm->Lookup(container, ta_handle, tensor_array)); + ResourceMgr* rm = ctx->resource_manager(); + if (rm == nullptr) return errors::Internal("No resource manager."); + TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(), + container + ta_handle, tensor_array)); return Status::OK(); } @@ -104,10 +107,9 @@ class TensorArrayCreationOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_temp( tensorflow::DT_STRING, tensorflow::TensorShape({2}), &tensor_array_output_handle, alloc_attr)); - // Store the handle in a container of the per-step RM. - ResourceMgr* rm = ctx->step_resource_manager(); - OP_REQUIRES(ctx, rm != nullptr, - errors::Internal("No per-step resource manager.")); + // Store the handle in a per-step container of the RM. + ResourceMgr* rm = ctx->resource_manager(); + OP_REQUIRES(ctx, rm != nullptr, errors::Internal("No resource manager.")); TensorArray* output_tensor_array; OP_REQUIRES_OK(ctx, CreateTensorArray(ctx, rm, &tensor_array_output_handle, @@ -167,8 +169,9 @@ class TensorArrayOp : public TensorArrayCreationOp { false /* multiple_writes_aggregate */, false /* is_grad */, -1 /* marked_size */, clear_after_read_); - TF_RETURN_IF_ERROR( - rm->Create(handle(0), unique_tensor_array_name, tensor_array)); + TF_RETURN_IF_ERROR(rm->Create( + ctx->step_container()->name(), + strings::StrCat(handle(0), unique_tensor_array_name), tensor_array)); *output_tensor_array = tensor_array; @@ -236,7 +239,9 @@ class TensorArrayGradOp : public TensorArrayCreationOp { output_handle(1) = strings::StrCat(tensor_array_name, "@", source_); TensorArray* tensor_array; - TF_RETURN_IF_ERROR(rm->Lookup(container, tensor_array_name, &tensor_array)); + TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(), + strings::StrCat(container, tensor_array_name), + &tensor_array)); core::ScopedUnref unref(tensor_array); // Once gradients are being calculated, the forward TensorArray @@ -268,7 +273,9 @@ class TensorArrayGradOp : public TensorArrayCreationOp { }; Status s = rm->LookupOrCreate<TensorArray>( - output_handle(0), output_handle(1), output_tensor_array, creator); + ctx->step_container()->name(), + strings::StrCat(output_handle(0), output_handle(1)), + output_tensor_array, creator); (*output_tensor_array)->Unref(); return s; diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h index f44f94c51b..d8d8831702 100644 --- a/tensorflow/core/kernels/variable_ops.h +++ b/tensorflow/core/kernels/variable_ops.h @@ -102,7 +102,7 @@ class TemporaryVariableOp : public OpKernel { void Compute(OpKernelContext* context) override { Status s; - ResourceMgr* rm = context->step_resource_manager(); + ResourceMgr* rm = context->resource_manager(); OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager.")); auto* tmp_var = new TmpVar; OP_REQUIRES(context, tmp_var, @@ -111,7 +111,8 @@ class TemporaryVariableOp : public OpKernel { s = context->allocate_temp(dtype_, shape_, &tmp_var->val); if (!s.ok()) tmp_var->Unref(); OP_REQUIRES_OK(context, s); - OP_REQUIRES_OK(context, rm->Create("tmp_var", var_name_, tmp_var)); + OP_REQUIRES_OK(context, rm->Create(context->step_container()->name(), + var_name_, tmp_var)); context->set_output_ref(0, &tmp_var->mu, &tmp_var->val); } @@ -149,10 +150,10 @@ class DestroyTemporaryVariableOp : public OpKernel { CHECK(IsRefType(context->input_dtype(0))); Tensor tmpvar = context->mutable_input(0, false); context->set_output(0, tmpvar); - ResourceMgr* rm = context->step_resource_manager(); + ResourceMgr* rm = context->resource_manager(); OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager.")); - OP_REQUIRES_OK( - context, rm->Delete<TemporaryVariableOp::TmpVar>("tmp_var", var_name_)); + OP_REQUIRES_OK(context, rm->Delete<TemporaryVariableOp::TmpVar>( + context->step_container()->name(), var_name_)); } private: diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h index ee197e54e3..183b18a5c6 100644 --- a/tensorflow/core/lib/strings/str_util.h +++ b/tensorflow/core/lib/strings/str_util.h @@ -108,9 +108,12 @@ struct SkipWhitespace { } }; -std::vector<string> Split(StringPiece text, char delim); +// Split strings using any of the supplied delimiters. For example: +// Split("a,b.c,d", ".,") would return {"a", "b", "c", "d"}. +std::vector<string> Split(StringPiece text, StringPiece delims); + template <typename Predicate> -std::vector<string> Split(StringPiece text, char delim, Predicate p); +std::vector<string> Split(StringPiece text, StringPiece delims, Predicate p); // Split "text" at "delim" characters, and parse each component as // an integer. If successful, adds the individual numbers in order @@ -157,17 +160,17 @@ string Join(const T& s, const char* sep, Formatter f) { return result; } -inline std::vector<string> Split(StringPiece text, char delim) { - return Split(text, delim, AllowEmpty()); +inline std::vector<string> Split(StringPiece text, StringPiece delims) { + return Split(text, delims, AllowEmpty()); } template <typename Predicate> -std::vector<string> Split(StringPiece text, char delim, Predicate p) { +std::vector<string> Split(StringPiece text, StringPiece delims, Predicate p) { std::vector<string> result; size_t token_start = 0; if (!text.empty()) { for (size_t i = 0; i < text.size() + 1; i++) { - if ((i == text.size()) || (text[i] == delim)) { + if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) { StringPiece token(text.data() + token_start, i - token_start); if (p(token)) { result.push_back(token.ToString()); @@ -179,6 +182,15 @@ std::vector<string> Split(StringPiece text, char delim, Predicate p) { return result; } +inline std::vector<string> Split(StringPiece text, char delim) { + return Split(text, StringPiece(&delim, 1)); +} + +template <typename Predicate> +std::vector<string> Split(StringPiece text, char delims, Predicate p) { + return Split(text, StringPiece(&delims, 1), p); +} + } // namespace str_util } // namespace tensorflow diff --git a/tensorflow/core/lib/strings/str_util_test.cc b/tensorflow/core/lib/strings/str_util_test.cc index 055e1e4ac0..afdc5855cc 100644 --- a/tensorflow/core/lib/strings/str_util_test.cc +++ b/tensorflow/core/lib/strings/str_util_test.cc @@ -239,6 +239,8 @@ TEST(Split, Basic) { EXPECT_EQ(str_util::Join(str_util::Split("a,b,c", ','), "|"), "a|b|c"); EXPECT_EQ(str_util::Join(str_util::Split("a,,,b,,c,", ','), "|"), "a|||b||c|"); + EXPECT_EQ(str_util::Join(str_util::Split("a!,!b,!c,", ",!"), "|"), + "a|||b||c|"); EXPECT_EQ(str_util::Join( str_util::Split("a,,,b,,c,", ',', str_util::SkipEmpty()), "|"), "a|b|c"); @@ -246,6 +248,10 @@ TEST(Split, Basic) { str_util::Join( str_util::Split("a, ,b,,c,", ',', str_util::SkipWhitespace()), "|"), "a|b|c"); + EXPECT_EQ(str_util::Join(str_util::Split("a. !b,;c,", ".,;!", + str_util::SkipWhitespace()), + "|"), + "a|b|c"); } TEST(SplitAndParseAsInts, Int32) { diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 8c459ed92b..7e8132f689 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -556,24 +556,32 @@ TEST(ArrayOpsTest, Concat_ShapeFn) { set_n(2); // Sum dim 0, merge the other two dims. - concat_dim_t = test::AsScalar(0); - INFER_OK(op, "[];[100,2,?];[10,?,3]", "[110,d1_1,d2_2]"); - INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op, - "[];[100,2,5];[10,?,3]"); - // concat_dim can't be summed, as one value is unknown. - INFER_OK(op, "[];[100,2,?];[?,?,3]", "[?,d1_1,d2_2]"); - INFER_OK(op, "[];[?,2,?];[10,?,3]", "[?,d1_1,d2_2]"); + for (int concat_dim : {0, -3}) { + concat_dim_t = test::AsScalar(concat_dim); + INFER_OK(op, "[];[100,2,?];[10,?,3]", "[110,d1_1,d2_2]"); + INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op, + "[];[100,2,5];[10,?,3]"); + // concat_dim can't be summed, as one value is unknown. + INFER_OK(op, "[];[100,2,?];[?,?,3]", "[?,d1_1,d2_2]"); + INFER_OK(op, "[];[?,2,?];[10,?,3]", "[?,d1_1,d2_2]"); + } // Test with a higher concat_dim. - concat_dim_t = test::AsScalar(1); - INFER_OK(op, "[];[1,100,?];[?,10,3]", "[d1_0,110,d2_2]"); - INFER_OK(op, "[];[1,100];[?,10]", "[d1_0,110]"); - INFER_OK(op, "[];[?,100];[1,10]", "[d2_0,110]"); - // concat_dim is too high. - INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, - "[];[100];[10,?]"); - INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, - "[];[100,5];[10]"); + for (bool use_negative : {false, true}) { + concat_dim_t = test::AsScalar(use_negative ? -2 : 1); + INFER_OK(op, "[];[1,100,?];[?,10,3]", "[d1_0,110,d2_2]"); + concat_dim_t = test::AsScalar(use_negative ? -1 : 1); + INFER_OK(op, "[];[1,100];[?,10]", "[d1_0,110]"); + INFER_OK(op, "[];[?,100];[1,10]", "[d2_0,110]"); + + // concat_dim is out of bounds. + concat_dim_t = test::AsScalar(use_negative ? -2 : 1); + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, + "[];[100];[10,?]"); + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, + "[];[100,5];[10]"); + } + // concat_dim is too low. concat_dim_t = test::AsScalar(-2); INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 6fbdb86c45..ff46aa2725 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -20029,7 +20029,7 @@ op { } input_arg { name: "delimiter" - description: "0-D. Delimiter character, or empty string." + description: "0-D. Delimiter characters (bytes), or empty string." type: DT_STRING } output_arg { @@ -20048,7 +20048,7 @@ op { type: DT_INT64 } summary: "Split elements of `input` based on `delimiter` into a `SparseTensor`." - description: "Let N be the size of source (typically N will be the batch size). Split each\nelement of `input` based on `delimiter` and return a `SparseTensor`\ncontaining the splitted tokens. Empty tokens are ignored.\n\n`delimiter` can be empty or a single-byte character. If `delimiter` is an empty\n string, each element of `input` is split into individual single-byte character\n strings, including splitting of UTF-8 multibyte sequences.\n\nFor example:\n N = 2, input[0] is \'hello world\' and input[1] is \'a b c\', then the output\n will be\n\n indices = [0, 0;\n 0, 1;\n 1, 0;\n 1, 1;\n 1, 2]\n shape = [2, 3]\n values = [\'hello\', \'world\', \'a\', \'b\', \'c\']" + description: "Let N be the size of source (typically N will be the batch size). Split each\nelement of `input` based on `delimiter` and return a `SparseTensor`\ncontaining the splitted tokens. Empty tokens are ignored.\n\n`delimiter` can be empty, or a string of split characters. If `delimiter` is an\n empty string, each element of `input` is split into individual single-byte\n character strings, including splitting of UTF-8 multibyte sequences. Otherwise\n every character of `delimiter` is a potential split point.\n\nFor example:\n N = 2, input[0] is \'hello world\' and input[1] is \'a b c\', then the output\n will be\n\n indices = [0, 0;\n 0, 1;\n 1, 0;\n 1, 1;\n 1, 2]\n shape = [2, 3]\n values = [\'hello\', \'world\', \'a\', \'b\', \'c\']" } op { name: "StringToHashBucket" diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index 53d75e4519..cef40289bf 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -222,9 +222,10 @@ Let N be the size of source (typically N will be the batch size). Split each element of `input` based on `delimiter` and return a `SparseTensor` containing the splitted tokens. Empty tokens are ignored. -`delimiter` can be empty or a single-byte character. If `delimiter` is an empty - string, each element of `input` is split into individual single-byte character - strings, including splitting of UTF-8 multibyte sequences. +`delimiter` can be empty, or a string of split characters. If `delimiter` is an + empty string, each element of `input` is split into individual single-byte + character strings, including splitting of UTF-8 multibyte sequences. Otherwise + every character of `delimiter` is a potential split point. For example: N = 2, input[0] is 'hello world' and input[1] is 'a b c', then the output @@ -239,7 +240,7 @@ For example: values = ['hello', 'world', 'a', 'b', 'c'] input: 1-D. Strings to split. -delimiter: 0-D. Delimiter character, or empty string. +delimiter: 0-D. Delimiter characters (bytes), or empty string. indices: A dense matrix of int64 representing the indices of the sparse tensor. values: A vector of strings corresponding to the splited values. shape: a length-2 vector of int64 representing the shape of the sparse diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD index bb52e75df3..810675fbcb 100644 --- a/tensorflow/core/platform/default/build_config/BUILD +++ b/tensorflow/core/platform/default/build_config/BUILD @@ -38,10 +38,17 @@ tf_cuda_library( ) cc_library( + name = "stream_executor_cuda", + deps = [ + "//tensorflow/stream_executor", + ] + select({ + "@local_config_cuda//cuda:darwin": ["IOKit"], + "//conditions:default": [], + }), +) + +cc_library( name = "stream_executor_no_cuda", - hdrs = [ - "stream_executor_no_cuda.h", - ], deps = [ "//tensorflow/stream_executor", ], diff --git a/tensorflow/core/platform/default/logging.h b/tensorflow/core/platform/default/logging.h index eaae673464..961fb8b4ad 100644 --- a/tensorflow/core/platform/default/logging.h +++ b/tensorflow/core/platform/default/logging.h @@ -67,6 +67,8 @@ class LogMessageFatal : public LogMessage { #define _TF_LOG_FATAL \ ::tensorflow::internal::LogMessageFatal(__FILE__, __LINE__) +#define _TF_LOG_QFATAL _TF_LOG_FATAL + #define LOG(severity) _TF_LOG_##severity // TODO(jeff): Define a proper implementation of VLOG_IS_ON diff --git a/tensorflow/core/platform/default/stacktrace.h b/tensorflow/core/platform/default/stacktrace.h index 8dc27b5d63..5f3073262a 100644 --- a/tensorflow/core/platform/default/stacktrace.h +++ b/tensorflow/core/platform/default/stacktrace.h @@ -22,12 +22,14 @@ namespace tensorflow { inline string CurrentStackTrace() { return "No stack trace available"; } +inline void DebugWriteToString(const char* data, void* arg) {} + // A dummy class that does nothing. Someday, add real support. class SavedStackTrace { public: SavedStackTrace() {} - void CreateCurrent() {} + void CreateCurrent(int skip_count) {} void Reset() {} diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc index 52ced38ac8..104ad42439 100644 --- a/tensorflow/core/platform/env_test.cc +++ b/tensorflow/core/platform/env_test.cc @@ -63,6 +63,23 @@ class DefaultEnvTest : public ::testing::Test { Env* env_ = Env::Default(); }; +TEST_F(DefaultEnvTest, IncompleteReadOutOfRange) { + const string filename = io::JoinPath(BaseDir(), "out_of_range"); + const string input = CreateTestFile(env_, filename, 2); + std::unique_ptr<RandomAccessFile> f; + TF_EXPECT_OK(env_->NewRandomAccessFile(filename, &f)); + + // Reading past EOF should give an OUT_OF_RANGE error + StringPiece result; + char scratch[3]; + EXPECT_EQ(error::OUT_OF_RANGE, f->Read(0, 3, &result, scratch).code()); + EXPECT_EQ(input, result); + + // Exact read to EOF works. + TF_EXPECT_OK(f->Read(0, 2, &result, scratch)); + EXPECT_EQ(input, result); +} + TEST_F(DefaultEnvTest, ReadFileToString) { for (const int length : {0, 1, 1212, 2553, 4928, 8196, 9000, (1 << 20) - 1, 1 << 20, (1 << 20) + 1}) { diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index 2be35eb455..a3b8b400a3 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -78,6 +78,19 @@ message OptimizerOptions { } Level opt_level = 3; + + // Control the use of the compiler/jit. Experimental. + enum GlobalJitLevel { + DEFAULT = 0; // Default setting ("off" now, but later expected to be "on") + OFF = -1; + // The following settings turn on compilation, with higher values being + // more aggressive. Higher values may reduce opportunities for parallelism + // and may use more memory. (At present, there is no distinction, but this + // is expected to change.) + ON_1 = 1; + ON_2 = 2; + } + GlobalJitLevel global_jit_level = 5; } message GraphOptions { diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc index 2ffe186e12..4693b4c005 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc @@ -511,12 +511,7 @@ TEST(TensorBundleTest, TruncatedTensorContents) { BundleReader reader(env, Prefix("end")); TF_ASSERT_OK(reader.status()); Tensor val(DT_FLOAT, TensorShape({2, 3})); -#if defined(PLATFORM_GOOGLE) - EXPECT_EQ("Data loss: Requested 24 bytes but read 23 bytes.", - reader.Lookup("key", &val).ToString()); -#else EXPECT_TRUE(errors::IsOutOfRange(reader.Lookup("key", &val))); -#endif } TEST(TensorBundleTest, HeaderEntry) { diff --git a/tensorflow/examples/android/AndroidManifest.xml b/tensorflow/examples/android/AndroidManifest.xml index 0a48d3d50b..e388734564 100644 --- a/tensorflow/examples/android/AndroidManifest.xml +++ b/tensorflow/examples/android/AndroidManifest.xml @@ -41,6 +41,15 @@ <category android:name="android.intent.category.LAUNCHER" /> </intent-filter> </activity> + + <activity android:name="org.tensorflow.demo.DetectorActivity" + android:screenOrientation="portrait" + android:label="@string/activity_name_detection"> + <intent-filter> + <action android:name="android.intent.action.MAIN" /> + <category android:name="android.intent.category.LAUNCHER" /> + </intent-filter> + </activity> </application> </manifest> diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD index beb8337702..3ba3a494ab 100644 --- a/tensorflow/examples/android/BUILD +++ b/tensorflow/examples/android/BUILD @@ -5,7 +5,11 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_copts") +load( + "//tensorflow:tensorflow.bzl", + "tf_copts", + "tf_opts_nortti_if_android", +) exports_files(["LICENSE"]) @@ -35,6 +39,7 @@ cc_binary( "notap", ], deps = [ + ":demo_proto_lib_cc", "//tensorflow/contrib/android:android_tensorflow_inference_jni", "//tensorflow/core:android_tensorflow_lib", LINKER_SCRIPT, @@ -60,6 +65,7 @@ android_binary( assets = [ "//tensorflow/examples/android/assets:asset_files", "@inception5h//:model_files", + "@mobile_multibox//:model_files", ], assets_dir = "", custom_package = "org.tensorflow.demo", @@ -111,3 +117,20 @@ filegroup( ) exports_files(["AndroidManifest.xml"]) + +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library", +) + +tf_proto_library( + name = "demo_proto_lib", + srcs = glob( + ["**/*.proto"], + ), + cc_api_version = 2, + visibility = ["//visibility:public"], +) + +# ----------------------------------------------------------------------------- +# Google-internal targets go here (must be at the end). diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md index b0465f7faa..b6556cdef4 100644 --- a/tensorflow/examples/android/README.md +++ b/tensorflow/examples/android/README.md @@ -1,11 +1,24 @@ # TensorFlow Android Camera Demo -This folder contains a simple camera-based demo application utilizing TensorFlow. +This folder contains an example application utilizing TensorFlow for Android +devices. ## Description -This demo uses a Google Inception model to classify camera frames in real-time, -displaying the top results in an overlay on the camera image. +The demos in this folder are designed to give straightforward samples of using +TensorFlow in mobile applications. Inference is done using the Java JNI API +exposed by `tensorflow/contrib/android`. + +Current samples: + +1. [TF Classify](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java): + Uses the [Google Inception](https://arxiv.org/abs/1409.4842) + model to classify camera frames in real-time, displaying the top results + in an overlay on the camera image. +2. [TF Detect](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java): + Demonstrates a model based on [Scalable Object Detection + using Deep Neural Networks](https://arxiv.org/abs/1312.2249) to + localize and track people in the camera preview in real-time. ## To build/install/run @@ -19,9 +32,9 @@ installed on your system. 3. The Android SDK and build tools may be obtained from: https://developer.android.com/tools/revisions/build-tools.html -The Android entries in [`<workspace_root>/WORKSPACE`](../../../WORKSPACE#L2-L13) must be -uncommented with the paths filled in appropriately depending on where you -installed the NDK and SDK. Otherwise an error such as: +The Android entries in [`<workspace_root>/WORKSPACE`](../../../WORKSPACE#L2-L13) +must be uncommented with the paths filled in appropriately depending on where +you installed the NDK and SDK. Otherwise an error such as: "The external label '//external:android/sdk' is not bound to anything" will be reported. @@ -29,19 +42,21 @@ The TensorFlow `GraphDef` that contains the model definition and weights is not packaged in the repo because of its size. It will be downloaded automatically via a new_http_archive defined in WORKSPACE. -**Optional**: If you wish to place the model in your assets manually (E.g. for -non-Bazel builds), remove the -`inception_5` entry in `BUILD` and download the archive yourself to the -`assets` directory in the source tree: +**Optional**: If you wish to place the models in your assets manually (E.g. for +non-Bazel builds), remove the `inception_5` and `mobile_multibox` entries in +`BUILD` and download the archives yourself to the `assets` directory in the +source tree: ```bash $ curl -L https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip -o /tmp/inception5h.zip +$ curl -L https://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1.zip -o /tmp/mobile_multibox_v1.zip $ unzip /tmp/inception5h.zip -d tensorflow/examples/android/assets/ +$ unzip /tmp/mobile_multibox_v1.zip -d tensorflow/examples/android/assets/ ``` -The labels file describing the possible classification will also be in the -assets directory. +The associated label and box prior files for the models will also be extracted +into the assets directory. After editing your WORKSPACE file to update the SDK/NDK configuration, you may build the APK. Run this from your workspace root: diff --git a/tensorflow/examples/android/jni/box_coder_jni.cc b/tensorflow/examples/android/jni/box_coder_jni.cc new file mode 100644 index 0000000000..be85414fc1 --- /dev/null +++ b/tensorflow/examples/android/jni/box_coder_jni.cc @@ -0,0 +1,92 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file loads the box coder mappings. + +#include <android/asset_manager.h> +#include <android/asset_manager_jni.h> +#include <android/bitmap.h> + +#include <jni.h> +#include <pthread.h> +#include <sys/stat.h> +#include <unistd.h> +#include <map> +#include <queue> +#include <sstream> +#include <string> + +#include "tensorflow/contrib/android/jni/jni_utils.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/proto/box_coder.pb.h" + +#define TENSORFLOW_METHOD(METHOD_NAME) \ + Java_org_tensorflow_demo_TensorFlowMultiBoxDetector_##METHOD_NAME // NOLINT + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +JNIEXPORT void JNICALL TENSORFLOW_METHOD(loadCoderOptions)( + JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring location, + jfloatArray priors); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +JNIEXPORT void JNICALL TENSORFLOW_METHOD(loadCoderOptions)( + JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring location, + jfloatArray priors) { + AAssetManager* const asset_manager = + AAssetManager_fromJava(env, java_asset_manager); + LOG(INFO) << "Acquired AssetManager."; + + const std::string location_str = GetString(env, location); + + org_tensorflow_demo::MultiBoxCoderOptions multi_options; + + LOG(INFO) << "Reading file to proto: " << location_str; + ReadFileToProtoOrDie(asset_manager, location_str.c_str(), &multi_options); + + LOG(INFO) << "Read file. " << multi_options.box_coder_size() << " entries."; + + jboolean iCopied = JNI_FALSE; + jfloat* values = env->GetFloatArrayElements(priors, &iCopied); + + const int array_length = env->GetArrayLength(priors); + LOG(INFO) << "Array length: " << array_length + << " (/8 = " << (array_length / 8) << ")"; + CHECK_EQ(array_length % 8, 0); + + const int num_items = + std::min(array_length / 8, multi_options.box_coder_size()); + + for (int i = 0; i < num_items; ++i) { + const org_tensorflow_demo::BoxCoderOptions& options = + multi_options.box_coder(i); + + for (int j = 0; j < 4; ++j) { + const org_tensorflow_demo::BoxCoderPrior& prior = options.priors(j); + values[i * 8 + j * 2] = prior.mean(); + values[i * 8 + j * 2 + 1] = prior.stddev(); + } + } + env->ReleaseFloatArrayElements(priors, values, 0); + + LOG(INFO) << "Read " << num_items << " options"; +} diff --git a/tensorflow/examples/android/jni/object_tracking/config.h b/tensorflow/examples/android/jni/object_tracking/config.h new file mode 100644 index 0000000000..86e9fc71b6 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/config.h @@ -0,0 +1,300 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_ + +#include <math.h> + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" + +namespace tf_tracking { + +// Arbitrary keypoint type ids for labeling the origin of tracked keypoints. +enum KeypointType { + KEYPOINT_TYPE_DEFAULT = 0, + KEYPOINT_TYPE_FAST = 1, + KEYPOINT_TYPE_INTEREST = 2 +}; + +// Struct that can be used to more richly store the results of a detection +// than a single number, while still maintaining comparability. +struct MatchScore { + explicit MatchScore(double val) : value(val) {} + MatchScore() { value = 0.0; } + + double value; + + MatchScore& operator+(const MatchScore& rhs) { + value += rhs.value; + return *this; + } + + friend std::ostream& operator<<(std::ostream& stream, + const MatchScore& detection) { + stream << detection.value; + return stream; + } +}; +inline bool operator< (const MatchScore& cC1, const MatchScore& cC2) { + return cC1.value < cC2.value; +} +inline bool operator> (const MatchScore& cC1, const MatchScore& cC2) { + return cC1.value > cC2.value; +} +inline bool operator>= (const MatchScore& cC1, const MatchScore& cC2) { + return cC1.value >= cC2.value; +} +inline bool operator<= (const MatchScore& cC1, const MatchScore& cC2) { + return cC1.value <= cC2.value; +} + +// Fixed seed used for all random number generators. +static const int kRandomNumberSeed = 11111; + +// TODO(andrewharp): Move as many of these settings as possible into a settings +// object which can be passed in from Java at runtime. + +// Whether or not to use ESM instead of LK flow. +static const bool kUseEsm = false; + +// This constant gets added to the diagonal of the Hessian +// before solving for translation in 2dof ESM. +// It ensures better behavior especially in the absence of +// strong texture. +static const int kEsmRegularizer = 20; + +// Do we want to brightness-normalize each keypoint patch when we compute +// its flow using ESM? +static const bool kDoBrightnessNormalize = true; + +// Whether or not to use fixed-point interpolated pixel lookups in optical flow. +#define USE_FIXED_POINT_FLOW 1 + +// Whether to normalize keypoint windows for intensity in LK optical flow. +// This is a define for now because it helps keep the code streamlined. +#define NORMALIZE 1 + +// Number of keypoints to store per frame. +static const int kMaxKeypoints = 76; + +// Keypoint detection. +static const int kMaxTempKeypoints = 1024; + +// Number of floats each keypoint takes up when exporting to an array. +static const int kKeypointStep = 7; + +// Number of frame deltas to keep around in the circular queue. +static const int kNumFrames = 512; + +// Number of iterations to do tracking on each keypoint at each pyramid level. +static const int kNumIterations = 3; + +// The number of bins (on a side) to divide each bin from the previous +// cache level into. Higher numbers will decrease performance by increasing +// cache misses, but mean that cache hits are more locally relevant. +static const int kCacheBranchFactor = 2; + +// Number of levels to put in the cache. +// Each level of the cache is a square grid of bins, length: +// branch_factor^(level - 1) on each side. +// +// This may be greater than kNumPyramidLevels. Setting it to 0 means no +// caching is enabled. +static const int kNumCacheLevels = 3; + +// The level at which the cache pyramid gets cut off and replaced by a matrix +// transform if such a matrix has been provided to the cache. +static const int kCacheCutoff = 1; + +static const int kNumPyramidLevels = 4; + +// The minimum number of keypoints needed in an object's area. +static const int kMaxKeypointsForObject = 16; + +// Minimum number of pyramid levels to use after getting cached value. +// This allows fine-scale adjustment from the cached value, which is taken +// from the center of the corresponding top cache level box. +// Can be [0, kNumPyramidLevels). +static const int kMinNumPyramidLevelsToUseForAdjustment = 1; + +// Window size to integrate over to find local image derivative. +static const int kFlowIntegrationWindowSize = 3; + +// Total area of integration windows. +static const int kFlowArraySize = + (2 * kFlowIntegrationWindowSize + 1) * (2 * kFlowIntegrationWindowSize + 1); + +// Error that's considered good enough to early abort tracking. +static const float kTrackingAbortThreshold = 0.03f; + +// Maximum number of deviations a keypoint-correspondence delta can be from the +// weighted average before being thrown out for region-based queries. +static const float kNumDeviations = 2.0f; + +// The length of the allowed delta between the forward and the backward +// flow deltas in terms of the length of the forward flow vector. +static const float kMaxForwardBackwardErrorAllowed = 0.5f; + +// Threshold for pixels to be considered different. +static const int kFastDiffAmount = 10; + +// How far from edge of frame to stop looking for FAST keypoints. +static const int kFastBorderBuffer = 10; + +// Determines if non-detected arbitrary keypoints should be added to regions. +// This will help if no keypoints have been detected in the region yet. +static const bool kAddArbitraryKeypoints = true; + +// How many arbitrary keypoints to add along each axis as candidates for each +// region? +static const int kNumToAddAsCandidates = 1; + +// In terms of region dimensions, how closely can we place keypoints +// next to each other? +static const float kClosestPercent = 0.6f; + +// How many FAST qualifying pixels must be connected to a pixel for it to be +// considered a candidate keypoint for Harris filtering. +static const int kMinNumConnectedForFastKeypoint = 8; + +// Size of the window to integrate over for Harris filtering. +// Compare to kFlowIntegrationWindowSize. +static const int kHarrisWindowSize = 2; + + +// DETECTOR PARAMETERS + +// Before relocalizing, make sure the new proposed position is better than +// the existing position by a small amount to prevent thrashing. +static const MatchScore kMatchScoreBuffer(0.01f); + +// Minimum score a tracked object can have and still be considered a match. +// TODO(andrewharp): Make this a per detector thing. +static const MatchScore kMinimumMatchScore(0.5f); + +static const float kMinimumCorrelationForTracking = 0.4f; + +static const MatchScore kMatchScoreForImmediateTermination(0.0f); + +// Run the detector every N frames. +static const int kDetectEveryNFrames = 4; + +// How many features does each feature_set contain? +static const int kFeaturesPerFeatureSet = 10; + +// The number of FeatureSets managed by the object detector. +// More FeatureSets can increase recall at the cost of performance. +static const int kNumFeatureSets = 7; + +// How many FeatureSets must respond affirmatively for a candidate descriptor +// and position to be given more thorough attention? +static const int kNumFeatureSetsForCandidate = 2; + +// How large the thumbnails used for correlation validation are. Used for both +// width and height. +static const int kNormalizedThumbnailSize = 11; + +// The area of intersection divided by union for the bounding boxes that tells +// if this tracking has slipped enough to invalidate all unlocked examples. +static const float kPositionOverlapThreshold = 0.6f; + +// The number of detection failures allowed before an object goes invisible. +// Tracking will still occur, so if it is actually still being tracked and +// comes back into a detectable position, it's likely to be found. +static const int kMaxNumDetectionFailures = 4; + + +// Minimum square size to scan with sliding window. +static const float kScanMinSquareSize = 16.0f; + +// Minimum square size to scan with sliding window. +static const float kScanMaxSquareSize = 64.0f; + +// Scale difference for consecutive scans of the sliding window. +static const float kScanScaleFactor = sqrtf(2.0f); + +// Step size for sliding window. +static const int kScanStepSize = 10; + + +// How tightly to pack the descriptor boxes for confirmed exemplars. +static const float kLockedScaleFactor = 1 / sqrtf(2.0f); + +// How tightly to pack the descriptor boxes for unconfirmed exemplars. +static const float kUnlockedScaleFactor = 1 / 2.0f; + +// How tightly the boxes to scan centered at the last known position will be +// packed. +static const float kLastKnownPositionScaleFactor = 1.0f / sqrtf(2.0f); + +// The bounds on how close a new object example must be to existing object +// examples for detection to be valid. +static const float kMinCorrelationForNewExample = 0.75f; +static const float kMaxCorrelationForNewExample = 0.99f; + + +// The number of safe tries an exemplar has after being created before +// missed detections count against it. +static const int kFreeTries = 5; + +// A false positive is worth this many missed detections. +static const int kFalsePositivePenalty = 5; + +struct ObjectDetectorConfig { + const Size image_size; + + explicit ObjectDetectorConfig(const Size& image_size) + : image_size(image_size) {} + virtual ~ObjectDetectorConfig() = default; +}; + +struct KeypointDetectorConfig { + const Size image_size; + + bool detect_skin; + + explicit KeypointDetectorConfig(const Size& image_size) + : image_size(image_size), + detect_skin(false) {} +}; + + +struct OpticalFlowConfig { + const Size image_size; + + explicit OpticalFlowConfig(const Size& image_size) + : image_size(image_size) {} +}; + +struct TrackerConfig { + const Size image_size; + KeypointDetectorConfig keypoint_detector_config; + OpticalFlowConfig flow_config; + bool always_track; + + float object_box_scale_factor_for_features; + + explicit TrackerConfig(const Size& image_size) + : image_size(image_size), + keypoint_detector_config(image_size), + flow_config(image_size), + always_track(false), + object_box_scale_factor_for_features(1.0f) {} +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/flow_cache.h b/tensorflow/examples/android/jni/object_tracking/flow_cache.h new file mode 100644 index 0000000000..8813ab6d71 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/flow_cache.h @@ -0,0 +1,306 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" + +namespace tf_tracking { + +// Class that helps OpticalFlow to speed up flow computation +// by caching coarse-grained flow. +class FlowCache { + public: + explicit FlowCache(const OpticalFlowConfig* const config) + : config_(config), + image_size_(config->image_size), + optical_flow_(config), + fullframe_matrix_(NULL) { + for (int i = 0; i < kNumCacheLevels; ++i) { + const int curr_dims = BlockDimForCacheLevel(i); + has_cache_[i] = new Image<bool>(curr_dims, curr_dims); + displacements_[i] = new Image<Point2f>(curr_dims, curr_dims); + } + } + + ~FlowCache() { + for (int i = 0; i < kNumCacheLevels; ++i) { + SAFE_DELETE(has_cache_[i]); + SAFE_DELETE(displacements_[i]); + } + delete[](fullframe_matrix_); + fullframe_matrix_ = NULL; + } + + void NextFrame(ImageData* const new_frame, + const float* const align_matrix23) { + ClearCache(); + SetFullframeAlignmentMatrix(align_matrix23); + optical_flow_.NextFrame(new_frame); + } + + void ClearCache() { + for (int i = 0; i < kNumCacheLevels; ++i) { + has_cache_[i]->Clear(false); + } + delete[](fullframe_matrix_); + fullframe_matrix_ = NULL; + } + + // Finds the flow at a point, using the cache for performance. + bool FindFlowAtPoint(const float u_x, const float u_y, + float* const flow_x, float* const flow_y) const { + // Get the best guess from the cache. + const Point2f guess_from_cache = LookupGuess(u_x, u_y); + + *flow_x = guess_from_cache.x; + *flow_y = guess_from_cache.y; + + // Now refine the guess using the image pyramid. + for (int pyramid_level = kMinNumPyramidLevelsToUseForAdjustment - 1; + pyramid_level >= 0; --pyramid_level) { + if (!optical_flow_.FindFlowAtPointSingleLevel( + pyramid_level, u_x, u_y, false, flow_x, flow_y)) { + return false; + } + } + + return true; + } + + // Determines the displacement of a point, and uses that to calculate a new + // position. + // Returns true iff the displacement determination worked and the new position + // is in the image. + bool FindNewPositionOfPoint(const float u_x, const float u_y, + float* final_x, float* final_y) const { + float flow_x; + float flow_y; + if (!FindFlowAtPoint(u_x, u_y, &flow_x, &flow_y)) { + return false; + } + + // Add in the displacement to get the final position. + *final_x = u_x + flow_x; + *final_y = u_y + flow_y; + + // Assign the best guess, if we're still in the image. + if (InRange(*final_x, 0.0f, static_cast<float>(image_size_.width) - 1) && + InRange(*final_y, 0.0f, static_cast<float>(image_size_.height) - 1)) { + return true; + } else { + return false; + } + } + + // Comparison function for qsort. + static int Compare(const void* a, const void* b) { + return *reinterpret_cast<const float*>(a) - + *reinterpret_cast<const float*>(b); + } + + // Returns the median flow within the given bounding box as determined + // by a grid_width x grid_height grid. + Point2f GetMedianFlow(const BoundingBox& bounding_box, + const bool filter_by_fb_error, + const int grid_width, + const int grid_height) const { + const int kMaxPoints = 100; + SCHECK(grid_width * grid_height <= kMaxPoints, + "Too many points for Median flow!"); + + const BoundingBox valid_box = bounding_box.Intersect( + BoundingBox(0, 0, image_size_.width - 1, image_size_.height - 1)); + + if (valid_box.GetArea() <= 0.0f) { + return Point2f(0, 0); + } + + float x_deltas[kMaxPoints]; + float y_deltas[kMaxPoints]; + + int curr_offset = 0; + for (int i = 0; i < grid_width; ++i) { + for (int j = 0; j < grid_height; ++j) { + const float x_in = valid_box.left_ + + (valid_box.GetWidth() * i) / (grid_width - 1); + + const float y_in = valid_box.top_ + + (valid_box.GetHeight() * j) / (grid_height - 1); + + float curr_flow_x; + float curr_flow_y; + const bool success = FindNewPositionOfPoint(x_in, y_in, + &curr_flow_x, &curr_flow_y); + + if (success) { + x_deltas[curr_offset] = curr_flow_x; + y_deltas[curr_offset] = curr_flow_y; + ++curr_offset; + } else { + LOGW("Tracking failure!"); + } + } + } + + if (curr_offset > 0) { + qsort(x_deltas, curr_offset, sizeof(*x_deltas), Compare); + qsort(y_deltas, curr_offset, sizeof(*y_deltas), Compare); + + return Point2f(x_deltas[curr_offset / 2], y_deltas[curr_offset / 2]); + } + + LOGW("No points were valid!"); + return Point2f(0, 0); + } + + void SetFullframeAlignmentMatrix(const float* const align_matrix23) { + if (align_matrix23 != NULL) { + if (fullframe_matrix_ == NULL) { + fullframe_matrix_ = new float[6]; + } + + memcpy(fullframe_matrix_, align_matrix23, + 6 * sizeof(fullframe_matrix_[0])); + } + } + + private: + Point2f LookupGuessFromLevel( + const int cache_level, const float x, const float y) const { + // LOGE("Looking up guess at %5.2f %5.2f for level %d.", x, y, cache_level); + + // Cutoff at the target level and use the matrix transform instead. + if (fullframe_matrix_ != NULL && cache_level == kCacheCutoff) { + const float xnew = x * fullframe_matrix_[0] + + y * fullframe_matrix_[1] + + fullframe_matrix_[2]; + const float ynew = x * fullframe_matrix_[3] + + y * fullframe_matrix_[4] + + fullframe_matrix_[5]; + + return Point2f(xnew - x, ynew - y); + } + + const int level_dim = BlockDimForCacheLevel(cache_level); + const int pixels_per_cache_block_x = + (image_size_.width + level_dim - 1) / level_dim; + const int pixels_per_cache_block_y = + (image_size_.height + level_dim - 1) / level_dim; + const int index_x = x / pixels_per_cache_block_x; + const int index_y = y / pixels_per_cache_block_y; + + Point2f displacement; + if (!(*has_cache_[cache_level])[index_y][index_x]) { + (*has_cache_[cache_level])[index_y][index_x] = true; + + // Get the lower cache level's best guess, if it exists. + displacement = cache_level >= kNumCacheLevels - 1 ? + Point2f(0, 0) : LookupGuessFromLevel(cache_level + 1, x, y); + // LOGI("Best guess at cache level %d is %5.2f, %5.2f.", cache_level, + // best_guess.x, best_guess.y); + + // Find the center of the block. + const float center_x = (index_x + 0.5f) * pixels_per_cache_block_x; + const float center_y = (index_y + 0.5f) * pixels_per_cache_block_y; + const int pyramid_level = PyramidLevelForCacheLevel(cache_level); + + // LOGI("cache level %d: [%d, %d (%5.2f / %d, %5.2f / %d)] " + // "Querying %5.2f, %5.2f at pyramid level %d, ", + // cache_level, index_x, index_y, + // x, pixels_per_cache_block_x, y, pixels_per_cache_block_y, + // center_x, center_y, pyramid_level); + + // TODO(andrewharp): Turn on FB error filtering. + const bool success = optical_flow_.FindFlowAtPointSingleLevel( + pyramid_level, center_x, center_y, false, + &displacement.x, &displacement.y); + + if (!success) { + LOGV("Computation of cached value failed for level %d!", cache_level); + } + + // Store the value for later use. + (*displacements_[cache_level])[index_y][index_x] = displacement; + } else { + displacement = (*displacements_[cache_level])[index_y][index_x]; + } + + // LOGI("Returning %5.2f, %5.2f for level %d", + // displacement.x, displacement.y, cache_level); + return displacement; + } + + Point2f LookupGuess(const float x, const float y) const { + if (x < 0 || x >= image_size_.width || y < 0 || y >= image_size_.height) { + return Point2f(0, 0); + } + + // LOGI("Looking up guess at %5.2f %5.2f.", x, y); + if (kNumCacheLevels > 0) { + return LookupGuessFromLevel(0, x, y); + } else { + return Point2f(0, 0); + } + } + + // Returns the number of cache bins in each dimension for a given level + // of the cache. + int BlockDimForCacheLevel(const int cache_level) const { + // The highest (coarsest) cache level has a block dim of kCacheBranchFactor, + // thus if there are 4 cache levels, requesting level 3 (0-based) should + // return kCacheBranchFactor, level 2 should return kCacheBranchFactor^2, + // and so on. + int block_dim = kNumCacheLevels; + for (int curr_level = kNumCacheLevels - 1; curr_level > cache_level; + --curr_level) { + block_dim *= kCacheBranchFactor; + } + return block_dim; + } + + // Returns the level of the image pyramid that a given cache level maps to. + int PyramidLevelForCacheLevel(const int cache_level) const { + // Higher cache and pyramid levels have smaller dimensions. The highest + // cache level should refer to the highest image pyramid level. The + // lower, finer image pyramid levels are uncached (assuming + // kNumCacheLevels < kNumPyramidLevels). + return cache_level + (kNumPyramidLevels - kNumCacheLevels); + } + + const OpticalFlowConfig* const config_; + + const Size image_size_; + OpticalFlow optical_flow_; + + float* fullframe_matrix_; + + // Whether this value is currently present in the cache. + Image<bool>* has_cache_[kNumCacheLevels]; + + // The cached displacement values. + Image<Point2f>* displacements_[kNumCacheLevels]; + + TF_DISALLOW_COPY_AND_ASSIGN(FlowCache); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/frame_pair.cc b/tensorflow/examples/android/jni/object_tracking/frame_pair.cc new file mode 100644 index 0000000000..fa86e2363c --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/frame_pair.cc @@ -0,0 +1,308 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <float.h> + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h" + +namespace tf_tracking { + +void FramePair::Init(const int64 start_time, const int64 end_time) { + start_time_ = start_time; + end_time_ = end_time; + memset(optical_flow_found_keypoint_, false, + sizeof(*optical_flow_found_keypoint_) * kMaxKeypoints); + number_of_keypoints_ = 0; +} + +void FramePair::AdjustBox(const BoundingBox box, + float* const translation_x, + float* const translation_y, + float* const scale_x, + float* const scale_y) const { + static float weights[kMaxKeypoints]; + static Point2f deltas[kMaxKeypoints]; + memset(weights, 0.0f, sizeof(*weights) * kMaxKeypoints); + + BoundingBox resized_box(box); + resized_box.Scale(0.4f, 0.4f); + FillWeights(resized_box, weights); + FillTranslations(deltas); + + const Point2f translation = GetWeightedMedian(weights, deltas); + + *translation_x = translation.x; + *translation_y = translation.y; + + const Point2f old_center = box.GetCenter(); + const int good_scale_points = + FillScales(old_center, translation, weights, deltas); + + // Default scale factor is 1 for x and y. + *scale_x = 1.0f; + *scale_y = 1.0f; + + // The assumption is that all deltas that make it to this stage with a + // correspondending optical_flow_found_keypoint_[i] == true are not in + // themselves degenerate. + // + // The degeneracy with scale arose because if the points are too close to the + // center of the objects, the scale ratio determination might be incalculable. + // + // The check for kMinNumInRange is not a degeneracy check, but merely an + // attempt to ensure some sort of stability. The actual degeneracy check is in + // the comparison to EPSILON in FillScales (which I've updated to return the + // number good remaining as well). + static const int kMinNumInRange = 5; + if (good_scale_points >= kMinNumInRange) { + const float scale_factor = GetWeightedMedianScale(weights, deltas); + + if (scale_factor > 0.0f) { + *scale_x = scale_factor; + *scale_y = scale_factor; + } + } +} + +int FramePair::FillWeights(const BoundingBox& box, + float* const weights) const { + // Compute the max score. + float max_score = -FLT_MAX; + float min_score = FLT_MAX; + for (int i = 0; i < kMaxKeypoints; ++i) { + if (optical_flow_found_keypoint_[i]) { + max_score = MAX(max_score, frame1_keypoints_[i].score_); + min_score = MIN(min_score, frame1_keypoints_[i].score_); + } + } + + int num_in_range = 0; + for (int i = 0; i < kMaxKeypoints; ++i) { + if (!optical_flow_found_keypoint_[i]) { + weights[i] = 0.0f; + continue; + } + + const bool in_box = box.Contains(frame1_keypoints_[i].pos_); + if (in_box) { + ++num_in_range; + } + + // The weighting based off distance. Anything within the bounding box + // has a weight of 1, and everything outside of that is within the range + // [0, kOutOfBoxMultiplier), falling off with the squared distance ratio. + float distance_score = 1.0f; + if (!in_box) { + const Point2f initial = box.GetCenter(); + const float sq_x_dist = + Square(initial.x - frame1_keypoints_[i].pos_.x); + const float sq_y_dist = + Square(initial.y - frame1_keypoints_[i].pos_.y); + const float squared_half_width = Square(box.GetWidth() / 2.0f); + const float squared_half_height = Square(box.GetHeight() / 2.0f); + + static const float kOutOfBoxMultiplier = 0.5f; + distance_score = kOutOfBoxMultiplier * + MIN(squared_half_height / sq_y_dist, squared_half_width / sq_x_dist); + } + + // The weighting based on relative score strength. kBaseScore - 1.0f. + float intrinsic_score = 1.0f; + if (max_score > min_score) { + static const float kBaseScore = 0.5f; + intrinsic_score = ((frame1_keypoints_[i].score_ - min_score) / + (max_score - min_score)) * (1.0f - kBaseScore) + kBaseScore; + } + + // The final score will be in the range [0, 1]. + weights[i] = distance_score * intrinsic_score; + } + + return num_in_range; +} + +void FramePair::FillTranslations(Point2f* const translations) const { + for (int i = 0; i < kMaxKeypoints; ++i) { + if (!optical_flow_found_keypoint_[i]) { + continue; + } + translations[i].x = + frame2_keypoints_[i].pos_.x - frame1_keypoints_[i].pos_.x; + translations[i].y = + frame2_keypoints_[i].pos_.y - frame1_keypoints_[i].pos_.y; + } +} + +int FramePair::FillScales(const Point2f& old_center, + const Point2f& translation, + float* const weights, + Point2f* const scales) const { + int num_good = 0; + for (int i = 0; i < kMaxKeypoints; ++i) { + if (!optical_flow_found_keypoint_[i]) { + continue; + } + + const Keypoint keypoint1 = frame1_keypoints_[i]; + const Keypoint keypoint2 = frame2_keypoints_[i]; + + const float dist1_x = keypoint1.pos_.x - old_center.x; + const float dist1_y = keypoint1.pos_.y - old_center.y; + + const float dist2_x = (keypoint2.pos_.x - translation.x) - old_center.x; + const float dist2_y = (keypoint2.pos_.y - translation.y) - old_center.y; + + // Make sure that the scale makes sense; points too close to the center + // will result in either NaNs or infinite results for scale due to + // limited tracking and floating point resolution. + // Also check that the parity of the points is the same with respect to + // x and y, as we can't really make sense of data that has flipped. + if (((dist2_x > EPSILON && dist1_x > EPSILON) || + (dist2_x < -EPSILON && dist1_x < -EPSILON)) && + ((dist2_y > EPSILON && dist1_y > EPSILON) || + (dist2_y < -EPSILON && dist1_y < -EPSILON))) { + scales[i].x = dist2_x / dist1_x; + scales[i].y = dist2_y / dist1_y; + ++num_good; + } else { + weights[i] = 0.0f; + scales[i].x = 1.0f; + scales[i].y = 1.0f; + } + } + return num_good; +} + +struct WeightedDelta { + float weight; + float delta; +}; + +// Sort by delta, not by weight. +inline int WeightedDeltaCompare(const void* const a, const void* const b) { + return (reinterpret_cast<const WeightedDelta*>(a)->delta - + reinterpret_cast<const WeightedDelta*>(b)->delta) <= 0 ? 1 : -1; +} + +// Returns the median delta from a sorted set of weighted deltas. +static float GetMedian(const int num_items, + const WeightedDelta* const weighted_deltas, + const float sum) { + if (num_items == 0 || sum < EPSILON) { + return 0.0f; + } + + float current_weight = 0.0f; + const float target_weight = sum / 2.0f; + for (int i = 0; i < num_items; ++i) { + if (weighted_deltas[i].weight > 0.0f) { + current_weight += weighted_deltas[i].weight; + if (current_weight >= target_weight) { + return weighted_deltas[i].delta; + } + } + } + LOGW("Median not found! %d points, sum of %.2f", num_items, sum); + return 0.0f; +} + +Point2f FramePair::GetWeightedMedian( + const float* const weights, const Point2f* const deltas) const { + Point2f median_delta; + + // TODO(andrewharp): only sort deltas that could possibly have an effect. + static WeightedDelta weighted_deltas[kMaxKeypoints]; + + // Compute median X value. + { + float total_weight = 0.0f; + + // Compute weighted mean and deltas. + for (int i = 0; i < kMaxKeypoints; ++i) { + weighted_deltas[i].delta = deltas[i].x; + const float weight = weights[i]; + weighted_deltas[i].weight = weight; + if (weight > 0.0f) { + total_weight += weight; + } + } + qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta), + WeightedDeltaCompare); + median_delta.x = GetMedian(kMaxKeypoints, weighted_deltas, total_weight); + } + + // Compute median Y value. + { + float total_weight = 0.0f; + + // Compute weighted mean and deltas. + for (int i = 0; i < kMaxKeypoints; ++i) { + const float weight = weights[i]; + weighted_deltas[i].weight = weight; + weighted_deltas[i].delta = deltas[i].y; + if (weight > 0.0f) { + total_weight += weight; + } + } + qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta), + WeightedDeltaCompare); + median_delta.y = GetMedian(kMaxKeypoints, weighted_deltas, total_weight); + } + + return median_delta; +} + +float FramePair::GetWeightedMedianScale( + const float* const weights, const Point2f* const deltas) const { + float median_delta; + + // TODO(andrewharp): only sort deltas that could possibly have an effect. + static WeightedDelta weighted_deltas[kMaxKeypoints * 2]; + + // Compute median scale value across x and y. + { + float total_weight = 0.0f; + + // Add X values. + for (int i = 0; i < kMaxKeypoints; ++i) { + weighted_deltas[i].delta = deltas[i].x; + const float weight = weights[i]; + weighted_deltas[i].weight = weight; + if (weight > 0.0f) { + total_weight += weight; + } + } + + // Add Y values. + for (int i = 0; i < kMaxKeypoints; ++i) { + weighted_deltas[i + kMaxKeypoints].delta = deltas[i].y; + const float weight = weights[i]; + weighted_deltas[i + kMaxKeypoints].weight = weight; + if (weight > 0.0f) { + total_weight += weight; + } + } + + qsort(weighted_deltas, kMaxKeypoints * 2, sizeof(WeightedDelta), + WeightedDeltaCompare); + + median_delta = GetMedian(kMaxKeypoints * 2, weighted_deltas, total_weight); + } + + return median_delta; +} + +} // namespace tf_tracking diff --git a/tensorflow/examples/android/jni/object_tracking/frame_pair.h b/tensorflow/examples/android/jni/object_tracking/frame_pair.h new file mode 100644 index 0000000000..3f2559a5e0 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/frame_pair.h @@ -0,0 +1,103 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ + +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" + +namespace tf_tracking { + +// A class that records keypoint correspondences from pairs of +// consecutive frames. +class FramePair { + public: + FramePair() + : start_time_(0), + end_time_(0), + number_of_keypoints_(0) {} + + // Cleans up the FramePair so that they can be reused. + void Init(const int64 start_time, const int64 end_time); + + void AdjustBox(const BoundingBox box, + float* const translation_x, + float* const translation_y, + float* const scale_x, + float* const scale_y) const; + + private: + // Returns the weighted median of the given deltas, computed independently on + // x and y. Returns 0,0 in case of failure. The assumption is that a + // translation of 0.0 in the degenerate case is the best that can be done, and + // should not be considered an error. + // + // In the case of scale, a slight exception is made just to be safe and + // there is a check for 0.0 explicitly, but that shouldn't ever be possible to + // happen naturally because of the non-zero + parity checks in FillScales. + Point2f GetWeightedMedian(const float* const weights, + const Point2f* const deltas) const; + + float GetWeightedMedianScale(const float* const weights, + const Point2f* const deltas) const; + + // Weights points based on the query_point and cutoff_dist. + int FillWeights(const BoundingBox& box, + float* const weights) const; + + // Fills in the array of deltas with the translations of the points + // between frames. + void FillTranslations(Point2f* const translations) const; + + // Fills in the array of deltas with the relative scale factor of points + // relative to a given center. Has the ability to override the weight to 0 if + // a degenerate scale is detected. + // Translation is the amount the center of the box has moved from one frame to + // the next. + int FillScales(const Point2f& old_center, + const Point2f& translation, + float* const weights, + Point2f* const scales) const; + + // TODO(andrewharp): Make these private. + public: + // The time at frame1. + int64 start_time_; + + // The time at frame2. + int64 end_time_; + + // This array will contain the keypoints found in frame 1. + Keypoint frame1_keypoints_[kMaxKeypoints]; + + // Contain the locations of the keypoints from frame 1 in frame 2. + Keypoint frame2_keypoints_[kMaxKeypoints]; + + // The number of keypoints in frame 1. + int number_of_keypoints_; + + // Keeps track of which keypoint correspondences were actually found from one + // frame to another. + // The i-th element of this array will be non-zero if and only if the i-th + // keypoint of frame 1 was found in frame 2. + bool optical_flow_found_keypoint_[kMaxKeypoints]; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(FramePair); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/geom.h b/tensorflow/examples/android/jni/object_tracking/geom.h new file mode 100644 index 0000000000..5d5249cd97 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/geom.h @@ -0,0 +1,319 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ + +#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +namespace tf_tracking { + +struct Size { + Size(const int width, const int height) : width(width), height(height) {} + + int width; + int height; +}; + + +class Point2f { + public: + Point2f() : x(0.0f), y(0.0f) {} + Point2f(const float x, const float y) : x(x), y(y) {} + + inline Point2f operator- (const Point2f& that) const { + return Point2f(this->x - that.x, this->y - that.y); + } + + inline Point2f operator+ (const Point2f& that) const { + return Point2f(this->x + that.x, this->y + that.y); + } + + inline Point2f& operator+= (const Point2f& that) { + this->x += that.x; + this->y += that.y; + return *this; + } + + inline Point2f& operator-= (const Point2f& that) { + this->x -= that.x; + this->y -= that.y; + return *this; + } + + inline Point2f operator- (const Point2f& that) { + return Point2f(this->x - that.x, this->y - that.y); + } + + inline float LengthSquared() { + return Square(this->x) + Square(this->y); + } + + inline float Length() { + return sqrtf(LengthSquared()); + } + + inline float DistanceSquared(const Point2f& that) { + return Square(this->x - that.x) + Square(this->y - that.y); + } + + inline float Distance(const Point2f& that) { + return sqrtf(DistanceSquared(that)); + } + + float x; + float y; +}; + +inline std::ostream& operator<<(std::ostream& stream, const Point2f& point) { + stream << point.x << "," << point.y; + return stream; +} + +class BoundingBox { + public: + BoundingBox() + : left_(0), + top_(0), + right_(0), + bottom_(0) {} + + BoundingBox(const BoundingBox& bounding_box) + : left_(bounding_box.left_), + top_(bounding_box.top_), + right_(bounding_box.right_), + bottom_(bounding_box.bottom_) { + SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_); + SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_); + } + + BoundingBox(const float left, + const float top, + const float right, + const float bottom) + : left_(left), + top_(top), + right_(right), + bottom_(bottom) { + SCHECK(left_ < right_, "Bounds out of whack! %.2f vs %.2f!", left_, right_); + SCHECK(top_ < bottom_, "Bounds out of whack! %.2f vs %.2f!", top_, bottom_); + } + + BoundingBox(const Point2f& point1, const Point2f& point2) + : left_(MIN(point1.x, point2.x)), + top_(MIN(point1.y, point2.y)), + right_(MAX(point1.x, point2.x)), + bottom_(MAX(point1.y, point2.y)) {} + + inline void CopyToArray(float* const bounds_array) const { + bounds_array[0] = left_; + bounds_array[1] = top_; + bounds_array[2] = right_; + bounds_array[3] = bottom_; + } + + inline float GetWidth() const { + return right_ - left_; + } + + inline float GetHeight() const { + return bottom_ - top_; + } + + inline float GetArea() const { + const float width = GetWidth(); + const float height = GetHeight(); + if (width <= 0 || height <= 0) { + return 0.0f; + } + + return width * height; + } + + inline Point2f GetCenter() const { + return Point2f((left_ + right_) / 2.0f, + (top_ + bottom_) / 2.0f); + } + + inline bool ValidBox() const { + return GetArea() > 0.0f; + } + + // Returns a bounding box created from the overlapping area of these two. + inline BoundingBox Intersect(const BoundingBox& that) const { + const float new_left = MAX(this->left_, that.left_); + const float new_right = MIN(this->right_, that.right_); + + if (new_left >= new_right) { + return BoundingBox(); + } + + const float new_top = MAX(this->top_, that.top_); + const float new_bottom = MIN(this->bottom_, that.bottom_); + + if (new_top >= new_bottom) { + return BoundingBox(); + } + + return BoundingBox(new_left, new_top, new_right, new_bottom); + } + + // Returns a bounding box that can contain both boxes. + inline BoundingBox Union(const BoundingBox& that) const { + return BoundingBox(MIN(this->left_, that.left_), + MIN(this->top_, that.top_), + MAX(this->right_, that.right_), + MAX(this->bottom_, that.bottom_)); + } + + inline float PascalScore(const BoundingBox& that) const { + SCHECK(GetArea() > 0.0f, "Empty bounding box!"); + SCHECK(that.GetArea() > 0.0f, "Empty bounding box!"); + + const float intersect_area = this->Intersect(that).GetArea(); + + if (intersect_area <= 0) { + return 0; + } + + const float score = + intersect_area / (GetArea() + that.GetArea() - intersect_area); + SCHECK(InRange(score, 0.0f, 1.0f), "Invalid score! %.2f", score); + return score; + } + + inline bool Intersects(const BoundingBox& that) const { + return InRange(that.left_, left_, right_) + || InRange(that.right_, left_, right_) + || InRange(that.top_, top_, bottom_) + || InRange(that.bottom_, top_, bottom_); + } + + // Returns whether another bounding box is completely inside of this bounding + // box. Sharing edges is ok. + inline bool Contains(const BoundingBox& that) const { + return that.left_ >= left_ && + that.right_ <= right_ && + that.top_ >= top_ && + that.bottom_ <= bottom_; + } + + inline bool Contains(const Point2f& point) const { + return InRange(point.x, left_, right_) && InRange(point.y, top_, bottom_); + } + + inline void Shift(const Point2f shift_amount) { + left_ += shift_amount.x; + top_ += shift_amount.y; + right_ += shift_amount.x; + bottom_ += shift_amount.y; + } + + inline void ScaleOrigin(const float scale_x, const float scale_y) { + left_ *= scale_x; + right_ *= scale_x; + top_ *= scale_y; + bottom_ *= scale_y; + } + + inline void Scale(const float scale_x, const float scale_y) { + const Point2f center = GetCenter(); + const float half_width = GetWidth() / 2.0f; + const float half_height = GetHeight() / 2.0f; + + left_ = center.x - half_width * scale_x; + right_ = center.x + half_width * scale_x; + + top_ = center.y - half_height * scale_y; + bottom_ = center.y + half_height * scale_y; + } + + float left_; + float top_; + float right_; + float bottom_; +}; +inline std::ostream& operator<<(std::ostream& stream, const BoundingBox& box) { + stream << "[" << box.left_ << " - " << box.right_ + << ", " << box.top_ << " - " << box.bottom_ + << ", w:" << box.GetWidth() << " h:" << box.GetHeight() << "]"; + return stream; +} + + +class BoundingSquare { + public: + BoundingSquare(const float x, const float y, const float size) + : x_(x), y_(y), size_(size) {} + + explicit BoundingSquare(const BoundingBox& box) + : x_(box.left_), y_(box.top_), size_(box.GetWidth()) { +#ifdef SANITY_CHECKS + if (std::abs(box.GetWidth() - box.GetHeight()) > 0.1f) { + LOG(WARNING) << "This is not a square: " << box << std::endl; + } +#endif + } + + inline BoundingBox ToBoundingBox() const { + return BoundingBox(x_, y_, x_ + size_, y_ + size_); + } + + inline bool ValidBox() { + return size_ > 0.0f; + } + + inline void Shift(const Point2f shift_amount) { + x_ += shift_amount.x; + y_ += shift_amount.y; + } + + inline void Scale(const float scale) { + const float new_size = size_ * scale; + const float position_diff = (new_size - size_) / 2.0f; + x_ -= position_diff; + y_ -= position_diff; + size_ = new_size; + } + + float x_; + float y_; + float size_; +}; +inline std::ostream& operator<<(std::ostream& stream, + const BoundingSquare& square) { + stream << "[" << square.x_ << "," << square.y_ << " " << square.size_ << "]"; + return stream; +} + + +inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box, + const float size) { + const float width_diff = (original_box.GetWidth() - size) / 2.0f; + const float height_diff = (original_box.GetHeight() - size) / 2.0f; + return BoundingSquare(original_box.left_ + width_diff, + original_box.top_ + height_diff, + size); +} + +inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box) { + return GetCenteredSquare( + original_box, MIN(original_box.GetWidth(), original_box.GetHeight())); +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/gl_utils.h b/tensorflow/examples/android/jni/object_tracking/gl_utils.h new file mode 100755 index 0000000000..bd5c233f4f --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/gl_utils.h @@ -0,0 +1,55 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_ + +#include <GLES/gl.h> +#include <GLES/glext.h> + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" + +namespace tf_tracking { + +// Draws a box at the given position. +inline static void DrawBox(const BoundingBox& bounding_box) { + const GLfloat line[] = { + bounding_box.left_, bounding_box.bottom_, + bounding_box.left_, bounding_box.top_, + bounding_box.left_, bounding_box.top_, + bounding_box.right_, bounding_box.top_, + bounding_box.right_, bounding_box.top_, + bounding_box.right_, bounding_box.bottom_, + bounding_box.right_, bounding_box.bottom_, + bounding_box.left_, bounding_box.bottom_ + }; + + glVertexPointer(2, GL_FLOAT, 0, line); + glEnableClientState(GL_VERTEX_ARRAY); + + glDrawArrays(GL_LINES, 0, 8); +} + + +// Changes the coordinate system such that drawing to an arbitrary square in +// the world can thereafter be drawn to using coordinates 0 - 1. +inline static void MapWorldSquareToUnitSquare(const BoundingSquare& square) { + glScalef(square.size_, square.size_, 1.0f); + glTranslatef(square.x_ / square.size_, square.y_ / square.size_, 0.0f); +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/image-inl.h b/tensorflow/examples/android/jni/object_tracking/image-inl.h new file mode 100644 index 0000000000..18123cef01 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/image-inl.h @@ -0,0 +1,642 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_ + +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +namespace tf_tracking { + +template <typename T> +Image<T>::Image(const int width, const int height) + : width_less_one_(width - 1), + height_less_one_(height - 1), + data_size_(width * height), + own_data_(true), + width_(width), + height_(height), + stride_(width) { + Allocate(); +} + +template <typename T> +Image<T>::Image(const Size& size) + : width_less_one_(size.width - 1), + height_less_one_(size.height - 1), + data_size_(size.width * size.height), + own_data_(true), + width_(size.width), + height_(size.height), + stride_(size.width) { + Allocate(); +} + +// Constructor that creates an image from preallocated data. +// Note: The image takes ownership of the data lifecycle, unless own_data is +// set to false. +template <typename T> +Image<T>::Image(const int width, const int height, T* const image_data, + const bool own_data) : + width_less_one_(width - 1), + height_less_one_(height - 1), + data_size_(width * height), + own_data_(own_data), + width_(width), + height_(height), + stride_(width) { + image_data_ = image_data; + SCHECK(image_data_ != NULL, "Can't create image with NULL data!"); +} + +template <typename T> +Image<T>::~Image() { + if (own_data_) { + delete[] image_data_; + } + image_data_ = NULL; +} + +template<typename T> +template<class DstType> +bool Image<T>::ExtractPatchAtSubpixelFixed1616(const int fp_x, + const int fp_y, + const int patchwidth, + const int patchheight, + DstType* to_data) const { + // Calculate weights. + const int trunc_x = fp_x >> 16; + const int trunc_y = fp_y >> 16; + + if (trunc_x < 0 || trunc_y < 0 || + (trunc_x + patchwidth) >= width_less_one_ || + (trunc_y + patchheight) >= height_less_one_) { + return false; + } + + // Now walk over destination patch and fill from interpolated source image. + for (int y = 0; y < patchheight; ++y, to_data += patchwidth) { + for (int x = 0; x < patchwidth; ++x) { + to_data[x] = + static_cast<DstType>(GetPixelInterpFixed1616(fp_x + (x << 16), + fp_y + (y << 16))); + } + } + + return true; +} + +template <typename T> +Image<T>* Image<T>::Crop( + const int left, const int top, const int right, const int bottom) const { + SCHECK(left >= 0 && left < width_, "out of bounds at %d!", left); + SCHECK(right >= 0 && right < width_, "out of bounds at %d!", right); + SCHECK(top >= 0 && top < height_, "out of bounds at %d!", top); + SCHECK(bottom >= 0 && bottom < height_, "out of bounds at %d!", bottom); + + SCHECK(left <= right, "mismatch!"); + SCHECK(top <= bottom, "mismatch!"); + + const int new_width = right - left + 1; + const int new_height = bottom - top + 1; + + Image<T>* const cropped_image = new Image(new_width, new_height); + + for (int y = 0; y < new_height; ++y) { + memcpy((*cropped_image)[y], ((*this)[y + top] + left), + new_width * sizeof(T)); + } + + return cropped_image; +} + +template <typename T> +inline float Image<T>::GetPixelInterp(const float x, const float y) const { + // Do int conversion one time. + const int floored_x = static_cast<int>(x); + const int floored_y = static_cast<int>(y); + + // Note: it might be the case that the *_[min|max] values are clipped, and + // these (the a b c d vals) aren't (for speed purposes), but that doesn't + // matter. We'll just be blending the pixel with itself in that case anyway. + const float b = x - floored_x; + const float a = 1.0f - b; + + const float d = y - floored_y; + const float c = 1.0f - d; + + SCHECK(ValidInterpPixel(x, y), + "x or y out of bounds! %.2f [0 - %d), %.2f [0 - %d)", + x, width_less_one_, y, height_less_one_); + + const T* const pix_ptr = (*this)[floored_y] + floored_x; + + // Get the pixel values surrounding this point. + const T& p1 = pix_ptr[0]; + const T& p2 = pix_ptr[1]; + const T& p3 = pix_ptr[width_]; + const T& p4 = pix_ptr[width_ + 1]; + + // Simple bilinear interpolation between four reference pixels. + // If x is the value requested: + // a b + // ------- + // c |p1 p2| + // | x | + // d |p3 p4| + // ------- + return c * ((a * p1) + (b * p2)) + + d * ((a * p3) + (b * p4)); +} + + +template <typename T> +inline T Image<T>::GetPixelInterpFixed1616( + const int fp_x_whole, const int fp_y_whole) const { + static const int kFixedPointOne = 0x00010000; + static const int kFixedPointHalf = 0x00008000; + static const int kFixedPointTruncateMask = 0xFFFF0000; + + int trunc_x = fp_x_whole & kFixedPointTruncateMask; + int trunc_y = fp_y_whole & kFixedPointTruncateMask; + const int fp_x = fp_x_whole - trunc_x; + const int fp_y = fp_y_whole - trunc_y; + + // Scale the truncated values back to regular ints. + trunc_x >>= 16; + trunc_y >>= 16; + + const int one_minus_fp_x = kFixedPointOne - fp_x; + const int one_minus_fp_y = kFixedPointOne - fp_y; + + const T* trunc_start = (*this)[trunc_y] + trunc_x; + + const T a = trunc_start[0]; + const T b = trunc_start[1]; + const T c = trunc_start[stride_]; + const T d = trunc_start[stride_ + 1]; + + return ((one_minus_fp_y * static_cast<int64>(one_minus_fp_x * a + fp_x * b) + + fp_y * static_cast<int64>(one_minus_fp_x * c + fp_x * d) + + kFixedPointHalf) >> 32); +} + +template <typename T> +inline bool Image<T>::ValidPixel(const int x, const int y) const { + return InRange(x, ZERO, width_less_one_) && + InRange(y, ZERO, height_less_one_); +} + +template <typename T> +inline BoundingBox Image<T>::GetContainingBox() const { + return BoundingBox( + 0, 0, width_less_one_ - EPSILON, height_less_one_ - EPSILON); +} + +template <typename T> +inline bool Image<T>::Contains(const BoundingBox& bounding_box) const { + // TODO(andrewharp): Come up with a more elegant way of ensuring that bounds + // are ok. + return GetContainingBox().Contains(bounding_box); +} + +template <typename T> +inline bool Image<T>::ValidInterpPixel(const float x, const float y) const { + // Exclusive of max because we can be more efficient if we don't handle + // interpolating on or past the last pixel. + return (x >= ZERO) && (x < width_less_one_) && + (y >= ZERO) && (y < height_less_one_); +} + +template <typename T> +void Image<T>::DownsampleAveraged(const T* const original, const int stride, + const int factor) { +#ifdef __ARM_NEON + if (factor == 4 || factor == 2) { + DownsampleAveragedNeon(original, stride, factor); + return; + } +#endif + + // TODO(andrewharp): delete or enable this for non-uint8 downsamples. + const int pixels_per_block = factor * factor; + + // For every pixel in resulting image. + for (int y = 0; y < height_; ++y) { + const int orig_y = y * factor; + const int y_bound = orig_y + factor; + + // Sum up the original pixels. + for (int x = 0; x < width_; ++x) { + const int orig_x = x * factor; + const int x_bound = orig_x + factor; + + // Making this int32 because type U or T might overflow. + int32 pixel_sum = 0; + + // Grab all the pixels that make up this pixel. + for (int curr_y = orig_y; curr_y < y_bound; ++curr_y) { + const T* p = original + curr_y * stride + orig_x; + + for (int curr_x = orig_x; curr_x < x_bound; ++curr_x) { + pixel_sum += *p++; + } + } + + (*this)[y][x] = pixel_sum / pixels_per_block; + } + } +} + +template <typename T> +void Image<T>::DownsampleInterpolateNearest(const Image<T>& original) { + // Calculating the scaling factors based on target image size. + const float factor_x = static_cast<float>(original.GetWidth()) / + static_cast<float>(width_); + const float factor_y = static_cast<float>(original.GetHeight()) / + static_cast<float>(height_); + + // Calculating initial offset in x-axis. + const float offset_x = 0.5f * (original.GetWidth() - width_) / width_; + + // Calculating initial offset in y-axis. + const float offset_y = 0.5f * (original.GetHeight() - height_) / height_; + + float orig_y = offset_y; + + // For every pixel in resulting image. + for (int y = 0; y < height_; ++y) { + float orig_x = offset_x; + + // Finding nearest pixel on y-axis. + const int nearest_y = static_cast<int>(orig_y + 0.5f); + const T* row_data = original[nearest_y]; + + T* pixel_ptr = (*this)[y]; + + for (int x = 0; x < width_; ++x) { + // Finding nearest pixel on x-axis. + const int nearest_x = static_cast<int>(orig_x + 0.5f); + + *pixel_ptr++ = row_data[nearest_x]; + + orig_x += factor_x; + } + + orig_y += factor_y; + } +} + +template <typename T> +void Image<T>::DownsampleInterpolateLinear(const Image<T>& original) { + // TODO(andrewharp): Turn this into a general compare sizes/bulk + // copy method. + if (original.GetWidth() == GetWidth() && + original.GetHeight() == GetHeight() && + original.stride() == stride()) { + memcpy(image_data_, original.data(), data_size_ * sizeof(T)); + return; + } + + // Calculating the scaling factors based on target image size. + const float factor_x = static_cast<float>(original.GetWidth()) / + static_cast<float>(width_); + const float factor_y = static_cast<float>(original.GetHeight()) / + static_cast<float>(height_); + + // Calculating initial offset in x-axis. + const float offset_x = 0; + const int offset_x_fp = RealToFixed1616(offset_x); + + // Calculating initial offset in y-axis. + const float offset_y = 0; + const int offset_y_fp = RealToFixed1616(offset_y); + + // Get the fixed point scaling factor value. + // Shift by 8 so we can fit everything into a 4 byte int later for speed + // reasons. This means the precision is limited to 1 / 256th of a pixel, + // but this should be good enough. + const int factor_x_fp = RealToFixed1616(factor_x) >> 8; + const int factor_y_fp = RealToFixed1616(factor_y) >> 8; + + int src_y_fp = offset_y_fp >> 8; + + static const int kFixedPointOne8 = 0x00000100; + static const int kFixedPointHalf8 = 0x00000080; + static const int kFixedPointTruncateMask8 = 0xFFFFFF00; + + // For every pixel in resulting image. + for (int y = 0; y < height_; ++y) { + int src_x_fp = offset_x_fp >> 8; + + int trunc_y = src_y_fp & kFixedPointTruncateMask8; + const int fp_y = src_y_fp - trunc_y; + + // Scale the truncated values back to regular ints. + trunc_y >>= 8; + + const int one_minus_fp_y = kFixedPointOne8 - fp_y; + + T* pixel_ptr = (*this)[y]; + + // Make sure not to read from an invalid row. + const int trunc_y_b = MIN(original.height_less_one_, trunc_y + 1); + const T* other_top_ptr = original[trunc_y]; + const T* other_bot_ptr = original[trunc_y_b]; + + int last_trunc_x = -1; + int trunc_x = -1; + + T a = 0; + T b = 0; + T c = 0; + T d = 0; + + for (int x = 0; x < width_; ++x) { + trunc_x = src_x_fp & kFixedPointTruncateMask8; + + const int fp_x = (src_x_fp - trunc_x) >> 8; + + // Scale the truncated values back to regular ints. + trunc_x >>= 8; + + // It's possible we're reading from the same pixels + if (trunc_x != last_trunc_x) { + // Make sure not to read from an invalid column. + const int trunc_x_b = MIN(original.width_less_one_, trunc_x + 1); + a = other_top_ptr[trunc_x]; + b = other_top_ptr[trunc_x_b]; + c = other_bot_ptr[trunc_x]; + d = other_bot_ptr[trunc_x_b]; + last_trunc_x = trunc_x; + } + + const int one_minus_fp_x = kFixedPointOne8 - fp_x; + + const int32 value = + ((one_minus_fp_y * one_minus_fp_x * a + fp_x * b) + + (fp_y * one_minus_fp_x * c + fp_x * d) + + kFixedPointHalf8) >> 16; + + *pixel_ptr++ = value; + + src_x_fp += factor_x_fp; + } + src_y_fp += factor_y_fp; + } +} + +template <typename T> +void Image<T>::DownsampleSmoothed3x3(const Image<T>& original) { + for (int y = 0; y < height_; ++y) { + const int orig_y = Clip(2 * y, ZERO, original.height_less_one_); + const int min_y = Clip(orig_y - 1, ZERO, original.height_less_one_); + const int max_y = Clip(orig_y + 1, ZERO, original.height_less_one_); + + for (int x = 0; x < width_; ++x) { + const int orig_x = Clip(2 * x, ZERO, original.width_less_one_); + const int min_x = Clip(orig_x - 1, ZERO, original.width_less_one_); + const int max_x = Clip(orig_x + 1, ZERO, original.width_less_one_); + + // Center. + int32 pixel_sum = original[orig_y][orig_x] * 4; + + // Sides. + pixel_sum += (original[orig_y][max_x] + + original[orig_y][min_x] + + original[max_y][orig_x] + + original[min_y][orig_x]) * 2; + + // Diagonals. + pixel_sum += (original[min_y][max_x] + + original[min_y][min_x] + + original[max_y][max_x] + + original[max_y][min_x]); + + (*this)[y][x] = pixel_sum >> 4; // 16 + } + } +} + +template <typename T> +void Image<T>::DownsampleSmoothed5x5(const Image<T>& original) { + const int max_x = original.width_less_one_; + const int max_y = original.height_less_one_; + + // The JY Bouget paper on Lucas-Kanade recommends a + // [1/16 1/4 3/8 1/4 1/16]^2 filter. + // This works out to a [1 4 6 4 1]^2 / 256 array, precomputed below. + static const int window_radius = 2; + static const int window_size = window_radius*2 + 1; + static const int window_weights[] = {1, 4, 6, 4, 1, // 16 + + 4, 16, 24, 16, 4, // 64 + + 6, 24, 36, 24, 6, // 96 + + 4, 16, 24, 16, 4, // 64 + + 1, 4, 6, 4, 1}; // 16 = 256 + + // We'll multiply and sum with the the whole numbers first, then divide by + // the total weight to normalize at the last moment. + for (int y = 0; y < height_; ++y) { + for (int x = 0; x < width_; ++x) { + int32 pixel_sum = 0; + + const int* w = window_weights; + const int start_x = Clip((x << 1) - window_radius, ZERO, max_x); + + // Clip the boundaries to the size of the image. + for (int window_y = 0; window_y < window_size; ++window_y) { + const int start_y = + Clip((y << 1) - window_radius + window_y, ZERO, max_y); + + const T* p = original[start_y] + start_x; + + for (int window_x = 0; window_x < window_size; ++window_x) { + pixel_sum += *p++ * *w++; + } + } + + // Conversion to type T will happen here after shifting right 8 bits to + // divide by 256. + (*this)[y][x] = pixel_sum >> 8; + } + } +} + +template <typename T> +template <typename U> +inline T Image<T>::ScharrPixelX(const Image<U>& original, + const int center_x, const int center_y) const { + const int min_x = Clip(center_x - 1, ZERO, original.width_less_one_); + const int max_x = Clip(center_x + 1, ZERO, original.width_less_one_); + const int min_y = Clip(center_y - 1, ZERO, original.height_less_one_); + const int max_y = Clip(center_y + 1, ZERO, original.height_less_one_); + + // Convolution loop unrolled for performance... + return (3 * (original[min_y][max_x] + + original[max_y][max_x] + - original[min_y][min_x] + - original[max_y][min_x]) + + 10 * (original[center_y][max_x] + - original[center_y][min_x])) / 32; +} + +template <typename T> +template <typename U> +inline T Image<T>::ScharrPixelY(const Image<U>& original, + const int center_x, const int center_y) const { + const int min_x = Clip(center_x - 1, 0, original.width_less_one_); + const int max_x = Clip(center_x + 1, 0, original.width_less_one_); + const int min_y = Clip(center_y - 1, 0, original.height_less_one_); + const int max_y = Clip(center_y + 1, 0, original.height_less_one_); + + // Convolution loop unrolled for performance... + return (3 * (original[max_y][min_x] + + original[max_y][max_x] + - original[min_y][min_x] + - original[min_y][max_x]) + + 10 * (original[max_y][center_x] + - original[min_y][center_x])) / 32; +} + +template <typename T> +template <typename U> +inline void Image<T>::ScharrX(const Image<U>& original) { + for (int y = 0; y < height_; ++y) { + for (int x = 0; x < width_; ++x) { + SetPixel(x, y, ScharrPixelX(original, x, y)); + } + } +} + +template <typename T> +template <typename U> +inline void Image<T>::ScharrY(const Image<U>& original) { + for (int y = 0; y < height_; ++y) { + for (int x = 0; x < width_; ++x) { + SetPixel(x, y, ScharrPixelY(original, x, y)); + } + } +} + +template <typename T> +template <typename U> +void Image<T>::DerivativeX(const Image<U>& original) { + for (int y = 0; y < height_; ++y) { + const U* const source_row = original[y]; + T* const dest_row = (*this)[y]; + + // Compute first pixel. Approximated with forward difference. + dest_row[0] = source_row[1] - source_row[0]; + + // All the pixels in between. Central difference method. + const U* source_prev_pixel = source_row; + T* dest_pixel = dest_row + 1; + const U* source_next_pixel = source_row + 2; + for (int x = 1; x < width_less_one_; ++x) { + *dest_pixel++ = HalfDiff(*source_prev_pixel++, *source_next_pixel++); + } + + // Last pixel. Approximated with backward difference. + dest_row[width_less_one_] = + source_row[width_less_one_] - source_row[width_less_one_ - 1]; + } +} + +template <typename T> +template <typename U> +void Image<T>::DerivativeY(const Image<U>& original) { + const int src_stride = original.stride(); + + // Compute 1st row. Approximated with forward difference. + { + const U* const src_row = original[0]; + T* dest_row = (*this)[0]; + for (int x = 0; x < width_; ++x) { + dest_row[x] = src_row[x + src_stride] - src_row[x]; + } + } + + // Compute all rows in between using central difference. + for (int y = 1; y < height_less_one_; ++y) { + T* dest_row = (*this)[y]; + + const U* source_prev_pixel = original[y - 1]; + const U* source_next_pixel = original[y + 1]; + for (int x = 0; x < width_; ++x) { + *dest_row++ = HalfDiff(*source_prev_pixel++, *source_next_pixel++); + } + } + + // Compute last row. Approximated with backward difference. + { + const U* const src_row = original[height_less_one_]; + T* dest_row = (*this)[height_less_one_]; + for (int x = 0; x < width_; ++x) { + dest_row[x] = src_row[x] - src_row[x - src_stride]; + } + } +} + +template <typename T> +template <typename U> +inline T Image<T>::ConvolvePixel3x3(const Image<U>& original, + const int* const filter, + const int center_x, const int center_y, + const int total) const { + int32 sum = 0; + for (int filter_y = 0; filter_y < 3; ++filter_y) { + const int y = Clip(center_y - 1 + filter_y, 0, original.GetHeight()); + for (int filter_x = 0; filter_x < 3; ++filter_x) { + const int x = Clip(center_x - 1 + filter_x, 0, original.GetWidth()); + sum += original[y][x] * filter[filter_y * 3 + filter_x]; + } + } + return sum / total; +} + +template <typename T> +template <typename U> +inline void Image<T>::Convolve3x3(const Image<U>& original, + const int32* const filter) { + int32 sum = 0; + for (int i = 0; i < 9; ++i) { + sum += abs(filter[i]); + } + for (int y = 0; y < height_; ++y) { + for (int x = 0; x < width_; ++x) { + SetPixel(x, y, ConvolvePixel3x3(original, filter, x, y, sum)); + } + } +} + +template <typename T> +inline void Image<T>::FromArray(const T* const pixels, const int stride, + const int factor) { + if (factor == 1 && stride == width_) { + // If not subsampling, memcpy per line should be faster. + memcpy(this->image_data_, pixels, data_size_ * sizeof(T)); + return; + } + + DownsampleAveraged(pixels, stride, factor); +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/image.h b/tensorflow/examples/android/jni/object_tracking/image.h new file mode 100644 index 0000000000..29b0adbda8 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/image.h @@ -0,0 +1,346 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_ + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +using namespace tensorflow; + +// TODO(andrewharp): Make this a cast to uint32 if/when we go unsigned for +// operations. +#define ZERO 0 + +#ifdef SANITY_CHECKS + #define CHECK_PIXEL(IMAGE, X, Y) {\ + SCHECK((IMAGE)->ValidPixel((X), (Y)), \ + "CHECK_PIXEL(%d,%d) in %dx%d image.", \ + static_cast<int>(X), static_cast<int>(Y), \ + (IMAGE)->GetWidth(), (IMAGE)->GetHeight());\ + } + + #define CHECK_PIXEL_INTERP(IMAGE, X, Y) {\ + SCHECK((IMAGE)->validInterpPixel((X), (Y)), \ + "CHECK_PIXEL_INTERP(%.2f, %.2f) in %dx%d image.", \ + static_cast<float>(X), static_cast<float>(Y), \ + (IMAGE)->GetWidth(), (IMAGE)->GetHeight());\ + } +#else + #define CHECK_PIXEL(image, x, y) {} + #define CHECK_PIXEL_INTERP(IMAGE, X, Y) {} +#endif + +namespace tf_tracking { + +#ifdef SANITY_CHECKS +// Class which exists solely to provide bounds checking for array-style image +// data access. +template <typename T> +class RowData { + public: + RowData(T* const row_data, const int max_col) + : row_data_(row_data), max_col_(max_col) {} + + inline T& operator[](const int col) const { + SCHECK(InRange(col, 0, max_col_), + "Column out of range: %d (%d max)", col, max_col_); + return row_data_[col]; + } + + inline operator T*() const { + return row_data_; + } + + private: + T* const row_data_; + const int max_col_; +}; +#endif + +// Naive templated sorting function. +template <typename T> +int Comp(const void* a, const void* b) { + const T val1 = *reinterpret_cast<const T*>(a); + const T val2 = *reinterpret_cast<const T*>(b); + + if (val1 == val2) { + return 0; + } else if (val1 < val2) { + return -1; + } else { + return 1; + } +} + +// TODO(andrewharp): Make explicit which operations support negative numbers or +// struct/class types in image data (possibly create fast multi-dim array class +// for data where pixel arithmetic does not make sense). + +// Image class optimized for working on numeric arrays as grayscale image data. +// Supports other data types as a 2D array class, so long as no pixel math +// operations are called (convolution, downsampling, etc). +template <typename T> +class Image { + public: + Image(const int width, const int height); + explicit Image(const Size& size); + + // Constructor that creates an image from preallocated data. + // Note: The image takes ownership of the data lifecycle, unless own_data is + // set to false. + Image(const int width, const int height, T* const image_data, + const bool own_data = true); + + ~Image(); + + // Extract a pixel patch from this image, starting at a subpixel location. + // Uses 16:16 fixed point format for representing real values and doing the + // bilinear interpolation. + // + // Arguments fp_x and fp_y tell the subpixel position in fixed point format, + // patchwidth/patchheight give the size of the patch in pixels and + // to_data must be a valid pointer to a *contiguous* destination data array. + template<class DstType> + bool ExtractPatchAtSubpixelFixed1616(const int fp_x, + const int fp_y, + const int patchwidth, + const int patchheight, + DstType* to_data) const; + + Image<T>* Crop( + const int left, const int top, const int right, const int bottom) const; + + inline int GetWidth() const { return width_; } + inline int GetHeight() const { return height_; } + + // Bilinearly sample a value between pixels. Values must be within the image. + inline float GetPixelInterp(const float x, const float y) const; + + // Bilinearly sample a pixels at a subpixel position using fixed point + // arithmetic. + // Avoids float<->int conversions. + // Values must be within the image. + // Arguments fp_x and fp_y tell the subpixel position in + // 16:16 fixed point format. + // + // Important: This function only makes sense for integer-valued images, such + // as Image<uint8> or Image<int> etc. + inline T GetPixelInterpFixed1616(const int fp_x_whole, + const int fp_y_whole) const; + + // Returns true iff the pixel is in the image's boundaries. + inline bool ValidPixel(const int x, const int y) const; + + inline BoundingBox GetContainingBox() const; + + inline bool Contains(const BoundingBox& bounding_box) const; + + inline T GetMedianValue() { + qsort(image_data_, data_size_, sizeof(image_data_[0]), Comp<T>); + return image_data_[data_size_ >> 1]; + } + + // Returns true iff the pixel is in the image's boundaries for interpolation + // purposes. + // TODO(andrewharp): check in interpolation follow-up change. + inline bool ValidInterpPixel(const float x, const float y) const; + + // Safe lookup with boundary enforcement. + inline T GetPixelClipped(const int x, const int y) const { + return (*this)[Clip(y, ZERO, height_less_one_)] + [Clip(x, ZERO, width_less_one_)]; + } + +#ifdef SANITY_CHECKS + inline RowData<T> operator[](const int row) { + SCHECK(InRange(row, 0, height_less_one_), + "Row out of range: %d (%d max)", row, height_less_one_); + return RowData<T>(image_data_ + row * stride_, width_less_one_); + } + + inline const RowData<T> operator[](const int row) const { + SCHECK(InRange(row, 0, height_less_one_), + "Row out of range: %d (%d max)", row, height_less_one_); + return RowData<T>(image_data_ + row * stride_, width_less_one_); + } +#else + inline T* operator[](const int row) { + return image_data_ + row * stride_; + } + + inline const T* operator[](const int row) const { + return image_data_ + row * stride_; + } +#endif + + const T* data() const { return image_data_; } + + inline int stride() const { return stride_; } + + // Clears image to a single value. + inline void Clear(const T& val) { + memset(image_data_, val, sizeof(*image_data_) * data_size_); + } + +#ifdef __ARM_NEON + void Downsample2x32ColumnsNeon(const uint8* const original, + const int stride, + const int orig_x); + + void Downsample4x32ColumnsNeon(const uint8* const original, + const int stride, + const int orig_x); + + void DownsampleAveragedNeon(const uint8* const original, const int stride, + const int factor); +#endif + + // Naive downsampler that reduces image size by factor by averaging pixels in + // blocks of size factor x factor. + void DownsampleAveraged(const T* const original, const int stride, + const int factor); + + // Naive downsampler that reduces image size by factor by averaging pixels in + // blocks of size factor x factor. + inline void DownsampleAveraged(const Image<T>& original, const int factor) { + DownsampleAveraged(original.data(), original.GetWidth(), factor); + } + + // Native downsampler that reduces image size using nearest interpolation + void DownsampleInterpolateNearest(const Image<T>& original); + + // Native downsampler that reduces image size using fixed-point bilinear + // interpolation + void DownsampleInterpolateLinear(const Image<T>& original); + + // Relatively efficient downsampling of an image by a factor of two with a + // low-pass 3x3 smoothing operation thrown in. + void DownsampleSmoothed3x3(const Image<T>& original); + + // Relatively efficient downsampling of an image by a factor of two with a + // low-pass 5x5 smoothing operation thrown in. + void DownsampleSmoothed5x5(const Image<T>& original); + + // Optimized Scharr filter on a single pixel in the X direction. + // Scharr filters are like central-difference operators, but have more + // rotational symmetry in their response because they also consider the + // diagonal neighbors. + template <typename U> + inline T ScharrPixelX(const Image<U>& original, + const int center_x, const int center_y) const; + + // Optimized Scharr filter on a single pixel in the X direction. + // Scharr filters are like central-difference operators, but have more + // rotational symmetry in their response because they also consider the + // diagonal neighbors. + template <typename U> + inline T ScharrPixelY(const Image<U>& original, + const int center_x, const int center_y) const; + + // Convolve the image with a Scharr filter in the X direction. + // Much faster than an equivalent generic convolution. + template <typename U> + inline void ScharrX(const Image<U>& original); + + // Convolve the image with a Scharr filter in the Y direction. + // Much faster than an equivalent generic convolution. + template <typename U> + inline void ScharrY(const Image<U>& original); + + static inline T HalfDiff(int32 first, int32 second) { + return (second - first) / 2; + } + + template <typename U> + void DerivativeX(const Image<U>& original); + + template <typename U> + void DerivativeY(const Image<U>& original); + + // Generic function for convolving pixel with 3x3 filter. + // Filter pixels should be in row major order. + template <typename U> + inline T ConvolvePixel3x3(const Image<U>& original, + const int* const filter, + const int center_x, const int center_y, + const int total) const; + + // Generic function for convolving an image with a 3x3 filter. + // TODO(andrewharp): Generalize this for any size filter. + template <typename U> + inline void Convolve3x3(const Image<U>& original, + const int32* const filter); + + // Load this image's data from a data array. The data at pixels is assumed to + // have dimensions equivalent to this image's dimensions * factor. + inline void FromArray(const T* const pixels, const int stride, + const int factor = 1); + + // Copy the image back out to an appropriately sized data array. + inline void ToArray(T* const pixels) const { + // If not subsampling, memcpy should be faster. + memcpy(pixels, this->image_data_, data_size_ * sizeof(T)); + } + + // Precompute these for efficiency's sake as they're used by a lot of + // clipping code and loop code. + // TODO(andrewharp): make these only accessible by other Images. + const int width_less_one_; + const int height_less_one_; + + // The raw size of the allocated data. + const int data_size_; + + private: + inline void Allocate() { + image_data_ = new T[data_size_]; + if (image_data_ == NULL) { + LOGE("Couldn't allocate image data!"); + } + } + + T* image_data_; + + bool own_data_; + + const int width_; + const int height_; + + // The image stride (offset to next row). + // TODO(andrewharp): Make sure that stride is honored in all code. + const int stride_; + + TF_DISALLOW_COPY_AND_ASSIGN(Image); +}; + +template <typename t> +inline std::ostream& operator<<(std::ostream& stream, const Image<t>& image) { + for (int y = 0; y < image.GetHeight(); ++y) { + for (int x = 0; x < image.GetWidth(); ++x) { + stream << image[y][x] << " "; + } + stream << std::endl; + } + return stream; +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/image_data.h b/tensorflow/examples/android/jni/object_tracking/image_data.h new file mode 100644 index 0000000000..16b1864ee6 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/image_data.h @@ -0,0 +1,270 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ + +#include <memory> + +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/image_utils.h" +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" + +using namespace tensorflow; + +namespace tf_tracking { + +// Class that encapsulates all bulky processed data for a frame. +class ImageData { + public: + explicit ImageData(const int width, const int height) + : uv_frame_width_(width << 1), + uv_frame_height_(height << 1), + timestamp_(0), + image_(width, height) { + InitPyramid(width, height); + ResetComputationCache(); + } + + private: + void ResetComputationCache() { + uv_data_computed_ = false; + integral_image_computed_ = false; + for (int i = 0; i < kNumPyramidLevels; ++i) { + spatial_x_computed_[i] = false; + spatial_y_computed_[i] = false; + pyramid_sqrt2_computed_[i * 2] = false; + pyramid_sqrt2_computed_[i * 2 + 1] = false; + } + } + + void InitPyramid(const int width, const int height) { + int level_width = width; + int level_height = height; + + for (int i = 0; i < kNumPyramidLevels; ++i) { + pyramid_sqrt2_[i * 2] = NULL; + pyramid_sqrt2_[i * 2 + 1] = NULL; + spatial_x_[i] = NULL; + spatial_y_[i] = NULL; + + level_width /= 2; + level_height /= 2; + } + + // Alias the first pyramid level to image_. + pyramid_sqrt2_[0] = &image_; + } + + public: + ~ImageData() { + // The first pyramid level is actually an alias to image_, + // so make sure it doesn't get deleted here. + pyramid_sqrt2_[0] = NULL; + + for (int i = 0; i < kNumPyramidLevels; ++i) { + SAFE_DELETE(pyramid_sqrt2_[i * 2]); + SAFE_DELETE(pyramid_sqrt2_[i * 2 + 1]); + SAFE_DELETE(spatial_x_[i]); + SAFE_DELETE(spatial_y_[i]); + } + } + + void SetData(const uint8* const new_frame, const int stride, + const int64 timestamp, const int downsample_factor) { + SetData(new_frame, NULL, stride, timestamp, downsample_factor); + } + + void SetData(const uint8* const new_frame, + const uint8* const uv_frame, + const int stride, + const int64 timestamp, const int downsample_factor) { + ResetComputationCache(); + + timestamp_ = timestamp; + + TimeLog("SetData!"); + + pyramid_sqrt2_[0]->FromArray(new_frame, stride, downsample_factor); + pyramid_sqrt2_computed_[0] = true; + TimeLog("Downsampled image"); + + if (uv_frame != NULL) { + if (u_data_.get() == NULL) { + u_data_.reset(new Image<uint8>(uv_frame_width_, uv_frame_height_)); + v_data_.reset(new Image<uint8>(uv_frame_width_, uv_frame_height_)); + } + + GetUV(uv_frame, u_data_.get(), v_data_.get()); + uv_data_computed_ = true; + TimeLog("Copied UV data"); + } else { + LOGV("No uv data!"); + } + +#ifdef LOG_TIME + // If profiling is enabled, precompute here to make it easier to distinguish + // total costs. + Precompute(); +#endif + } + + inline const uint64 GetTimestamp() const { + return timestamp_; + } + + inline const Image<uint8>* GetImage() const { + SCHECK(pyramid_sqrt2_computed_[0], "image not set!"); + return pyramid_sqrt2_[0]; + } + + const Image<uint8>* GetPyramidSqrt2Level(const int level) const { + if (!pyramid_sqrt2_computed_[level]) { + SCHECK(level != 0, "Level equals 0!"); + if (level == 1) { + const Image<uint8>& upper_level = *GetPyramidSqrt2Level(0); + if (pyramid_sqrt2_[level] == NULL) { + const int new_width = + (static_cast<int>(upper_level.GetWidth() / sqrtf(2)) + 1) / 2 * 2; + const int new_height = + (static_cast<int>(upper_level.GetHeight() / sqrtf(2)) + 1) / 2 * + 2; + + pyramid_sqrt2_[level] = new Image<uint8>(new_width, new_height); + } + pyramid_sqrt2_[level]->DownsampleInterpolateLinear(upper_level); + } else { + const Image<uint8>& upper_level = *GetPyramidSqrt2Level(level - 2); + if (pyramid_sqrt2_[level] == NULL) { + pyramid_sqrt2_[level] = new Image<uint8>( + upper_level.GetWidth() / 2, upper_level.GetHeight() / 2); + } + pyramid_sqrt2_[level]->DownsampleAveraged( + upper_level.data(), upper_level.stride(), 2); + } + pyramid_sqrt2_computed_[level] = true; + } + return pyramid_sqrt2_[level]; + } + + inline const Image<int32>* GetSpatialX(const int level) const { + if (!spatial_x_computed_[level]) { + const Image<uint8>& src = *GetPyramidSqrt2Level(level * 2); + if (spatial_x_[level] == NULL) { + spatial_x_[level] = new Image<int32>(src.GetWidth(), src.GetHeight()); + } + spatial_x_[level]->DerivativeX(src); + spatial_x_computed_[level] = true; + } + return spatial_x_[level]; + } + + inline const Image<int32>* GetSpatialY(const int level) const { + if (!spatial_y_computed_[level]) { + const Image<uint8>& src = *GetPyramidSqrt2Level(level * 2); + if (spatial_y_[level] == NULL) { + spatial_y_[level] = new Image<int32>(src.GetWidth(), src.GetHeight()); + } + spatial_y_[level]->DerivativeY(src); + spatial_y_computed_[level] = true; + } + return spatial_y_[level]; + } + + // The integral image is currently only used for object detection, so lazily + // initialize it on request. + inline const IntegralImage* GetIntegralImage() const { + if (integral_image_.get() == NULL) { + integral_image_.reset(new IntegralImage(image_)); + } else if (!integral_image_computed_) { + integral_image_->Recompute(image_); + } + integral_image_computed_ = true; + return integral_image_.get(); + } + + inline const Image<uint8>* GetU() const { + SCHECK(uv_data_computed_, "UV data not provided!"); + return u_data_.get(); + } + + inline const Image<uint8>* GetV() const { + SCHECK(uv_data_computed_, "UV data not provided!"); + return v_data_.get(); + } + + private: + void Precompute() { + // Create the smoothed pyramids. + for (int i = 0; i < kNumPyramidLevels * 2; i += 2) { + (void) GetPyramidSqrt2Level(i); + } + TimeLog("Created smoothed pyramids"); + + // Create the smoothed pyramids. + for (int i = 1; i < kNumPyramidLevels * 2; i += 2) { + (void) GetPyramidSqrt2Level(i); + } + TimeLog("Created smoothed sqrt pyramids"); + + // Create the spatial derivatives for frame 1. + for (int i = 0; i < kNumPyramidLevels; ++i) { + (void) GetSpatialX(i); + (void) GetSpatialY(i); + } + TimeLog("Created spatial derivatives"); + + (void) GetIntegralImage(); + TimeLog("Got integral image!"); + } + + const int uv_frame_width_; + const int uv_frame_height_; + + int64 timestamp_; + + Image<uint8> image_; + + bool uv_data_computed_; + std::unique_ptr<Image<uint8> > u_data_; + std::unique_ptr<Image<uint8> > v_data_; + + mutable bool spatial_x_computed_[kNumPyramidLevels]; + mutable Image<int32>* spatial_x_[kNumPyramidLevels]; + + mutable bool spatial_y_computed_[kNumPyramidLevels]; + mutable Image<int32>* spatial_y_[kNumPyramidLevels]; + + // Mutable so the lazy initialization can work when this class is const. + // Whether or not the integral image has been computed for the current image. + mutable bool integral_image_computed_; + mutable std::unique_ptr<IntegralImage> integral_image_; + + mutable bool pyramid_sqrt2_computed_[kNumPyramidLevels * 2]; + mutable Image<uint8>* pyramid_sqrt2_[kNumPyramidLevels * 2]; + + TF_DISALLOW_COPY_AND_ASSIGN(ImageData); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/image_neon.cc b/tensorflow/examples/android/jni/object_tracking/image_neon.cc new file mode 100644 index 0000000000..ddd8447bf3 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/image_neon.cc @@ -0,0 +1,270 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// NEON implementations of Image methods for compatible devices. Control +// should never enter this compilation unit on incompatible devices. + +#ifdef __ARM_NEON + +#include <arm_neon.h> + +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/image_utils.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +using namespace tensorflow; + +namespace tf_tracking { + +// This function does the bulk of the work. +template <> +void Image<uint8>::Downsample2x32ColumnsNeon(const uint8* const original, + const int stride, + const int orig_x) { + // Divide input x offset by 2 to find output offset. + const int new_x = orig_x >> 1; + + // Initial offset into top row. + const uint8* offset = original + orig_x; + + // This points to the leftmost pixel of our 8 horizontally arranged + // pixels in the destination data. + uint8* ptr_dst = (*this)[0] + new_x; + + // Sum along vertical columns. + // Process 32x2 input pixels and 16x1 output pixels per iteration. + for (int new_y = 0; new_y < height_; ++new_y) { + uint16x8_t accum1 = vdupq_n_u16(0); + uint16x8_t accum2 = vdupq_n_u16(0); + + // Go top to bottom across the four rows of input pixels that make up + // this output row. + for (int row_num = 0; row_num < 2; ++row_num) { + // First 16 bytes. + { + // Load 16 bytes of data from current offset. + const uint8x16_t curr_data1 = vld1q_u8(offset); + + // Pairwise add and accumulate into accum vectors (16 bit to account + // for values above 255). + accum1 = vpadalq_u8(accum1, curr_data1); + } + + // Second 16 bytes. + { + // Load 16 bytes of data from current offset. + const uint8x16_t curr_data2 = vld1q_u8(offset + 16); + + // Pairwise add and accumulate into accum vectors (16 bit to account + // for values above 255). + accum2 = vpadalq_u8(accum2, curr_data2); + } + + // Move offset down one row. + offset += stride; + } + + // Divide by 4 (number of input pixels per output + // pixel) and narrow data from 16 bits per pixel to 8 bpp. + const uint8x8_t tmp_pix1 = vqshrn_n_u16(accum1, 2); + const uint8x8_t tmp_pix2 = vqshrn_n_u16(accum2, 2); + + // Concatenate 8x1 pixel strips into 16x1 pixel strip. + const uint8x16_t allpixels = vcombine_u8(tmp_pix1, tmp_pix2); + + // Copy all pixels from composite 16x1 vector into output strip. + vst1q_u8(ptr_dst, allpixels); + + ptr_dst += stride_; + } +} + +// This function does the bulk of the work. +template <> +void Image<uint8>::Downsample4x32ColumnsNeon(const uint8* const original, + const int stride, + const int orig_x) { + // Divide input x offset by 4 to find output offset. + const int new_x = orig_x >> 2; + + // Initial offset into top row. + const uint8* offset = original + orig_x; + + // This points to the leftmost pixel of our 8 horizontally arranged + // pixels in the destination data. + uint8* ptr_dst = (*this)[0] + new_x; + + // Sum along vertical columns. + // Process 32x4 input pixels and 8x1 output pixels per iteration. + for (int new_y = 0; new_y < height_; ++new_y) { + uint16x8_t accum1 = vdupq_n_u16(0); + uint16x8_t accum2 = vdupq_n_u16(0); + + // Go top to bottom across the four rows of input pixels that make up + // this output row. + for (int row_num = 0; row_num < 4; ++row_num) { + // First 16 bytes. + { + // Load 16 bytes of data from current offset. + const uint8x16_t curr_data1 = vld1q_u8(offset); + + // Pairwise add and accumulate into accum vectors (16 bit to account + // for values above 255). + accum1 = vpadalq_u8(accum1, curr_data1); + } + + // Second 16 bytes. + { + // Load 16 bytes of data from current offset. + const uint8x16_t curr_data2 = vld1q_u8(offset + 16); + + // Pairwise add and accumulate into accum vectors (16 bit to account + // for values above 255). + accum2 = vpadalq_u8(accum2, curr_data2); + } + + // Move offset down one row. + offset += stride; + } + + // Add and widen, then divide by 16 (number of input pixels per output + // pixel) and narrow data from 32 bits per pixel to 16 bpp. + const uint16x4_t tmp_pix1 = vqshrn_n_u32(vpaddlq_u16(accum1), 4); + const uint16x4_t tmp_pix2 = vqshrn_n_u32(vpaddlq_u16(accum2), 4); + + // Combine 4x1 pixel strips into 8x1 pixel strip and narrow from + // 16 bits to 8 bits per pixel. + const uint8x8_t allpixels = vmovn_u16(vcombine_u16(tmp_pix1, tmp_pix2)); + + // Copy all pixels from composite 8x1 vector into output strip. + vst1_u8(ptr_dst, allpixels); + + ptr_dst += stride_; + } +} + + +// Hardware accelerated downsampling method for supported devices. +// Requires that image size be a multiple of 16 pixels in each dimension, +// and that downsampling be by a factor of 2 or 4. +template <> +void Image<uint8>::DownsampleAveragedNeon(const uint8* const original, + const int stride, const int factor) { + // TODO(andrewharp): stride is a bad approximation for the src image's width. + // Better to pass that in directly. + SCHECK(width_ * factor <= stride, "Uh oh!"); + const int last_starting_index = width_ * factor - 32; + + // We process 32 input pixels lengthwise at a time. + // The output per pass of this loop is an 8 wide by downsampled height tall + // pixel strip. + int orig_x = 0; + for (; orig_x <= last_starting_index; orig_x += 32) { + if (factor == 2) { + Downsample2x32ColumnsNeon(original, stride, orig_x); + } else { + Downsample4x32ColumnsNeon(original, stride, orig_x); + } + } + + // If a last pass is required, push it to the left enough so that it never + // goes out of bounds. This will result in some extra computation on devices + // whose frame widths are multiples of 16 and not 32. + if (orig_x < last_starting_index + 32) { + if (factor == 2) { + Downsample2x32ColumnsNeon(original, stride, last_starting_index); + } else { + Downsample4x32ColumnsNeon(original, stride, last_starting_index); + } + } +} + + +// Puts the image gradient matrix about a pixel into the 2x2 float array G. +// vals_x should be an array of the window x gradient values, whose indices +// can be in any order but are parallel to the vals_y entries. +// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for more details. +void CalculateGNeon(const float* const vals_x, const float* const vals_y, + const int num_vals, float* const G) { + const float32_t* const arm_vals_x = (const float32_t*) vals_x; + const float32_t* const arm_vals_y = (const float32_t*) vals_y; + + // Running sums. + float32x4_t xx = vdupq_n_f32(0.0f); + float32x4_t xy = vdupq_n_f32(0.0f); + float32x4_t yy = vdupq_n_f32(0.0f); + + // Maximum index we can load 4 consecutive values from. + // e.g. if there are 81 values, our last full pass can be from index 77: + // 81-4=>77 (77, 78, 79, 80) + const int max_i = num_vals - 4; + + // Defined here because we want to keep track of how many values were + // processed by NEON, so that we can finish off the remainder the normal + // way. + int i = 0; + + // Process values 4 at a time, accumulating the sums of + // the pixel-wise x*x, x*y, and y*y values. + for (; i <= max_i; i += 4) { + // Load xs + float32x4_t x = vld1q_f32(arm_vals_x + i); + + // Multiply x*x and accumulate. + xx = vmlaq_f32(xx, x, x); + + // Load ys + float32x4_t y = vld1q_f32(arm_vals_y + i); + + // Multiply x*y and accumulate. + xy = vmlaq_f32(xy, x, y); + + // Multiply y*y and accumulate. + yy = vmlaq_f32(yy, y, y); + } + + static float32_t xx_vals[4]; + static float32_t xy_vals[4]; + static float32_t yy_vals[4]; + + vst1q_f32(xx_vals, xx); + vst1q_f32(xy_vals, xy); + vst1q_f32(yy_vals, yy); + + // Accumulated values are store in sets of 4, we have to manually add + // the last bits together. + for (int j = 0; j < 4; ++j) { + G[0] += xx_vals[j]; + G[1] += xy_vals[j]; + G[3] += yy_vals[j]; + } + + // Finishes off last few values (< 4) from above. + for (; i < num_vals; ++i) { + G[0] += Square(vals_x[i]); + G[1] += vals_x[i] * vals_y[i]; + G[3] += Square(vals_y[i]); + } + + // The matrix is symmetric, so this is a given. + G[2] = G[1]; +} + +} // namespace tf_tracking + +#endif diff --git a/tensorflow/examples/android/jni/object_tracking/image_utils.h b/tensorflow/examples/android/jni/object_tracking/image_utils.h new file mode 100644 index 0000000000..5357a9352f --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/image_utils.h @@ -0,0 +1,301 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_ + +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +using namespace tensorflow; + +namespace tf_tracking { + +inline void GetUV( + const uint8* const input, Image<uint8>* const u, Image<uint8>* const v) { + const uint8* pUV = input; + + for (int row = 0; row < u->GetHeight(); ++row) { + uint8* u_curr = (*u)[row]; + uint8* v_curr = (*v)[row]; + for (int col = 0; col < u->GetWidth(); ++col) { +#ifdef __APPLE__ + *u_curr++ = *pUV++; + *v_curr++ = *pUV++; +#else + *v_curr++ = *pUV++; + *u_curr++ = *pUV++; +#endif + } + } +} + +// Marks every point within a circle of a given radius on the given boolean +// image true. +template <typename U> +inline static void MarkImage(const int x, const int y, const int radius, + Image<U>* const img) { + SCHECK(img->ValidPixel(x, y), "Marking invalid pixel in image! %d, %d", x, y); + + // Precomputed for efficiency. + const int squared_radius = Square(radius); + + // Mark every row in the circle. + for (int d_y = 0; d_y <= radius; ++d_y) { + const int squared_y_dist = Square(d_y); + + const int min_y = MAX(y - d_y, 0); + const int max_y = MIN(y + d_y, img->height_less_one_); + + // The max d_x of the circle must be strictly greater or equal to + // radius - d_y for any positive d_y. Thus, starting from radius - d_y will + // reduce the number of iterations required as compared to starting from + // either 0 and counting up or radius and counting down. + for (int d_x = radius - d_y; d_x <= radius; ++d_x) { + // The first time this critera is met, we know the width of the circle at + // this row (without using sqrt). + if (squared_y_dist + Square(d_x) >= squared_radius) { + const int min_x = MAX(x - d_x, 0); + const int max_x = MIN(x + d_x, img->width_less_one_); + + // Mark both above and below the center row. + bool* const top_row_start = (*img)[min_y] + min_x; + bool* const bottom_row_start = (*img)[max_y] + min_x; + + const int x_width = max_x - min_x + 1; + memset(top_row_start, true, sizeof(*top_row_start) * x_width); + memset(bottom_row_start, true, sizeof(*bottom_row_start) * x_width); + + // This row is marked, time to move on to the next row. + break; + } + } + } +} + +#ifdef __ARM_NEON +void CalculateGNeon( + const float* const vals_x, const float* const vals_y, + const int num_vals, float* const G); +#endif + +// Puts the image gradient matrix about a pixel into the 2x2 float array G. +// vals_x should be an array of the window x gradient values, whose indices +// can be in any order but are parallel to the vals_y entries. +// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for more details. +inline void CalculateG(const float* const vals_x, const float* const vals_y, + const int num_vals, float* const G) { +#ifdef __ARM_NEON + CalculateGNeon(vals_x, vals_y, num_vals, G); + return; +#endif + + // Non-accelerated version. + for (int i = 0; i < num_vals; ++i) { + G[0] += Square(vals_x[i]); + G[1] += vals_x[i] * vals_y[i]; + G[3] += Square(vals_y[i]); + } + + // The matrix is symmetric, so this is a given. + G[2] = G[1]; +} + + +inline void CalculateGInt16(const int16* const vals_x, + const int16* const vals_y, + const int num_vals, int* const G) { + // Non-accelerated version. + for (int i = 0; i < num_vals; ++i) { + G[0] += Square(vals_x[i]); + G[1] += vals_x[i] * vals_y[i]; + G[3] += Square(vals_y[i]); + } + + // The matrix is symmetric, so this is a given. + G[2] = G[1]; +} + + +// Puts the image gradient matrix about a pixel into the 2x2 float array G. +// Looks up interpolated pixels, then calls above method for implementation. +inline void CalculateG(const int window_radius, + const float center_x, const float center_y, + const Image<int32>& I_x, const Image<int32>& I_y, + float* const G) { + SCHECK(I_x.ValidPixel(center_x, center_y), "Problem in calculateG!"); + + // Hardcoded to allow for a max window radius of 5 (9 pixels x 9 pixels). + static const int kMaxWindowRadius = 5; + SCHECK(window_radius <= kMaxWindowRadius, + "Window %d > %d!", window_radius, kMaxWindowRadius); + + // Diameter of window is 2 * radius + 1 for center pixel. + static const int kWindowBufferSize = + (kMaxWindowRadius * 2 + 1) * (kMaxWindowRadius * 2 + 1); + + // Preallocate buffers statically for efficiency. + static int16 vals_x[kWindowBufferSize]; + static int16 vals_y[kWindowBufferSize]; + + const int src_left_fixed = RealToFixed1616(center_x - window_radius); + const int src_top_fixed = RealToFixed1616(center_y - window_radius); + + int16* vals_x_ptr = vals_x; + int16* vals_y_ptr = vals_y; + + const int window_size = 2 * window_radius + 1; + for (int y = 0; y < window_size; ++y) { + const int fp_y = src_top_fixed + (y << 16); + + for (int x = 0; x < window_size; ++x) { + const int fp_x = src_left_fixed + (x << 16); + + *vals_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y); + *vals_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y); + } + } + + int32 g_temp[] = {0, 0, 0, 0}; + CalculateGInt16(vals_x, vals_y, window_size * window_size, g_temp); + + for (int i = 0; i < 4; ++i) { + G[i] = g_temp[i]; + } +} + +inline float ImageCrossCorrelation(const Image<float>& image1, + const Image<float>& image2, + const int x_offset, const int y_offset) { + SCHECK(image1.GetWidth() == image2.GetWidth() && + image1.GetHeight() == image2.GetHeight(), + "Dimension mismatch! %dx%d vs %dx%d", + image1.GetWidth(), image1.GetHeight(), + image2.GetWidth(), image2.GetHeight()); + + const int num_pixels = image1.GetWidth() * image1.GetHeight(); + const float* data1 = image1.data(); + const float* data2 = image2.data(); + return ComputeCrossCorrelation(data1, data2, num_pixels); +} + +// Copies an arbitrary region of an image to another (floating point) +// image, scaling as it goes using bilinear interpolation. +inline void CopyArea(const Image<uint8>& image, + const BoundingBox& area_to_copy, + Image<float>* const patch_image) { + VLOG(2) << "Copying from: " << area_to_copy << std::endl; + + const int patch_width = patch_image->GetWidth(); + const int patch_height = patch_image->GetHeight(); + + const float x_dist_between_samples = patch_width > 0 ? + area_to_copy.GetWidth() / (patch_width - 1) : 0; + + const float y_dist_between_samples = patch_height > 0 ? + area_to_copy.GetHeight() / (patch_height - 1) : 0; + + for (int y_index = 0; y_index < patch_height; ++y_index) { + const float sample_y = + y_index * y_dist_between_samples + area_to_copy.top_; + + for (int x_index = 0; x_index < patch_width; ++x_index) { + const float sample_x = + x_index * x_dist_between_samples + area_to_copy.left_; + + if (image.ValidInterpPixel(sample_x, sample_y)) { + // TODO(andrewharp): Do area averaging when downsampling. + (*patch_image)[y_index][x_index] = + image.GetPixelInterp(sample_x, sample_y); + } else { + (*patch_image)[y_index][x_index] = -1.0f; + } + } + } +} + + +// Takes a floating point image and normalizes it in-place. +// +// First, negative values will be set to the mean of the non-negative pixels +// in the image. +// +// Then, the resulting will be normalized such that it has mean value of 0.0 and +// a standard deviation of 1.0. +inline void NormalizeImage(Image<float>* const image) { + const float* const data_ptr = image->data(); + + // Copy only the non-negative values to some temp memory. + float running_sum = 0.0f; + int num_data_gte_zero = 0; + { + float* const curr_data = (*image)[0]; + for (int i = 0; i < image->data_size_; ++i) { + if (curr_data[i] >= 0.0f) { + running_sum += curr_data[i]; + ++num_data_gte_zero; + } else { + curr_data[i] = -1.0f; + } + } + } + + // If none of the pixels are valid, just set the entire thing to 0.0f. + if (num_data_gte_zero == 0) { + image->Clear(0.0f); + return; + } + + const float corrected_mean = running_sum / num_data_gte_zero; + + float* curr_data = (*image)[0]; + for (int i = 0; i < image->data_size_; ++i) { + const float curr_val = *curr_data; + *curr_data++ = curr_val < 0 ? 0 : curr_val - corrected_mean; + } + + const float std_dev = ComputeStdDev(data_ptr, image->data_size_, 0.0f); + + if (std_dev > 0.0f) { + curr_data = (*image)[0]; + for (int i = 0; i < image->data_size_; ++i) { + *curr_data++ /= std_dev; + } + +#ifdef SANITY_CHECKS + LOGV("corrected_mean: %1.2f std_dev: %1.2f", corrected_mean, std_dev); + const float correlation = + ComputeCrossCorrelation(image->data(), + image->data(), + image->data_size_); + + if (std::abs(correlation - 1.0f) > EPSILON) { + LOG(ERROR) << "Bad image!" << std::endl; + LOG(ERROR) << *image << std::endl; + } + + SCHECK(std::abs(correlation - 1.0f) < EPSILON, + "Correlation wasn't 1.0f: %.10f", correlation); +#endif + } +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/integral_image.h b/tensorflow/examples/android/jni/object_tracking/integral_image.h new file mode 100755 index 0000000000..28b2045572 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/integral_image.h @@ -0,0 +1,187 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +namespace tf_tracking { + +typedef uint8 Code; + +class IntegralImage: public Image<uint32> { + public: + explicit IntegralImage(const Image<uint8>& image_base) : + Image<uint32>(image_base.GetWidth(), image_base.GetHeight()) { + Recompute(image_base); + } + + IntegralImage(const int width, const int height) : + Image<uint32>(width, height) {} + + void Recompute(const Image<uint8>& image_base) { + SCHECK(image_base.GetWidth() == GetWidth() && + image_base.GetHeight() == GetHeight(), "Dimensions don't match!"); + + // Sum along first row. + { + int x_sum = 0; + for (int x = 0; x < image_base.GetWidth(); ++x) { + x_sum += image_base[0][x]; + (*this)[0][x] = x_sum; + } + } + + // Sum everything else. + for (int y = 1; y < image_base.GetHeight(); ++y) { + uint32* curr_sum = (*this)[y]; + + // Previously summed pointers. + const uint32* up_one = (*this)[y - 1]; + + // Current value pointer. + const uint8* curr_delta = image_base[y]; + + uint32 row_till_now = 0; + + for (int x = 0; x < GetWidth(); ++x) { + // Add the one above and the one to the left. + row_till_now += *curr_delta; + *curr_sum = *up_one + row_till_now; + + // Scoot everything along. + ++curr_sum; + ++up_one; + ++curr_delta; + } + } + + SCHECK(VerifyData(image_base), "Images did not match!"); + } + + bool VerifyData(const Image<uint8>& image_base) { + for (int y = 0; y < GetHeight(); ++y) { + for (int x = 0; x < GetWidth(); ++x) { + uint32 curr_val = (*this)[y][x]; + + if (x > 0) { + curr_val -= (*this)[y][x - 1]; + } + + if (y > 0) { + curr_val -= (*this)[y - 1][x]; + } + + if (x > 0 && y > 0) { + curr_val += (*this)[y - 1][x - 1]; + } + + if (curr_val != image_base[y][x]) { + LOGE("Mismatch! %d vs %d", curr_val, image_base[y][x]); + return false; + } + + if (GetRegionSum(x, y, x, y) != curr_val) { + LOGE("Mismatch!"); + } + } + } + + return true; + } + + // Returns the sum of all pixels in the specified region. + inline uint32 GetRegionSum(const int x1, const int y1, + const int x2, const int y2) const { + SCHECK(x1 >= 0 && y1 >= 0 && + x2 >= x1 && y2 >= y1 && x2 < GetWidth() && y2 < GetHeight(), + "indices out of bounds! %d-%d / %d, %d-%d / %d, ", + x1, x2, GetWidth(), y1, y2, GetHeight()); + + const uint32 everything = (*this)[y2][x2]; + + uint32 sum = everything; + if (x1 > 0 && y1 > 0) { + // Most common case. + const uint32 left = (*this)[y2][x1 - 1]; + const uint32 top = (*this)[y1 - 1][x2]; + const uint32 top_left = (*this)[y1 - 1][x1 - 1]; + + sum = everything - left - top + top_left; + SCHECK(sum >= 0, "Both: %d - %d - %d + %d => %d! indices: %d %d %d %d", + everything, left, top, top_left, sum, x1, y1, x2, y2); + } else if (x1 > 0) { + // Flush against top of image. + // Subtract out the region to the left only. + const uint32 top = (*this)[y2][x1 - 1]; + sum = everything - top; + SCHECK(sum >= 0, "Top: %d - %d => %d!", everything, top, sum); + } else if (y1 > 0) { + // Flush against left side of image. + // Subtract out the region above only. + const uint32 left = (*this)[y1 - 1][x2]; + sum = everything - left; + SCHECK(sum >= 0, "Left: %d - %d => %d!", everything, left, sum); + } + + SCHECK(sum >= 0, "Negative sum!"); + + return sum; + } + + // Returns the 2bit code associated with this region, which represents + // the overall gradient. + inline Code GetCode(const BoundingBox& bounding_box) const { + return GetCode(bounding_box.left_, bounding_box.top_, + bounding_box.right_, bounding_box.bottom_); + } + + inline Code GetCode(const int x1, const int y1, + const int x2, const int y2) const { + SCHECK(x1 < x2 && y1 < y2, "Bounds out of order!! TL:%d,%d BR:%d,%d", + x1, y1, x2, y2); + + // Gradient computed vertically. + const int box_height = (y2 - y1) / 2; + const int top_sum = GetRegionSum(x1, y1, x2, y1 + box_height); + const int bottom_sum = GetRegionSum(x1, y2 - box_height, x2, y2); + const bool vertical_code = top_sum > bottom_sum; + + // Gradient computed horizontally. + const int box_width = (x2 - x1) / 2; + const int left_sum = GetRegionSum(x1, y1, x1 + box_width, y2); + const int right_sum = GetRegionSum(x2 - box_width, y1, x2, y2); + const bool horizontal_code = left_sum > right_sum; + + const Code final_code = (vertical_code << 1) | horizontal_code; + + SCHECK(InRange(final_code, static_cast<Code>(0), static_cast<Code>(3)), + "Invalid code! %d", final_code); + + // Returns a value 0-3. + return final_code; + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(IntegralImage); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/jni_utils.h b/tensorflow/examples/android/jni/object_tracking/jni_utils.h new file mode 100644 index 0000000000..92458536b6 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/jni_utils.h @@ -0,0 +1,62 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_ + +#include <android/log.h> + +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +// The JniIntField class is used to access Java fields from native code. This +// technique of hiding pointers to native objects in opaque Java fields is how +// the Android hardware libraries work. This reduces the amount of static +// native methods and makes it easier to manage the lifetime of native objects. +class JniIntField { + public: + JniIntField(const char* field_name) : field_name_(field_name), field_ID_(0) {} + + int get(JNIEnv* env, jobject thiz) { + if (field_ID_ == 0) { + jclass cls = env->GetObjectClass(thiz); + CHECK_ALWAYS(cls != 0, "Unable to find class"); + field_ID_ = env->GetFieldID(cls, field_name_, "I"); + CHECK_ALWAYS(field_ID_ != 0, + "Unable to find field %s. (Check proguard cfg)", field_name_); + } + + return env->GetIntField(thiz, field_ID_); + } + + void set(JNIEnv* env, jobject thiz, int value) { + if (field_ID_ == 0) { + jclass cls = env->GetObjectClass(thiz); + CHECK_ALWAYS(cls != 0, "Unable to find class"); + field_ID_ = env->GetFieldID(cls, field_name_, "I"); + CHECK_ALWAYS(field_ID_ != 0, + "Unable to find field %s (Check proguard cfg)", field_name_); + } + + env->SetIntField(thiz, field_ID_, value); + } + + private: + const char* const field_name_; + + // This is just a cache + jfieldID field_ID_; +}; + +#endif diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint.h b/tensorflow/examples/android/jni/object_tracking/keypoint.h new file mode 100644 index 0000000000..82917261cb --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/keypoint.h @@ -0,0 +1,48 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h" +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" + +namespace tf_tracking { + +// For keeping track of keypoints. +struct Keypoint { + Keypoint() : pos_(0.0f, 0.0f), score_(0.0f), type_(0) {} + Keypoint(const float x, const float y) + : pos_(x, y), score_(0.0f), type_(0) {} + + Point2f pos_; + float score_; + uint8 type_; +}; + +inline std::ostream& operator<<(std::ostream& stream, const Keypoint keypoint) { + return stream << "[" << keypoint.pos_ << ", " + << keypoint.score_ << ", " << keypoint.type_ << "]"; +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc new file mode 100644 index 0000000000..6cc6b4e73f --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc @@ -0,0 +1,549 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Various keypoint detecting functions. + +#include <float.h> + +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" + +namespace tf_tracking { + +static inline int GetDistSquaredBetween(const int* vec1, const int* vec2) { + return Square(vec1[0] - vec2[0]) + Square(vec1[1] - vec2[1]); +} + +void KeypointDetector::ScoreKeypoints(const ImageData& image_data, + const int num_candidates, + Keypoint* const candidate_keypoints) { + const Image<int>& I_x = *image_data.GetSpatialX(0); + const Image<int>& I_y = *image_data.GetSpatialY(0); + + if (config_->detect_skin) { + const Image<uint8>& u_data = *image_data.GetU(); + const Image<uint8>& v_data = *image_data.GetV(); + + static const int reference[] = {111, 155}; + + // Score all the keypoints. + for (int i = 0; i < num_candidates; ++i) { + Keypoint* const keypoint = candidate_keypoints + i; + + const int x_pos = keypoint->pos_.x * 2; + const int y_pos = keypoint->pos_.y * 2; + + const int curr_color[] = {u_data[y_pos][x_pos], v_data[y_pos][x_pos]}; + keypoint->score_ = + HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y) / + GetDistSquaredBetween(reference, curr_color); + } + } else { + // Score all the keypoints. + for (int i = 0; i < num_candidates; ++i) { + Keypoint* const keypoint = candidate_keypoints + i; + keypoint->score_ = + HarrisFilter(I_x, I_y, keypoint->pos_.x, keypoint->pos_.y); + } + } +} + + +inline int KeypointCompare(const void* const a, const void* const b) { + return (reinterpret_cast<const Keypoint*>(a)->score_ - + reinterpret_cast<const Keypoint*>(b)->score_) <= 0 ? 1 : -1; +} + + +// Quicksorts detected keypoints by score. +void KeypointDetector::SortKeypoints(const int num_candidates, + Keypoint* const candidate_keypoints) const { + qsort(candidate_keypoints, num_candidates, sizeof(Keypoint), KeypointCompare); + +#ifdef SANITY_CHECKS + // Verify that the array got sorted. + float last_score = FLT_MAX; + for (int i = 0; i < num_candidates; ++i) { + const float curr_score = candidate_keypoints[i].score_; + + // Scores should be monotonically increasing. + SCHECK(last_score >= curr_score, + "Quicksort failure! %d: %.5f > %d: %.5f (%d total)", + i - 1, last_score, i, curr_score, num_candidates); + + last_score = curr_score; + } +#endif +} + + +int KeypointDetector::SelectKeypointsInBox( + const BoundingBox& box, + const Keypoint* const candidate_keypoints, + const int num_candidates, + const int max_keypoints, + const int num_existing_keypoints, + const Keypoint* const existing_keypoints, + Keypoint* const final_keypoints) const { + if (max_keypoints <= 0) { + return 0; + } + + // This is the distance within which keypoints may be placed to each other + // within this box, roughly based on the box dimensions. + const int distance = + MAX(1, MIN(box.GetWidth(), box.GetHeight()) * kClosestPercent / 2.0f); + + // First, mark keypoints that already happen to be inside this region. Ignore + // keypoints that are outside it, however close they might be. + interest_map_->Clear(false); + for (int i = 0; i < num_existing_keypoints; ++i) { + const Keypoint& candidate = existing_keypoints[i]; + + const int x_pos = candidate.pos_.x; + const int y_pos = candidate.pos_.y; + if (box.Contains(candidate.pos_)) { + MarkImage(x_pos, y_pos, distance, interest_map_.get()); + } + } + + // Now, go through and check which keypoints will still fit in the box. + int num_keypoints_selected = 0; + for (int i = 0; i < num_candidates; ++i) { + const Keypoint& candidate = candidate_keypoints[i]; + + const int x_pos = candidate.pos_.x; + const int y_pos = candidate.pos_.y; + + if (!box.Contains(candidate.pos_) || + !interest_map_->ValidPixel(x_pos, y_pos)) { + continue; + } + + if (!(*interest_map_)[y_pos][x_pos]) { + final_keypoints[num_keypoints_selected++] = candidate; + if (num_keypoints_selected >= max_keypoints) { + break; + } + MarkImage(x_pos, y_pos, distance, interest_map_.get()); + } + } + return num_keypoints_selected; +} + + +void KeypointDetector::SelectKeypoints( + const std::vector<BoundingBox>& boxes, + const Keypoint* const candidate_keypoints, + const int num_candidates, + FramePair* const curr_change) const { + // Now select all the interesting keypoints that fall insider our boxes. + curr_change->number_of_keypoints_ = 0; + for (std::vector<BoundingBox>::const_iterator iter = boxes.begin(); + iter != boxes.end(); ++iter) { + const BoundingBox bounding_box = *iter; + + // Count up keypoints that have already been selected, and fall within our + // box. + int num_keypoints_already_in_box = 0; + for (int i = 0; i < curr_change->number_of_keypoints_; ++i) { + if (bounding_box.Contains(curr_change->frame1_keypoints_[i].pos_)) { + ++num_keypoints_already_in_box; + } + } + + const int max_keypoints_to_find_in_box = + MIN(kMaxKeypointsForObject - num_keypoints_already_in_box, + kMaxKeypoints - curr_change->number_of_keypoints_); + + const int num_new_keypoints_in_box = SelectKeypointsInBox( + bounding_box, + candidate_keypoints, + num_candidates, + max_keypoints_to_find_in_box, + curr_change->number_of_keypoints_, + curr_change->frame1_keypoints_, + curr_change->frame1_keypoints_ + curr_change->number_of_keypoints_); + + curr_change->number_of_keypoints_ += num_new_keypoints_in_box; + + LOGV("Selected %d keypoints!", curr_change->number_of_keypoints_); + } +} + + +// Walks along the given circle checking for pixels above or below the center. +// Returns a score, or 0 if the keypoint did not pass the criteria. +// +// Parameters: +// circle_perimeter: the circumference in pixels of the circle. +// threshold: the minimum number of contiguous pixels that must be above or +// below the center value. +// center_ptr: the location of the center pixel in memory +// offsets: the relative offsets from the center pixel of the edge pixels. +inline int TestCircle(const int circle_perimeter, const int threshold, + const uint8* const center_ptr, + const int* offsets) { + // Get the actual value of the center pixel for easier reference later on. + const int center_value = static_cast<int>(*center_ptr); + + // Number of total pixels to check. Have to wrap around some in case + // the contiguous section is split by the array edges. + const int num_total = circle_perimeter + threshold - 1; + + int num_above = 0; + int above_diff = 0; + + int num_below = 0; + int below_diff = 0; + + // Used to tell when this is definitely not going to meet the threshold so we + // can early abort. + int minimum_by_now = threshold - num_total + 1; + + // Go through every pixel along the perimeter of the circle, and then around + // again a little bit. + for (int i = 0; i < num_total; ++i) { + // This should be faster than mod. + const int perim_index = i < circle_perimeter ? i : i - circle_perimeter; + + // This gets the value of the current pixel along the perimeter by using + // a precomputed offset. + const int curr_value = + static_cast<int>(center_ptr[offsets[perim_index]]); + + const int difference = curr_value - center_value; + + if (difference > kFastDiffAmount) { + above_diff += difference; + ++num_above; + + num_below = 0; + below_diff = 0; + + if (num_above >= threshold) { + return above_diff; + } + } else if (difference < -kFastDiffAmount) { + below_diff += difference; + ++num_below; + + num_above = 0; + above_diff = 0; + + if (num_below >= threshold) { + return below_diff; + } + } else { + num_above = 0; + num_below = 0; + above_diff = 0; + below_diff = 0; + } + + // See if there's any chance of making the threshold. + if (MAX(num_above, num_below) < minimum_by_now) { + // Didn't pass. + return 0; + } + ++minimum_by_now; + } + + // Didn't pass. + return 0; +} + + +// Returns a score in the range [0.0, positive infinity) which represents the +// relative likelihood of a point being a corner. +float KeypointDetector::HarrisFilter(const Image<int32>& I_x, + const Image<int32>& I_y, + const float x, const float y) const { + if (I_x.ValidInterpPixel(x - kHarrisWindowSize, y - kHarrisWindowSize) && + I_x.ValidInterpPixel(x + kHarrisWindowSize, y + kHarrisWindowSize)) { + // Image gradient matrix. + float G[] = { 0, 0, 0, 0 }; + CalculateG(kHarrisWindowSize, x, y, I_x, I_y, G); + + const float dx = G[0]; + const float dy = G[3]; + const float dxy = G[1]; + + // Harris-Nobel corner score. + return (dx * dy - Square(dxy)) / (dx + dy + FLT_MIN); + } + + return 0.0f; +} + + +int KeypointDetector::AddExtraCandidatesForBoxes( + const std::vector<BoundingBox>& boxes, + const int max_num_keypoints, + Keypoint* const keypoints) const { + int num_keypoints_added = 0; + + for (std::vector<BoundingBox>::const_iterator iter = boxes.begin(); + iter != boxes.end(); ++iter) { + const BoundingBox box = *iter; + + for (int i = 0; i < kNumToAddAsCandidates; ++i) { + for (int j = 0; j < kNumToAddAsCandidates; ++j) { + if (num_keypoints_added >= max_num_keypoints) { + LOGW("Hit cap of %d for temporary keypoints!", max_num_keypoints); + return num_keypoints_added; + } + + Keypoint curr_keypoint = keypoints[num_keypoints_added++]; + curr_keypoint.pos_ = Point2f( + box.left_ + box.GetWidth() * (i + 0.5f) / kNumToAddAsCandidates, + box.top_ + box.GetHeight() * (j + 0.5f) / kNumToAddAsCandidates); + curr_keypoint.type_ = KEYPOINT_TYPE_INTEREST; + } + } + } + + return num_keypoints_added; +} + + +void KeypointDetector::FindKeypoints(const ImageData& image_data, + const std::vector<BoundingBox>& rois, + const FramePair& prev_change, + FramePair* const curr_change) { + // Copy keypoints from second frame of last pass to temp keypoints of this + // pass. + int number_of_tmp_keypoints = CopyKeypoints(prev_change, tmp_keypoints_); + + const int max_num_fast = kMaxTempKeypoints - number_of_tmp_keypoints; + number_of_tmp_keypoints += + FindFastKeypoints(image_data, max_num_fast, + tmp_keypoints_ + number_of_tmp_keypoints); + + TimeLog("Found FAST keypoints"); + + if (number_of_tmp_keypoints >= kMaxTempKeypoints) { + LOGW("Hit cap of %d for temporary keypoints (FAST)! %d keypoints", + kMaxTempKeypoints, number_of_tmp_keypoints); + } + + if (kAddArbitraryKeypoints) { + // Add some for each object prior to scoring. + const int max_num_box_keypoints = + kMaxTempKeypoints - number_of_tmp_keypoints; + number_of_tmp_keypoints += + AddExtraCandidatesForBoxes(rois, max_num_box_keypoints, + tmp_keypoints_ + number_of_tmp_keypoints); + TimeLog("Added box keypoints"); + + if (number_of_tmp_keypoints >= kMaxTempKeypoints) { + LOGW("Hit cap of %d for temporary keypoints (boxes)! %d keypoints", + kMaxTempKeypoints, number_of_tmp_keypoints); + } + } + + // Score them... + LOGV("Scoring %d keypoints!", number_of_tmp_keypoints); + ScoreKeypoints(image_data, number_of_tmp_keypoints, tmp_keypoints_); + TimeLog("Scored keypoints"); + + // Now pare it down a bit. + SortKeypoints(number_of_tmp_keypoints, tmp_keypoints_); + TimeLog("Sorted keypoints"); + + LOGV("%d keypoints to select from!", number_of_tmp_keypoints); + + SelectKeypoints(rois, tmp_keypoints_, number_of_tmp_keypoints, curr_change); + TimeLog("Selected keypoints"); + + LOGV("Picked %d (%d max) final keypoints out of %d potential.", + curr_change->number_of_keypoints_, + kMaxKeypoints, number_of_tmp_keypoints); +} + + +int KeypointDetector::CopyKeypoints(const FramePair& prev_change, + Keypoint* const new_keypoints) { + int number_of_keypoints = 0; + + // Caching values from last pass, just copy and compact. + for (int i = 0; i < prev_change.number_of_keypoints_; ++i) { + if (prev_change.optical_flow_found_keypoint_[i]) { + new_keypoints[number_of_keypoints] = + prev_change.frame2_keypoints_[i]; + + new_keypoints[number_of_keypoints].score_ = + prev_change.frame1_keypoints_[i].score_; + + ++number_of_keypoints; + } + } + + TimeLog("Copied keypoints"); + return number_of_keypoints; +} + + +// FAST keypoint detector. +int KeypointDetector::FindFastKeypoints(const Image<uint8>& frame, + const int quadrant, + const int downsample_factor, + const int max_num_keypoints, + Keypoint* const keypoints) { + /* + // Reference for a circle of diameter 7. + const int circle[] = {0, 0, 1, 1, 1, 0, 0, + 0, 1, 0, 0, 0, 1, 0, + 1, 0, 0, 0, 0, 0, 1, + 1, 0, 0, 0, 0, 0, 1, + 1, 0, 0, 0, 0, 0, 1, + 0, 1, 0, 0, 0, 1, 0, + 0, 0, 1, 1, 1, 0, 0}; + const int circle_offset[] = + {2, 3, 4, 8, 12, 14, 20, 21, 27, 28, 34, 36, 40, 44, 45, 46}; + */ + + // Quick test of compass directions. Any length 16 circle with a break of up + // to 4 pixels will have at least 3 of these 4 pixels active. + static const int short_circle_perimeter = 4; + static const int short_threshold = 3; + static const int short_circle_x[] = { -3, 0, +3, 0 }; + static const int short_circle_y[] = { 0, -3, 0, +3 }; + + // Precompute image offsets. + int short_offsets[short_circle_perimeter]; + for (int i = 0; i < short_circle_perimeter; ++i) { + short_offsets[i] = short_circle_x[i] + short_circle_y[i] * frame.GetWidth(); + } + + // Large circle values. + static const int full_circle_perimeter = 16; + static const int full_threshold = 12; + static const int full_circle_x[] = + { -1, 0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2, -3, -3, -3, -2 }; + static const int full_circle_y[] = + { -3, -3, -3, -2, -1, 0, +1, +2, +3, +3, +3, +2, +1, +0, -1, -2 }; + + // Precompute image offsets. + int full_offsets[full_circle_perimeter]; + for (int i = 0; i < full_circle_perimeter; ++i) { + full_offsets[i] = full_circle_x[i] + full_circle_y[i] * frame.GetWidth(); + } + + const int scratch_stride = frame.stride(); + + keypoint_scratch_->Clear(0); + + // Set up the bounds on the region to test based on the passed-in quadrant. + const int quadrant_width = (frame.GetWidth() / 2) - kFastBorderBuffer; + const int quadrant_height = (frame.GetHeight() / 2) - kFastBorderBuffer; + const int start_x = + kFastBorderBuffer + ((quadrant % 2 == 0) ? 0 : quadrant_width); + const int start_y = + kFastBorderBuffer + ((quadrant < 2) ? 0 : quadrant_height); + const int end_x = start_x + quadrant_width; + const int end_y = start_y + quadrant_height; + + // Loop through once to find FAST keypoint clumps. + for (int img_y = start_y; img_y < end_y; ++img_y) { + const uint8* curr_pixel_ptr = frame[img_y] + start_x; + + for (int img_x = start_x; img_x < end_x; ++img_x) { + // Only insert it if it meets the quick minimum requirements test. + if (TestCircle(short_circle_perimeter, short_threshold, + curr_pixel_ptr, short_offsets) != 0) { + // Longer test for actual keypoint score.. + const int fast_score = TestCircle(full_circle_perimeter, + full_threshold, + curr_pixel_ptr, + full_offsets); + + // Non-zero score means the keypoint was found. + if (fast_score != 0) { + uint8* const center_ptr = (*keypoint_scratch_)[img_y] + img_x; + + // Increase the keypoint count on this pixel and the pixels in all + // 4 cardinal directions. + *center_ptr += 5; + *(center_ptr - 1) += 1; + *(center_ptr + 1) += 1; + *(center_ptr - scratch_stride) += 1; + *(center_ptr + scratch_stride) += 1; + } + } + + ++curr_pixel_ptr; + } // x + } // y + + TimeLog("Found FAST keypoints."); + + int num_keypoints = 0; + // Loop through again and Harris filter pixels in the center of clumps. + // We can shrink the window by 1 pixel on every side. + for (int img_y = start_y + 1; img_y < end_y - 1; ++img_y) { + const uint8* curr_pixel_ptr = (*keypoint_scratch_)[img_y] + start_x; + + for (int img_x = start_x + 1; img_x < end_x - 1; ++img_x) { + if (*curr_pixel_ptr >= kMinNumConnectedForFastKeypoint) { + Keypoint* const keypoint = keypoints + num_keypoints; + keypoint->pos_ = Point2f( + img_x * downsample_factor, img_y * downsample_factor); + keypoint->score_ = 0; + keypoint->type_ = KEYPOINT_TYPE_FAST; + + ++num_keypoints; + if (num_keypoints >= max_num_keypoints) { + return num_keypoints; + } + } + + ++curr_pixel_ptr; + } // x + } // y + + TimeLog("Picked FAST keypoints."); + + return num_keypoints; +} + +int KeypointDetector::FindFastKeypoints(const ImageData& image_data, + const int max_num_keypoints, + Keypoint* const keypoints) { + int downsample_factor = 1; + int num_found = 0; + + // TODO(andrewharp): Get this working for multiple image scales. + for (int i = 0; i < 1; ++i) { + const Image<uint8>& frame = *image_data.GetPyramidSqrt2Level(i); + num_found += FindFastKeypoints( + frame, fast_quadrant_, + downsample_factor, max_num_keypoints, keypoints + num_found); + downsample_factor *= 2; + } + + // Increment the current quadrant. + fast_quadrant_ = (fast_quadrant_ + 1) % 4; + + return num_found; +} + +} // namespace tf_tracking diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h new file mode 100644 index 0000000000..6cdd5dde11 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h @@ -0,0 +1,133 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ + +#include <vector> + +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/image_data.h" +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" + +using namespace tensorflow; + +namespace tf_tracking { + +struct Keypoint; + +class KeypointDetector { + public: + explicit KeypointDetector(const KeypointDetectorConfig* const config) + : config_(config), + keypoint_scratch_(new Image<uint8>(config_->image_size)), + interest_map_(new Image<bool>(config_->image_size)), + fast_quadrant_(0) { + interest_map_->Clear(false); + } + + ~KeypointDetector() {} + + // Finds a new set of keypoints for the current frame, picked from the current + // set of keypoints and also from a set discovered via a keypoint detector. + // Special attention is applied to make sure that keypoints are distributed + // within the supplied ROIs. + void FindKeypoints(const ImageData& image_data, + const std::vector<BoundingBox>& rois, + const FramePair& prev_change, + FramePair* const curr_change); + + private: + // Compute the corneriness of a point in the image. + float HarrisFilter(const Image<int32>& I_x, const Image<int32>& I_y, + const float x, const float y) const; + + // Adds a grid of candidate keypoints to the given box, up to + // max_num_keypoints or kNumToAddAsCandidates^2, whichever is lower. + int AddExtraCandidatesForBoxes( + const std::vector<BoundingBox>& boxes, + const int max_num_keypoints, + Keypoint* const keypoints) const; + + // Scan the frame for potential keypoints using the FAST keypoint detector. + // Quadrant is an argument 0-3 which refers to the quadrant of the image in + // which to detect keypoints. + int FindFastKeypoints(const Image<uint8>& frame, + const int quadrant, + const int downsample_factor, + const int max_num_keypoints, + Keypoint* const keypoints); + + int FindFastKeypoints(const ImageData& image_data, + const int max_num_keypoints, + Keypoint* const keypoints); + + // Score a bunch of candidate keypoints. Assigns the scores to the input + // candidate_keypoints array entries. + void ScoreKeypoints(const ImageData& image_data, + const int num_candidates, + Keypoint* const candidate_keypoints); + + void SortKeypoints(const int num_candidates, + Keypoint* const candidate_keypoints) const; + + // Selects a set of keypoints falling within the supplied box such that the + // most highly rated keypoints are picked first, and so that none of them are + // too close together. + int SelectKeypointsInBox( + const BoundingBox& box, + const Keypoint* const candidate_keypoints, + const int num_candidates, + const int max_keypoints, + const int num_existing_keypoints, + const Keypoint* const existing_keypoints, + Keypoint* const final_keypoints) const; + + // Selects from the supplied sorted keypoint pool a set of keypoints that will + // best cover the given set of boxes, such that each box is covered at a + // resolution proportional to its size. + void SelectKeypoints( + const std::vector<BoundingBox>& boxes, + const Keypoint* const candidate_keypoints, + const int num_candidates, + FramePair* const frame_change) const; + + // Copies and compacts the found keypoints in the second frame of prev_change + // into the array at new_keypoints. + static int CopyKeypoints(const FramePair& prev_change, + Keypoint* const new_keypoints); + + const KeypointDetectorConfig* const config_; + + // Scratch memory for keypoint candidacy detection and non-max suppression. + std::unique_ptr<Image<uint8> > keypoint_scratch_; + + // Regions of the image to pay special attention to. + std::unique_ptr<Image<bool> > interest_map_; + + // The current quadrant of the image to detect FAST keypoints in. + // Keypoint detection is staggered for performance reasons. Every four frames + // a full scan of the frame will have been performed. + int fast_quadrant_; + + Keypoint tmp_keypoints_[kMaxTempKeypoints]; +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/log_streaming.h b/tensorflow/examples/android/jni/object_tracking/log_streaming.h new file mode 100644 index 0000000000..e68945cc72 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/log_streaming.h @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_ + +#include <string.h> +#include <string> + +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +using namespace tensorflow; + +namespace tf_tracking { + +#define LOGV(...) +#define LOGD(...) +#define LOGI(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__); +#define LOGW(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__); +#define LOGE(...) LOG(INFO) << tensorflow::strings::Printf(__VA_ARGS__); + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/object_detector.cc b/tensorflow/examples/android/jni/object_tracking/object_detector.cc new file mode 100644 index 0000000000..7f65716fdf --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/object_detector.cc @@ -0,0 +1,27 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// NOTE: no native object detectors are currently provided or used by the code +// in this directory. This class remains mainly for historical reasons. +// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java. + +#include "tensorflow/examples/android/jni/object_tracking/object_detector.h" + +namespace tf_tracking { + +// This is here so that the vtable gets created properly. +ObjectDetectorBase::~ObjectDetectorBase() {} + +} // namespace tf_tracking diff --git a/tensorflow/examples/android/jni/object_tracking/object_detector.h b/tensorflow/examples/android/jni/object_tracking/object_detector.h new file mode 100644 index 0000000000..043f606e1d --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/object_detector.h @@ -0,0 +1,232 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// NOTE: no native object detectors are currently provided or used by the code +// in this directory. This class remains mainly for historical reasons. +// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java. + +// Defines the ObjectDetector class that is the main interface for detecting +// ObjectModelBases in frames. + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ + +#include <float.h> +#include <map> +#include <memory> +#include <sstream> +#include <string> +#include <vector> + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" +#ifdef __RENDER_OPENGL__ +#include "tensorflow/examples/android/jni/object_tracking/sprite.h" +#endif +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/image_data.h" +#include "tensorflow/examples/android/jni/object_tracking/object_model.h" + +namespace tf_tracking { + +// Adds BoundingSquares to a vector such that the first square added is centered +// in the position given and of square_size, and the remaining squares are added +// concentrentically, scaling down by scale_factor until the minimum threshold +// size is passed. +// Squares that do not fall completely within image_bounds will not be added. +static inline void FillWithSquares( + const BoundingBox& image_bounds, + const BoundingBox& position, + const float starting_square_size, + const float smallest_square_size, + const float scale_factor, + std::vector<BoundingSquare>* const squares) { + BoundingSquare descriptor_area = + GetCenteredSquare(position, starting_square_size); + + SCHECK(scale_factor < 1.0f, "Scale factor too large at %.2f!", scale_factor); + + // Use a do/while loop to ensure that at least one descriptor is created. + do { + if (image_bounds.Contains(descriptor_area.ToBoundingBox())) { + squares->push_back(descriptor_area); + } + descriptor_area.Scale(scale_factor); + } while (descriptor_area.size_ >= smallest_square_size - EPSILON); + LOGV("Created %zu squares starting from size %.2f to min size %.2f " + "using scale factor: %.2f", + squares->size(), starting_square_size, smallest_square_size, + scale_factor); +} + + +// Represents a potential detection of a specific ObjectExemplar and Descriptor +// at a specific position in the image. +class Detection { + public: + explicit Detection(const ObjectModelBase* const object_model, + const MatchScore match_score, + const BoundingBox& bounding_box) + : object_model_(object_model), + match_score_(match_score), + bounding_box_(bounding_box) {} + + Detection(const Detection& other) + : object_model_(other.object_model_), + match_score_(other.match_score_), + bounding_box_(other.bounding_box_) {} + + virtual ~Detection() {} + + inline BoundingBox GetObjectBoundingBox() const { + return bounding_box_; + } + + inline MatchScore GetMatchScore() const { + return match_score_; + } + + inline const ObjectModelBase* GetObjectModel() const { + return object_model_; + } + + inline bool Intersects(const Detection& other) { + // Check if any of the four axes separates us, there must be at least one. + return bounding_box_.Intersects(other.bounding_box_); + } + + struct Comp { + inline bool operator()(const Detection& a, const Detection& b) const { + return a.match_score_ > b.match_score_; + } + }; + + // TODO(andrewharp): add accessors to update these instead. + const ObjectModelBase* object_model_; + MatchScore match_score_; + BoundingBox bounding_box_; +}; + +inline std::ostream& operator<<(std::ostream& stream, + const Detection& detection) { + const BoundingBox actual_area = detection.GetObjectBoundingBox(); + stream << actual_area; + return stream; +} + +class ObjectDetectorBase { + public: + explicit ObjectDetectorBase(const ObjectDetectorConfig* const config) + : config_(config), + image_data_(NULL) {} + + virtual ~ObjectDetectorBase(); + + // Sets the current image data. All calls to ObjectDetector other than + // FillDescriptors use the image data last set. + inline void SetImageData(const ImageData* const image_data) { + image_data_ = image_data; + } + + // Main entry point into the detection algorithm. + // Scans the frame for candidates, tweaks them, and fills in the + // given std::vector of Detection objects with acceptable matches. + virtual void Detect(const std::vector<BoundingSquare>& positions, + std::vector<Detection>* const detections) const = 0; + + virtual ObjectModelBase* CreateObjectModel(const std::string& name) = 0; + + virtual void DeleteObjectModel(const std::string& name) = 0; + + virtual void GetObjectModels( + std::vector<const ObjectModelBase*>* models) const = 0; + + // Creates a new ObjectExemplar from the given position in the context of + // the last frame passed to NextFrame. + // Will return null in the case that there's no room for a descriptor to be + // created in the example area, or the example area is not completely + // contained within the frame. + virtual void UpdateModel( + const Image<uint8>& base_image, + const IntegralImage& integral_image, + const BoundingBox& bounding_box, + const bool locked, + ObjectModelBase* model) const = 0; + + virtual void Draw() const = 0; + + virtual bool AllowSpontaneousDetections() = 0; + + protected: + const std::unique_ptr<const ObjectDetectorConfig> config_; + + // The latest frame data, upon which all detections will be performed. + // Not owned by this object, just provided for reference by ObjectTracker + // via SetImageData(). + const ImageData* image_data_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetectorBase); +}; + +template <typename ModelType> +class ObjectDetector : public ObjectDetectorBase { + public: + explicit ObjectDetector(const ObjectDetectorConfig* const config) + : ObjectDetectorBase(config) {} + + virtual ~ObjectDetector() { + typename std::map<std::string, ModelType*>::const_iterator it = + object_models_.begin(); + for (; it != object_models_.end(); ++it) { + ModelType* model = it->second; + delete model; + } + } + + virtual void DeleteObjectModel(const std::string& name) { + ModelType* model = object_models_[name]; + CHECK_ALWAYS(model != NULL, "Model was null!"); + object_models_.erase(name); + SAFE_DELETE(model); + } + + virtual void GetObjectModels( + std::vector<const ObjectModelBase*>* models) const { + typename std::map<std::string, ModelType*>::const_iterator it = + object_models_.begin(); + for (; it != object_models_.end(); ++it) { + models->push_back(it->second); + } + } + + virtual bool AllowSpontaneousDetections() { + return false; + } + + protected: + std::map<std::string, ModelType*> object_models_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetector); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/object_model.h b/tensorflow/examples/android/jni/object_tracking/object_model.h new file mode 100644 index 0000000000..2d359668b2 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/object_model.h @@ -0,0 +1,101 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// NOTE: no native object detectors are currently provided or used by the code +// in this directory. This class remains mainly for historical reasons. +// Detection in the TF demo is done through TensorFlowMultiBoxDetector.java. + +// Contains ObjectModelBase declaration. + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_ + +#ifdef __RENDER_OPENGL__ +#include <GLES/gl.h> +#include <GLES/glext.h> +#endif + +#include <vector> + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" +#ifdef __RENDER_OPENGL__ +#include "tensorflow/examples/android/jni/object_tracking/sprite.h" +#endif +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/image_data.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" + +namespace tf_tracking { + +// The ObjectModelBase class represents all the known appearance information for +// an object. It is not a specific instance of the object in the world, +// but just the general appearance information that enables detection. An +// ObjectModelBase can be reused across multiple-instances of TrackedObjects. +class ObjectModelBase { + public: + ObjectModelBase(const std::string& name) : name_(name) {} + + virtual ~ObjectModelBase() {} + + // Called when the next step in an ongoing track occurs. + virtual void TrackStep( + const BoundingBox& position, const Image<uint8>& image, + const IntegralImage& integral_image, const bool authoritative) {} + + // Called when an object track is lost. + virtual void TrackLost() {} + + // Called when an object track is confirmed as legitimate. + virtual void TrackConfirmed() {} + + virtual float GetMaxCorrelation(const Image<float>& patch_image) const = 0; + + virtual MatchScore GetMatchScore( + const BoundingBox& position, const ImageData& image_data) const = 0; + + virtual void Draw(float* const depth) const = 0; + + inline const std::string& GetName() const { + return name_; + } + + protected: + const std::string name_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ObjectModelBase); +}; + +template <typename DetectorType> +class ObjectModel : public ObjectModelBase { + public: + ObjectModel<DetectorType>(const DetectorType* const detector, + const std::string& name) + : ObjectModelBase(name), detector_(detector) {} + + protected: + const DetectorType* const detector_; + + TF_DISALLOW_COPY_AND_ASSIGN(ObjectModel<DetectorType>); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker.cc b/tensorflow/examples/android/jni/object_tracking/object_tracker.cc new file mode 100644 index 0000000000..1d867b934b --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/object_tracker.cc @@ -0,0 +1,690 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef __RENDER_OPENGL__ +#include <GLES/gl.h> +#include <GLES/glext.h> +#endif + +#include <string> +#include <map> + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" +#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h" +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" +#include "tensorflow/examples/android/jni/object_tracking/object_detector.h" +#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h" +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" + +namespace tf_tracking { + +ObjectTracker::ObjectTracker(const TrackerConfig* const config, + ObjectDetectorBase* const detector) + : config_(config), + frame_width_(config->image_size.width), + frame_height_(config->image_size.height), + curr_time_(0), + num_frames_(0), + flow_cache_(&config->flow_config), + keypoint_detector_(&config->keypoint_detector_config), + curr_num_frame_pairs_(0), + first_frame_index_(0), + frame1_(new ImageData(frame_width_, frame_height_)), + frame2_(new ImageData(frame_width_, frame_height_)), + detector_(detector), + num_detected_(0) { + for (int i = 0; i < kNumFrames; ++i) { + frame_pairs_[i].Init(-1, -1); + } +} + + +ObjectTracker::~ObjectTracker() { + for (TrackedObjectMap::iterator iter = objects_.begin(); + iter != objects_.end(); iter++) { + TrackedObject* object = iter->second; + SAFE_DELETE(object); + } +} + + +// Finds the correspondences for all the points in the current pair of frames. +// Stores the results in the given FramePair. +void ObjectTracker::FindCorrespondences(FramePair* const frame_pair) const { + // Keypoints aren't found until they're found. + memset(frame_pair->optical_flow_found_keypoint_, false, + sizeof(*frame_pair->optical_flow_found_keypoint_) * kMaxKeypoints); + TimeLog("Cleared old found keypoints"); + + int num_keypoints_found = 0; + + // For every keypoint... + for (int i_feat = 0; i_feat < frame_pair->number_of_keypoints_; ++i_feat) { + Keypoint* const keypoint1 = frame_pair->frame1_keypoints_ + i_feat; + Keypoint* const keypoint2 = frame_pair->frame2_keypoints_ + i_feat; + + if (flow_cache_.FindNewPositionOfPoint( + keypoint1->pos_.x, keypoint1->pos_.y, + &keypoint2->pos_.x, &keypoint2->pos_.y)) { + frame_pair->optical_flow_found_keypoint_[i_feat] = true; + ++num_keypoints_found; + } + } + + TimeLog("Found correspondences"); + + LOGV("Found %d of %d keypoint correspondences", + num_keypoints_found, frame_pair->number_of_keypoints_); +} + + +void ObjectTracker::NextFrame(const uint8* const new_frame, + const uint8* const uv_frame, + const int64 timestamp, + const float* const alignment_matrix_2x3) { + IncrementFrameIndex(); + LOGV("Received frame %d", num_frames_); + + FramePair* const curr_change = frame_pairs_ + GetNthIndexFromEnd(0); + curr_change->Init(curr_time_, timestamp); + + CHECK_ALWAYS(curr_time_ < timestamp, + "Timestamp must monotonically increase! Went from %lld to %lld" + " on frame %d.", + curr_time_, timestamp, num_frames_); + curr_time_ = timestamp; + + // Swap the frames. + frame1_.swap(frame2_); + + frame2_->SetData(new_frame, uv_frame, frame_width_, timestamp, 1); + + if (detector_.get() != NULL) { + detector_->SetImageData(frame2_.get()); + } + + flow_cache_.NextFrame(frame2_.get(), alignment_matrix_2x3); + + if (num_frames_ == 1) { + // This must be the first frame, so abort. + return; + } + + if (config_->always_track || objects_.size() > 0) { + LOGV("Tracking %zu targets", objects_.size()); + ComputeKeypoints(true); + TimeLog("Keypoints computed!"); + + FindCorrespondences(curr_change); + TimeLog("Flow computed!"); + + TrackObjects(); + } + TimeLog("Targets tracked!"); + + if (detector_.get() != NULL && num_frames_ % kDetectEveryNFrames == 0) { + DetectTargets(); + } + TimeLog("Detected objects."); +} + + +TrackedObject* ObjectTracker::MaybeAddObject( + const std::string& id, + const Image<uint8>& source_image, + const BoundingBox& bounding_box, + const ObjectModelBase* object_model) { + // Train the detector if this is a new object. + if (objects_.find(id) != objects_.end()) { + return objects_[id]; + } + + // Need to get a non-const version of the model, or create a new one if it + // wasn't given. + ObjectModelBase* model = NULL; + if (detector_ != NULL) { + // If a detector is registered, then this new object must have a model. + CHECK_ALWAYS(object_model != NULL, "No model given!"); + model = detector_->CreateObjectModel(object_model->GetName()); + } + TrackedObject* const object = + new TrackedObject(id, source_image, bounding_box, model); + + objects_[id] = object; + return object; +} + + +void ObjectTracker::RegisterNewObjectWithAppearance( + const std::string& id, const uint8* const new_frame, + const BoundingBox& bounding_box) { + ObjectModelBase* object_model = NULL; + + Image<uint8> image(frame_width_, frame_height_); + image.FromArray(new_frame, frame_width_, 1); + + if (detector_ != NULL) { + object_model = detector_->CreateObjectModel(id); + CHECK_ALWAYS(object_model != NULL, "Null object model!"); + + const IntegralImage integral_image(image); + object_model->TrackStep(bounding_box, image, integral_image, true); + } + + // Create an object at this position. + CHECK_ALWAYS(!HaveObject(id), "Already have this object!"); + if (objects_.find(id) == objects_.end()) { + TrackedObject* const object = + MaybeAddObject(id, image, bounding_box, object_model); + CHECK_ALWAYS(object != NULL, "Object not created!"); + } +} + + +void ObjectTracker::SetPreviousPositionOfObject(const std::string& id, + const BoundingBox& bounding_box, + const int64 timestamp) { + CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %lld", timestamp); + CHECK_ALWAYS(timestamp <= curr_time_, + "Timestamp too great! %lld vs %lld", timestamp, curr_time_); + + TrackedObject* const object = GetObject(id); + + // Track this bounding box from the past to the current time. + const BoundingBox current_position = TrackBox(bounding_box, timestamp); + + object->UpdatePosition(current_position, curr_time_, *frame2_, false); + + VLOG(2) << "Set tracked position for " << id << " to " << bounding_box + << std::endl; +} + + +void ObjectTracker::SetCurrentPositionOfObject( + const std::string& id, const BoundingBox& bounding_box) { + SetPreviousPositionOfObject(id, bounding_box, curr_time_); +} + + +void ObjectTracker::ForgetTarget(const std::string& id) { + LOGV("Forgetting object %s", id.c_str()); + TrackedObject* const object = GetObject(id); + delete object; + objects_.erase(id); + + if (detector_ != NULL) { + detector_->DeleteObjectModel(id); + } +} + + +int ObjectTracker::GetKeypointsPacked(uint16* const out_data, + const float scale) const { + const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)]; + uint16* curr_data = out_data; + int num_keypoints = 0; + + for (int i = 0; i < change.number_of_keypoints_; ++i) { + if (change.optical_flow_found_keypoint_[i]) { + ++num_keypoints; + const Point2f& point1 = change.frame1_keypoints_[i].pos_; + *curr_data++ = RealToFixed115(point1.x * scale); + *curr_data++ = RealToFixed115(point1.y * scale); + + const Point2f& point2 = change.frame2_keypoints_[i].pos_; + *curr_data++ = RealToFixed115(point2.x * scale); + *curr_data++ = RealToFixed115(point2.y * scale); + } + } + + return num_keypoints; +} + + +int ObjectTracker::GetKeypoints(const bool only_found, + float* const out_data) const { + int curr_keypoint = 0; + const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)]; + + for (int i = 0; i < change.number_of_keypoints_; ++i) { + if (!only_found || change.optical_flow_found_keypoint_[i]) { + const int base = curr_keypoint * kKeypointStep; + out_data[base + 0] = change.frame1_keypoints_[i].pos_.x; + out_data[base + 1] = change.frame1_keypoints_[i].pos_.y; + + out_data[base + 2] = + change.optical_flow_found_keypoint_[i] ? 1.0f : -1.0f; + out_data[base + 3] = change.frame2_keypoints_[i].pos_.x; + out_data[base + 4] = change.frame2_keypoints_[i].pos_.y; + + out_data[base + 5] = change.frame1_keypoints_[i].score_; + out_data[base + 6] = change.frame1_keypoints_[i].type_; + ++curr_keypoint; + } + } + + LOGV("Got %d keypoints.", curr_keypoint); + + return curr_keypoint; +} + + +BoundingBox ObjectTracker::TrackBox(const BoundingBox& region, + const FramePair& frame_pair) const { + float translation_x; + float translation_y; + + float scale_x; + float scale_y; + + BoundingBox tracked_box(region); + frame_pair.AdjustBox( + tracked_box, &translation_x, &translation_y, &scale_x, &scale_y); + + tracked_box.Shift(Point2f(translation_x, translation_y)); + + if (scale_x > 0 && scale_y > 0) { + tracked_box.Scale(scale_x, scale_y); + } + return tracked_box; +} + + +BoundingBox ObjectTracker::TrackBox(const BoundingBox& region, + const int64 timestamp) const { + CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %lld", timestamp); + CHECK_ALWAYS(timestamp <= curr_time_, "Timestamp is in the future!"); + + // Anything that ended before the requested timestamp is of no concern to us. + bool found_it = false; + int num_frames_back = -1; + for (int i = 0; i < curr_num_frame_pairs_; ++i) { + const FramePair& frame_pair = + frame_pairs_[GetNthIndexFromEnd(i)]; + + if (frame_pair.end_time_ <= timestamp) { + num_frames_back = i - 1; + + if (num_frames_back > 0) { + LOGV("Went %d out of %d frames before finding frame. (index: %d)", + num_frames_back, curr_num_frame_pairs_, GetNthIndexFromEnd(i)); + } + + found_it = true; + break; + } + } + + if (!found_it) { + LOGW("History did not go back far enough! %lld vs %lld", + frame_pairs_[GetNthIndexFromEnd(0)].end_time_ - + frame_pairs_[GetNthIndexFromStart(0)].end_time_, + frame_pairs_[GetNthIndexFromEnd(0)].end_time_ - timestamp); + } + + // Loop over all the frames in the queue, tracking the accumulated delta + // of the point from frame to frame. It's possible the point could + // go out of frame, but keep tracking as best we can, using points near + // the edge of the screen where it went out of bounds. + BoundingBox tracked_box(region); + for (int i = num_frames_back; i >= 0; --i) { + const FramePair& frame_pair = frame_pairs_[GetNthIndexFromEnd(i)]; + SCHECK(frame_pair.end_time_ >= timestamp, "Frame timestamp was too early!"); + tracked_box = TrackBox(tracked_box, frame_pair); + } + return tracked_box; +} + + +// Converts a row-major 3x3 2d transformation matrix to a column-major 4x4 +// 3d transformation matrix. +inline void Convert3x3To4x4( + const float* const in_matrix, float* const out_matrix) { + // X + out_matrix[0] = in_matrix[0]; + out_matrix[1] = in_matrix[3]; + out_matrix[2] = 0.0f; + out_matrix[3] = 0.0f; + + // Y + out_matrix[4] = in_matrix[1]; + out_matrix[5] = in_matrix[4]; + out_matrix[6] = 0.0f; + out_matrix[7] = 0.0f; + + // Z + out_matrix[8] = 0.0f; + out_matrix[9] = 0.0f; + out_matrix[10] = 1.0f; + out_matrix[11] = 0.0f; + + // Translation + out_matrix[12] = in_matrix[2]; + out_matrix[13] = in_matrix[5]; + out_matrix[14] = 0.0f; + out_matrix[15] = 1.0f; +} + + +void ObjectTracker::Draw(const int canvas_width, const int canvas_height, + const float* const frame_to_canvas) const { +#ifdef __RENDER_OPENGL__ + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT); + + glMatrixMode(GL_PROJECTION); + glLoadIdentity(); + + glOrthof(0.0f, canvas_width, 0.0f, canvas_height, 0.0f, 1.0f); + + // To make Y go the right direction (0 at top of frame). + glScalef(1.0f, -1.0f, 1.0f); + glTranslatef(0.0f, -canvas_height, 0.0f); + + glMatrixMode(GL_MODELVIEW); + glLoadIdentity(); + + glPushMatrix(); + + // Apply the frame to canvas transformation. + static GLfloat transformation[16]; + Convert3x3To4x4(frame_to_canvas, transformation); + glMultMatrixf(transformation); + + // Draw tracked object bounding boxes. + for (TrackedObjectMap::const_iterator iter = objects_.begin(); + iter != objects_.end(); ++iter) { + TrackedObject* tracked_object = iter->second; + tracked_object->Draw(); + } + + static const bool kRenderDebugPyramid = false; + if (kRenderDebugPyramid) { + glColor4f(1.0f, 1.0f, 1.0f, 1.0f); + for (int i = 0; i < kNumPyramidLevels * 2; ++i) { + Sprite(*frame1_->GetPyramidSqrt2Level(i)).Draw(); + } + } + + static const bool kRenderDebugDerivative = false; + if (kRenderDebugDerivative) { + glColor4f(1.0f, 1.0f, 1.0f, 1.0f); + for (int i = 0; i < kNumPyramidLevels; ++i) { + const Image<int32>& dx = *frame1_->GetSpatialX(i); + Image<uint8> render_image(dx.GetWidth(), dx.GetHeight()); + for (int y = 0; y < dx.GetHeight(); ++y) { + const int32* dx_ptr = dx[y]; + uint8* dst_ptr = render_image[y]; + for (int x = 0; x < dx.GetWidth(); ++x) { + *dst_ptr++ = Clip(-(*dx_ptr++), 0, 255); + } + } + + Sprite(render_image).Draw(); + } + } + + if (detector_ != NULL) { + glDisable(GL_CULL_FACE); + detector_->Draw(); + } + glPopMatrix(); +#endif +} + +static void AddQuadrants(const BoundingBox& box, + std::vector<BoundingBox>* boxes) { + const Point2f center = box.GetCenter(); + + float x1 = box.left_; + float x2 = center.x; + float x3 = box.right_; + + float y1 = box.top_; + float y2 = center.y; + float y3 = box.bottom_; + + // Upper left. + boxes->push_back(BoundingBox(x1, y1, x2, y2)); + + // Upper right. + boxes->push_back(BoundingBox(x2, y1, x3, y2)); + + // Bottom left. + boxes->push_back(BoundingBox(x1, y2, x2, y3)); + + // Bottom right. + boxes->push_back(BoundingBox(x2, y2, x3, y3)); + + // Whole thing. + boxes->push_back(box); +} + +void ObjectTracker::ComputeKeypoints(const bool cached_ok) { + const FramePair& prev_change = frame_pairs_[GetNthIndexFromEnd(1)]; + FramePair* const curr_change = &frame_pairs_[GetNthIndexFromEnd(0)]; + + std::vector<BoundingBox> boxes; + + for (TrackedObjectMap::iterator object_iter = objects_.begin(); + object_iter != objects_.end(); ++object_iter) { + BoundingBox box = object_iter->second->GetPosition(); + box.Scale(config_->object_box_scale_factor_for_features, + config_->object_box_scale_factor_for_features); + AddQuadrants(box, &boxes); + } + + AddQuadrants(frame1_->GetImage()->GetContainingBox(), &boxes); + + keypoint_detector_.FindKeypoints(*frame1_, boxes, prev_change, curr_change); +} + + +// Given a vector of detections and a model, simply returns the Detection for +// that model with the highest correlation. +bool ObjectTracker::GetBestObjectForDetection( + const Detection& detection, TrackedObject** match) const { + TrackedObject* best_match = NULL; + float best_overlap = -FLT_MAX; + + LOGV("Looking for matches in %zu objects!", objects_.size()); + for (TrackedObjectMap::const_iterator object_iter = objects_.begin(); + object_iter != objects_.end(); ++object_iter) { + TrackedObject* const tracked_object = object_iter->second; + + const float overlap = tracked_object->GetPosition().PascalScore( + detection.GetObjectBoundingBox()); + + if (!detector_->AllowSpontaneousDetections() && + (detection.GetObjectModel() != tracked_object->GetModel())) { + if (overlap > 0.0f) { + return false; + } + continue; + } + + const float jump_distance = + (tracked_object->GetPosition().GetCenter() - + detection.GetObjectBoundingBox().GetCenter()).LengthSquared(); + + const float allowed_distance = + tracked_object->GetAllowableDistanceSquared(); + + LOGV("Distance: %.2f, Allowed distance %.2f, Overlap: %.2f", + jump_distance, allowed_distance, overlap); + + // TODO(andrewharp): No need to do this verification twice, eliminate + // one of the score checks (the other being in OnDetection). + if (jump_distance < allowed_distance && + overlap > best_overlap && + tracked_object->GetMatchScore() + kMatchScoreBuffer < + detection.GetMatchScore()) { + best_match = tracked_object; + best_overlap = overlap; + } else if (overlap > 0.0f) { + return false; + } + } + + *match = best_match; + return true; +} + + +void ObjectTracker::ProcessDetections( + std::vector<Detection>* const detections) { + LOGV("Initial detection done, iterating over %zu detections now.", + detections->size()); + + const bool spontaneous_detections_allowed = + detector_->AllowSpontaneousDetections(); + for (std::vector<Detection>::const_iterator it = detections->begin(); + it != detections->end(); ++it) { + const Detection& detection = *it; + SCHECK(frame2_->GetImage()->Contains(detection.GetObjectBoundingBox()), + "Frame does not contain bounding box!"); + + TrackedObject* best_match = NULL; + + const bool no_collisions = + GetBestObjectForDetection(detection, &best_match); + + // Need to get a non-const version of the model, or create a new one if it + // wasn't given. + ObjectModelBase* model = + const_cast<ObjectModelBase*>(detection.GetObjectModel()); + + if (best_match != NULL) { + if (model != best_match->GetModel()) { + CHECK_ALWAYS(detector_->AllowSpontaneousDetections(), + "Model for object changed but spontaneous detections not allowed!"); + } + best_match->OnDetection(model, + detection.GetObjectBoundingBox(), + detection.GetMatchScore(), + curr_time_, *frame2_); + } else if (no_collisions && spontaneous_detections_allowed) { + if (detection.GetMatchScore() > kMinimumMatchScore) { + LOGV("No match, adding it!"); + const ObjectModelBase* model = detection.GetObjectModel(); + std::ostringstream ss; + // TODO(andrewharp): Generate this in a more general fashion. + ss << "hand_" << num_detected_++; + std::string object_name = ss.str(); + MaybeAddObject(object_name, *frame2_->GetImage(), + detection.GetObjectBoundingBox(), model); + } + } + } +} + + +void ObjectTracker::DetectTargets() { + // Detect all object model types that we're currently tracking. + std::vector<const ObjectModelBase*> object_models; + detector_->GetObjectModels(&object_models); + if (object_models.size() == 0) { + LOGV("No objects to search for, aborting."); + return; + } + + LOGV("Trying to detect %zu models", object_models.size()); + + LOGV("Creating test vector!"); + std::vector<BoundingSquare> positions; + + for (TrackedObjectMap::iterator object_iter = objects_.begin(); + object_iter != objects_.end(); ++object_iter) { + TrackedObject* const tracked_object = object_iter->second; + +#if DEBUG_PREDATOR + positions.push_back(GetCenteredSquare( + frame2_->GetImage()->GetContainingBox(), 32.0f)); +#else + const BoundingBox& position = tracked_object->GetPosition(); + + const float square_size = MAX( + kScanMinSquareSize / (kLastKnownPositionScaleFactor * + kLastKnownPositionScaleFactor), + MIN(position.GetWidth(), + position.GetHeight())) / kLastKnownPositionScaleFactor; + + FillWithSquares(frame2_->GetImage()->GetContainingBox(), + tracked_object->GetPosition(), + square_size, + kScanMinSquareSize, + kLastKnownPositionScaleFactor, + &positions); + } +#endif + + LOGV("Created test vector!"); + + std::vector<Detection> detections; + LOGV("Detecting!"); + detector_->Detect(positions, &detections); + LOGV("Found %zu detections", detections.size()); + + TimeLog("Finished detection."); + + ProcessDetections(&detections); + + TimeLog("iterated over detections"); + + LOGV("Done detecting!"); +} + + +void ObjectTracker::TrackObjects() { + // TODO(andrewharp): Correlation should be allowed to remove objects too. + const bool automatic_removal_allowed = detector_.get() != NULL ? + detector_->AllowSpontaneousDetections() : false; + + LOGV("Tracking %zu objects!", objects_.size()); + std::vector<std::string> dead_objects; + for (TrackedObjectMap::iterator iter = objects_.begin(); + iter != objects_.end(); iter++) { + TrackedObject* object = iter->second; + const BoundingBox tracked_position = TrackBox( + object->GetPosition(), frame_pairs_[GetNthIndexFromEnd(0)]); + object->UpdatePosition(tracked_position, curr_time_, *frame2_, false); + + if (automatic_removal_allowed && + object->GetNumConsecutiveFramesBelowThreshold() > + kMaxNumDetectionFailures * 5) { + dead_objects.push_back(iter->first); + } + } + + if (detector_ != NULL && automatic_removal_allowed) { + for (std::vector<std::string>::iterator iter = dead_objects.begin(); + iter != dead_objects.end(); iter++) { + LOGE("Removing object! %s", iter->c_str()); + ForgetTarget(*iter); + } + } + TimeLog("Tracked all objects."); + + LOGV("%zu objects tracked!", objects_.size()); +} + +} // namespace tf_tracking diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker.h b/tensorflow/examples/android/jni/object_tracking/object_tracker.h new file mode 100644 index 0000000000..3d2a9af360 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/object_tracker.h @@ -0,0 +1,271 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ + +#include <map> +#include <string> + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" +#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h" +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" +#include "tensorflow/examples/android/jni/object_tracking/object_model.h" +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" +#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h" + +namespace tf_tracking { + +typedef std::map<const std::string, TrackedObject*> TrackedObjectMap; + +inline std::ostream& operator<<(std::ostream& stream, + const TrackedObjectMap& map) { + for (TrackedObjectMap::const_iterator iter = map.begin(); + iter != map.end(); ++iter) { + const TrackedObject& tracked_object = *iter->second; + const std::string& key = iter->first; + stream << key << ": " << tracked_object; + } + return stream; +} + + +// ObjectTracker is the highest-level class in the tracking/detection framework. +// It handles basic image processing, keypoint detection, keypoint tracking, +// object tracking, and object detection/relocalization. +class ObjectTracker { + public: + ObjectTracker(const TrackerConfig* const config, + ObjectDetectorBase* const detector); + virtual ~ObjectTracker(); + + virtual void NextFrame(const uint8* const new_frame, + const int64 timestamp, + const float* const alignment_matrix_2x3) { + NextFrame(new_frame, NULL, timestamp, alignment_matrix_2x3); + } + + // Called upon the arrival of a new frame of raw data. + // Does all image processing, keypoint detection, and object + // tracking/detection for registered objects. + // Argument alignment_matrix_2x3 is a 2x3 matrix (stored row-wise) that + // represents the main transformation that has happened between the last + // and the current frame. + // Argument align_level is the pyramid level (where 0 == finest) that + // the matrix is valid for. + virtual void NextFrame(const uint8* const new_frame, + const uint8* const uv_frame, + const int64 timestamp, + const float* const alignment_matrix_2x3); + + virtual void RegisterNewObjectWithAppearance( + const std::string& id, const uint8* const new_frame, + const BoundingBox& bounding_box); + + // Updates the position of a tracked object, given that it was known to be at + // a certain position at some point in the past. + virtual void SetPreviousPositionOfObject(const std::string& id, + const BoundingBox& bounding_box, + const int64 timestamp); + + // Sets the current position of the object in the most recent frame provided. + virtual void SetCurrentPositionOfObject(const std::string& id, + const BoundingBox& bounding_box); + + // Tells the ObjectTracker to stop tracking a target. + void ForgetTarget(const std::string& id); + + // Fills the given out_data buffer with the latest detected keypoint + // correspondences, first scaled by scale_factor (to adjust for downsampling + // that may have occurred elsewhere), then packed in a fixed-point format. + int GetKeypointsPacked(uint16* const out_data, + const float scale_factor) const; + + // Copy the keypoint arrays after computeFlow is called. + // out_data should be at least kMaxKeypoints * kKeypointStep long. + // Currently, its format is [x1 y1 found x2 y2 score] repeated N times, + // where N is the number of keypoints tracked. N is returned as the result. + int GetKeypoints(const bool only_found, float* const out_data) const; + + // Returns the current position of a box, given that it was at a certain + // position at the given time. + BoundingBox TrackBox(const BoundingBox& region, + const int64 timestamp) const; + + // Returns the number of frames that have been passed to NextFrame(). + inline int GetNumFrames() const { + return num_frames_; + } + + inline bool HaveObject(const std::string& id) const { + return objects_.find(id) != objects_.end(); + } + + // Returns the TrackedObject associated with the given id. + inline const TrackedObject* GetObject(const std::string& id) const { + TrackedObjectMap::const_iterator iter = objects_.find(id); + CHECK_ALWAYS(iter != objects_.end(), + "Unknown object key! \"%s\"", id.c_str()); + TrackedObject* const object = iter->second; + return object; + } + + // Returns the TrackedObject associated with the given id. + inline TrackedObject* GetObject(const std::string& id) { + TrackedObjectMap::iterator iter = objects_.find(id); + CHECK_ALWAYS(iter != objects_.end(), + "Unknown object key! \"%s\"", id.c_str()); + TrackedObject* const object = iter->second; + return object; + } + + bool IsObjectVisible(const std::string& id) const { + SCHECK(HaveObject(id), "Don't have this object."); + + const TrackedObject* object = GetObject(id); + return object->IsVisible(); + } + + virtual void Draw(const int canvas_width, const int canvas_height, + const float* const frame_to_canvas) const; + + protected: + // Creates a new tracked object at the given position. + // If an object model is provided, then that model will be associated with the + // object. If not, a new model may be created from the appearance at the + // initial position and registered with the object detector. + virtual TrackedObject* MaybeAddObject(const std::string& id, + const Image<uint8>& image, + const BoundingBox& bounding_box, + const ObjectModelBase* object_model); + + // Find the keypoints in the frame before the current frame. + // If only one frame exists, keypoints will be found in that frame. + void ComputeKeypoints(const bool cached_ok = false); + + // Finds the correspondences for all the points in the current pair of frames. + // Stores the results in the given FramePair. + void FindCorrespondences(FramePair* const curr_change) const; + + inline int GetNthIndexFromEnd(const int offset) const { + return GetNthIndexFromStart(curr_num_frame_pairs_ - 1 - offset); + } + + BoundingBox TrackBox(const BoundingBox& region, + const FramePair& frame_pair) const; + + inline void IncrementFrameIndex() { + // Move the current framechange index up. + ++num_frames_; + ++curr_num_frame_pairs_; + + // If we've got too many, push up the start of the queue. + if (curr_num_frame_pairs_ > kNumFrames) { + first_frame_index_ = GetNthIndexFromStart(1); + --curr_num_frame_pairs_; + } + } + + inline int GetNthIndexFromStart(const int offset) const { + SCHECK(offset >= 0 && offset < curr_num_frame_pairs_, + "Offset out of range! %d out of %d.", offset, curr_num_frame_pairs_); + return (first_frame_index_ + offset) % kNumFrames; + } + + void TrackObjects(); + + const std::unique_ptr<const TrackerConfig> config_; + + const int frame_width_; + const int frame_height_; + + int64 curr_time_; + + int num_frames_; + + TrackedObjectMap objects_; + + FlowCache flow_cache_; + + KeypointDetector keypoint_detector_; + + int curr_num_frame_pairs_; + int first_frame_index_; + + std::unique_ptr<ImageData> frame1_; + std::unique_ptr<ImageData> frame2_; + + FramePair frame_pairs_[kNumFrames]; + + std::unique_ptr<ObjectDetectorBase> detector_; + + int num_detected_; + + private: + void TrackTarget(TrackedObject* const object); + + bool GetBestObjectForDetection( + const Detection& detection, TrackedObject** match) const; + + void ProcessDetections(std::vector<Detection>* const detections); + + void DetectTargets(); + + // Temp object used in ObjectTracker::CreateNewExample. + mutable std::vector<BoundingSquare> squares; + + friend std::ostream& operator<<(std::ostream& stream, + const ObjectTracker& tracker); + + TF_DISALLOW_COPY_AND_ASSIGN(ObjectTracker); +}; + +inline std::ostream& operator<<(std::ostream& stream, + const ObjectTracker& tracker) { + stream << "Frame size: " << tracker.frame_width_ << "x" + << tracker.frame_height_ << std::endl; + + stream << "Num frames: " << tracker.num_frames_ << std::endl; + + stream << "Curr time: " << tracker.curr_time_ << std::endl; + + const int first_frame_index = tracker.GetNthIndexFromStart(0); + const FramePair& first_frame_pair = tracker.frame_pairs_[first_frame_index]; + + const int last_frame_index = tracker.GetNthIndexFromEnd(0); + const FramePair& last_frame_pair = tracker.frame_pairs_[last_frame_index]; + + stream << "first frame: " << first_frame_index << "," + << first_frame_pair.end_time_ << " " + << "last frame: " << last_frame_index << "," + << last_frame_pair.end_time_ << " diff: " + << last_frame_pair.end_time_ - first_frame_pair.end_time_ << "ms" + << std::endl; + + stream << "Tracked targets:"; + stream << tracker.objects_; + + return stream; +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc b/tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc new file mode 100644 index 0000000000..30c5974654 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc @@ -0,0 +1,463 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <android/log.h> +#include <jni.h> +#include <stdint.h> +#include <stdlib.h> +#include <string.h> +#include <cstdint> + +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/jni_utils.h" +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h" + +using namespace tensorflow; + +namespace tf_tracking { + +#define OBJECT_TRACKER_METHOD(METHOD_NAME) \ + Java_org_tensorflow_demo_tracking_ObjectTracker_##METHOD_NAME // NOLINT + +JniIntField object_tracker_field("nativeObjectTracker"); + +ObjectTracker* get_object_tracker(JNIEnv* env, jobject thiz) { + ObjectTracker* const object_tracker = + reinterpret_cast<ObjectTracker*>(object_tracker_field.get(env, thiz)); + CHECK_ALWAYS(object_tracker != NULL, "null object tracker!"); + return object_tracker; +} + +void set_object_tracker(JNIEnv* env, jobject thiz, + const ObjectTracker* object_tracker) { + object_tracker_field.set(env, thiz, + reinterpret_cast<intptr_t>(object_tracker)); +} + +#ifdef __cplusplus +extern "C" { +#endif +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz, + jint width, jint height, + jboolean always_track); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env, + jobject thiz); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)( + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, + jfloat x2, jfloat y2, jbyteArray frame_data); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)( + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, + jfloat x2, jfloat y2, jlong timestamp); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)( + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, + jfloat x2, jfloat y2); + +JNIEXPORT +jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz, + jstring object_id); + +JNIEXPORT +jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env, + jobject thiz, + jstring object_id); + +JNIEXPORT +jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env, + jobject thiz, + jstring object_id); + +JNIEXPORT +jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env, + jobject thiz, + jstring object_id); + +JNIEXPORT +jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz, + jstring object_id); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)( + JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz, + jbyteArray y_data, + jbyteArray uv_data, + jlong timestamp, + jfloatArray vg_matrix_2x3); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz, + jstring object_id); + +JNIEXPORT +jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)( + JNIEnv* env, jobject thiz, jfloat scale_factor); + +JNIEXPORT +jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)( + JNIEnv* env, jobject thiz, jboolean only_found_); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)( + JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1, + jfloat position_y1, jfloat position_x2, jfloat position_y2, + jfloatArray delta); + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(drawNative)(JNIEnv* env, jobject obj, + jint view_width, + jint view_height, + jfloatArray delta); + +JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)( + JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride, + jbyteArray input, jint factor, jbyteArray output); + +#ifdef __cplusplus +} +#endif + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz, + jint width, jint height, + jboolean always_track) { + LOGI("Initializing object tracker. %dx%d @%p", width, height, thiz); + const Size image_size(width, height); + TrackerConfig* const tracker_config = new TrackerConfig(image_size); + tracker_config->always_track = always_track; + + // XXX detector + ObjectTracker* const tracker = new ObjectTracker(tracker_config, NULL); + set_object_tracker(env, thiz, tracker); + LOGI("Initialized!"); + + CHECK_ALWAYS(get_object_tracker(env, thiz) == tracker, + "Failure to set hand tracker!"); +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env, + jobject thiz) { + delete get_object_tracker(env, thiz); + set_object_tracker(env, thiz, NULL); +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)( + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, + jfloat x2, jfloat y2, jbyteArray frame_data) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1, + x2, y2); + + jboolean iCopied = JNI_FALSE; + + // Copy image into currFrame. + jbyte* pixels = env->GetByteArrayElements(frame_data, &iCopied); + + BoundingBox bounding_box(x1, y1, x2, y2); + get_object_tracker(env, thiz)->RegisterNewObjectWithAppearance( + id_str, reinterpret_cast<const uint8*>(pixels), bounding_box); + + env->ReleaseByteArrayElements(frame_data, pixels, JNI_ABORT); + + env->ReleaseStringUTFChars(object_id, id_str); +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)( + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, + jfloat x2, jfloat y2, jlong timestamp) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + LOGI( + "Registering the position of %s at %.2f,%.2f,%.2f,%.2f" + " at time %lld", + id_str, x1, y1, x2, y2, static_cast<int64>(timestamp)); + + get_object_tracker(env, thiz)->SetPreviousPositionOfObject( + id_str, BoundingBox(x1, y1, x2, y2), timestamp); + + env->ReleaseStringUTFChars(object_id, id_str); +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)( + JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1, + jfloat x2, jfloat y2) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1, + x2, y2); + + get_object_tracker(env, thiz)->SetCurrentPositionOfObject( + id_str, BoundingBox(x1, y1, x2, y2)); + + env->ReleaseStringUTFChars(object_id, id_str); +} + +JNIEXPORT +jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz, + jstring object_id) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + const bool haveObject = get_object_tracker(env, thiz)->HaveObject(id_str); + env->ReleaseStringUTFChars(object_id, id_str); + return haveObject; +} + +JNIEXPORT +jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env, + jobject thiz, + jstring object_id) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + const bool visible = get_object_tracker(env, thiz)->IsObjectVisible(id_str); + env->ReleaseStringUTFChars(object_id, id_str); + return visible; +} + +JNIEXPORT +jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env, + jobject thiz, + jstring object_id) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + const TrackedObject* const object = + get_object_tracker(env, thiz)->GetObject(id_str); + env->ReleaseStringUTFChars(object_id, id_str); + jstring model_name = env->NewStringUTF(object->GetModel()->GetName().c_str()); + return model_name; +} + +JNIEXPORT +jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env, + jobject thiz, + jstring object_id) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + const float correlation = + get_object_tracker(env, thiz)->GetObject(id_str)->GetCorrelation(); + env->ReleaseStringUTFChars(object_id, id_str); + return correlation; +} + +JNIEXPORT +jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz, + jstring object_id) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + const float match_score = + get_object_tracker(env, thiz)->GetObject(id_str)->GetMatchScore().value; + env->ReleaseStringUTFChars(object_id, id_str); + return match_score; +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)( + JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array) { + jboolean iCopied = JNI_FALSE; + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + const BoundingBox bounding_box = + get_object_tracker(env, thiz)->GetObject(id_str)->GetPosition(); + env->ReleaseStringUTFChars(object_id, id_str); + + jfloat* rect = env->GetFloatArrayElements(rect_array, &iCopied); + bounding_box.CopyToArray(reinterpret_cast<float*>(rect)); + env->ReleaseFloatArrayElements(rect_array, rect, 0); +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz, + jbyteArray y_data, + jbyteArray uv_data, + jlong timestamp, + jfloatArray vg_matrix_2x3) { + TimeLog("Starting object tracker"); + + jboolean iCopied = JNI_FALSE; + + float vision_gyro_matrix_array[6]; + jfloat* jmat = NULL; + + if (vg_matrix_2x3 != NULL) { + // Copy the alignment matrix into a float array. + jmat = env->GetFloatArrayElements(vg_matrix_2x3, &iCopied); + for (int i = 0; i < 6; ++i) { + vision_gyro_matrix_array[i] = static_cast<float>(jmat[i]); + } + } + // Copy image into currFrame. + jbyte* pixels = env->GetByteArrayElements(y_data, &iCopied); + jbyte* uv_pixels = + uv_data != NULL ? env->GetByteArrayElements(uv_data, &iCopied) : NULL; + + TimeLog("Got elements"); + + // Add the frame to the object tracker object. + get_object_tracker(env, thiz)->NextFrame( + reinterpret_cast<uint8*>(pixels), reinterpret_cast<uint8*>(uv_pixels), + timestamp, vg_matrix_2x3 != NULL ? vision_gyro_matrix_array : NULL); + + env->ReleaseByteArrayElements(y_data, pixels, JNI_ABORT); + + if (uv_data != NULL) { + env->ReleaseByteArrayElements(uv_data, uv_pixels, JNI_ABORT); + } + + if (vg_matrix_2x3 != NULL) { + env->ReleaseFloatArrayElements(vg_matrix_2x3, jmat, JNI_ABORT); + } + + TimeLog("Released elements"); + + PrintTimeLog(); + ResetTimeLog(); +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz, + jstring object_id) { + const char* const id_str = env->GetStringUTFChars(object_id, 0); + + get_object_tracker(env, thiz)->ForgetTarget(id_str); + + env->ReleaseStringUTFChars(object_id, id_str); +} + +JNIEXPORT +jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)( + JNIEnv* env, jobject thiz, jboolean only_found) { + jfloat keypoint_arr[kMaxKeypoints * kKeypointStep]; + + const int number_of_keypoints = + get_object_tracker(env, thiz)->GetKeypoints(only_found, keypoint_arr); + + // Create and return the array that will be passed back to Java. + jfloatArray keypoints = + env->NewFloatArray(number_of_keypoints * kKeypointStep); + if (keypoints == NULL) { + LOGE("null array!"); + return NULL; + } + env->SetFloatArrayRegion(keypoints, 0, number_of_keypoints * kKeypointStep, + keypoint_arr); + + return keypoints; +} + +JNIEXPORT +jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)( + JNIEnv* env, jobject thiz, jfloat scale_factor) { + // 2 bytes to a uint16 and two pairs of xy coordinates per keypoint. + const int bytes_per_keypoint = sizeof(uint16) * 2 * 2; + jbyte keypoint_arr[kMaxKeypoints * bytes_per_keypoint]; + + const int number_of_keypoints = + get_object_tracker(env, thiz)->GetKeypointsPacked( + reinterpret_cast<uint16*>(keypoint_arr), scale_factor); + + // Create and return the array that will be passed back to Java. + jbyteArray keypoints = + env->NewByteArray(number_of_keypoints * bytes_per_keypoint); + + if (keypoints == NULL) { + LOGE("null array!"); + return NULL; + } + + env->SetByteArrayRegion( + keypoints, 0, number_of_keypoints * bytes_per_keypoint, keypoint_arr); + + return keypoints; +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)( + JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1, + jfloat position_y1, jfloat position_x2, jfloat position_y2, + jfloatArray delta) { + jfloat point_arr[4]; + + const BoundingBox new_position = get_object_tracker(env, thiz)->TrackBox( + BoundingBox(position_x1, position_y1, position_x2, position_y2), + timestamp); + + new_position.CopyToArray(point_arr); + env->SetFloatArrayRegion(delta, 0, 4, point_arr); +} + +JNIEXPORT +void JNICALL OBJECT_TRACKER_METHOD(drawNative)( + JNIEnv* env, jobject thiz, jint view_width, jint view_height, + jfloatArray frame_to_canvas_arr) { + ObjectTracker* object_tracker = get_object_tracker(env, thiz); + if (object_tracker != NULL) { + jfloat* frame_to_canvas = + env->GetFloatArrayElements(frame_to_canvas_arr, NULL); + + object_tracker->Draw(view_width, view_height, frame_to_canvas); + env->ReleaseFloatArrayElements(frame_to_canvas_arr, frame_to_canvas, + JNI_ABORT); + } +} + +JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)( + JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride, + jbyteArray input, jint factor, jbyteArray output) { + if (input == NULL || output == NULL) { + LOGW("Received null arrays, hopefully this is a test!"); + return; + } + + jbyte* const input_array = env->GetByteArrayElements(input, 0); + jbyte* const output_array = env->GetByteArrayElements(output, 0); + + { + tf_tracking::Image<uint8> full_image( + width, height, reinterpret_cast<uint8*>(input_array), false); + + const int new_width = (width + factor - 1) / factor; + const int new_height = (height + factor - 1) / factor; + + tf_tracking::Image<uint8> downsampled_image( + new_width, new_height, reinterpret_cast<uint8*>(output_array), false); + + downsampled_image.DownsampleAveraged(reinterpret_cast<uint8*>(input_array), + row_stride, factor); + } + + env->ReleaseByteArrayElements(input, input_array, JNI_ABORT); + env->ReleaseByteArrayElements(output, output_array, 0); +} + +} // namespace tf_tracking diff --git a/tensorflow/examples/android/jni/object_tracking/optical_flow.cc b/tensorflow/examples/android/jni/object_tracking/optical_flow.cc new file mode 100644 index 0000000000..fab0a3155d --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/optical_flow.cc @@ -0,0 +1,490 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <math.h> + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h" +#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h" +#include "tensorflow/examples/android/jni/object_tracking/image_data.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" +#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" + +namespace tf_tracking { + +OpticalFlow::OpticalFlow(const OpticalFlowConfig* const config) + : config_(config), + frame1_(NULL), + frame2_(NULL), + working_size_(config->image_size) {} + + +void OpticalFlow::NextFrame(const ImageData* const image_data) { + // Special case for the first frame: make sure the image ends up in + // frame1_ so that keypoint detection can be done on it if desired. + frame1_ = (frame1_ == NULL) ? image_data : frame2_; + frame2_ = image_data; +} + + +// Static heart of the optical flow computation. +// Lucas Kanade algorithm. +bool OpticalFlow::FindFlowAtPoint_LK(const Image<uint8>& img_I, + const Image<uint8>& img_J, + const Image<int32>& I_x, + const Image<int32>& I_y, + const float p_x, + const float p_y, + float* out_g_x, + float* out_g_y) { + float g_x = *out_g_x; + float g_y = *out_g_y; + // Get values for frame 1. They remain constant through the inner + // iteration loop. + float vals_I[kFlowArraySize]; + float vals_I_x[kFlowArraySize]; + float vals_I_y[kFlowArraySize]; + + const int kPatchSize = 2 * kFlowIntegrationWindowSize + 1; + const float kWindowSizeFloat = static_cast<float>(kFlowIntegrationWindowSize); + +#if USE_FIXED_POINT_FLOW + const int fixed_x_max = RealToFixed1616(img_I.width_less_one_) - 1; + const int fixed_y_max = RealToFixed1616(img_I.height_less_one_) - 1; +#else + const float real_x_max = I_x.width_less_one_ - EPSILON; + const float real_y_max = I_x.height_less_one_ - EPSILON; +#endif + + // Get the window around the original point. + const float src_left_real = p_x - kWindowSizeFloat; + const float src_top_real = p_y - kWindowSizeFloat; + float* vals_I_ptr = vals_I; + float* vals_I_x_ptr = vals_I_x; + float* vals_I_y_ptr = vals_I_y; +#if USE_FIXED_POINT_FLOW + // Source integer coordinates. + const int src_left_fixed = RealToFixed1616(src_left_real); + const int src_top_fixed = RealToFixed1616(src_top_real); + + for (int y = 0; y < kPatchSize; ++y) { + const int fp_y = Clip(src_top_fixed + (y << 16), 0, fixed_y_max); + + for (int x = 0; x < kPatchSize; ++x) { + const int fp_x = Clip(src_left_fixed + (x << 16), 0, fixed_x_max); + + *vals_I_ptr++ = img_I.GetPixelInterpFixed1616(fp_x, fp_y); + *vals_I_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y); + *vals_I_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y); + } + } +#else + for (int y = 0; y < kPatchSize; ++y) { + const float y_pos = Clip(src_top_real + y, 0.0f, real_y_max); + + for (int x = 0; x < kPatchSize; ++x) { + const float x_pos = Clip(src_left_real + x, 0.0f, real_x_max); + + *vals_I_ptr++ = img_I.GetPixelInterp(x_pos, y_pos); + *vals_I_x_ptr++ = I_x.GetPixelInterp(x_pos, y_pos); + *vals_I_y_ptr++ = I_y.GetPixelInterp(x_pos, y_pos); + } + } +#endif + + // Compute the spatial gradient matrix about point p. + float G[] = { 0, 0, 0, 0 }; + CalculateG(vals_I_x, vals_I_y, kFlowArraySize, G); + + // Find the inverse of G. + float G_inv[4]; + if (!Invert2x2(G, G_inv)) { + return false; + } + +#if NORMALIZE + const float mean_I = ComputeMean(vals_I, kFlowArraySize); + const float std_dev_I = ComputeStdDev(vals_I, kFlowArraySize, mean_I); +#endif + + // Iterate kNumIterations times or until we converge. + for (int iteration = 0; iteration < kNumIterations; ++iteration) { + // Get values for frame 2. + float vals_J[kFlowArraySize]; + + // Get the window around the destination point. + const float left_real = p_x + g_x - kWindowSizeFloat; + const float top_real = p_y + g_y - kWindowSizeFloat; + float* vals_J_ptr = vals_J; +#if USE_FIXED_POINT_FLOW + // The top-left sub-pixel is set for the current iteration (in 16:16 + // fixed). This is constant over one iteration. + const int left_fixed = RealToFixed1616(left_real); + const int top_fixed = RealToFixed1616(top_real); + + for (int win_y = 0; win_y < kPatchSize; ++win_y) { + const int fp_y = Clip(top_fixed + (win_y << 16), 0, fixed_y_max); + for (int win_x = 0; win_x < kPatchSize; ++win_x) { + const int fp_x = Clip(left_fixed + (win_x << 16), 0, fixed_x_max); + *vals_J_ptr++ = img_J.GetPixelInterpFixed1616(fp_x, fp_y); + } + } +#else + for (int win_y = 0; win_y < kPatchSize; ++win_y) { + const float y_pos = Clip(top_real + win_y, 0.0f, real_y_max); + for (int win_x = 0; win_x < kPatchSize; ++win_x) { + const float x_pos = Clip(left_real + win_x, 0.0f, real_x_max); + *vals_J_ptr++ = img_J.GetPixelInterp(x_pos, y_pos); + } + } +#endif + +#if NORMALIZE + const float mean_J = ComputeMean(vals_J, kFlowArraySize); + const float std_dev_J = ComputeStdDev(vals_J, kFlowArraySize, mean_J); + + // TODO(andrewharp): Probably better to completely detect and handle the + // "corner case" where the patch is fully outside the image diagonally. + const float std_dev_ratio = std_dev_J > 0.0f ? std_dev_I / std_dev_J : 1.0f; +#endif + + // Compute image mismatch vector. + float b_x = 0.0f; + float b_y = 0.0f; + + vals_I_ptr = vals_I; + vals_J_ptr = vals_J; + vals_I_x_ptr = vals_I_x; + vals_I_y_ptr = vals_I_y; + + for (int win_y = 0; win_y < kPatchSize; ++win_y) { + for (int win_x = 0; win_x < kPatchSize; ++win_x) { +#if NORMALIZE + // Normalized Image difference. + const float dI = + (*vals_I_ptr++ - mean_I) - (*vals_J_ptr++ - mean_J) * std_dev_ratio; +#else + const float dI = *vals_I_ptr++ - *vals_J_ptr++; +#endif + b_x += dI * *vals_I_x_ptr++; + b_y += dI * *vals_I_y_ptr++; + } + } + + // Optical flow... solve n = G^-1 * b + const float n_x = (G_inv[0] * b_x) + (G_inv[1] * b_y); + const float n_y = (G_inv[2] * b_x) + (G_inv[3] * b_y); + + // Update best guess with residual displacement from this level and + // iteration. + g_x += n_x; + g_y += n_y; + + // LOGV("Iteration %d: delta (%.3f, %.3f)", iteration, n_x, n_y); + + // Abort early if we're already below the threshold. + if (Square(n_x) + Square(n_y) < Square(kTrackingAbortThreshold)) { + break; + } + } // Iteration. + + // Copy value back into output. + *out_g_x = g_x; + *out_g_y = g_y; + return true; +} + + +// Pointwise flow using translational 2dof ESM. +bool OpticalFlow::FindFlowAtPoint_ESM(const Image<uint8>& img_I, + const Image<uint8>& img_J, + const Image<int32>& I_x, + const Image<int32>& I_y, + const Image<int32>& J_x, + const Image<int32>& J_y, + const float p_x, + const float p_y, + float* out_g_x, + float* out_g_y) { + float g_x = *out_g_x; + float g_y = *out_g_y; + const float area_inv = 1.0f / static_cast<float>(kFlowArraySize); + + // Get values for frame 1. They remain constant through the inner + // iteration loop. + uint8 vals_I[kFlowArraySize]; + uint8 vals_J[kFlowArraySize]; + int16 src_gradient_x[kFlowArraySize]; + int16 src_gradient_y[kFlowArraySize]; + + // TODO(rspring): try out the IntegerPatchAlign() method once + // the code for that is in ../common. + const float wsize_float = static_cast<float>(kFlowIntegrationWindowSize); + const int src_left_fixed = RealToFixed1616(p_x - wsize_float); + const int src_top_fixed = RealToFixed1616(p_y - wsize_float); + const int patch_size = 2 * kFlowIntegrationWindowSize + 1; + + // Create the keypoint template patch from a subpixel location. + if (!img_I.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed, + patch_size, patch_size, vals_I) || + !I_x.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed, + patch_size, patch_size, + src_gradient_x) || + !I_y.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed, + patch_size, patch_size, + src_gradient_y)) { + return false; + } + + int bright_offset = 0; + int sum_diff = 0; + + // The top-left sub-pixel is set for the current iteration (in 16:16 fixed). + // This is constant over one iteration. + int left_fixed = RealToFixed1616(p_x + g_x - wsize_float); + int top_fixed = RealToFixed1616(p_y + g_y - wsize_float); + + // The truncated version gives the most top-left pixel that is used. + int left_trunc = left_fixed >> 16; + int top_trunc = top_fixed >> 16; + + // Compute an initial brightness offset. + if (kDoBrightnessNormalize && + left_trunc >= 0 && top_trunc >= 0 && + (left_trunc + patch_size) < img_J.width_less_one_ && + (top_trunc + patch_size) < img_J.height_less_one_) { + int templ_index = 0; + const uint8* j_row = img_J[top_trunc] + left_trunc; + + const int j_stride = img_J.stride(); + + for (int y = 0; y < patch_size; ++y, j_row += j_stride) { + for (int x = 0; x < patch_size; ++x) { + sum_diff += static_cast<int>(j_row[x]) - vals_I[templ_index++]; + } + } + + bright_offset = static_cast<int>(static_cast<float>(sum_diff) * area_inv); + } + + // Iterate kNumIterations times or until we go out of image. + for (int iteration = 0; iteration < kNumIterations; ++iteration) { + int jtj[3] = { 0, 0, 0 }; + int jtr[2] = { 0, 0 }; + sum_diff = 0; + + // Extract the target image values. + // Extract the gradient from the target image patch and accumulate to + // the gradient of the source image patch. + if (!img_J.ExtractPatchAtSubpixelFixed1616(left_fixed, top_fixed, + patch_size, patch_size, + vals_J)) { + break; + } + + const uint8* templ_row = vals_I; + const uint8* extract_row = vals_J; + const int16* src_dx_row = src_gradient_x; + const int16* src_dy_row = src_gradient_y; + + for (int y = 0; y < patch_size; ++y, templ_row += patch_size, + src_dx_row += patch_size, src_dy_row += patch_size, + extract_row += patch_size) { + const int fp_y = top_fixed + (y << 16); + for (int x = 0; x < patch_size; ++x) { + const int fp_x = left_fixed + (x << 16); + int32 target_dx = J_x.GetPixelInterpFixed1616(fp_x, fp_y); + int32 target_dy = J_y.GetPixelInterpFixed1616(fp_x, fp_y); + + // Combine the two Jacobians. + // Right-shift by one to account for the fact that we add + // two Jacobians. + int32 dx = (src_dx_row[x] + target_dx) >> 1; + int32 dy = (src_dy_row[x] + target_dy) >> 1; + + // The current residual b - h(q) == extracted - (template + offset) + int32 diff = static_cast<int32>(extract_row[x]) - + static_cast<int32>(templ_row[x]) - + bright_offset; + + jtj[0] += dx * dx; + jtj[1] += dx * dy; + jtj[2] += dy * dy; + + jtr[0] += dx * diff; + jtr[1] += dy * diff; + + sum_diff += diff; + } + } + + const float jtr1_float = static_cast<float>(jtr[0]); + const float jtr2_float = static_cast<float>(jtr[1]); + + // Add some baseline stability to the system. + jtj[0] += kEsmRegularizer; + jtj[2] += kEsmRegularizer; + + const int64 prod1 = static_cast<int64>(jtj[0]) * jtj[2]; + const int64 prod2 = static_cast<int64>(jtj[1]) * jtj[1]; + + // One ESM step. + const float jtj_1[4] = { static_cast<float>(jtj[2]), + static_cast<float>(-jtj[1]), + static_cast<float>(-jtj[1]), + static_cast<float>(jtj[0]) }; + const double det_inv = 1.0 / static_cast<double>(prod1 - prod2); + + g_x -= det_inv * (jtj_1[0] * jtr1_float + jtj_1[1] * jtr2_float); + g_y -= det_inv * (jtj_1[2] * jtr1_float + jtj_1[3] * jtr2_float); + + if (kDoBrightnessNormalize) { + bright_offset += + static_cast<int>(area_inv * static_cast<float>(sum_diff) + 0.5f); + } + + // Update top left position. + left_fixed = RealToFixed1616(p_x + g_x - wsize_float); + top_fixed = RealToFixed1616(p_y + g_y - wsize_float); + + left_trunc = left_fixed >> 16; + top_trunc = top_fixed >> 16; + + // Abort iterations if we go out of borders. + if (left_trunc < 0 || top_trunc < 0 || + (left_trunc + patch_size) >= J_x.width_less_one_ || + (top_trunc + patch_size) >= J_y.height_less_one_) { + break; + } + } // Iteration. + + // Copy value back into output. + *out_g_x = g_x; + *out_g_y = g_y; + return true; +} + + +bool OpticalFlow::FindFlowAtPointReversible( + const int level, const float u_x, const float u_y, + const bool reverse_flow, + float* flow_x, float* flow_y) const { + const ImageData& frame_a = reverse_flow ? *frame2_ : *frame1_; + const ImageData& frame_b = reverse_flow ? *frame1_ : *frame2_; + + // Images I (prev) and J (next). + const Image<uint8>& img_I = *frame_a.GetPyramidSqrt2Level(level * 2); + const Image<uint8>& img_J = *frame_b.GetPyramidSqrt2Level(level * 2); + + // Computed gradients. + const Image<int32>& I_x = *frame_a.GetSpatialX(level); + const Image<int32>& I_y = *frame_a.GetSpatialY(level); + const Image<int32>& J_x = *frame_b.GetSpatialX(level); + const Image<int32>& J_y = *frame_b.GetSpatialY(level); + + // Shrink factor from original. + const float shrink_factor = (1 << level); + + // Image position vector (p := u^l), scaled for this level. + const float scaled_p_x = u_x / shrink_factor; + const float scaled_p_y = u_y / shrink_factor; + + float scaled_flow_x = *flow_x / shrink_factor; + float scaled_flow_y = *flow_y / shrink_factor; + + // LOGE("FindFlowAtPoint level %d: %5.2f, %5.2f (%5.2f, %5.2f)", level, + // scaled_p_x, scaled_p_y, &scaled_flow_x, &scaled_flow_y); + + const bool success = kUseEsm ? + FindFlowAtPoint_ESM(img_I, img_J, I_x, I_y, J_x, J_y, + scaled_p_x, scaled_p_y, + &scaled_flow_x, &scaled_flow_y) : + FindFlowAtPoint_LK(img_I, img_J, I_x, I_y, + scaled_p_x, scaled_p_y, + &scaled_flow_x, &scaled_flow_y); + + *flow_x = scaled_flow_x * shrink_factor; + *flow_y = scaled_flow_y * shrink_factor; + + return success; +} + + +bool OpticalFlow::FindFlowAtPointSingleLevel( + const int level, + const float u_x, const float u_y, + const bool filter_by_fb_error, + float* flow_x, float* flow_y) const { + if (!FindFlowAtPointReversible(level, u_x, u_y, false, flow_x, flow_y)) { + return false; + } + + if (filter_by_fb_error) { + const float new_position_x = u_x + *flow_x; + const float new_position_y = u_y + *flow_y; + + float reverse_flow_x = 0.0f; + float reverse_flow_y = 0.0f; + + // Now find the backwards flow and confirm it lines up with the original + // starting point. + if (!FindFlowAtPointReversible(level, new_position_x, new_position_y, + true, + &reverse_flow_x, &reverse_flow_y)) { + LOGE("Backward error!"); + return false; + } + + const float discrepancy_length = + sqrtf(Square(*flow_x + reverse_flow_x) + + Square(*flow_y + reverse_flow_y)); + + const float flow_length = sqrtf(Square(*flow_x) + Square(*flow_y)); + + return discrepancy_length < + (kMaxForwardBackwardErrorAllowed * flow_length); + } + + return true; +} + + +// An implementation of the Pyramidal Lucas-Kanade Optical Flow algorithm. +// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for details. +bool OpticalFlow::FindFlowAtPointPyramidal(const float u_x, const float u_y, + const bool filter_by_fb_error, + float* flow_x, float* flow_y) const { + const int max_level = MAX(kMinNumPyramidLevelsToUseForAdjustment, + kNumPyramidLevels - kNumCacheLevels); + + // For every level in the pyramid, update the coordinates of the best match. + for (int l = max_level - 1; l >= 0; --l) { + if (!FindFlowAtPointSingleLevel(l, u_x, u_y, + filter_by_fb_error, flow_x, flow_y)) { + return false; + } + } + + return true; +} + +} // namespace tf_tracking diff --git a/tensorflow/examples/android/jni/object_tracking/optical_flow.h b/tensorflow/examples/android/jni/object_tracking/optical_flow.h new file mode 100644 index 0000000000..1329927b99 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/optical_flow.h @@ -0,0 +1,111 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ + +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h" +#include "tensorflow/examples/android/jni/object_tracking/image_data.h" +#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" + +using namespace tensorflow; + +namespace tf_tracking { + +class FlowCache; + +// Class encapsulating all the data and logic necessary for performing optical +// flow. +class OpticalFlow { + public: + explicit OpticalFlow(const OpticalFlowConfig* const config); + + // Add a new frame to the optical flow. Will update all the non-keypoint + // related member variables. + // + // new_frame should be a buffer of grayscale values, one byte per pixel, + // at the original frame_width and frame_height used to initialize the + // OpticalFlow object. Downsampling will be handled internally. + // + // time_stamp should be a time in milliseconds that later calls to this and + // other methods will be relative to. + void NextFrame(const ImageData* const image_data); + + // An implementation of the Lucas-Kanade Optical Flow algorithm. + static bool FindFlowAtPoint_LK(const Image<uint8>& img_I, + const Image<uint8>& img_J, + const Image<int32>& I_x, + const Image<int32>& I_y, + const float p_x, + const float p_y, + float* out_g_x, + float* out_g_y); + + // Pointwise flow using translational 2dof ESM. + static bool FindFlowAtPoint_ESM(const Image<uint8>& img_I, + const Image<uint8>& img_J, + const Image<int32>& I_x, + const Image<int32>& I_y, + const Image<int32>& J_x, + const Image<int32>& J_y, + const float p_x, + const float p_y, + float* out_g_x, + float* out_g_y); + + // Finds the flow using a specific level, in either direction. + // If reversed, the coordinates are in the context of the latest + // frame, not the frame before it. + // All coordinates used in parameters are global, not scaled. + bool FindFlowAtPointReversible( + const int level, const float u_x, const float u_y, + const bool reverse_flow, + float* final_x, float* final_y) const; + + // Finds the flow using a specific level, filterable by forward-backward + // error. All coordinates used in parameters are global, not scaled. + bool FindFlowAtPointSingleLevel(const int level, + const float u_x, const float u_y, + const bool filter_by_fb_error, + float* flow_x, float* flow_y) const; + + // Pyramidal optical-flow using all levels. + bool FindFlowAtPointPyramidal(const float u_x, const float u_y, + const bool filter_by_fb_error, + float* flow_x, float* flow_y) const; + + private: + const OpticalFlowConfig* const config_; + + const ImageData* frame1_; + const ImageData* frame2_; + + // Size of the internally allocated images (after original is downsampled). + const Size working_size_; + + TF_DISALLOW_COPY_AND_ASSIGN(OpticalFlow); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/sprite.h b/tensorflow/examples/android/jni/object_tracking/sprite.h new file mode 100755 index 0000000000..6240591cf2 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/sprite.h @@ -0,0 +1,205 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ + +#include <GLES/gl.h> +#include <GLES/glext.h> + +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" + +#ifndef __RENDER_OPENGL__ +#error sprite.h should not included if OpenGL is not enabled by platform.h +#endif + +namespace tf_tracking { + +// This class encapsulates the logic necessary to load an render image data +// at the same aspect ratio as the original source. +class Sprite { + public: + // Only create Sprites when you have an OpenGl context. + explicit Sprite(const Image<uint8>& image) { + LoadTexture(image, NULL); + } + + Sprite(const Image<uint8>& image, const BoundingBox* const area) { + LoadTexture(image, area); + } + + // Also, try to only delete a Sprite when holding an OpenGl context. + ~Sprite() { + glDeleteTextures(1, &texture_); + } + + inline int GetWidth() const { + return actual_width_; + } + + inline int GetHeight() const { + return actual_height_; + } + + // Draw the sprite at 0,0 - original width/height in the current reference + // frame. Any transformations desired must be applied before calling this + // function. + void Draw() const { + const float float_width = static_cast<float>(actual_width_); + const float float_height = static_cast<float>(actual_height_); + + // Where it gets rendered to. + const float vertices[] = { 0.0f, 0.0f, 0.0f, + 0.0f, float_height, 0.0f, + float_width, 0.0f, 0.0f, + float_width, float_height, 0.0f, + }; + + // The coordinates the texture gets drawn from. + const float max_x = float_width / texture_width_; + const float max_y = float_height / texture_height_; + const float textureVertices[] = { + 0, 0, + 0, max_y, + max_x, 0, + max_x, max_y, + }; + + glEnable(GL_TEXTURE_2D); + glBindTexture(GL_TEXTURE_2D, texture_); + + glEnableClientState(GL_VERTEX_ARRAY); + glEnableClientState(GL_TEXTURE_COORD_ARRAY); + + glVertexPointer(3, GL_FLOAT, 0, vertices); + glTexCoordPointer(2, GL_FLOAT, 0, textureVertices); + + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + + glDisableClientState(GL_VERTEX_ARRAY); + glDisableClientState(GL_TEXTURE_COORD_ARRAY); + } + + private: + inline int GetNextPowerOfTwo(const int number) const { + int power_of_two = 1; + while (power_of_two < number) { + power_of_two *= 2; + } + return power_of_two; + } + + // TODO(andrewharp): Allow sprites to have their textures reloaded. + void LoadTexture(const Image<uint8>& texture_source, + const BoundingBox* const area) { + glEnable(GL_TEXTURE_2D); + + glGenTextures(1, &texture_); + + glBindTexture(GL_TEXTURE_2D, texture_); + + int left = 0; + int top = 0; + + if (area != NULL) { + // If a sub-region was provided to pull the texture from, use that. + left = area->left_; + top = area->top_; + actual_width_ = area->GetWidth(); + actual_height_ = area->GetHeight(); + } else { + actual_width_ = texture_source.GetWidth(); + actual_height_ = texture_source.GetHeight(); + } + + // The textures must be a power of two, so find the sizes that are large + // enough to contain the image data. + texture_width_ = GetNextPowerOfTwo(actual_width_); + texture_height_ = GetNextPowerOfTwo(actual_height_); + + bool allocated_data = false; + uint8* texture_data; + + // Except in the lucky case where we're not using a sub-region of the + // original image AND the source data has dimensions that are power of two, + // care must be taken to copy data at the appropriate source and destination + // strides so that the final block can be copied directly into texture + // memory. + // TODO(andrewharp): Figure out if data can be pulled directly from the + // source image with some alignment modifications. + if (left != 0 || top != 0 || + actual_width_ != texture_source.GetWidth() || + actual_height_ != texture_source.GetHeight()) { + texture_data = new uint8[actual_width_ * actual_height_]; + + for (int y = 0; y < actual_height_; ++y) { + memcpy(texture_data + actual_width_ * y, + texture_source[top + y] + left, + actual_width_ * sizeof(uint8)); + } + allocated_data = true; + } else { + // Cast away const-ness because for some reason glTexSubImage2D wants + // a non-const data pointer. + texture_data = const_cast<uint8*>(texture_source.data()); + } + + glTexImage2D(GL_TEXTURE_2D, + 0, + GL_LUMINANCE, + texture_width_, + texture_height_, + 0, + GL_LUMINANCE, + GL_UNSIGNED_BYTE, + NULL); + + glPixelStorei(GL_UNPACK_ALIGNMENT, 1); + glTexSubImage2D(GL_TEXTURE_2D, + 0, + 0, + 0, + actual_width_, + actual_height_, + GL_LUMINANCE, + GL_UNSIGNED_BYTE, + texture_data); + + if (allocated_data) { + delete(texture_data); + } + + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR); + } + + // The id for the texture on the GPU. + GLuint texture_; + + // The width and height to be used for display purposes, referring to the + // dimensions of the original texture. + int actual_width_; + int actual_height_; + + // The allocated dimensions of the texture data, which must be powers of 2. + int texture_width_; + int texture_height_; + + TF_DISALLOW_COPY_AND_ASSIGN(Sprite); +}; + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/time_log.cc b/tensorflow/examples/android/jni/object_tracking/time_log.cc new file mode 100644 index 0000000000..cb1f3c23c8 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/time_log.cc @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/time_log.h" + +using namespace tensorflow; + +#ifdef LOG_TIME +// Storage for logging functionality. +int num_time_logs = 0; +LogEntry time_logs[NUM_LOGS]; + +int num_avg_entries = 0; +AverageEntry avg_entries[NUM_LOGS]; +#endif diff --git a/tensorflow/examples/android/jni/object_tracking/time_log.h b/tensorflow/examples/android/jni/object_tracking/time_log.h new file mode 100644 index 0000000000..ec539a1b3b --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/time_log.h @@ -0,0 +1,138 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utility functions for performance profiling. + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_ + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/examples/android/jni/object_tracking/log_streaming.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +#ifdef LOG_TIME + +// Blend constant for running average. +#define ALPHA 0.98f +#define NUM_LOGS 100 + +struct LogEntry { + const char* id; + int64 time_stamp; +}; + +struct AverageEntry { + const char* id; + float average_duration; +}; + +// Storage for keeping track of this frame's values. +extern int num_time_logs; +extern LogEntry time_logs[NUM_LOGS]; + +// Storage for keeping track of average values (each entry may not be printed +// out each frame). +extern AverageEntry avg_entries[NUM_LOGS]; +extern int num_avg_entries; + +// Call this at the start of a logging phase. +inline static void ResetTimeLog() { + num_time_logs = 0; +} + + +// Log a message to be printed out when printTimeLog is called, along with the +// amount of time in ms that has passed since the last call to this function. +inline static void TimeLog(const char* const str) { + LOGV("%s", str); + if (num_time_logs >= NUM_LOGS) { + LOGE("Out of log entries!"); + return; + } + + time_logs[num_time_logs].id = str; + time_logs[num_time_logs].time_stamp = CurrentThreadTimeNanos(); + ++num_time_logs; +} + + +inline static float Blend(float old_val, float new_val) { + return ALPHA * old_val + (1.0f - ALPHA) * new_val; +} + + +inline static float UpdateAverage(const char* str, const float new_val) { + for (int entry_num = 0; entry_num < num_avg_entries; ++entry_num) { + AverageEntry* const entry = avg_entries + entry_num; + if (str == entry->id) { + entry->average_duration = Blend(entry->average_duration, new_val); + return entry->average_duration; + } + } + + if (num_avg_entries >= NUM_LOGS) { + LOGE("Too many log entries!"); + } + + // If it wasn't there already, add it. + avg_entries[num_avg_entries].id = str; + avg_entries[num_avg_entries].average_duration = new_val; + ++num_avg_entries; + + return new_val; +} + + +// Prints out all the timeLog statements in chronological order with the +// interval that passed between subsequent statements. The total time between +// the first and last statements is printed last. +inline static void PrintTimeLog() { + LogEntry* last_time = time_logs; + + float average_running_total = 0.0f; + + for (int i = 0; i < num_time_logs; ++i) { + LogEntry* const this_time = time_logs + i; + + const float curr_time = + (this_time->time_stamp - last_time->time_stamp) / 1000000.0f; + + const float avg_time = UpdateAverage(this_time->id, curr_time); + average_running_total += avg_time; + + LOGD("%32s: %6.3fms %6.4fms", this_time->id, curr_time, avg_time); + last_time = this_time; + } + + const float total_time = + (last_time->time_stamp - time_logs->time_stamp) / 1000000.0f; + + LOGD("TOTAL TIME: %6.3fms %6.4fms\n", + total_time, average_running_total); + LOGD(" "); +} +#else +inline static void ResetTimeLog() {} + +inline static void TimeLog(const char* const str) { + LOGV("%s", str); +} + +inline static void PrintTimeLog() {} +#endif + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/tracked_object.cc b/tensorflow/examples/android/jni/object_tracking/tracked_object.cc new file mode 100644 index 0000000000..823fb3a90e --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/tracked_object.cc @@ -0,0 +1,163 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h" + +namespace tf_tracking { + +static const float kInitialDistance = 20.0f; + +static void InitNormalized(const Image<uint8>& src_image, + const BoundingBox& position, + Image<float>* const dst_image) { + BoundingBox scaled_box(position); + CopyArea(src_image, scaled_box, dst_image); + NormalizeImage(dst_image); +} + +TrackedObject::TrackedObject(const std::string& id, + const Image<uint8>& image, + const BoundingBox& bounding_box, + ObjectModelBase* const model) + : id_(id), + last_known_position_(bounding_box), + last_detection_position_(bounding_box), + position_last_computed_time_(-1), + object_model_(model), + last_detection_thumbnail_(kNormalizedThumbnailSize, + kNormalizedThumbnailSize), + last_frame_thumbnail_(kNormalizedThumbnailSize, kNormalizedThumbnailSize), + tracked_correlation_(0.0f), + tracked_match_score_(0.0), + num_consecutive_frames_below_threshold_(0), + allowable_detection_distance_(Square(kInitialDistance)) { + InitNormalized(image, bounding_box, &last_detection_thumbnail_); +} + +TrackedObject::~TrackedObject() {} + +void TrackedObject::UpdatePosition(const BoundingBox& new_position, + const int64 timestamp, + const ImageData& image_data, + const bool authoratative) { + last_known_position_ = new_position; + position_last_computed_time_ = timestamp; + + InitNormalized(*image_data.GetImage(), new_position, &last_frame_thumbnail_); + + const float last_localization_correlation = ComputeCrossCorrelation( + last_detection_thumbnail_.data(), + last_frame_thumbnail_.data(), + last_frame_thumbnail_.data_size_); + LOGV("Tracked correlation to last localization: %.6f", + last_localization_correlation); + + // Correlation to object model, if it exists. + if (object_model_ != NULL) { + tracked_correlation_ = + object_model_->GetMaxCorrelation(last_frame_thumbnail_); + LOGV("Tracked correlation to model: %.6f", + tracked_correlation_); + + tracked_match_score_ = + object_model_->GetMatchScore(new_position, image_data); + LOGV("Tracked match score with model: %.6f", + tracked_match_score_.value); + } else { + // If there's no model to check against, set the tracked correlation to + // simply be the correlation to the last set position. + tracked_correlation_ = last_localization_correlation; + tracked_match_score_ = MatchScore(0.0f); + } + + // Determine if it's still being tracked. + if (tracked_correlation_ >= kMinimumCorrelationForTracking && + tracked_match_score_ >= kMinimumMatchScore) { + num_consecutive_frames_below_threshold_ = 0; + + if (object_model_ != NULL) { + object_model_->TrackStep(last_known_position_, *image_data.GetImage(), + *image_data.GetIntegralImage(), authoratative); + } + } else if (tracked_match_score_ < kMatchScoreForImmediateTermination) { + if (num_consecutive_frames_below_threshold_ < 1000) { + LOGD("Tracked match score is way too low (%.6f), aborting track.", + tracked_match_score_.value); + } + + // Add an absurd amount of missed frames so that all heuristics will + // consider it a lost track. + num_consecutive_frames_below_threshold_ += 1000; + + if (object_model_ != NULL) { + object_model_->TrackLost(); + } + } else { + ++num_consecutive_frames_below_threshold_; + allowable_detection_distance_ *= 1.1f; + } +} + +void TrackedObject::OnDetection(ObjectModelBase* const model, + const BoundingBox& detection_position, + const MatchScore match_score, + const int64 timestamp, + const ImageData& image_data) { + const float overlap = detection_position.PascalScore(last_known_position_); + if (overlap > kPositionOverlapThreshold) { + // If the position agreement with the current tracked position is good + // enough, lock all the current unlocked examples. + object_model_->TrackConfirmed(); + num_consecutive_frames_below_threshold_ = 0; + } + + // Before relocalizing, make sure the new proposed position is better than + // the existing position by a small amount to prevent thrashing. + if (match_score <= tracked_match_score_ + kMatchScoreBuffer) { + LOGI("Not relocalizing since new match is worse: %.6f < %.6f + %.6f", + match_score.value, tracked_match_score_.value, + kMatchScoreBuffer.value); + return; + } + + LOGI("Relocalizing! From (%.1f, %.1f)[%.1fx%.1f] to " + "(%.1f, %.1f)[%.1fx%.1f]: %.6f > %.6f", + last_known_position_.left_, last_known_position_.top_, + last_known_position_.GetWidth(), last_known_position_.GetHeight(), + detection_position.left_, detection_position.top_, + detection_position.GetWidth(), detection_position.GetHeight(), + match_score.value, tracked_match_score_.value); + + if (overlap < kPositionOverlapThreshold) { + // The path might be good, it might be bad, but it's no longer a path + // since we're moving the box to a new position, so just nuke it from + // orbit to be safe. + object_model_->TrackLost(); + } + + object_model_ = model; + + // Reset the last detected appearance. + InitNormalized( + *image_data.GetImage(), detection_position, &last_detection_thumbnail_); + + num_consecutive_frames_below_threshold_ = 0; + last_detection_position_ = detection_position; + + UpdatePosition(detection_position, timestamp, image_data, false); + allowable_detection_distance_ = Square(kInitialDistance); +} + +} // namespace tf_tracking diff --git a/tensorflow/examples/android/jni/object_tracking/tracked_object.h b/tensorflow/examples/android/jni/object_tracking/tracked_object.h new file mode 100644 index 0000000000..5580cd2b89 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/tracked_object.h @@ -0,0 +1,191 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ + +#ifdef __RENDER_OPENGL__ +#include "tensorflow/examples/android/jni/object_tracking/gl_utils.h" +#endif +#include "tensorflow/examples/android/jni/object_tracking/object_detector.h" + +namespace tf_tracking { + +// A TrackedObject is a specific instance of an ObjectModel, with a known +// position in the world. +// It provides the last known position and number of recent detection failures, +// in addition to the more general appearance data associated with the object +// class (which is in ObjectModel). +// TODO(andrewharp): Make getters/setters follow styleguide. +class TrackedObject { + public: + TrackedObject(const std::string& id, + const Image<uint8>& image, + const BoundingBox& bounding_box, + ObjectModelBase* const model); + + ~TrackedObject(); + + void UpdatePosition(const BoundingBox& new_position, + const int64 timestamp, + const ImageData& image_data, + const bool authoratative); + + // This method is called when the tracked object is detected at a + // given position, and allows the associated Model to grow and/or prune + // itself based on where the detection occurred. + void OnDetection(ObjectModelBase* const model, + const BoundingBox& detection_position, + const MatchScore match_score, + const int64 timestamp, + const ImageData& image_data); + + // Called when there's no detection of the tracked object. This will cause + // a tracking failure after enough consecutive failures if the area under + // the current bounding box also doesn't meet a minimum correlation threshold + // with the model. + void OnDetectionFailure() {} + + inline bool IsVisible() const { + return tracked_correlation_ >= kMinimumCorrelationForTracking || + num_consecutive_frames_below_threshold_ < kMaxNumDetectionFailures; + } + + inline float GetCorrelation() { + return tracked_correlation_; + } + + inline MatchScore GetMatchScore() { + return tracked_match_score_; + } + + inline BoundingBox GetPosition() const { + return last_known_position_; + } + + inline BoundingBox GetLastDetectionPosition() const { + return last_detection_position_; + } + + inline const ObjectModelBase* GetModel() const { + return object_model_; + } + + inline const std::string& GetName() const { + return id_; + } + + inline void Draw() const { +#ifdef __RENDER_OPENGL__ + if (tracked_correlation_ < kMinimumCorrelationForTracking) { + glColor4f(MAX(0.0f, -tracked_correlation_), + MAX(0.0f, tracked_correlation_), + 0.0f, + 1.0f); + } else { + glColor4f(MAX(0.0f, -tracked_correlation_), + MAX(0.0f, tracked_correlation_), + 1.0f, + 1.0f); + } + + // Render the box itself. + BoundingBox temp_box(last_known_position_); + DrawBox(temp_box); + + // Render a box inside this one (in case the actual box is hidden). + const float kBufferSize = 1.0f; + temp_box.left_ -= kBufferSize; + temp_box.top_ -= kBufferSize; + temp_box.right_ += kBufferSize; + temp_box.bottom_ += kBufferSize; + DrawBox(temp_box); + + // Render one outside as well. + temp_box.left_ -= -2.0f * kBufferSize; + temp_box.top_ -= -2.0f * kBufferSize; + temp_box.right_ += -2.0f * kBufferSize; + temp_box.bottom_ += -2.0f * kBufferSize; + DrawBox(temp_box); +#endif + } + + // Get current object's num_consecutive_frames_below_threshold_. + inline int64 GetNumConsecutiveFramesBelowThreshold() { + return num_consecutive_frames_below_threshold_; + } + + // Reset num_consecutive_frames_below_threshold_ to 0. + inline void resetNumConsecutiveFramesBelowThreshold() { + num_consecutive_frames_below_threshold_ = 0; + } + + inline float GetAllowableDistanceSquared() const { + return allowable_detection_distance_; + } + + private: + // The unique id used throughout the system to identify this + // tracked object. + const std::string id_; + + // The last known position of the object. + BoundingBox last_known_position_; + + // The last known position of the object. + BoundingBox last_detection_position_; + + // When the position was last computed. + int64 position_last_computed_time_; + + // The object model this tracked object is representative of. + ObjectModelBase* object_model_; + + Image<float> last_detection_thumbnail_; + + Image<float> last_frame_thumbnail_; + + // The correlation of the object model with the preview frame at its last + // tracked position. + float tracked_correlation_; + + MatchScore tracked_match_score_; + + // The number of consecutive frames that the tracked position for this object + // has been under the correlation threshold. + int num_consecutive_frames_below_threshold_; + + float allowable_detection_distance_; + + friend std::ostream& operator<<(std::ostream& stream, + const TrackedObject& tracked_object); + + TF_DISALLOW_COPY_AND_ASSIGN(TrackedObject); +}; + +inline std::ostream& operator<<(std::ostream& stream, + const TrackedObject& tracked_object) { + stream << tracked_object.id_ + << " " << tracked_object.last_known_position_ + << " " << tracked_object.position_last_computed_time_ + << " " << tracked_object.num_consecutive_frames_below_threshold_ + << " " << tracked_object.object_model_ + << " " << tracked_object.tracked_correlation_; + return stream; +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/utils.h b/tensorflow/examples/android/jni/object_tracking/utils.h new file mode 100644 index 0000000000..cbdfc408c6 --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/utils.h @@ -0,0 +1,386 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_ + +#include <math.h> +#include <stdlib.h> +#include <time.h> + +#include <cmath> // for std::abs(float) + +#ifndef HAVE_CLOCK_GETTIME +// Use gettimeofday() instead of clock_gettime(). +#include <sys/time.h> +#endif // ifdef HAVE_CLOCK_GETTIME + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +using namespace tensorflow; + +// TODO(andrewharp): clean up these macros to use the codebase statndard. + +// A very small number, generally used as the tolerance for accumulated +// floating point errors in bounds-checks. +#define EPSILON 0.00001f + +#define SAFE_DELETE(pointer) {\ + if ((pointer) != NULL) {\ + LOGV("Safe deleting pointer: %s", #pointer);\ + delete (pointer);\ + (pointer) = NULL;\ + } else {\ + LOGV("Pointer already null: %s", #pointer);\ + }\ +} + + +#ifdef __GOOGLE__ + +#define CHECK_ALWAYS(condition, format, ...) {\ + CHECK(condition) << StringPrintf(format, ##__VA_ARGS__);\ +} + +#define SCHECK(condition, format, ...) {\ + DCHECK(condition) << StringPrintf(format, ##__VA_ARGS__);\ +} + +#else + +#define CHECK_ALWAYS(condition, format, ...) {\ + if (!(condition)) {\ + LOGE("CHECK FAILED (%s): " format, #condition, ##__VA_ARGS__);\ + abort();\ + }\ +} + +#ifdef SANITY_CHECKS +#define SCHECK(condition, format, ...) {\ + CHECK_ALWAYS(condition, format, ##__VA_ARGS__);\ +} +#else +#define SCHECK(condition, format, ...) {} +#endif // SANITY_CHECKS + +#endif // __GOOGLE__ + + +#ifndef MAX +#define MAX(a, b) (((a) > (b)) ? (a) : (b)) +#endif +#ifndef MIN +#define MIN(a, b) (((a) > (b)) ? (b) : (a)) +#endif + + + +inline static int64 CurrentThreadTimeNanos() { +#ifdef HAVE_CLOCK_GETTIME + struct timespec tm; + clock_gettime(CLOCK_THREAD_CPUTIME_ID, &tm); + return tm.tv_sec * 1000000000LL + tm.tv_nsec; +#else + struct timeval tv; + gettimeofday(&tv, NULL); + return tv.tv_sec * 1000000000 + tv.tv_usec * 1000; +#endif +} + + +inline static int64 CurrentRealTimeMillis() { +#ifdef HAVE_CLOCK_GETTIME + struct timespec tm; + clock_gettime(CLOCK_MONOTONIC, &tm); + return tm.tv_sec * 1000LL + tm.tv_nsec / 1000000LL; +#else + struct timeval tv; + gettimeofday(&tv, NULL); + return tv.tv_sec * 1000 + tv.tv_usec / 1000; +#endif +} + + +template<typename T> +inline static T Square(const T a) { + return a * a; +} + + +template<typename T> +inline static T Clip(const T a, const T floor, const T ceil) { + SCHECK(ceil >= floor, "Bounds mismatch!"); + return (a <= floor) ? floor : ((a >= ceil) ? ceil : a); +} + + +template<typename T> +inline static int Floor(const T a) { + return static_cast<int>(a); +} + + +template<typename T> +inline static int Ceil(const T a) { + return Floor(a) + 1; +} + + +template<typename T> +inline static bool InRange(const T a, const T min, const T max) { + return (a >= min) && (a <= max); +} + + +inline static bool ValidIndex(const int a, const int max) { + return (a >= 0) && (a < max); +} + + +inline bool NearlyEqual(const float a, const float b, const float tolerance) { + return std::abs(a - b) < tolerance; +} + + +inline bool NearlyEqual(const float a, const float b) { + return NearlyEqual(a, b, EPSILON); +} + + +template<typename T> +inline static int Round(const float a) { + return (a - static_cast<float>(floor(a) > 0.5f) ? ceil(a) : floor(a)); +} + + +template<typename T> +inline static void Swap(T* const a, T* const b) { + // Cache out the VALUE of what's at a. + T tmp = *a; + *a = *b; + + *b = tmp; +} + + +static inline float randf() { + return rand() / static_cast<float>(RAND_MAX); +} + +static inline float randf(const float min_value, const float max_value) { + return randf() * (max_value - min_value) + min_value; +} + +static inline uint16 RealToFixed115(const float real_number) { + SCHECK(InRange(real_number, 0.0f, 2048.0f), + "Value out of range! %.2f", real_number); + + static const float kMult = 32.0f; + const float round_add = (real_number > 0.0f) ? 0.5f : -0.5f; + return static_cast<uint16>(real_number * kMult + round_add); +} + +static inline float FixedToFloat115(const uint16 fp_number) { + const float kDiv = 32.0f; + return (static_cast<float>(fp_number) / kDiv); +} + +static inline int RealToFixed1616(const float real_number) { + static const float kMult = 65536.0f; + SCHECK(InRange(real_number, -kMult, kMult), + "Value out of range! %.2f", real_number); + + const float round_add = (real_number > 0.0f) ? 0.5f : -0.5f; + return static_cast<int>(real_number * kMult + round_add); +} + +static inline float FixedToFloat1616(const int fp_number) { + const float kDiv = 65536.0f; + return (static_cast<float>(fp_number) / kDiv); +} + +template<typename T> +// produces numbers in range [0,2*M_PI] (rather than -PI,PI) +inline T FastAtan2(const T y, const T x) { + static const T coeff_1 = (T)(M_PI / 4.0); + static const T coeff_2 = (T)(3.0 * coeff_1); + const T abs_y = fabs(y); + T angle; + if (x >= 0) { + T r = (x - abs_y) / (x + abs_y); + angle = coeff_1 - coeff_1 * r; + } else { + T r = (x + abs_y) / (abs_y - x); + angle = coeff_2 - coeff_1 * r; + } + static const T PI_2 = 2.0 * M_PI; + return y < 0 ? PI_2 - angle : angle; +} + +#define NELEMS(X) (sizeof(X) / sizeof(X[0])) + +namespace tf_tracking { + +#ifdef __ARM_NEON +float ComputeMeanNeon(const float* const values, const int num_vals); + +float ComputeStdDevNeon(const float* const values, const int num_vals, + const float mean); + +float ComputeWeightedMeanNeon(const float* const values, + const float* const weights, const int num_vals); + +float ComputeCrossCorrelationNeon(const float* const values1, + const float* const values2, + const int num_vals); +#endif + +inline float ComputeMeanCpu(const float* const values, const int num_vals) { + // Get mean. + float sum = values[0]; + for (int i = 1; i < num_vals; ++i) { + sum += values[i]; + } + return sum / static_cast<float>(num_vals); +} + + +inline float ComputeMean(const float* const values, const int num_vals) { + return +#ifdef __ARM_NEON + (num_vals >= 8) ? ComputeMeanNeon(values, num_vals) : +#endif + ComputeMeanCpu(values, num_vals); +} + + +inline float ComputeStdDevCpu(const float* const values, + const int num_vals, + const float mean) { + // Get Std dev. + float squared_sum = 0.0f; + for (int i = 0; i < num_vals; ++i) { + squared_sum += Square(values[i] - mean); + } + return sqrt(squared_sum / static_cast<float>(num_vals)); +} + + +inline float ComputeStdDev(const float* const values, + const int num_vals, + const float mean) { + return +#ifdef __ARM_NEON + (num_vals >= 8) ? ComputeStdDevNeon(values, num_vals, mean) : +#endif + ComputeStdDevCpu(values, num_vals, mean); +} + + +// TODO(andrewharp): Accelerate with NEON. +inline float ComputeWeightedMean(const float* const values, + const float* const weights, + const int num_vals) { + float sum = 0.0f; + float total_weight = 0.0f; + for (int i = 0; i < num_vals; ++i) { + sum += values[i] * weights[i]; + total_weight += weights[i]; + } + return sum / num_vals; +} + + +inline float ComputeCrossCorrelationCpu(const float* const values1, + const float* const values2, + const int num_vals) { + float sxy = 0.0f; + for (int offset = 0; offset < num_vals; ++offset) { + sxy += values1[offset] * values2[offset]; + } + + const float cross_correlation = sxy / num_vals; + + return cross_correlation; +} + + +inline float ComputeCrossCorrelation(const float* const values1, + const float* const values2, + const int num_vals) { + return +#ifdef __ARM_NEON + (num_vals >= 8) ? ComputeCrossCorrelationNeon(values1, values2, num_vals) + : +#endif + ComputeCrossCorrelationCpu(values1, values2, num_vals); +} + + +inline void NormalizeNumbers(float* const values, const int num_vals) { + // Find the mean and then subtract so that the new mean is 0.0. + const float mean = ComputeMean(values, num_vals); + VLOG(2) << "Mean is " << mean; + float* curr_data = values; + for (int i = 0; i < num_vals; ++i) { + *curr_data -= mean; + curr_data++; + } + + // Now divide by the std deviation so the new standard deviation is 1.0. + // The numbers might all be identical (and thus shifted to 0.0 now), + // so only scale by the standard deviation if this is not the case. + const float std_dev = ComputeStdDev(values, num_vals, 0.0f); + if (std_dev > 0.0f) { + VLOG(2) << "Std dev is " << std_dev; + curr_data = values; + for (int i = 0; i < num_vals; ++i) { + *curr_data /= std_dev; + curr_data++; + } + } +} + + +// Returns the determinant of a 2x2 matrix. +template<class T> +inline T FindDeterminant2x2(const T* const a) { + // Determinant: (ad - bc) + return a[0] * a[3] - a[1] * a[2]; +} + + +// Finds the inverse of a 2x2 matrix. +// Returns true upon success, false if the matrix is not invertible. +template<class T> +inline bool Invert2x2(const T* const a, float* const a_inv) { + const float det = static_cast<float>(FindDeterminant2x2(a)); + if (fabs(det) < EPSILON) { + return false; + } + const float inv_det = 1.0f / det; + + a_inv[0] = inv_det * static_cast<float>(a[3]); // d + a_inv[1] = inv_det * static_cast<float>(-a[1]); // -b + a_inv[2] = inv_det * static_cast<float>(-a[2]); // -c + a_inv[3] = inv_det * static_cast<float>(a[0]); // a + + return true; +} + +} // namespace tf_tracking + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_ diff --git a/tensorflow/examples/android/jni/object_tracking/utils_neon.cc b/tensorflow/examples/android/jni/object_tracking/utils_neon.cc new file mode 100755 index 0000000000..5a5250e32e --- /dev/null +++ b/tensorflow/examples/android/jni/object_tracking/utils_neon.cc @@ -0,0 +1,151 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// NEON implementations of Image methods for compatible devices. Control +// should never enter this compilation unit on incompatible devices. + +#ifdef __ARM_NEON + +#include <arm_neon.h> + +#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" +#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/examples/android/jni/object_tracking/utils.h" + +namespace tf_tracking { + +inline static float GetSum(const float32x4_t& values) { + static float32_t summed_values[4]; + vst1q_f32(summed_values, values); + return summed_values[0] + + summed_values[1] + + summed_values[2] + + summed_values[3]; +} + + +float ComputeMeanNeon(const float* const values, const int num_vals) { + SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals); + + const float32_t* const arm_vals = (const float32_t* const) values; + float32x4_t accum = vdupq_n_f32(0.0f); + + int offset = 0; + for (; offset <= num_vals - 4; offset += 4) { + accum = vaddq_f32(accum, vld1q_f32(&arm_vals[offset])); + } + + // Pull the accumulated values into a single variable. + float sum = GetSum(accum); + + // Get the remaining 1 to 3 values. + for (; offset < num_vals; ++offset) { + sum += values[offset]; + } + + const float mean_neon = sum / static_cast<float>(num_vals); + +#ifdef SANITY_CHECKS + const float mean_cpu = ComputeMeanCpu(values, num_vals); + SCHECK(NearlyEqual(mean_neon, mean_cpu, EPSILON * num_vals), + "Neon mismatch with CPU mean! %.10f vs %.10f", + mean_neon, mean_cpu); +#endif + + return mean_neon; +} + + +float ComputeStdDevNeon(const float* const values, + const int num_vals, const float mean) { + SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals); + + const float32_t* const arm_vals = (const float32_t* const) values; + const float32x4_t mean_vec = vdupq_n_f32(-mean); + + float32x4_t accum = vdupq_n_f32(0.0f); + + int offset = 0; + for (; offset <= num_vals - 4; offset += 4) { + const float32x4_t deltas = + vaddq_f32(mean_vec, vld1q_f32(&arm_vals[offset])); + + accum = vmlaq_f32(accum, deltas, deltas); + } + + // Pull the accumulated values into a single variable. + float squared_sum = GetSum(accum); + + // Get the remaining 1 to 3 values. + for (; offset < num_vals; ++offset) { + squared_sum += Square(values[offset] - mean); + } + + const float std_dev_neon = sqrt(squared_sum / static_cast<float>(num_vals)); + +#ifdef SANITY_CHECKS + const float std_dev_cpu = ComputeStdDevCpu(values, num_vals, mean); + SCHECK(NearlyEqual(std_dev_neon, std_dev_cpu, EPSILON * num_vals), + "Neon mismatch with CPU std dev! %.10f vs %.10f", + std_dev_neon, std_dev_cpu); +#endif + + return std_dev_neon; +} + + +float ComputeCrossCorrelationNeon(const float* const values1, + const float* const values2, + const int num_vals) { + SCHECK(num_vals >= 8, "Not enough values to merit NEON: %d", num_vals); + + const float32_t* const arm_vals1 = (const float32_t* const) values1; + const float32_t* const arm_vals2 = (const float32_t* const) values2; + + float32x4_t accum = vdupq_n_f32(0.0f); + + int offset = 0; + for (; offset <= num_vals - 4; offset += 4) { + accum = vmlaq_f32(accum, + vld1q_f32(&arm_vals1[offset]), + vld1q_f32(&arm_vals2[offset])); + } + + // Pull the accumulated values into a single variable. + float sxy = GetSum(accum); + + // Get the remaining 1 to 3 values. + for (; offset < num_vals; ++offset) { + sxy += values1[offset] * values2[offset]; + } + + const float cross_correlation_neon = sxy / num_vals; + +#ifdef SANITY_CHECKS + const float cross_correlation_cpu = + ComputeCrossCorrelationCpu(values1, values2, num_vals); + SCHECK(NearlyEqual(cross_correlation_neon, cross_correlation_cpu, + EPSILON * num_vals), + "Neon mismatch with CPU cross correlation! %.10f vs %.10f", + cross_correlation_neon, cross_correlation_cpu); +#endif + + return cross_correlation_neon; +} + +} // namespace tf_tracking + +#endif // __ARM_NEON diff --git a/tensorflow/examples/android/proto/box_coder.proto b/tensorflow/examples/android/proto/box_coder.proto new file mode 100644 index 0000000000..8576294110 --- /dev/null +++ b/tensorflow/examples/android/proto/box_coder.proto @@ -0,0 +1,42 @@ +syntax = "proto2"; + +package org_tensorflow_demo; + +// Prior for a single feature (like minimum x coordinate, width, area, etc.) +message BoxCoderPrior { + optional float mean = 1 [default = 0.0]; + optional float stddev = 2 [default = 1.0]; +}; + +// Box encoding/decoding configuration for a single box. +message BoxCoderOptions { + // Number of priors must match the number of values used to encoded + // values which is derived from the use_... flags below. + repeated BoxCoderPrior priors = 1; + + // Minimum/maximum X/Y of the four corners are used as features. + // Order: MinX, MinY, MaxX, MaxY. + // Number of values: 4. + optional bool use_corners = 2 [default = true]; + + // Width and height of the box in this order. + // Number of values: 2. + optional bool use_width_height = 3 [default = false]; + + // Coordinates of the center of the box. + // Order: X, Y. + // Number of values: 2. + optional bool use_center = 4 [default = false]; + + // Area of the box. + // Number of values: 1. + optional bool use_area = 5 [default = false]; +}; + +// Options for MultiBoxCoder which is a encoder/decoder for a fixed number of +// boxes. +// A list of BoxCoderOptions that allows for storing multiple box coder options +// in a single file. +message MultiBoxCoderOptions { + repeated BoxCoderOptions box_coder = 1; +}; diff --git a/tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml b/tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml new file mode 100644 index 0000000000..674f25785a --- /dev/null +++ b/tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml @@ -0,0 +1,30 @@ +<?xml version="1.0" encoding="utf-8"?><!-- + Copyright 2016 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--> +<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android" + android:layout_width="match_parent" + android:layout_height="match_parent"> + + <org.tensorflow.demo.AutoFitTextureView + android:id="@+id/texture" + android:layout_width="wrap_content" + android:layout_height="wrap_content"/> + + <org.tensorflow.demo.OverlayView + android:id="@+id/overlay" + android:layout_width="match_parent" + android:layout_height="match_parent"/> + +</FrameLayout> diff --git a/tensorflow/examples/android/res/values/base-strings.xml b/tensorflow/examples/android/res/values/base-strings.xml index 93cfe0dac2..f6c57d5030 100644 --- a/tensorflow/examples/android/res/values/base-strings.xml +++ b/tensorflow/examples/android/res/values/base-strings.xml @@ -1,6 +1,6 @@ <?xml version="1.0" encoding="UTF-8"?> <!-- - Copyright 2013 The TensorFlow Authors. All Rights Reserved. + Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,5 +17,6 @@ <resources> <string name="app_name">TensorFlow Demo</string> - <string name="activity_name_classification">TF Classification</string> + <string name="activity_name_classification">TF Classify</string> + <string name="activity_name_detection">TF Detect</string> </resources> diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java b/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java index e498c9e28f..2f16ded6c2 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java @@ -17,7 +17,6 @@ package org.tensorflow.demo; import android.graphics.Bitmap; import android.graphics.RectF; - import java.util.List; /** @@ -44,10 +43,8 @@ public interface Classifier { */ private final Float confidence; - /** - * Optional location within the source image for the location of the recognized object. - */ - private final RectF location; + /** Optional location within the source image for the location of the recognized object. */ + private RectF location; public Recognition( final String id, final String title, final Float confidence, final RectF location) { @@ -73,6 +70,10 @@ public interface Classifier { return new RectF(location); } + public void setLocation(RectF location) { + this.location = location; + } + @Override public String toString() { String resultString = ""; diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java new file mode 100644 index 0000000000..d75136485a --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java @@ -0,0 +1,317 @@ +/* + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.demo; + +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Matrix; +import android.graphics.Paint; +import android.graphics.Paint.Style; +import android.graphics.RectF; +import android.media.Image; +import android.media.Image.Plane; +import android.media.ImageReader; +import android.media.ImageReader.OnImageAvailableListener; +import android.os.SystemClock; +import android.os.Trace; +import android.util.Size; +import android.util.TypedValue; +import android.view.Display; +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import java.util.Vector; +import org.tensorflow.demo.OverlayView.DrawCallback; +import org.tensorflow.demo.env.BorderedText; +import org.tensorflow.demo.env.ImageUtils; +import org.tensorflow.demo.env.Logger; +import org.tensorflow.demo.tracking.MultiBoxTracker; + +/** + * An activity that uses a TensorFlowMultiboxDetector and ObjectTracker to detect and then track + * objects. + */ +public class DetectorActivity extends CameraActivity implements OnImageAvailableListener { + private static final Logger LOGGER = new Logger(); + + private static final int NUM_LOCATIONS = 784; + private static final int INPUT_SIZE = 224; + private static final int IMAGE_MEAN = 128; + private static final float IMAGE_STD = 128; + private static final String INPUT_NAME = "ResizeBilinear"; + private static final String OUTPUT_NAMES = "output_locations/Reshape,output_scores/Reshape"; + + private static final String MODEL_FILE = "file:///android_asset/multibox_model.pb"; + private static final String LOCATION_FILE = "file:///android_asset/multibox_location_priors.pb"; + + // Minimum detection confidence to track a detection. + private static final float MINIMUM_CONFIDENCE = 0.1f; + + private static final boolean SAVE_PREVIEW_BITMAP = false; + + private static final boolean MAINTAIN_ASPECT = false; + + private static final float TEXT_SIZE_DIP = 18; + + private Integer sensorOrientation; + + private TensorFlowMultiBoxDetector detector; + + private int previewWidth = 0; + private int previewHeight = 0; + private byte[][] yuvBytes; + private int[] rgbBytes = null; + private Bitmap rgbFrameBitmap = null; + private Bitmap croppedBitmap = null; + + private boolean computing = false; + + private long timestamp = 0; + + private Matrix frameToCropTransform; + private Matrix cropToFrameTransform; + + private Bitmap cropCopyBitmap; + + private MultiBoxTracker tracker; + + private byte[] luminance; + + private BorderedText borderedText; + + private long lastProcessingTimeMs; + + @Override + public void onPreviewSizeChosen(final Size size, final int rotation) { + final float textSizePx = + TypedValue.applyDimension( + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); + borderedText = new BorderedText(textSizePx); + + tracker = new MultiBoxTracker(getResources().getDisplayMetrics()); + + detector = new TensorFlowMultiBoxDetector(); + try { + detector.initializeTensorFlow( + getAssets(), + MODEL_FILE, + LOCATION_FILE, + NUM_LOCATIONS, + INPUT_SIZE, + IMAGE_MEAN, + IMAGE_STD, + INPUT_NAME, + OUTPUT_NAMES); + } catch (final IOException e) { + LOGGER.e(e, "Exception!"); + } + + previewWidth = size.getWidth(); + previewHeight = size.getHeight(); + + final Display display = getWindowManager().getDefaultDisplay(); + final int screenOrientation = display.getRotation(); + + LOGGER.i("Sensor orientation: %d, Screen orientation: %d", rotation, screenOrientation); + + sensorOrientation = rotation + screenOrientation; + + LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); + rgbBytes = new int[previewWidth * previewHeight]; + rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); + croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888); + + frameToCropTransform = + ImageUtils.getTransformationMatrix( + previewWidth, previewHeight, + INPUT_SIZE, INPUT_SIZE, + sensorOrientation, MAINTAIN_ASPECT); + + cropToFrameTransform = new Matrix(); + frameToCropTransform.invert(cropToFrameTransform); + yuvBytes = new byte[3][]; + + addCallback( + new DrawCallback() { + @Override + public void drawCallback(final Canvas canvas) { + final Bitmap copy = cropCopyBitmap; + + tracker.draw(canvas); + + if (!isDebug()) { + return; + } + + tracker.drawDebug(canvas); + + if (copy != null) { + final Matrix matrix = new Matrix(); + final float scaleFactor = 2; + matrix.postScale(scaleFactor, scaleFactor); + matrix.postTranslate( + canvas.getWidth() - copy.getWidth() * scaleFactor, + canvas.getHeight() - copy.getHeight() * scaleFactor); + canvas.drawBitmap(copy, matrix, new Paint()); + + final Vector<String> lines = new Vector<String>(); + lines.add("Frame: " + previewWidth + "x" + previewHeight); + lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight()); + lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight()); + lines.add("Rotation: " + sensorOrientation); + lines.add("Inference time: " + lastProcessingTimeMs + "ms"); + + int lineNum = 0; + for (final String line : lines) { + borderedText.drawText( + canvas, + 10, + canvas.getHeight() - 10 - borderedText.getTextSize() * lineNum, + line); + ++lineNum; + } + } + } + }); + } + + @Override + public void onImageAvailable(final ImageReader reader) { + Image image = null; + + ++timestamp; + final long currTimestamp = timestamp; + + try { + image = reader.acquireLatestImage(); + + if (image == null) { + return; + } + + Trace.beginSection("imageAvailable"); + + final Plane[] planes = image.getPlanes(); + fillBytes(planes, yuvBytes); + + tracker.onFrame( + previewWidth, + previewHeight, + planes[0].getRowStride(), + sensorOrientation, + yuvBytes[0], + timestamp); + + requestRender(); + + // No mutex needed as this method is not reentrant. + if (computing) { + image.close(); + return; + } + computing = true; + + final int yRowStride = planes[0].getRowStride(); + final int uvRowStride = planes[1].getRowStride(); + final int uvPixelStride = planes[1].getPixelStride(); + ImageUtils.convertYUV420ToARGB8888( + yuvBytes[0], + yuvBytes[1], + yuvBytes[2], + rgbBytes, + previewWidth, + previewHeight, + yRowStride, + uvRowStride, + uvPixelStride, + false); + + image.close(); + } catch (final Exception e) { + if (image != null) { + image.close(); + } + LOGGER.e(e, "Exception!"); + Trace.endSection(); + return; + } + + rgbFrameBitmap.setPixels(rgbBytes, 0, previewWidth, 0, 0, previewWidth, previewHeight); + final Canvas canvas = new Canvas(croppedBitmap); + canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null); + + // For examining the actual TF input. + if (SAVE_PREVIEW_BITMAP) { + ImageUtils.saveBitmap(croppedBitmap); + } + + if (luminance == null) { + luminance = new byte[yuvBytes[0].length]; + } + System.arraycopy(yuvBytes[0], 0, luminance, 0, luminance.length); + + runInBackground( + new Runnable() { + @Override + public void run() { + final long startTime = SystemClock.uptimeMillis(); + final List<Classifier.Recognition> results = detector.recognizeImage(croppedBitmap); + lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; + + cropCopyBitmap = Bitmap.createBitmap(croppedBitmap); + final Canvas canvas = new Canvas(cropCopyBitmap); + final Paint paint = new Paint(); + paint.setColor(Color.RED); + paint.setStyle(Style.STROKE); + paint.setStrokeWidth(2.0f); + + final List<Classifier.Recognition> mappedRecognitions = + new LinkedList<Classifier.Recognition>(); + + for (final Classifier.Recognition result : results) { + final RectF location = result.getLocation(); + if (location != null && result.getConfidence() >= MINIMUM_CONFIDENCE) { + canvas.drawRect(location, paint); + + cropToFrameTransform.mapRect(location); + result.setLocation(location); + mappedRecognitions.add(result); + } + } + + tracker.trackResults(mappedRecognitions, luminance, currTimestamp); + + requestRender(); + computing = false; + } + }); + + Trace.endSection(); + } + + @Override + protected int getLayoutId() { + return R.layout.camera_connection_fragment_tracking; + } + + @Override + protected int getDesiredPreviewFrameSize() { + return INPUT_SIZE; + } +} diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java new file mode 100644 index 0000000000..66e25304d3 --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java @@ -0,0 +1,218 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.demo; + +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.RectF; +import android.os.Trace; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.PriorityQueue; +import org.tensorflow.contrib.android.TensorFlowInferenceInterface; +import org.tensorflow.demo.env.Logger; + +/** + * A detector for general purpose object detection as described in Scalable Object Detection using + * Deep Neural Networks (https://arxiv.org/abs/1312.2249). + */ +public class TensorFlowMultiBoxDetector implements Classifier { + private static final Logger LOGGER = new Logger(); + + static { + System.loadLibrary("tensorflow_demo"); + } + + // Only return this many results with at least this confidence. + private static final int MAX_RESULTS = Integer.MAX_VALUE; + + // Config values. + private String inputName; + private int inputSize; + private int imageMean; + private float imageStd; + + // Pre-allocated buffers. + private int[] intValues; + private float[] floatValues; + private float[] outputLocations; + private float[] outputScores; + private String[] outputNames; + private int numLocations; + + private TensorFlowInferenceInterface inferenceInterface; + + private float[] boxPriors; + + /** + * Initializes a native TensorFlow session for classifying images. + * + * @param assetManager The asset manager to be used to load assets. + * @param modelFilename The filepath of the model GraphDef protocol buffer. + * @param locationFilename The filepath of label file for classes. + * @param inputSize The input size. A square image of inputSize x inputSize is assumed. + * @param imageMean The assumed mean of the image values. + * @param imageStd The assumed std of the image values. + * @param inputName The label of the image input node. + * @param outputName The label of the output node. + * @return The native return value, 0 indicating success. + * @throws IOException + */ + public int initializeTensorFlow( + final AssetManager assetManager, + final String modelFilename, + final String locationFilename, + final int numLocations, + final int inputSize, + final int imageMean, + final float imageStd, + final String inputName, + final String outputName) + throws IOException { + this.inputName = inputName; + this.inputSize = inputSize; + this.imageMean = imageMean; + this.imageStd = imageStd; + this.numLocations = numLocations; + + this.boxPriors = new float[numLocations * 8]; + + loadCoderOptions(assetManager, locationFilename, boxPriors); + + // Pre-allocate buffers. + outputNames = outputName.split(","); + intValues = new int[inputSize * inputSize]; + floatValues = new float[inputSize * inputSize * 3]; + outputScores = new float[numLocations]; + outputLocations = new float[numLocations * 4]; + + inferenceInterface = new TensorFlowInferenceInterface(); + + return inferenceInterface.initializeTensorFlow(assetManager, modelFilename); + } + + // Load BoxCoderOptions from native code. + private native void loadCoderOptions( + AssetManager assetManager, String locationFilename, float[] boxPriors); + + private float[] decodeLocationsEncoding(final float[] locationEncoding) { + final float[] locations = new float[locationEncoding.length]; + boolean nonZero = false; + for (int i = 0; i < numLocations; ++i) { + for (int j = 0; j < 4; ++j) { + final float currEncoding = locationEncoding[4 * i + j]; + nonZero = nonZero || currEncoding != 0.0f; + + final float mean = boxPriors[i * 8 + j * 2]; + final float stdDev = boxPriors[i * 8 + j * 2 + 1]; + float currentLocation = currEncoding * stdDev + mean; + currentLocation = Math.max(currentLocation, 0.0f); + currentLocation = Math.min(currentLocation, 1.0f); + locations[4 * i + j] = currentLocation; + } + } + + if (!nonZero) { + LOGGER.w("No non-zero encodings; check log for inference errors."); + } + return locations; + } + + private float[] decodeScoresEncoding(final float[] scoresEncoding) { + final float[] scores = new float[scoresEncoding.length]; + for (int i = 0; i < scoresEncoding.length; ++i) { + scores[i] = 1 / ((float) (1 + Math.exp(-scoresEncoding[i]))); + } + return scores; + } + + @Override + public List<Recognition> recognizeImage(final Bitmap bitmap) { + // Log this method so that it can be analyzed with systrace. + Trace.beginSection("recognizeImage"); + + Trace.beginSection("preprocessBitmap"); + // Preprocess the image data from 0-255 int to normalized float based + // on the provided parameters. + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); + + for (int i = 0; i < intValues.length; ++i) { + floatValues[i * 3 + 0] = ((intValues[i] & 0xFF) - imageMean) / imageStd; + floatValues[i * 3 + 1] = (((intValues[i] >> 8) & 0xFF) - imageMean) / imageStd; + floatValues[i * 3 + 2] = (((intValues[i] >> 16) & 0xFF) - imageMean) / imageStd; + } + Trace.endSection(); // preprocessBitmap + + // Copy the input data into TensorFlow. + Trace.beginSection("fillNodeFloat"); + inferenceInterface.fillNodeFloat( + inputName, new int[] {1, inputSize, inputSize, 3}, floatValues); + Trace.endSection(); + + // Run the inference call. + Trace.beginSection("runInference"); + inferenceInterface.runInference(outputNames); + Trace.endSection(); + + // Copy the output Tensor back into the output array. + Trace.beginSection("readNodeFloat"); + final float[] outputScoresEncoding = new float[numLocations]; + final float[] outputLocationsEncoding = new float[numLocations * 4]; + inferenceInterface.readNodeFloat(outputNames[0], outputLocationsEncoding); + inferenceInterface.readNodeFloat(outputNames[1], outputScoresEncoding); + Trace.endSection(); + + outputLocations = decodeLocationsEncoding(outputLocationsEncoding); + outputScores = decodeScoresEncoding(outputScoresEncoding); + + // Find the best detections. + final PriorityQueue<Recognition> pq = + new PriorityQueue<Recognition>( + 1, + new Comparator<Recognition>() { + @Override + public int compare(final Recognition lhs, final Recognition rhs) { + // Intentionally reversed to put high confidence at the head of the queue. + return Float.compare(rhs.getConfidence(), lhs.getConfidence()); + } + }); + + // Scale them back to the input size. + for (int i = 0; i < outputScores.length; ++i) { + final RectF detection = + new RectF( + outputLocations[4 * i] * inputSize, + outputLocations[4 * i + 1] * inputSize, + outputLocations[4 * i + 2] * inputSize, + outputLocations[4 * i + 3] * inputSize); + pq.add(new Recognition("" + i, "" + i, outputScores[i], detection)); + } + + final ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); + for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) { + recognitions.add(pq.poll()); + } + Trace.endSection(); // "recognizeImage" + return recognitions; + } + + @Override + public void close() { + inferenceInterface.close(); + } +} diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java new file mode 100644 index 0000000000..24e5cb57df --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java @@ -0,0 +1,381 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.demo.tracking; + +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Matrix; +import android.graphics.Paint; +import android.graphics.Paint.Cap; +import android.graphics.Paint.Join; +import android.graphics.Paint.Style; +import android.graphics.RectF; +import android.util.DisplayMetrics; +import android.util.Pair; +import android.util.TypedValue; +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; + +import org.tensorflow.demo.Classifier.Recognition; +import org.tensorflow.demo.env.BorderedText; +import org.tensorflow.demo.env.ImageUtils; +import org.tensorflow.demo.env.Logger; + +/** + * A tracker wrapping ObjectTracker that also handles non-max suppression and matching existing + * objects to new detections. + */ +public class MultiBoxTracker { + private final Logger logger = new Logger(); + + private static final float TEXT_SIZE_DIP = 18; + + // Maximum percentage of a box that can be overlapped by another box at detection time. Otherwise + // the lower scored box (new or old) will be removed. + private static final float MAX_OVERLAP = 0.35f; + + private static final float MIN_SIZE = 16.0f; + + // Allow replacement of the tracked box with new results if + // correlation has dropped below this level. + private static final float MARGINAL_CORRELATION = 0.75f; + + // Consider object to be lost if correlation falls below this threshold. + private static final float MIN_CORRELATION = 0.3f; + + private static final int[] COLORS = { + Color.BLUE, Color.RED, Color.GREEN, Color.YELLOW, Color.CYAN, Color.MAGENTA + }; + + private final Queue<Integer> availableColors = new LinkedList<Integer>(); + + public ObjectTracker objectTracker; + + final List<Pair<Float, RectF>> screenRects = new LinkedList<Pair<Float, RectF>>(); + + private static class TrackedRecognition { + ObjectTracker.TrackedObject trackedObject; + float detectionConfidence; + int color; + } + + private final List<TrackedRecognition> trackedObjects = new LinkedList<TrackedRecognition>(); + + private final Paint boxPaint = new Paint(); + + private final float textSizePx; + private final BorderedText borderedText; + + private Matrix frameToCanvasMatrix; + + private int frameWidth; + private int frameHeight; + + private int sensorOrientation; + + public MultiBoxTracker(final DisplayMetrics metrics) { + for (final int color : COLORS) { + availableColors.add(color); + } + + boxPaint.setColor(Color.RED); + boxPaint.setStyle(Style.STROKE); + boxPaint.setStrokeWidth(12.0f); + boxPaint.setStrokeCap(Cap.ROUND); + boxPaint.setStrokeJoin(Join.ROUND); + boxPaint.setStrokeMiter(100); + + textSizePx = TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, metrics); + borderedText = new BorderedText(textSizePx); + } + + private Matrix getFrameToCanvasMatrix() { + return frameToCanvasMatrix; + } + + public synchronized void drawDebug(final Canvas canvas) { + final Paint textPaint = new Paint(); + textPaint.setColor(Color.WHITE); + textPaint.setTextSize(60.0f); + + final Paint boxPaint = new Paint(); + boxPaint.setColor(Color.RED); + boxPaint.setAlpha(200); + boxPaint.setStyle(Style.STROKE); + + for (final Pair<Float, RectF> detection : screenRects) { + final RectF rect = detection.second; + canvas.drawRect(rect, boxPaint); + canvas.drawText("" + detection.first, rect.left, rect.top, textPaint); + borderedText.drawText(canvas, rect.centerX(), rect.centerY(), "" + detection.first); + } + + if (objectTracker == null) { + return; + } + + // Draw correlations. + for (final TrackedRecognition recognition : trackedObjects) { + final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject; + + final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame(); + + if (getFrameToCanvasMatrix().mapRect(trackedPos)) { + final String labelString = String.format("%.2f", trackedObject.getCurrentCorrelation()); + borderedText.drawText(canvas, trackedPos.right, trackedPos.bottom, labelString); + } + } + + final Matrix matrix = getFrameToCanvasMatrix(); + objectTracker.drawDebug(canvas, matrix); + } + + public synchronized void trackResults( + final List<Recognition> results, final byte[] frame, final long timestamp) { + logger.i("Processing %d results from %d", results.size(), timestamp); + processResults(timestamp, results, frame); + } + + public synchronized void draw(final Canvas canvas) { + if (objectTracker == null) { + return; + } + + // TODO(andrewharp): This may not work for non-90 deg rotations. + final float multiplier = + Math.min(canvas.getWidth() / (float) frameHeight, canvas.getHeight() / (float) frameWidth); + frameToCanvasMatrix = + ImageUtils.getTransformationMatrix( + frameWidth, + frameHeight, + (int) (multiplier * frameHeight), + (int) (multiplier * frameWidth), + sensorOrientation, + false); + + for (final TrackedRecognition recognition : trackedObjects) { + final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject; + + final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame(); + + if (getFrameToCanvasMatrix().mapRect(trackedPos)) { + boxPaint.setColor(recognition.color); + + final float cornerSize = Math.min(trackedPos.width(), trackedPos.height()) / 8.0f; + canvas.drawRoundRect(trackedPos, cornerSize, cornerSize, boxPaint); + + final String labelString = String.format("%.2f", recognition.detectionConfidence); + borderedText.drawText(canvas, trackedPos.left + cornerSize, trackedPos.bottom, labelString); + } + } + } + + public synchronized void onFrame( + final int w, + final int h, + final int rowStride, + final int sensorOrienation, + final byte[] frame, + final long timestamp) { + if (objectTracker == null) { + ObjectTracker.clearInstance(); + + logger.i("Initializing ObjectTracker: %dx%d", w, h); + objectTracker = ObjectTracker.getInstance(w, h, rowStride, true); + frameWidth = w; + frameHeight = h; + this.sensorOrientation = sensorOrienation; + } + + objectTracker.nextFrame(frame, null, timestamp, null, true); + + // Clean up any objects not worth tracking any more. + final LinkedList<TrackedRecognition> copyList = + new LinkedList<TrackedRecognition>(trackedObjects); + for (final TrackedRecognition recognition : copyList) { + final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject; + final float correlation = trackedObject.getCurrentCorrelation(); + if (correlation < MIN_CORRELATION) { + logger.v("Removing tracked object %s because NCC is %.2f", trackedObject, correlation); + trackedObject.stopTracking(); + trackedObjects.remove(recognition); + + availableColors.add(recognition.color); + } + } + } + + private void processResults( + final long timestamp, final List<Recognition> results, final byte[] originalFrame) { + final List<Pair<Float, RectF>> rectsToTrack = new LinkedList<Pair<Float, RectF>>(); + + screenRects.clear(); + final Matrix rgbFrameToScreen = new Matrix(getFrameToCanvasMatrix()); + + for (final Recognition result : results) { + if (result.getLocation() == null) { + continue; + } + final RectF detectionFrameRect = new RectF(result.getLocation()); + + final RectF detectionScreenRect = new RectF(); + rgbFrameToScreen.mapRect(detectionScreenRect, detectionFrameRect); + + logger.v( + "Result! Frame: " + result.getLocation() + " mapped to screen:" + detectionScreenRect); + + screenRects.add(new Pair<Float, RectF>(result.getConfidence(), detectionScreenRect)); + + if (detectionFrameRect.width() < MIN_SIZE || detectionFrameRect.height() < MIN_SIZE) { + logger.w("Degenerate rectangle! " + detectionFrameRect); + continue; + } + + rectsToTrack.add(new Pair<Float, RectF>(result.getConfidence(), detectionFrameRect)); + } + + if (rectsToTrack.isEmpty()) { + logger.v("Nothing to track, aborting."); + return; + } + + if (objectTracker == null) { + logger.w("No ObjectTracker, can't track anything!"); + return; + } + + logger.i("%d rects to track", rectsToTrack.size()); + for (final Pair<Float, RectF> potential : rectsToTrack) { + handleDetection(originalFrame, timestamp, potential); + } + } + + private void handleDetection( + final byte[] frameCopy, final long timestamp, final Pair<Float, RectF> potential) { + final ObjectTracker.TrackedObject potentialObject = + objectTracker.trackObject(potential.second, timestamp, frameCopy); + + final float potentialCorrelation = potentialObject.getCurrentCorrelation(); + logger.v( + "Tracked object went from %s to %s with correlation %.2f", + potential.second, potentialObject.getTrackedPositionInPreviewFrame(), potentialCorrelation); + + if (potentialCorrelation < MARGINAL_CORRELATION) { + logger.v("Correlation too low to begin tracking %s.", potentialObject); + potentialObject.stopTracking(); + return; + } + + final List<TrackedRecognition> removeList = new LinkedList<TrackedRecognition>(); + + float maxIntersect = 0.0f; + + // This is the current tracked object whose color we will take. If left null we'll take the + // first one from the color queue. + TrackedRecognition recogToReplace = null; + + // Look for intersections that will be overridden by this object or an intersection that would + // prevent this one from being placed. + for (final TrackedRecognition trackedRecognition : trackedObjects) { + final RectF a = trackedRecognition.trackedObject.getTrackedPositionInPreviewFrame(); + final RectF b = potentialObject.getTrackedPositionInPreviewFrame(); + final RectF intersection = new RectF(); + final boolean intersects = intersection.setIntersect(a, b); + + final float intersectAmount = + intersection.width() + * intersection.height() + / Math.min(a.width() * a.height(), b.width() * b.height()); + + // If there is an intersection with this currently tracked box above the maximum overlap + // percentage allowed, either the new recognition needs to be dismissed or the old + // recognition needs to be removed and possibly replaced with the new one. + if (intersects && intersectAmount > MAX_OVERLAP) { + if (potential.first < trackedRecognition.detectionConfidence + && trackedRecognition.trackedObject.getCurrentCorrelation() > MARGINAL_CORRELATION) { + // If track for the existing object is still going strong and the detection score was + // good, reject this new object. + potentialObject.stopTracking(); + return; + } else { + removeList.add(trackedRecognition); + + // Let the previously tracked object with max intersection amount donate its color to + // the new object. + if (intersectAmount > maxIntersect) { + maxIntersect = intersectAmount; + recogToReplace = trackedRecognition; + } + } + } + } + + // If we're already tracking the max object and no intersections were found to bump off, + // pick the worst current tracked object to remove, if it's also worse than this candidate + // object. + if (availableColors.isEmpty() && removeList.isEmpty()) { + for (final TrackedRecognition candidate : trackedObjects) { + if (candidate.detectionConfidence < potential.first) { + if (recogToReplace == null + || candidate.detectionConfidence < recogToReplace.detectionConfidence) { + // Save it so that we use this color for the new object. + recogToReplace = candidate; + } + } + } + if (recogToReplace != null) { + logger.v("Found non-intersecting object to remove."); + removeList.add(recogToReplace); + } else { + logger.v("No non-intersecting object found to remove"); + } + } + + // Remove everything that got intersected. + for (final TrackedRecognition trackedRecognition : removeList) { + logger.v( + "Removing tracked object %s with detection confidence %.2f, correlation %.2f", + trackedRecognition.trackedObject, + trackedRecognition.detectionConfidence, + trackedRecognition.trackedObject.getCurrentCorrelation()); + trackedRecognition.trackedObject.stopTracking(); + trackedObjects.remove(trackedRecognition); + if (trackedRecognition != recogToReplace) { + availableColors.add(trackedRecognition.color); + } + } + + if (recogToReplace == null && availableColors.isEmpty()) { + logger.e("No room to track this object, aborting."); + potentialObject.stopTracking(); + return; + } + + // Finally safe to say we can track this object. + logger.v( + "Tracking object %s with detection confidence %.2f at position %s", + potentialObject, potential.first, potential.second); + final TrackedRecognition trackedRecognition = new TrackedRecognition(); + trackedRecognition.detectionConfidence = potential.first; + trackedRecognition.trackedObject = potentialObject; + + // Use the color from a replaced object before taking one from the color queue. + trackedRecognition.color = + recogToReplace != null ? recogToReplace.color : availableColors.poll(); + trackedObjects.add(trackedRecognition); + } +} diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java new file mode 100644 index 0000000000..211d8077a3 --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java @@ -0,0 +1,649 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.demo.tracking; + +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Matrix; +import android.graphics.Paint; +import android.graphics.PointF; +import android.graphics.RectF; +import android.graphics.Typeface; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Vector; +import javax.microedition.khronos.opengles.GL10; +import org.tensorflow.demo.env.Logger; +import org.tensorflow.demo.env.Size; + +/** + * True object detector/tracker class that tracks objects across consecutive preview frames. + * It provides a simplified Java interface to the analogous native object defined by + * jni/client_vision/tracking/object_tracker.*. + * + * Currently, the ObjectTracker is a singleton due to native code restrictions, and so must + * be allocated by ObjectTracker.getInstance(). In addition, release() should be called + * as soon as the ObjectTracker is no longer needed, and before a new one is created. + * + * nextFrame() should be called as new frames become available, preferably as often as possible. + * + * After allocation, new TrackedObjects may be instantiated via trackObject(). TrackedObjects + * are associated with the ObjectTracker that created them, and are only valid while that + * ObjectTracker still exists. + */ +public class ObjectTracker { + private final Logger logger = new Logger(); + + private static final boolean DRAW_TEXT = false; + + /** + * How many history points to keep track of and draw in the red history line. + */ + private static final int MAX_DEBUG_HISTORY_SIZE = 30; + + /** + * How many frames of optical flow deltas to record. + * TODO(andrewharp): Push this down to the native level so it can be polled + * efficiently into a an array for upload, instead of keeping a duplicate + * copy in Java. + */ + private static final int MAX_FRAME_HISTORY_SIZE = 200; + + private static final int DOWNSAMPLE_FACTOR = 2; + + private final byte[] downsampledFrame; + + protected static ObjectTracker instance; + + private final Map<String, TrackedObject> trackedObjects; + + private long lastTimestamp; + + private FrameChange lastKeypoints; + + private final Vector<PointF> debugHistory; + + private final LinkedList<TimestampedDeltas> timestampedDeltas; + + protected final int frameWidth; + protected final int frameHeight; + private final int rowStride; + protected final boolean alwaysTrack; + + private static class TimestampedDeltas { + final long timestamp; + final byte[] deltas; + + public TimestampedDeltas(final long timestamp, final byte[] deltas) { + this.timestamp = timestamp; + this.deltas = deltas; + } + } + + /** + * A simple class that records keypoint information, which includes + * local location, score and type. This will be used in calculating + * FrameChange. + */ + public static class Keypoint { + public final float x; + public final float y; + public final float score; + public final int type; + + public Keypoint(final float x, final float y) { + this.x = x; + this.y = y; + this.score = 0; + this.type = -1; + } + + public Keypoint(final float x, final float y, final float score, final int type) { + this.x = x; + this.y = y; + this.score = score; + this.type = type; + } + + Keypoint delta(final Keypoint other) { + return new Keypoint(this.x - other.x, this.y - other.y); + } + } + + /** + * A simple class that could calculate Keypoint delta. + * This class will be used in calculating frame translation delta + * for optical flow. + */ + public static class PointChange { + public final Keypoint keypointA; + public final Keypoint keypointB; + Keypoint pointDelta; + private final boolean wasFound; + + public PointChange(final float x1, final float y1, + final float x2, final float y2, + final float score, final int type, + final boolean wasFound) { + this.wasFound = wasFound; + + keypointA = new Keypoint(x1, y1, score, type); + keypointB = new Keypoint(x2, y2); + } + + public Keypoint getDelta() { + if (pointDelta == null) { + pointDelta = keypointB.delta(keypointA); + } + return pointDelta; + } + } + + /** A class that records a timestamped frame translation delta for optical flow. */ + public static class FrameChange { + public static final int KEYPOINT_STEP = 7; + + public final Vector<PointChange> pointDeltas; + + private final float minScore; + private final float maxScore; + + public FrameChange(final float[] framePoints) { + float minScore = 100.0f; + float maxScore = -100.0f; + + pointDeltas = new Vector<PointChange>(framePoints.length / KEYPOINT_STEP); + + for (int i = 0; i < framePoints.length; i += KEYPOINT_STEP) { + final float x1 = framePoints[i + 0] * DOWNSAMPLE_FACTOR; + final float y1 = framePoints[i + 1] * DOWNSAMPLE_FACTOR; + + final boolean wasFound = framePoints[i + 2] > 0.0f; + + final float x2 = framePoints[i + 3] * DOWNSAMPLE_FACTOR; + final float y2 = framePoints[i + 4] * DOWNSAMPLE_FACTOR; + final float score = framePoints[i + 5]; + final int type = (int) framePoints[i + 6]; + + minScore = Math.min(minScore, score); + maxScore = Math.max(maxScore, score); + + pointDeltas.add(new PointChange(x1, y1, x2, y2, score, type, wasFound)); + } + + this.minScore = minScore; + this.maxScore = maxScore; + } + } + + public static synchronized ObjectTracker getInstance( + final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) { + if (instance == null) { + instance = new ObjectTracker(frameWidth, frameHeight, rowStride, alwaysTrack); + instance.init(); + } else { + throw new RuntimeException( + "Tried to create a new objectracker before releasing the old one!"); + } + return instance; + } + + public static synchronized void clearInstance() { + if (instance != null) { + instance.release(); + } + } + + protected ObjectTracker( + final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) { + this.frameWidth = frameWidth; + this.frameHeight = frameHeight; + this.rowStride = rowStride; + this.alwaysTrack = alwaysTrack; + this.timestampedDeltas = new LinkedList<TimestampedDeltas>(); + + trackedObjects = new HashMap<String, TrackedObject>(); + + debugHistory = new Vector<PointF>(MAX_DEBUG_HISTORY_SIZE); + + downsampledFrame = + new byte + [(frameWidth + DOWNSAMPLE_FACTOR - 1) + / DOWNSAMPLE_FACTOR + * (frameWidth + DOWNSAMPLE_FACTOR - 1) + / DOWNSAMPLE_FACTOR]; + } + + protected void init() { + // The native tracker never sees the full frame, so pre-scale dimensions + // by the downsample factor. + initNative(frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, alwaysTrack); + } + + private final float[] matrixValues = new float[9]; + + private long downsampledTimestamp; + + @SuppressWarnings("unused") + public synchronized void drawOverlay(final GL10 gl, + final Size cameraViewSize, final Matrix matrix) { + final Matrix tempMatrix = new Matrix(matrix); + tempMatrix.preScale(DOWNSAMPLE_FACTOR, DOWNSAMPLE_FACTOR); + tempMatrix.getValues(matrixValues); + drawNative(cameraViewSize.width, cameraViewSize.height, matrixValues); + } + + public synchronized void nextFrame( + final byte[] frameData, final byte[] uvData, + final long timestamp, final float[] transformationMatrix, + final boolean updateDebugInfo) { + if (downsampledTimestamp != timestamp) { + ObjectTracker.downsampleImageNative( + frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame); + downsampledTimestamp = timestamp; + } + + // Do Lucas Kanade using the fullframe initializer. + nextFrameNative(downsampledFrame, uvData, timestamp, transformationMatrix); + + timestampedDeltas.add(new TimestampedDeltas(timestamp, getKeypointsPacked(DOWNSAMPLE_FACTOR))); + while (timestampedDeltas.size() > MAX_FRAME_HISTORY_SIZE) { + timestampedDeltas.removeFirst(); + } + + for (final TrackedObject trackedObject : trackedObjects.values()) { + trackedObject.updateTrackedPosition(); + } + + if (updateDebugInfo) { + updateDebugHistory(); + } + + lastTimestamp = timestamp; + } + + public synchronized void release() { + releaseMemoryNative(); + synchronized (ObjectTracker.class) { + instance = null; + } + } + + private void drawHistoryDebug(final Canvas canvas) { + drawHistoryPoint( + canvas, frameWidth * DOWNSAMPLE_FACTOR / 2, frameHeight * DOWNSAMPLE_FACTOR / 2); + } + + private void drawHistoryPoint(final Canvas canvas, final float startX, final float startY) { + final Paint p = new Paint(); + p.setAntiAlias(false); + p.setTypeface(Typeface.SERIF); + + p.setColor(Color.RED); + p.setStrokeWidth(2.0f); + + // Draw the center circle. + p.setColor(Color.GREEN); + canvas.drawCircle(startX, startY, 3.0f, p); + + p.setColor(Color.RED); + + // Iterate through in backwards order. + synchronized (debugHistory) { + final int numPoints = debugHistory.size(); + float lastX = startX; + float lastY = startY; + for (int keypointNum = 0; keypointNum < numPoints; ++keypointNum) { + final PointF delta = debugHistory.get(numPoints - keypointNum - 1); + final float newX = lastX + delta.x; + final float newY = lastY + delta.y; + canvas.drawLine(lastX, lastY, newX, newY, p); + lastX = newX; + lastY = newY; + } + } + } + + private static int floatToChar(final float value) { + return Math.max(0, Math.min((int) (value * 255.999f), 255)); + } + + private void drawKeypointsDebug(final Canvas canvas) { + final Paint p = new Paint(); + if (lastKeypoints == null) { + return; + } + final int keypointSize = 3; + + final float minScore = lastKeypoints.minScore; + final float maxScore = lastKeypoints.maxScore; + + for (final PointChange keypoint : lastKeypoints.pointDeltas) { + if (keypoint.wasFound) { + final int r = + floatToChar((keypoint.keypointA.score - minScore) / (maxScore - minScore)); + final int b = + floatToChar(1.0f - (keypoint.keypointA.score - minScore) / (maxScore - minScore)); + + final int color = 0xFF000000 | (r << 16) | b; + p.setColor(color); + + final float[] screenPoints = {keypoint.keypointA.x, keypoint.keypointA.y, + keypoint.keypointB.x, keypoint.keypointB.y}; + canvas.drawRect(screenPoints[2] - keypointSize, + screenPoints[3] - keypointSize, + screenPoints[2] + keypointSize, + screenPoints[3] + keypointSize, p); + p.setColor(Color.CYAN); + canvas.drawLine(screenPoints[2], screenPoints[3], + screenPoints[0], screenPoints[1], p); + + if (DRAW_TEXT) { + p.setColor(Color.WHITE); + canvas.drawText(keypoint.keypointA.type + ": " + keypoint.keypointA.score, + keypoint.keypointA.x, keypoint.keypointA.y, p); + } + } else { + p.setColor(Color.YELLOW); + final float[] screenPoint = {keypoint.keypointA.x, keypoint.keypointA.y}; + canvas.drawCircle(screenPoint[0], screenPoint[1], 5.0f, p); + } + } + } + + private synchronized PointF getAccumulatedDelta(final long timestamp, final float positionX, + final float positionY, final float radius) { + final RectF currPosition = getCurrentPosition(timestamp, + new RectF(positionX - radius, positionY - radius, positionX + radius, positionY + radius)); + return new PointF(currPosition.centerX() - positionX, currPosition.centerY() - positionY); + } + + private synchronized RectF getCurrentPosition(final long timestamp, final RectF + oldPosition) { + final RectF downscaledFrameRect = downscaleRect(oldPosition); + + final float[] delta = new float[4]; + getCurrentPositionNative(timestamp, downscaledFrameRect.left, downscaledFrameRect.top, + downscaledFrameRect.right, downscaledFrameRect.bottom, delta); + + final RectF newPosition = new RectF(delta[0], delta[1], delta[2], delta[3]); + + return upscaleRect(newPosition); + } + + private void updateDebugHistory() { + lastKeypoints = new FrameChange(getKeypointsNative(false)); + + if (lastTimestamp == 0) { + return; + } + + final PointF delta = + getAccumulatedDelta( + lastTimestamp, frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, 100); + + synchronized (debugHistory) { + debugHistory.add(delta); + + while (debugHistory.size() > MAX_DEBUG_HISTORY_SIZE) { + debugHistory.remove(0); + } + } + } + + public synchronized void drawDebug(final Canvas canvas, final Matrix frameToCanvas) { + canvas.save(); + canvas.setMatrix(frameToCanvas); + + drawHistoryDebug(canvas); + drawKeypointsDebug(canvas); + + canvas.restore(); + } + + public Vector<String> getDebugText() { + final Vector<String> lines = new Vector<String>(); + + if (lastKeypoints != null) { + lines.add("Num keypoints " + lastKeypoints.pointDeltas.size()); + lines.add("Min score: " + lastKeypoints.minScore); + lines.add("Max score: " + lastKeypoints.maxScore); + } + + return lines; + } + + public synchronized List<byte[]> pollAccumulatedFlowData(final long endFrameTime) { + final List<byte[]> frameDeltas = new ArrayList<byte[]>(); + while (timestampedDeltas.size() > 0) { + final TimestampedDeltas currentDeltas = timestampedDeltas.peek(); + if (currentDeltas.timestamp <= endFrameTime) { + frameDeltas.add(currentDeltas.deltas); + timestampedDeltas.removeFirst(); + } else { + break; + } + } + + return frameDeltas; + } + + private RectF downscaleRect(final RectF fullFrameRect) { + return new RectF( + fullFrameRect.left / DOWNSAMPLE_FACTOR, + fullFrameRect.top / DOWNSAMPLE_FACTOR, + fullFrameRect.right / DOWNSAMPLE_FACTOR, + fullFrameRect.bottom / DOWNSAMPLE_FACTOR); + } + + private RectF upscaleRect(final RectF downsampledFrameRect) { + return new RectF( + downsampledFrameRect.left * DOWNSAMPLE_FACTOR, + downsampledFrameRect.top * DOWNSAMPLE_FACTOR, + downsampledFrameRect.right * DOWNSAMPLE_FACTOR, + downsampledFrameRect.bottom * DOWNSAMPLE_FACTOR); + } + + /** + * A TrackedObject represents a native TrackedObject, and provides access to the + * relevant native tracking information available after every frame update. They may + * be safely passed around and acessed externally, but will become invalid after + * stopTracking() is called or the related creating ObjectTracker is deactivated. + * + * @author andrewharp@google.com (Andrew Harp) + */ + public class TrackedObject { + private final String id; + + private long lastExternalPositionTime; + + private RectF lastTrackedPosition; + private boolean visibleInLastFrame; + + private boolean isDead; + + TrackedObject(final RectF position, final long timestamp, final byte[] data) { + isDead = false; + + id = Integer.toString(this.hashCode()); + + lastExternalPositionTime = timestamp; + + synchronized (ObjectTracker.this) { + registerInitialAppearance(position, data); + setPreviousPosition(position, timestamp); + trackedObjects.put(id, this); + } + } + + public void stopTracking() { + checkValidObject(); + + synchronized (ObjectTracker.this) { + isDead = true; + forgetNative(id); + trackedObjects.remove(id); + } + } + + public float getCurrentCorrelation() { + checkValidObject(); + return ObjectTracker.this.getCurrentCorrelation(id); + } + + void registerInitialAppearance(final RectF position, final byte[] data) { + final RectF externalPosition = downscaleRect(position); + registerNewObjectWithAppearanceNative(id, + externalPosition.left, externalPosition.top, + externalPosition.right, externalPosition.bottom, + data); + } + + synchronized void setPreviousPosition(final RectF position, final long timestamp) { + checkValidObject(); + synchronized (ObjectTracker.this) { + if (lastExternalPositionTime > timestamp) { + logger.w("Tried to use older position time!"); + return; + } + final RectF externalPosition = downscaleRect(position); + lastExternalPositionTime = timestamp; + + setPreviousPositionNative(id, + externalPosition.left, externalPosition.top, + externalPosition.right, externalPosition.bottom, + lastExternalPositionTime); + + updateTrackedPosition(); + } + } + + void setCurrentPosition(final RectF position) { + checkValidObject(); + final RectF downsampledPosition = downscaleRect(position); + synchronized (ObjectTracker.this) { + setCurrentPositionNative(id, + downsampledPosition.left, downsampledPosition.top, + downsampledPosition.right, downsampledPosition.bottom); + } + } + + private synchronized void updateTrackedPosition() { + checkValidObject(); + + final float[] delta = new float[4]; + getTrackedPositionNative(id, delta); + lastTrackedPosition = new RectF(delta[0], delta[1], delta[2], delta[3]); + + visibleInLastFrame = isObjectVisible(id); + } + + public synchronized RectF getTrackedPositionInPreviewFrame() { + checkValidObject(); + + if (lastTrackedPosition == null) { + return null; + } + return upscaleRect(lastTrackedPosition); + } + + synchronized long getLastExternalPositionTime() { + return lastExternalPositionTime; + } + + public synchronized boolean visibleInLastPreviewFrame() { + return visibleInLastFrame; + } + + private void checkValidObject() { + if (isDead) { + throw new RuntimeException("TrackedObject already removed from tracking!"); + } else if (ObjectTracker.this != instance) { + throw new RuntimeException("TrackedObject created with another ObjectTracker!"); + } + } + } + + public synchronized TrackedObject trackObject( + final RectF position, final long timestamp, final byte[] frameData) { + if (downsampledTimestamp != timestamp) { + ObjectTracker.downsampleImageNative( + frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame); + downsampledTimestamp = timestamp; + } + return new TrackedObject(position, timestamp, downsampledFrame); + } + + public synchronized TrackedObject trackObject(final RectF position, final byte[] frameData) { + return new TrackedObject(position, lastTimestamp, frameData); + } + + /*********************** NATIVE CODE *************************************/ + + /** + * This will contain an opaque pointer to the native ObjectTracker + */ + private int nativeObjectTracker; + + private native void initNative(int imageWidth, int imageHeight, boolean alwaysTrack); + + protected native void registerNewObjectWithAppearanceNative( + String objectId, float x1, float y1, float x2, float y2, byte[] data); + + protected native void setPreviousPositionNative( + String objectId, float x1, float y1, float x2, float y2, long timestamp); + + protected native void setCurrentPositionNative( + String objectId, float x1, float y1, float x2, float y2); + + protected native void forgetNative(String key); + + protected native String getModelIdNative(String key); + + protected native boolean haveObject(String key); + protected native boolean isObjectVisible(String key); + protected native float getCurrentCorrelation(String key); + + protected native float getMatchScore(String key); + + protected native void getTrackedPositionNative(String key, float[] points); + + protected native void nextFrameNative( + byte[] frameData, byte[] uvData, long timestamp, float[] frameAlignMatrix); + + protected native void releaseMemoryNative(); + + protected native void getCurrentPositionNative(long timestamp, + final float positionX1, final float positionY1, + final float positionX2, final float positionY2, + final float[] delta); + + protected native byte[] getKeypointsPacked(float scaleFactor); + + protected native float[] getKeypointsNative(boolean onlyReturnCorrespondingKeypoints); + + protected native void drawNative(int viewWidth, int viewHeight, float[] frameToCanvas); + + protected static native void downsampleImageNative( + int width, int height, int rowStride, byte[] input, int factor, byte[] output); + + static { + System.loadLibrary("tensorflow_demo"); + } +} diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py index e9ca4e5520..ef3d21767a 100644 --- a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py +++ b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py @@ -91,7 +91,7 @@ def run_training(): sess.run(init_op) # Instantiate a SummaryWriter to output summaries and the Graph. - summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) + summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph) # Start input enqueue threads. coord = tf.train.Coordinator() diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py index 4e41ab18e3..392309d543 100644 --- a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py +++ b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py @@ -101,7 +101,7 @@ def run_training(): feed_dict={labels_initializer: data_sets.train.labels}) # Instantiate a SummaryWriter to output summaries and the Graph. - summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) + summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph) # Start input enqueue threads. coord = tf.train.Coordinator() diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py index 6b34e57b8f..ca8c9358b3 100644 --- a/tensorflow/examples/image_retraining/retrain.py +++ b/tensorflow/examples/image_retraining/retrain.py @@ -85,14 +85,6 @@ from tensorflow.python.util import compat FLAGS = None -# Input and output file flags. - -# Details of the training configuration. - -# File-system cache locations. - -# Controls the distortions used during training. - # These are all parameters that are tied to the particular model architecture # we're using for Inception v3. These include things like tensor names and their # sizes. If you want to adapt this script to work with another model, you will @@ -455,7 +447,8 @@ def get_random_cached_bottlenecks(sess, image_lists, how_many, category, Args: sess: Current TensorFlow Session. image_lists: Dictionary of training images for each label. - how_many: The number of bottleneck values to return. + how_many: If positive, a random sample of this size will be chosen. + If negative, all bottlenecks will be retrieved. category: Name string of which set to pull from - training, testing, or validation. bottleneck_dir: Folder string holding cached files of bottleneck values. @@ -465,24 +458,47 @@ def get_random_cached_bottlenecks(sess, image_lists, how_many, category, bottleneck_tensor: The bottleneck output layer of the CNN graph. Returns: - List of bottleneck arrays and their corresponding ground truths. + List of bottleneck arrays, their corresponding ground truths, and the + relevant filenames. """ class_count = len(image_lists.keys()) bottlenecks = [] ground_truths = [] - for unused_i in range(how_many): - label_index = random.randrange(class_count) - label_name = list(image_lists.keys())[label_index] - image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1) - bottleneck = get_or_create_bottleneck(sess, image_lists, label_name, - image_index, image_dir, category, - bottleneck_dir, jpeg_data_tensor, - bottleneck_tensor) - ground_truth = np.zeros(class_count, dtype=np.float32) - ground_truth[label_index] = 1.0 - bottlenecks.append(bottleneck) - ground_truths.append(ground_truth) - return bottlenecks, ground_truths + filenames = [] + if how_many >= 0: + # Retrieve a random sample of bottlenecks. + for unused_i in range(how_many): + label_index = random.randrange(class_count) + label_name = list(image_lists.keys())[label_index] + image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1) + image_name = get_image_path(image_lists, label_name, image_index, + image_dir, category) + bottleneck = get_or_create_bottleneck(sess, image_lists, label_name, + image_index, image_dir, category, + bottleneck_dir, jpeg_data_tensor, + bottleneck_tensor) + ground_truth = np.zeros(class_count, dtype=np.float32) + ground_truth[label_index] = 1.0 + bottlenecks.append(bottleneck) + ground_truths.append(ground_truth) + filenames.append(image_name) + else: + # Retrieve all bottlenecks. + for label_index, label_name in enumerate(image_lists.keys()): + for image_index, image_name in enumerate( + image_lists[label_name][category]): + image_name = get_image_path(image_lists, label_name, image_index, + image_dir, category) + bottleneck = get_or_create_bottleneck(sess, image_lists, label_name, + image_index, image_dir, category, + bottleneck_dir, jpeg_data_tensor, + bottleneck_tensor) + ground_truth = np.zeros(class_count, dtype=np.float32) + ground_truth[label_index] = 1.0 + bottlenecks.append(bottleneck) + ground_truths.append(ground_truth) + filenames.append(image_name) + return bottlenecks, ground_truths, filenames def get_random_distorted_bottlenecks( @@ -729,16 +745,17 @@ def add_evaluation_step(result_tensor, ground_truth_tensor): into. Returns: - Nothing. + Tuple of (evaluation step, prediction). """ with tf.name_scope('accuracy'): with tf.name_scope('correct_prediction'): - correct_prediction = tf.equal(tf.argmax(result_tensor, 1), \ - tf.argmax(ground_truth_tensor, 1)) + prediction = tf.argmax(result_tensor, 1) + correct_prediction = tf.equal( + prediction, tf.argmax(ground_truth_tensor, 1)) with tf.name_scope('accuracy'): evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', evaluation_step) - return evaluation_step + return evaluation_step, prediction def main(_): @@ -788,13 +805,14 @@ def main(_): bottleneck_tensor) # Create the operations we need to evaluate the accuracy of our new layer. - evaluation_step = add_evaluation_step(final_tensor, ground_truth_input) + evaluation_step, prediction = add_evaluation_step( + final_tensor, ground_truth_input) # Merge all the summaries and write them out to /tmp/retrain_logs (by default) merged = tf.summary.merge_all() - train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/train', - sess.graph) - validation_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/validation') + train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', + sess.graph) + validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation') # Set up all our weights to their initial default values. init = tf.global_variables_initializer() @@ -810,7 +828,7 @@ def main(_): FLAGS.image_dir, distorted_jpeg_data_tensor, distorted_image_tensor, resized_image_tensor, bottleneck_tensor) else: - train_bottlenecks, train_ground_truth = get_random_cached_bottlenecks( + train_bottlenecks, train_ground_truth, _ = get_random_cached_bottlenecks( sess, image_lists, FLAGS.train_batch_size, 'training', FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, bottleneck_tensor) @@ -832,7 +850,7 @@ def main(_): train_accuracy * 100)) print('%s: Step %d: Cross entropy = %f' % (datetime.now(), i, cross_entropy_value)) - validation_bottlenecks, validation_ground_truth = ( + validation_bottlenecks, validation_ground_truth, _ = ( get_random_cached_bottlenecks( sess, image_lists, FLAGS.validation_batch_size, 'validation', FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, @@ -844,20 +862,29 @@ def main(_): feed_dict={bottleneck_input: validation_bottlenecks, ground_truth_input: validation_ground_truth}) validation_writer.add_summary(validation_summary, i) - print('%s: Step %d: Validation accuracy = %.1f%%' % - (datetime.now(), i, validation_accuracy * 100)) + print('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' % + (datetime.now(), i, validation_accuracy * 100, + len(validation_bottlenecks))) # We've completed all our training, so run a final test evaluation on # some new images we haven't used before. - test_bottlenecks, test_ground_truth = get_random_cached_bottlenecks( - sess, image_lists, FLAGS.test_batch_size, 'testing', - FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, - bottleneck_tensor) - test_accuracy = sess.run( - evaluation_step, + test_bottlenecks, test_ground_truth, test_filenames = ( + get_random_cached_bottlenecks(sess, image_lists, FLAGS.test_batch_size, + 'testing', FLAGS.bottleneck_dir, + FLAGS.image_dir, jpeg_data_tensor, + bottleneck_tensor)) + test_accuracy, predictions = sess.run( + [evaluation_step, prediction], feed_dict={bottleneck_input: test_bottlenecks, ground_truth_input: test_ground_truth}) - print('Final test accuracy = %.1f%%' % (test_accuracy * 100)) + print('Final test accuracy = %.1f%% (N=%d)' % ( + test_accuracy * 100, len(test_bottlenecks))) + + if FLAGS.print_misclassified_test_images: + print('=== MISCLASSIFIED TEST IMAGES ===') + for i, test_filename in enumerate(test_filenames): + if predictions[i] != test_ground_truth[i].argmax(): + print('%70s %s' % (test_filename, image_lists.keys()[predictions[i]])) # Write out the trained graph and labels with the weights stored as constants. output_graph_def = graph_util.convert_variables_to_constants( @@ -933,10 +960,12 @@ if __name__ == '__main__': parser.add_argument( '--test_batch_size', type=int, - default=500, + default=-1, help="""\ - How many images to test on at a time. This test set is only used - infrequently to verify the overall accuracy of the model.\ + How many images to test on. This test set is only used once, to evaluate + the final accuracy of the model after training completes. + A value of -1 causes the entire test set to be used, which leads to more + stable results across runs.\ """ ) parser.add_argument( @@ -946,10 +975,21 @@ if __name__ == '__main__': help="""\ How many images to use in an evaluation batch. This validation set is used much more often than the test set, and is an early indicator of how - accurate the model is during training.\ + accurate the model is during training. + A value of -1 causes the entire validation set to be used, which leads to + more stable results across training iterations, but may be slower on large + training sets.\ """ ) parser.add_argument( + '--print_misclassified_test_images', + default=False, + help="""\ + Whether to print out a list of all misclassified test images.\ + """, + action='store_true' + ) + parser.add_argument( '--model_dir', type=str, default='/tmp/imagenet', diff --git a/tensorflow/examples/learn/wide_n_deep_tutorial.py b/tensorflow/examples/learn/wide_n_deep_tutorial.py index 888bf33b48..1fe12f7b76 100644 --- a/tensorflow/examples/learn/wide_n_deep_tutorial.py +++ b/tensorflow/examples/learn/wide_n_deep_tutorial.py @@ -56,7 +56,7 @@ def maybe_download(): train_file_name = FLAGS.train_data else: train_file = tempfile.NamedTemporaryFile(delete=False) - urllib.request.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", train_file.name) # pylint: disable=line-too-long + urllib.request.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data", train_file.name) # pylint: disable=line-too-long train_file_name = train_file.name train_file.close() print("Training data is downloaded to %s" % train_file_name) @@ -65,7 +65,7 @@ def maybe_download(): test_file_name = FLAGS.test_data else: test_file = tempfile.NamedTemporaryFile(delete=False) - urllib.request.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", test_file.name) # pylint: disable=line-too-long + urllib.request.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test", test_file.name) # pylint: disable=line-too-long test_file_name = test_file.name test_file.close() print("Test data is downloaded to %s" % test_file_name) diff --git a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py index f24a558dc2..be50f4529f 100644 --- a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py +++ b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py @@ -152,7 +152,7 @@ def run_training(): sess = tf.Session() # Instantiate a SummaryWriter to output summaries and the Graph. - summary_writer = tf.train.SummaryWriter(FLAGS.log_dir, sess.graph) + summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph) # And then after everything is built: diff --git a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py index 0e12a6571b..83879d0807 100644 --- a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py +++ b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py @@ -137,9 +137,8 @@ def train(): # Merge all the summaries and write them out to /tmp/mnist_logs (by default) merged = tf.summary.merge_all() - train_writer = tf.train.SummaryWriter(FLAGS.log_dir + '/train', - sess.graph) - test_writer = tf.train.SummaryWriter(FLAGS.log_dir + '/test') + train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) + test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test') tf.global_variables_initializer().run() # Train the model, and also write summaries. diff --git a/tensorflow/g3doc/api_docs/python/client.md b/tensorflow/g3doc/api_docs/python/client.md index 76571a0aff..fbd1bf5808 100644 --- a/tensorflow/g3doc/api_docs/python/client.md +++ b/tensorflow/g3doc/api_docs/python/client.md @@ -103,8 +103,8 @@ and evaluate every `Tensor` in `fetches`, substituting the values in `feed_dict` for the corresponding input values. The `fetches` argument may be a single graph element, or an arbitrarily -nested list, tuple, namedtuple, or dict containing graph elements at its -leaves. A graph element can be one of the following types: +nested list, tuple, namedtuple, dict, or OrderedDict containing graph +elements at its leaves. A graph element can be one of the following types: * An [`Operation`](../../api_docs/python/framework.md#Operation). The corresponding fetched value will be `None`. diff --git a/tensorflow/g3doc/api_docs/python/contrib.distributions.md b/tensorflow/g3doc/api_docs/python/contrib.distributions.md index bd1a9db7bb..da86d2cad1 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.distributions.md +++ b/tensorflow/g3doc/api_docs/python/contrib.distributions.md @@ -20891,7 +20891,7 @@ log_normal = ds.TransformedDistribution( forward_fn=tf.exp, inverse_fn=tf.log, inverse_log_det_jacobian_fn=( - lambda y: -tf.reduce_sum(tf.log(x), reduction_indices=-1)), + lambda y: -tf.reduce_sum(tf.log(y), reduction_indices=-1)), name="LogNormalTransformedDistribution") ``` @@ -20913,7 +20913,7 @@ Construct a Transformed Distribution. ##### Args: -* <b>`distribution`</b>: The base distribution class to transform. Typically an +* <b>`distribution`</b>: The base distribution instance to transform. Typically an instance of `Distribution`. * <b>`bijector`</b>: The object responsible for calculating the transformation. Typically an instance of `Bijector`. @@ -20987,10 +20987,10 @@ cdf(x) := P[X <= x] Additional documentation from `TransformedDistribution`: -##### <b>`condition_kwargs`</b>: +##### `condition_kwargs`: -* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector. -* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution. +* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector. +* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution. ##### Args: @@ -21128,10 +21128,10 @@ a more accurate answer than simply taking the logarithm of the `cdf` when Additional documentation from `TransformedDistribution`: -##### <b>`condition_kwargs`</b>: +##### `condition_kwargs`: -* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector. -* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution. +* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector. +* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution. ##### Args: @@ -21212,10 +21212,10 @@ Implements `(log o p o g^{-1})(y) + (log o det o J o g^{-1})(y)`, Also raises a `ValueError` if `inverse` was not provided to the distribution and `y` was not returned from `sample`. -##### <b>`condition_kwargs`</b>: +##### `condition_kwargs`: -* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector. -* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution. +* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector. +* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution. ##### Args: @@ -21251,10 +21251,10 @@ survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`. Additional documentation from `TransformedDistribution`: -##### <b>`condition_kwargs`</b>: +##### `condition_kwargs`: -* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector. -* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution. +* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector. +* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution. ##### Args: @@ -21404,10 +21404,10 @@ Implements `p(g^{-1}(y)) det|J(g^{-1}(y))|`, where `g^{-1}` is the Also raises a `ValueError` if `inverse` was not provided to the distribution and `y` was not returned from `sample`. -##### <b>`condition_kwargs`</b>: +##### `condition_kwargs`: -* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector. -* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution. +* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector. +* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution. ##### Args: @@ -21458,10 +21458,10 @@ Additional documentation from `TransformedDistribution`: Samples from the base distribution and then passes through the bijector's forward transform. -##### <b>`condition_kwargs`</b>: +##### `condition_kwargs`: -* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector. -* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution. +* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector. +* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution. ##### Args: @@ -21507,10 +21507,10 @@ survival_function(x) = P[X > x] Additional documentation from `TransformedDistribution`: -##### <b>`condition_kwargs`</b>: +##### `condition_kwargs`: -* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector. -* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution. +* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector. +* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution. ##### Args: diff --git a/tensorflow/g3doc/api_docs/python/contrib.framework.md b/tensorflow/g3doc/api_docs/python/contrib.framework.md index 0b8690cf2d..49892fdcaf 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.framework.md +++ b/tensorflow/g3doc/api_docs/python/contrib.framework.md @@ -37,14 +37,15 @@ be `dtypes.float32` or `dtypes.float64`. If neither `tensors` nor - - - -### `tf.contrib.framework.assert_scalar_int(tensor)` {#assert_scalar_int} +### `tf.contrib.framework.assert_scalar_int(tensor, name=None)` {#assert_scalar_int} Assert `tensor` is 0-D, of type `tf.int32` or `tf.int64`. ##### Args: -* <b>`tensor`</b>: Tensor to test. +* <b>`tensor`</b>: `Tensor` to test. +* <b>`name`</b>: Name of the op and of the new `Tensor` if one is created. ##### Returns: @@ -309,7 +310,7 @@ to the rest of the docstring. - - - -### `tf.contrib.framework.deprecated_args(date, instructions, *deprecated_arg_names)` {#deprecated_args} +### `tf.contrib.framework.deprecated_args(date, instructions, *deprecated_arg_names_or_tuples)` {#deprecated_args} Decorator for marking specific function arguments as deprecated. @@ -333,7 +334,10 @@ prepended to the rest of the docstring. ISO 8601 (YYYY-MM-DD). * <b>`instructions`</b>: String. Instructions on how to update code using the deprecated function. -* <b>`*deprecated_arg_names`</b>: String. The deprecated arguments. +* <b>`*deprecated_arg_names_or_tuples`</b>: String. or 2-Tuple(String, + [ok_vals]). The string is the deprecated argument name. + Optionally, an ok-value may be provided. If the user provided + argument equals this value, the warning is suppressed. ##### Returns: @@ -342,8 +346,10 @@ prepended to the rest of the docstring. ##### Raises: -* <b>`ValueError`</b>: If date is not in ISO 8601 format, instructions are empty, or - the deprecated arguments are not present in the function signature. +* <b>`ValueError`</b>: If date is not in ISO 8601 format, instructions are + empty, the deprecated arguments are not present in the function + signature, or the second element of a deprecated_tuple is not a + list. - - - @@ -865,6 +871,11 @@ Gets an existing model variable with these parameters or creates a new one. device. * <b>`device`</b>: Optional device to place the variable. It can be an string or a function that is called to get the device for the variable. +* <b>`partitioner`</b>: Optional callable that accepts a fully defined `TensorShape` + and dtype of the `Variable` to be created, and returns a list of + partitions for each axis (currently only one axis can be partitioned). +* <b>`custom_getter`</b>: Callable that allows overwriting the internal + get_variable method and has to have the same signature. ##### Returns: @@ -896,6 +907,11 @@ Gets an existing variable with these parameters or creates a new one. device. * <b>`device`</b>: Optional device to place the variable. It can be an string or a function that is called to get the device for the variable. +* <b>`partitioner`</b>: Optional callable that accepts a fully defined `TensorShape` + and dtype of the `Variable` to be created, and returns a list of + partitions for each axis (currently only one axis can be partitioned). +* <b>`custom_getter`</b>: Callable that allows overwriting the internal + get_variable method and has to have the same signature. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/contrib.layers.md b/tensorflow/g3doc/api_docs/python/contrib.layers.md index a6fcf6f270..3babfa0cab 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.layers.md +++ b/tensorflow/g3doc/api_docs/python/contrib.layers.md @@ -635,53 +635,6 @@ to produce the end result. - - - -### `tf.stack(values, axis=0, name='stack')` {#stack} - -Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor. - -Packs the list of tensors in `values` into a tensor with rank one higher than -each tensor in `values`, by packing them along the `axis` dimension. -Given a list of length `N` of tensors of shape `(A, B, C)`; - -if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`. -if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`. -Etc. - -For example: - -```prettyprint -# 'x' is [1, 4] -# 'y' is [2, 5] -# 'z' is [3, 6] -stack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. -stack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]] -``` - -This is the opposite of unstack. The numpy equivalent is - - tf.stack([x, y, z]) = np.asarray([x, y, z]) - -##### Args: - - -* <b>`values`</b>: A list of `Tensor` objects with the same shape and type. -* <b>`axis`</b>: An `int`. The axis to stack along. Defaults to the first dimension. - Supports negative indexes. -* <b>`name`</b>: A name for this operation (optional). - -##### Returns: - - -* <b>`output`</b>: A stacked `Tensor` with the same type as `values`. - -##### Raises: - - -* <b>`ValueError`</b>: If `axis` is out of the range [-(R+1), R+1). - - -- - - - ### `tf.contrib.layers.unit_norm(*args, **kwargs)` {#unit_norm} Normalizes the given input across the specified dimension to unit length. @@ -710,6 +663,9 @@ Note that the rank of `input` must be known. Aliases for fully_connected which set a default activation function are available: `relu`, `relu6` and `linear`. +`stack` operation is also available. It builds a stack of layers by applying +a layer repeatedly. + ## Regularizers Regularization can help prevent overfitting. These have the signature @@ -1230,7 +1186,7 @@ Creates a _CrossedColumn for performing feature crosses. - - - -### `tf.contrib.layers.embedding_column(sparse_id_column, dimension, combiner=None, initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None)` {#embedding_column} +### `tf.contrib.layers.embedding_column(sparse_id_column, dimension, combiner=None, initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None)` {#embedding_column} Creates an `_EmbeddingColumn` for feeding sparse data into a DNN. @@ -1258,6 +1214,8 @@ Creates an `_EmbeddingColumn` for feeding sparse data into a DNN. * <b>`tensor_name_in_ckpt`</b>: (Optional). Name of the `Tensor` in the provided checkpoint from which to restore the column weights. Required if `ckpt_to_load_from` is not None. +* <b>`max_norm`</b>: (Optional). If not None, embedding values are l2-normalized to + the value of max_norm. ##### Returns: @@ -1582,7 +1540,7 @@ Creates a `_RealValuedColumn` for dense numeric data. - - - -### `tf.contrib.layers.shared_embedding_columns(sparse_id_columns, dimension, combiner=None, shared_embedding_name=None, initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None)` {#shared_embedding_columns} +### `tf.contrib.layers.shared_embedding_columns(sparse_id_columns, dimension, combiner=None, shared_embedding_name=None, initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None)` {#shared_embedding_columns} Creates a list of `_EmbeddingColumn` sharing the same embedding. @@ -1613,6 +1571,8 @@ Creates a list of `_EmbeddingColumn` sharing the same embedding. * <b>`tensor_name_in_ckpt`</b>: (Optional). Name of the `Tensor` in the provided checkpoint from which to restore the column weights. Required if `ckpt_to_load_from` is not None. +* <b>`max_norm`</b>: (Optional). If not None, embedding values are l2-normalized to + the value of max_norm. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/contrib.learn.md b/tensorflow/g3doc/api_docs/python/contrib.learn.md index 682d6ff930..1808bb94e2 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.learn.md +++ b/tensorflow/g3doc/api_docs/python/contrib.learn.md @@ -459,6 +459,39 @@ The signature of the input_fn accepted by export is changing to be consistent wi - - - +#### `tf.contrib.learn.Estimator.export_savedmodel(*args, **kwargs)` {#Estimator.export_savedmodel} + +Exports inference graph as a SavedModel into given dir. (experimental) + +THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning. + + + Args: + export_dir_base: A string containing a directory to write the exported + graph and checkpoints. + input_fn: A function that takes no argument and + returns an `InputFnOps`. + default_output_alternative_key: the name of the head to serve when none is + specified. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel. Each key should give the destination + path (including the filename) relative to the assets.extra directory. + The corresponding value gives the full path of the source file to be + copied. For example, the simple case of copying a single file without + renaming it is specified as + `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + as_text: whether to write the SavedModel proto in text format. + exports_to_keep: Number of exports to keep. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if an unrecognized export_type is requested. + + +- - - + #### `tf.contrib.learn.Estimator.fit(*args, **kwargs)` {#Estimator.fit} See `Trainable`. (deprecated arguments) @@ -842,7 +875,7 @@ Input of `fit` and `evaluate` should have following features, whose `value` is a `Tensor`. - - - -#### `tf.contrib.learn.DNNClassifier.__init__(hidden_units, feature_columns, model_dir=None, n_classes=2, weight_column_name=None, optimizer=None, activation_fn=relu, dropout=None, gradient_clip_norm=None, enable_centered_bias=False, config=None, feature_engineering_fn=None)` {#DNNClassifier.__init__} +#### `tf.contrib.learn.DNNClassifier.__init__(hidden_units, feature_columns, model_dir=None, n_classes=2, weight_column_name=None, optimizer=None, activation_fn=relu, dropout=None, gradient_clip_norm=None, enable_centered_bias=False, config=None, feature_engineering_fn=None, embedding_lr_multipliers=None)` {#DNNClassifier.__init__} Initializes a DNNClassifier instance. @@ -882,6 +915,9 @@ Initializes a DNNClassifier instance. labels which are the output of `input_fn` and returns features and labels which will be fed into the model. +* <b>`embedding_lr_multipliers`</b>: Optional. A dictionary from `EmbeddingColumn` to + a `float` multiplier. Multiplier will be used to multiply with + learning rate for the embedding variables. ##### Returns: @@ -927,6 +963,15 @@ See BaseEstimator.export. - - - +#### `tf.contrib.learn.DNNClassifier.export_savedmodel(*args, **kwargs)` {#DNNClassifier.export_savedmodel} + +EXPERIMENTAL FUNCTION + +THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning. + + +- - - + #### `tf.contrib.learn.DNNClassifier.fit(x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None)` {#DNNClassifier.fit} See trainable.Trainable. Note: Labels must be integer class indices. @@ -1556,6 +1601,15 @@ See BaseEstimator.export. - - - +#### `tf.contrib.learn.LinearClassifier.export_savedmodel(*args, **kwargs)` {#LinearClassifier.export_savedmodel} + +EXPERIMENTAL FUNCTION + +THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning. + + +- - - + #### `tf.contrib.learn.LinearClassifier.fit(x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None)` {#LinearClassifier.fit} See trainable.Trainable. Note: Labels must be integer class indices. @@ -1746,6 +1800,15 @@ See BaseEstimator.export. - - - +#### `tf.contrib.learn.LinearRegressor.export_savedmodel(*args, **kwargs)` {#LinearRegressor.export_savedmodel} + +EXPERIMENTAL FUNCTION + +THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning. + + +- - - + #### `tf.contrib.learn.LinearRegressor.fit(x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None)` {#LinearRegressor.fit} See trainable.Trainable. @@ -1928,6 +1991,39 @@ The signature of the input_fn accepted by export is changing to be consistent wi - - - +#### `tf.contrib.learn.LogisticRegressor.export_savedmodel(*args, **kwargs)` {#LogisticRegressor.export_savedmodel} + +Exports inference graph as a SavedModel into given dir. (experimental) + +THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning. + + + Args: + export_dir_base: A string containing a directory to write the exported + graph and checkpoints. + input_fn: A function that takes no argument and + returns an `InputFnOps`. + default_output_alternative_key: the name of the head to serve when none is + specified. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel. Each key should give the destination + path (including the filename) relative to the assets.extra directory. + The corresponding value gives the full path of the source file to be + copied. For example, the simple case of copying a single file without + renaming it is specified as + `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + as_text: whether to write the SavedModel proto in text format. + exports_to_keep: Number of exports to keep. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if an unrecognized export_type is requested. + + +- - - + #### `tf.contrib.learn.LogisticRegressor.fit(*args, **kwargs)` {#LogisticRegressor.fit} See `Trainable`. (deprecated arguments) diff --git a/tensorflow/g3doc/api_docs/python/contrib.learn.monitors.md b/tensorflow/g3doc/api_docs/python/contrib.learn.monitors.md index 122c4e3551..d4d200399e 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.learn.monitors.md +++ b/tensorflow/g3doc/api_docs/python/contrib.learn.monitors.md @@ -887,9 +887,9 @@ The signature of the input_fn accepted by export is changing to be consistent wi `None`). input_feature_key: String key into the features dict returned by `input_fn` that corresponds to the raw `Example` strings `Tensor` that - the exported model will take as input. Can only be `None` if you're - using a custom `signature_fn` that does not use the first arg - (examples). + the exported model will take as input. Should be `None` if and only if + you're passing in a `signature_fn` that does not use the first arg + (`Tensor` of `Example` strings). exports_to_keep: int, number of exports to keep. signature_fn: Function that returns a default signature and a named signature map, given `Tensor` of `Example` strings, `dict` of `Tensor`s diff --git a/tensorflow/g3doc/api_docs/python/contrib.linalg.md b/tensorflow/g3doc/api_docs/python/contrib.linalg.md index e678edfe72..d5edaa3e82 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.linalg.md +++ b/tensorflow/g3doc/api_docs/python/contrib.linalg.md @@ -106,6 +106,19 @@ FILL IN WHAT IS MEANT BY COMPATIBLE SHAPE ### Performance FILL THIS IN + +### Matrix property hints + +This `LinearOperator` is initialized with boolean flags of the form `is_X`, +for `X = non_singular, self_adjoint` etc... +These have the following meaning +* If `is_X == True`, callers should expect the operator to have the + property `X`. This is a promise that should be fulfilled, but is *not* a + runtime assert. For example, finite floating point precision may result + in these promises being violated. +* If `is_X == False`, callers should expect the operator to not have `X`. +* If `is_X == None` (the default), callers should have no expectation either + way. - - - #### `tf.contrib.linalg.LinearOperator.__init__(dtype, graph_parents=None, is_non_singular=None, is_self_adjoint=None, is_positive_definite=None, name=None)` {#LinearOperator.__init__} @@ -115,16 +128,6 @@ Initialize the `LinearOperator`. **This is a private method for subclass use.** **Subclasses should copy-paste this `__init__` documentation.** -For `X = non_singular, self_adjoint` etc... -`is_X` is a Python `bool` initialization argument with the following meaning -* If `is_X == True`, callers should expect the operator to have the - attribute `X`. This is a promise that should be fulfilled, but is *not* a - runtime assert. Issues, such as floating point error, could mean the - operator violates this promise. -* If `is_X == False`, callers should expect the operator to not have `X`. -* If `is_X == None` (the default), callers should have no expectation either - way. - ##### Args: @@ -135,8 +138,12 @@ For `X = non_singular, self_adjoint` etc... * <b>`is_non_singular`</b>: Expect that this operator is non-singular. * <b>`is_self_adjoint`</b>: Expect that this operator is equal to its hermitian transpose. If `dtype` is real, this is equivalent to being symmetric. -* <b>`is_positive_definite`</b>: Expect that this operator is positive definite. -* <b>`name`</b>: A name for this `LinearOperator`. Default: subclass name. +* <b>`is_positive_definite`</b>: Expect that this operator is positive definite, + meaning the real part of all eigenvalues is positive. We do not require + the operator to be self-adjoint to be positive-definite. See: +* <b>`https`</b>: //en.wikipedia.org/wiki/Positive-definite_matrix\ + #Extension_for_non_symmetric_matrices +* <b>`name`</b>: A name for this `LinearOperator`. ##### Raises: @@ -146,6 +153,23 @@ For `X = non_singular, self_adjoint` etc... - - - +#### `tf.contrib.linalg.LinearOperator.add_to_tensor(x, name='add_to_tensor')` {#LinearOperator.add_to_tensor} + +Add matrix represented by this operator to `x`. Equivalent to `A + x`. + +##### Args: + + +* <b>`x`</b>: `Tensor` with same `dtype` and shape broadcastable to `self.shape`. +* <b>`name`</b>: A name to give this `Op`. + +##### Returns: + + A `Tensor` with broadcast shape and same `dtype` as `self`. + + +- - - + #### `tf.contrib.linalg.LinearOperator.apply(x, adjoint=False, name='apply')` {#LinearOperator.apply} Transform `x` with left multiplication: `x --> Ax`. @@ -176,6 +200,25 @@ Returns an `Op` that asserts this operator is non singular. Returns an `Op` that asserts this operator is positive definite. +Here, positive definite means the real part of all eigenvalues is positive. +We do not require the operator to be self-adjoint. + +##### Args: + + +* <b>`name`</b>: A name to give this `Op`. + +##### Returns: + + An `Op` that asserts this operator is positive definite. + + +- - - + +#### `tf.contrib.linalg.LinearOperator.assert_self_adjoint(name='assert_self_adjoint')` {#LinearOperator.assert_self_adjoint} + +Returns an `Op` that asserts this operator is self-adjoint. + - - - @@ -493,7 +536,7 @@ Return a dense (batch) matrix representing this operator. This operator acts like a [batch] matrix `A` with shape `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is -an `m x n` matrix. Again, this matrix `A` may not be materialized, but for +an `N x N` matrix. This matrix `A` is not materialized, but for purposes of broadcasting this shape will be relevant. `LinearOperatorDiag` is initialized with a (batch) vector. @@ -507,7 +550,7 @@ operator.to_dense() ==> [[1., 0.] [0., -1.]] -operator.shape() +operator.shape ==> [2, 2] operator.log_determinant() @@ -542,7 +585,7 @@ and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] ### Performance -Suppose `operator` is a `LinearOperatorDiag` is of shape `[N, N]`, +Suppose `operator` is a `LinearOperatorDiag` of shape `[N, N]`, and `x.shape = [N, R]`. Then * `operator.apply(x)` involves `N*R` multiplications. @@ -551,43 +594,68 @@ and `x.shape = [N, R]`. Then If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. -- - - -#### `tf.contrib.linalg.LinearOperatorDiag.__init__(diag, is_non_singular=None, is_self_adjoint=True, is_positive_definite=None, name='LinearOperatorDiag')` {#LinearOperatorDiag.__init__} +### Matrix property hints -Initialize a `LinearOperatorDiag`. - -For `X = non_singular, self_adjoint` etc... -`is_X` is a Python `bool` initialization argument with the following meaning +This `LinearOperator` is initialized with boolean flags of the form `is_X`, +for `X = non_singular, self_adjoint` etc... +These have the following meaning * If `is_X == True`, callers should expect the operator to have the - attribute `X`. This is a promise that should be fulfilled, but is *not* a - runtime assert. Issues, such as floating point error, could mean the - operator violates this promise. + property `X`. This is a promise that should be fulfilled, but is *not* a + runtime assert. For example, finite floating point precision may result + in these promises being violated. * If `is_X == False`, callers should expect the operator to not have `X`. * If `is_X == None` (the default), callers should have no expectation either way. +- - - + +#### `tf.contrib.linalg.LinearOperatorDiag.__init__(diag, is_non_singular=None, is_self_adjoint=True, is_positive_definite=None, name='LinearOperatorDiag')` {#LinearOperatorDiag.__init__} + +Initialize a `LinearOperatorDiag`. ##### Args: -* <b>`diag`</b>: Shape `[B1,...,Bb, N]` real float type `Tensor` with `b >= 0`, - `N >= 0`. The diagonal of the operator. +* <b>`diag`</b>: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. + The diagonal of the operator. Allowed dtypes: `float32`, `float64`, + `complex64`, `complex128`. * <b>`is_non_singular`</b>: Expect that this operator is non-singular. * <b>`is_self_adjoint`</b>: Expect that this operator is equal to its hermitian transpose. Since this is a real (not complex) diagonal operator, it is always self adjoint. -* <b>`is_positive_definite`</b>: Expect that this operator is positive definite. -* <b>`name`</b>: A name for this `LinearOperator`. Default: subclass name. +* <b>`is_positive_definite`</b>: Expect that this operator is positive definite, + meaning the real part of all eigenvalues is positive. We do not require + the operator to be self-adjoint to be positive-definite. See: +* <b>`https`</b>: //en.wikipedia.org/wiki/Positive-definite_matrix + #Extension_for_non_symmetric_matrices +* <b>`name`</b>: A name for this `LinearOperator`. ##### Raises: -* <b>`ValueError`</b>: If `diag.dtype` is not floating point. +* <b>`TypeError`</b>: If `diag.dtype` is not an allowed type. * <b>`ValueError`</b>: If `is_self_adjoint` is not `True`. - - - +#### `tf.contrib.linalg.LinearOperatorDiag.add_to_tensor(x, name='add_to_tensor')` {#LinearOperatorDiag.add_to_tensor} + +Add matrix represented by this operator to `x`. Equivalent to `A + x`. + +##### Args: + + +* <b>`x`</b>: `Tensor` with same `dtype` and shape broadcastable to `self.shape`. +* <b>`name`</b>: A name to give this `Op`. + +##### Returns: + + A `Tensor` with broadcast shape and same `dtype` as `self`. + + +- - - + #### `tf.contrib.linalg.LinearOperatorDiag.apply(x, adjoint=False, name='apply')` {#LinearOperatorDiag.apply} Transform `x` with left multiplication: `x --> Ax`. @@ -618,6 +686,25 @@ Returns an `Op` that asserts this operator is non singular. Returns an `Op` that asserts this operator is positive definite. +Here, positive definite means the real part of all eigenvalues is positive. +We do not require the operator to be self-adjoint. + +##### Args: + + +* <b>`name`</b>: A name to give this `Op`. + +##### Returns: + + An `Op` that asserts this operator is positive definite. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorDiag.assert_self_adjoint(name='assert_self_adjoint')` {#LinearOperatorDiag.assert_self_adjoint} + +Returns an `Op` that asserts this operator is self-adjoint. + - - - diff --git a/tensorflow/g3doc/api_docs/python/contrib.losses.md b/tensorflow/g3doc/api_docs/python/contrib.losses.md index cc6a14f891..e6b0e136a9 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.losses.md +++ b/tensorflow/g3doc/api_docs/python/contrib.losses.md @@ -67,6 +67,7 @@ Instructions for updating: Args: losses: A tensor of size [batch_size, d1, ... dN]. weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N. + scope: the scope for the operations performed in computing the loss. weight: Deprecated alias for `weights`. Returns: diff --git a/tensorflow/g3doc/api_docs/python/contrib.training.md b/tensorflow/g3doc/api_docs/python/contrib.training.md index 935c163e06..67ce73a347 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.training.md +++ b/tensorflow/g3doc/api_docs/python/contrib.training.md @@ -900,7 +900,7 @@ batch. - - - -### `tf.contrib.training.weighted_resample(inputs, weights, overall_rate, scope=None, mean_decay=0.999, warmup=10, seed=None)` {#weighted_resample} +### `tf.contrib.training.weighted_resample(inputs, weights, overall_rate, scope=None, mean_decay=0.999, seed=None)` {#weighted_resample} Performs an approximate weighted resampling of `inputs`. @@ -917,9 +917,6 @@ rate of selection across all inputs (and many invocations!) is * <b>`overall_rate`</b>: Desired overall rate of resampling. * <b>`scope`</b>: Scope to use for the op. * <b>`mean_decay`</b>: How quickly to decay the running estimate of the mean weight. -* <b>`warmup`</b>: Until the resulting tensor has been evaluated `warmup` - times, the resampling menthod uses the true mean over all calls - as its weight estimate, rather than a decayed mean. * <b>`seed`</b>: Random seed. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/control_flow_ops.md b/tensorflow/g3doc/api_docs/python/control_flow_ops.md index 2c9b4bed34..31435d5ec3 100644 --- a/tensorflow/g3doc/api_docs/python/control_flow_ops.md +++ b/tensorflow/g3doc/api_docs/python/control_flow_ops.md @@ -596,7 +596,7 @@ Returns the truth value of (x >= y) element-wise. - - - -### `tf.select(condition, t, e, name=None)` {#select} +### `tf.select(*args, **kwargs)` {#select} Selects elements from `t` or `e`, depending on `condition`. diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md index 2601bd99ff..7c799283fd 100644 --- a/tensorflow/g3doc/api_docs/python/framework.md +++ b/tensorflow/g3doc/api_docs/python/framework.md @@ -1438,8 +1438,19 @@ dynamic condition of the `Tensor`. #### `tf.Tensor.__div__(x, y)` {#Tensor.__div__} +Divide two values using Python 2 semantics. Used for Tensor.__div__. + +##### Args: +* <b>`x`</b>: `Tensor` numerator of real numeric type. +* <b>`y`</b>: `Tensor` denominator of real numeric type. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + `x / y` returns the quotient of x and y. + - - - @@ -1847,7 +1858,18 @@ Returns the truth value of x AND y element-wise. #### `tf.Tensor.__rdiv__(y, x)` {#Tensor.__rdiv__} +Divide two values using Python 2 semantics. Used for Tensor.__div__. + +##### Args: + + +* <b>`x`</b>: `Tensor` numerator of real numeric type. +* <b>`y`</b>: `Tensor` denominator of real numeric type. +* <b>`name`</b>: A name for the operation (optional). +##### Returns: + + `x / y` returns the quotient of x and y. - - - @@ -1998,34 +2020,7 @@ Returns x - y element-wise. #### `tf.Tensor.__rtruediv__(y, x)` {#Tensor.__rtruediv__} -Divides x / y elementwise, always producing floating point results. - -The same as `tf.div` for floating point arguments, but casts integer arguments -to floating point before dividing so that the result is always floating point. -This op is generated by normal `x / y` division in Python 3 and in Python 2.7 -with `from __future__ import division`. If you want integer division that -rounds down, use `x // y` or `tf.floordiv`. - -`x` and `y` must have the same numeric type. If the inputs are floating -point, the output will have the same type. If the inputs are integral, the -inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32` -and `int64` (matching the behavior of Numpy). - -##### Args: - - -* <b>`x`</b>: `Tensor` numerator of numeric type. -* <b>`y`</b>: `Tensor` denominator of numeric type. -* <b>`name`</b>: A name for the operation (optional). - -##### Returns: - `x / y` evaluated in floating point. - -##### Raises: - - -* <b>`TypeError`</b>: If `x` and `y` have different dtypes. - - - @@ -2067,34 +2062,7 @@ Returns x - y element-wise. #### `tf.Tensor.__truediv__(x, y)` {#Tensor.__truediv__} -Divides x / y elementwise, always producing floating point results. - -The same as `tf.div` for floating point arguments, but casts integer arguments -to floating point before dividing so that the result is always floating point. -This op is generated by normal `x / y` division in Python 3 and in Python 2.7 -with `from __future__ import division`. If you want integer division that -rounds down, use `x // y` or `tf.floordiv`. - -`x` and `y` must have the same numeric type. If the inputs are floating -point, the output will have the same type. If the inputs are integral, the -inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32` -and `int64` (matching the behavior of Numpy). - -##### Args: - - -* <b>`x`</b>: `Tensor` numerator of numeric type. -* <b>`y`</b>: `Tensor` denominator of numeric type. -* <b>`name`</b>: A name for the operation (optional). - -##### Returns: - - `x / y` evaluated in floating point. - -##### Raises: - -* <b>`TypeError`</b>: If `x` and `y` have different dtypes. - - - @@ -2827,7 +2795,7 @@ The following standard keys are defined: for more details. * `SUMMARIES`: the summary `Tensor` objects that have been created in the graph. See - [`tf.merge_all_summaries()`](../../api_docs/python/train.md#merge_all_summaries) + [`tf.contrib.deprecated.merge_all_summaries()`](../../api_docs/python/train.md#merge_all_summaries) for more details. * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to produce input for a computation. See diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.framework.deprecated_args.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.framework.deprecated_args.md index 4dd9bbc0f8..3f81ac9fc1 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.framework.deprecated_args.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.framework.deprecated_args.md @@ -1,4 +1,4 @@ -### `tf.contrib.framework.deprecated_args(date, instructions, *deprecated_arg_names)` {#deprecated_args} +### `tf.contrib.framework.deprecated_args(date, instructions, *deprecated_arg_names_or_tuples)` {#deprecated_args} Decorator for marking specific function arguments as deprecated. @@ -22,7 +22,10 @@ prepended to the rest of the docstring. ISO 8601 (YYYY-MM-DD). * <b>`instructions`</b>: String. Instructions on how to update code using the deprecated function. -* <b>`*deprecated_arg_names`</b>: String. The deprecated arguments. +* <b>`*deprecated_arg_names_or_tuples`</b>: String. or 2-Tuple(String, + [ok_vals]). The string is the deprecated argument name. + Optionally, an ok-value may be provided. If the user provided + argument equals this value, the warning is suppressed. ##### Returns: @@ -31,6 +34,8 @@ prepended to the rest of the docstring. ##### Raises: -* <b>`ValueError`</b>: If date is not in ISO 8601 format, instructions are empty, or - the deprecated arguments are not present in the function signature. +* <b>`ValueError`</b>: If date is not in ISO 8601 format, instructions are + empty, the deprecated arguments are not present in the function + signature, or the second element of a deprecated_tuple is not a + list. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.LinearRegressor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.LinearRegressor.md index 02cf9a8674..81405d1ab5 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.LinearRegressor.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.LinearRegressor.md @@ -113,6 +113,15 @@ See BaseEstimator.export. - - - +#### `tf.contrib.learn.LinearRegressor.export_savedmodel(*args, **kwargs)` {#LinearRegressor.export_savedmodel} + +EXPERIMENTAL FUNCTION + +THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning. + + +- - - + #### `tf.contrib.learn.LinearRegressor.fit(x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None)` {#LinearRegressor.fit} See trainable.Trainable. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.linalg.LinearOperatorDiag.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.linalg.LinearOperatorDiag.md index ebe0d9a3c2..f4360dc46c 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.linalg.LinearOperatorDiag.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.linalg.LinearOperatorDiag.md @@ -3,7 +3,7 @@ This operator acts like a [batch] matrix `A` with shape `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is -an `m x n` matrix. Again, this matrix `A` may not be materialized, but for +an `N x N` matrix. This matrix `A` is not materialized, but for purposes of broadcasting this shape will be relevant. `LinearOperatorDiag` is initialized with a (batch) vector. @@ -17,7 +17,7 @@ operator.to_dense() ==> [[1., 0.] [0., -1.]] -operator.shape() +operator.shape ==> [2, 2] operator.log_determinant() @@ -52,7 +52,7 @@ and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] ### Performance -Suppose `operator` is a `LinearOperatorDiag` is of shape `[N, N]`, +Suppose `operator` is a `LinearOperatorDiag` of shape `[N, N]`, and `x.shape = [N, R]`. Then * `operator.apply(x)` involves `N*R` multiplications. @@ -61,43 +61,68 @@ and `x.shape = [N, R]`. Then If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. -- - - -#### `tf.contrib.linalg.LinearOperatorDiag.__init__(diag, is_non_singular=None, is_self_adjoint=True, is_positive_definite=None, name='LinearOperatorDiag')` {#LinearOperatorDiag.__init__} +### Matrix property hints -Initialize a `LinearOperatorDiag`. - -For `X = non_singular, self_adjoint` etc... -`is_X` is a Python `bool` initialization argument with the following meaning +This `LinearOperator` is initialized with boolean flags of the form `is_X`, +for `X = non_singular, self_adjoint` etc... +These have the following meaning * If `is_X == True`, callers should expect the operator to have the - attribute `X`. This is a promise that should be fulfilled, but is *not* a - runtime assert. Issues, such as floating point error, could mean the - operator violates this promise. + property `X`. This is a promise that should be fulfilled, but is *not* a + runtime assert. For example, finite floating point precision may result + in these promises being violated. * If `is_X == False`, callers should expect the operator to not have `X`. * If `is_X == None` (the default), callers should have no expectation either way. +- - - + +#### `tf.contrib.linalg.LinearOperatorDiag.__init__(diag, is_non_singular=None, is_self_adjoint=True, is_positive_definite=None, name='LinearOperatorDiag')` {#LinearOperatorDiag.__init__} + +Initialize a `LinearOperatorDiag`. ##### Args: -* <b>`diag`</b>: Shape `[B1,...,Bb, N]` real float type `Tensor` with `b >= 0`, - `N >= 0`. The diagonal of the operator. +* <b>`diag`</b>: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. + The diagonal of the operator. Allowed dtypes: `float32`, `float64`, + `complex64`, `complex128`. * <b>`is_non_singular`</b>: Expect that this operator is non-singular. * <b>`is_self_adjoint`</b>: Expect that this operator is equal to its hermitian transpose. Since this is a real (not complex) diagonal operator, it is always self adjoint. -* <b>`is_positive_definite`</b>: Expect that this operator is positive definite. -* <b>`name`</b>: A name for this `LinearOperator`. Default: subclass name. +* <b>`is_positive_definite`</b>: Expect that this operator is positive definite, + meaning the real part of all eigenvalues is positive. We do not require + the operator to be self-adjoint to be positive-definite. See: +* <b>`https`</b>: //en.wikipedia.org/wiki/Positive-definite_matrix + #Extension_for_non_symmetric_matrices +* <b>`name`</b>: A name for this `LinearOperator`. ##### Raises: -* <b>`ValueError`</b>: If `diag.dtype` is not floating point. +* <b>`TypeError`</b>: If `diag.dtype` is not an allowed type. * <b>`ValueError`</b>: If `is_self_adjoint` is not `True`. - - - +#### `tf.contrib.linalg.LinearOperatorDiag.add_to_tensor(x, name='add_to_tensor')` {#LinearOperatorDiag.add_to_tensor} + +Add matrix represented by this operator to `x`. Equivalent to `A + x`. + +##### Args: + + +* <b>`x`</b>: `Tensor` with same `dtype` and shape broadcastable to `self.shape`. +* <b>`name`</b>: A name to give this `Op`. + +##### Returns: + + A `Tensor` with broadcast shape and same `dtype` as `self`. + + +- - - + #### `tf.contrib.linalg.LinearOperatorDiag.apply(x, adjoint=False, name='apply')` {#LinearOperatorDiag.apply} Transform `x` with left multiplication: `x --> Ax`. @@ -128,6 +153,25 @@ Returns an `Op` that asserts this operator is non singular. Returns an `Op` that asserts this operator is positive definite. +Here, positive definite means the real part of all eigenvalues is positive. +We do not require the operator to be self-adjoint. + +##### Args: + + +* <b>`name`</b>: A name to give this `Op`. + +##### Returns: + + An `Op` that asserts this operator is positive definite. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorDiag.assert_self_adjoint(name='assert_self_adjoint')` {#LinearOperatorDiag.assert_self_adjoint} + +Returns an `Op` that asserts this operator is self-adjoint. + - - - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.training.weighted_resample.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.training.weighted_resample.md index 4977793e37..903cad838b 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.training.weighted_resample.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.training.weighted_resample.md @@ -1,4 +1,4 @@ -### `tf.contrib.training.weighted_resample(inputs, weights, overall_rate, scope=None, mean_decay=0.999, warmup=10, seed=None)` {#weighted_resample} +### `tf.contrib.training.weighted_resample(inputs, weights, overall_rate, scope=None, mean_decay=0.999, seed=None)` {#weighted_resample} Performs an approximate weighted resampling of `inputs`. @@ -15,9 +15,6 @@ rate of selection across all inputs (and many invocations!) is * <b>`overall_rate`</b>: Desired overall rate of resampling. * <b>`scope`</b>: Scope to use for the op. * <b>`mean_decay`</b>: How quickly to decay the running estimate of the mean weight. -* <b>`warmup`</b>: Until the resulting tensor has been evaluated `warmup` - times, the resampling menthod uses the true mean over all calls - as its weight estimate, rather than a decayed mean. * <b>`seed`</b>: Random seed. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.summary.FileWriterCache.clear.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.summary.FileWriterCache.clear.md new file mode 100644 index 0000000000..e3c7027813 --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.summary.FileWriterCache.clear.md @@ -0,0 +1,4 @@ +#### `tf.summary.FileWriterCache.clear()` {#FileWriterCache.clear} + +Clear cached summary writers. Currently only used for unit tests. + diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.ExponentialMovingAverage.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.ExponentialMovingAverage.md index 3c8fd4c447..f1f89fff93 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.ExponentialMovingAverage.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.ExponentialMovingAverage.md @@ -80,7 +80,7 @@ saver.restore(...checkpoint filename...) - - - -#### `tf.train.ExponentialMovingAverage.__init__(decay, num_updates=None, name='ExponentialMovingAverage')` {#ExponentialMovingAverage.__init__} +#### `tf.train.ExponentialMovingAverage.__init__(decay, num_updates=None, zero_debias=False, name='ExponentialMovingAverage')` {#ExponentialMovingAverage.__init__} Creates a new ExponentialMovingAverage object. @@ -100,6 +100,8 @@ move faster. If passed, the actual decay rate used is: * <b>`decay`</b>: Float. The decay to use. * <b>`num_updates`</b>: Optional count of number of updates applied to variables. +* <b>`zero_debias`</b>: If `True`, zero debias moving-averages that are initialized + with tensors. * <b>`name`</b>: String. Optional prefix name to use for the name of ops added in `apply()`. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md index 2ce825eb7b..aac10e4396 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md @@ -299,8 +299,19 @@ dynamic condition of the `Tensor`. #### `tf.Tensor.__div__(x, y)` {#Tensor.__div__} +Divide two values using Python 2 semantics. Used for Tensor.__div__. + +##### Args: +* <b>`x`</b>: `Tensor` numerator of real numeric type. +* <b>`y`</b>: `Tensor` denominator of real numeric type. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + `x / y` returns the quotient of x and y. + - - - @@ -708,7 +719,18 @@ Returns the truth value of x AND y element-wise. #### `tf.Tensor.__rdiv__(y, x)` {#Tensor.__rdiv__} +Divide two values using Python 2 semantics. Used for Tensor.__div__. + +##### Args: + + +* <b>`x`</b>: `Tensor` numerator of real numeric type. +* <b>`y`</b>: `Tensor` denominator of real numeric type. +* <b>`name`</b>: A name for the operation (optional). +##### Returns: + + `x / y` returns the quotient of x and y. - - - @@ -859,34 +881,7 @@ Returns x - y element-wise. #### `tf.Tensor.__rtruediv__(y, x)` {#Tensor.__rtruediv__} -Divides x / y elementwise, always producing floating point results. - -The same as `tf.div` for floating point arguments, but casts integer arguments -to floating point before dividing so that the result is always floating point. -This op is generated by normal `x / y` division in Python 3 and in Python 2.7 -with `from __future__ import division`. If you want integer division that -rounds down, use `x // y` or `tf.floordiv`. - -`x` and `y` must have the same numeric type. If the inputs are floating -point, the output will have the same type. If the inputs are integral, the -inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32` -and `int64` (matching the behavior of Numpy). - -##### Args: - - -* <b>`x`</b>: `Tensor` numerator of numeric type. -* <b>`y`</b>: `Tensor` denominator of numeric type. -* <b>`name`</b>: A name for the operation (optional). - -##### Returns: - `x / y` evaluated in floating point. - -##### Raises: - - -* <b>`TypeError`</b>: If `x` and `y` have different dtypes. - - - @@ -928,34 +923,7 @@ Returns x - y element-wise. #### `tf.Tensor.__truediv__(x, y)` {#Tensor.__truediv__} -Divides x / y elementwise, always producing floating point results. - -The same as `tf.div` for floating point arguments, but casts integer arguments -to floating point before dividing so that the result is always floating point. -This op is generated by normal `x / y` division in Python 3 and in Python 2.7 -with `from __future__ import division`. If you want integer division that -rounds down, use `x // y` or `tf.floordiv`. - -`x` and `y` must have the same numeric type. If the inputs are floating -point, the output will have the same type. If the inputs are integral, the -inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32` -and `int64` (matching the behavior of Numpy). - -##### Args: - - -* <b>`x`</b>: `Tensor` numerator of numeric type. -* <b>`y`</b>: `Tensor` denominator of numeric type. -* <b>`name`</b>: A name for the operation (optional). - -##### Returns: - - `x / y` evaluated in floating point. - -##### Raises: - -* <b>`TypeError`</b>: If `x` and `y` have different dtypes. - - - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.distributions.TransformedDistribution.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.distributions.TransformedDistribution.md index 4b4f4413b5..746339a662 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.distributions.TransformedDistribution.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.distributions.TransformedDistribution.md @@ -87,7 +87,7 @@ log_normal = ds.TransformedDistribution( forward_fn=tf.exp, inverse_fn=tf.log, inverse_log_det_jacobian_fn=( - lambda y: -tf.reduce_sum(tf.log(x), reduction_indices=-1)), + lambda y: -tf.reduce_sum(tf.log(y), reduction_indices=-1)), name="LogNormalTransformedDistribution") ``` @@ -109,7 +109,7 @@ Construct a Transformed Distribution. ##### Args: -* <b>`distribution`</b>: The base distribution class to transform. Typically an +* <b>`distribution`</b>: The base distribution instance to transform. Typically an instance of `Distribution`. * <b>`bijector`</b>: The object responsible for calculating the transformation. Typically an instance of `Bijector`. @@ -183,10 +183,10 @@ cdf(x) := P[X <= x] Additional documentation from `TransformedDistribution`: -##### <b>`condition_kwargs`</b>: +##### `condition_kwargs`: -* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector. -* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution. +* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector. +* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution. ##### Args: @@ -324,10 +324,10 @@ a more accurate answer than simply taking the logarithm of the `cdf` when Additional documentation from `TransformedDistribution`: -##### <b>`condition_kwargs`</b>: +##### `condition_kwargs`: -* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector. -* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution. +* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector. +* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution. ##### Args: @@ -408,10 +408,10 @@ Implements `(log o p o g^{-1})(y) + (log o det o J o g^{-1})(y)`, Also raises a `ValueError` if `inverse` was not provided to the distribution and `y` was not returned from `sample`. -##### <b>`condition_kwargs`</b>: +##### `condition_kwargs`: -* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector. -* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution. +* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector. +* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution. ##### Args: @@ -447,10 +447,10 @@ survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`. Additional documentation from `TransformedDistribution`: -##### <b>`condition_kwargs`</b>: +##### `condition_kwargs`: -* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector. -* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution. +* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector. +* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution. ##### Args: @@ -600,10 +600,10 @@ Implements `p(g^{-1}(y)) det|J(g^{-1}(y))|`, where `g^{-1}` is the Also raises a `ValueError` if `inverse` was not provided to the distribution and `y` was not returned from `sample`. -##### <b>`condition_kwargs`</b>: +##### `condition_kwargs`: -* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector. -* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution. +* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector. +* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution. ##### Args: @@ -654,10 +654,10 @@ Additional documentation from `TransformedDistribution`: Samples from the base distribution and then passes through the bijector's forward transform. -##### <b>`condition_kwargs`</b>: +##### `condition_kwargs`: -* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector. -* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution. +* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector. +* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution. ##### Args: @@ -703,10 +703,10 @@ survival_function(x) = P[X > x] Additional documentation from `TransformedDistribution`: -##### <b>`condition_kwargs`</b>: +##### `condition_kwargs`: -* <b>`bijector_kwargs`</b>: Python dictionary of arg names/values forwarded to the bijector. -* <b>`distribution_kwargs`</b>: Python dictionary of arg names/values forwarded to the distribution. +* `bijector_kwargs`: Python dictionary of arg names/values forwarded to the bijector. +* `distribution_kwargs`: Python dictionary of arg names/values forwarded to the distribution. ##### Args: diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.layers.embedding_column.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.layers.embedding_column.md index e4457e4be4..09a78073d6 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.layers.embedding_column.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.layers.embedding_column.md @@ -1,4 +1,4 @@ -### `tf.contrib.layers.embedding_column(sparse_id_column, dimension, combiner=None, initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None)` {#embedding_column} +### `tf.contrib.layers.embedding_column(sparse_id_column, dimension, combiner=None, initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None)` {#embedding_column} Creates an `_EmbeddingColumn` for feeding sparse data into a DNN. @@ -26,6 +26,8 @@ Creates an `_EmbeddingColumn` for feeding sparse data into a DNN. * <b>`tensor_name_in_ckpt`</b>: (Optional). Name of the `Tensor` in the provided checkpoint from which to restore the column weights. Required if `ckpt_to_load_from` is not None. +* <b>`max_norm`</b>: (Optional). If not None, embedding values are l2-normalized to + the value of max_norm. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.learn.LinearClassifier.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.learn.LinearClassifier.md index e5c7d7edf3..fbbab399f3 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.learn.LinearClassifier.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.learn.LinearClassifier.md @@ -141,6 +141,15 @@ See BaseEstimator.export. - - - +#### `tf.contrib.learn.LinearClassifier.export_savedmodel(*args, **kwargs)` {#LinearClassifier.export_savedmodel} + +EXPERIMENTAL FUNCTION + +THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning. + + +- - - + #### `tf.contrib.learn.LinearClassifier.fit(x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None)` {#LinearClassifier.fit} See trainable.Trainable. Note: Labels must be integer class indices. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.losses.compute_weighted_loss.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.losses.compute_weighted_loss.md index 9cbade4389..8cd3b0d69f 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.losses.compute_weighted_loss.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.losses.compute_weighted_loss.md @@ -9,6 +9,7 @@ Instructions for updating: Args: losses: A tensor of size [batch_size, d1, ... dN]. weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N. + scope: the scope for the operations performed in computing the loss. weight: Deprecated alias for `weights`. Returns: diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.framework.assert_scalar_int.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.framework.assert_scalar_int.md index 94d5355e10..469566f7b8 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.framework.assert_scalar_int.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.framework.assert_scalar_int.md @@ -1,11 +1,12 @@ -### `tf.contrib.framework.assert_scalar_int(tensor)` {#assert_scalar_int} +### `tf.contrib.framework.assert_scalar_int(tensor, name=None)` {#assert_scalar_int} Assert `tensor` is 0-D, of type `tf.int32` or `tf.int64`. ##### Args: -* <b>`tensor`</b>: Tensor to test. +* <b>`tensor`</b>: `Tensor` to test. +* <b>`name`</b>: Name of the op and of the new `Tensor` if one is created. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.layers.shared_embedding_columns.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.layers.shared_embedding_columns.md index 830f1bd352..0550580a0e 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.layers.shared_embedding_columns.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.layers.shared_embedding_columns.md @@ -1,4 +1,4 @@ -### `tf.contrib.layers.shared_embedding_columns(sparse_id_columns, dimension, combiner=None, shared_embedding_name=None, initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None)` {#shared_embedding_columns} +### `tf.contrib.layers.shared_embedding_columns(sparse_id_columns, dimension, combiner=None, shared_embedding_name=None, initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None)` {#shared_embedding_columns} Creates a list of `_EmbeddingColumn` sharing the same embedding. @@ -29,6 +29,8 @@ Creates a list of `_EmbeddingColumn` sharing the same embedding. * <b>`tensor_name_in_ckpt`</b>: (Optional). Name of the `Tensor` in the provided checkpoint from which to restore the column weights. Required if `ckpt_to_load_from` is not None. +* <b>`max_norm`</b>: (Optional). If not None, embedding values are l2-normalized to + the value of max_norm. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.learn.monitors.ExportMonitor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.learn.monitors.ExportMonitor.md index 647ef7e955..53992bdf4f 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.learn.monitors.ExportMonitor.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.learn.monitors.ExportMonitor.md @@ -18,9 +18,9 @@ The signature of the input_fn accepted by export is changing to be consistent wi `None`). input_feature_key: String key into the features dict returned by `input_fn` that corresponds to the raw `Example` strings `Tensor` that - the exported model will take as input. Can only be `None` if you're - using a custom `signature_fn` that does not use the first arg - (examples). + the exported model will take as input. Should be `None` if and only if + you're passing in a `signature_fn` that does not use the first arg + (`Tensor` of `Example` strings). exports_to_keep: int, number of exports to keep. signature_fn: Function that returns a default signature and a named signature map, given `Tensor` of `Example` strings, `dict` of `Tensor`s diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.string_split.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.string_split.md index 25607d1619..08ccc5f104 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.string_split.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.string_split.md @@ -8,7 +8,8 @@ containing the splitted tokens. Empty tokens are ignored. If `delimiter` is an empty string, each element of the `source` is split into individual strings, each containing one byte. (This includes splitting -multibyte sequences of UTF-8.) +multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is +treated as a set of delimiters with each considered a potential split point. For example: N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output @@ -29,14 +30,14 @@ st.values = ['hello', 'world', 'a', 'b', 'c'] * <b>`delimiter`</b>: `0-D` string `Tensor`, the delimiter character, the string should be length 0 or 1. +##### Raises: + + +* <b>`ValueError`</b>: If delimiter is not a string. + ##### Returns: A `SparseTensor` of rank `2`, the strings split according to the delimiter. The first column of the indices corresponds to the row in `source` and the second column corresponds to the index of the split component in this row. -##### Raises: - - -* <b>`ValueError`</b>: If delimiter is not a single-byte character. - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.learn.Estimator.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.learn.Estimator.md index aa3c101dbf..0f3006d9ca 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.learn.Estimator.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.learn.Estimator.md @@ -140,6 +140,39 @@ The signature of the input_fn accepted by export is changing to be consistent wi - - - +#### `tf.contrib.learn.Estimator.export_savedmodel(*args, **kwargs)` {#Estimator.export_savedmodel} + +Exports inference graph as a SavedModel into given dir. (experimental) + +THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning. + + + Args: + export_dir_base: A string containing a directory to write the exported + graph and checkpoints. + input_fn: A function that takes no argument and + returns an `InputFnOps`. + default_output_alternative_key: the name of the head to serve when none is + specified. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel. Each key should give the destination + path (including the filename) relative to the assets.extra directory. + The corresponding value gives the full path of the source file to be + copied. For example, the simple case of copying a single file without + renaming it is specified as + `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + as_text: whether to write the SavedModel proto in text format. + exports_to_keep: Number of exports to keep. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if an unrecognized export_type is requested. + + +- - - + #### `tf.contrib.learn.Estimator.fit(*args, **kwargs)` {#Estimator.fit} See `Trainable`. (deprecated arguments) diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.DNNClassifier.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.DNNClassifier.md index e9034e9115..f9fa7c70cb 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.DNNClassifier.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.DNNClassifier.md @@ -51,7 +51,7 @@ Input of `fit` and `evaluate` should have following features, whose `value` is a `Tensor`. - - - -#### `tf.contrib.learn.DNNClassifier.__init__(hidden_units, feature_columns, model_dir=None, n_classes=2, weight_column_name=None, optimizer=None, activation_fn=relu, dropout=None, gradient_clip_norm=None, enable_centered_bias=False, config=None, feature_engineering_fn=None)` {#DNNClassifier.__init__} +#### `tf.contrib.learn.DNNClassifier.__init__(hidden_units, feature_columns, model_dir=None, n_classes=2, weight_column_name=None, optimizer=None, activation_fn=relu, dropout=None, gradient_clip_norm=None, enable_centered_bias=False, config=None, feature_engineering_fn=None, embedding_lr_multipliers=None)` {#DNNClassifier.__init__} Initializes a DNNClassifier instance. @@ -91,6 +91,9 @@ Initializes a DNNClassifier instance. labels which are the output of `input_fn` and returns features and labels which will be fed into the model. +* <b>`embedding_lr_multipliers`</b>: Optional. A dictionary from `EmbeddingColumn` to + a `float` multiplier. Multiplier will be used to multiply with + learning rate for the embedding variables. ##### Returns: @@ -136,6 +139,15 @@ See BaseEstimator.export. - - - +#### `tf.contrib.learn.DNNClassifier.export_savedmodel(*args, **kwargs)` {#DNNClassifier.export_savedmodel} + +EXPERIMENTAL FUNCTION + +THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning. + + +- - - + #### `tf.contrib.learn.DNNClassifier.fit(x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None)` {#DNNClassifier.fit} See trainable.Trainable. Note: Labels must be integer class indices. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.select.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.select.md index af9d8dbb76..594855e8a8 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.select.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.select.md @@ -1,4 +1,4 @@ -### `tf.select(condition, t, e, name=None)` {#select} +### `tf.select(*args, **kwargs)` {#select} Selects elements from `t` or `e`, depending on `condition`. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.framework.variable.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.framework.variable.md index fba1f9071c..79081d4e9f 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.framework.variable.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.framework.variable.md @@ -21,6 +21,11 @@ Gets an existing variable with these parameters or creates a new one. device. * <b>`device`</b>: Optional device to place the variable. It can be an string or a function that is called to get the device for the variable. +* <b>`partitioner`</b>: Optional callable that accepts a fully defined `TensorShape` + and dtype of the `Variable` to be created, and returns a list of + partitions for each axis (currently only one axis can be partitioned). +* <b>`custom_getter`</b>: Callable that allows overwriting the internal + get_variable method and has to have the same signature. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.div.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.div.md index 0c0913dc09..8c25e24373 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.div.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.div.md @@ -1,4 +1,23 @@ ### `tf.div(x, y, name=None)` {#div} +Divides x / y elementwise (using Python 2 division operator semantics). +NOTE: Prefer using the Tensor division operator or tf.divide which obey Python +division operator semantics. + +This function divides `x` and `y`, forcing Python 2.7 semantics. That is, +if one of `x` or `y` is a float, then the result will be a float. +Otherwise, the output will be an integer type. Flooring semantics are used +for integer division. + +##### Args: + + +* <b>`x`</b>: `Tensor` numerator of real numeric type. +* <b>`y`</b>: `Tensor` denominator of real numeric type. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + `x / y` returns the quotient of x and y. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.make_template.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.make_template.md index 1f1e960f48..8e628e6067 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.make_template.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.make_template.md @@ -1,4 +1,4 @@ -### `tf.make_template(name_, func_, create_scope_now_=False, unique_name_=None, **kwargs)` {#make_template} +### `tf.make_template(name_, func_, create_scope_now_=False, unique_name_=None, custom_getter_=None, **kwargs)` {#make_template} Given an arbitrary function, wrap it so that it does variable sharing. @@ -89,6 +89,9 @@ reduce the likelihood of collisions with kwargs. * <b>`unique_name_`</b>: When used, it overrides name_ and is not made unique. If a template of the same scope/unique_name already exists and reuse is false, an error is raised. Defaults to None. +* <b>`custom_getter_`</b>: Optional custom getter for variables used in `func_`. See + the [`get_variable`](#get_variable) `custom_getter` documentation for + more information. * <b>`**kwargs`</b>: Keyword arguments to apply to `func_`. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.summary.FileWriterCache.get.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.summary.FileWriterCache.get.md new file mode 100644 index 0000000000..0f416a5909 --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.summary.FileWriterCache.get.md @@ -0,0 +1,13 @@ +#### `tf.summary.FileWriterCache.get(logdir)` {#FileWriterCache.get} + +Returns the FileWriter for the specified directory. + +##### Args: + + +* <b>`logdir`</b>: str, name of the directory. + +##### Returns: + + A `FileWriter`. + diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.SummaryWriterCache.get.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.SummaryWriterCache.get.md index a7bb580232..d5fee8f7b4 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.SummaryWriterCache.get.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.SummaryWriterCache.get.md @@ -1,6 +1,6 @@ #### `tf.train.SummaryWriterCache.get(logdir)` {#SummaryWriterCache.get} -Returns the SummaryWriter for the specified directory. +Returns the FileWriter for the specified directory. ##### Args: @@ -9,5 +9,5 @@ Returns the SummaryWriter for the specified directory. ##### Returns: - A `SummaryWriter`. + A `FileWriter`. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.summary.FileWriter.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.summary.FileWriter.md index 1ecf1822c9..6e80a4a562 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.summary.FileWriter.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.summary.FileWriter.md @@ -29,7 +29,7 @@ the graph from the session in which you launched it: # Launch the graph in a session. sess = tf.Session() # Create a summary writer, add the 'graph' to the event file. -writer = tf.train.SummaryWriter(<some-directory>, sess.graph) +writer = tf.summary.FileWriter(<some-directory>, sess.graph) ``` The other arguments to the constructor control the asynchronous writes to diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.summary.FileWriterCache.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.summary.FileWriterCache.md new file mode 100644 index 0000000000..3c6c8773b3 --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.summary.FileWriterCache.md @@ -0,0 +1,26 @@ +Cache for file writers. + +This class caches file writers, one per directory. +- - - + +#### `tf.summary.FileWriterCache.clear()` {#FileWriterCache.clear} + +Clear cached summary writers. Currently only used for unit tests. + + +- - - + +#### `tf.summary.FileWriterCache.get(logdir)` {#FileWriterCache.get} + +Returns the FileWriter for the specified directory. + +##### Args: + + +* <b>`logdir`</b>: str, name of the directory. + +##### Returns: + + A `FileWriter`. + + diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriter.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriter.md index cb96358aa1..e9bdda200f 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriter.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriter.md @@ -1,131 +1,109 @@ -Writes `Summary` protocol buffers to event files. - -The `FileWriter` class provides a mechanism to create an event file in a -given directory and add summaries and events to it. The class updates the -file contents asynchronously. This allows a training program to call methods -to add data to the file directly from the training loop, without slowing down -training. - - - -#### `tf.train.SummaryWriter.__init__(logdir, graph=None, max_queue=10, flush_secs=120, graph_def=None)` {#SummaryWriter.__init__} - -Creates a `FileWriter` and an event file. +#### `tf.train.SummaryWriter.__init__(*args, **kwargs)` {#SummaryWriter.__init__} -On construction the summary writer creates a new event file in `logdir`. -This event file will contain `Event` protocol buffers constructed when you -call one of the following functions: `add_summary()`, `add_session_log()`, -`add_event()`, or `add_graph()`. +Creates a `SummaryWriter` and an event file. (deprecated) -If you pass a `Graph` to the constructor it is added to -the event file. (This is equivalent to calling `add_graph()` later). +THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30. +Instructions for updating: +Please switch to tf.summary.FileWriter. The interface and behavior is the same; this is just a rename. -TensorBoard will pick the graph from the file and display it graphically so -you can interactively explore the graph you built. You will usually pass -the graph from the session in which you launched it: + This class is deprecated, and should be replaced with tf.summary.FileWriter. -```python -...create a graph... -# Launch the graph in a session. -sess = tf.Session() -# Create a summary writer, add the 'graph' to the event file. -writer = tf.train.SummaryWriter(<some-directory>, sess.graph) -``` + On construction the summary writer creates a new event file in `logdir`. + This event file will contain `Event` protocol buffers constructed when you + call one of the following functions: `add_summary()`, `add_session_log()`, + `add_event()`, or `add_graph()`. -The other arguments to the constructor control the asynchronous writes to -the event file: + If you pass a `Graph` to the constructor it is added to + the event file. (This is equivalent to calling `add_graph()` later). -* `flush_secs`: How often, in seconds, to flush the added summaries - and events to disk. -* `max_queue`: Maximum number of summaries or events pending to be - written to disk before one of the 'add' calls block. + TensorBoard will pick the graph from the file and display it graphically so + you can interactively explore the graph you built. You will usually pass + the graph from the session in which you launched it: -##### Args: + ```python + ...create a graph... + # Launch the graph in a session. + sess = tf.Session() + # Create a summary writer, add the 'graph' to the event file. + writer = tf.train.SummaryWriter(<some-directory>, sess.graph) + ``` + The other arguments to the constructor control the asynchronous writes to + the event file: -* <b>`logdir`</b>: A string. Directory where event file will be written. -* <b>`graph`</b>: A `Graph` object, such as `sess.graph`. -* <b>`max_queue`</b>: Integer. Size of the queue for pending events and summaries. -* <b>`flush_secs`</b>: Number. How often, in seconds, to flush the - pending events and summaries to disk. -* <b>`graph_def`</b>: DEPRECATED: Use the `graph` argument instead. + * `flush_secs`: How often, in seconds, to flush the added summaries + and events to disk. + * `max_queue`: Maximum number of summaries or events pending to be + written to disk before one of the 'add' calls block. + Args: + logdir: A string. Directory where event file will be written. + graph: A `Graph` object, such as `sess.graph`. + max_queue: Integer. Size of the queue for pending events and summaries. + flush_secs: Number. How often, in seconds, to flush the + pending events and summaries to disk. + graph_def: DEPRECATED: Use the `graph` argument instead. - - - -#### `tf.train.SummaryWriter.add_summary(summary, global_step=None)` {#SummaryWriter.add_summary} - -Adds a `Summary` protocol buffer to the event file. - -This method wraps the provided summary in an `Event` protocol buffer -and adds it to the event file. +#### `tf.train.SummaryWriter.add_event(event)` {#SummaryWriter.add_event} -You can pass the result of evaluating any summary op, using -[`Session.run()`](client.md#Session.run) or -[`Tensor.eval()`](framework.md#Tensor.eval), to this -function. Alternatively, you can pass a `tf.Summary` protocol -buffer that you populate with your own data. The latter is -commonly done to report evaluation results in event files. +Adds an event to the event file. ##### Args: -* <b>`summary`</b>: A `Summary` protocol buffer, optionally serialized as a string. -* <b>`global_step`</b>: Number. Optional global step value to record with the - summary. +* <b>`event`</b>: An `Event` protocol buffer. - - - -#### `tf.train.SummaryWriter.add_session_log(session_log, global_step=None)` {#SummaryWriter.add_session_log} +#### `tf.train.SummaryWriter.add_graph(graph, global_step=None, graph_def=None)` {#SummaryWriter.add_graph} -Adds a `SessionLog` protocol buffer to the event file. +Adds a `Graph` to the event file. -This method wraps the provided session in an `Event` protocol buffer -and adds it to the event file. +The graph described by the protocol buffer will be displayed by +TensorBoard. Most users pass a graph in the constructor instead. ##### Args: -* <b>`session_log`</b>: A `SessionLog` protocol buffer. -* <b>`global_step`</b>: Number. Optional global step value to record with the - summary. - - -- - - - -#### `tf.train.SummaryWriter.add_event(event)` {#SummaryWriter.add_event} - -Adds an event to the event file. +* <b>`graph`</b>: A `Graph` object, such as `sess.graph`. +* <b>`global_step`</b>: Number. Optional global step counter to record with the + graph. +* <b>`graph_def`</b>: DEPRECATED. Use the `graph` parameter instead. -##### Args: +##### Raises: -* <b>`event`</b>: An `Event` protocol buffer. +* <b>`ValueError`</b>: If both graph and graph_def are passed to the method. - - - -#### `tf.train.SummaryWriter.add_graph(graph, global_step=None, graph_def=None)` {#SummaryWriter.add_graph} +#### `tf.train.SummaryWriter.add_meta_graph(meta_graph_def, global_step=None)` {#SummaryWriter.add_meta_graph} -Adds a `Graph` to the event file. +Adds a `MetaGraphDef` to the event file. -The graph described by the protocol buffer will be displayed by -TensorBoard. Most users pass a graph in the constructor instead. +The `MetaGraphDef` allows running the given graph via +`saver.import_meta_graph()`. ##### Args: -* <b>`graph`</b>: A `Graph` object, such as `sess.graph`. +* <b>`meta_graph_def`</b>: A `MetaGraphDef` object, often as retured by + `saver.export_meta_graph()`. * <b>`global_step`</b>: Number. Optional global step counter to record with the graph. -* <b>`graph_def`</b>: DEPRECATED. Use the `graph` parameter instead. ##### Raises: -* <b>`ValueError`</b>: If both graph and graph_def are passed to the method. +* <b>`TypeError`</b>: If both `meta_graph_def` is not an instance of `MetaGraphDef`. - - - @@ -150,20 +128,43 @@ Adds a metadata information for a single session.run() call. - - - -#### `tf.train.SummaryWriter.get_logdir()` {#SummaryWriter.get_logdir} +#### `tf.train.SummaryWriter.add_session_log(session_log, global_step=None)` {#SummaryWriter.add_session_log} -Returns the directory where event file will be written. +Adds a `SessionLog` protocol buffer to the event file. + +This method wraps the provided session in an `Event` protocol buffer +and adds it to the event file. + +##### Args: +* <b>`session_log`</b>: A `SessionLog` protocol buffer. +* <b>`global_step`</b>: Number. Optional global step value to record with the + summary. + - - - -#### `tf.train.SummaryWriter.flush()` {#SummaryWriter.flush} +#### `tf.train.SummaryWriter.add_summary(summary, global_step=None)` {#SummaryWriter.add_summary} -Flushes the event file to disk. +Adds a `Summary` protocol buffer to the event file. -Call this method to make sure that all pending events have been written to -disk. +This method wraps the provided summary in an `Event` protocol buffer +and adds it to the event file. + +You can pass the result of evaluating any summary op, using +[`Session.run()`](client.md#Session.run) or +[`Tensor.eval()`](framework.md#Tensor.eval), to this +function. Alternatively, you can pass a `tf.Summary` protocol +buffer that you populate with your own data. The latter is +commonly done to report evaluation results in event files. + +##### Args: + + +* <b>`summary`</b>: A `Summary` protocol buffer, optionally serialized as a string. +* <b>`global_step`</b>: Number. Optional global step value to record with the + summary. - - - @@ -175,8 +176,23 @@ Flushes the event file to disk and close the file. Call this method when you do not need the summary writer anymore. +- - - + +#### `tf.train.SummaryWriter.flush()` {#SummaryWriter.flush} + +Flushes the event file to disk. + +Call this method to make sure that all pending events have been written to +disk. + + +- - - + +#### `tf.train.SummaryWriter.get_logdir()` {#SummaryWriter.get_logdir} + +Returns the directory where event file will be written. + -#### Other Methods - - - #### `tf.train.SummaryWriter.reopen()` {#SummaryWriter.reopen} diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriterCache.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriterCache.md index 4782bfac68..01136ac630 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriterCache.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriterCache.md @@ -1,6 +1,6 @@ -Cache for summary writers. +Cache for file writers. -This class caches summary writers, one per directory. +This class caches file writers, one per directory. - - - #### `tf.train.SummaryWriterCache.clear()` {#SummaryWriterCache.clear} @@ -12,7 +12,7 @@ Clear cached summary writers. Currently only used for unit tests. #### `tf.train.SummaryWriterCache.get(logdir)` {#SummaryWriterCache.get} -Returns the SummaryWriter for the specified directory. +Returns the FileWriter for the specified directory. ##### Args: @@ -21,6 +21,6 @@ Returns the SummaryWriter for the specified directory. ##### Returns: - A `SummaryWriter`. + A `FileWriter`. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.truediv.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.truediv.md index 0ccb1b2217..7a0c7a4aac 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.truediv.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.truediv.md @@ -1,12 +1,15 @@ ### `tf.truediv(x, y, name=None)` {#truediv} -Divides x / y elementwise, always producing floating point results. +Divides x / y elementwise (using Python 3 division operator semantics). -The same as `tf.div` for floating point arguments, but casts integer arguments -to floating point before dividing so that the result is always floating point. -This op is generated by normal `x / y` division in Python 3 and in Python 2.7 -with `from __future__ import division`. If you want integer division that -rounds down, use `x // y` or `tf.floordiv`. +NOTE: Prefer using the Tensor operator or tf.divide which obey Python +division operator semantics. + +This function forces Python 3 division operator semantics where all integer +arguments are cast to floating types first. This op is generated by normal +`x / y` division in Python 3 and in Python 2.7 with +`from __future__ import division`. If you want integer division that rounds +down, use `x // y` or `tf.floordiv`. `x` and `y` must have the same numeric type. If the inputs are floating point, the output will have the same type. If the inputs are integral, the diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.GraphKeys.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.GraphKeys.md index f647b524d4..7ec18e834c 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.GraphKeys.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.GraphKeys.md @@ -27,7 +27,7 @@ The following standard keys are defined: for more details. * `SUMMARIES`: the summary `Tensor` objects that have been created in the graph. See - [`tf.merge_all_summaries()`](../../api_docs/python/train.md#merge_all_summaries) + [`tf.contrib.deprecated.merge_all_summaries()`](../../api_docs/python/train.md#merge_all_summaries) for more details. * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to produce input for a computation. See diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Session.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Session.md index d9de06d5d0..1c183cb120 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Session.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Session.md @@ -87,8 +87,8 @@ and evaluate every `Tensor` in `fetches`, substituting the values in `feed_dict` for the corresponding input values. The `fetches` argument may be a single graph element, or an arbitrarily -nested list, tuple, namedtuple, or dict containing graph elements at its -leaves. A graph element can be one of the following types: +nested list, tuple, namedtuple, dict, or OrderedDict containing graph +elements at its leaves. A graph element can be one of the following types: * An [`Operation`](../../api_docs/python/framework.md#Operation). The corresponding fetched value will be `None`. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Variable.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Variable.md index 648606c3db..9b58355a38 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Variable.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.Variable.md @@ -56,7 +56,7 @@ all the variables. You then run that Op after launching the graph. ```python # Add an Op to initialize global variables. -init_op = tf.global_variable_initializers() +init_op = tf.global_variables_initializer() # Launch the graph in a session. with tf.Session() as sess: @@ -154,6 +154,10 @@ Returns the value of the initialized variable. You should use this instead of the variable itself to initialize another variable with a value that depends on the value of this variable. +Beware of using initialized_value except during initialization: +initialized_value causes the Variable's initializer op to be run, so running +this op resets the variable to the initial value. + ```python # Initialize 'v' with a random tensor. v = tf.Variable(tf.truncated_normal([10, 40])) @@ -455,7 +459,18 @@ Returns the truth value of x AND y element-wise. #### `tf.Variable.__div__(a, *args)` {#Variable.__div__} +Divide two values using Python 2 semantics. Used for Tensor.__div__. + +##### Args: + + +* <b>`x`</b>: `Tensor` numerator of real numeric type. +* <b>`y`</b>: `Tensor` denominator of real numeric type. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + `x / y` returns the quotient of x and y. - - - @@ -807,7 +822,18 @@ Returns the truth value of x AND y element-wise. #### `tf.Variable.__rdiv__(a, *args)` {#Variable.__rdiv__} +Divide two values using Python 2 semantics. Used for Tensor.__div__. +##### Args: + + +* <b>`x`</b>: `Tensor` numerator of real numeric type. +* <b>`y`</b>: `Tensor` denominator of real numeric type. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + `x / y` returns the quotient of x and y. - - - @@ -951,34 +977,7 @@ Returns x - y element-wise. #### `tf.Variable.__rtruediv__(a, *args)` {#Variable.__rtruediv__} -Divides x / y elementwise, always producing floating point results. - -The same as `tf.div` for floating point arguments, but casts integer arguments -to floating point before dividing so that the result is always floating point. -This op is generated by normal `x / y` division in Python 3 and in Python 2.7 -with `from __future__ import division`. If you want integer division that -rounds down, use `x // y` or `tf.floordiv`. -`x` and `y` must have the same numeric type. If the inputs are floating -point, the output will have the same type. If the inputs are integral, the -inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32` -and `int64` (matching the behavior of Numpy). - -##### Args: - - -* <b>`x`</b>: `Tensor` numerator of numeric type. -* <b>`y`</b>: `Tensor` denominator of numeric type. -* <b>`name`</b>: A name for the operation (optional). - -##### Returns: - - `x / y` evaluated in floating point. - -##### Raises: - - -* <b>`TypeError`</b>: If `x` and `y` have different dtypes. - - - @@ -1020,34 +1019,7 @@ Returns x - y element-wise. #### `tf.Variable.__truediv__(a, *args)` {#Variable.__truediv__} -Divides x / y elementwise, always producing floating point results. - -The same as `tf.div` for floating point arguments, but casts integer arguments -to floating point before dividing so that the result is always floating point. -This op is generated by normal `x / y` division in Python 3 and in Python 2.7 -with `from __future__ import division`. If you want integer division that -rounds down, use `x // y` or `tf.floordiv`. - -`x` and `y` must have the same numeric type. If the inputs are floating -point, the output will have the same type. If the inputs are integral, the -inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32` -and `int64` (matching the behavior of Numpy). - -##### Args: - - -* <b>`x`</b>: `Tensor` numerator of numeric type. -* <b>`y`</b>: `Tensor` denominator of numeric type. -* <b>`name`</b>: A name for the operation (optional). - -##### Returns: - - `x / y` evaluated in floating point. - -##### Raises: - -* <b>`TypeError`</b>: If `x` and `y` have different dtypes. - - - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.framework.model_variable.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.framework.model_variable.md index 2bbd4d5077..daa96911d9 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.framework.model_variable.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.framework.model_variable.md @@ -22,6 +22,11 @@ Gets an existing model variable with these parameters or creates a new one. device. * <b>`device`</b>: Optional device to place the variable. It can be an string or a function that is called to get the device for the variable. +* <b>`partitioner`</b>: Optional callable that accepts a fully defined `TensorShape` + and dtype of the `Variable` to be created, and returns a list of + partitions for each axis (currently only one axis can be partitioned). +* <b>`custom_getter`</b>: Callable that allows overwriting the internal + get_variable method and has to have the same signature. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.learn.LogisticRegressor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.learn.LogisticRegressor.md index dcbf0fbb1c..82a42aaf22 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.learn.LogisticRegressor.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.learn.LogisticRegressor.md @@ -125,6 +125,39 @@ The signature of the input_fn accepted by export is changing to be consistent wi - - - +#### `tf.contrib.learn.LogisticRegressor.export_savedmodel(*args, **kwargs)` {#LogisticRegressor.export_savedmodel} + +Exports inference graph as a SavedModel into given dir. (experimental) + +THIS FUNCTION IS EXPERIMENTAL. It may change or be removed at any time, and without warning. + + + Args: + export_dir_base: A string containing a directory to write the exported + graph and checkpoints. + input_fn: A function that takes no argument and + returns an `InputFnOps`. + default_output_alternative_key: the name of the head to serve when none is + specified. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel. Each key should give the destination + path (including the filename) relative to the assets.extra directory. + The corresponding value gives the full path of the source file to be + copied. For example, the simple case of copying a single file without + renaming it is specified as + `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + as_text: whether to write the SavedModel proto in text format. + exports_to_keep: Number of exports to keep. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if an unrecognized export_type is requested. + + +- - - + #### `tf.contrib.learn.LogisticRegressor.fit(*args, **kwargs)` {#LogisticRegressor.fit} See `Trainable`. (deprecated arguments) diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.linalg.LinearOperator.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.linalg.LinearOperator.md index 6624ad3ce9..eb8a200558 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.linalg.LinearOperator.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.linalg.LinearOperator.md @@ -84,6 +84,19 @@ FILL IN WHAT IS MEANT BY COMPATIBLE SHAPE ### Performance FILL THIS IN + +### Matrix property hints + +This `LinearOperator` is initialized with boolean flags of the form `is_X`, +for `X = non_singular, self_adjoint` etc... +These have the following meaning +* If `is_X == True`, callers should expect the operator to have the + property `X`. This is a promise that should be fulfilled, but is *not* a + runtime assert. For example, finite floating point precision may result + in these promises being violated. +* If `is_X == False`, callers should expect the operator to not have `X`. +* If `is_X == None` (the default), callers should have no expectation either + way. - - - #### `tf.contrib.linalg.LinearOperator.__init__(dtype, graph_parents=None, is_non_singular=None, is_self_adjoint=None, is_positive_definite=None, name=None)` {#LinearOperator.__init__} @@ -93,16 +106,6 @@ Initialize the `LinearOperator`. **This is a private method for subclass use.** **Subclasses should copy-paste this `__init__` documentation.** -For `X = non_singular, self_adjoint` etc... -`is_X` is a Python `bool` initialization argument with the following meaning -* If `is_X == True`, callers should expect the operator to have the - attribute `X`. This is a promise that should be fulfilled, but is *not* a - runtime assert. Issues, such as floating point error, could mean the - operator violates this promise. -* If `is_X == False`, callers should expect the operator to not have `X`. -* If `is_X == None` (the default), callers should have no expectation either - way. - ##### Args: @@ -113,8 +116,12 @@ For `X = non_singular, self_adjoint` etc... * <b>`is_non_singular`</b>: Expect that this operator is non-singular. * <b>`is_self_adjoint`</b>: Expect that this operator is equal to its hermitian transpose. If `dtype` is real, this is equivalent to being symmetric. -* <b>`is_positive_definite`</b>: Expect that this operator is positive definite. -* <b>`name`</b>: A name for this `LinearOperator`. Default: subclass name. +* <b>`is_positive_definite`</b>: Expect that this operator is positive definite, + meaning the real part of all eigenvalues is positive. We do not require + the operator to be self-adjoint to be positive-definite. See: +* <b>`https`</b>: //en.wikipedia.org/wiki/Positive-definite_matrix\ + #Extension_for_non_symmetric_matrices +* <b>`name`</b>: A name for this `LinearOperator`. ##### Raises: @@ -124,6 +131,23 @@ For `X = non_singular, self_adjoint` etc... - - - +#### `tf.contrib.linalg.LinearOperator.add_to_tensor(x, name='add_to_tensor')` {#LinearOperator.add_to_tensor} + +Add matrix represented by this operator to `x`. Equivalent to `A + x`. + +##### Args: + + +* <b>`x`</b>: `Tensor` with same `dtype` and shape broadcastable to `self.shape`. +* <b>`name`</b>: A name to give this `Op`. + +##### Returns: + + A `Tensor` with broadcast shape and same `dtype` as `self`. + + +- - - + #### `tf.contrib.linalg.LinearOperator.apply(x, adjoint=False, name='apply')` {#LinearOperator.apply} Transform `x` with left multiplication: `x --> Ax`. @@ -154,6 +178,25 @@ Returns an `Op` that asserts this operator is non singular. Returns an `Op` that asserts this operator is positive definite. +Here, positive definite means the real part of all eigenvalues is positive. +We do not require the operator to be self-adjoint. + +##### Args: + + +* <b>`name`</b>: A name to give this `Op`. + +##### Returns: + + An `Op` that asserts this operator is positive definite. + + +- - - + +#### `tf.contrib.linalg.LinearOperator.assert_self_adjoint(name='assert_self_adjoint')` {#LinearOperator.assert_self_adjoint} + +Returns an `Op` that asserts this operator is self-adjoint. + - - - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.nn.zero_fraction.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.nn.zero_fraction.md index 625ff785cb..766341a73f 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.nn.zero_fraction.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.nn.zero_fraction.md @@ -8,7 +8,7 @@ This is useful in summaries to measure and report sparsity. For example, ```python z = tf.Relu(...) - summ = tf.scalar_summary('sparsity', tf.nn.zero_fraction(z)) + summ = tf.contrib.deprecated.scalar_summary('sparsity', tf.nn.zero_fraction(z)) ``` ##### Args: diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.summary_iterator.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.summary_iterator.md index 5702571441..f998e62046 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.summary_iterator.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.summary_iterator.md @@ -18,7 +18,7 @@ Example: Print selected summary values. # This example supposes that the events file contains summaries with a # summary value tag 'loss'. These could have been added by calling # `add_summary()`, passing the output of a scalar summary op created with -# with: `tf.scalar_summary(['loss'], loss_tensor)`. +# with: `tf.summary.scalar('loss', loss_tensor)`. for e in tf.train.summary_iterator(path to events file): for v in e.summary.value: if v.tag == 'loss': diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md index a4490f22db..f3127013bf 100644 --- a/tensorflow/g3doc/api_docs/python/index.md +++ b/tensorflow/g3doc/api_docs/python/index.md @@ -687,6 +687,7 @@ * **[Summary Operations](../../api_docs/python/summary.md)**: * [`audio`](../../api_docs/python/summary.md#audio) * [`FileWriter`](../../api_docs/python/summary.md#FileWriter) + * [`FileWriterCache`](../../api_docs/python/summary.md#FileWriterCache) * [`get_summary_description`](../../api_docs/python/summary.md#get_summary_description) * [`histogram`](../../api_docs/python/summary.md#histogram) * [`image`](../../api_docs/python/summary.md#image) @@ -957,7 +958,6 @@ * [`sparse_column_with_hash_bucket`](../../api_docs/python/contrib.layers.md#sparse_column_with_hash_bucket) * [`sparse_column_with_integerized_feature`](../../api_docs/python/contrib.layers.md#sparse_column_with_integerized_feature) * [`sparse_column_with_keys`](../../api_docs/python/contrib.layers.md#sparse_column_with_keys) - * [`stack`](../../api_docs/python/contrib.layers.md#stack) * [`sum_regularizer`](../../api_docs/python/contrib.layers.md#sum_regularizer) * [`summarize_activation`](../../api_docs/python/contrib.layers.md#summarize_activation) * [`summarize_activations`](../../api_docs/python/contrib.layers.md#summarize_activations) diff --git a/tensorflow/g3doc/api_docs/python/math_ops.md b/tensorflow/g3doc/api_docs/python/math_ops.md index 27e8f7ed82..994ff48c7a 100644 --- a/tensorflow/g3doc/api_docs/python/math_ops.md +++ b/tensorflow/g3doc/api_docs/python/math_ops.md @@ -108,7 +108,26 @@ multiply with arbitrary tensors. ### `tf.div(x, y, name=None)` {#div} +Divides x / y elementwise (using Python 2 division operator semantics). +NOTE: Prefer using the Tensor division operator or tf.divide which obey Python +division operator semantics. + +This function divides `x` and `y`, forcing Python 2.7 semantics. That is, +if one of `x` or `y` is a float, then the result will be a float. +Otherwise, the output will be an integer type. Flooring semantics are used +for integer division. + +##### Args: + + +* <b>`x`</b>: `Tensor` numerator of real numeric type. +* <b>`y`</b>: `Tensor` denominator of real numeric type. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + `x / y` returns the quotient of x and y. - - - @@ -122,13 +141,16 @@ Computes Python style division of `x` by `y`. ### `tf.truediv(x, y, name=None)` {#truediv} -Divides x / y elementwise, always producing floating point results. +Divides x / y elementwise (using Python 3 division operator semantics). + +NOTE: Prefer using the Tensor operator or tf.divide which obey Python +division operator semantics. -The same as `tf.div` for floating point arguments, but casts integer arguments -to floating point before dividing so that the result is always floating point. -This op is generated by normal `x / y` division in Python 3 and in Python 2.7 -with `from __future__ import division`. If you want integer division that -rounds down, use `x // y` or `tf.floordiv`. +This function forces Python 3 division operator semantics where all integer +arguments are cast to floating types first. This op is generated by normal +`x / y` division in Python 3 and in Python 2.7 with +`from __future__ import division`. If you want integer division that rounds +down, use `x // y` or `tf.floordiv`. `x` and `y` must have the same numeric type. If the inputs are floating point, the output will have the same type. If the inputs are integral, the diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md index 7e440b887d..d5173bdb19 100644 --- a/tensorflow/g3doc/api_docs/python/state_ops.md +++ b/tensorflow/g3doc/api_docs/python/state_ops.md @@ -71,7 +71,7 @@ all the variables. You then run that Op after launching the graph. ```python # Add an Op to initialize global variables. -init_op = tf.global_variable_initializers() +init_op = tf.global_variables_initializer() # Launch the graph in a session. with tf.Session() as sess: @@ -169,6 +169,10 @@ Returns the value of the initialized variable. You should use this instead of the variable itself to initialize another variable with a value that depends on the value of this variable. +Beware of using initialized_value except during initialization: +initialized_value causes the Variable's initializer op to be run, so running +this op resets the variable to the initial value. + ```python # Initialize 'v' with a random tensor. v = tf.Variable(tf.truncated_normal([10, 40])) @@ -470,7 +474,18 @@ Returns the truth value of x AND y element-wise. #### `tf.Variable.__div__(a, *args)` {#Variable.__div__} +Divide two values using Python 2 semantics. Used for Tensor.__div__. + +##### Args: + + +* <b>`x`</b>: `Tensor` numerator of real numeric type. +* <b>`y`</b>: `Tensor` denominator of real numeric type. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + `x / y` returns the quotient of x and y. - - - @@ -822,7 +837,18 @@ Returns the truth value of x AND y element-wise. #### `tf.Variable.__rdiv__(a, *args)` {#Variable.__rdiv__} +Divide two values using Python 2 semantics. Used for Tensor.__div__. +##### Args: + + +* <b>`x`</b>: `Tensor` numerator of real numeric type. +* <b>`y`</b>: `Tensor` denominator of real numeric type. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + `x / y` returns the quotient of x and y. - - - @@ -966,34 +992,7 @@ Returns x - y element-wise. #### `tf.Variable.__rtruediv__(a, *args)` {#Variable.__rtruediv__} -Divides x / y elementwise, always producing floating point results. - -The same as `tf.div` for floating point arguments, but casts integer arguments -to floating point before dividing so that the result is always floating point. -This op is generated by normal `x / y` division in Python 3 and in Python 2.7 -with `from __future__ import division`. If you want integer division that -rounds down, use `x // y` or `tf.floordiv`. -`x` and `y` must have the same numeric type. If the inputs are floating -point, the output will have the same type. If the inputs are integral, the -inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32` -and `int64` (matching the behavior of Numpy). - -##### Args: - - -* <b>`x`</b>: `Tensor` numerator of numeric type. -* <b>`y`</b>: `Tensor` denominator of numeric type. -* <b>`name`</b>: A name for the operation (optional). - -##### Returns: - - `x / y` evaluated in floating point. - -##### Raises: - - -* <b>`TypeError`</b>: If `x` and `y` have different dtypes. - - - @@ -1035,34 +1034,7 @@ Returns x - y element-wise. #### `tf.Variable.__truediv__(a, *args)` {#Variable.__truediv__} -Divides x / y elementwise, always producing floating point results. - -The same as `tf.div` for floating point arguments, but casts integer arguments -to floating point before dividing so that the result is always floating point. -This op is generated by normal `x / y` division in Python 3 and in Python 2.7 -with `from __future__ import division`. If you want integer division that -rounds down, use `x // y` or `tf.floordiv`. - -`x` and `y` must have the same numeric type. If the inputs are floating -point, the output will have the same type. If the inputs are integral, the -inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32` -and `int64` (matching the behavior of Numpy). - -##### Args: - - -* <b>`x`</b>: `Tensor` numerator of numeric type. -* <b>`y`</b>: `Tensor` denominator of numeric type. -* <b>`name`</b>: A name for the operation (optional). - -##### Returns: - - `x / y` evaluated in floating point. - -##### Raises: - -* <b>`TypeError`</b>: If `x` and `y` have different dtypes. - - - @@ -2268,7 +2240,7 @@ Returns the current variable scope. - - - -### `tf.make_template(name_, func_, create_scope_now_=False, unique_name_=None, **kwargs)` {#make_template} +### `tf.make_template(name_, func_, create_scope_now_=False, unique_name_=None, custom_getter_=None, **kwargs)` {#make_template} Given an arbitrary function, wrap it so that it does variable sharing. @@ -2359,6 +2331,9 @@ reduce the likelihood of collisions with kwargs. * <b>`unique_name_`</b>: When used, it overrides name_ and is not made unique. If a template of the same scope/unique_name already exists and reuse is false, an error is raised. Defaults to None. +* <b>`custom_getter_`</b>: Optional custom getter for variables used in `func_`. See + the [`get_variable`](#get_variable) `custom_getter` documentation for + more information. * <b>`**kwargs`</b>: Keyword arguments to apply to `func_`. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/string_ops.md b/tensorflow/g3doc/api_docs/python/string_ops.md index 86878ca664..7e75148891 100644 --- a/tensorflow/g3doc/api_docs/python/string_ops.md +++ b/tensorflow/g3doc/api_docs/python/string_ops.md @@ -194,7 +194,8 @@ containing the splitted tokens. Empty tokens are ignored. If `delimiter` is an empty string, each element of the `source` is split into individual strings, each containing one byte. (This includes splitting -multibyte sequences of UTF-8.) +multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is +treated as a set of delimiters with each considered a potential split point. For example: N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output @@ -215,17 +216,17 @@ st.values = ['hello', 'world', 'a', 'b', 'c'] * <b>`delimiter`</b>: `0-D` string `Tensor`, the delimiter character, the string should be length 0 or 1. +##### Raises: + + +* <b>`ValueError`</b>: If delimiter is not a string. + ##### Returns: A `SparseTensor` of rank `2`, the strings split according to the delimiter. The first column of the indices corresponds to the row in `source` and the second column corresponds to the index of the split component in this row. -##### Raises: - - -* <b>`ValueError`</b>: If delimiter is not a single-byte character. - - - - diff --git a/tensorflow/g3doc/api_docs/python/summary.md b/tensorflow/g3doc/api_docs/python/summary.md index 90598f7e8c..f20f876ca3 100644 --- a/tensorflow/g3doc/api_docs/python/summary.md +++ b/tensorflow/g3doc/api_docs/python/summary.md @@ -41,7 +41,7 @@ the graph from the session in which you launched it: # Launch the graph in a session. sess = tf.Session() # Create a summary writer, add the 'graph' to the event file. -writer = tf.train.SummaryWriter(<some-directory>, sess.graph) +writer = tf.summary.FileWriter(<some-directory>, sess.graph) ``` The other arguments to the constructor control the asynchronous writes to @@ -202,6 +202,37 @@ Does nothing if the EventFileWriter was not closed. +- - - + +### `class tf.summary.FileWriterCache` {#FileWriterCache} + +Cache for file writers. + +This class caches file writers, one per directory. +- - - + +#### `tf.summary.FileWriterCache.clear()` {#FileWriterCache.clear} + +Clear cached summary writers. Currently only used for unit tests. + + +- - - + +#### `tf.summary.FileWriterCache.get(logdir)` {#FileWriterCache.get} + +Returns the FileWriter for the specified directory. + +##### Args: + + +* <b>`logdir`</b>: str, name of the directory. + +##### Returns: + + A `FileWriter`. + + + ### Summary Ops - - - diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md index b336b8cfc9..fb4ff94f4f 100644 --- a/tensorflow/g3doc/api_docs/python/train.md +++ b/tensorflow/g3doc/api_docs/python/train.md @@ -1372,7 +1372,7 @@ saver.restore(...checkpoint filename...) - - - -#### `tf.train.ExponentialMovingAverage.__init__(decay, num_updates=None, name='ExponentialMovingAverage')` {#ExponentialMovingAverage.__init__} +#### `tf.train.ExponentialMovingAverage.__init__(decay, num_updates=None, zero_debias=False, name='ExponentialMovingAverage')` {#ExponentialMovingAverage.__init__} Creates a new ExponentialMovingAverage object. @@ -1392,6 +1392,8 @@ move faster. If passed, the actual decay rate used is: * <b>`decay`</b>: Float. The decay to use. * <b>`num_updates`</b>: Optional count of number of updates applied to variables. +* <b>`zero_debias`</b>: If `True`, zero debias moving-averages that are initialized + with tensors. * <b>`name`</b>: String. Optional prefix name to use for the name of ops added in `apply()`. @@ -4048,7 +4050,7 @@ This is useful in summaries to measure and report sparsity. For example, ```python z = tf.Relu(...) - summ = tf.scalar_summary('sparsity', tf.nn.zero_fraction(z)) + summ = tf.contrib.deprecated.scalar_summary('sparsity', tf.nn.zero_fraction(z)) ``` ##### Args: @@ -4122,134 +4124,112 @@ overview of summaries, event files, and visualization in TensorBoard. ### `class tf.train.SummaryWriter` {#SummaryWriter} -Writes `Summary` protocol buffers to event files. - -The `FileWriter` class provides a mechanism to create an event file in a -given directory and add summaries and events to it. The class updates the -file contents asynchronously. This allows a training program to call methods -to add data to the file directly from the training loop, without slowing down -training. - - - -#### `tf.train.SummaryWriter.__init__(logdir, graph=None, max_queue=10, flush_secs=120, graph_def=None)` {#SummaryWriter.__init__} +#### `tf.train.SummaryWriter.__init__(*args, **kwargs)` {#SummaryWriter.__init__} -Creates a `FileWriter` and an event file. +Creates a `SummaryWriter` and an event file. (deprecated) -On construction the summary writer creates a new event file in `logdir`. -This event file will contain `Event` protocol buffers constructed when you -call one of the following functions: `add_summary()`, `add_session_log()`, -`add_event()`, or `add_graph()`. +THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30. +Instructions for updating: +Please switch to tf.summary.FileWriter. The interface and behavior is the same; this is just a rename. -If you pass a `Graph` to the constructor it is added to -the event file. (This is equivalent to calling `add_graph()` later). + This class is deprecated, and should be replaced with tf.summary.FileWriter. -TensorBoard will pick the graph from the file and display it graphically so -you can interactively explore the graph you built. You will usually pass -the graph from the session in which you launched it: + On construction the summary writer creates a new event file in `logdir`. + This event file will contain `Event` protocol buffers constructed when you + call one of the following functions: `add_summary()`, `add_session_log()`, + `add_event()`, or `add_graph()`. -```python -...create a graph... -# Launch the graph in a session. -sess = tf.Session() -# Create a summary writer, add the 'graph' to the event file. -writer = tf.train.SummaryWriter(<some-directory>, sess.graph) -``` + If you pass a `Graph` to the constructor it is added to + the event file. (This is equivalent to calling `add_graph()` later). -The other arguments to the constructor control the asynchronous writes to -the event file: + TensorBoard will pick the graph from the file and display it graphically so + you can interactively explore the graph you built. You will usually pass + the graph from the session in which you launched it: -* `flush_secs`: How often, in seconds, to flush the added summaries - and events to disk. -* `max_queue`: Maximum number of summaries or events pending to be - written to disk before one of the 'add' calls block. - -##### Args: + ```python + ...create a graph... + # Launch the graph in a session. + sess = tf.Session() + # Create a summary writer, add the 'graph' to the event file. + writer = tf.train.SummaryWriter(<some-directory>, sess.graph) + ``` + The other arguments to the constructor control the asynchronous writes to + the event file: -* <b>`logdir`</b>: A string. Directory where event file will be written. -* <b>`graph`</b>: A `Graph` object, such as `sess.graph`. -* <b>`max_queue`</b>: Integer. Size of the queue for pending events and summaries. -* <b>`flush_secs`</b>: Number. How often, in seconds, to flush the - pending events and summaries to disk. -* <b>`graph_def`</b>: DEPRECATED: Use the `graph` argument instead. + * `flush_secs`: How often, in seconds, to flush the added summaries + and events to disk. + * `max_queue`: Maximum number of summaries or events pending to be + written to disk before one of the 'add' calls block. + Args: + logdir: A string. Directory where event file will be written. + graph: A `Graph` object, such as `sess.graph`. + max_queue: Integer. Size of the queue for pending events and summaries. + flush_secs: Number. How often, in seconds, to flush the + pending events and summaries to disk. + graph_def: DEPRECATED: Use the `graph` argument instead. - - - -#### `tf.train.SummaryWriter.add_summary(summary, global_step=None)` {#SummaryWriter.add_summary} - -Adds a `Summary` protocol buffer to the event file. - -This method wraps the provided summary in an `Event` protocol buffer -and adds it to the event file. +#### `tf.train.SummaryWriter.add_event(event)` {#SummaryWriter.add_event} -You can pass the result of evaluating any summary op, using -[`Session.run()`](client.md#Session.run) or -[`Tensor.eval()`](framework.md#Tensor.eval), to this -function. Alternatively, you can pass a `tf.Summary` protocol -buffer that you populate with your own data. The latter is -commonly done to report evaluation results in event files. +Adds an event to the event file. ##### Args: -* <b>`summary`</b>: A `Summary` protocol buffer, optionally serialized as a string. -* <b>`global_step`</b>: Number. Optional global step value to record with the - summary. +* <b>`event`</b>: An `Event` protocol buffer. - - - -#### `tf.train.SummaryWriter.add_session_log(session_log, global_step=None)` {#SummaryWriter.add_session_log} +#### `tf.train.SummaryWriter.add_graph(graph, global_step=None, graph_def=None)` {#SummaryWriter.add_graph} -Adds a `SessionLog` protocol buffer to the event file. +Adds a `Graph` to the event file. -This method wraps the provided session in an `Event` protocol buffer -and adds it to the event file. +The graph described by the protocol buffer will be displayed by +TensorBoard. Most users pass a graph in the constructor instead. ##### Args: -* <b>`session_log`</b>: A `SessionLog` protocol buffer. -* <b>`global_step`</b>: Number. Optional global step value to record with the - summary. - - -- - - - -#### `tf.train.SummaryWriter.add_event(event)` {#SummaryWriter.add_event} - -Adds an event to the event file. +* <b>`graph`</b>: A `Graph` object, such as `sess.graph`. +* <b>`global_step`</b>: Number. Optional global step counter to record with the + graph. +* <b>`graph_def`</b>: DEPRECATED. Use the `graph` parameter instead. -##### Args: +##### Raises: -* <b>`event`</b>: An `Event` protocol buffer. +* <b>`ValueError`</b>: If both graph and graph_def are passed to the method. - - - -#### `tf.train.SummaryWriter.add_graph(graph, global_step=None, graph_def=None)` {#SummaryWriter.add_graph} +#### `tf.train.SummaryWriter.add_meta_graph(meta_graph_def, global_step=None)` {#SummaryWriter.add_meta_graph} -Adds a `Graph` to the event file. +Adds a `MetaGraphDef` to the event file. -The graph described by the protocol buffer will be displayed by -TensorBoard. Most users pass a graph in the constructor instead. +The `MetaGraphDef` allows running the given graph via +`saver.import_meta_graph()`. ##### Args: -* <b>`graph`</b>: A `Graph` object, such as `sess.graph`. +* <b>`meta_graph_def`</b>: A `MetaGraphDef` object, often as retured by + `saver.export_meta_graph()`. * <b>`global_step`</b>: Number. Optional global step counter to record with the graph. -* <b>`graph_def`</b>: DEPRECATED. Use the `graph` parameter instead. ##### Raises: -* <b>`ValueError`</b>: If both graph and graph_def are passed to the method. +* <b>`TypeError`</b>: If both `meta_graph_def` is not an instance of `MetaGraphDef`. - - - @@ -4274,20 +4254,43 @@ Adds a metadata information for a single session.run() call. - - - -#### `tf.train.SummaryWriter.get_logdir()` {#SummaryWriter.get_logdir} +#### `tf.train.SummaryWriter.add_session_log(session_log, global_step=None)` {#SummaryWriter.add_session_log} -Returns the directory where event file will be written. +Adds a `SessionLog` protocol buffer to the event file. +This method wraps the provided session in an `Event` protocol buffer +and adds it to the event file. + +##### Args: + + +* <b>`session_log`</b>: A `SessionLog` protocol buffer. +* <b>`global_step`</b>: Number. Optional global step value to record with the + summary. - - - -#### `tf.train.SummaryWriter.flush()` {#SummaryWriter.flush} +#### `tf.train.SummaryWriter.add_summary(summary, global_step=None)` {#SummaryWriter.add_summary} -Flushes the event file to disk. +Adds a `Summary` protocol buffer to the event file. -Call this method to make sure that all pending events have been written to -disk. +This method wraps the provided summary in an `Event` protocol buffer +and adds it to the event file. + +You can pass the result of evaluating any summary op, using +[`Session.run()`](client.md#Session.run) or +[`Tensor.eval()`](framework.md#Tensor.eval), to this +function. Alternatively, you can pass a `tf.Summary` protocol +buffer that you populate with your own data. The latter is +commonly done to report evaluation results in event files. + +##### Args: + + +* <b>`summary`</b>: A `Summary` protocol buffer, optionally serialized as a string. +* <b>`global_step`</b>: Number. Optional global step value to record with the + summary. - - - @@ -4299,8 +4302,23 @@ Flushes the event file to disk and close the file. Call this method when you do not need the summary writer anymore. +- - - + +#### `tf.train.SummaryWriter.flush()` {#SummaryWriter.flush} + +Flushes the event file to disk. + +Call this method to make sure that all pending events have been written to +disk. + + +- - - + +#### `tf.train.SummaryWriter.get_logdir()` {#SummaryWriter.get_logdir} + +Returns the directory where event file will be written. + -#### Other Methods - - - #### `tf.train.SummaryWriter.reopen()` {#SummaryWriter.reopen} @@ -4318,9 +4336,9 @@ Does nothing if the EventFileWriter was not closed. ### `class tf.train.SummaryWriterCache` {#SummaryWriterCache} -Cache for summary writers. +Cache for file writers. -This class caches summary writers, one per directory. +This class caches file writers, one per directory. - - - #### `tf.train.SummaryWriterCache.clear()` {#SummaryWriterCache.clear} @@ -4332,7 +4350,7 @@ Clear cached summary writers. Currently only used for unit tests. #### `tf.train.SummaryWriterCache.get(logdir)` {#SummaryWriterCache.get} -Returns the SummaryWriter for the specified directory. +Returns the FileWriter for the specified directory. ##### Args: @@ -4341,7 +4359,7 @@ Returns the SummaryWriter for the specified directory. ##### Returns: - A `SummaryWriter`. + A `FileWriter`. @@ -4367,7 +4385,7 @@ Example: Print selected summary values. # This example supposes that the events file contains summaries with a # summary value tag 'loss'. These could have been added by calling # `add_summary()`, passing the output of a scalar summary op created with -# with: `tf.scalar_summary(['loss'], loss_tensor)`. +# with: `tf.summary.scalar('loss', loss_tensor)`. for e in tf.train.summary_iterator(path to events file): for v in e.summary.value: if v.tag == 'loss': diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md index 875f19be74..f4177dc47a 100644 --- a/tensorflow/g3doc/get_started/os_setup.md +++ b/tensorflow/g3doc/get_started/os_setup.md @@ -293,7 +293,7 @@ packages needed by TensorFlow. * Activate the conda environment and install TensorFlow in it. * After the install you will activate the conda environment each time you want to use TensorFlow. -* Optionally install ipython and other packages into the conda environment +* Optionally install ipython and other packages into the conda environment. Install Anaconda: diff --git a/tensorflow/g3doc/how_tos/adding_an_op/index.md b/tensorflow/g3doc/how_tos/adding_an_op/index.md index 15a1e68d5f..88d0cf9e1c 100644 --- a/tensorflow/g3doc/how_tos/adding_an_op/index.md +++ b/tensorflow/g3doc/how_tos/adding_an_op/index.md @@ -37,7 +37,7 @@ any [attrs](#attrs) the Op might require. To see how this works, suppose you'd like to create an Op that takes a tensor of `int32`s and outputs a copy of the tensor, with all but the first element set to -zero. Create file [`tensorflow/core/user_ops`][user_ops]`/zero_out.cc` and +zero. Create file `tensorflow/core/user_ops/zero_out.cc` and add a call to the `REGISTER_OP` macro that defines the interface for such an Op: ```c++ @@ -321,11 +321,10 @@ using the `Attr` method, which expects a spec of the form: where `<name>` begins with a letter and can be composed of alphanumeric characters and underscores, and `<attr-type-expr>` is a type expression of the -form [described below](#attr-types) +form [described below](#attr-types). For example, if you'd like the `ZeroOut` Op to preserve a user-specified index, instead of only the 0th element, you can register the Op like so: - <code class="lang-c++"><pre> REGISTER\_OP("ZeroOut") <b>.Attr("preserve\_index: int")</b> @@ -335,7 +334,6 @@ REGISTER\_OP("ZeroOut") Your kernel can then access this attr in its constructor via the `context` parameter: - <code class="lang-c++"><pre> class ZeroOutOp : public OpKernel { public: @@ -357,7 +355,6 @@ class ZeroOutOp : public OpKernel { </pre></code> which can then be used in the `Compute` method: - <code class="lang-c++"><pre> void Compute(OpKernelContext\* context) override { // ... @@ -512,7 +509,6 @@ you would then register an `OpKernel` for each supported type. For instance, if you'd like the `ZeroOut` Op to work on `float`s in addition to `int32`s, your Op registration might look like: - <code class="lang-c++"><pre> REGISTER\_OP("ZeroOut") <b>.Attr("T: {float, int32}")</b> @@ -632,7 +628,6 @@ REGISTER\_KERNEL\_BUILDER( > </pre></code> Lets say you wanted to add more types, say `double`: - <code class="lang-c++"><pre> REGISTER\_OP("ZeroOut") <b>.Attr("T: {float, <b>double,</b> int32}")</b> @@ -643,7 +638,6 @@ REGISTER\_OP("ZeroOut") Instead of writing another `OpKernel` with redundant code as above, often you will be able to use a C++ template instead. You will still have one kernel registration (`REGISTER_KERNEL_BUILDER` call) per overload. - <code class="lang-c++"><pre> <b>template <typename T></b> class ZeroOutOp : public OpKernel { diff --git a/tensorflow/g3doc/how_tos/graph_viz/index.md b/tensorflow/g3doc/how_tos/graph_viz/index.md index c94afd70b5..d09769e274 100644 --- a/tensorflow/g3doc/how_tos/graph_viz/index.md +++ b/tensorflow/g3doc/how_tos/graph_viz/index.md @@ -33,9 +33,9 @@ with tf.name_scope('hidden') as scope: This results in the following three op names: -* *hidden*/alpha -* *hidden*/weights -* *hidden*/biases +* `hidden/alpha` +* `hidden/weights` +* `hidden/biases` By default, the visualization will collapse all three into a node labeled `hidden`. The extra detail isn't lost. You can double-click, or click @@ -253,7 +253,7 @@ The images below show the CIFAR-10 model with tensor shape information: Often it is useful to collect runtime metadata for a run, such as total memory usage, total compute time, and tensor shapes for nodes. The code example below is a snippet from the train and test section of a modification of the -[simple MNIST tutorial](http://tensorflow.org/tutorials/mnist/beginners/index.md), +[simple MNIST tutorial](../../tutorials/mnist/beginners/index.md), in which we have recorded summaries and runtime statistics. See the [Summaries Tutorial](../../how_tos/summaries_and_tensorboard/index.md#serializing-the-data) for details on how to record summaries. Full source is [here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py). diff --git a/tensorflow/g3doc/how_tos/hadoop/index.md b/tensorflow/g3doc/how_tos/hadoop/index.md index f55d6d182f..a2dd67babd 100644 --- a/tensorflow/g3doc/how_tos/hadoop/index.md +++ b/tensorflow/g3doc/how_tos/hadoop/index.md @@ -29,7 +29,7 @@ be set: set this environment variable by running: ```shell -source $HADOOP_HOME/libexec/hadoop-config.sh +source ${HADOOP_HOME}/libexec/hadoop-config.sh ``` * **LD_LIBRARY_PATH**: To include the path to libjvm.so, and optionally the path @@ -37,16 +37,16 @@ source $HADOOP_HOME/libexec/hadoop-config.sh `$HADOOP_HDFS_HOME/lib/native`. On Linux: ```shell -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$JAVA_HOME/jre/lib/amd64/server +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${JAVA_HOME}/jre/lib/amd64/server ``` * **CLASSPATH**: The Hadoop jars must be added prior to running your TensorFlow program. The CLASSPATH set by - `$HADOOP_HOME/libexec/hadoop-config.sh` is insufficient. Globs must be + `${HADOOP_HOME}/libexec/hadoop-config.sh` is insufficient. Globs must be expanded as described in the libhdfs documentation: ```shell -CLASSPATH=$($HADOOP_HDFS_HOME/bin/hadoop classpath --glob) python your_script.py +CLASSPATH=$($HADOOP_HDFS_HOME}/bin/hadoop classpath --glob) python your_script.py ``` If you are running [Distributed TensorFlow](../distributed/index.md), then all diff --git a/tensorflow/g3doc/how_tos/image_retraining/index.md b/tensorflow/g3doc/how_tos/image_retraining/index.md index c6a0467fb3..d721f61810 100644 --- a/tensorflow/g3doc/how_tos/image_retraining/index.md +++ b/tensorflow/g3doc/how_tos/image_retraining/index.md @@ -290,11 +290,32 @@ usual split is to put 80% of the images into the main training set, keep 10% aside to run as validation frequently during training, and then have a final 10% that are used less often as a testing set to predict the real-world performance of the classifier. These ratios can be controlled using the -`--testing_percentage` and `--validation_percentage` flags. One subtle thing -that the script does is it uses the filename of the image to determine which set -it is put into. This is designed to ensure that images don't get moved between -training and testing sets on different runs, since that could be a problem if -images that had been used for training a model were subsequently used in a -validation set. In general you should be able to leave these values at their -defaults, since you won't usually find any advantage to training to adjusting -them. +`--testing_percentage` and `--validation_percentage` flags. In general +you should be able to leave these values at their defaults, since you won't +usually find any advantage to training to adjusting them. + +Note that the script uses the image filenames (rather than a completely random +function) to divide the images among the training, validation, and test sets. +This is done to ensure that images don't get moved between training and testing +sets on different runs, since that could be a problem if images that had been +used for training a model were subsequently used in a validation set. + +You might notice that the validation accuracy fluctuates among iterations. Much +of this fluctuation arises from the fact that a random subset of the validation +set is chosen for each validation accuracy measurement. The fluctuations can be +greatly reduced, at the cost of some increase in training time, by choosing +`--validation_batch_size=-1`, which uses the entire validation set for each +accuracy computation. + +Once training is complete, you may find it insightful to examine misclassified +images in the test set. This can be done by adding the flag +`--print_misclassified_test_images`. This may help you get a feeling for which +types of images were most confusing for the model, and which categories were +most difficult to distinguish. For instance, you might discover that some +subtype of a particular category, or some unusual photo angle, is particularly +difficult to identify, which may encourage you to add more training images of +that subtype. Oftentimes, examining misclassified images can also point to +errors in the input data set, such as mislabeled, low-quality, or ambiguous +images. However, one should generally avoid point-fixing individual errors in +the test set, since they are likely to merely reflect more general problems in +the (much larger) training set. diff --git a/tensorflow/g3doc/how_tos/threading_and_queues/index.md b/tensorflow/g3doc/how_tos/threading_and_queues/index.md index 46444a02db..639ad116c9 100644 --- a/tensorflow/g3doc/how_tos/threading_and_queues/index.md +++ b/tensorflow/g3doc/how_tos/threading_and_queues/index.md @@ -28,7 +28,7 @@ creating these operations. Now that you have a bit of a feel for queues, let's dive into the details... -## Queue Use Overview +## Queue use overview Queues, such as `FIFOQueue` and `RandomShuffleQueue`, are important TensorFlow objects for computing tensors asynchronously in a graph. @@ -149,7 +149,7 @@ coord.request_stop() coord.join(enqueue_threads) ``` -## Handling Exceptions +## Handling exceptions Threads started by queue runners do more than just run the enqueue ops. They also catch and handle exceptions generated by queues, including diff --git a/tensorflow/g3doc/how_tos/variable_scope/index.md b/tensorflow/g3doc/how_tos/variable_scope/index.md index 4e01ce1259..bb1b3e53f4 100644 --- a/tensorflow/g3doc/how_tos/variable_scope/index.md +++ b/tensorflow/g3doc/how_tos/variable_scope/index.md @@ -69,7 +69,7 @@ def my_image_filter(input_images, variables_dict): strides=[1, 1, 1, 1], padding='SAME') return tf.nn.relu(conv2 + variables_dict["conv2_biases"]) -# The 2 calls to my_image_filter() now use the same variables +# Both calls to my_image_filter() now use the same variables result1 = my_image_filter(image1, variables_dict) result2 = my_image_filter(image2, variables_dict) ``` @@ -90,7 +90,7 @@ while constructing a graph. ## Variable Scope Example -Variable Scope mechanism in TensorFlow consists of 2 main functions: +Variable Scope mechanism in TensorFlow consists of two main functions: * `tf.get_variable(<name>, <shape>, <initializer>)`: Creates or returns a variable with a given name. @@ -280,9 +280,9 @@ when opening a new variable scope. ```python with tf.variable_scope("foo") as foo_scope: v = tf.get_variable("v", [1]) -with tf.variable_scope(foo_scope) +with tf.variable_scope(foo_scope): w = tf.get_variable("w", [1]) -with tf.variable_scope(foo_scope, reuse=True) +with tf.variable_scope(foo_scope, reuse=True): v1 = tf.get_variable("v", [1]) w1 = tf.get_variable("w", [1]) assert v1 is v @@ -296,7 +296,7 @@ different one. This is fully independent of where we do it. ```python with tf.variable_scope("foo") as foo_scope: assert foo_scope.name == "foo" -with tf.variable_scope("bar") +with tf.variable_scope("bar"): with tf.variable_scope("baz") as other_scope: assert other_scope.name == "bar/baz" with tf.variable_scope(foo_scope) as foo_scope2: diff --git a/tensorflow/g3doc/tutorials/seq2seq/index.md b/tensorflow/g3doc/tutorials/seq2seq/index.md index 7e8c3cb929..4cfcc56b29 100644 --- a/tensorflow/g3doc/tutorials/seq2seq/index.md +++ b/tensorflow/g3doc/tutorials/seq2seq/index.md @@ -35,7 +35,7 @@ File | What's in it? `models/rnn/translate/translate.py` | Binary that trains and runs the translation model. -## Sequence-to-Sequence Basics +## Sequence-to-sequence basics A basic sequence-to-sequence model, as introduced in [Cho et al., 2014](http://arxiv.org/abs/1406.1078) @@ -69,7 +69,7 @@ attention mechanism in the decoder looks like this. <img style="width:100%" src="../../images/attention_seq2seq.png" /> </div> -## TensorFlow seq2seq Library +## TensorFlow seq2seq library As you can see above, there are many different sequence-to-sequence models. Each of these models can use different RNN cells, but all @@ -148,7 +148,7 @@ more sequence-to-sequence models in `seq2seq.py`, take a look there. They all have similar interfaces, so we will not describe them in detail. We will use `embedding_attention_seq2seq` for our translation model below. -## Neural Translation Model +## Neural translation model While the core of the sequence-to-sequence model is constructed by the functions in `python/ops/seq2seq.py`, there are still a few tricks @@ -238,7 +238,7 @@ with encoder inputs representing `[PAD PAD "." "go" "I"]` and decoder inputs `[GO "Je" "vais" "." EOS PAD PAD PAD PAD PAD]`. -## Let's Run It +## Let's run it To train the model described above, we need to a large English-French corpus. We will use the *10^9-French-English corpus* from the @@ -312,7 +312,7 @@ Reading model parameters from /tmp/translate.ckpt-340000 Qui est le président des États-Unis ? ``` -## What Next? +## What next? The example above shows how you can build your own English-to-French translator, end-to-end. Run it and see how the model performs for yourself. diff --git a/tensorflow/g3doc/tutorials/wide/index.md b/tensorflow/g3doc/tutorials/wide/index.md index 4d76f85628..d30ad11374 100644 --- a/tensorflow/g3doc/tutorials/wide/index.md +++ b/tensorflow/g3doc/tutorials/wide/index.md @@ -63,8 +63,8 @@ import tempfile import urllib train_file = tempfile.NamedTemporaryFile() test_file = tempfile.NamedTemporaryFile() -urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", train_file.name) -urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", test_file.name) +urllib.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data", train_file.name) +urllib.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test", test_file.name) ``` Once the CSV files are downloaded, let's read them into diff --git a/tensorflow/g3doc/tutorials/wide_and_deep/index.md b/tensorflow/g3doc/tutorials/wide_and_deep/index.md index f1928bdca4..4928dd41a3 100644 --- a/tensorflow/g3doc/tutorials/wide_and_deep/index.md +++ b/tensorflow/g3doc/tutorials/wide_and_deep/index.md @@ -215,8 +215,8 @@ CONTINUOUS_COLUMNS = ["age", "education_num", "capital_gain", "capital_loss", # test_file to your own paths. train_file = tempfile.NamedTemporaryFile() test_file = tempfile.NamedTemporaryFile() -urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", train_file.name) -urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", test_file.name) +urllib.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data", train_file.name) +urllib.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test", test_file.name) # Read the training and test data sets into Pandas dataframe. df_train = pd.read_csv(train_file, names=COLUMNS, skipinitialspace=True) diff --git a/tensorflow/g3doc/tutorials/word2vec/index.md b/tensorflow/g3doc/tutorials/word2vec/index.md index 15653474df..936cb24a23 100644 --- a/tensorflow/g3doc/tutorials/word2vec/index.md +++ b/tensorflow/g3doc/tutorials/word2vec/index.md @@ -102,7 +102,7 @@ $$ \begin{align} P(w_t | h) &= \text{softmax}(\text{score}(w_t, h)) \\ &= \frac{\exp \{ \text{score}(w_t, h) \} } - {\sum_\text{Word w' in Vocab} \exp \{ \text{score}(w', h) \} }. + {\sum_\text{Word w' in Vocab} \exp \{ \text{score}(w', h) \} } \end{align} $$ @@ -115,7 +115,7 @@ $$ \begin{align} J_\text{ML} &= \log P(w_t | h) \\ &= \text{score}(w_t, h) - - \log \left( \sum_\text{Word w' in Vocab} \exp \{ \text{score}(w', h) \} \right) + \log \left( \sum_\text{Word w' in Vocab} \exp \{ \text{score}(w', h) \} \right). \end{align} $$ diff --git a/tensorflow/go/README.md b/tensorflow/go/README.md index 9817469006..5e08372cf9 100644 --- a/tensorflow/go/README.md +++ b/tensorflow/go/README.md @@ -76,5 +76,4 @@ go test -v github.com/tensorflow/tensorflow/tensorflow/go This API has been built on top of the [C API](https://www.tensorflow.org/code/tensorflow/c/c_api.h), which is intended for building language bindings for TensorFlow functionality. -However, this is far from complete. Contributions are welcome. To monitor -progress follow [issue 10](https://github.com/tensorflow/tensorflow/issues/10). +However, this is far from complete. Contributions are welcome. diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD new file mode 100644 index 0000000000..2aa077c2b7 --- /dev/null +++ b/tensorflow/java/BUILD @@ -0,0 +1,61 @@ +# Description: +# TensorFlow Java API. + +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +java_library( + name = "tensorflow", + srcs = glob(["src/main/java/org/tensorflow/*.java"]), + data = [":libtensorflow-jni"], + visibility = ["//visibility:public"], +) + +java_test( + name = "TensorFlowTest", + srcs = ["src/test/java/org/tensorflow/TensorFlowTest.java"], + test_class = "org.tensorflow.TensorFlowTest", + deps = [ + ":tensorflow", + "//external:junit", + ], +) + +filegroup( + name = "libtensorflow-jni", + srcs = select({ + "//tensorflow:darwin": [":libtensorflow-jni.dylib"], + "//conditions:default": [":libtensorflow-jni.so"], + }), +) + +cc_binary( + name = "libtensorflow-jni.so", + linkshared = 1, + linkstatic = 1, + deps = ["//tensorflow/java/src/main/native"], +) + +# System.loadLibrary() on OS X looks for ".dylib" or ".jnilib" +# and no ".so". If and when https://github.com/bazelbuild/bazel/issues/914 +# is resolved, perhaps this workaround rule can be removed. +genrule( + name = "darwin-compat", + srcs = [":libtensorflow-jni.so"], + outs = ["libtensorflow-jni.dylib"], + cmd = "cp $< $@", + output_to_bindir = 1, +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/java/README.md b/tensorflow/java/README.md new file mode 100644 index 0000000000..d9bee5e342 --- /dev/null +++ b/tensorflow/java/README.md @@ -0,0 +1,72 @@ +# TensorFlow for Java + +Java bindings for TensorFlow. + +> *WARNING*: The TensorFlow Java API is incomplete and experimental and can +> change without notice. Progress can be followed in +> [issue #5](https://github.com/tensorflow/tensorflow/issues/5). +> +> Till then, for using TensorFlow on Android refer to +> [contrib/android](https://www.tensorflow.org/code/tensorflow/contrib/android), +> [makefile](https://www.tensorflow.org/code/tensorflow/contrib/makefile#android) +> and/or the [Android camera +> demo](https://www.tensorflow.org/code/tensorflow/examples/android). + +## Requirements + +- [bazel](https://www.bazel.build/versions/master/docs/install.html) +- Environment to build TensorFlow from source code + ([Linux](https://www.tensorflow.org/versions/master/get_started/os_setup.html#prepare-environment-for-linux) + or [Mac OS + X](https://www.tensorflow.org/versions/master/get_started/os_setup.html#prepare-environment-for-mac-os-x)). + If you'd like to skip reading those details and do not care about GPU + support, try the following: + + ```sh + # On Linux + sudo apt-get install python swig python-numpy + + # On Mac OS X with homebrew + brew install swig + ``` + +## Installation + +Build the Java Archive and native library: + +```sh +bazel build -c opt \ + //tensorflow/java:libtensorflow.jar \ + //tensorflow/java:libtensorflow-jni +``` + +## Example Usage + +### With bazel + +Add a dependency on `//tensorflow/java:tensorflow` to the `java_binary` or +`java_library` rule. For example: + +```sh +bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example +``` + +### With `javac` + +- Add `libtensorflow.jar` to classpath for compilation. For example: + + ```sh + javac \ + -cp ../../bazel-bin/tensorflow/java/libtensorflow.jar \ + ./src/main/java/org/tensorflow/examples/Example.java + ``` + +- Make `libtensorflow.jar` and `libtensorflow-jni.so` + (`libtensorflow-jni.dylib` on OS X) available during execution. For example: + + ```sh + java \ + -Djava.library.path=../../bazel-bin/tensorflow/java \ + -cp ../../bazel-bin/tensorflow/java/libtensorflow.jar:./src/main/java \ + org.tensorflow.examples.Example + ``` diff --git a/tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java new file mode 100644 index 0000000000..dc7f87b928 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java @@ -0,0 +1,28 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +/** Static utility methods describing the TensorFlow runtime. */ +public final class TensorFlow { + private TensorFlow() {} + + static { + System.loadLibrary("tensorflow-jni"); + } + + /** Returns the version of the underlying TensorFlow runtime. */ + public static native String getVersion(); +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/.gitignore b/tensorflow/java/src/main/java/org/tensorflow/examples/.gitignore new file mode 100644 index 0000000000..8dc1579ef5 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/examples/.gitignore @@ -0,0 +1,3 @@ +# .class files generated when building examples using javac +# as described in README.md +*.class diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD b/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD new file mode 100644 index 0000000000..529287a038 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD @@ -0,0 +1,25 @@ +# Description: +# TensorFlow Java examples. + +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +java_binary( + name = "example", + srcs = ["Example.java"], + main_class = "org.tensorflow.examples.Example", + deps = ["//tensorflow/java:tensorflow"], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/Example.java b/tensorflow/java/src/main/java/org/tensorflow/examples/Example.java new file mode 100644 index 0000000000..f61c44b4ab --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/examples/Example.java @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.examples; + +import org.tensorflow.TensorFlow; + +/** + * Sample usage of the TensorFlow Java library. + * + * <p>This sample should become more useful as functionality is added to the API. + */ +public class Example { + public static void main(String[] args) { + System.out.println("TensorFlow version: " + TensorFlow.getVersion()); + } +} diff --git a/tensorflow/java/src/main/native/BUILD b/tensorflow/java/src/main/native/BUILD new file mode 100644 index 0000000000..3a2d0cbbfb --- /dev/null +++ b/tensorflow/java/src/main/native/BUILD @@ -0,0 +1,66 @@ +# Description: +# Java Native Interface (JNI) library intended for implementing the +# TensorFlow Java API using the TensorFlow C library. + +package(default_visibility = ["//tensorflow/java:__pkg__"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cuda_library") + +tf_cuda_library( + name = "native", + srcs = [ + "tensorflow.cc", + ":jni.h", + ":jni_md.h", + ], + hdrs = ["tensorflow.h"], + includes = ["."], + deps = [ + "//tensorflow/c:c_api", + ], + alwayslink = 1, +) + +# Silly rules to make +# #include <jni.h> +# in the source headers work +# (in combination with the "includes" attribute of the tf_cuda_library rule +# above). +# +# Inspired from: +# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD +# but hopefully there is a simpler alternative to this. +# +# TODO(ashankar): This should not be necessary for Android builds as the +# toolchain makes <jni.h> available. Perhaps remove ":jni.h" and ":jni_md.h" +# from "srcs" and make these genrules a no-op when building for Android? +genrule( + name = "copy_jni_h", + srcs = ["@bazel_tools//tools/jdk:jni_header"], + outs = ["jni.h"], + cmd = "cp -f $< $@", +) + +genrule( + name = "copy_jni_md_h", + srcs = select({ + "//tensorflow:darwin": ["@bazel_tools//tools/jdk:jni_md_header-darwin"], + "//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"], + }), + outs = ["jni_md.h"], + cmd = "cp -f $< $@", +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/python/client/net_lib.i b/tensorflow/java/src/main/native/tensorflow.cc index 333e2abbc5..55de5771dd 100644 --- a/tensorflow/python/client/net_lib.i +++ b/tensorflow/java/src/main/native/tensorflow.cc @@ -13,18 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -%include "tensorflow/python/platform/base.i" +#include "tensorflow/java/src/main/native/tensorflow.h" +#include "tensorflow/c/c_api.h" -%{ -#include "tensorflow/core/platform/net.h" -%} - -%ignoreall - -%unignore tensorflow; -%unignore tensorflow::internal; -%unignore tensorflow::internal::PickUnusedPortOrDie; - -%include "tensorflow/core/platform/net.h" - -%unignoreall +JNIEXPORT jstring JNICALL +Java_org_tensorflow_TensorFlow_getVersion(JNIEnv* env, jclass clazz) { + return env->NewStringUTF(TF_Version()); +} diff --git a/tensorflow/java/src/main/native/tensorflow.h b/tensorflow/java/src/main/native/tensorflow.h new file mode 100644 index 0000000000..897a000ac0 --- /dev/null +++ b/tensorflow/java/src/main/native/tensorflow.h @@ -0,0 +1,36 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_JNI_H_ +#define TENSORFLOW_JAVA_JNI_H_ + +#include <jni.h> + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/* + * Class: TensorFlow + * Method: getVersion + * Signature: ()Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_getVersion(JNIEnv*, + jclass); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_JAVA_JNI_H_ diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java new file mode 100644 index 0000000000..94fd0582c1 --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java @@ -0,0 +1,31 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +import static org.junit.Assert.assertTrue; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.TensorFlow}. */ +@RunWith(JUnit4.class) +public class TensorFlowTest { + @Test + public void version() { + assertTrue(TensorFlow.getVersion().length() > 0); + } +} diff --git a/tensorflow/models/embedding/word2vec.py b/tensorflow/models/embedding/word2vec.py index e463e300c1..4b36554716 100644 --- a/tensorflow/models/embedding/word2vec.py +++ b/tensorflow/models/embedding/word2vec.py @@ -365,7 +365,7 @@ class Word2Vec(object): self._word2id[w] = i true_logits, sampled_logits = self.forward(examples, labels) loss = self.nce_loss(true_logits, sampled_logits) - tf.scalar_summary("NCE loss", loss) + tf.contrib.deprecated.scalar_summary("NCE loss", loss) self._loss = loss self.optimize(loss) @@ -396,8 +396,8 @@ class Word2Vec(object): initial_epoch, initial_words = self._session.run([self._epoch, self._words]) - summary_op = tf.merge_all_summaries() - summary_writer = tf.train.SummaryWriter(opts.save_path, self._session.graph) + summary_op = tf.contrib.deprecated.merge_all_summaries() + summary_writer = tf.summary.FileWriter(opts.save_path, self._session.graph) workers = [] for _ in xrange(opts.concurrent_steps): t = threading.Thread(target=self._train_thread_body) diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py index 1c51b76f09..55c34ba84b 100644 --- a/tensorflow/models/image/cifar10/cifar10.py +++ b/tensorflow/models/image/cifar10/cifar10.py @@ -91,8 +91,9 @@ def _activation_summary(x): # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training # session. This helps the clarity of presentation on tensorboard. tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name) - tf.histogram_summary(tensor_name + '/activations', x) - tf.scalar_summary(tensor_name + '/sparsity', tf.nn.zero_fraction(x)) + tf.contrib.deprecated.histogram_summary(tensor_name + '/activations', x) + tf.contrib.deprecated.scalar_summary(tensor_name + '/sparsity', + tf.nn.zero_fraction(x)) def _variable_on_cpu(name, shape, initializer): @@ -316,8 +317,8 @@ def _add_loss_summaries(total_loss): for l in losses + [total_loss]: # Name each loss as '(raw)' and name the moving average version of the loss # as the original loss name. - tf.scalar_summary(l.op.name +' (raw)', l) - tf.scalar_summary(l.op.name, loss_averages.average(l)) + tf.contrib.deprecated.scalar_summary(l.op.name + ' (raw)', l) + tf.contrib.deprecated.scalar_summary(l.op.name, loss_averages.average(l)) return loss_averages_op @@ -345,7 +346,7 @@ def train(total_loss, global_step): decay_steps, LEARNING_RATE_DECAY_FACTOR, staircase=True) - tf.scalar_summary('learning_rate', lr) + tf.contrib.deprecated.scalar_summary('learning_rate', lr) # Generate moving averages of all losses and associated summaries. loss_averages_op = _add_loss_summaries(total_loss) @@ -360,12 +361,12 @@ def train(total_loss, global_step): # Add histograms for trainable variables. for var in tf.trainable_variables(): - tf.histogram_summary(var.op.name, var) + tf.contrib.deprecated.histogram_summary(var.op.name, var) # Add histograms for gradients. for grad, var in grads: if grad is not None: - tf.histogram_summary(var.op.name + '/gradients', grad) + tf.contrib.deprecated.histogram_summary(var.op.name + '/gradients', grad) # Track the moving averages of all trainable variables. variable_averages = tf.train.ExponentialMovingAverage( @@ -394,5 +395,5 @@ def maybe_download_and_extract(): print() statinfo = os.stat(filepath) print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') - + tarfile.open(filepath, 'r:gz').extractall(dest_directory) diff --git a/tensorflow/models/image/cifar10/cifar10_eval.py b/tensorflow/models/image/cifar10/cifar10_eval.py index 19bf74c477..c2329380d6 100644 --- a/tensorflow/models/image/cifar10/cifar10_eval.py +++ b/tensorflow/models/image/cifar10/cifar10_eval.py @@ -134,9 +134,9 @@ def evaluate(): saver = tf.train.Saver(variables_to_restore) # Build the summary operation based on the TF collection of Summaries. - summary_op = tf.merge_all_summaries() + summary_op = tf.contrib.deprecated.merge_all_summaries() - summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, g) + summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g) while True: eval_once(saver, summary_writer, top_k_op, summary_op) diff --git a/tensorflow/models/image/cifar10/cifar10_input.py b/tensorflow/models/image/cifar10/cifar10_input.py index 14ea94d72f..b00859b262 100644 --- a/tensorflow/models/image/cifar10/cifar10_input.py +++ b/tensorflow/models/image/cifar10/cifar10_input.py @@ -130,7 +130,7 @@ def _generate_image_and_label_batch(image, label, min_queue_examples, capacity=min_queue_examples + 3 * batch_size) # Display the training images in the visualizer. - tf.image_summary('images', images) + tf.contrib.deprecated.image_summary('images', images) return images, tf.reshape(label_batch, [batch_size]) diff --git a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py index 53ae7d5c74..a59e13d5e3 100644 --- a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py +++ b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py @@ -93,7 +93,7 @@ def tower_loss(scope): # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training # session. This helps the clarity of presentation on tensorboard. loss_name = re.sub('%s_[0-9]*/' % cifar10.TOWER_NAME, '', l.op.name) - tf.scalar_summary(loss_name, l) + tf.contrib.deprecated.scalar_summary(loss_name, l) return total_loss @@ -187,20 +187,22 @@ def train(): grads = average_gradients(tower_grads) # Add a summary to track the learning rate. - summaries.append(tf.scalar_summary('learning_rate', lr)) + summaries.append(tf.contrib.deprecated.scalar_summary('learning_rate', lr)) # Add histograms for gradients. for grad, var in grads: if grad is not None: summaries.append( - tf.histogram_summary(var.op.name + '/gradients', grad)) + tf.contrib.deprecated.histogram_summary(var.op.name + '/gradients', + grad)) # Apply the gradients to adjust the shared variables. apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) # Add histograms for trainable variables. for var in tf.trainable_variables(): - summaries.append(tf.histogram_summary(var.op.name, var)) + summaries.append( + tf.contrib.deprecated.histogram_summary(var.op.name, var)) # Track the moving averages of all trainable variables. variable_averages = tf.train.ExponentialMovingAverage( @@ -214,7 +216,7 @@ def train(): saver = tf.train.Saver(tf.all_variables()) # Build the summary operation from the last tower summaries. - summary_op = tf.merge_summary(summaries) + summary_op = tf.contrib.deprecated.merge_summary(summaries) # Build an initialization operation to run below. init = tf.global_variables_initializer() @@ -230,7 +232,7 @@ def train(): # Start the queue runners. tf.train.start_queue_runners(sess=sess) - summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) + summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph) for step in xrange(FLAGS.max_steps): start_time = time.time() diff --git a/tensorflow/models/image/cifar10/cifar10_train.py b/tensorflow/models/image/cifar10/cifar10_train.py index 45c0bbd9f0..eab499fc3e 100644 --- a/tensorflow/models/image/cifar10/cifar10_train.py +++ b/tensorflow/models/image/cifar10/cifar10_train.py @@ -118,3 +118,4 @@ def main(argv=None): # pylint: disable=unused-argument if __name__ == '__main__': tf.app.run() + diff --git a/tensorflow/models/rnn/ptb/ptb_word_lm.py b/tensorflow/models/rnn/ptb/ptb_word_lm.py index 020edbd5e5..45fb1e774a 100644 --- a/tensorflow/models/rnn/ptb/ptb_word_lm.py +++ b/tensorflow/models/rnn/ptb/ptb_word_lm.py @@ -328,14 +328,14 @@ def main(_): train_input = PTBInput(config=config, data=train_data, name="TrainInput") with tf.variable_scope("Model", reuse=None, initializer=initializer): m = PTBModel(is_training=True, config=config, input_=train_input) - tf.scalar_summary("Training Loss", m.cost) - tf.scalar_summary("Learning Rate", m.lr) + tf.contrib.deprecated.scalar_summary("Training Loss", m.cost) + tf.contrib.deprecated.scalar_summary("Learning Rate", m.lr) with tf.name_scope("Valid"): valid_input = PTBInput(config=config, data=valid_data, name="ValidInput") with tf.variable_scope("Model", reuse=True, initializer=initializer): mvalid = PTBModel(is_training=False, config=config, input_=valid_input) - tf.scalar_summary("Validation Loss", mvalid.cost) + tf.contrib.deprecated.scalar_summary("Validation Loss", mvalid.cost) with tf.name_scope("Test"): test_input = PTBInput(config=eval_config, data=test_data, name="TestInput") diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index e1dda0674b..f1fae16bb0 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1204,7 +1204,10 @@ py_library( py_library( name = "rnn_cell", - srcs = ["ops/rnn_cell.py"], + srcs = [ + "ops/rnn_cell.py", + "ops/rnn_cell_impl.py", + ], srcs_version = "PY2AND3", deps = [ ":array_ops", @@ -1906,28 +1909,6 @@ cuda_py_tests( ], ) -py_library( - name = "net_lib", - testonly = 1, - srcs = ["util/net_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":pywrap_tensorflow", - ], -) - -py_tests( - name = "net_lib_test", - size = "small", - srcs = [ - "util/net_lib_test.py", - ], - additional_deps = [ - ":net_lib", - "//tensorflow:tensorflow_py", - ], -) - tf_cuda_library( name = "tf_session_helper", srcs = ["client/tf_session_helper.cc"], @@ -1954,7 +1935,6 @@ tf_py_wrap_cc( swig_includes = [ "client/device_lib.i", "client/events_writer.i", - "client/net_lib.i", "client/quantize_training.i", "client/tf_session.i", "framework/cpp_shape_inference.i", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index ae8d0e02f1..2a7a76c396 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -82,7 +82,7 @@ from tensorflow.python.ops.standard_ops import * # pylint: enable=wildcard-import # Bring in subpackages. -from tensorflow.python import layers +from tensorflow.python.layers import layers from tensorflow.python.ops import nn from tensorflow.python.ops import resources from tensorflow.python.ops import sdca_ops as sdca diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 71c931037e..591cc5afbc 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -364,6 +364,7 @@ class _DictFetchMapper(_FetchMapper): Args: fetches: Dict of fetches. """ + self._fetch_type = type(fetches) self._keys = fetches.keys() self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches.values()] @@ -373,7 +374,7 @@ class _DictFetchMapper(_FetchMapper): return self._unique_fetches def build_results(self, values): - results = {} + results = self._fetch_type() for k, m, vi in zip(self._keys, self._mappers, self._value_indices): results[k] = m.build_results([values[j] for j in vi]) return results @@ -661,8 +662,8 @@ class BaseSession(SessionInterface): `feed_dict` for the corresponding input values. The `fetches` argument may be a single graph element, or an arbitrarily - nested list, tuple, namedtuple, or dict containing graph elements at its - leaves. A graph element can be one of the following types: + nested list, tuple, namedtuple, dict, or OrderedDict containing graph + elements at its leaves. A graph element can be one of the following types: * An [`Operation`](../../api_docs/python/framework.md#Operation). The corresponding fetched value will be `None`. diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index a20376b91d..0c602a9014 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -254,6 +254,18 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertEqual(None, res['b']) self.assertEqual(44.0, res['c']) + def testFetchOrderedDict(self): + with session.Session() as sess: + a = constant_op.constant(42.0) + b = control_flow_ops.no_op() # An op, not a tensor. + c = constant_op.constant(44.0) + res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)])) + self.assertTrue(isinstance(res, collections.OrderedDict)) + self.assertEqual([3, 2, 1], list(res.keys())) + self.assertEqual(42.0, res[3]) + self.assertEqual(None, res[2]) + self.assertEqual(44.0, res[1]) + def testFetchNestingEmptyOneLevel(self): with session.Session() as sess: a_val = 11.0 diff --git a/tensorflow/python/client/timeline.py b/tensorflow/python/client/timeline.py index ea64d74d6d..f3ba4244ce 100644 --- a/tensorflow/python/client/timeline.py +++ b/tensorflow/python/client/timeline.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections import copy import json +import re # The timeline target is usually imported as part of BUILD target # "platform_test", which includes also includes the "platform" @@ -384,12 +385,15 @@ class Timeline(object): def _parse_op_label(self, label): """Parses the fields in a node timeline label.""" - nn, rest = label.split(' = ') - op, rest = rest.split('(') - if rest == ')': + # Expects labels of the form: name = op(arg, arg, ...). + match = re.match(r'(.*) = (.*)\((.*)\)', label) + if match is None: + return 'unknown', 'unknown', [] + nn, op, inputs = match.groups() + if not inputs: inputs = [] else: - inputs = rest[:-1].split(', ') + inputs = inputs.split(', ') return nn, op, inputs def _assign_lanes(self): @@ -421,11 +425,14 @@ class Timeline(object): start = nodestats.all_start_micros duration = nodestats.all_end_rel_micros tid = nodestats.thread_id + inputs = [] if is_gputrace: # Node names should always have the form 'name:op'. fields = node_name.split(':') + ['unknown'] node_name, op = fields[:2] - inputs = [] + elif node_name == 'RecvTensor': + # RPC tracing does not use the standard timeline_label format. + op = 'RecvTensor' else: _, op, inputs = self._parse_op_label(nodestats.timeline_label) args = {'name': node_name, 'op': op} @@ -518,7 +525,7 @@ class Timeline(object): end_time = node_stats.all_start_micros + node_stats.all_end_rel_micros self._emit_op(node_stats, device_pid, is_gputrace) - if is_gputrace: + if is_gputrace or node_stats.node_name == 'RecvTensor': continue _, _, inputs = self._parse_op_label(node_stats.timeline_label) diff --git a/tensorflow/python/client/timeline_test.py b/tensorflow/python/client/timeline_test.py index 7c9d7847b7..46984f694f 100644 --- a/tensorflow/python/client/timeline_test.py +++ b/tensorflow/python/client/timeline_test.py @@ -109,6 +109,23 @@ class TimelineTest(tf.test.TestCase): show_dataflow=False) self._validateTrace(ctf) + def testTimelineWithRPCs(self): + """Tests that Timeline can handle RPC tracing.""" + metadata = tf.RunMetadata() + step_stats = metadata.step_stats + dev_stats = step_stats.dev_stats.add() + dev_stats.device = '/job:worker/replica:0/task:0/cpu:0' + node_stats = dev_stats.node_stats.add() + node_stats.node_name = 'RecvTensor' + node_stats.all_start_micros = 12345 + node_stats.op_end_rel_micros = 42 + node_stats.timeline_label = ('[1024B] edge_160_conv2/biases/read from ' + '/job:ps/replica:0/task:3/cpu:0 to ' + '/job:worker/replica:0/task:0/cpu:0') + tl = timeline.Timeline(step_stats) + ctf = tl.generate_chrome_trace_format() + self._validateTrace(ctf) + def testAnalysisAndAllocations(self): run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() diff --git a/tensorflow/python/framework/graph_util.py b/tensorflow/python/framework/graph_util.py index 402a5ebf0e..a666630e44 100644 --- a/tensorflow/python/framework/graph_util.py +++ b/tensorflow/python/framework/graph_util.py @@ -26,7 +26,6 @@ from tensorflow.python.framework.graph_util_impl import convert_variables_to_con from tensorflow.python.framework.graph_util_impl import extract_sub_graph from tensorflow.python.framework.graph_util_impl import must_run_on_cpu from tensorflow.python.framework.graph_util_impl import remove_training_nodes -from tensorflow.python.framework.graph_util_impl import set_cpu0 from tensorflow.python.framework.graph_util_impl import tensor_shape_from_node_def_name # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented @@ -36,7 +35,6 @@ _allowed_symbols = [ "convert_variables_to_constants", "extract_sub_graph", "must_run_on_cpu", - "set_cpu0", "tensor_shape_from_node_def_name", "remove_training_nodes", ] diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py index ba693503c2..587f883260 100644 --- a/tensorflow/python/framework/graph_util_impl.py +++ b/tensorflow/python/framework/graph_util_impl.py @@ -25,7 +25,6 @@ import re from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 -from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -49,23 +48,6 @@ def _is_variable_op(op): return op in _VARIABLE_OPS -def set_cpu0(device_string): - """Creates a new device string based on `device_string' but using /CPU:0. - - If the device is already on /CPU:0, this is a no-op. - - Args: - device_string: A device string. - - Returns: - A device string. - """ - parsed_device = pydev.DeviceSpec.from_string(device_string) - parsed_device.device_type = "CPU" - parsed_device.device_index = 0 - return parsed_device.to_string() - - def must_run_on_cpu(node, pin_variables_on_cpu=False): """Returns True if the given node_def must run on CPU, otherwise False. diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 24169b57db..d1edf43193 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -3980,7 +3980,7 @@ class GraphKeys(object): for more details. * `SUMMARIES`: the summary `Tensor` objects that have been created in the graph. See - [`tf.merge_all_summaries()`](../../api_docs/python/train.md#merge_all_summaries) + [`tf.contrib.deprecated.merge_all_summaries()`](../../api_docs/python/train.md#merge_all_summaries) for more details. * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to produce input for a computation. See diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py index 02c71d3032..e779dc7c69 100644 --- a/tensorflow/python/kernel_tests/concat_op_test.py +++ b/tensorflow/python/kernel_tests/concat_op_test.py @@ -461,11 +461,17 @@ class ConcatOpTest(tf.test.TestCase): with self.test_session(use_gpu=True): t1 = [[1, 2, 3], [4, 5, 6]] t2 = [[7, 8, 9], [10, 11, 12]] - output = tf.concat(-2, [t1, t2]).eval() + + c = tf.concat(-2, [t1, t2]) + output = c.eval() + self.assertEqual([4, 3], c.get_shape().as_list()) self.assertAllEqual( [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], output) - output = tf.concat(-1, [t1, t2]).eval() + + c = tf.concat(-1, [t1, t2]) + self.assertEqual([2, 6], c.get_shape().as_list()) + output = c.eval() self.assertAllEqual( [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], output) @@ -488,11 +494,17 @@ class ConcatOpTest(tf.test.TestCase): with self.test_session(use_gpu=True): t1 = [[1, 2, 3], [4, 5, 6]] t2 = [[7, 8, 9], [10, 11, 12]] - output = gen_array_ops._concat_v2([t1, t2], -2).eval() + + c = gen_array_ops._concat_v2([t1, t2], -2) + self.assertEqual([4, 3], c.get_shape().as_list()) + output = c.eval() self.assertAllEqual( [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], output) - output = gen_array_ops._concat_v2([t1, t2], -1).eval() + + c = gen_array_ops._concat_v2([t1, t2], -1) + self.assertEqual([2, 6], c.get_shape().as_list()) + output = c.eval() self.assertAllEqual( [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], output) diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 8fc4be8e6e..732a604dc2 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1269,6 +1269,31 @@ class ControlFlowTest(tf.test.TestCase): tf.global_variables_initializer().run() self.assertAllClose(216.0, r[0].eval()) + def testWhileGradInCond(self): + with self.test_session(): + n = tf.convert_to_tensor(1.0, name="n") + x = tf.placeholder(tf.float32, shape=None) + c = lambda n: tf.less(n, 10.0) + b = lambda n: tf.add(n, x) + def fn1(): + r = tf.while_loop(c, b, [n], [tensor_shape.unknown_shape()]) + return tf.gradients(r, x) + r = tf.cond(tf.less(1, 2), fn1, lambda: x) + self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0})) + + def testWhileGradInWhile(self): + with self.test_session(): + n = tf.convert_to_tensor(1.0, name="n") + x = tf.placeholder(tf.float32, shape=None) + c = lambda n: tf.less(n, 10.0) + b = lambda n: tf.add(n, x) + def b1(n): + r = tf.while_loop(c, b, [n], [tensor_shape.unknown_shape()]) + return tf.gradients(r, x) + r = tf.while_loop(lambda n: n < 6.0, b1, [n], + [tensor_shape.unknown_shape()]) + self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0})) + def testWhile_NestedInput(self): with self.test_session() as sess: named = collections.namedtuple("named", ("a", "b")) diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index f6397226af..aa31c03e19 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -1297,7 +1297,7 @@ class SelectOpTest(tf.test.TestCase): def _compare(self, c, x, y, use_gpu): np_ans = np.where(c, x, y) with self.test_session(use_gpu=use_gpu): - out = tf.select(c, x, y) + out = tf.where(c, x, y) tf_ans = out.eval() self.assertAllEqual(np_ans, tf_ans) self.assertShapeEqual(np_ans, out) @@ -1306,7 +1306,7 @@ class SelectOpTest(tf.test.TestCase): with self.test_session(): inx = tf.convert_to_tensor(x) iny = tf.convert_to_tensor(y) - out = tf.select(c, inx, iny) + out = tf.where(c, inx, iny) s = list(np.shape(c)) jacob_t, jacob_n = tf.test.compute_gradient(inx, s, @@ -1318,7 +1318,7 @@ class SelectOpTest(tf.test.TestCase): yf = y.astype(numeric_gradient_type) inxf = tf.convert_to_tensor(xf) inyf = tf.convert_to_tensor(yf) - outf = tf.select(c, inxf, inyf) + outf = tf.where(c, inxf, inyf) _, jacob_n = tf.test.compute_gradient(inxf, s, outf, @@ -1336,7 +1336,7 @@ class SelectOpTest(tf.test.TestCase): with self.test_session(): inx = tf.convert_to_tensor(x) iny = tf.convert_to_tensor(y) - out = tf.select(c, inx, iny) + out = tf.where(c, inx, iny) s = list(np.shape(c)) jacob_t, jacob_n = tf.test.compute_gradient(iny, s, @@ -1349,7 +1349,7 @@ class SelectOpTest(tf.test.TestCase): yf = y.astype(numeric_gradient_type) inxf = tf.convert_to_tensor(xf) inyf = tf.convert_to_tensor(yf) - outf = tf.select(c, inxf, inyf) + outf = tf.where(c, inxf, inyf) _, jacob_n = tf.test.compute_gradient(inyf, s, outf, @@ -1415,7 +1415,7 @@ class SelectOpTest(tf.test.TestCase): xt = x.astype(t) yt = y.astype(t) with self.assertRaises(ValueError): - tf.select(c, xt, yt) + tf.where(c, xt, yt) def testEmptyTensor(self): c = np.random.randint(0, 3, 0).astype(np.bool).reshape(1, 3, 0) @@ -1425,7 +1425,7 @@ class SelectOpTest(tf.test.TestCase): with self.test_session(): xt = x.astype(np.float32) yt = y.astype(np.float32) - z = tf.select(c, xt, yt).eval() + z = tf.where(c, xt, yt).eval() self.assertAllEqual(z_expected, z) def testNan(self): @@ -1434,7 +1434,7 @@ class SelectOpTest(tf.test.TestCase): for c in False, True: for a in 7.0, np.nan: for b in 5.0, np.nan: - x = tf.select(c, a, b).eval() + x = tf.where(c, a, b).eval() y = a if c else b self.assertEqual(np.isnan(x), np.isnan(y)) @@ -1447,7 +1447,7 @@ class BatchSelectOpTest(tf.test.TestCase): [x_i if c_i else y_i for c_i, x_i, y_i in zip(c, x, y)]).transpose( [2, 0, 1]) with self.test_session(use_gpu=use_gpu): - out = tf.select(c, x, y) + out = tf.where(c, x, y) tf_ans = out.eval() self.assertAllEqual(np_ans, tf_ans) self.assertShapeEqual(np_ans, out) @@ -1456,7 +1456,7 @@ class BatchSelectOpTest(tf.test.TestCase): with self.test_session(): inx = tf.convert_to_tensor(x) iny = tf.convert_to_tensor(y) - out = tf.select(c, inx, iny) + out = tf.where(c, inx, iny) s = list(np.shape(x)) jacob_t, jacob_n = tf.test.compute_gradient(inx, s, @@ -1468,7 +1468,7 @@ class BatchSelectOpTest(tf.test.TestCase): yf = y.astype(numeric_gradient_type) inxf = tf.convert_to_tensor(xf) inyf = tf.convert_to_tensor(yf) - outf = tf.select(c, inxf, inyf) + outf = tf.where(c, inxf, inyf) _, jacob_n = tf.test.compute_gradient(inxf, s, outf, @@ -1486,7 +1486,7 @@ class BatchSelectOpTest(tf.test.TestCase): with self.test_session(): inx = tf.convert_to_tensor(x) iny = tf.convert_to_tensor(y) - out = tf.select(c, inx, iny) + out = tf.where(c, inx, iny) s = list(np.shape(x)) jacob_t, jacob_n = tf.test.compute_gradient(iny, s, @@ -1498,7 +1498,7 @@ class BatchSelectOpTest(tf.test.TestCase): yf = y.astype(numeric_gradient_type) inxf = tf.convert_to_tensor(xf) inyf = tf.convert_to_tensor(yf) - outf = tf.select(c, inxf, inyf) + outf = tf.where(c, inxf, inyf) _, jacob_n = tf.test.compute_gradient(inyf, s, outf, @@ -1552,7 +1552,7 @@ class BatchSelectOpTest(tf.test.TestCase): xt = x.astype(t) yt = y.astype(t) with self.assertRaises(ValueError): - tf.select(c, xt, yt) + tf.where(c, xt, yt) class MinMaxOpTest(tf.test.TestCase): diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py index 889f14cd53..cbc5ee278e 100644 --- a/tensorflow/python/kernel_tests/parsing_ops_test.py +++ b/tensorflow/python/kernel_tests/parsing_ops_test.py @@ -258,6 +258,82 @@ class ParseExampleTest(tf.test.TestCase): } }, expected_output) + def testSerializedContainingSparseFeature(self): + original = [ + example(features=features({ + "val": float_feature([3, 4]), + "idx": int64_feature([5, 10]) + })), + example(features=features({ + "val": float_feature([]), # empty float list + "idx": int64_feature([]) + })), + example(features=features({ + "val": feature(), # feature with nothing in it + # missing idx feature + })), + example(features=features({ + "val": float_feature([1, 2, -1]), + "idx": int64_feature([0, 9, 3]) # unsorted + })) + ] + + serialized = [m.SerializeToString() for m in original] + + expected_sp = ( # indices, values, shape + np.array([[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64), + np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), + np.array([4, 13], dtype=np.int64)) # batch == 4, max_elems = 13 + + expected_output = { + "sp": expected_sp, + } + + self._test({ + "serialized": tf.convert_to_tensor(serialized), + "features": { + "sp": tf.SparseFeature("idx", "val", tf.float32, 13) + } + }, expected_output) + + def testSerializedContainingSparseFeatureReuse(self): + original = [ + example(features=features({ + "val1": float_feature([3, 4]), + "val2": float_feature([5, 6]), + "idx": int64_feature([5, 10]) + })), + example(features=features({ + "val1": float_feature([]), # empty float list + "idx": int64_feature([]) + })), + ] + + serialized = [m.SerializeToString() for m in original] + + expected_sp1 = ( # indices, values, shape + np.array([[0, 5], [0, 10]], dtype=np.int64), + np.array([3.0, 4.0], dtype=np.float32), + np.array([2, 13], dtype=np.int64)) # batch == 2, max_elems = 13 + + expected_sp2 = ( # indices, values, shape + np.array([[0, 5], [0, 10]], dtype=np.int64), + np.array([5.0, 6.0], dtype=np.float32), + np.array([2, 7], dtype=np.int64)) # batch == 2, max_elems = 13 + + expected_output = { + "sp1": expected_sp1, + "sp2": expected_sp2, + } + + self._test({ + "serialized": tf.convert_to_tensor(serialized), + "features": { + "sp1": tf.SparseFeature("idx", "val1", tf.float32, 13), + "sp2": tf.SparseFeature("idx", "val2", tf.float32, 7) + } + }, expected_output) + def testSerializedContainingDense(self): aname = "a" bname = "b*has+a:tricky_name" @@ -400,7 +476,7 @@ class ParseExampleTest(tf.test.TestCase): }, expected_output) - def testSerializedContainingSparseAndDenseWithNoDefault(self): + def testSerializedContainingSparseAndSparseFeatureAndDenseWithNoDefault(self): expected_st_a = ( # indices, values, shape np.empty( (0, 2), dtype=np.int64), # indices @@ -408,12 +484,20 @@ class ParseExampleTest(tf.test.TestCase): (0,), dtype=np.int64), # sp_a is DT_INT64 np.array( [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0 + expected_sp = ( # indices, values, shape + np.array([[0, 0], [0, 3], [1, 7]], dtype=np.int64), + np.array(["a", "b", "c"], dtype="|S"), + np.array([2, 13], dtype=np.int64)) # batch == 4, max_elems = 13 original = [ example(features=features({ - "c": float_feature([3, 4]) + "c": float_feature([3, 4]), + "val": bytes_feature([b"a", b"b"]), + "idx": int64_feature([0, 3]) })), example(features=features({ - "c": float_feature([1, 2]) + "c": float_feature([1, 2]), + "val": bytes_feature([b"c"]), + "idx": int64_feature([7]) })) ] @@ -424,6 +508,7 @@ class ParseExampleTest(tf.test.TestCase): b_default = np.random.rand(3, 3).astype(bytes) expected_output = { "st_a": expected_st_a, + "sp": expected_sp, "a": np.array(2 * [[a_default]]), "b": np.array(2 * [b_default]), "c": np.array( @@ -436,6 +521,7 @@ class ParseExampleTest(tf.test.TestCase): "serialized": tf.convert_to_tensor(serialized), "features": { "st_a": tf.VarLenFeature(tf.int64), + "sp": tf.SparseFeature("idx", "val", tf.string, 13), "a": tf.FixedLenFeature( (1, 3), tf.int64, default_value=a_default), "b": tf.FixedLenFeature( @@ -446,6 +532,46 @@ class ParseExampleTest(tf.test.TestCase): }, expected_output) + def testSerializedContainingSparseAndSparseFeatureWithReuse(self): + expected_idx = ( # indices, values, shape + np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64), + np.array([0, 3, 7, 1]), + np.array([2, 2], dtype=np.int64)) # batch == 4, max_elems = 2 + + expected_sp = ( # indices, values, shape + np.array([[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64), + np.array(["a", "b", "d", "c"], dtype="|S"), + np.array([2, 13], dtype=np.int64)) # batch == 4, max_elems = 13 + + original = [ + example(features=features({ + "val": bytes_feature([b"a", b"b"]), + "idx": int64_feature([0, 3]) + })), example(features=features({ + "val": bytes_feature([b"c", b"d"]), + "idx": int64_feature([7, 1]) + })) + ] + + names = ["in1", "in2"] + serialized = [m.SerializeToString() for m in original] + + expected_output = { + "idx": expected_idx, + "sp": expected_sp, + } + + self._test( + { + "example_names": names, + "serialized": tf.convert_to_tensor(serialized), + "features": { + "idx": tf.VarLenFeature(tf.int64), + "sp": tf.SparseFeature("idx", "val", tf.string, 13), + } + }, + expected_output) + class ParseSingleExampleTest(tf.test.TestCase): @@ -473,8 +599,10 @@ class ParseSingleExampleTest(tf.test.TestCase): self.assertEqual(tuple(out[k].values.get_shape().as_list()), (None,)) self.assertEqual(tuple(out[k].shape.get_shape().as_list()), (1,)) - def testSingleExampleWithSparseAndDense(self): + def testSingleExampleWithSparseAndSparseFeatureAndDense(self): original = example(features=features({"c": float_feature([3, 4]), + "val": bytes_feature([b"a", b"b"]), + "idx": int64_feature([0, 3]), "st_a": float_feature([3.0, 4.0])})) serialized = original.SerializeToString() @@ -486,10 +614,16 @@ class ParseSingleExampleTest(tf.test.TestCase): np.array( [2], dtype=np.int64)) # shape: max_values = 2 + expected_sp = ( # indices, values, shape + np.array([[0], [3]], dtype=np.int64), + np.array(["a", "b"], dtype="|S"), + np.array([13], dtype=np.int64)) # max_values = 13 + a_default = [1, 2, 3] b_default = np.random.rand(3, 3).astype(bytes) expected_output = { "st_a": expected_st_a, + "sp": expected_sp, "a": [a_default], "b": b_default, "c": np.array( @@ -502,6 +636,7 @@ class ParseSingleExampleTest(tf.test.TestCase): "serialized": tf.convert_to_tensor(serialized), "features": { "st_a": tf.VarLenFeature(tf.float32), + "sp": tf.SparseFeature("idx", "val", tf.string, 13), "a": tf.FixedLenFeature( (1, 3), tf.int64, default_value=a_default), "b": tf.FixedLenFeature( diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index 9a1f96b6fe..6fe112b6be 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -62,8 +62,8 @@ def GetTestConfigs(): all the valid test configs as tuples of data_format and use_gpu. """ test_configs = [("NHWC", False), ("NHWC", True)] - if tf.test.is_gpu_available(): - # "NCHW" format is not currently supported on CPU. + if tf.test.is_gpu_available(cuda_only=True): + # "NCHW" format is currently supported exclusively on CUDA GPUs. test_configs += [("NCHW", True)] return test_configs diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py index bfb3b3a56b..4af5c3c8a2 100644 --- a/tensorflow/python/kernel_tests/reader_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_ops_test.py @@ -741,6 +741,21 @@ class TFRecordIteratorTest(tf.test.TestCase): actual.append(r) self.assertEqual(actual, original) + def testBadFile(self): + """Verify that tf_record_iterator throws an exception on bad TFRecords.""" + fn = os.path.join(self.get_temp_dir(), "bad_file") + with tf.python_io.TFRecordWriter(fn) as writer: + writer.write(b"123") + fn_truncated = os.path.join(self.get_temp_dir(), "bad_file_truncated") + with open(fn, "rb") as f: + with open(fn_truncated, "wb") as f2: + # DataLossError requires that we've written the header, so this must + # be at least 12 bytes. + f2.write(f.read(14)) + with self.assertRaises(tf.errors.DataLossError): + for _ in tf.python_io.tf_record_iterator(fn_truncated): + pass + class AsyncReaderTest(tf.test.TestCase): diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index bea631f038..776d9b6665 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -300,6 +300,36 @@ class EluTest(tf.test.TestCase): print("elu (float64) gradient of gradient err = ", err) self.assertLess(err, 1e-6) - + +class CreluTest(tf.test.TestCase): + + def testCreluShape(self): + f = tf.random_normal([50, 5, 7, 10]) + t = tf.nn.crelu(f) + self.assertEqual([50, 5, 7, 20], t.get_shape()) + + def _testCrelu(self, np_features, use_gpu=False): + np_relu = np.maximum(np_features, np.zeros_like(np_features)) + np_neg_relu = np.maximum(-np_features, np.zeros_like(np_features)) + np_crelu = np.concatenate( + (np_relu, np_neg_relu), len(np_features.shape) - 1) + + with self.test_session(use_gpu=use_gpu): + crelu = tf.nn.crelu(np_features) + tf_relu = crelu.eval() + + self.assertAllClose(np_crelu, tf_relu) + self.assertShapeEqual(np_crelu, crelu) + + def testNumbers(self): + for t in [np.int32, np.int64, np.float16, np.float32, np.float64]: + self._testCrelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), + use_gpu=False) + if t in [np.float16, np.float32, np.float64]: + self._testCrelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), + use_gpu=True) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py index e4e239169a..cc60e796ba 100644 --- a/tensorflow/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/python/kernel_tests/rnn_cell_test.py @@ -23,9 +23,10 @@ import functools import numpy as np import tensorflow as tf +from tensorflow.python.ops import rnn_cell_impl # TODO(ebrevdo): Remove once _linear is fully deprecated. # pylint: disable=protected-access -from tensorflow.python.ops.rnn_cell import _linear as linear +from tensorflow.python.ops.rnn_cell_impl import _linear as linear # pylint: enable=protected-access @@ -367,7 +368,7 @@ class SlimRNNCellTest(tf.test.TestCase): m = tf.zeros([1, 2]) my_cell = functools.partial(basic_rnn_cell, num_units=2) # pylint: disable=protected-access - g, _ = tf.nn.rnn_cell._SlimRNNCell(my_cell)(x, m) + g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m) # pylint: enable=protected-access sess.run([tf.global_variables_initializer()]) res = sess.run([g], {x.name: np.array([[1., 1.]]), @@ -384,7 +385,7 @@ class SlimRNNCellTest(tf.test.TestCase): _, initial_state = basic_rnn_cell(inputs, None, num_units) my_cell = functools.partial(basic_rnn_cell, num_units=num_units) # pylint: disable=protected-access - slim_cell = tf.nn.rnn_cell._SlimRNNCell(my_cell) + slim_cell = rnn_cell_impl._SlimRNNCell(my_cell) # pylint: enable=protected-access slim_outputs, slim_state = slim_cell(inputs, initial_state) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units) diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py index 6ec5274873..1b1810e175 100644 --- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py @@ -227,6 +227,24 @@ class ScatterNdTest(tf.test.TestCase): tf.scatter_nd_update(ref, indices, updates).get_shape().as_list(), shape) + def testExtraIndicesDimensions(self): + indices = tf.zeros([1, 1, 2], tf.int32) + updates = tf.zeros([1, 1], tf.int32) + shape = np.array([2, 2]) + scatter = tf.scatter_nd(indices, updates, shape) + self.assertAllEqual(scatter.get_shape().as_list(), shape) + expected_result = np.zeros([2, 2], dtype=np.int32) + with self.test_session(): + self.assertAllEqual(expected_result, scatter.eval()) + + ref = tf.Variable(tf.zeros(shape, tf.int32)) + scatter_update = tf.scatter_nd_update(ref, indices, updates) + self.assertAllEqual(scatter_update.get_shape().as_list(), shape) + + with self.test_session(): + ref.initializer.run() + self.assertAllEqual(expected_result, scatter_update.eval()) + def testUndefinedIndicesShape(self): indices = tf.placeholder(tf.int32, shape=None) updates = tf.placeholder(tf.int32, shape=[2, 2, 2]) diff --git a/tensorflow/python/kernel_tests/string_split_op_test.py b/tensorflow/python/kernel_tests/string_split_op_test.py index 227832b18e..5aa1390a9a 100644 --- a/tensorflow/python/kernel_tests/string_split_op_test.py +++ b/tensorflow/python/kernel_tests/string_split_op_test.py @@ -63,9 +63,6 @@ class StringSplitOpTest(tf.test.TestCase): with self.test_session() as sess: self.assertRaises( - ValueError, tf.string_split, strings, delimiter="delimiter") - - self.assertRaises( ValueError, tf.string_split, strings, delimiter=["|", ""]) self.assertRaises(ValueError, tf.string_split, strings, delimiter=["a"]) @@ -76,6 +73,12 @@ class StringSplitOpTest(tf.test.TestCase): self.assertAllEqual(values, [b"hello", b"world", b"hello world"]) self.assertAllEqual(shape, [2, 2]) + tokens = tf.string_split(strings, delimiter="| ") + indices, values, shape = sess.run(tokens) + self.assertAllEqual(indices, [[0, 0], [0, 1], [1, 0], [1, 1]]) + self.assertAllEqual(values, [b"hello", b"world", b"hello", b"world"]) + self.assertAllEqual(shape, [2, 2]) + def testStringSplitWithDelimiterTensor(self): strings = ["hello|world", "hello world"] @@ -88,14 +91,31 @@ class StringSplitOpTest(tf.test.TestCase): sess.run(tokens, feed_dict={delimiter: ["a", "b"]}) with self.assertRaises(tf.errors.InvalidArgumentError): sess.run(tokens, feed_dict={delimiter: ["a"]}) - with self.assertRaises(tf.errors.InvalidArgumentError): - sess.run(tokens, feed_dict={delimiter: "abc"}) indices, values, shape = sess.run(tokens, feed_dict={delimiter: "|"}) self.assertAllEqual(indices, [[0, 0], [0, 1], [1, 0]]) self.assertAllEqual(values, [b"hello", b"world", b"hello world"]) self.assertAllEqual(shape, [2, 2]) + def testStringSplitWithDelimitersTensor(self): + strings = ["hello.cruel,world", "hello cruel world"] + + with self.test_session() as sess: + delimiter = tf.placeholder(tf.string) + + tokens = tf.string_split(strings, delimiter=delimiter) + + with self.assertRaises(tf.errors.InvalidArgumentError): + sess.run(tokens, feed_dict={delimiter: ["a", "b"]}) + with self.assertRaises(tf.errors.InvalidArgumentError): + sess.run(tokens, feed_dict={delimiter: ["a"]}) + indices, values, shape = sess.run(tokens, feed_dict={delimiter: ".,"}) + + self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [1, 0]]) + self.assertAllEqual(values, [b"hello", b"cruel", b"world", + b"hello cruel world"]) + self.assertAllEqual(shape, [2, 3]) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py index 7c434f4561..06148eefa4 100644 --- a/tensorflow/python/kernel_tests/summary_ops_test.py +++ b/tensorflow/python/kernel_tests/summary_ops_test.py @@ -30,7 +30,8 @@ class SummaryOpsTest(tf.test.TestCase): def testScalarSummary(self): with self.test_session() as sess: const = tf.constant([10.0, 20.0]) - summ = tf.scalar_summary(["c1", "c2"], const, name="mysumm") + summ = tf.contrib.deprecated.scalar_summary( + ["c1", "c2"], const, name="mysumm") value = sess.run(summ) self.assertEqual([], summ.get_shape()) self.assertProtoEquals(""" @@ -41,7 +42,7 @@ class SummaryOpsTest(tf.test.TestCase): def testScalarSummaryDefaultName(self): with self.test_session() as sess: const = tf.constant([10.0, 20.0]) - summ = tf.scalar_summary(["c1", "c2"], const) + summ = tf.contrib.deprecated.scalar_summary(["c1", "c2"], const) value = sess.run(summ) self.assertEqual([], summ.get_shape()) self.assertProtoEquals(""" @@ -53,7 +54,7 @@ class SummaryOpsTest(tf.test.TestCase): with self.test_session() as sess: const = tf.constant(10.0) summ1 = tf.summary.histogram("h", const) - summ2 = tf.scalar_summary("c", const) + summ2 = tf.contrib.deprecated.scalar_summary("c", const) merge = tf.summary.merge([summ1, summ2]) value = sess.run(merge) self.assertEqual([], merge.get_shape()) @@ -88,11 +89,12 @@ class SummaryOpsTest(tf.test.TestCase): self.assertEqual(2, len(merge.op.inputs)) self.assertEqual(summ1, merge.op.inputs[0]) self.assertEqual(summ3, merge.op.inputs[1]) - merge = tf.merge_all_summaries("foo_key") + merge = tf.contrib.deprecated.merge_all_summaries("foo_key") self.assertEqual("MergeSummary", merge.op.type) self.assertEqual(1, len(merge.op.inputs)) self.assertEqual(summ2, merge.op.inputs[0]) - self.assertTrue(tf.merge_all_summaries("bar_key") is None) + self.assertTrue( + tf.contrib.deprecated.merge_all_summaries("bar_key") is None) def testHistogramSummaryTypes(self): with tf.Graph().as_default(): diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py index 17acbeffc4..f9508d4709 100644 --- a/tensorflow/python/kernel_tests/template_test.py +++ b/tensorflow/python/kernel_tests/template_test.py @@ -272,5 +272,34 @@ class TemplateTest(tf.test.TestCase): # Template is called at the top level, so there is no preceding "foo_2". self.assertEqual(tc.var_scope.name, "blah") + def test_custom_getter(self): + # Custom getter that maintains call count and forwards to true getter + custom_getter_count = [0] + def custom_getter(getter, name, *args, **kwargs): + custom_getter_count[0] += 1 + return getter(name, *args, **kwargs) + + # Test that custom getter is called both when variables are created and + # subsequently accessed + tmpl1 = template.make_template("s1", var_scoped_function, + custom_getter_=custom_getter) + self.assertEqual(custom_getter_count[0], 0) + tmpl1() + self.assertEqual(custom_getter_count[0], 1) + tmpl1() + self.assertEqual(custom_getter_count[0], 2) + + # Test that custom getter is called when the variable scope is created + # during construction + custom_getter_count[0] = 0 + tmpl2 = template.make_template("s2", var_scoped_function, + custom_getter_=custom_getter, + create_scope_now_=True) + self.assertEqual(custom_getter_count[0], 0) + tmpl2() + self.assertEqual(custom_getter_count[0], 1) + tmpl2() + self.assertEqual(custom_getter_count[0], 2) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index b21dcaf8e8..16f2585fec 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -367,9 +367,9 @@ class TensorArrayTest(tf.test.TestCase): # Test reading wrong datatype r0_bad = gen_data_flow_ops._tensor_array_read_v2( - handle=w0.handle, index=0, dtype=tf.int64, flow_in=w0.flow) + handle=w0.handle, index=0, dtype=tf.float64, flow_in=w0.flow) with self.assertRaisesOpError( - "TensorArray dtype is float but Op requested dtype int64."): + "TensorArray dtype is float but Op requested dtype double."): r0_bad.eval() # Test reading from a different index than the one we wrote to diff --git a/tensorflow/python/layers/__init__.py b/tensorflow/python/layers/__init__.py index e0e7658513..e69de29bb2 100644 --- a/tensorflow/python/layers/__init__.py +++ b/tensorflow/python/layers/__init__.py @@ -1,32 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# pylint: disable=line-too-long -"""This library provides a set of high-level neural networks layers. - -## Core layers - -@@fully_connected - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=g-bad-import-order,unused-import - -# Core layers. -from tensorflow.python.layers.core import fully_connected diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index df538a5bd0..8d875477f6 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -24,6 +24,7 @@ from __future__ import division from __future__ import print_function import functools +import inspect import re from six.moves import xrange # pylint: disable=redefined-builtin import numpy as np @@ -84,18 +85,33 @@ class _Layer(object): self._reuse_weights = kwargs.get('_reuse_weights') self._dtype = dtype - # Determine name. + # Determine base name (non-unique). + base_name = name if not name: - prefix = _to_snake_case(self.__class__.__name__) - name = ops.get_default_graph().unique_name(prefix, mark_as_used=False) - self.name = name + base_name = _to_snake_case(self.__class__.__name__) # Determine variable scope. scope = kwargs.get('_scope') if scope: - self._scope = scope + self._scope = next(vs.variable_scope(scope).gen) else: - self._scope = next(vs.variable_scope(None, default_name=self.name).gen) + self._scope = next(vs.variable_scope(None, default_name=base_name).gen) + + # Unique name is borrowed from scope to match variable names. + self._name = self._scope.name + + def __setattr__(self, name, value): + if hasattr(self, name): + # Only allow self to update its own attributes + stack_0_locals = inspect.stack()[1][0].f_locals + called_from_layer = stack_0_locals.get('self', None) is self + if not called_from_layer: + raise AttributeError('Read-only property cannot be set: %s' % name) + super(_Layer, self).__setattr__(name, value) + + @property + def name(self): + return self._name @property def trainable_weights(self): @@ -135,16 +151,17 @@ class _Layer(object): """ self._built = True - def call(self, inputs): + def call(self, inputs, **kwargs): """The logic of the layer lives here. Arguments: inputs: input tensor(s). + **kwargs: additional keyword arguments. Returns: Output tensor(s). """ - return inputs + raise NotImplementedError def _add_weight(self, name, shape, dtype=None, initializer=None, regularizer=None, trainable=True, @@ -186,18 +203,23 @@ class _Layer(object): regularization, ops.GraphKeys.REGULARIZATION_LOSSES) return variable - def __call__(self, inputs): + def __call__(self, inputs, **kwargs): """Wraps `call`, applying pre- and post-processing steps. Arguments: inputs: input tensor(s). + **kwargs: additional keyword arguments to be passed to `self.call`. Returns: Output tensor(s). """ - # Define a custom to override tf.get_variable when creating layer weights. + # Define a custom getter to override tf.get_variable when creating layer + # weights. We respect current custom getter, if one is set. + current_custom_getter = vs.get_variable_scope().custom_getter def variable_getter(getter, name, shape, dtype=None, initializer=None, regularizer=None, trainable=True, **kwargs): + if current_custom_getter is not None: + getter = functools.partial(current_custom_getter, getter) return self._add_weight( name, shape, initializer=initializer, regularizer=regularizer, dtype=dtype, trainable=trainable, @@ -215,7 +237,7 @@ class _Layer(object): else: self.build(input_shapes) self._built = True - outputs = self.call(inputs) + outputs = self.call(inputs, **kwargs) # Apply activity regularization. # Note that it should be applied every time the layer creates a new @@ -233,23 +255,29 @@ class _Layer(object): _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) return outputs - def apply(self, inputs): + def apply(self, inputs, **kwargs): """Apply the layer on a input. This simply wraps `self.__call__`. Arguments: - inputs: input tensor(s). + inputs: Input tensor(s). + **kwargs: additional keyword arguments to be passed to `self.call`. Returns: Output tensor(s). """ - return self.__call__(inputs) + return self.__call__(inputs, **kwargs) def _to_snake_case(name): intermediate = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) - return re.sub('([a-z0-9])([A-Z])', r'\1_\2', intermediate).lower() + insecure = re.sub('([a-z0-9])([A-Z])', r'\1_\2', intermediate).lower() + # If the class is private the name starts with "_" which is not secure + # for creating scopes. We prefix the name with "private" in this case. + if insecure[0] != '_': + return insecure + return 'private' + insecure def _to_list(x): diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index fd9ebd33d1..9262db2fc7 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -80,6 +80,9 @@ class BaseLayerTest(tf.test.TestCase): self.w = tf.get_variable('my_var', [2, 2], initializer=tf.zeros_initializer) + def call(self, inputs): + return inputs + layer = MyLayer(name='my_layer') inputs = tf.random_uniform((5,), seed=1) _ = layer.apply(inputs) @@ -98,6 +101,38 @@ class BaseLayerTest(tf.test.TestCase): self.assertEqual(layer.built, True) self.assertEqual(outputs.op.name, 'my_layer/Square') + def testNaming(self): + default_layer = base_layers._Layer() + self.assertEqual(default_layer.name, 'private__layer') + default_layer1 = base_layers._Layer() + self.assertEqual(default_layer1.name, 'private__layer_1') + my_layer = base_layers._Layer(name='my_layer') + self.assertEqual(my_layer.name, 'my_layer') + my_layer1 = base_layers._Layer(name='my_layer') + self.assertEqual(my_layer1.name, 'my_layer_1') + # New graph has fully orthogonal names. + with tf.Graph().as_default(): + my_layer_other_graph = base_layers._Layer(name='my_layer') + self.assertEqual(my_layer_other_graph.name, 'my_layer') + my_layer2 = base_layers._Layer(name='my_layer') + self.assertEqual(my_layer2.name, 'my_layer_2') + # Name scope shouldn't affect names. + with tf.name_scope('some_name_scope'): + default_layer2 = base_layers._Layer() + self.assertEqual(default_layer2.name, 'private__layer_2') + my_layer3 = base_layers._Layer(name='my_layer') + self.assertEqual(my_layer3.name, 'my_layer_3') + other_layer = base_layers._Layer(name='other_layer') + self.assertEqual(other_layer.name, 'other_layer') + # Variable scope gets added to names. + with tf.variable_scope('var_scope'): + default_layer_scoped = base_layers._Layer() + self.assertEqual(default_layer_scoped.name, 'var_scope/private__layer') + my_layer_scoped = base_layers._Layer(name='my_layer') + self.assertEqual(my_layer_scoped.name, 'var_scope/my_layer') + my_layer_scoped1 = base_layers._Layer(name='my_layer') + self.assertEqual(my_layer_scoped1.name, 'var_scope/my_layer_1') + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index b0c17a46af..f3ffbf33b9 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -27,11 +27,14 @@ from six.moves import xrange # pylint: disable=redefined-builtin import numpy as np from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import nn from tensorflow.python.ops import standard_ops from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import control_flow_ops from tensorflow.python.layers import base @@ -51,40 +54,41 @@ class FullyConnected(base._Layer): # pylint: disable=protected-access flattened prior to the initial matrix multiply by `w`. Arguments: - output_dim: Integer or Long, dimensionality of the output space. + units: Integer or Long, dimensionality of the output space. activation: Activation function (callable). Set it to None to maintain a linear activation. use_bias: Boolean, whether the layer uses a bias. - w_initializer: Initializer function for the weight matrix. + weights_initializer: Initializer function for the weight matrix. bias_initializer: Initializer function for the bias. - w_regularizer: Regularizer function for the weight matrix. + weights_regularizer: Regularizer function for the weight matrix. bias_regularizer: Regularizer function for the bias. activity_regularizer: Regularizer function for the output. trainable: Boolean, if `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). - name: String, the name of the layer. + name: String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require reuse=True in such cases. reuse: Boolean, whether to reuse the weights of a previous layer by the same name. Properties: - output_dim: Integer or Long, dimensionality of the output space. + units: Python integer, dimensionality of the output space. activation: Activation function (callable). use_bias: Boolean, whether the layer uses a bias. - w_initializer: Initializer instance (or name) for the weight matrix. + weights_initializer: Initializer instance (or name) for the weight matrix. bias_initializer: Initializer instance (or name) for the bias. - w_regularizer: Regularizer instance for the weight matrix (callable) + weights_regularizer: Regularizer instance for the weight matrix (callable) bias_regularizer: Regularizer instance for the bias (callable). activity_regularizer: Regularizer instance for the output (callable) - w: Weight matrix (TensorFlow variable or tensor). + weights: Weight matrix (TensorFlow variable or tensor). bias: Bias vector, if applicable (TensorFlow variable or tensor). """ - def __init__(self, output_dim, + def __init__(self, units, activation=None, use_bias=True, - w_initializer=None, + weights_initializer=None, bias_initializer=init_ops.zeros_initializer, - w_regularizer=None, + weights_regularizer=None, bias_regularizer=None, activity_regularizer=None, trainable=True, @@ -92,19 +96,22 @@ class FullyConnected(base._Layer): # pylint: disable=protected-access **kwargs): super(FullyConnected, self).__init__(trainable=trainable, name=name, **kwargs) - self.output_dim = output_dim + self.units = units self.activation = activation self.use_bias = use_bias - self.w_initializer = w_initializer + self.weights_initializer = weights_initializer self.bias_initializer = bias_initializer - self.w_regularizer = w_regularizer + self.weights_regularizer = weights_regularizer self.bias_regularizer = bias_regularizer self.activity_regularizer = activity_regularizer def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + if input_shape.ndims is None: + raise ValueError('Inputs to `FullyConnected` should have known rank.') if len(input_shape) < 2: raise ValueError('Inputs to `FullyConnected` should have rank >= 2.') - if input_shape[-1] is None: + if input_shape[-1].value is None: raise ValueError('The last dimension of the inputs to `FullyConnected` ' 'should be defined. Found `None`.') # Note that we set `trainable=True` because this is a trainable @@ -112,14 +119,14 @@ class FullyConnected(base._Layer): # pylint: disable=protected-access # (self.trainable = False), the variable will not be added to # tf.trainable_variables(), and self.trainable_weights will be empty. self.w = vs.get_variable('weights', - shape=[input_shape[-1], self.output_dim], - initializer=self.w_initializer, - regularizer=self.w_regularizer, + shape=[input_shape[-1].value, self.units], + initializer=self.weights_initializer, + regularizer=self.weights_regularizer, dtype=self._dtype, trainable=True) if self.use_bias: self.bias = vs.get_variable('biases', - shape=[self.output_dim,], + shape=[self.units,], initializer=self.bias_initializer, regularizer=self.bias_regularizer, dtype=self._dtype, @@ -130,11 +137,11 @@ class FullyConnected(base._Layer): # pylint: disable=protected-access def call(self, inputs): shape = inputs.get_shape().as_list() input_dim = shape[-1] - output_shape = shape[:-1] + [self.output_dim] + output_shape = shape[:-1] + [self.units] if len(output_shape) > 2: # Reshape the input to 2D. output_shape_tensors = array_ops.unpack(array_ops.shape(inputs)) - output_shape_tensors[-1] = self.output_dim + output_shape_tensors[-1] = self.units output_shape_tensor = array_ops.pack(output_shape_tensors) inputs = array_ops.reshape(inputs, [-1, input_dim]) @@ -148,17 +155,17 @@ class FullyConnected(base._Layer): # pylint: disable=protected-access outputs.set_shape(output_shape) if self.activation is not None: - return self.activation(outputs) + return self.activation(outputs) # pylint: disable=not-callable return outputs def fully_connected( - inputs, output_dim, + inputs, units, activation=None, use_bias=True, - w_initializer=None, + weights_initializer=None, bias_initializer=init_ops.zeros_initializer, - w_regularizer=None, + weights_regularizer=None, bias_regularizer=None, activity_regularizer=None, trainable=True, @@ -176,13 +183,13 @@ def fully_connected( Arguments: inputs: Tensor input. - output_dim: Integer or Long, dimensionality of the output space. + units: Integer or Long, dimensionality of the output space. activation: Activation function (callable). Set it to None to maintain a linear activation. use_bias: Boolean, whether the layer uses a bias. - w_initializer: Initializer function for the weight matrix. + weights_initializer: Initializer function for the weight matrix. bias_initializer: Initializer function for the bias. - w_regularizer: Regularizer function for the weight matrix. + weights_regularizer: Regularizer function for the weight matrix. bias_regularizer: Regularizer function for the bias. activity_regularizer: Regularizer function for the output. trainable: Boolean, if `True` also add variables to the graph collection @@ -194,15 +201,105 @@ def fully_connected( Returns: Output tensor. """ - layer = FullyConnected(output_dim, + layer = FullyConnected(units, activation=activation, use_bias=use_bias, - w_initializer=w_initializer, + weights_initializer=weights_initializer, bias_initializer=bias_initializer, - w_regularizer=w_regularizer, + weights_regularizer=weights_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, trainable=trainable, name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, _reuse_weights=reuse) return layer.apply(inputs) + + +class Dropout(base._Layer): # pylint: disable=protected-access + """Applies Dropout to the input. + + Dropout consists in randomly setting a fraction `rate` of input units to 0 + at each update during training time, which helps prevent overfitting. + The units that are kept are scaled by `1 / (1 - rate)`, so that their + sum is unchanged at training time and inference time. + + Arguments: + rate: The dropout rate, between 0 and 1. E.g. "rate=0.1" would drop out + 10% of input units. + noise_shape: 1D tensor of type `int32` representing the shape of the + binary dropout mask that will be multiplied with the input. + For instance, if your inputs have shape + `(batch_size, timesteps, features)`, and you want the dropout mask + to be the same for all timesteps, you can use + `noise_shape=[batch_size, 1, features]`. + seed: A Python integer. Used to create random seeds. See + [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed) + for behavior. + name: The name of the layer (string). + """ + + def __init__(self, rate=0.5, + noise_shape=None, + seed=None, + name=None, + **kwargs): + super(Dropout, self).__init__(name=name, **kwargs) + self.rate = rate + self.noise_shape = noise_shape + self.seed = seed + + def call(self, inputs, training=False): + if isinstance(training, bool): + training_bool = training + else: + training_bool = tensor_util.constant_value(training) + if training_bool is False: + return array_ops.identity(inputs) + dropped_inputs = nn.dropout(inputs, 1 - self.rate, + noise_shape=self.noise_shape, + seed=self.seed) + if training_bool is True: + return dropped_inputs + return control_flow_ops.cond(training, + lambda: dropped_inputs, + lambda: inputs) + + +def dropout(inputs, + rate=0.5, + noise_shape=None, + seed=None, + training=False, + name=None): + """Applies Dropout to the input. + + Dropout consists in randomly setting a fraction `rate` of input units to 0 + at each update during training time, which helps prevent overfitting. + The units that are kept are scaled by `1 / (1 - rate)`, so that their + sum is unchanged at training time and inference time. + + Arguments: + inputs: Tensor input. + rate: The dropout rate, between 0 and 1. E.g. "rate=0.1" would drop out + 10% of input units. + noise_shape: 1D tensor of type `int32` representing the shape of the + binary dropout mask that will be multiplied with the input. + For instance, if your inputs have shape + `(batch_size, timesteps, features)`, and you want the dropout mask + to be the same for all timesteps, you can use + `noise_shape=[batch_size, 1, features]`. + seed: A Python integer. Used to create random seeds. See + [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed) + for behavior. + training: Either a Python boolean, or a TensorFlow boolean scalar tensor + (e.g. a placeholder). Whether to return the output in training mode + (apply dropout) or in inference mode (return the input untouched). + name: The name of the layer (string). + + Returns: + Output tensor. + """ + layer = Dropout(rate, noise_shape=noise_shape, seed=seed, name=name) + return layer.apply(inputs, training=training) diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py index 588887ab16..710fd37fd0 100644 --- a/tensorflow/python/layers/core_test.py +++ b/tensorflow/python/layers/core_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np import tensorflow as tf from tensorflow.python.layers import core as core_layers @@ -27,9 +28,9 @@ class FullyConnectedTest(tf.test.TestCase): def testFCProperties(self): fc = core_layers.FullyConnected(2, activation=tf.nn.relu, name='fc') - self.assertEqual(fc.output_dim, 2) + self.assertEqual(fc.units, 2) self.assertEqual(fc.activation, tf.nn.relu) - self.assertEqual(fc.w_regularizer, None) + self.assertEqual(fc.weights_regularizer, None) self.assertEqual(fc.bias_regularizer, None) self.assertEqual(fc.activity_regularizer, None) self.assertEqual(fc.use_bias, True) @@ -141,7 +142,7 @@ class FullyConnectedTest(tf.test.TestCase): def testWeightsRegularizer(self): regularizer = lambda x: tf.reduce_sum(x) * 1e-3 fc = core_layers.FullyConnected(2, name='fc', - w_regularizer=regularizer) + weights_regularizer=regularizer) inputs = tf.random_uniform((5, 3), seed=1) _ = fc(inputs) loss_keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) @@ -167,6 +168,107 @@ class FullyConnectedTest(tf.test.TestCase): self.assertEqual(outputs.op.name, 'fc/Relu') self.assertEqual(outputs.get_shape().as_list(), [5, 2]) + def testFunctionalFCTwice(self): + inputs = tf.random_uniform((5, 3), seed=1) + core_layers.fully_connected(inputs, 2) + vars1 = tf.trainable_variables() + core_layers.fully_connected(inputs, 2) + vars2 = tf.trainable_variables() + self.assertEqual(len(vars1), 2) + self.assertEqual(len(vars2), 4) + + def testFunctionalFCTwiceReuse(self): + inputs = tf.random_uniform((5, 3), seed=1) + core_layers.fully_connected(inputs, 2, name='fc') + vars1 = tf.trainable_variables() + core_layers.fully_connected(inputs, 2, name='fc', reuse=True) + vars2 = tf.trainable_variables() + self.assertEqual(vars1, vars2) + + def testFunctionalFCWithCustomGetter(self): + called = [0] + def custom_getter(getter, *args, **kwargs): + called[0] += 1 + return getter(*args, **kwargs) + with tf.variable_scope('test', custom_getter=custom_getter): + inputs = tf.random_uniform((5, 3), seed=1) + core_layers.fully_connected(inputs, 2) + self.assertEqual(called[0], 2) + + def testFunctionalFCInScope(self): + with tf.variable_scope('test'): + inputs = tf.random_uniform((5, 3), seed=1) + core_layers.fully_connected(inputs, 2, name='fc') + var = tf.trainable_variables()[0] + self.assertEqual(var.name, 'test/fc/weights:0') + with tf.variable_scope('test1') as scope: + inputs = tf.random_uniform((5, 3), seed=1) + core_layers.fully_connected(inputs, 2, name=scope) + var = tf.trainable_variables()[2] + self.assertEqual(var.name, 'test1/weights:0') + with tf.variable_scope('test2'): + inputs = tf.random_uniform((5, 3), seed=1) + core_layers.fully_connected(inputs, 2) + var = tf.trainable_variables()[4] + self.assertEqual(var.name, 'test2/fully_connected/weights:0') + + +class DropoutTest(tf.test.TestCase): + + def testDropoutProperties(self): + dp = core_layers.Dropout(0.5) + self.assertEqual(dp.rate, 0.5) + self.assertEqual(dp.name, 'dropout') + self.assertEqual(dp.noise_shape, None) + + def testBooleanLearningPhase(self): + with self.test_session() as sess: + dp = core_layers.Dropout(0.5) + inputs = tf.ones((5, 3)) + dropped = dp.apply(inputs, training=True) + sess.run(tf.global_variables_initializer()) + np_output = sess.run(dropped) + self.assertAlmostEqual(0., np_output.min()) + dropped = dp.apply(inputs, training=False) + np_output = sess.run(dropped) + self.assertAllClose(np.ones((5, 3)), np_output) + + def testDynamicLearningPhase(self): + with self.test_session() as sess: + dp = core_layers.Dropout(0.5, seed=1) + inputs = tf.ones((5, 5)) + training = tf.placeholder(dtype='bool') + dropped = dp.apply(inputs, training=training) + sess.run(tf.global_variables_initializer()) + np_output = sess.run(dropped, feed_dict={training: True}) + self.assertAlmostEqual(0., np_output.min()) + np_output = sess.run(dropped, feed_dict={training: False}) + self.assertAllClose(np.ones((5, 5)), np_output) + + def testCustomNoiseShape(self): + with self.test_session() as sess: + inputs = tf.ones((5, 3, 2)) + noise_shape = [5, 1, 2] + dp = core_layers.Dropout(0.5, noise_shape=noise_shape, seed=1) + dropped = dp.apply(inputs, training=True) + sess.run(tf.global_variables_initializer()) + np_output = sess.run(dropped) + self.assertAlmostEqual(0., np_output.min()) + self.assertAllClose(np_output[:, 0, :], np_output[:, 1, :]) + + def testFunctionalDropout(self): + with self.test_session() as sess: + inputs = tf.ones((5, 5)) + training = tf.placeholder(dtype='bool') + dropped = core_layers.dropout(inputs, 0.5, training=training, seed=1) + self.assertEqual(dropped.op.name, 'dropout/cond/Merge') + + sess.run(tf.global_variables_initializer()) + np_output = sess.run(dropped, feed_dict={training: True}) + self.assertAlmostEqual(0., np_output.min()) + np_output = sess.run(dropped, feed_dict={training: False}) + self.assertAllClose(np.ones((5, 5)), np_output) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/python/layers/layers.py b/tensorflow/python/layers/layers.py new file mode 100644 index 0000000000..1466487164 --- /dev/null +++ b/tensorflow/python/layers/layers.py @@ -0,0 +1,39 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# pylint: disable=line-too-long +"""This library provides a set of high-level neural networks layers. + +## Core layers + +@@fully_connected + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.util.all_util import remove_undocumented + +# pylint: disable=g-bad-import-order,unused-import + +# Core layers. +from tensorflow.python.layers.core import fully_connected +# pylint: enable=g-bad-import-order,unused-import + +_allowed_symbols = [] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/python/lib/io/file_io.py b/tensorflow/python/lib/io/file_io.py index 2653219c5a..d3b06188ea 100644 --- a/tensorflow/python/lib/io/file_io.py +++ b/tensorflow/python/lib/io/file_io.py @@ -287,7 +287,9 @@ def create_dir(dirname): def recursive_create_dir(dirname): - """Create a directory and all parent/intermediate directories. + """Creates a directory and all parent/intermediate directories. + + It succeeds if dirname already exists and is writable. Args: dirname: string, name of the directory to be created diff --git a/tensorflow/python/lib/io/py_record_reader.cc b/tensorflow/python/lib/io/py_record_reader.cc index d3f557506e..5fcb51b3b2 100644 --- a/tensorflow/python/lib/io/py_record_reader.cc +++ b/tensorflow/python/lib/io/py_record_reader.cc @@ -55,10 +55,14 @@ PyRecordReader::~PyRecordReader() { delete file_; } -bool PyRecordReader::GetNext() { - if (reader_ == nullptr) return false; +void PyRecordReader::GetNext(TF_Status* status) { + if (reader_ == nullptr) { + Set_TF_Status_from_Status(status, + errors::FailedPrecondition("Reader is closed.")); + return; + } Status s = reader_->ReadRecord(&offset_, &record_); - return s.ok(); + Set_TF_Status_from_Status(status, s); } void PyRecordReader::Close() { diff --git a/tensorflow/python/lib/io/py_record_reader.h b/tensorflow/python/lib/io/py_record_reader.h index 0da74ee948..b7ecc928d2 100644 --- a/tensorflow/python/lib/io/py_record_reader.h +++ b/tensorflow/python/lib/io/py_record_reader.h @@ -42,10 +42,12 @@ class PyRecordReader { ~PyRecordReader(); - // Attempt to get the next record at "current_offset()". If - // successful, returns true, and the record contents can be retrieved - // with "this->record()". Otherwise, returns false. - bool GetNext(); + // Attempt to get the next record at "current_offset()". Populates status + // with OK on success, OUT_OF_RANGE for end of file, DATA_LOSS for some + // kinds of truncated reads, or another code for other errors + // (e.g., filesystem errors). + void GetNext(TF_Status* status); + // Return the current record contents. Only valid after the preceding call // to GetNext() returned true string record() const { return record_; } diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py index 9dc3ac52c2..d02baeb6cd 100644 --- a/tensorflow/python/lib/io/tf_record.py +++ b/tensorflow/python/lib/io/tf_record.py @@ -71,7 +71,12 @@ def tf_record_iterator(path, options=None): if reader is None: raise IOError("Could not open %s." % path) - while reader.GetNext(): + while True: + try: + with errors.raise_exception_on_not_ok_status() as status: + reader.GetNext(status) + except errors.OutOfRangeError: + break yield reader.record() reader.Close() diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 5ea35e8e04..1d7827bb98 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2423,6 +2423,6 @@ def where(condition, x=None, y=None, name=None): if x is None and y is None: return gen_array_ops.where(input=condition, name=name) elif x is not None and y is not None: - return gen_math_ops.select(condition=condition, t=x, e=y, name=name) + return gen_math_ops._select(condition=condition, t=x, e=y, name=name) else: raise ValueError("x and y must both be non-None or both be None.") diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 8d29de1f89..ce22ffccba 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -2236,7 +2236,9 @@ class WhileContext(ControlFlowContext): if self.outer_context: self.outer_context.Exit() else: value = op.inputs[0] - if self.outer_context: + if (isinstance(self.outer_context, WhileContext) and + self.outer_context.grad_state is not None): + # We are in a nested while loop. forward_ctxt = self.grad_state.forward_context forward_ctxt.outer_context.Enter() zeros_shape = array_ops.shape_internal(value, optimize=False) @@ -2250,8 +2252,10 @@ class WhileContext(ControlFlowContext): acc = array_ops.zeros(real_shape, grad.dtype) self.outer_context.Exit() else: + if self.outer_context: self.outer_context.Enter() zeros_shape = array_ops.shape_internal(value, optimize=False) acc = array_ops.zeros(zeros_shape, grad.dtype) + if self.outer_context: self.outer_context.Exit() acc._shape = grad.get_shape() # pylint: disable=protected-access self.Enter() diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index 5afe22e32e..ab30c8cf19 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -184,12 +184,16 @@ BatchIFFT2D BatchIFFT3D Complex Conj +FloorDiv +FloorMod Max Mean Min Pow Prod Range +RealDiv +Select SparseMatMul Sum MatMul @@ -201,6 +205,8 @@ InvGrad ReciprocalGrad SqrtGrad RsqrtGrad +TruncateDiv +TruncateMod # nn_ops AvgPoolGrad # "*Grad" accessible through nn_grad instead of nn_ops. diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index a901106e85..99f992ff5f 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -63,6 +63,7 @@ here](https://www.tensorflow.org/code/tensorflow/core/example/feature.proto). @@VarLenFeature @@FixedLenFeature @@FixedLenSequenceFeature +@@SparseFeature @@parse_example @@parse_single_example @@parse_tensor diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 3502f11892..5a490b5395 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -211,7 +211,7 @@ def _SegmentMinOrMaxGrad(op, grad): weighted_grads = math_ops.div(grad, num_selected) gathered_grads = array_ops.gather(weighted_grads, op.inputs[1]) - return math_ops.select(is_selected, gathered_grads, zeros), None + return array_ops.where(is_selected, gathered_grads, zeros), None @ops.RegisterGradient("SegmentMin") @@ -674,11 +674,11 @@ def _PowGrad(op, grad): # Avoid false singularity at x = 0 if x.dtype.is_complex: # real(x) < 0 is fine for the complex case - log_x = math_ops.select( + log_x = array_ops.where( math_ops.not_equal(x, 0), math_ops.log(x), array_ops.zeros_like(x)) else: # There's no sensible real value to return if x < 0, so return 0 - log_x = math_ops.select(x > 0, math_ops.log(x), array_ops.zeros_like(x)) + log_x = array_ops.where(x > 0, math_ops.log(x), array_ops.zeros_like(x)) gy = array_ops.reshape( math_ops.reduce_sum(grad * z * log_x, ry), sy) return gx, gy @@ -695,8 +695,8 @@ def _MaximumMinimumGrad(op, grad, selector_op): zeros = array_ops.zeros(gradshape, gdtype) xmask = selector_op(x, y) rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) - xgrad = math_ops.select(xmask, grad, zeros) - ygrad = math_ops.select(math_ops.logical_not(xmask), grad, zeros) + xgrad = array_ops.where(xmask, grad, zeros) + ygrad = array_ops.where(math_ops.logical_not(xmask), grad, zeros) gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx) gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy) return (gx, gy) @@ -750,8 +750,8 @@ def _SelectGrad(op, grad): c = op.inputs[0] x = op.inputs[1] zeros = array_ops.zeros_like(x) - return (None, math_ops.select(c, grad, zeros), - math_ops.select(c, zeros, grad)) + return (None, array_ops.where(c, grad, zeros), + array_ops.where(c, zeros, grad)) @ops.RegisterGradient("MatMul") diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 9cf092edd7..21bfe205ef 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -235,6 +235,7 @@ from tensorflow.python.ops import state_ops # pylint: disable=wildcard-import from tensorflow.python.ops.gen_math_ops import * # pylint: enable=wildcard-import +from tensorflow.python.util.deprecation import deprecated # Aliases for some automatically-generated names. linspace = gen_math_ops.lin_space @@ -889,31 +890,7 @@ def _sparse_dense_truediv(sp_indices, sp_values, sp_shape, y, name=None): sp_indices, sp_values, sp_shape, y, name=name) -def truediv(x, y, name=None): - """Divides x / y elementwise, always producing floating point results. - - The same as `tf.div` for floating point arguments, but casts integer arguments - to floating point before dividing so that the result is always floating point. - This op is generated by normal `x / y` division in Python 3 and in Python 2.7 - with `from __future__ import division`. If you want integer division that - rounds down, use `x // y` or `tf.floordiv`. - - `x` and `y` must have the same numeric type. If the inputs are floating - point, the output will have the same type. If the inputs are integral, the - inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32` - and `int64` (matching the behavior of Numpy). - - Args: - x: `Tensor` numerator of numeric type. - y: `Tensor` denominator of numeric type. - name: A name for the operation (optional). - - Returns: - `x / y` evaluated in floating point. - - Raises: - TypeError: If `x` and `y` have different dtypes. - """ +def _truediv_python3(x, y, name=None): with ops.name_scope(name, "truediv", [x, y]) as name: x = ops.convert_to_tensor(x, name="x") y = ops.convert_to_tensor(y, name="y") @@ -929,11 +906,21 @@ def truediv(x, y, name=None): if dtype is not None: x = cast(x, dtype) y = cast(y, dtype) - return gen_math_ops.real_div(x, y, name=name) + return gen_math_ops._real_div(x, y, name=name) -def div(x, y, name=None): - with ops.name_scope(name, "truediv", [x, y]) as name: +def _div_python2(x, y, name=None): + """Divide two values using Python 2 semantics. Used for Tensor.__div__. + + Args: + x: `Tensor` numerator of real numeric type. + y: `Tensor` denominator of real numeric type. + name: A name for the operation (optional). + Returns: + `x / y` returns the quotient of x and y. + """ + + with ops.name_scope(name, "div", [x, y]) as name: x = ops.convert_to_tensor(x, name="x") y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype) x_dtype = x.dtype.base_dtype @@ -942,20 +929,65 @@ def div(x, y, name=None): raise TypeError("x and y must have the same dtype, got %r != %r" % (x_dtype, y_dtype)) if x_dtype.is_floating or x_dtype.is_complex: - return gen_math_ops.real_div(x, y, name=name) + return gen_math_ops._real_div(x, y, name=name) else: - return gen_math_ops.floor_div(x, y, name=name) + return gen_math_ops._floor_div(x, y, name=name) + + +def truediv(x, y, name=None): + """Divides x / y elementwise (using Python 3 division operator semantics). + + NOTE: Prefer using the Tensor operator or tf.divide which obey Python + division operator semantics. + + This function forces Python 3 division operator semantics where all integer + arguments are cast to floating types first. This op is generated by normal + `x / y` division in Python 3 and in Python 2.7 with + `from __future__ import division`. If you want integer division that rounds + down, use `x // y` or `tf.floordiv`. + + `x` and `y` must have the same numeric type. If the inputs are floating + point, the output will have the same type. If the inputs are integral, the + inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32` + and `int64` (matching the behavior of Numpy). + + Args: + x: `Tensor` numerator of numeric type. + y: `Tensor` denominator of numeric type. + name: A name for the operation (optional). + + Returns: + `x / y` evaluated in floating point. + + Raises: + TypeError: If `x` and `y` have different dtypes. + """ + return _truediv_python3(x, y, name) -def div_deprecated(x, y, name=None): - return gen_math_ops.div(x, y, name) +def div(x, y, name=None): + """Divides x / y elementwise (using Python 2 division operator semantics). + NOTE: Prefer using the Tensor division operator or tf.divide which obey Python + division operator semantics. -mod = gen_math_ops.floor_mod + This function divides `x` and `y`, forcing Python 2.7 semantics. That is, + if one of `x` or `y` is a float, then the result will be a float. + Otherwise, the output will be an integer type. Flooring semantics are used + for integer division. + Args: + x: `Tensor` numerator of real numeric type. + y: `Tensor` denominator of real numeric type. + name: A name for the operation (optional). + Returns: + `x / y` returns the quotient of x and y. + """ + return _div_python2(x, y, name) -def mod_deprecated(x, y, name=None): - return gen_math_ops.mod(x, y, name) + +# TODO(aselle): This should be removed +mod = gen_math_ops._floor_mod # TODO(aselle): Deprecate this once all internal functionality uses @@ -987,29 +1019,15 @@ def floordiv(x, y, name=None): TypeError: If the inputs are complex. """ with ops.name_scope(name, "floordiv", [x, y]) as name: - return gen_math_ops.floor_div(x, y, name=name) + return gen_math_ops._floor_div(x, y, name=name) -def floordiv_deprecated(x, y, name=None): - with ops.name_scope(name, "floordiv", [x, y]) as name: - x = ops.convert_to_tensor(x, name="x") - dtype = x.dtype - if dtype.is_floating: - return gen_math_ops.floor(gen_math_ops.div(x, y), name=name) - else: - if not dtype.is_integer: - raise TypeError("Expected floating point or integer, got %r" % dtype) - # TODO(aselle): Switch to math_ops.floor_div() when ready - # return gen_math_ops.floor_div(x, y, name=name) - return gen_math_ops.div(x, y, name=name) - - -realdiv = gen_math_ops.real_div -truncatediv = gen_math_ops.truncate_div +realdiv = gen_math_ops._real_div +truncatediv = gen_math_ops._truncate_div # TODO(aselle): Rename this to floordiv when we can. -floor_div = gen_math_ops.floor_div -truncatemod = gen_math_ops.truncate_mod -floormod = gen_math_ops.floor_mod +floor_div = gen_math_ops._floor_div +truncatemod = gen_math_ops._truncate_mod +floormod = gen_math_ops._floor_mod def _mul_dispatch(x, y, name=None): @@ -1023,7 +1041,9 @@ def _mul_dispatch(x, y, name=None): y.shape, x, name) return sparse_tensor.SparseTensor(y.indices, new_vals, y.shape) - +# NOTE(aselle): When integer division is added for sparse_dense_cwise, +# div, truediv, and floordiv should be delegated appropriately for +# Python sematnics, analogous to dense cwise tensor operations. _OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_div, "div", sparse_tensor.SparseTensor) _OverrideBinaryOperatorHelper(_sparse_dense_truediv, "truediv", @@ -1034,12 +1054,12 @@ _OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_mul, "mul", _OverrideBinaryOperatorHelper(gen_math_ops.add, "add") _OverrideBinaryOperatorHelper(gen_math_ops.sub, "sub") _OverrideBinaryOperatorHelper(_mul_dispatch, "mul") -_OverrideBinaryOperatorHelper(div, "div") -_OverrideBinaryOperatorHelper(truediv, "truediv") +_OverrideBinaryOperatorHelper(_div_python2, "div") +_OverrideBinaryOperatorHelper(_truediv_python3, "truediv") _OverrideBinaryOperatorHelper(floordiv, "floordiv") # TODO(aselle): Switch mod to floor_mod when ready # _OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod") -_OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod") +_OverrideBinaryOperatorHelper(gen_math_ops._floor_mod, "mod") _OverrideBinaryOperatorHelper(pow, "pow") @@ -2146,3 +2166,13 @@ def reduced_shape(input_shape, axes): input_shape, # [2, 3, 5, 7] array_ops.fill(axes_shape, 1) ]) # [1, 1] + + +@deprecated( + "2016-12-07", + "This op will be removed after the deprecation date. " + "Please switch to tf.where().") +def select(condition, x, y, name=None): + return gen_math_ops._select(condition, x, y, name) +select.__doc__ = gen_math_ops._select.__doc__ + diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 197ddb6a75..b2fc3a84d4 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -309,7 +309,7 @@ class DivAndModTest(test_util.TensorFlowTestCase): def testComplexDiv(self): foo = array_ops.constant([1.+3.j]) with self.test_session(): - _ = math_ops.div_deprecated(foo, 1.).eval() + _ = math_ops.divide(foo, 1.).eval() _ = math_ops.div(foo, 2.).eval() def testFloorDivGrad(self): @@ -318,7 +318,7 @@ class DivAndModTest(test_util.TensorFlowTestCase): b = variables.Variable(4.) with self.test_session() as sess: sess.run(variables.initialize_all_variables()) - c_grad = gradients.gradients(math_ops.div_deprecated(a, b), [a, b]) + c_grad = gradients.gradients(math_ops.divide(a, b), [a, b]) self.assertAllEqual([x.eval() for x in c_grad], [.25, -.125]) c_grad = gradients.gradients(math_ops.div(a, b), [a, b]) self.assertAllEqual([x.eval() for x in c_grad], [.25, -.125]) @@ -330,7 +330,7 @@ class DivAndModTest(test_util.TensorFlowTestCase): nums, divs = self.intTestData() with self.test_session(): tf_result = ( - math_ops.floor_div(nums, divs) * divs + math_ops.floor_mod(nums, divs) + math_ops.floor_div(nums, divs) * divs + math_ops.floormod(nums, divs) ).eval() tf_nums = array_ops.constant(nums) tf_divs = array_ops.constant(divs) diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 149bde451a..bfa15f9401 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -25,7 +25,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import gen_nn_ops -from tensorflow.python.ops import gen_math_ops + @ops.RegisterGradient("Conv2DBackpropInput") def _Conv2DBackpropInputGrad(op, grad): @@ -271,9 +271,10 @@ def _ReluGrad(op, grad): @ops.RegisterGradient("EluGrad") def _EluGradGrad(op, grad): x = op.inputs[1] - return (gen_nn_ops._elu_grad(grad, op.outputs[0]), - gen_math_ops.select(x < 0., gen_nn_ops._elu_grad(grad, op.outputs[0] + 1), - array_ops.zeros(shape = array_ops.shape(x), dtype = x.dtype))) + return (gen_nn_ops._elu_grad(grad, op.outputs[0]), + array_ops.where( + x < 0., gen_nn_ops._elu_grad(grad, op.outputs[0] + 1), + array_ops.zeros(shape = array_ops.shape(x), dtype = x.dtype))) @ops.RegisterGradient("Relu6") diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 4ef95c1146..afacef7acd 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -92,7 +92,7 @@ def log_poisson_loss(log_input, targets, compute_full_loss=False, name=None): zeros = array_ops.zeros_like(targets, dtype=targets.dtype) ones = array_ops.ones_like(targets, dtype=targets.dtype) cond = math_ops.logical_and(targets >= zeros, targets <= ones) - result += math_ops.select(cond, zeros, stirling_approx) + result += array_ops.where(cond, zeros, stirling_approx) return result @@ -157,8 +157,8 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None): # abs functions. zeros = array_ops.zeros_like(logits, dtype=logits.dtype) cond = (logits >= zeros) - relu_logits = math_ops.select(cond, logits, zeros) - neg_abs_logits = math_ops.select(cond, -logits, logits) + relu_logits = array_ops.where(cond, logits, zeros) + neg_abs_logits = array_ops.where(cond, -logits, logits) return math_ops.add(relu_logits - logits * targets, math_ops.log1p(math_ops.exp(neg_abs_logits)), name=name) @@ -292,7 +292,7 @@ def zero_fraction(value, name=None): ```python z = tf.Relu(...) - summ = tf.scalar_summary('sparsity', tf.nn.zero_fraction(z)) + summ = tf.contrib.deprecated.scalar_summary('sparsity', tf.nn.zero_fraction(z)) ``` Args: diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 31db4e9d56..35610cc554 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1218,8 +1218,8 @@ def crelu(features, name=None): """ with ops.name_scope(name, "CRelu", [features]) as name: features = ops.convert_to_tensor(features, name="features") - return gen_nn_ops.relu(array_ops.concat(array_ops.rank(features) - 1, - [features, -features], name=name)) + c = array_ops.concat(-1, [features, -features], name=name) + return gen_nn_ops.relu(c) def relu6(features, name=None): diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py index 21b957380a..fa99e3a49b 100644 --- a/tensorflow/python/ops/parsing_ops.py +++ b/tensorflow/python/ops/parsing_ops.py @@ -22,6 +22,7 @@ import collections import re from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape @@ -29,6 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_parsing_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops # go/tf-wildcard-import # pylint: disable=wildcard-import,undefined-variable from tensorflow.python.ops.gen_parsing_ops import * @@ -49,6 +51,28 @@ class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])): pass +class SparseFeature( + collections.namedtuple( + "SparseFeature", + ["index_key", "value_key", "dtype", "size", "already_sorted"])): + """Configuration for parsing a sparse input feature. + + Fields: + index_key: Name of index feature. The underlying feature's type must + be `int64` and its length must always match that of the `value_key` + feature. + value_key: Name of value feature. The underlying feature's type must + be `dtype` and its length must always match that of the `index_key` + feature. + dtype: Data type of the `value_key` feature. + size: Each value in the `index_key` feature must be in `[0, size)`. + already_sorted: A boolean to specify whether the values in `index_key` are + already sorted. If so skip sorting, False by default (optional). + """ + pass +SparseFeature.__new__.__defaults__ = (False,) + + class FixedLenFeature(collections.namedtuple( "FixedLenFeature", ["shape", "dtype", "default_value"])): """Configuration for parsing a fixed-length input feature. @@ -91,7 +115,7 @@ def _features_to_raw_params(features, types): Args: features: A `dict` mapping feature keys to objects of a type in `types`. types: Type of features to allow, among `FixedLenFeature`, `VarLenFeature`, - and `FixedLenSequenceFeature`. + `SparseFeature`, and `FixedLenSequenceFeature`. Returns: Tuple of `sparse_keys`, `sparse_types`, `dense_keys`, `dense_types`, @@ -118,6 +142,34 @@ def _features_to_raw_params(features, types): raise ValueError("Missing type for feature %s." % key) sparse_keys.append(key) sparse_types.append(feature.dtype) + elif isinstance(feature, SparseFeature): + if SparseFeature not in types: + raise ValueError("Unsupported SparseFeature %s.", feature) + if not feature.index_key: + raise ValueError( + "Missing index_key for SparseFeature %s.", feature) + if not feature.value_key: + raise ValueError( + "Missing value_key for SparseFeature %s.", feature) + if not feature.dtype: + raise ValueError("Missing type for feature %s." % key) + if feature.index_key in sparse_keys: + dtype = sparse_types[sparse_keys.index(feature.index_key)] + if dtype != dtypes.int64: + raise ValueError("Conflicting type %s vs int64 for feature %s." % ( + dtype, feature.index_key)) + else: + sparse_keys.append(feature.index_key) + sparse_types.append(dtypes.int64) + + if feature.value_key in sparse_keys: + dtype = sparse_types[sparse_keys.index(feature.value_key)] + if dtype != feature.dtype: + raise ValueError("Conflicting type %s vs %s for feature %s." % ( + dtype, feature.dtype, feature.value_key)) + else: + sparse_keys.append(feature.value_key) + sparse_types.append(feature.dtype) elif isinstance(feature, FixedLenFeature): if FixedLenFeature not in types: raise ValueError("Unsupported FixedLenFeature %s.", feature) @@ -149,6 +201,38 @@ def _features_to_raw_params(features, types): dense_shapes) +def _construct_sparse_tensors_for_sparse_features(features, tensor_dict): + """Merges SparseTensors of indices and values of SparseFeatures. + + Updates `tensor_dict`. For `SparseFeatures` in the values of `features` + expects their `index_key`s and `index_value`s to be present in `tensor_dict` + mapping to `SparseTensor`s. Removes those, constructs a single `SparseTensor` + from them, and adds it to `tensor_dict` with the key from `features`. + + Args: + features: A `dict` mapping feature keys to `SparseFeature` values. + Values of other types will be ignored. + tensor_dict: A `dict` mapping feature keys to `Tensor` and `SparseTensor` + values. Expected to contain keys of the `SparseFeature`s' `index_key`s and + `value_key`s and mapping them to `SparseTensor`s. + """ + # Construct SparseTensors for SparseFeatures. + for key in sorted(features.keys()): + feature = features[key] + if isinstance(feature, SparseFeature): + sp_ids = tensor_dict[feature.index_key] + sp_values = tensor_dict[feature.value_key] + tensor_dict[key] = sparse_ops.sparse_merge( + sp_ids, + sp_values, + feature.size, + feature.already_sorted) + # Remove tensors from dictionary that were only used to construct + # SparseTensors for SparseFeature. + for key in set(tensor_dict.keys()) - set(features.keys()): + del tensor_dict[key] + + def parse_example(serialized, features, name=None, example_names=None): # pylint: disable=line-too-long """Parses `Example` protos into a `dict` of tensors. @@ -158,18 +242,27 @@ def parse_example(serialized, features, name=None, example_names=None): `example_names` may contain descriptive names for the corresponding serialized protos. These may be useful for debugging purposes, but they have no effect on - the output. If not `None`, `example_names` must be the same length as `serialized`. + the output. If not `None`, `example_names` must be the same length as + `serialized`. This op parses serialized examples into a dictionary mapping keys to `Tensor` - and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature` - and `FixedLenFeature` objects. Each `VarLenFeature` is mapped to a - `SparseTensor`, and each `FixedLenFeature` is mapped to a `Tensor`. + and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`, + `SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature` + and `SparseFeature` is mapped to a `SparseTensor`, and each + `FixedLenFeature` is mapped to a `Tensor`. Each `VarLenFeature` maps to a `SparseTensor` of the specified type representing a ragged matrix. Its indices are `[batch, index]` where `batch` is the batch entry the value is from in `serialized`, and `index` is the value's index in the list of values associated with that feature and example. + Each `SparseFeature` maps to a `SparseTensor` of the specified type + representing a sparse matrix of shape + `(serialized.size(), SparseFeature.size)`. Its indices are `[batch, index]` + where `batch` is the batch entry the value is from in `serialized`, and + `index` is the value's index is given by the values in the + `SparseFeature.index_key` feature column. + Each `FixedLenFeature` `df` maps to a `Tensor` of the specified type (or `tf.float32` if not specified) and shape `(serialized.size(),) + df.shape`. @@ -281,11 +374,46 @@ def parse_example(serialized, features, name=None, example_names=None): } ``` + Given two `Example` input protos in `serialized`: + + ``` + [ + features { + feature { key: "val" value { float_list { value: [ 0.5, -1.0 ] } } } + feature { key: "ix" value { int64_list { value: [ 3, 20 ] } } } + }, + features { + feature { key: "val" value { float_list { value: [ 0.0 ] } } } + feature { key: "ix" value { int64_list { value: [ 42 ] } } } + } + ] + ``` + + And arguments + + ``` + example_names: ["input0", "input1"], + features: { + "sparse": SparseFeature("ix", "val", tf.float32, 100), + } + ``` + + Then the output is a dictionary: + + ```python + { + "sparse": SparseTensor( + indices=[[0, 3], [0, 20], [1, 42]], + values=[0.5, -1.0, 0.0] + shape=[2, 100]), + } + ``` + Args: serialized: A vector (1-D Tensor) of strings, a batch of binary serialized `Example` protos. - features: A `dict` mapping feature keys to `FixedLenFeature` or - `VarLenFeature` values. + features: A `dict` mapping feature keys to `FixedLenFeature`, + `VarLenFeature`, and `SparseFeature` values. name: A name for this operation (optional). example_names: A vector (1-D Tensor) of strings (optional), the names of the serialized protos in the batch. @@ -300,10 +428,12 @@ def parse_example(serialized, features, name=None, example_names=None): raise ValueError("Missing: features was %s." % features) (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults, dense_shapes) = _features_to_raw_params( - features, [VarLenFeature, FixedLenFeature]) - return _parse_example_raw( + features, [VarLenFeature, SparseFeature, FixedLenFeature]) + outputs = _parse_example_raw( serialized, example_names, sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults, dense_shapes, name) + _construct_sparse_tensors_for_sparse_features(features, outputs) + return outputs def _parse_example_raw(serialized, @@ -410,8 +540,7 @@ def _parse_example_raw(serialized, sparse_tensor.SparseTensor(ix, val, shape) for (ix, val, shape) in zip(sparse_indices, sparse_values, sparse_shapes)] - return dict( - zip(sparse_keys + dense_keys, sparse_tensors + dense_values)) + return dict(zip(sparse_keys + dense_keys, sparse_tensors + dense_values)) def parse_single_example(serialized, features, name=None, example_names=None): @@ -447,10 +576,12 @@ def parse_single_example(serialized, features, name=None, example_names=None): raise ValueError("Missing features.") (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults, dense_shapes) = _features_to_raw_params( - features, [VarLenFeature, FixedLenFeature]) - return _parse_single_example_raw( + features, [VarLenFeature, FixedLenFeature, SparseFeature]) + outputs = _parse_single_example_raw( serialized, example_names, sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults, dense_shapes, name) + _construct_sparse_tensors_for_sparse_features(features, outputs) + return outputs def _parse_single_example_raw(serialized, @@ -514,15 +645,16 @@ def _parse_single_example_raw(serialized, name="NamesDependencies") names = array_ops.expand_dims(names, 0) - outputs = _parse_example_raw(serialized, - names=names, - sparse_keys=sparse_keys, - sparse_types=sparse_types, - dense_keys=dense_keys, - dense_types=dense_types, - dense_defaults=dense_defaults, - dense_shapes=dense_shapes, - name=name) + outputs = _parse_example_raw( + serialized, + names=names, + sparse_keys=sparse_keys, + sparse_types=sparse_types, + dense_keys=dense_keys, + dense_types=dense_types, + dense_defaults=dense_defaults, + dense_shapes=dense_shapes, + name=name) if dense_keys is not None: for d in dense_keys: d_name = re.sub("[^A-Za-z0-9_.\\-/]", "_", d) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index b1270a1937..61536ab4a0 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -27,13 +27,14 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.util import nest # pylint: disable=protected-access -_state_size_with_prefix = rnn_cell._state_size_with_prefix +_state_size_with_prefix = rnn_cell_impl._state_size_with_prefix # pylint: enable=protected-access @@ -365,7 +366,7 @@ def _rnn_step( def _copy_one_through(output, new_output): copy_cond = (time >= sequence_length) - return math_ops.select(copy_cond, output, new_output) + return array_ops.where(copy_cond, output, new_output) def _copy_some_through(flat_new_output, flat_new_state): # Use broadcasting select to determine which values should get @@ -1298,7 +1299,7 @@ def raw_rnn(cell, loop_fn, current_flat = nest.flatten(current) candidate_flat = nest.flatten(candidate) result_flat = [ - math_ops.select(elements_finished, current_i, candidate_i) + array_ops.where(elements_finished, current_i, candidate_i) for (current_i, candidate_i) in zip(current_flat, candidate_flat)] return nest.pack_sequence_as( structure=current, flat_sequence=result_flat) diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py index d620177e90..b6da265ae0 100644 --- a/tensorflow/python/ops/rnn_cell.py +++ b/tensorflow/python/ops/rnn_cell.py @@ -42,854 +42,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import math +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.python.ops.rnn_cell_impl import * +# pylint: enable=wildcard-import +# TODO(drpng): remove this once internal use has been eradicated. +# pylint: disable=unused-import +from tensorflow.python.ops.rnn_cell_impl import _linear +# pylint: enable=unused-import +from tensorflow.python.util.all_util import remove_undocumented -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import embedding_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import partitioned_variables -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops.math_ops import sigmoid -from tensorflow.python.ops.math_ops import tanh +_allowed_symbols = [] -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util import nest - - -def _state_size_with_prefix(state_size, prefix=None): - """Helper function that enables int or TensorShape shape specification. - - This function takes a size specification, which can be an integer or a - TensorShape, and converts it into a list of integers. One may specify any - additional dimensions that precede the final state size specification. - - Args: - state_size: TensorShape or int that specifies the size of a tensor. - prefix: optional additional list of dimensions to prepend. - - Returns: - result_state_size: list of dimensions the resulting tensor size. - """ - result_state_size = tensor_shape.as_shape(state_size).as_list() - if prefix is not None: - if not isinstance(prefix, list): - raise TypeError("prefix of _state_size_with_prefix should be a list.") - result_state_size = prefix + result_state_size - return result_state_size - - -class RNNCell(object): - """Abstract object representing an RNN cell. - - The definition of cell in this package differs from the definition used in the - literature. In the literature, cell refers to an object with a single scalar - output. The definition in this package refers to a horizontal array of such - units. - - An RNN cell, in the most abstract setting, is anything that has - a state and performs some operation that takes a matrix of inputs. - This operation results in an output matrix with `self.output_size` columns. - If `self.state_size` is an integer, this operation also results in a new - state matrix with `self.state_size` columns. If `self.state_size` is a - tuple of integers, then it results in a tuple of `len(state_size)` state - matrices, each with a column size corresponding to values in `state_size`. - - This module provides a number of basic commonly used RNN cells, such as - LSTM (Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number - of operators that allow add dropouts, projections, or embeddings for inputs. - Constructing multi-layer cells is supported by the class `MultiRNNCell`, - or by calling the `rnn` ops several times. Every `RNNCell` must have the - properties below and and implement `__call__` with the following signature. - """ - - def __call__(self, inputs, state, scope=None): - """Run this RNN cell on inputs, starting from the given state. - - Args: - inputs: `2-D` tensor with shape `[batch_size x input_size]`. - state: if `self.state_size` is an integer, this should be a `2-D Tensor` - with shape `[batch_size x self.state_size]`. Otherwise, if - `self.state_size` is a tuple of integers, this should be a tuple - with shapes `[batch_size x s] for s in self.state_size`. - scope: VariableScope for the created subgraph; defaults to class name. - - Returns: - A pair containing: - - - Output: A `2-D` tensor with shape `[batch_size x self.output_size]`. - - New state: Either a single `2-D` tensor, or a tuple of tensors matching - the arity and shapes of `state`. - """ - raise NotImplementedError("Abstract method") - - @property - def state_size(self): - """size(s) of state(s) used by this cell. - - It can be represented by an Integer, a TensorShape or a tuple of Integers - or TensorShapes. - """ - raise NotImplementedError("Abstract method") - - @property - def output_size(self): - """Integer or TensorShape: size of outputs produced by this cell.""" - raise NotImplementedError("Abstract method") - - def zero_state(self, batch_size, dtype): - """Return zero-filled state tensor(s). - - Args: - batch_size: int, float, or unit Tensor representing the batch size. - dtype: the data type to use for the state. - - Returns: - If `state_size` is an int or TensorShape, then the return value is a - `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. - - If `state_size` is a nested list or tuple, then the return value is - a nested list or tuple (of the same structure) of `2-D` tensors with - the shapes `[batch_size x s]` for each s in `state_size`. - """ - state_size = self.state_size - if nest.is_sequence(state_size): - state_size_flat = nest.flatten(state_size) - zeros_flat = [ - array_ops.zeros( - array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])), - dtype=dtype) - for s in state_size_flat] - for s, z in zip(state_size_flat, zeros_flat): - z.set_shape(_state_size_with_prefix(s, prefix=[None])) - zeros = nest.pack_sequence_as(structure=state_size, - flat_sequence=zeros_flat) - else: - zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size]) - zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype) - zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None])) - - return zeros - - -class BasicRNNCell(RNNCell): - """The most basic RNN cell.""" - - def __init__(self, num_units, input_size=None, activation=tanh): - if input_size is not None: - logging.warn("%s: The input_size parameter is deprecated.", self) - self._num_units = num_units - self._activation = activation - - @property - def state_size(self): - return self._num_units - - @property - def output_size(self): - return self._num_units - - def __call__(self, inputs, state, scope=None): - """Most basic RNN: output = new_state = act(W * input + U * state + B).""" - with vs.variable_scope(scope or "basic_rnn_cell"): - output = self._activation( - _linear([inputs, state], self._num_units, True, scope=scope)) - return output, output - - -class GRUCell(RNNCell): - """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).""" - - def __init__(self, num_units, input_size=None, activation=tanh): - if input_size is not None: - logging.warn("%s: The input_size parameter is deprecated.", self) - self._num_units = num_units - self._activation = activation - - @property - def state_size(self): - return self._num_units - - @property - def output_size(self): - return self._num_units - - def __call__(self, inputs, state, scope=None): - """Gated recurrent unit (GRU) with nunits cells.""" - with vs.variable_scope(scope or "gru_cell"): - with vs.variable_scope("gates"): # Reset gate and update gate. - # We start with bias of 1.0 to not reset and not update. - r, u = array_ops.split( - 1, 2, _linear([inputs, state], 2 * self._num_units, True, 1.0, - scope=scope)) - r, u = sigmoid(r), sigmoid(u) - with vs.variable_scope("candidate"): - c = self._activation(_linear([inputs, r * state], - self._num_units, True, - scope=scope)) - new_h = u * state + (1 - u) * c - return new_h, new_h - - -_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h")) - - -class LSTMStateTuple(_LSTMStateTuple): - """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. - - Stores two elements: `(c, h)`, in that order. - - Only used when `state_is_tuple=True`. - """ - __slots__ = () - - @property - def dtype(self): - (c, h) = self - if not c.dtype == h.dtype: - raise TypeError("Inconsistent internal state: %s vs %s" % - (str(c.dtype), str(h.dtype))) - return c.dtype - - -class BasicLSTMCell(RNNCell): - """Basic LSTM recurrent network cell. - - The implementation is based on: http://arxiv.org/abs/1409.2329. - - We add forget_bias (default: 1) to the biases of the forget gate in order to - reduce the scale of forgetting in the beginning of the training. - - It does not allow cell clipping, a projection layer, and does not - use peep-hole connections: it is the basic baseline. - - For advanced models, please use the full LSTMCell that follows. - """ - - def __init__(self, num_units, forget_bias=1.0, input_size=None, - state_is_tuple=True, activation=tanh): - """Initialize the basic LSTM cell. - - Args: - num_units: int, The number of units in the LSTM cell. - forget_bias: float, The bias added to forget gates (see above). - input_size: Deprecated and unused. - state_is_tuple: If True, accepted and returned states are 2-tuples of - the `c_state` and `m_state`. If False, they are concatenated - along the column axis. The latter behavior will soon be deprecated. - activation: Activation function of the inner states. - """ - if not state_is_tuple: - logging.warn("%s: Using a concatenated state is slower and will soon be " - "deprecated. Use state_is_tuple=True.", self) - if input_size is not None: - logging.warn("%s: The input_size parameter is deprecated.", self) - self._num_units = num_units - self._forget_bias = forget_bias - self._state_is_tuple = state_is_tuple - self._activation = activation - - @property - def state_size(self): - return (LSTMStateTuple(self._num_units, self._num_units) - if self._state_is_tuple else 2 * self._num_units) - - @property - def output_size(self): - return self._num_units - - def __call__(self, inputs, state, scope=None): - """Long short-term memory cell (LSTM).""" - with vs.variable_scope(scope or "basic_lstm_cell"): - # Parameters of gates are concatenated into one multiply for efficiency. - if self._state_is_tuple: - c, h = state - else: - c, h = array_ops.split(1, 2, state) - concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope) - - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - i, j, f, o = array_ops.split(1, 4, concat) - - new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * - self._activation(j)) - new_h = self._activation(new_c) * sigmoid(o) - - if self._state_is_tuple: - new_state = LSTMStateTuple(new_c, new_h) - else: - new_state = array_ops.concat(1, [new_c, new_h]) - return new_h, new_state - - -class LSTMCell(RNNCell): - """Long short-term memory unit (LSTM) recurrent network cell. - - The default non-peephole implementation is based on: - - http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf - - S. Hochreiter and J. Schmidhuber. - "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. - - The peephole implementation is based on: - - https://research.google.com/pubs/archive/43905.pdf - - Hasim Sak, Andrew Senior, and Francoise Beaufays. - "Long short-term memory recurrent neural network architectures for - large scale acoustic modeling." INTERSPEECH, 2014. - - The class uses optional peep-hole connections, optional cell clipping, and - an optional projection layer. - """ - - def __init__(self, num_units, input_size=None, - use_peepholes=False, cell_clip=None, - initializer=None, num_proj=None, proj_clip=None, - num_unit_shards=None, num_proj_shards=None, - forget_bias=1.0, state_is_tuple=True, - activation=tanh): - """Initialize the parameters for an LSTM cell. - - Args: - num_units: int, The number of units in the LSTM cell - input_size: Deprecated and unused. - use_peepholes: bool, set True to enable diagonal/peephole connections. - cell_clip: (optional) A float value, if provided the cell state is clipped - by this value prior to the cell output activation. - initializer: (optional) The initializer to use for the weight and - projection matrices. - num_proj: (optional) int, The output dimensionality for the projection - matrices. If None, no projection is performed. - proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is - provided, then the projected values are clipped elementwise to within - `[-proj_clip, proj_clip]`. - num_unit_shards: Deprecated, will be removed by Jan. 2017. - Use a variable_scope partitioner instead. - num_proj_shards: Deprecated, will be removed by Jan. 2017. - Use a variable_scope partitioner instead. - forget_bias: Biases of the forget gate are initialized by default to 1 - in order to reduce the scale of forgetting at the beginning of - the training. - state_is_tuple: If True, accepted and returned states are 2-tuples of - the `c_state` and `m_state`. If False, they are concatenated - along the column axis. This latter behavior will soon be deprecated. - activation: Activation function of the inner states. - """ - if not state_is_tuple: - logging.warn("%s: Using a concatenated state is slower and will soon be " - "deprecated. Use state_is_tuple=True.", self) - if input_size is not None: - logging.warn("%s: The input_size parameter is deprecated.", self) - if num_unit_shards is not None or num_proj_shards is not None: - logging.warn( - "%s: The num_unit_shards and proj_unit_shards parameters are " - "deprecated and will be removed in Jan 2017. " - "Use a variable scope with a partitioner instead.", self) - - self._num_units = num_units - self._use_peepholes = use_peepholes - self._cell_clip = cell_clip - self._initializer = initializer - self._num_proj = num_proj - self._proj_clip = proj_clip - self._num_unit_shards = num_unit_shards - self._num_proj_shards = num_proj_shards - self._forget_bias = forget_bias - self._state_is_tuple = state_is_tuple - self._activation = activation - - if num_proj: - self._state_size = ( - LSTMStateTuple(num_units, num_proj) - if state_is_tuple else num_units + num_proj) - self._output_size = num_proj - else: - self._state_size = ( - LSTMStateTuple(num_units, num_units) - if state_is_tuple else 2 * num_units) - self._output_size = num_units - - @property - def state_size(self): - return self._state_size - - @property - def output_size(self): - return self._output_size - - def __call__(self, inputs, state, scope=None): - """Run one step of LSTM. - - Args: - inputs: input Tensor, 2D, batch x num_units. - state: if `state_is_tuple` is False, this must be a state Tensor, - `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a - tuple of state Tensors, both `2-D`, with column sizes `c_state` and - `m_state`. - scope: VariableScope for the created subgraph; defaults to "lstm_cell". - - Returns: - A tuple containing: - - - A `2-D, [batch x output_dim]`, Tensor representing the output of the - LSTM after reading `inputs` when previous state was `state`. - Here output_dim is: - num_proj if num_proj was set, - num_units otherwise. - - Tensor(s) representing the new state of LSTM after reading `inputs` when - the previous state was `state`. Same type and shape(s) as `state`. - - Raises: - ValueError: If input size cannot be inferred from inputs via - static shape inference. - """ - num_proj = self._num_units if self._num_proj is None else self._num_proj - - if self._state_is_tuple: - (c_prev, m_prev) = state - else: - c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) - m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) - - dtype = inputs.dtype - input_size = inputs.get_shape().with_rank(2)[1] - if input_size.value is None: - raise ValueError("Could not infer input size from inputs.get_shape()[-1]") - with vs.variable_scope(scope or "lstm_cell", - initializer=self._initializer) as unit_scope: - if self._num_unit_shards is not None: - unit_scope.set_partitioner( - partitioned_variables.fixed_size_partitioner( - self._num_unit_shards)) - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True, - scope=scope) - i, j, f, o = array_ops.split(1, 4, lstm_matrix) - - # Diagonal connections - if self._use_peepholes: - with vs.variable_scope(unit_scope) as projection_scope: - if self._num_unit_shards is not None: - projection_scope.set_partitioner(None) - w_f_diag = vs.get_variable( - "w_f_diag", shape=[self._num_units], dtype=dtype) - w_i_diag = vs.get_variable( - "w_i_diag", shape=[self._num_units], dtype=dtype) - w_o_diag = vs.get_variable( - "w_o_diag", shape=[self._num_units], dtype=dtype) - - if self._use_peepholes: - c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + - sigmoid(i + w_i_diag * c_prev) * self._activation(j)) - else: - c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * - self._activation(j)) - - if self._cell_clip is not None: - # pylint: disable=invalid-unary-operand-type - c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) - # pylint: enable=invalid-unary-operand-type - - if self._use_peepholes: - m = sigmoid(o + w_o_diag * c) * self._activation(c) - else: - m = sigmoid(o) * self._activation(c) - - if self._num_proj is not None: - with vs.variable_scope("projection") as proj_scope: - if self._num_proj_shards is not None: - proj_scope.set_partitioner( - partitioned_variables.fixed_size_partitioner( - self._num_proj_shards)) - m = _linear(m, self._num_proj, bias=False, scope=scope) - - if self._proj_clip is not None: - # pylint: disable=invalid-unary-operand-type - m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) - # pylint: enable=invalid-unary-operand-type - - new_state = (LSTMStateTuple(c, m) if self._state_is_tuple - else array_ops.concat(1, [c, m])) - return m, new_state - - -class OutputProjectionWrapper(RNNCell): - """Operator adding an output projection to the given cell. - - Note: in many cases it may be more efficient to not use this wrapper, - but instead concatenate the whole sequence of your outputs in time, - do the projection on this batch-concatenated sequence, then split it - if needed or directly feed into a softmax. - """ - - def __init__(self, cell, output_size): - """Create a cell with output projection. - - Args: - cell: an RNNCell, a projection to output_size is added to it. - output_size: integer, the size of the output after projection. - - Raises: - TypeError: if cell is not an RNNCell. - ValueError: if output_size is not positive. - """ - if not isinstance(cell, RNNCell): - raise TypeError("The parameter cell is not RNNCell.") - if output_size < 1: - raise ValueError("Parameter output_size must be > 0: %d." % output_size) - self._cell = cell - self._output_size = output_size - - @property - def state_size(self): - return self._cell.state_size - - @property - def output_size(self): - return self._output_size - - def __call__(self, inputs, state, scope=None): - """Run the cell and output projection on inputs, starting from state.""" - output, res_state = self._cell(inputs, state) - # Default scope: "OutputProjectionWrapper" - with vs.variable_scope(scope or "output_projection_wrapper"): - projected = _linear(output, self._output_size, True, scope=scope) - return projected, res_state - - -class InputProjectionWrapper(RNNCell): - """Operator adding an input projection to the given cell. - - Note: in many cases it may be more efficient to not use this wrapper, - but instead concatenate the whole sequence of your inputs in time, - do the projection on this batch-concatenated sequence, then split it. - """ - - def __init__(self, cell, num_proj, input_size=None): - """Create a cell with input projection. - - Args: - cell: an RNNCell, a projection of inputs is added before it. - num_proj: Python integer. The dimension to project to. - input_size: Deprecated and unused. - - Raises: - TypeError: if cell is not an RNNCell. - """ - if input_size is not None: - logging.warn("%s: The input_size parameter is deprecated.", self) - if not isinstance(cell, RNNCell): - raise TypeError("The parameter cell is not RNNCell.") - self._cell = cell - self._num_proj = num_proj - - @property - def state_size(self): - return self._cell.state_size - - @property - def output_size(self): - return self._cell.output_size - - def __call__(self, inputs, state, scope=None): - """Run the input projection and then the cell.""" - # Default scope: "InputProjectionWrapper" - with vs.variable_scope(scope or "input_projection_wrapper"): - projected = _linear(inputs, self._num_proj, True, scope=scope) - return self._cell(projected, state) - - -class DropoutWrapper(RNNCell): - """Operator adding dropout to inputs and outputs of the given cell.""" - - def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0, - seed=None): - """Create a cell with added input and/or output dropout. - - Dropout is never used on the state. - - Args: - cell: an RNNCell, a projection to output_size is added to it. - input_keep_prob: unit Tensor or float between 0 and 1, input keep - probability; if it is float and 1, no input dropout will be added. - output_keep_prob: unit Tensor or float between 0 and 1, output keep - probability; if it is float and 1, no output dropout will be added. - seed: (optional) integer, the randomness seed. - - Raises: - TypeError: if cell is not an RNNCell. - ValueError: if keep_prob is not between 0 and 1. - """ - if not isinstance(cell, RNNCell): - raise TypeError("The parameter cell is not a RNNCell.") - if (isinstance(input_keep_prob, float) and - not (input_keep_prob >= 0.0 and input_keep_prob <= 1.0)): - raise ValueError("Parameter input_keep_prob must be between 0 and 1: %d" - % input_keep_prob) - if (isinstance(output_keep_prob, float) and - not (output_keep_prob >= 0.0 and output_keep_prob <= 1.0)): - raise ValueError("Parameter output_keep_prob must be between 0 and 1: %d" - % output_keep_prob) - self._cell = cell - self._input_keep_prob = input_keep_prob - self._output_keep_prob = output_keep_prob - self._seed = seed - - @property - def state_size(self): - return self._cell.state_size - - @property - def output_size(self): - return self._cell.output_size - - def __call__(self, inputs, state, scope=None): - """Run the cell with the declared dropouts.""" - if (not isinstance(self._input_keep_prob, float) or - self._input_keep_prob < 1): - inputs = nn_ops.dropout(inputs, self._input_keep_prob, seed=self._seed) - output, new_state = self._cell(inputs, state, scope) - if (not isinstance(self._output_keep_prob, float) or - self._output_keep_prob < 1): - output = nn_ops.dropout(output, self._output_keep_prob, seed=self._seed) - return output, new_state - - -class EmbeddingWrapper(RNNCell): - """Operator adding input embedding to the given cell. - - Note: in many cases it may be more efficient to not use this wrapper, - but instead concatenate the whole sequence of your inputs in time, - do the embedding on this batch-concatenated sequence, then split it and - feed into your RNN. - """ - - def __init__(self, cell, embedding_classes, embedding_size, initializer=None): - """Create a cell with an added input embedding. - - Args: - cell: an RNNCell, an embedding will be put before its inputs. - embedding_classes: integer, how many symbols will be embedded. - embedding_size: integer, the size of the vectors we embed into. - initializer: an initializer to use when creating the embedding; - if None, the initializer from variable scope or a default one is used. - - Raises: - TypeError: if cell is not an RNNCell. - ValueError: if embedding_classes is not positive. - """ - if not isinstance(cell, RNNCell): - raise TypeError("The parameter cell is not RNNCell.") - if embedding_classes <= 0 or embedding_size <= 0: - raise ValueError("Both embedding_classes and embedding_size must be > 0: " - "%d, %d." % (embedding_classes, embedding_size)) - self._cell = cell - self._embedding_classes = embedding_classes - self._embedding_size = embedding_size - self._initializer = initializer - - @property - def state_size(self): - return self._cell.state_size - - @property - def output_size(self): - return self._cell.output_size - - def __call__(self, inputs, state, scope=None): - """Run the cell on embedded inputs.""" - with vs.variable_scope(scope or "embedding_wrapper"): # "EmbeddingWrapper" - with ops.device("/cpu:0"): - if self._initializer: - initializer = self._initializer - elif vs.get_variable_scope().initializer: - initializer = vs.get_variable_scope().initializer - else: - # Default initializer for embeddings should have variance=1. - sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1. - initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3) - - if type(state) is tuple: - data_type = state[0].dtype - else: - data_type = state.dtype - - embedding = vs.get_variable( - "embedding", [self._embedding_classes, self._embedding_size], - initializer=initializer, - dtype=data_type) - embedded = embedding_ops.embedding_lookup( - embedding, array_ops.reshape(inputs, [-1])) - return self._cell(embedded, state) - - -class MultiRNNCell(RNNCell): - """RNN cell composed sequentially of multiple simple cells.""" - - def __init__(self, cells, state_is_tuple=True): - """Create a RNN cell composed sequentially of a number of RNNCells. - - Args: - cells: list of RNNCells that will be composed in this order. - state_is_tuple: If True, accepted and returned states are n-tuples, where - `n = len(cells)`. If False, the states are all - concatenated along the column axis. This latter behavior will soon be - deprecated. - - Raises: - ValueError: if cells is empty (not allowed), or at least one of the cells - returns a state tuple but the flag `state_is_tuple` is `False`. - """ - if not cells: - raise ValueError("Must specify at least one cell for MultiRNNCell.") - self._cells = cells - self._state_is_tuple = state_is_tuple - if not state_is_tuple: - if any(nest.is_sequence(c.state_size) for c in self._cells): - raise ValueError("Some cells return tuples of states, but the flag " - "state_is_tuple is not set. State sizes are: %s" - % str([c.state_size for c in self._cells])) - - @property - def state_size(self): - if self._state_is_tuple: - return tuple(cell.state_size for cell in self._cells) - else: - return sum([cell.state_size for cell in self._cells]) - - @property - def output_size(self): - return self._cells[-1].output_size - - def __call__(self, inputs, state, scope=None): - """Run this multi-layer cell on inputs, starting from state.""" - with vs.variable_scope(scope or "multi_rnn_cell"): - cur_state_pos = 0 - cur_inp = inputs - new_states = [] - for i, cell in enumerate(self._cells): - with vs.variable_scope("cell_%d" % i): - if self._state_is_tuple: - if not nest.is_sequence(state): - raise ValueError( - "Expected state to be a tuple of length %d, but received: %s" - % (len(self.state_size), state)) - cur_state = state[i] - else: - cur_state = array_ops.slice( - state, [0, cur_state_pos], [-1, cell.state_size]) - cur_state_pos += cell.state_size - cur_inp, new_state = cell(cur_inp, cur_state) - new_states.append(new_state) - new_states = (tuple(new_states) if self._state_is_tuple - else array_ops.concat(1, new_states)) - return cur_inp, new_states - - -class _SlimRNNCell(RNNCell): - """A simple wrapper for slim.rnn_cells.""" - - def __init__(self, cell_fn): - """Create a SlimRNNCell from a cell_fn. - - Args: - cell_fn: a function which takes (inputs, state, scope) and produces the - outputs and the new_state. Additionally when called with inputs=None and - state=None it should return (initial_outputs, initial_state). - - Raises: - TypeError: if cell_fn is not callable - ValueError: if cell_fn cannot produce a valid initial state. - """ - if not callable(cell_fn): - raise TypeError("cell_fn %s needs to be callable", cell_fn) - self._cell_fn = cell_fn - self._cell_name = cell_fn.func.__name__ - init_output, init_state = self._cell_fn(None, None) - output_shape = init_output.get_shape() - state_shape = init_state.get_shape() - self._output_size = output_shape.with_rank(2)[1].value - self._state_size = state_shape.with_rank(2)[1].value - if self._output_size is None: - raise ValueError("Initial output created by %s has invalid shape %s" % - (self._cell_name, output_shape)) - if self._state_size is None: - raise ValueError("Initial state created by %s has invalid shape %s" % - (self._cell_name, state_shape)) - - @property - def state_size(self): - return self._state_size - - @property - def output_size(self): - return self._output_size - - def __call__(self, inputs, state, scope=None): - scope = scope or self._cell_name - output, state = self._cell_fn(inputs, state, scope=scope) - return output, state - - -def _linear(args, output_size, bias, bias_start=0.0, scope=None): - """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. - - Args: - args: a 2D Tensor or a list of 2D, batch x n, Tensors. - output_size: int, second dimension of W[i]. - bias: boolean, whether to add a bias term or not. - bias_start: starting value to initialize the bias; 0 by default. - scope: (optional) Variable scope to create parameters in. - - Returns: - A 2D Tensor with shape [batch x output_size] equal to - sum_i(args[i] * W[i]), where W[i]s are newly created matrices. - - Raises: - ValueError: if some of the arguments has unspecified or wrong shape. - """ - if args is None or (nest.is_sequence(args) and not args): - raise ValueError("`args` must be specified") - if not nest.is_sequence(args): - args = [args] - - # Calculate the total size of arguments on dimension 1. - total_arg_size = 0 - shapes = [a.get_shape() for a in args] - for shape in shapes: - if shape.ndims != 2: - raise ValueError("linear is expecting 2D arguments: %s" % shapes) - if shape[1].value is None: - raise ValueError("linear expects shape[1] to be provided for shape %s, " - "but saw %d" % (shape, shape[1])) - else: - total_arg_size += shape[1].value - - dtype = [a.dtype for a in args][0] - - # Now the computation. - scope = vs.get_variable_scope() - with vs.variable_scope(scope) as outer_scope: - weights = vs.get_variable( - "weights", [total_arg_size, output_size], dtype=dtype) - if len(args) == 1: - res = math_ops.matmul(args[0], weights) - else: - res = math_ops.matmul(array_ops.concat(1, args), weights) - if not bias: - return res - with vs.variable_scope(outer_scope) as inner_scope: - inner_scope.set_partitioner(None) - biases = vs.get_variable( - "biases", [output_size], - dtype=dtype, - initializer=init_ops.constant_initializer(bias_start, dtype=dtype)) - return nn_ops.bias_add(res, biases) +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py new file mode 100644 index 0000000000..81d510de28 --- /dev/null +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -0,0 +1,872 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Module implementing RNN Cells.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import math + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope as vs + +from tensorflow.python.ops.math_ops import sigmoid +from tensorflow.python.ops.math_ops import tanh + +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import nest + + +def _state_size_with_prefix(state_size, prefix=None): + """Helper function that enables int or TensorShape shape specification. + + This function takes a size specification, which can be an integer or a + TensorShape, and converts it into a list of integers. One may specify any + additional dimensions that precede the final state size specification. + + Args: + state_size: TensorShape or int that specifies the size of a tensor. + prefix: optional additional list of dimensions to prepend. + + Returns: + result_state_size: list of dimensions the resulting tensor size. + """ + result_state_size = tensor_shape.as_shape(state_size).as_list() + if prefix is not None: + if not isinstance(prefix, list): + raise TypeError("prefix of _state_size_with_prefix should be a list.") + result_state_size = prefix + result_state_size + return result_state_size + + +class RNNCell(object): + """Abstract object representing an RNN cell. + + The definition of cell in this package differs from the definition used in the + literature. In the literature, cell refers to an object with a single scalar + output. The definition in this package refers to a horizontal array of such + units. + + An RNN cell, in the most abstract setting, is anything that has + a state and performs some operation that takes a matrix of inputs. + This operation results in an output matrix with `self.output_size` columns. + If `self.state_size` is an integer, this operation also results in a new + state matrix with `self.state_size` columns. If `self.state_size` is a + tuple of integers, then it results in a tuple of `len(state_size)` state + matrices, each with a column size corresponding to values in `state_size`. + + This module provides a number of basic commonly used RNN cells, such as + LSTM (Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number + of operators that allow add dropouts, projections, or embeddings for inputs. + Constructing multi-layer cells is supported by the class `MultiRNNCell`, + or by calling the `rnn` ops several times. Every `RNNCell` must have the + properties below and and implement `__call__` with the following signature. + """ + + def __call__(self, inputs, state, scope=None): + """Run this RNN cell on inputs, starting from the given state. + + Args: + inputs: `2-D` tensor with shape `[batch_size x input_size]`. + state: if `self.state_size` is an integer, this should be a `2-D Tensor` + with shape `[batch_size x self.state_size]`. Otherwise, if + `self.state_size` is a tuple of integers, this should be a tuple + with shapes `[batch_size x s] for s in self.state_size`. + scope: VariableScope for the created subgraph; defaults to class name. + + Returns: + A pair containing: + + - Output: A `2-D` tensor with shape `[batch_size x self.output_size]`. + - New state: Either a single `2-D` tensor, or a tuple of tensors matching + the arity and shapes of `state`. + """ + raise NotImplementedError("Abstract method") + + @property + def state_size(self): + """size(s) of state(s) used by this cell. + + It can be represented by an Integer, a TensorShape or a tuple of Integers + or TensorShapes. + """ + raise NotImplementedError("Abstract method") + + @property + def output_size(self): + """Integer or TensorShape: size of outputs produced by this cell.""" + raise NotImplementedError("Abstract method") + + def zero_state(self, batch_size, dtype): + """Return zero-filled state tensor(s). + + Args: + batch_size: int, float, or unit Tensor representing the batch size. + dtype: the data type to use for the state. + + Returns: + If `state_size` is an int or TensorShape, then the return value is a + `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. + + If `state_size` is a nested list or tuple, then the return value is + a nested list or tuple (of the same structure) of `2-D` tensors with + the shapes `[batch_size x s]` for each s in `state_size`. + """ + state_size = self.state_size + if nest.is_sequence(state_size): + state_size_flat = nest.flatten(state_size) + zeros_flat = [ + array_ops.zeros( + array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])), + dtype=dtype) + for s in state_size_flat] + for s, z in zip(state_size_flat, zeros_flat): + z.set_shape(_state_size_with_prefix(s, prefix=[None])) + zeros = nest.pack_sequence_as(structure=state_size, + flat_sequence=zeros_flat) + else: + zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size]) + zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype) + zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None])) + + return zeros + + +class BasicRNNCell(RNNCell): + """The most basic RNN cell.""" + + def __init__(self, num_units, input_size=None, activation=tanh): + if input_size is not None: + logging.warn("%s: The input_size parameter is deprecated.", self) + self._num_units = num_units + self._activation = activation + + @property + def state_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + def __call__(self, inputs, state, scope=None): + """Most basic RNN: output = new_state = act(W * input + U * state + B).""" + with vs.variable_scope(scope or "basic_rnn_cell"): + output = self._activation( + _linear([inputs, state], self._num_units, True, scope=scope)) + return output, output + + +class GRUCell(RNNCell): + """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).""" + + def __init__(self, num_units, input_size=None, activation=tanh): + if input_size is not None: + logging.warn("%s: The input_size parameter is deprecated.", self) + self._num_units = num_units + self._activation = activation + + @property + def state_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + def __call__(self, inputs, state, scope=None): + """Gated recurrent unit (GRU) with nunits cells.""" + with vs.variable_scope(scope or "gru_cell"): + with vs.variable_scope("gates"): # Reset gate and update gate. + # We start with bias of 1.0 to not reset and not update. + r, u = array_ops.split( + 1, 2, _linear([inputs, state], 2 * self._num_units, True, 1.0, + scope=scope)) + r, u = sigmoid(r), sigmoid(u) + with vs.variable_scope("candidate"): + c = self._activation(_linear([inputs, r * state], + self._num_units, True, + scope=scope)) + new_h = u * state + (1 - u) * c + return new_h, new_h + + +_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h")) + + +class LSTMStateTuple(_LSTMStateTuple): + """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. + + Stores two elements: `(c, h)`, in that order. + + Only used when `state_is_tuple=True`. + """ + __slots__ = () + + @property + def dtype(self): + (c, h) = self + if not c.dtype == h.dtype: + raise TypeError("Inconsistent internal state: %s vs %s" % + (str(c.dtype), str(h.dtype))) + return c.dtype + + +class BasicLSTMCell(RNNCell): + """Basic LSTM recurrent network cell. + + The implementation is based on: http://arxiv.org/abs/1409.2329. + + We add forget_bias (default: 1) to the biases of the forget gate in order to + reduce the scale of forgetting in the beginning of the training. + + It does not allow cell clipping, a projection layer, and does not + use peep-hole connections: it is the basic baseline. + + For advanced models, please use the full LSTMCell that follows. + """ + + def __init__(self, num_units, forget_bias=1.0, input_size=None, + state_is_tuple=True, activation=tanh): + """Initialize the basic LSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell. + forget_bias: float, The bias added to forget gates (see above). + input_size: Deprecated and unused. + state_is_tuple: If True, accepted and returned states are 2-tuples of + the `c_state` and `m_state`. If False, they are concatenated + along the column axis. The latter behavior will soon be deprecated. + activation: Activation function of the inner states. + """ + if not state_is_tuple: + logging.warn("%s: Using a concatenated state is slower and will soon be " + "deprecated. Use state_is_tuple=True.", self) + if input_size is not None: + logging.warn("%s: The input_size parameter is deprecated.", self) + self._num_units = num_units + self._forget_bias = forget_bias + self._state_is_tuple = state_is_tuple + self._activation = activation + + @property + def state_size(self): + return (LSTMStateTuple(self._num_units, self._num_units) + if self._state_is_tuple else 2 * self._num_units) + + @property + def output_size(self): + return self._num_units + + def __call__(self, inputs, state, scope=None): + """Long short-term memory cell (LSTM).""" + with vs.variable_scope(scope or "basic_lstm_cell"): + # Parameters of gates are concatenated into one multiply for efficiency. + if self._state_is_tuple: + c, h = state + else: + c, h = array_ops.split(1, 2, state) + concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = array_ops.split(1, 4, concat) + + new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * + self._activation(j)) + new_h = self._activation(new_c) * sigmoid(o) + + if self._state_is_tuple: + new_state = LSTMStateTuple(new_c, new_h) + else: + new_state = array_ops.concat(1, [new_c, new_h]) + return new_h, new_state + + +class LSTMCell(RNNCell): + """Long short-term memory unit (LSTM) recurrent network cell. + + The default non-peephole implementation is based on: + + http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf + + S. Hochreiter and J. Schmidhuber. + "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. + + The peephole implementation is based on: + + https://research.google.com/pubs/archive/43905.pdf + + Hasim Sak, Andrew Senior, and Francoise Beaufays. + "Long short-term memory recurrent neural network architectures for + large scale acoustic modeling." INTERSPEECH, 2014. + + The class uses optional peep-hole connections, optional cell clipping, and + an optional projection layer. + """ + + def __init__(self, num_units, input_size=None, + use_peepholes=False, cell_clip=None, + initializer=None, num_proj=None, proj_clip=None, + num_unit_shards=None, num_proj_shards=None, + forget_bias=1.0, state_is_tuple=True, + activation=tanh): + """Initialize the parameters for an LSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell + input_size: Deprecated and unused. + use_peepholes: bool, set True to enable diagonal/peephole connections. + cell_clip: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. + initializer: (optional) The initializer to use for the weight and + projection matrices. + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. + num_unit_shards: Deprecated, will be removed by Jan. 2017. + Use a variable_scope partitioner instead. + num_proj_shards: Deprecated, will be removed by Jan. 2017. + Use a variable_scope partitioner instead. + forget_bias: Biases of the forget gate are initialized by default to 1 + in order to reduce the scale of forgetting at the beginning of + the training. + state_is_tuple: If True, accepted and returned states are 2-tuples of + the `c_state` and `m_state`. If False, they are concatenated + along the column axis. This latter behavior will soon be deprecated. + activation: Activation function of the inner states. + """ + if not state_is_tuple: + logging.warn("%s: Using a concatenated state is slower and will soon be " + "deprecated. Use state_is_tuple=True.", self) + if input_size is not None: + logging.warn("%s: The input_size parameter is deprecated.", self) + if num_unit_shards is not None or num_proj_shards is not None: + logging.warn( + "%s: The num_unit_shards and proj_unit_shards parameters are " + "deprecated and will be removed in Jan 2017. " + "Use a variable scope with a partitioner instead.", self) + + self._num_units = num_units + self._use_peepholes = use_peepholes + self._cell_clip = cell_clip + self._initializer = initializer + self._num_proj = num_proj + self._proj_clip = proj_clip + self._num_unit_shards = num_unit_shards + self._num_proj_shards = num_proj_shards + self._forget_bias = forget_bias + self._state_is_tuple = state_is_tuple + self._activation = activation + + if num_proj: + self._state_size = ( + LSTMStateTuple(num_units, num_proj) + if state_is_tuple else num_units + num_proj) + self._output_size = num_proj + else: + self._state_size = ( + LSTMStateTuple(num_units, num_units) + if state_is_tuple else 2 * num_units) + self._output_size = num_units + + @property + def state_size(self): + return self._state_size + + @property + def output_size(self): + return self._output_size + + def __call__(self, inputs, state, scope=None): + """Run one step of LSTM. + + Args: + inputs: input Tensor, 2D, batch x num_units. + state: if `state_is_tuple` is False, this must be a state Tensor, + `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a + tuple of state Tensors, both `2-D`, with column sizes `c_state` and + `m_state`. + scope: VariableScope for the created subgraph; defaults to "lstm_cell". + + Returns: + A tuple containing: + + - A `2-D, [batch x output_dim]`, Tensor representing the output of the + LSTM after reading `inputs` when previous state was `state`. + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - Tensor(s) representing the new state of LSTM after reading `inputs` when + the previous state was `state`. Same type and shape(s) as `state`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + num_proj = self._num_units if self._num_proj is None else self._num_proj + + if self._state_is_tuple: + (c_prev, m_prev) = state + else: + c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) + m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) + + dtype = inputs.dtype + input_size = inputs.get_shape().with_rank(2)[1] + if input_size.value is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + with vs.variable_scope(scope or "lstm_cell", + initializer=self._initializer) as unit_scope: + if self._num_unit_shards is not None: + unit_scope.set_partitioner( + partitioned_variables.fixed_size_partitioner( + self._num_unit_shards)) + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True, + scope=scope) + i, j, f, o = array_ops.split(1, 4, lstm_matrix) + + # Diagonal connections + if self._use_peepholes: + with vs.variable_scope(unit_scope) as projection_scope: + if self._num_unit_shards is not None: + projection_scope.set_partitioner(None) + w_f_diag = vs.get_variable( + "w_f_diag", shape=[self._num_units], dtype=dtype) + w_i_diag = vs.get_variable( + "w_i_diag", shape=[self._num_units], dtype=dtype) + w_o_diag = vs.get_variable( + "w_o_diag", shape=[self._num_units], dtype=dtype) + + if self._use_peepholes: + c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + + sigmoid(i + w_i_diag * c_prev) * self._activation(j)) + else: + c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * + self._activation(j)) + + if self._cell_clip is not None: + # pylint: disable=invalid-unary-operand-type + c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) + # pylint: enable=invalid-unary-operand-type + + if self._use_peepholes: + m = sigmoid(o + w_o_diag * c) * self._activation(c) + else: + m = sigmoid(o) * self._activation(c) + + if self._num_proj is not None: + with vs.variable_scope("projection") as proj_scope: + if self._num_proj_shards is not None: + proj_scope.set_partitioner( + partitioned_variables.fixed_size_partitioner( + self._num_proj_shards)) + m = _linear(m, self._num_proj, bias=False, scope=scope) + + if self._proj_clip is not None: + # pylint: disable=invalid-unary-operand-type + m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) + # pylint: enable=invalid-unary-operand-type + + new_state = (LSTMStateTuple(c, m) if self._state_is_tuple + else array_ops.concat(1, [c, m])) + return m, new_state + + +class OutputProjectionWrapper(RNNCell): + """Operator adding an output projection to the given cell. + + Note: in many cases it may be more efficient to not use this wrapper, + but instead concatenate the whole sequence of your outputs in time, + do the projection on this batch-concatenated sequence, then split it + if needed or directly feed into a softmax. + """ + + def __init__(self, cell, output_size): + """Create a cell with output projection. + + Args: + cell: an RNNCell, a projection to output_size is added to it. + output_size: integer, the size of the output after projection. + + Raises: + TypeError: if cell is not an RNNCell. + ValueError: if output_size is not positive. + """ + if not isinstance(cell, RNNCell): + raise TypeError("The parameter cell is not RNNCell.") + if output_size < 1: + raise ValueError("Parameter output_size must be > 0: %d." % output_size) + self._cell = cell + self._output_size = output_size + + @property + def state_size(self): + return self._cell.state_size + + @property + def output_size(self): + return self._output_size + + def __call__(self, inputs, state, scope=None): + """Run the cell and output projection on inputs, starting from state.""" + output, res_state = self._cell(inputs, state) + # Default scope: "OutputProjectionWrapper" + with vs.variable_scope(scope or "output_projection_wrapper"): + projected = _linear(output, self._output_size, True, scope=scope) + return projected, res_state + + +class InputProjectionWrapper(RNNCell): + """Operator adding an input projection to the given cell. + + Note: in many cases it may be more efficient to not use this wrapper, + but instead concatenate the whole sequence of your inputs in time, + do the projection on this batch-concatenated sequence, then split it. + """ + + def __init__(self, cell, num_proj, input_size=None): + """Create a cell with input projection. + + Args: + cell: an RNNCell, a projection of inputs is added before it. + num_proj: Python integer. The dimension to project to. + input_size: Deprecated and unused. + + Raises: + TypeError: if cell is not an RNNCell. + """ + if input_size is not None: + logging.warn("%s: The input_size parameter is deprecated.", self) + if not isinstance(cell, RNNCell): + raise TypeError("The parameter cell is not RNNCell.") + self._cell = cell + self._num_proj = num_proj + + @property + def state_size(self): + return self._cell.state_size + + @property + def output_size(self): + return self._cell.output_size + + def __call__(self, inputs, state, scope=None): + """Run the input projection and then the cell.""" + # Default scope: "InputProjectionWrapper" + with vs.variable_scope(scope or "input_projection_wrapper"): + projected = _linear(inputs, self._num_proj, True, scope=scope) + return self._cell(projected, state) + + +class DropoutWrapper(RNNCell): + """Operator adding dropout to inputs and outputs of the given cell.""" + + def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0, + seed=None): + """Create a cell with added input and/or output dropout. + + Dropout is never used on the state. + + Args: + cell: an RNNCell, a projection to output_size is added to it. + input_keep_prob: unit Tensor or float between 0 and 1, input keep + probability; if it is float and 1, no input dropout will be added. + output_keep_prob: unit Tensor or float between 0 and 1, output keep + probability; if it is float and 1, no output dropout will be added. + seed: (optional) integer, the randomness seed. + + Raises: + TypeError: if cell is not an RNNCell. + ValueError: if keep_prob is not between 0 and 1. + """ + if not isinstance(cell, RNNCell): + raise TypeError("The parameter cell is not a RNNCell.") + if (isinstance(input_keep_prob, float) and + not (input_keep_prob >= 0.0 and input_keep_prob <= 1.0)): + raise ValueError("Parameter input_keep_prob must be between 0 and 1: %d" + % input_keep_prob) + if (isinstance(output_keep_prob, float) and + not (output_keep_prob >= 0.0 and output_keep_prob <= 1.0)): + raise ValueError("Parameter output_keep_prob must be between 0 and 1: %d" + % output_keep_prob) + self._cell = cell + self._input_keep_prob = input_keep_prob + self._output_keep_prob = output_keep_prob + self._seed = seed + + @property + def state_size(self): + return self._cell.state_size + + @property + def output_size(self): + return self._cell.output_size + + def __call__(self, inputs, state, scope=None): + """Run the cell with the declared dropouts.""" + if (not isinstance(self._input_keep_prob, float) or + self._input_keep_prob < 1): + inputs = nn_ops.dropout(inputs, self._input_keep_prob, seed=self._seed) + output, new_state = self._cell(inputs, state, scope) + if (not isinstance(self._output_keep_prob, float) or + self._output_keep_prob < 1): + output = nn_ops.dropout(output, self._output_keep_prob, seed=self._seed) + return output, new_state + + +class EmbeddingWrapper(RNNCell): + """Operator adding input embedding to the given cell. + + Note: in many cases it may be more efficient to not use this wrapper, + but instead concatenate the whole sequence of your inputs in time, + do the embedding on this batch-concatenated sequence, then split it and + feed into your RNN. + """ + + def __init__(self, cell, embedding_classes, embedding_size, initializer=None): + """Create a cell with an added input embedding. + + Args: + cell: an RNNCell, an embedding will be put before its inputs. + embedding_classes: integer, how many symbols will be embedded. + embedding_size: integer, the size of the vectors we embed into. + initializer: an initializer to use when creating the embedding; + if None, the initializer from variable scope or a default one is used. + + Raises: + TypeError: if cell is not an RNNCell. + ValueError: if embedding_classes is not positive. + """ + if not isinstance(cell, RNNCell): + raise TypeError("The parameter cell is not RNNCell.") + if embedding_classes <= 0 or embedding_size <= 0: + raise ValueError("Both embedding_classes and embedding_size must be > 0: " + "%d, %d." % (embedding_classes, embedding_size)) + self._cell = cell + self._embedding_classes = embedding_classes + self._embedding_size = embedding_size + self._initializer = initializer + + @property + def state_size(self): + return self._cell.state_size + + @property + def output_size(self): + return self._cell.output_size + + def __call__(self, inputs, state, scope=None): + """Run the cell on embedded inputs.""" + with vs.variable_scope(scope or "embedding_wrapper"): # "EmbeddingWrapper" + with ops.device("/cpu:0"): + if self._initializer: + initializer = self._initializer + elif vs.get_variable_scope().initializer: + initializer = vs.get_variable_scope().initializer + else: + # Default initializer for embeddings should have variance=1. + sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1. + initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3) + + if type(state) is tuple: + data_type = state[0].dtype + else: + data_type = state.dtype + + embedding = vs.get_variable( + "embedding", [self._embedding_classes, self._embedding_size], + initializer=initializer, + dtype=data_type) + embedded = embedding_ops.embedding_lookup( + embedding, array_ops.reshape(inputs, [-1])) + return self._cell(embedded, state) + + +class MultiRNNCell(RNNCell): + """RNN cell composed sequentially of multiple simple cells.""" + + def __init__(self, cells, state_is_tuple=True): + """Create a RNN cell composed sequentially of a number of RNNCells. + + Args: + cells: list of RNNCells that will be composed in this order. + state_is_tuple: If True, accepted and returned states are n-tuples, where + `n = len(cells)`. If False, the states are all + concatenated along the column axis. This latter behavior will soon be + deprecated. + + Raises: + ValueError: if cells is empty (not allowed), or at least one of the cells + returns a state tuple but the flag `state_is_tuple` is `False`. + """ + if not cells: + raise ValueError("Must specify at least one cell for MultiRNNCell.") + self._cells = cells + self._state_is_tuple = state_is_tuple + if not state_is_tuple: + if any(nest.is_sequence(c.state_size) for c in self._cells): + raise ValueError("Some cells return tuples of states, but the flag " + "state_is_tuple is not set. State sizes are: %s" + % str([c.state_size for c in self._cells])) + + @property + def state_size(self): + if self._state_is_tuple: + return tuple(cell.state_size for cell in self._cells) + else: + return sum([cell.state_size for cell in self._cells]) + + @property + def output_size(self): + return self._cells[-1].output_size + + def __call__(self, inputs, state, scope=None): + """Run this multi-layer cell on inputs, starting from state.""" + with vs.variable_scope(scope or "multi_rnn_cell"): + cur_state_pos = 0 + cur_inp = inputs + new_states = [] + for i, cell in enumerate(self._cells): + with vs.variable_scope("cell_%d" % i): + if self._state_is_tuple: + if not nest.is_sequence(state): + raise ValueError( + "Expected state to be a tuple of length %d, but received: %s" + % (len(self.state_size), state)) + cur_state = state[i] + else: + cur_state = array_ops.slice( + state, [0, cur_state_pos], [-1, cell.state_size]) + cur_state_pos += cell.state_size + cur_inp, new_state = cell(cur_inp, cur_state) + new_states.append(new_state) + new_states = (tuple(new_states) if self._state_is_tuple + else array_ops.concat(1, new_states)) + return cur_inp, new_states + + +class _SlimRNNCell(RNNCell): + """A simple wrapper for slim.rnn_cells.""" + + def __init__(self, cell_fn): + """Create a SlimRNNCell from a cell_fn. + + Args: + cell_fn: a function which takes (inputs, state, scope) and produces the + outputs and the new_state. Additionally when called with inputs=None and + state=None it should return (initial_outputs, initial_state). + + Raises: + TypeError: if cell_fn is not callable + ValueError: if cell_fn cannot produce a valid initial state. + """ + if not callable(cell_fn): + raise TypeError("cell_fn %s needs to be callable", cell_fn) + self._cell_fn = cell_fn + self._cell_name = cell_fn.func.__name__ + init_output, init_state = self._cell_fn(None, None) + output_shape = init_output.get_shape() + state_shape = init_state.get_shape() + self._output_size = output_shape.with_rank(2)[1].value + self._state_size = state_shape.with_rank(2)[1].value + if self._output_size is None: + raise ValueError("Initial output created by %s has invalid shape %s" % + (self._cell_name, output_shape)) + if self._state_size is None: + raise ValueError("Initial state created by %s has invalid shape %s" % + (self._cell_name, state_shape)) + + @property + def state_size(self): + return self._state_size + + @property + def output_size(self): + return self._output_size + + def __call__(self, inputs, state, scope=None): + scope = scope or self._cell_name + output, state = self._cell_fn(inputs, state, scope=scope) + return output, state + + +def _linear(args, output_size, bias, bias_start=0.0, scope=None): + """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. + + Args: + args: a 2D Tensor or a list of 2D, batch x n, Tensors. + output_size: int, second dimension of W[i]. + bias: boolean, whether to add a bias term or not. + bias_start: starting value to initialize the bias; 0 by default. + scope: (optional) Variable scope to create parameters in. + + Returns: + A 2D Tensor with shape [batch x output_size] equal to + sum_i(args[i] * W[i]), where W[i]s are newly created matrices. + + Raises: + ValueError: if some of the arguments has unspecified or wrong shape. + """ + if args is None or (nest.is_sequence(args) and not args): + raise ValueError("`args` must be specified") + if not nest.is_sequence(args): + args = [args] + + # Calculate the total size of arguments on dimension 1. + total_arg_size = 0 + shapes = [a.get_shape() for a in args] + for shape in shapes: + if shape.ndims != 2: + raise ValueError("linear is expecting 2D arguments: %s" % shapes) + if shape[1].value is None: + raise ValueError("linear expects shape[1] to be provided for shape %s, " + "but saw %d" % (shape, shape[1])) + else: + total_arg_size += shape[1].value + + dtype = [a.dtype for a in args][0] + + # Now the computation. + scope = vs.get_variable_scope() + with vs.variable_scope(scope) as outer_scope: + weights = vs.get_variable( + "weights", [total_arg_size, output_size], dtype=dtype) + if len(args) == 1: + res = math_ops.matmul(args[0], weights) + else: + res = math_ops.matmul(array_ops.concat(1, args), weights) + if not bias: + return res + with vs.variable_scope(outer_scope) as inner_scope: + inner_scope.set_partitioner(None) + biases = vs.get_variable( + "biases", [output_size], + dtype=dtype, + initializer=init_ops.constant_initializer(bias_start, dtype=dtype)) + return nn_ops.bias_add(res, biases) diff --git a/tensorflow/python/ops/seq2seq.py b/tensorflow/python/ops/seq2seq.py index 9ec12583de..5bda634aee 100644 --- a/tensorflow/python/ops/seq2seq.py +++ b/tensorflow/python/ops/seq2seq.py @@ -71,11 +71,12 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest # TODO(ebrevdo): Remove once _linear is fully deprecated. -linear = rnn_cell._linear # pylint: disable=protected-access +linear = rnn_cell_impl._linear # pylint: disable=protected-access def _extract_argmax_and_embed(embedding, output_projection=None, diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index c46c24af9a..57e7742355 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -46,8 +46,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six - from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -70,7 +68,8 @@ def string_split(source, delimiter=" "): # pylint: disable=invalid-name If `delimiter` is an empty string, each element of the `source` is split into individual strings, each containing one byte. (This includes splitting - multibyte sequences of UTF-8.) + multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is + treated as a set of delimiters with each considered a potential split point. For example: N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output @@ -89,17 +88,14 @@ def string_split(source, delimiter=" "): # pylint: disable=invalid-name delimiter: `0-D` string `Tensor`, the delimiter character, the string should be length 0 or 1. + Raises: + ValueError: If delimiter is not a string. + Returns: A `SparseTensor` of rank `2`, the strings split according to the delimiter. The first column of the indices corresponds to the row in `source` and the second column corresponds to the index of the split component in this row. - - Raises: - ValueError: If delimiter is not a single-byte character. """ - if isinstance(delimiter, six.string_types) and len(delimiter) > 1: - raise ValueError("delimiter must be a single byte-character, got %s" % - delimiter) delimiter = ops.convert_to_tensor(delimiter, dtype=dtypes.string) source = ops.convert_to_tensor(source, dtype=dtypes.string) diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index fca39e0ad5..09955e690c 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -30,7 +30,7 @@ __all__ = ["make_template"] def make_template(name_, func_, create_scope_now_=False, unique_name_=None, - **kwargs): + custom_getter_=None, **kwargs): """Given an arbitrary function, wrap it so that it does variable sharing. This wraps `func_` in a Template and partially evaluates it. Templates are @@ -118,6 +118,9 @@ def make_template(name_, func_, create_scope_now_=False, unique_name_=None, unique_name_: When used, it overrides name_ and is not made unique. If a template of the same scope/unique_name already exists and reuse is false, an error is raised. Defaults to None. + custom_getter_: Optional custom getter for variables used in `func_`. See + the [`get_variable`](#get_variable) `custom_getter` documentation for + more information. **kwargs: Keyword arguments to apply to `func_`. Returns: @@ -136,7 +139,7 @@ def make_template(name_, func_, create_scope_now_=False, unique_name_=None, func_ = functools.partial(func_, **kwargs) return Template( name_, func_, create_scope_now=create_scope_now_, - unique_name=unique_name_) + unique_name=unique_name_, custom_getter=custom_getter_) def _skip_common_stack_elements(stacktrace, base_case): @@ -159,7 +162,8 @@ class Template(object): call. """ - def __init__(self, name, func, create_scope_now=False, unique_name=None): + def __init__(self, name, func, create_scope_now=False, unique_name=None, + custom_getter=None): """Creates a template for the given function. Args: @@ -179,6 +183,7 @@ class Template(object): unique_name: When used, it overrides name_ and is not made unique. If a template of the same scope/unique_name already exists and reuse is false, an error is raised. Defaults to None. + custom_getter: optional custom getter to pass to variable_scope() Raises: ValueError: if the name is None. @@ -187,11 +192,13 @@ class Template(object): self._stacktrace = traceback.format_stack()[:-2] self._name = name self._unique_name = unique_name + self._custom_getter = custom_getter if name is None: raise ValueError("name cannot be None.") if create_scope_now: with variable_scope.variable_scope( - self._unique_name, self._name) as vs: + self._unique_name, self._name, + custom_getter=self._custom_getter) as vs: self._var_scope = vs else: self._var_scope = None @@ -262,7 +269,8 @@ class Template(object): # Subsequent calls should reuse variables. self._variables_created = True with variable_scope.variable_scope( - self._unique_name, self._name) as vs: + self._unique_name, self._name, + custom_getter=self._custom_getter) as vs: self._var_scope = vs return self._call_func(args, kwargs, check_for_new_variables=False) diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 40c90dfba8..9f03ae6264 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -87,7 +87,7 @@ class Variable(object): ```python # Add an Op to initialize global variables. - init_op = tf.global_variable_initializers() + init_op = tf.global_variables_initializer() # Launch the graph in a session. with tf.Session() as sess: @@ -518,6 +518,10 @@ class Variable(object): You should use this instead of the variable itself to initialize another variable with a value that depends on the value of this variable. + Beware of using initialized_value except during initialization: + initialized_value causes the Variable's initializer op to be run, so running + this op resets the variable to the initial value. + ```python # Initialize 'v' with a random tensor. v = tf.Variable(tf.truncated_normal([10, 40])) diff --git a/tensorflow/python/platform/app.py b/tensorflow/python/platform/app.py index bd58db7b45..a47d183e60 100644 --- a/tensorflow/python/platform/app.py +++ b/tensorflow/python/platform/app.py @@ -18,9 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys +import sys as _sys from tensorflow.python.platform import flags +from tensorflow.python.util.all_util import remove_undocumented def run(main=None, argv=None): @@ -36,8 +37,17 @@ def run(main=None, argv=None): flags_passthrough = f._parse_flags(args=args) # pylint: enable=protected-access - main = main or sys.modules['__main__'].main + main = main or _sys.modules['__main__'].main # Call the main function, passing through any arguments # to the final program. - sys.exit(main(sys.argv[:1] + flags_passthrough)) + _sys.exit(main(_sys.argv[:1] + flags_passthrough)) + + +_allowed_symbols = [ + 'run', + # Allowed submodule. + 'flags', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index db5768acb8..1663a1f251 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -2,7 +2,10 @@ # TensorFlow SavedModel. package( - default_visibility = ["//tensorflow/python/saved_model:__subpackages__"], + default_visibility = [ + "//tensorflow/contrib/learn:__subpackages__", + "//tensorflow/python/saved_model:__subpackages__", + ], ) licenses(["notice"]) # Apache 2.0 @@ -33,7 +36,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":constants", - "//tensorflow:tensorflow_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", @@ -48,7 +50,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":constants", - "//tensorflow:tensorflow_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:platform", "//tensorflow/python:training", @@ -94,7 +95,9 @@ py_library( name = "utils", srcs = ["utils.py"], srcs_version = "PY2AND3", - deps = ["//tensorflow/core:protos_all_py"], + deps = [ + "//tensorflow/core:protos_all_py", + ], ) py_test( @@ -111,6 +114,31 @@ py_test( ], ) +py_library( + name = "signature_def_utils", + srcs = ["signature_def_utils.py"], + srcs_version = "PY2AND3", + deps = [ + ":signature_constants", + ":utils", + "//tensorflow/core:protos_all_py", + ], +) + +py_test( + name = "signature_def_utils_test", + size = "small", + srcs = [ + "signature_def_utils_test.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], + deps = [ + ":signature_def_utils", + "//tensorflow:tensorflow_py", + ], +) + # ----------------------------------------------------------------------------- # Google-internal targets. These must be at the end for syncrepo. diff --git a/tensorflow/python/saved_model/example/BUILD b/tensorflow/python/saved_model/example/BUILD index 8198312109..5f4785676e 100644 --- a/tensorflow/python/saved_model/example/BUILD +++ b/tensorflow/python/saved_model/example/BUILD @@ -41,11 +41,13 @@ py_binary( ], ) -# TODO(b/32248363): change saved_model_half_plus_two.py to accept output -# location so that we can avoid writing to /tmp/ and copying the files from -# /tmp/. +# Genrule for SavedModel half-plus-two test data. Specifically, this genrule +# exports the test SavedModel to a versioned directory in order to be compatible +# with TensorFlow Serving model server requirements of a versioned subdirectory. +# Please note that SavedModel itself accepts any valid directory as the save +# location and does not perform any versioning. genrule( - name = "versioned_saved_model_half_plus_two_data", + name = "saved_model_half_plus_two_data", outs = [ "saved_model_half_plus_two/00000123/saved_model.pb", "saved_model_half_plus_two/00000123/assets/foo.txt", @@ -57,10 +59,8 @@ genrule( "saved_model_half_plus_two_pbtxt/00000123/variables/variables.index", ], cmd = - "rm -rf /tmp/saved_model; " + - "./$(locations :saved_model_half_plus_two); " + - "cp -r /tmp/saved_model/half_plus_two/* $(@D)/saved_model_half_plus_two/00000123; " + - "cp -r /tmp/saved_model/half_plus_two_pbtxt/* $(@D)/saved_model_half_plus_two_pbtxt/00000123", + "rm -rf $(@D)/saved_model_half_plus_two $(@D)/saved_model_half_plus_two_pbtxt; " + + "./$(locations :saved_model_half_plus_two) --output_dir=$(@D)/saved_model_half_plus_two/00000123 --output_dir_pbtxt=$(@D)/saved_model_half_plus_two_pbtxt/00000123", tools = [ ":saved_model_half_plus_two", ], diff --git a/tensorflow/python/saved_model/example/saved_model_half_plus_two.py b/tensorflow/python/saved_model/example/saved_model_half_plus_two.py index 65eb0f2fd7..d0b7b80674 100644 --- a/tensorflow/python/saved_model/example/saved_model_half_plus_two.py +++ b/tensorflow/python/saved_model/example/saved_model_half_plus_two.py @@ -36,10 +36,18 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.lib.io import file_io from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants -from tensorflow.python.saved_model import utils from tensorflow.python.util import compat +tf.app.flags.DEFINE_string("output_dir", "/tmp/saved_model_half_plus_two", + "Directory where to ouput SavedModel.") +tf.app.flags.DEFINE_string("output_dir_pbtxt", + "/tmp/saved_model_half_plus_two_pbtxt", + "Directory where to ouput the text format of " + "SavedModel.") +FLAGS = tf.flags.FLAGS + def _write_assets(assets_directory, assets_filename): """Writes asset files to be used with SavedModel for half plus two. @@ -113,16 +121,31 @@ def _generate_saved_model_for_half_plus_two(export_dir, as_text=False): output_tensor = meta_graph_pb2.TensorInfo() output_tensor.name = tf.identity(y).name signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor} - signature_def = utils.build_signature_def( + signature_def = signature_def_utils.build_signature_def( signature_inputs, signature_outputs, signature_constants.REGRESS_METHOD_NAME) + # Set up the signature for Predict with input and output tensor + # specification. + predict_input_tensor = meta_graph_pb2.TensorInfo() + predict_input_tensor.name = x.name + predict_signature_inputs = { + "x": predict_input_tensor + } + predict_signature_def = signature_def_utils.build_signature_def( + {"x": predict_input_tensor}, + {"y": output_tensor}, + signature_constants.PREDICT_METHOD_NAME) + # Initialize all variables and then save the SavedModel. sess.run(tf.global_variables_initializer()) builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], signature_def_map={ - signature_constants.REGRESS_METHOD_NAME: signature_def + signature_constants.REGRESS_METHOD_NAME: + signature_def, + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + predict_signature_def }, assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS), legacy_init_op=tf.group(assign_filename_op)) @@ -130,13 +153,11 @@ def _generate_saved_model_for_half_plus_two(export_dir, as_text=False): def main(_): - export_dir_pb = "/tmp/saved_model/half_plus_two" - _generate_saved_model_for_half_plus_two(export_dir_pb) - print("SavedModel generated at: %s" % export_dir_pb) + _generate_saved_model_for_half_plus_two(FLAGS.output_dir) + print("SavedModel generated at: %s" % FLAGS.output_dir) - export_dir_pbtxt = "/tmp/saved_model/half_plus_two_pbtxt" - _generate_saved_model_for_half_plus_two(export_dir_pbtxt, as_text=True) - print("SavedModel generated at: %s" % export_dir_pbtxt) + _generate_saved_model_for_half_plus_two(FLAGS.output_dir_pbtxt, as_text=True) + print("SavedModel generated at: %s" % FLAGS.output_dir_pbtxt) if __name__ == "__main__": diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 0f8ddfc65b..bf5b186b80 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -27,8 +27,8 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants -from tensorflow.python.saved_model import utils from tensorflow.python.util import compat @@ -315,7 +315,8 @@ class SavedModelTest(tf.test.TestCase): with self.test_session(graph=tf.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) # Build and populate an empty SignatureDef for testing. - foo_signature = utils.build_signature_def(dict(), dict(), "foo") + foo_signature = signature_def_utils.build_signature_def( + dict(), dict(), "foo") builder.add_meta_graph_and_variables( sess, ["foo"], signature_def_map={"foo_key": foo_signature}) @@ -324,10 +325,12 @@ class SavedModelTest(tf.test.TestCase): with self.test_session(graph=tf.Graph()) as sess: self._init_and_validate_variable(sess, "v", 43) # Build and populate a different SignatureDef for testing. - bar_signature = utils.build_signature_def(dict(), dict(), "bar") + bar_signature = signature_def_utils.build_signature_def( + dict(), dict(), "bar") # Also, build a different SignatureDef corresponding to "foo_key" defined # in the previous graph. - foo_new_signature = utils.build_signature_def(dict(), dict(), "foo_new") + foo_new_signature = signature_def_utils.build_signature_def( + dict(), dict(), "foo_new") builder.add_meta_graph( ["bar"], signature_def_map={"bar_key": bar_signature, diff --git a/tensorflow/python/saved_model/signature_def_utils.py b/tensorflow/python/saved_model/signature_def_utils.py new file mode 100644 index 0000000000..23e844adb2 --- /dev/null +++ b/tensorflow/python/saved_model/signature_def_utils.py @@ -0,0 +1,158 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SignatureDef utility functions. + +Utility functions for constructing SignatureDef protos. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import utils + + +def build_signature_def(inputs=None, outputs=None, method_name=None): + """Utility function to build a SignatureDef protocol buffer. + + Args: + inputs: Inputs of the SignatureDef defined as a proto map of string to + tensor info. + outputs: Outputs of the SignatureDef defined as a proto map of string to + tensor info. + method_name: Method name of the SignatureDef as a string. + + Returns: + A SignatureDef protocol buffer constructed based on the supplied arguments. + """ + signature_def = meta_graph_pb2.SignatureDef() + if inputs is not None: + for item in inputs: + signature_def.inputs[item].CopyFrom(inputs[item]) + if outputs is not None: + for item in outputs: + signature_def.outputs[item].CopyFrom(outputs[item]) + if method_name is not None: + signature_def.method_name = method_name + return signature_def + + +def regression_signature_def(examples, predictions): + """Creates regression signature from given examples and predictions. + + Args: + examples: `Tensor`. + predictions: `Tensor`. + + Returns: + A regression-flavored signature_def. + + Raises: + ValueError: If examples is `None`. + """ + if examples is None: + raise ValueError('examples cannot be None for regression.') + if predictions is None: + raise ValueError('predictions cannot be None for regression.') + + input_tensor_info = utils.build_tensor_info(examples) + signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info} + + output_tensor_info = utils.build_tensor_info(predictions) + signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info} + signature_def = build_signature_def( + signature_inputs, signature_outputs, + signature_constants.REGRESS_METHOD_NAME) + + return signature_def + + +def classification_signature_def(examples, classes, scores): + """Creates classification signature from given examples and predictions. + + Args: + examples: `Tensor`. + classes: `Tensor`. + scores: `Tensor`. + + Returns: + A classification-flavored signature_def. + + Raises: + ValueError: If examples is `None`. + """ + if examples is None: + raise ValueError('examples cannot be None for classification.') + if classes is None and scores is None: + raise ValueError('classes and scores cannot both be None for ' + 'classification.') + + input_tensor_info = utils.build_tensor_info(examples) + signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info} + + signature_outputs = {} + if classes is not None: + classes_tensor_info = utils.build_tensor_info(classes) + signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = ( + classes_tensor_info) + if scores is not None: + scores_tensor_info = utils.build_tensor_info(scores) + signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = ( + scores_tensor_info) + + signature_def = build_signature_def( + signature_inputs, signature_outputs, + signature_constants.CLASSIFY_METHOD_NAME) + + return signature_def + + +def predict_signature_def(inputs, outputs): + """Creates prediction signature from given inputs and outputs. + + Args: + inputs: dict of string to `Tensor`. + outputs: dict of string to `Tensor`. + + Returns: + A prediction-flavored signature_def. + + Raises: + ValueError: If inputs or outputs is `None`. + """ + if inputs is None or not inputs: + raise ValueError('inputs cannot be None or empty for prediction.') + if outputs is None: + raise ValueError('outputs cannot be None or empty for prediction.') + + # If there's only one input or output, we can standardize keys + if len(inputs) == 1: + (_, value), = inputs.items() + inputs = {signature_constants.PREDICT_INPUTS: value} + if len(outputs) == 1: + (_, value), = outputs.items() + outputs = {signature_constants.PREDICT_OUTPUTS: value} + + signature_inputs = {key: utils.build_tensor_info(tensor) + for key, tensor in inputs.items()} + signature_outputs = {key: utils.build_tensor_info(tensor) + for key, tensor in outputs.items()} + + signature_def = build_signature_def( + signature_inputs, signature_outputs, + signature_constants.PREDICT_METHOD_NAME) + + return signature_def diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py new file mode 100644 index 0000000000..6dfc4b2cd6 --- /dev/null +++ b/tensorflow/python/saved_model/signature_def_utils_test.py @@ -0,0 +1,156 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SignatureDef utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.core.framework import types_pb2 +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.saved_model import utils + + +class SignatureDefUtilsTest(tf.test.TestCase): + + def testBuildSignatureDef(self): + x = tf.placeholder(tf.float32, 1, name="x") + x_tensor_info = utils.build_tensor_info(x) + inputs = dict() + inputs["foo-input"] = x_tensor_info + + y = tf.placeholder(tf.float32, name="y") + y_tensor_info = utils.build_tensor_info(y) + outputs = dict() + outputs["foo-output"] = y_tensor_info + + signature_def = signature_def_utils.build_signature_def( + inputs, outputs, "foo-method-name") + self.assertEqual("foo-method-name", signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(1, len(signature_def.inputs)) + x_tensor_info_actual = signature_def.inputs["foo-input"] + self.assertEqual("x:0", x_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype) + self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim)) + self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size) + + # Check outputs in signature def. + self.assertEqual(1, len(signature_def.outputs)) + y_tensor_info_actual = signature_def.outputs["foo-output"] + self.assertEqual("y:0", y_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype) + self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim)) + + def testRegressionSignatureDef(self): + input1 = tf.constant("a", name="input-1") + output1 = tf.constant("b", name="output-1") + signature_def = signature_def_utils.regression_signature_def( + input1, output1) + + self.assertEqual(signature_constants.REGRESS_METHOD_NAME, + signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(1, len(signature_def.inputs)) + x_tensor_info_actual = ( + signature_def.inputs[signature_constants.REGRESS_INPUTS]) + self.assertEqual("input-1:0", x_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype) + self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim)) + + # Check outputs in signature def. + self.assertEqual(1, len(signature_def.outputs)) + y_tensor_info_actual = ( + signature_def.outputs[signature_constants.REGRESS_OUTPUTS]) + self.assertEqual("output-1:0", y_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, y_tensor_info_actual.dtype) + self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim)) + + def testClassificationSignatureDef(self): + input1 = tf.constant("a", name="input-1") + output1 = tf.constant("b", name="output-1") + output2 = tf.constant("c", name="output-2") + signature_def = signature_def_utils.classification_signature_def( + input1, output1, output2) + + self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME, + signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(1, len(signature_def.inputs)) + x_tensor_info_actual = ( + signature_def.inputs[signature_constants.CLASSIFY_INPUTS]) + self.assertEqual("input-1:0", x_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype) + self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim)) + + # Check outputs in signature def. + self.assertEqual(2, len(signature_def.outputs)) + classes_tensor_info_actual = ( + signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES]) + self.assertEqual("output-1:0", classes_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, classes_tensor_info_actual.dtype) + self.assertEqual(0, len(classes_tensor_info_actual.tensor_shape.dim)) + scores_tensor_info_actual = ( + signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES]) + self.assertEqual("output-2:0", scores_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, scores_tensor_info_actual.dtype) + self.assertEqual(0, len(scores_tensor_info_actual.tensor_shape.dim)) + + def testPredictionSignatureDef(self): + input1 = tf.constant("a", name="input-1") + input2 = tf.constant("b", name="input-2") + output1 = tf.constant("c", name="output-1") + output2 = tf.constant("d", name="output-2") + signature_def = signature_def_utils.predict_signature_def( + {"input-1": input1, "input-2": input2}, + {"output-1": output1, "output-2": output2}) + + self.assertEqual(signature_constants.PREDICT_METHOD_NAME, + signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(2, len(signature_def.inputs)) + input1_tensor_info_actual = ( + signature_def.inputs["input-1"]) + self.assertEqual("input-1:0", input1_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype) + self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim)) + input2_tensor_info_actual = ( + signature_def.inputs["input-2"]) + self.assertEqual("input-2:0", input2_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype) + self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim)) + + # Check outputs in signature def. + self.assertEqual(2, len(signature_def.outputs)) + output1_tensor_info_actual = ( + signature_def.outputs["output-1"]) + self.assertEqual("output-1:0", output1_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, output1_tensor_info_actual.dtype) + self.assertEqual(0, len(output1_tensor_info_actual.tensor_shape.dim)) + output2_tensor_info_actual = ( + signature_def.outputs["output-2"]) + self.assertEqual("output-2:0", output2_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype) + self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim)) + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/saved_model/utils.py b/tensorflow/python/saved_model/utils.py index 550eed0fcc..ecc58fbc7a 100644 --- a/tensorflow/python/saved_model/utils.py +++ b/tensorflow/python/saved_model/utils.py @@ -23,6 +23,7 @@ from __future__ import print_function from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.framework import dtypes + # TensorInfo helpers. @@ -40,30 +41,3 @@ def build_tensor_info(tensor): name=tensor.name, dtype=dtype_enum, tensor_shape=tensor.get_shape().as_proto()) - -# SignatureDef helpers. - - -def build_signature_def(inputs=None, outputs=None, method_name=None): - """Utility function to build a SignatureDef protocol buffer. - - Args: - inputs: Inputs of the SignatureDef defined as a proto map of string to - tensor info. - outputs: Outputs of the SignatureDef defined as a proto map of string to - tensor info. - method_name: Method name of the SignatureDef as a string. - - Returns: - A SignatureDef protocol buffer constructed based on the supplied arguments. - """ - signature_def = meta_graph_pb2.SignatureDef() - if inputs is not None: - for item in inputs: - signature_def.inputs[item].CopyFrom(inputs[item]) - if outputs is not None: - for item in outputs: - signature_def.outputs[item].CopyFrom(outputs[item]) - if method_name is not None: - signature_def.method_name = method_name - return signature_def diff --git a/tensorflow/python/saved_model/utils_test.py b/tensorflow/python/saved_model/utils_test.py index 8ce7d1dea1..74f2624773 100644 --- a/tensorflow/python/saved_model/utils_test.py +++ b/tensorflow/python/saved_model/utils_test.py @@ -34,36 +34,6 @@ class UtilsTest(tf.test.TestCase): self.assertEqual(1, len(x_tensor_info.tensor_shape.dim)) self.assertEqual(1, x_tensor_info.tensor_shape.dim[0].size) - def testBuildSignatureDef(self): - x = tf.placeholder(tf.float32, 1, name="x") - x_tensor_info = utils.build_tensor_info(x) - inputs = dict() - inputs["foo-input"] = x_tensor_info - - y = tf.placeholder(tf.float32, name="y") - y_tensor_info = utils.build_tensor_info(y) - outputs = dict() - outputs["foo-output"] = y_tensor_info - - signature_def = utils.build_signature_def(inputs, outputs, - "foo-method-name") - self.assertEqual("foo-method-name", signature_def.method_name) - - # Check inputs in signature def. - self.assertEqual(1, len(signature_def.inputs)) - x_tensor_info_actual = signature_def.inputs["foo-input"] - self.assertEqual("x:0", x_tensor_info_actual.name) - self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype) - self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim)) - self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size) - - # Check outputs in signature def. - self.assertEqual(1, len(signature_def.outputs)) - y_tensor_info_actual = signature_def.outputs["foo-output"] - self.assertEqual("y:0", y_tensor_info_actual.name) - self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype) - self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim)) - if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/summary/event_accumulator_test.py b/tensorflow/python/summary/event_accumulator_test.py index aa85ea56ab..6d659e27e3 100644 --- a/tensorflow/python/summary/event_accumulator_test.py +++ b/tensorflow/python/summary/event_accumulator_test.py @@ -645,7 +645,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): ipt = tf.placeholder(tf.float32) tf.summary.scalar('scalar1', ipt) tf.summary.scalar('scalar2', ipt * ipt) - merged = tf.merge_all_summaries() + merged = tf.contrib.deprecated.merge_all_summaries() writer.add_graph(sess.graph) for i in xrange(10): summ = sess.run(merged, feed_dict={ipt: i}) @@ -692,7 +692,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): tf.summary.image('images', ipt, max_outputs=2) with tf.name_scope('3'): tf.summary.image('images', ipt, max_outputs=3) - merged = tf.merge_all_summaries() + merged = tf.contrib.deprecated.merge_all_summaries() writer.add_graph(sess.graph) for i in xrange(10): summ = sess.run(merged) @@ -736,7 +736,7 @@ class RealisticEventAccumulatorTest(EventAccumulatorTest): gfile.DeleteRecursively(directory) gfile.MkDir(directory) - writer = tf.train.SummaryWriter(directory, max_queue=100) + writer = tf.summary.FileWriter(directory, max_queue=100) with tf.Graph().as_default() as graph: _ = tf.constant([2.0, 1.0]) @@ -814,7 +814,7 @@ class RealisticEventAccumulatorTest(EventAccumulatorTest): gfile.DeleteRecursively(directory) gfile.MkDir(directory) - writer = tf.train.SummaryWriter(directory, max_queue=100) + writer = tf.summary.FileWriter(directory, max_queue=100) with tf.Graph().as_default() as graph: _ = tf.constant([2.0, 1.0]) diff --git a/tensorflow/python/summary/impl/event_file_loader.py b/tensorflow/python/summary/impl/event_file_loader.py index dedebe5484..ccc61d4564 100644 --- a/tensorflow/python/summary/impl/event_file_loader.py +++ b/tensorflow/python/summary/impl/event_file_loader.py @@ -52,7 +52,15 @@ class EventFileLoader(object): Yields: All values that were written to disk that have not been yielded yet. """ - while self._reader.GetNext(): + while True: + try: + with errors.raise_exception_on_not_ok_status() as status: + self._reader.GetNext(status) + except (errors.DataLossError, errors.OutOfRangeError): + # We ignore partial read exceptions, because a record may be truncated. + # PyRecordReader holds the offset prior to the failed read, so retrying + # will succeed. + break event = event_pb2.Event() event.ParseFromString(self._reader.record()) yield event diff --git a/tensorflow/python/summary/impl/event_file_loader_test.py b/tensorflow/python/summary/impl/event_file_loader_test.py index f4d7cf218e..0b354d553d 100644 --- a/tensorflow/python/summary/impl/event_file_loader_test.py +++ b/tensorflow/python/summary/impl/event_file_loader_test.py @@ -78,6 +78,15 @@ class EventFileLoaderTest(test_util.TensorFlowTestCase): loader = self._LoaderForTestFile(filename) self.assertEqual(len(list(loader.Load())), 2) + def testMultipleWritesWithBadWrite(self): + filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name + self._WriteToFile(filename, EventFileLoaderTest.RECORD) + self._WriteToFile(filename, EventFileLoaderTest.RECORD) + # Test that we ignore partial record writes at the end of the file. + self._WriteToFile(filename, b'123') + loader = self._LoaderForTestFile(filename) + self.assertEqual(len(list(loader.Load())), 2) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py index 4e29bbd88d..2e653106f4 100644 --- a/tensorflow/python/summary/summary.py +++ b/tensorflow/python/summary/summary.py @@ -17,6 +17,7 @@ ### Class for writing Summaries @@FileWriter +@@FileWriterCache ### Summary Ops @@tensor_summary @@ -56,9 +57,10 @@ from tensorflow.python.ops import gen_logging_ops as _gen_logging_ops from tensorflow.python.ops.summary_ops import tensor_summary # pylint: enable=unused-import from tensorflow.python.platform import tf_logging as _logging -# exports FileWriter +# exports FileWriter, FileWriterCache # pylint: disable=unused-import from tensorflow.python.summary.writer.writer import FileWriter +from tensorflow.python.summary.writer.writer_cache import FileWriterCache # pylint: enable=unused-import from tensorflow.python.util import compat as _compat from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/python/summary/summary_iterator.py b/tensorflow/python/summary/summary_iterator.py index 9c3e8fcf4e..490ce141f1 100644 --- a/tensorflow/python/summary/summary_iterator.py +++ b/tensorflow/python/summary/summary_iterator.py @@ -79,7 +79,7 @@ class SummaryWriter(object): # Launch the graph in a session. sess = tf.Session() # Create a summary writer, add the 'graph' to the event file. - writer = tf.train.SummaryWriter(<some-directory>, sess.graph) + writer = tf.summary.FileWriter(<some-directory>, sess.graph) ``` The other arguments to the constructor control the asynchronous writes to @@ -342,7 +342,7 @@ def summary_iterator(path): # This example supposes that the events file contains summaries with a # summary value tag 'loss'. These could have been added by calling # `add_summary()`, passing the output of a scalar summary op created with - # with: `tf.scalar_summary(['loss'], loss_tensor)`. + # with: `tf.summary.scalar('loss', loss_tensor)`. for e in tf.train.summary_iterator(path to events file): for v in e.summary.value: if v.tag == 'loss': diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py index fa1715bbcb..fc90d547dc 100644 --- a/tensorflow/python/summary/writer/writer.py +++ b/tensorflow/python/summary/writer/writer.py @@ -66,7 +66,7 @@ class SummaryToEventTransformer(object): # Launch the graph in a session. sess = tf.Session() # Create a summary writer, add the 'graph' to the event file. - writer = tf.train.SummaryWriter(<some-directory>, sess.graph) + writer = tf.summary.FileWriter(<some-directory>, sess.graph) ``` @@ -286,7 +286,7 @@ class FileWriter(SummaryToEventTransformer): # Launch the graph in a session. sess = tf.Session() # Create a summary writer, add the 'graph' to the event file. - writer = tf.train.SummaryWriter(<some-directory>, sess.graph) + writer = tf.summary.FileWriter(<some-directory>, sess.graph) ``` The other arguments to the constructor control the asynchronous writes to diff --git a/tensorflow/python/summary/writer/writer_cache.py b/tensorflow/python/summary/writer/writer_cache.py index 7655fc5ba4..21870e788e 100644 --- a/tensorflow/python/summary/writer/writer_cache.py +++ b/tensorflow/python/summary/writer/writer_cache.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Reads Summaries from and writes Summaries to event files.""" +"""A cache for FileWriters.""" from __future__ import absolute_import from __future__ import division @@ -21,38 +21,38 @@ from __future__ import print_function import threading from tensorflow.python.framework import ops -from tensorflow.python.summary.writer.writer import FileWriter as SummaryWriter +from tensorflow.python.summary.writer.writer import FileWriter -class SummaryWriterCache(object): - """Cache for summary writers. +class FileWriterCache(object): + """Cache for file writers. - This class caches summary writers, one per directory. + This class caches file writers, one per directory. """ # Cache, keyed by directory. _cache = {} - # Lock protecting _SUMMARY_WRITERS. + # Lock protecting _FILE_WRITERS. _lock = threading.RLock() @staticmethod def clear(): """Clear cached summary writers. Currently only used for unit tests.""" - with SummaryWriterCache._lock: - SummaryWriterCache._cache = {} + with FileWriterCache._lock: + FileWriterCache._cache = {} @staticmethod def get(logdir): - """Returns the SummaryWriter for the specified directory. + """Returns the FileWriter for the specified directory. Args: logdir: str, name of the directory. Returns: - A `SummaryWriter`. + A `FileWriter`. """ - with SummaryWriterCache._lock: - if logdir not in SummaryWriterCache._cache: - SummaryWriterCache._cache[logdir] = SummaryWriter( + with FileWriterCache._lock: + if logdir not in FileWriterCache._cache: + FileWriterCache._cache[logdir] = FileWriter( logdir, graph=ops.get_default_graph()) - return SummaryWriterCache._cache[logdir] + return FileWriterCache._cache[logdir] diff --git a/tensorflow/python/summary/writer/writer_test.py b/tensorflow/python/summary/writer/writer_test.py index aeaebc2092..466f691dd6 100644 --- a/tensorflow/python/summary/writer/writer_test.py +++ b/tensorflow/python/summary/writer/writer_test.py @@ -83,7 +83,7 @@ class SummaryWriterTestCase(tf.test.TestCase): def testAddingSummaryGraphAndRunMetadata(self): test_dir = self._CleanTestDir("basics") - sw = tf.train.SummaryWriter(test_dir) + sw = tf.summary.FileWriter(test_dir) sw.add_session_log(tf.SessionLog(status=SessionLog.START), 1) sw.add_summary( @@ -154,7 +154,7 @@ class SummaryWriterTestCase(tf.test.TestCase): test_dir = self._CleanTestDir("basics_named_graph") with tf.Graph().as_default() as g: tf.constant([12], name="douze") - sw = tf.train.SummaryWriter(test_dir, graph=g) + sw = tf.summary.FileWriter(test_dir, graph=g) sw.close() self._assertEventsWithGraph(test_dir, g, True) @@ -162,7 +162,7 @@ class SummaryWriterTestCase(tf.test.TestCase): test_dir = self._CleanTestDir("basics_positional_graph") with tf.Graph().as_default() as g: tf.constant([12], name="douze") - sw = tf.train.SummaryWriter(test_dir, g) + sw = tf.summary.FileWriter(test_dir, g) sw.close() self._assertEventsWithGraph(test_dir, g, True) @@ -171,7 +171,7 @@ class SummaryWriterTestCase(tf.test.TestCase): with tf.Graph().as_default() as g: tf.constant([12], name="douze") gd = g.as_graph_def() - sw = tf.train.SummaryWriter(test_dir, graph_def=gd) + sw = tf.summary.FileWriter(test_dir, graph_def=gd) sw.close() self._assertEventsWithGraph(test_dir, g, False) @@ -180,7 +180,7 @@ class SummaryWriterTestCase(tf.test.TestCase): with tf.Graph().as_default() as g: tf.constant([12], name="douze") gd = g.as_graph_def() - sw = tf.train.SummaryWriter(test_dir, gd) + sw = tf.summary.FileWriter(test_dir, gd) sw.close() self._assertEventsWithGraph(test_dir, g, False) @@ -190,18 +190,18 @@ class SummaryWriterTestCase(tf.test.TestCase): with tf.Graph().as_default() as g: tf.constant([12], name="douze") gd = g.as_graph_def() - sw = tf.train.SummaryWriter(test_dir, graph=g, graph_def=gd) + sw = tf.summary.FileWriter(test_dir, graph=g, graph_def=gd) sw.close() def testNeitherGraphNorGraphDef(self): with self.assertRaises(TypeError): test_dir = self._CleanTestDir("basics_string_instead_of_graph") - sw = tf.train.SummaryWriter(test_dir, "string instead of graph object") + sw = tf.summary.FileWriter(test_dir, "string instead of graph object") sw.close() def testCloseAndReopen(self): test_dir = self._CleanTestDir("close_and_reopen") - sw = tf.train.SummaryWriter(test_dir) + sw = tf.summary.FileWriter(test_dir) sw.add_session_log(tf.SessionLog(status=SessionLog.START), 1) sw.close() # Sleep at least one second to make sure we get a new event file name. @@ -247,7 +247,7 @@ class SummaryWriterTestCase(tf.test.TestCase): # protocol buffers correctly. def testAddingSummariesFromSessionRunCalls(self): test_dir = self._CleanTestDir("global_step") - sw = tf.train.SummaryWriter(test_dir) + sw = tf.summary.FileWriter(test_dir) with self.test_session(): i = tf.constant(1, dtype=tf.int32, shape=[]) l = tf.constant(2, dtype=tf.int64, shape=[]) @@ -314,9 +314,9 @@ class SummaryWriterCacheTest(tf.test.TestCase): with tf.Graph().as_default(): dir1 = self._test_dir("test_cache_1") dir2 = self._test_dir("test_cache_2") - sw1 = tf.train.SummaryWriterCache.get(dir1) - sw2 = tf.train.SummaryWriterCache.get(dir2) - sw3 = tf.train.SummaryWriterCache.get(dir1) + sw1 = tf.summary.FileWriterCache.get(dir1) + sw2 = tf.summary.FileWriterCache.get(dir2) + sw3 = tf.summary.FileWriterCache.get(dir1) self.assertEqual(sw1, sw3) self.assertFalse(sw1 == sw2) sw1.close() @@ -331,9 +331,9 @@ class SummaryWriterCacheTest(tf.test.TestCase): def test_clear(self): with tf.Graph().as_default(): dir1 = self._test_dir("test_clear") - sw1 = tf.train.SummaryWriterCache.get(dir1) - tf.train.SummaryWriterCache.clear() - sw2 = tf.train.SummaryWriterCache.get(dir1) + sw1 = tf.summary.FileWriterCache.get(dir1) + tf.summary.FileWriterCache.clear() + sw2 = tf.summary.FileWriterCache.get(dir1) self.assertFalse(sw1 == sw2) diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index 97c942320a..0f7deb7827 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -29,7 +29,6 @@ limitations under the License. %include "tensorflow/python/client/tf_session.i" %include "tensorflow/python/client/device_lib.i" -%include "tensorflow/python/client/net_lib.i" %include "tensorflow/python/client/quantize_training.i" %include "tensorflow/python/lib/io/file_io.i" diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index ca8c537d55..542396003c 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -314,7 +314,6 @@ class StepCounterHook(session_run_hook.SessionRunHook): every_n_secs=None, output_dir=None, summary_writer=None): - self._summary_tag = "global_step/sec" if (every_n_steps is None) == (every_n_secs is None): raise ValueError( @@ -328,6 +327,7 @@ class StepCounterHook(session_run_hook.SessionRunHook): def begin(self): self._global_step_tensor = training_util.get_global_step() + self._summary_tag = self._global_step_tensor.op.name + "/sec" if self._global_step_tensor is None: raise RuntimeError( "Global step should be created to use StepCounterHook.") diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py index f7eab4a3b5..1b8ebd11f3 100644 --- a/tensorflow/python/training/basic_session_run_hooks_test.py +++ b/tensorflow/python/training/basic_session_run_hooks_test.py @@ -419,6 +419,34 @@ class StepCounterHookTest(tf.test.TestCase): self.assertEqual('global_step/sec', summary_value.tag) self.assertGreater(summary_value.simple_value, 0) + def test_global_step_name(self): + with tf.Graph().as_default() as g, tf.Session() as sess: + with tf.variable_scope('bar'): + foo_step = tf.get_variable('foo', initializer=0, trainable=False, + collections=[tf.GraphKeys.GLOBAL_STEP, + tf.GraphKeys.GLOBAL_VARIABLES]) + train_op = tf.assign_add(foo_step, 1) + summary_writer = testing.FakeSummaryWriter(self.log_dir, g) + hook = tf.train.StepCounterHook( + summary_writer=summary_writer, every_n_steps=1, every_n_secs=None) + + hook.begin() + sess.run(tf.global_variables_initializer()) + mon_sess = monitored_session._HookedSession(sess, [hook]) + mon_sess.run(train_op) + mon_sess.run(train_op) + hook.end(sess) + + summary_writer.assert_summaries( + test_case=self, + expected_logdir=self.log_dir, + expected_graph=g, + expected_summaries={}) + self.assertTrue(summary_writer.summaries, 'No summaries were created.') + self.assertItemsEqual([2], summary_writer.summaries.keys()) + summary_value = summary_writer.summaries[2][0].value[0] + self.assertEqual('bar/foo/sec', summary_value.tag) + class SummarySaverHookTest(tf.test.TestCase): @@ -581,7 +609,7 @@ class GlobalStepWaiterHookTest(tf.test.TestCase): hook = tf.train.GlobalStepWaiterHook(wait_until_step=1000) hook.begin() with tf.Session() as sess: - sess.run(tf.initialize_all_variables()) + sess.run(tf.global_variables_initializer()) waiter = threading.Thread( target=hook.before_run, args=(tf.train.SessionRunContext( diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index bac38ee689..45438b1342 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -288,7 +288,8 @@ class ExponentialMovingAverage(object): @@variables_to_restore """ - def __init__(self, decay, num_updates=None, name="ExponentialMovingAverage"): + def __init__(self, decay, num_updates=None, zero_debias=False, + name="ExponentialMovingAverage"): """Creates a new ExponentialMovingAverage object. The `apply()` method has to be called to create shadow variables and add @@ -305,11 +306,14 @@ class ExponentialMovingAverage(object): Args: decay: Float. The decay to use. num_updates: Optional count of number of updates applied to variables. + zero_debias: If `True`, zero debias moving-averages that are initialized + with tensors. name: String. Optional prefix name to use for the name of ops added in `apply()`. """ self._decay = decay self._num_updates = num_updates + self._zero_debias = zero_debias self._name = name self._averages = {} @@ -373,7 +377,8 @@ class ExponentialMovingAverage(object): var, self._name, colocate_with_primary=(var.op.type == "Variable")) - zero_debias_true.add(avg) + if self._zero_debias: + zero_debias_true.add(avg) self._averages[var] = avg with ops.name_scope(self._name) as scope: diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py index a892912cc8..dae89fbefe 100644 --- a/tensorflow/python/training/moving_averages_test.py +++ b/tensorflow/python/training/moving_averages_test.py @@ -89,6 +89,11 @@ def _Repeat(value, dim): class ExponentialMovingAverageTest(tf.test.TestCase): def _CheckDecay(self, ema, actual_decay, dim): + def _Scale(dk, steps): + if ema._zero_debias: + return 1 - dk ** (steps + 1) + else: + return 1 tens = _Repeat(10.0, dim) thirties = _Repeat(30.0, dim) var0 = tf.Variable(tens, name="v0") @@ -133,7 +138,7 @@ class ExponentialMovingAverageTest(tf.test.TestCase): self.assertAllClose(expected, avg0.eval()) expected = _Repeat(30.0 * dk + 30.0 * (1 - dk), dim) self.assertAllClose(expected, avg1.eval()) - expected = _Repeat(0.0 * dk + (10.0 + 30.0) * (1 - dk) / (1 - dk ** 2), dim) + expected = _Repeat(0.0 * dk + (10.0 + 30.0) * (1 - dk) / _Scale(dk, 1), dim) self.assertAllClose(expected, avg2.eval()) # Again, update the averages and check. @@ -145,7 +150,7 @@ class ExponentialMovingAverageTest(tf.test.TestCase): dim) self.assertAllClose(expected, avg1.eval()) expected = _Repeat(((0.0 * dk + (10.0 + 30.0) * (1 - dk)) * dk + - (10.0 + 30.0) * (1 - dk)) / (1 - dk ** 3), + (10.0 + 30.0) * (1 - dk)) / _Scale(dk, 2), dim) self.assertAllClose(expected, avg2.eval()) @@ -154,23 +159,47 @@ class ExponentialMovingAverageTest(tf.test.TestCase): ema = tf.train.ExponentialMovingAverage(0.25) self._CheckDecay(ema, actual_decay=0.25, dim=1) + def testAverageVariablesNoNumUpdates_Scalar_Debias(self): + with self.test_session(): + ema = tf.train.ExponentialMovingAverage(0.25, zero_debias=True) + self._CheckDecay(ema, actual_decay=0.25, dim=1) + def testAverageVariablesNoNumUpdates_Vector(self): with self.test_session(): ema = tf.train.ExponentialMovingAverage(0.25) self._CheckDecay(ema, actual_decay=0.25, dim=5) + def testAverageVariablesNoNumUpdates_Vector_Debias(self): + with self.test_session(): + ema = tf.train.ExponentialMovingAverage(0.25, zero_debias=True) + self._CheckDecay(ema, actual_decay=0.25, dim=5) + def testAverageVariablesNumUpdates_Scalar(self): with self.test_session(): # With num_updates 1, the decay applied is 0.1818 ema = tf.train.ExponentialMovingAverage(0.25, num_updates=1) self._CheckDecay(ema, actual_decay=0.181818, dim=1) + def testAverageVariablesNumUpdates_Scalar_Debias(self): + with self.test_session(): + # With num_updates 1, the decay applied is 0.1818 + ema = tf.train.ExponentialMovingAverage( + 0.25, num_updates=1, zero_debias=True) + self._CheckDecay(ema, actual_decay=0.181818, dim=1) + def testAverageVariablesNumUpdates_Vector(self): with self.test_session(): # With num_updates 1, the decay applied is 0.1818 ema = tf.train.ExponentialMovingAverage(0.25, num_updates=1) self._CheckDecay(ema, actual_decay=0.181818, dim=5) + def testAverageVariablesNumUpdates_Vector_Debias(self): + with self.test_session(): + # With num_updates 1, the decay applied is 0.1818 + ema = tf.train.ExponentialMovingAverage( + 0.25, num_updates=1, zero_debias=True) + self._CheckDecay(ema, actual_decay=0.181818, dim=5) + def testAverageVariablesWithControlDeps(self): with self.test_session() as sess: v0 = tf.Variable(0, name="v0") @@ -195,14 +224,15 @@ class ExponentialMovingAverageTest(tf.test.TestCase): self.assertEqual(1, sess.run(v0)) self.assertEqual([17.5], sess.run(v1_avg)) - def testAverageVariablesNames(self): + def averageVariablesNamesHelper(self, zero_debias): with self.test_session(): v0 = tf.Variable(10.0, name="v0") v1 = tf.Variable(30.0, name="v1") # Add a non-trainable variable. v2 = tf.Variable(20.0, name="v2", trainable=False) tensor2 = v0 + v1 - ema = tf.train.ExponentialMovingAverage(0.25, name="foo") + ema = tf.train.ExponentialMovingAverage( + 0.25, zero_debias=zero_debias, name="foo") self.assertEqual("v0/foo", ema.average_name(v0)) self.assertEqual("v1/foo", ema.average_name(v1)) self.assertEqual("add/foo", ema.average_name(tensor2)) @@ -212,21 +242,30 @@ class ExponentialMovingAverageTest(tf.test.TestCase): # {v0/foo : v0, # v1/foo : v1, # add/foo : add/foo, - # add/foo/biased: add/foo/biased, - # add/foo/local_step: add/foo/local_step, # v2 : v2} + expected_names = [ema.average_name(v0), + ema.average_name(v1), + ema.average_name(tensor2), + v2.op.name] + if zero_debias: + # vars_to_restore should also contain the following: + # {add/foo/biased: add/foo/biased, + # add/foo/local_step: add/foo/local_step} + expected_names += [ema.average_name(tensor2) + "/biased", + ema.average_name(tensor2) + "/local_step"] self.assertEqual(sorted(vars_to_restore.keys()), - sorted([ema.average_name(v0), - ema.average_name(v1), - ema.average_name(tensor2), - ema.average_name(tensor2) + "/biased", - ema.average_name(tensor2) + "/local_step", - v2.op.name])) + sorted(expected_names)) self.assertEqual(ema.average_name(v0), ema.average(v0).op.name) self.assertEqual(ema.average_name(v1), ema.average(v1).op.name) self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name) - def testAverageVariablesNamesRespectScope(self): + def testAverageVariablesNames(self): + self.averageVariablesNamesHelper(zero_debias=True) + + def testAverageVariablesNamesNoDebias(self): + self.averageVariablesNamesHelper(zero_debias=False) + + def averageVariablesNamesRespectScopeHelper(self, zero_debias): # See discussion on #2740. with self.test_session(): with tf.variable_scope("scope1"): @@ -236,7 +275,8 @@ class ExponentialMovingAverageTest(tf.test.TestCase): v2 = tf.Variable(20.0, name="v2", trainable=False) tensor2 = v0 + v1 with tf.variable_scope("scope2"): - ema = tf.train.ExponentialMovingAverage(0.25, name="foo") + ema = tf.train.ExponentialMovingAverage( + 0.25, zero_debias=zero_debias, name="foo") self.assertEqual("scope2/scope1/v0/foo", ema.average_name(v0)) self.assertEqual("scope2/scope1/v1/foo", ema.average_name(v1)) self.assertEqual("scope2/scope1/add/foo", ema.average_name(tensor2)) @@ -246,22 +286,32 @@ class ExponentialMovingAverageTest(tf.test.TestCase): # {scope2/scope1/v0/foo : v0, # scope2/scope1/v1/foo : v1, # scope2/scope1/add/foo : add/foo, - # scope2/scope2/scope1/add/foo/biased: add/foo/biased, - # scope2/scope2/scope1/add/foo/local_step: add/foo/local_step, # scope1/v2 : v2} - sc = "scope2/" + expected_names = [ema.average_name(v0), + ema.average_name(v1), + ema.average_name(tensor2), + v2.op.name] + if zero_debias: + # vars_to_restore should also contain the following: + # {scope2/scope2/scope1/add/foo/biased: add/foo/biased, + # scope2/scope2/scope1/add/foo/local_step: add/foo/local_step} + sc = "scope2/" + expected_names += [sc + ema.average_name(tensor2) + "/biased", + sc + ema.average_name(tensor2) + "/local_step"] + self.assertEqual(sorted(vars_to_restore.keys()), - sorted([ema.average_name(v0), - ema.average_name(v1), - ema.average_name(tensor2), - sc + ema.average_name(tensor2) + "/biased", - sc + ema.average_name(tensor2) + "/local_step", - v2.op.name])) + sorted(expected_names)) self.assertEqual(ema.average_name(v0), ema.average(v0).op.name) self.assertEqual(ema.average_name(v1), ema.average(v1).op.name) self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name) + def testAverageVariablesNamesRespectScope(self): + self.averageVariablesNamesRespectScopeHelper(zero_debias=True) + + def testAverageVariablesNamesRespectScopeNoDebias(self): + self.averageVariablesNamesRespectScopeHelper(zero_debias=False) + def testSubsetAverageVariablesNames(self): with self.test_session(): v0 = tf.Variable(10.0, name="v0") diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 68e4bfd0f8..cb4e1de235 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -36,7 +36,6 @@ from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as pydev from tensorflow.python.framework import errors -from tensorflow.python.framework import graph_util from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io @@ -62,6 +61,23 @@ _VARIABLE_OPS = set(["Variable", "ResourceGather"]) +def _set_cpu0(device_string): + """Creates a new device string based on `device_string` but using /CPU:0. + + If the device is already on /CPU:0, this is a no-op. + + Args: + device_string: A device string. + + Returns: + A device string. + """ + parsed_device = pydev.DeviceSpec.from_string(device_string) + parsed_device.device_type = "CPU" + parsed_device.device_index = 0 + return parsed_device.to_string() + + class BaseSaverBuilder(object): """Base class for Savers. @@ -380,8 +396,7 @@ class BaseSaverBuilder(object): # available on the GPU. # TODO(touts): Re-enable restore on GPU when we can support annotating # string tensors as "HostMemory" inputs. - with ops.device( - graph_util.set_cpu0(saveable.device) if saveable.device else None): + with ops.device(_set_cpu0(saveable.device) if saveable.device else None): with ops.control_dependencies(restore_control_inputs): tensors = self.restore_op(filename_tensor, saveable, preferred_shard) shapes = None diff --git a/tensorflow/python/training/summary_io.py b/tensorflow/python/training/summary_io.py index cbbc527d85..4d1c3e7954 100644 --- a/tensorflow/python/training/summary_io.py +++ b/tensorflow/python/training/summary_io.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Reads Summaries from and writes Summaries to event files.""" from __future__ import absolute_import @@ -22,7 +21,7 @@ from __future__ import print_function # pylint: disable=unused-import from tensorflow.python.summary.summary_iterator import summary_iterator from tensorflow.python.summary.writer.writer import FileWriter as _FileWriter -from tensorflow.python.summary.writer.writer_cache import SummaryWriterCache +from tensorflow.python.summary.writer.writer_cache import FileWriterCache as SummaryWriterCache # pylint: enable=unused-import from tensorflow.python.util.deprecation import deprecated diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py index 474d417947..dda0166aa6 100644 --- a/tensorflow/python/training/supervisor_test.py +++ b/tensorflow/python/training/supervisor_test.py @@ -418,7 +418,7 @@ class SupervisorTest(tf.test.TestCase): tf.summary.scalar("c2", tf.constant(2)) tf.summary.scalar("c3", tf.constant(3)) summ = tf.summary.merge_all() - sw = tf.train.SummaryWriter(logdir) + sw = tf.summary.FileWriter(logdir) sv = tf.train.Supervisor(logdir="", summary_op=None, summary_writer=sw) meta_graph_def = meta_graph.create_meta_graph_def() sess = sv.prepare_or_wait_for_session("") diff --git a/tensorflow/python/training/tensorboard_logging_test.py b/tensorflow/python/training/tensorboard_logging_test.py index dd0ee372f9..286062cab7 100644 --- a/tensorflow/python/training/tensorboard_logging_test.py +++ b/tensorflow/python/training/tensorboard_logging_test.py @@ -35,7 +35,7 @@ class EventLoggingTest(tf.test.TestCase): def setUp(self): self._work_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) - self._sw = tf.train.SummaryWriter(self._work_dir) + self._sw = tf.summary.FileWriter(self._work_dir) tensorboard_logging.set_summary_writer(self._sw) self.addCleanup(shutil.rmtree, self._work_dir) diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py index 561a58d594..8e0d9bbb06 100644 --- a/tensorflow/python/util/deprecation.py +++ b/tensorflow/python/util/deprecation.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import functools import inspect import re @@ -114,7 +115,11 @@ def deprecated(date, instructions): return deprecated_wrapper -def deprecated_args(date, instructions, *deprecated_arg_names): +DeprecatedArgSpec = collections.namedtuple( + 'DeprecatedArgSpec', ['position', 'has_ok_value', 'ok_value']) + + +def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples): """Decorator for marking specific function arguments as deprecated. This decorator logs a deprecation warning whenever the decorated function is @@ -135,32 +140,77 @@ def deprecated_args(date, instructions, *deprecated_arg_names): ISO 8601 (YYYY-MM-DD). instructions: String. Instructions on how to update code using the deprecated function. - *deprecated_arg_names: String. The deprecated arguments. + *deprecated_arg_names_or_tuples: String. or 2-Tuple(String, + [ok_vals]). The string is the deprecated argument name. + Optionally, an ok-value may be provided. If the user provided + argument equals this value, the warning is suppressed. Returns: Decorated function or method. Raises: - ValueError: If date is not in ISO 8601 format, instructions are empty, or - the deprecated arguments are not present in the function signature. + ValueError: If date is not in ISO 8601 format, instructions are + empty, the deprecated arguments are not present in the function + signature, or the second element of a deprecated_tuple is not a + list. """ _validate_deprecation_args(date, instructions) - if not deprecated_arg_names: + if not deprecated_arg_names_or_tuples: raise ValueError('Specify which argument is deprecated.') + def _get_arg_names_to_ok_vals(): + """Returns a dict mapping arg_name to DeprecatedArgSpec w/o position.""" + d = {} + for name_or_tuple in deprecated_arg_names_or_tuples: + if isinstance(name_or_tuple, tuple): + d[name_or_tuple[0]] = DeprecatedArgSpec(-1, True, name_or_tuple[1]) + else: + d[name_or_tuple] = DeprecatedArgSpec(-1, False, None) + return d + + def _get_deprecated_positional_arguments(names_to_ok_vals, arg_spec): + """Builds a dictionary from deprecated arguments to thier spec. + + Returned dict is keyed by argument name. + Each value is a DeprecatedArgSpec with the following fields: + position: The zero-based argument position of the argument + within the signature. None if the argument isn't found in + the signature. + ok_values: Values of this argument for which warning will be + suppressed. + + Args: + names_to_ok_vals: dict from string arg_name to a list of values, + possibly empty, which should not elicit a warning. + arg_spec: Output from inspect.getargspec on the called function. + + Returns: + Dictionary from arg_name to DeprecatedArgSpec. + """ + arg_name_to_pos = dict( + (name, pos) for (pos, name) in enumerate(arg_spec.args)) + deprecated_positional_args = {} + for arg_name, spec in iter(names_to_ok_vals.items()): + if arg_name in arg_name_to_pos: + pos = arg_name_to_pos[arg_name] + deprecated_positional_args[arg_name] = DeprecatedArgSpec( + pos, spec.has_ok_value, spec.ok_value) + return deprecated_positional_args + def deprecated_wrapper(func): """Deprecation decorator.""" decorator_utils.validate_callable(func, 'deprecated_args') + deprecated_arg_names = _get_arg_names_to_ok_vals() arg_spec = inspect.getargspec(func) - deprecated_positions = [ - (i, arg_name) for (i, arg_name) in enumerate(arg_spec.args) - if arg_name in deprecated_arg_names] + deprecated_positions = _get_deprecated_positional_arguments( + deprecated_arg_names, arg_spec) + is_varargs_deprecated = arg_spec.varargs in deprecated_arg_names is_kwargs_deprecated = arg_spec.keywords in deprecated_arg_names if (len(deprecated_positions) + is_varargs_deprecated + is_kwargs_deprecated - != len(deprecated_arg_names)): + != len(deprecated_arg_names_or_tuples)): known_args = arg_spec.args + [arg_spec.varargs, arg_spec.keywords] missing_args = [arg_name for arg_name in deprecated_arg_names if arg_name not in known_args] @@ -172,15 +222,21 @@ def deprecated_args(date, instructions, *deprecated_arg_names): def new_func(*args, **kwargs): """Deprecation wrapper.""" invalid_args = [] - for (i, arg_name) in deprecated_positions: - if i < len(args): + named_args = inspect.getcallargs(func, *args, **kwargs) + for arg_name, spec in iter(deprecated_positions.items()): + if (spec.position < len(args) and + not (spec.has_ok_value and + named_args[arg_name] == spec.ok_value)): invalid_args.append(arg_name) if is_varargs_deprecated and len(args) > len(arg_spec.args): invalid_args.append(arg_spec.varargs) if is_kwargs_deprecated and kwargs: invalid_args.append(arg_spec.keywords) for arg_name in deprecated_arg_names: - if arg_name in kwargs: + if (arg_name in kwargs and + not (deprecated_positions[arg_name].has_ok_value and + (named_args[arg_name] == + deprecated_positions[arg_name].ok_value))): invalid_args.append(arg_name) for arg_name in invalid_args: logging.warning( diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py index 791593189f..75bd054d7f 100644 --- a/tensorflow/python/util/deprecation_test.py +++ b/tensorflow/python/util/deprecation_test.py @@ -538,6 +538,39 @@ class DeprecatedArgsTest(tf.test.TestCase): self.assertRegexpMatches(args1[0], r"deprecated and will be removed after") self._assert_subset(set([date, instructions, "d2"]), set(args2[1:])) + @tf.test.mock.patch.object(logging, "warning", autospec=True) + def test_positional_and_named_with_ok_vals(self, mock_warning): + date = "2016-07-04" + instructions = "This is how you update..." + + @deprecation.deprecated_args( + date, + instructions, + ("d1", None), + ("d2", "my_ok_val")) + def _fn(arg0, d1=None, arg1=2, d2=None): + return arg0 + arg1 if d1 else arg1 + arg0 if d2 else arg0 * arg1 + + # Assert calls without the deprecated arguments log nothing. + self.assertEqual(2, _fn(1, arg1=2)) + self.assertEqual(0, mock_warning.call_count) + + # Assert calls with the deprecated arguments log warnings. + self.assertEqual(2, _fn(1, False, 2, d2=False)) + self.assertEqual(2, mock_warning.call_count) + (args1, _) = mock_warning.call_args_list[0] + self.assertRegexpMatches(args1[0], r"deprecated and will be removed after") + self._assert_subset(set([date, instructions, "d1"]), set(args1[1:])) + (args2, _) = mock_warning.call_args_list[1] + self.assertRegexpMatches(args1[0], r"deprecated and will be removed after") + self._assert_subset(set([date, instructions, "d2"]), set(args2[1:])) + + # Assert calls with the deprecated arguments dont log warnings if + # the value matches the 'ok_val'. + mock_warning.reset_mock() + self.assertEqual(3, _fn(1, None, 2, d2="my_ok_val")) + self.assertEqual(0, mock_warning.call_count) + class DeprecatedArgValuesTest(tf.test.TestCase): diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index 322d33ae26..6b31325694 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -349,31 +349,12 @@ bool CUDAExecutor::GetKernelMetadata(CUDAKernel *cuda_kernel, bool CUDAExecutor::Launch(Stream *stream, const ThreadDim &thread_dims, const BlockDim &block_dims, const KernelBase &kernel, - const std::vector<KernelArg> &args) { - CHECK_EQ(kernel.Arity(), args.size()); + const KernelArgsArrayBase &args) { + CHECK_EQ(kernel.Arity(), args.number_of_arguments()); CUstream custream = AsCUDAStreamValue(stream); const CUDAKernel *cuda_kernel = AsCUDAKernel(&kernel); CUfunction cufunc = cuda_kernel->AsCUDAFunctionValue(); - std::vector<void *> addrs; - addrs.reserve(args.size()); - int shmem_bytes = 0; - for (size_t i = 0; i < args.size(); i++) { - switch (args[i].type) { - case KernelArg::kNormal: - addrs.push_back(const_cast<void *>( - static_cast<const void *>(args[i].data.begin()))); - break; - case KernelArg::kSharedMemory: - shmem_bytes += args[i].bytes; - break; - default: - LOG(ERROR) << "Invalid kernel arg type passed (" << args[i].type - << ") for arg " << i; - return false; - } - } - // Only perform/print the occupancy check 1x. launched_kernels_mu_.lock(); if (launched_kernels_.find(cufunc) == launched_kernels_.end()) { @@ -389,11 +370,15 @@ bool CUDAExecutor::Launch(Stream *stream, const ThreadDim &thread_dims, CUDADriver::FuncSetCacheConfig(cufunc, cuda_kernel->GetCUDACacheConfig()); } - if (!CUDADriver::LaunchKernel( - GetCudaContext(stream), cufunc, block_dims.x, block_dims.y, - block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, - shmem_bytes, custream, addrs.data(), nullptr /* = extra */)) { - LOG(ERROR) << "failed to launch CUDA kernel with args: " << args.size() + void **kernel_params = const_cast<void **>(args.argument_addresses().data()); + + if (!CUDADriver::LaunchKernel(GetCudaContext(stream), cufunc, block_dims.x, + block_dims.y, block_dims.z, thread_dims.x, + thread_dims.y, thread_dims.z, + args.number_of_shared_bytes(), custream, + kernel_params, nullptr /* = extra */)) { + LOG(ERROR) << "failed to launch CUDA kernel with args: " + << args.number_of_arguments() << "; thread dim: " << thread_dims.ToString() << "; block dim: " << block_dims.ToString(); return false; @@ -849,18 +834,6 @@ bool CUDAExecutor::FillBlockDimLimit(BlockDim *block_dim_limit) const { return true; } -KernelArg CUDAExecutor::DeviceMemoryToKernelArg( - const DeviceMemoryBase &gpu_mem) const { - const void* arg = gpu_mem.opaque(); - const uint8 *arg_ptr = reinterpret_cast<const uint8 *>(&arg); - - KernelArg kernel_arg; - kernel_arg.type = KernelArg::kNormal; - kernel_arg.data = port::InlinedVector<uint8, 4>(arg_ptr, arg_ptr + sizeof(arg)); - kernel_arg.bytes = sizeof(arg); - return kernel_arg; -} - bool CUDAExecutor::SupportsBlas() const { return true; } bool CUDAExecutor::SupportsFft() const { return true; } diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h index 9e01f48781..3959d04439 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h @@ -76,7 +76,7 @@ class CUDAExecutor : public internal::StreamExecutorInterface { bool Launch(Stream *stream, const ThreadDim &thread_dims, const BlockDim &block_dims, const KernelBase &k, - const std::vector<KernelArg> &args) override; + const KernelArgsArrayBase &args) override; void *Allocate(uint64 size) override; @@ -186,9 +186,6 @@ class CUDAExecutor : public internal::StreamExecutorInterface { // will be only partially populated as a result, and an error will be logged. bool FillBlockDimLimit(BlockDim *block_dim_limit) const; - KernelArg DeviceMemoryToKernelArg( - const DeviceMemoryBase &gpu_mem) const override; - bool SupportsBlas() const override; blas::BlasSupport *CreateBlas() override; diff --git a/tensorflow/stream_executor/kernel.h b/tensorflow/stream_executor/kernel.h index 7742e066c7..3e5453e4c9 100644 --- a/tensorflow/stream_executor/kernel.h +++ b/tensorflow/stream_executor/kernel.h @@ -76,9 +76,10 @@ limitations under the License. #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/kernel_cache_config.h" +#include "tensorflow/stream_executor/lib/array_slice.h" +#include "tensorflow/stream_executor/lib/inlined_vector.h" #include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/platform/port.h" -#include "tensorflow/stream_executor/lib/inlined_vector.h" namespace perftools { namespace gputools { @@ -265,24 +266,220 @@ struct IsSharedDeviceMemory<SharedDeviceMemory<U>> { static constexpr bool value = true; }; -// KernelArg encapsulates the information necessary for a back-end executor to -// configure a kernel to launch using the given argument. +// Basic data about a kernel argument. struct KernelArg { - // Indicates the type of an argument: normal, to be passed to the kernel - // in the standard manner, or shared memory, which has distinct - // rules for specification per backend. - enum Type { - kNormal, - kSharedMemory, - } type; - - // The data to pass to the kernel - either a pointer to device memory, or the - // argument value. compact_array is used to prevent smaller args (ex. u8, u64) - // from requiring heap allocation. - port::InlinedVector<uint8, 4> data; - - // The size of this argument in bytes. - uint64 bytes; + bool is_shared; + const void *address; + size_t size; +}; + +// An iterator for traversing all the arguments of a KernelArgsArray. +class KernelArgIterator { + public: + KernelArgIterator(int number_of_argument_addresses, + int number_of_shared_memory_arguments, + const void *const *arg_addresses_data, + const size_t *arg_sizes_data, + const size_t *shmem_bytes_data, + const size_t *shmem_indices_data) + : arg_index_(0), + number_of_arguments_(number_of_argument_addresses + + number_of_shared_memory_arguments), + arg_address_iter_(arg_addresses_data), + arg_size_iter_(arg_sizes_data), + shmem_bytes_iter_(shmem_bytes_data), + shmem_indices_iter_(shmem_indices_data), + shmem_indices_end_(shmem_indices_data + + number_of_shared_memory_arguments) {} + + // Returns true if another argument is present in the iterator. + bool has_next() { return arg_index_ < number_of_arguments_; } + + // Returns the next argument in the iterator. + // + // Returns a default-constructed KernelArg if there is no next argument. + KernelArg next() { + KernelArg result; + if (!has_next()) { + return result; + } else if ((shmem_indices_iter_ != shmem_indices_end_) && + (arg_index_ == *shmem_indices_iter_)) { + result.is_shared = true; + result.address = nullptr; + result.size = *shmem_bytes_iter_; + ++shmem_indices_iter_; + ++shmem_bytes_iter_; + } else { + result.is_shared = false; + result.address = *arg_address_iter_; + result.size = *arg_size_iter_; + ++arg_address_iter_; + ++arg_size_iter_; + } + ++arg_index_; + return result; + } + + private: + int arg_index_; + int number_of_arguments_; + const void *const *arg_address_iter_; + const size_t *arg_size_iter_; + const size_t *shmem_bytes_iter_; + const size_t *shmem_indices_iter_; + const size_t *const shmem_indices_end_; +}; + +// Base class for KernelArgsArray. +// +// Supports all the getter methods that do not depend on the compile-time number +// of arguments template parameter. +// +// This class exists as a way to pass kernel arguments to +// StreamExecutorInterface::Launch. That Launch method is virtual, so it can't +// be templated to accept any KernelArgsArray type, therfore a reference to this +// base type is passed instead. +// +// Performance is not a concern here because each of these methods will be +// called at most once per kernel launch. Past performance concerns with +// KernelArgsArray have been in reference to the argument packing routines which +// are called once per kernel argument. Those packing routines are now handled +// by the templated KernelArgsArray subclass of this class where they can take +// advantage of compile-time knowledge of the number of arguments in order to be +// very efficient. +class KernelArgsArrayBase { + public: + virtual ~KernelArgsArrayBase() = default; + + // Gets the number of arguments added so far, including shared memory + // arguments. + virtual size_t number_of_arguments() const = 0; + + // Gets the total number of shared memory bytes added so far. + virtual uint64 number_of_shared_bytes() const = 0; + + // Gets the list of argument addresses. + virtual port::ArraySlice<const void *> argument_addresses() const = 0; + + // Gets an iterator to the arguments in the array. + virtual KernelArgIterator arg_iterator() const = 0; +}; + +// A list of arguments for a kernel call. +// +// The template parameter kNumArgs is the maximum number of arguments which can +// be stored in the list. +// +// Contains a list of addresses for non-shared-memory arguments and a list of +// sizes for shared-memory arguments. Since the shared-memory arguments may be +// interspersed with the non-shared-memory arguments, it also stores a list of +// the indices at which the shared-memory arguments appeared. +// +// For example, if the argument address list contains {a, b, c, d, e}, the +// shared-memory arguments list contains the sizes of {A, B, C}, and the +// shared-memory indices list contains {0, 3, 5}, then the original list of +// arguments was {A, a, b, B, c, C, d, e}. +// +// This way of storing the arguments makes CUDA kernel calls efficient because +// they only require the argument address list and the total number of shared +// bytes, but it also makes it possible for OpenCL kernel calls because they +// depend on the location of each shared-memory argument and its size. +// +// Note that the code for adding arguments has been identified as a performance +// hotspot in some real-world applications so this structure has been optimized +// for the performance of argument adding. +template <size_t kNumArgs> +class KernelArgsArray : public KernelArgsArrayBase { + public: + explicit KernelArgsArray() + : total_shared_memory_bytes_(0), + number_of_argument_addresses_(0), + number_of_shared_memory_arguments_(0) {} + + // Adds an argument to the list. + // + // Note that the address of the argument is stored, so the input must not go + // out of scope before the instance of this class that calls this method does. + template <typename T> + void add_argument(const T &arg) { + argument_addresses_[number_of_argument_addresses_] = + static_cast<const void *>(&arg); + argument_sizes_[number_of_argument_addresses_] = sizeof(arg); + ++number_of_argument_addresses_; + } + + // Adds a device memory argument to the list. + void add_device_memory_argument(const DeviceMemoryBase &arg) { + const void **copy_ptr = + &device_memory_opaque_pointers_[number_of_argument_addresses_]; + *copy_ptr = arg.opaque(); + argument_addresses_[number_of_argument_addresses_] = copy_ptr; + argument_sizes_[number_of_argument_addresses_] = sizeof(void *); + ++number_of_argument_addresses_; + } + + // Adds a shared memory argument to the list. + // + // The only significant information about a shared argument is its size, so + // that is the only parameter in this function. + void add_shared_bytes(size_t number_of_bytes) { + shared_memory_indices_[number_of_shared_memory_arguments_] = + number_of_argument_addresses_ + number_of_shared_memory_arguments_; + shared_memory_bytes_[number_of_shared_memory_arguments_] = number_of_bytes; + ++number_of_shared_memory_arguments_; + total_shared_memory_bytes_ += number_of_bytes; + } + + // Gets the number of arguments added so far, including shared memory + // arguments. + size_t number_of_arguments() const override { + return number_of_argument_addresses_ + number_of_shared_memory_arguments_; + } + + // Gets the total number of shared memory bytes added so far. + uint64 number_of_shared_bytes() const override { + return total_shared_memory_bytes_; + } + + // Gets the list of argument addresses. + port::ArraySlice<const void *> argument_addresses() const override { + return port::ArraySlice<const void *>(argument_addresses_.data(), + number_of_argument_addresses_); + } + + // Gets an iterator to the arguments in the array. + KernelArgIterator arg_iterator() const override { + return KernelArgIterator( + number_of_argument_addresses_, number_of_shared_memory_arguments_, + argument_addresses_.data(), argument_sizes_.data(), + shared_memory_bytes_.data(), shared_memory_indices_.data()); + } + + private: + // A place to store copies of opaque pointers from device memory arguments. + std::array<const void *, kNumArgs> device_memory_opaque_pointers_; + + // Addresses for non-shared-memory arguments. + std::array<const void *, kNumArgs> argument_addresses_; + + // Sizes for non-shared-memory arguments. + std::array<size_t, kNumArgs> argument_sizes_; + + // Size in bytes for each shared memory argument. + std::array<size_t, kNumArgs> shared_memory_bytes_; + + // Indices in the arguments array for shared memory arguments. + std::array<size_t, kNumArgs> shared_memory_indices_; + + // Total of all shared memory sizes. + size_t total_shared_memory_bytes_; + + // Number of significant entries in argument_addresses_ and argument_sizes_. + size_t number_of_argument_addresses_; + + // Number of significant entries in shared_memory_bytes_ and + // shared_memory_indices_. + size_t number_of_shared_memory_arguments_; }; // Typed variant of KernelBase, like a typed device function pointer. See the @@ -298,6 +495,8 @@ struct KernelArg { template <typename... Params> class TypedKernel : public KernelBase { public: + static constexpr size_t kNumberOfParameters = sizeof...(Params); + // Delegates to KernelBase::KernelBase(), see that constructor. explicit TypedKernel(StreamExecutor *parent) : KernelBase(parent) {} @@ -318,13 +517,19 @@ class TypedKernel : public KernelBase { // // Const refs are taken as parameters on all of the handlers to avoid // implicit type promotion of integers. - void PackParams(std::vector<KernelArg> *args, Params... params) const { + // + // WARNING: as a performance optimization this method may store pointers to + // some of the input parameters in the kernel args structure, so any params + // passed into this method must live at least as long as the kernel args + // structure. + void PackParams(KernelArgsArray<kNumberOfParameters> *args, + Params &... params) const { PackOneParam(args, params...); } template <typename T, typename... RestOfParams> - void PackOneParam(std::vector<KernelArg> *args, const T &arg, - const RestOfParams... rest) const { + void PackOneParam(KernelArgsArray<kNumberOfParameters> *args, const T &arg, + const RestOfParams &... rest) const { PackOneParam(args, arg); PackOneParam(args, rest...); } @@ -334,7 +539,7 @@ class TypedKernel : public KernelBase { // separate implementation below. template <typename T> void PackOneParam( - std::vector<KernelArg> *args, const T &arg, + KernelArgsArray<kNumberOfParameters> *args, const T &arg, typename std::enable_if<!IsDeviceMemoryValueLike<T>::value && !IsDeviceMemoryPointer<T>::value && !IsSharedDeviceMemory<T>::value>::type * = @@ -343,44 +548,40 @@ class TypedKernel : public KernelBase { "cannot pass raw pointer to the device"); static_assert(!std::is_convertible<T, DeviceMemoryBase>::value, "cannot pass device memory as a normal value"); - const uint8 *arg_ptr = reinterpret_cast<const uint8 *>(&arg); - args->emplace_back(KernelArg{ - KernelArg::kNormal, - port::InlinedVector<uint8, 4>{arg_ptr, arg_ptr + sizeof(arg)}, sizeof(arg)}); + args->add_argument(arg); } // DeviceMemoryBase family reference override. template <typename T> void PackOneParam( - std::vector<KernelArg> *args, const T &arg, + KernelArgsArray<kNumberOfParameters> *args, const T &arg, typename std::enable_if<IsDeviceMemoryValueLike<T>::value>::type * = nullptr) const { - args->emplace_back(parent()->DeviceMemoryToKernelArg(arg)); + args->add_device_memory_argument(arg); } // DeviceMemoryBase family pointer override. template <typename T> void PackOneParam( - std::vector<KernelArg> *args, T arg, + KernelArgsArray<kNumberOfParameters> *args, T arg, typename std::enable_if<IsDeviceMemoryPointer<T>::value>::type * = nullptr) const { DeviceMemoryBase *ptr = static_cast<DeviceMemoryBase *>(arg); - args->emplace_back(parent()->DeviceMemoryToKernelArg(*ptr)); + args->add_device_memory_argument(*ptr); } // Dynamic shared device memory has a size, but no associated allocation on // the host; internally, the device will allocate storage. template <typename T> void PackOneParam( - std::vector<KernelArg> *args, T arg, + KernelArgsArray<kNumberOfParameters> *args, T arg, typename std::enable_if<IsSharedDeviceMemory<T>::value>::type * = nullptr) const { - args->emplace_back(KernelArg{KernelArg::kSharedMemory, - port::InlinedVector<uint8, 4>(), arg.size()}); + args->add_shared_bytes(arg.size()); } // Base case for variadic template expansion - nothing to do! - void PackOneParam(std::vector<KernelArg> *args) const {} + void PackOneParam(KernelArgsArray<kNumberOfParameters> *args) const {} SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel); }; diff --git a/tensorflow/stream_executor/stream_executor.h b/tensorflow/stream_executor/stream_executor.h index dd4664849d..2995dccf46 100644 --- a/tensorflow/stream_executor/stream_executor.h +++ b/tensorflow/stream_executor/stream_executor.h @@ -18,34 +18,6 @@ limitations under the License. // * Loading/launching data-parallel-kernels // * Invoking pre-canned high-performance library routines (like matrix // multiply) -// -// The appropriately-typed kernel and "loader spec" are automatically generated -// for the user within a namespace by the gcudacc compiler output, so typical -// use looks like so: -// -// namespace gpu = ::perftools::gputools; -// namespace gcudacc = ::platforms::gpus::gcudacc; -// -// gpu::StreamExecutor stream_exec{PlatformKind::kCuda}; -// gcudacc::kernel::MyKernel my_kernel{&stream_exec}; -// bool ok = stream_exec.GetKernel(gcudacc::spec::MyKernelSpec(), -// &my_kernel); -// if (!ok) { ... } -// gpu::DeviceMemory<int> result = stream_exec.AllocateZeroed<int>(); -// if (result == nullptr) { ... } -// int host_result; -// gpu::Stream my_stream{&stream_exec}; -// my_stream -// .Init() -// .ThenLaunch(ThreadDim{1024}, BlockDim{1}, my_kernel, result) -// .ThenMemcpy(&host_result, result, sizeof(host_result)) -// .BlockHostUntilDone() -// if (!my_stream.ok()) { ... } -// printf("%d\n", host_result); -// -// Since the device may operate asynchronously to the host, the -// Stream::BlockHostUntilDone() call forces the calling host thread to wait for -// the chain of commands specified for the Stream to complete execution. #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h index acdbb07cb7..57db7775a6 100644 --- a/tensorflow/stream_executor/stream_executor_internal.h +++ b/tensorflow/stream_executor/stream_executor_internal.h @@ -184,7 +184,7 @@ class StreamExecutorInterface { } virtual bool Launch(Stream *stream, const ThreadDim &thread_dims, const BlockDim &block_dims, const KernelBase &k, - const std::vector<KernelArg> &args) { + const KernelArgsArrayBase &args) { return false; } virtual void *Allocate(uint64 size) = 0; @@ -258,9 +258,6 @@ class StreamExecutorInterface { // caller. virtual DeviceDescription *PopulateDeviceDescription() const = 0; - virtual KernelArg DeviceMemoryToKernelArg( - const DeviceMemoryBase &gpu_mem) const = 0; - // Attempts to register the provided TraceListener with the device-specific // Executor implementation. When this is called, the PIMPL interface has // already taken ownership of the object and is managing the generic tracing diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index 2fdd1e4b49..7739d31662 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -394,7 +394,7 @@ rng::RngSupport *StreamExecutor::AsRng() { bool StreamExecutor::Launch(Stream *stream, const ThreadDim &thread_dims, const BlockDim &block_dims, const KernelBase &kernel, - const std::vector<KernelArg> &args) { + const KernelArgsArrayBase &args) { SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims, kernel, args); @@ -659,11 +659,6 @@ bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const { return implementation_->DeviceMemoryUsage(free, total); } -KernelArg StreamExecutor::DeviceMemoryToKernelArg( - const DeviceMemoryBase &gpu_mem) const { - return implementation_->DeviceMemoryToKernelArg(gpu_mem); -} - void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) { background_threads_->Schedule(task); } diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index 2b5a70f807..83fd27599e 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -392,7 +392,7 @@ class StreamExecutor { // implementation in StreamExecutorInterface::Launch(). bool Launch(Stream *stream, const ThreadDim &thread_dims, const BlockDim &block_dims, const KernelBase &kernel, - const std::vector<KernelArg> &args); + const KernelArgsArrayBase &args); // Gets-or-creates (creates with memoization) a FftSupport datatype that can // be used to execute FFT routines on the current platform. @@ -427,10 +427,6 @@ class StreamExecutor { // previously registered. bool UnregisterTraceListener(TraceListener* listener); - // Converts a DeviceMemory object into a KernelArg object for passing to the - // device driver for kernel launch. - KernelArg DeviceMemoryToKernelArg(const DeviceMemoryBase &gpu_mem) const; - private: template <typename BeginCallT, typename CompleteCallT, typename ReturnT, typename... BeginArgsT> @@ -758,9 +754,9 @@ inline Stream &Stream::ThenLaunch(ThreadDim thread_dims, BlockDim block_dims, // we pack the variadic parameters passed as ...args into the desired // tuple form and pass that packed form to the StreamExecutor::Launch() // implementation. - std::vector<KernelArg> kernel_args; - kernel_args.reserve(kernel.Arity()); + KernelArgsArray<sizeof...(args)> kernel_args; kernel.PackParams(&kernel_args, args...); + DCHECK(parent_ != nullptr); bool ok = parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args); if (!ok) { diff --git a/tensorflow/stream_executor/trace_listener.h b/tensorflow/stream_executor/trace_listener.h index 804c6ee8fa..88c54f982b 100644 --- a/tensorflow/stream_executor/trace_listener.h +++ b/tensorflow/stream_executor/trace_listener.h @@ -50,7 +50,7 @@ class TraceListener { virtual void LaunchSubmit(Stream* stream, const ThreadDim& thread_dims, const BlockDim& block_dims, const KernelBase& kernel, - const std::vector<KernelArg>& args) {} + const KernelArgsArrayBase& args) {} virtual void SynchronousMemcpyH2DBegin(int64 correlation_id, const void* host_src, int64 size, diff --git a/tensorflow/tensorboard/backend/server_test.py b/tensorflow/tensorboard/backend/server_test.py index 3dd7843e66..596fff2864 100644 --- a/tensorflow/tensorboard/backend/server_test.py +++ b/tensorflow/tensorboard/backend/server_test.py @@ -333,7 +333,7 @@ class TensorboardServerTest(tf.test.TestCase): self.addCleanup(shutil.rmtree, temp_dir) run1_path = os.path.join(temp_dir, 'run1') os.makedirs(run1_path) - writer = tf.train.SummaryWriter(run1_path) + writer = tf.summary.FileWriter(run1_path) histogram_value = tf.HistogramProto(min=0, max=2, diff --git a/tensorflow/tensorboard/components/vz_projector/data.ts b/tensorflow/tensorboard/components/vz_projector/data.ts index 34f275f546..4b5cb7a687 100644 --- a/tensorflow/tensorboard/components/vz_projector/data.ts +++ b/tensorflow/tensorboard/components/vz_projector/data.ts @@ -22,8 +22,7 @@ import {getSearchPredicate, runAsyncTask, shuffle} from './util'; import * as vector from './vector'; export type DistanceFunction = (a: number[], b: number[]) => number; -export type PointAccessor = (index: number) => number; -export type PointAccessors3D = [PointAccessor, PointAccessor, PointAccessor]; +export type ProjectionComponents3D = [string, string, string]; export interface PointMetadata { [key: string]: number|string; } @@ -187,25 +186,6 @@ export class DataSet { return traces; } - getPointAccessors(projection: ProjectionType, components: (number|string)[]): - [PointAccessor, PointAccessor, PointAccessor] { - if (components.length > 3) { - throw new RangeError('components length must be <= 3'); - } - const accessors: [PointAccessor, PointAccessor, PointAccessor] = - [null, null, null]; - const prefix = (projection === 'custom') ? 'linear' : projection; - for (let i = 0; i < components.length; ++i) { - if (components[i] == null) { - continue; - } - accessors[i] = - (index => - this.points[index].projections[prefix + '-' + components[i]]); - } - return accessors; - } - projectionCanBeRendered(projection: ProjectionType): boolean { if (projection !== 'tsne') { return true; @@ -222,8 +202,9 @@ export class DataSet { * @return A subset of the original dataset. */ getSubset(subset?: number[]): DataSet { - let pointsSubset = - subset && subset.length ? subset.map(i => this.points[i]) : this.points; + const pointsSubset = ((subset != null) && (subset.length > 0)) ? + subset.map(i => this.points[i]) : + this.points; let points = pointsSubset.map(dp => { return { metadata: dp.metadata, @@ -302,12 +283,13 @@ export class DataSet { } return newV; }); - for (let j = 0; j < NUM_PCA_COMPONENTS; j++) { - let label = 'pca-' + j; + for (let d = 0; d < NUM_PCA_COMPONENTS; d++) { + let label = 'pca-' + d; this.projections.add(label); - this.points.forEach((d, i) => { - d.projections[label] = pcaVectors[i][j]; - }); + for (let i = 0; i < pcaVectors.length; i++) { + let pointIndex = this.shuffledDataIndices[i]; + this.points[pointIndex].projections[label] = pcaVectors[i][d]; + } } }); } @@ -418,8 +400,8 @@ export type ProjectionType = 'tsne' | 'pca' | 'custom'; export class Projection { constructor( public projectionType: ProjectionType, - public pointAccessors: PointAccessors3D, public dimensionality: number, - public dataSet: DataSet) {} + public projectionComponents: ProjectionComponents3D, + public dimensionality: number, public dataSet: DataSet) {} } export interface ColorOption { @@ -489,6 +471,23 @@ export class State { selectedLabelOption: string; } +export function getProjectionComponents( + projection: ProjectionType, + components: (number|string)[]): ProjectionComponents3D { + if (components.length > 3) { + throw new RangeError('components length must be <= 3'); + } + const projectionComponents: [string, string, string] = [null, null, null]; + const prefix = (projection === 'custom') ? 'linear' : projection; + for (let i = 0; i < components.length; ++i) { + if (components[i] == null) { + continue; + } + projectionComponents[i] = prefix + '-' + components[i]; + } + return projectionComponents; +} + export function stateGetAccessorDimensions(state: State): Array<number|string> { let dimensions: Array<number|string>; switch (state.selectedProjection) { diff --git a/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts b/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts index 7fa9924813..d00973935c 100644 --- a/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts +++ b/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {DataSet, DistanceFunction, PointAccessors3D, Projection, State} from './data'; +import {DataSet, DistanceFunction, Projection, ProjectionComponents3D, State} from './data'; import {NearestEntry} from './knn'; import {ProjectorEventContext} from './projectorEventContext'; import {LabelRenderParams} from './renderContext'; @@ -82,11 +82,14 @@ export class ProjectorScatterPlotAdapter { private selectedPointIndices: number[]; private neighborsOfFirstSelectedPoint: NearestEntry[]; private renderLabelsIn3D: boolean = false; - private labelPointAccessor: (index: number) => string; - private legendPointColorer: (index: number) => string; + private labelPointAccessor: (ds: DataSet, index: number) => string; + private legendPointColorer: (ds: DataSet, index: number) => string; private distanceMetric: DistanceFunction; + private spriteVisualizer: ScatterPlotVisualizerSprites; private labels3DVisualizer: ScatterPlotVisualizer3DLabels; + private canvasLabelsVisualizer: ScatterPlotVisualizerCanvasLabels; + private traceVisualizer: ScatterPlotVisualizerTraces; constructor( scatterPlotContainer: d3.Selection<any>, @@ -102,6 +105,7 @@ export class ProjectorScatterPlotAdapter { (selectedPointIndices, neighbors) => { this.selectedPointIndices = selectedPointIndices; this.neighborsOfFirstSelectedPoint = neighbors; + this.updateScatterPlotPositions(); this.updateScatterPlotAttributes(); this.scatterPlot.render(); }); @@ -124,6 +128,42 @@ export class ProjectorScatterPlotAdapter { this.scatterPlot.render(); } + setDataSet(dataSet: DataSet) { + if (this.projection != null) { + // TODO(nicholsonc): setDataSet needs to go away, the projection is the + // atomic unit of update. + this.projection.dataSet = dataSet; + } + if (this.traceVisualizer != null) { + this.traceVisualizer.setDataSet(dataSet); + } + if (this.canvasLabelsVisualizer != null) { + this.canvasLabelsVisualizer.setDataSet(dataSet); + } + if (this.labels3DVisualizer != null) { + this.labels3DVisualizer.setDataSet(dataSet); + } + if (this.spriteVisualizer == null) { + return; + } + this.spriteVisualizer.clearSpriteAtlas(); + if ((dataSet == null) || (dataSet.spriteAndMetadataInfo == null)) { + return; + } + const metadata = dataSet.spriteAndMetadataInfo; + if ((metadata.spriteImage == null) || (metadata.spriteMetadata == null)) { + return; + } + const n = dataSet.points.length; + const spriteIndices = new Float32Array(n); + for (let i = 0; i < n; ++i) { + spriteIndices[i] = dataSet.points[i].index; + } + this.spriteVisualizer.setSpriteAtlas( + metadata.spriteImage, metadata.spriteMetadata.singleImageDim, + spriteIndices); + } + set3DLabelMode(renderLabelsIn3D: boolean) { this.renderLabelsIn3D = renderLabelsIn3D; this.createVisualizers(renderLabelsIn3D); @@ -131,14 +171,17 @@ export class ProjectorScatterPlotAdapter { this.scatterPlot.render(); } - setLegendPointColorer(legendPointColorer: (index: number) => string) { + setLegendPointColorer( + legendPointColorer: (ds: DataSet, index: number) => string) { this.legendPointColorer = legendPointColorer; } - setLabelPointAccessor(labelPointAccessor: (index: number) => string) { + setLabelPointAccessor( + labelPointAccessor: (ds: DataSet, index: number) => string) { this.labelPointAccessor = labelPointAccessor; if (this.labels3DVisualizer != null) { - this.labels3DVisualizer.setLabelAccessor(labelPointAccessor); + this.labels3DVisualizer.setLabelStrings(this.generate3DLabelsArray( + this.projection.dataSet, labelPointAccessor)); } } @@ -157,10 +200,11 @@ export class ProjectorScatterPlotAdapter { updateScatterPlotPositions() { const ds = (this.projection == null) ? null : this.projection.dataSet; - const accessors = - (this.projection == null) ? null : this.projection.pointAccessors; - const newPositions = this.generatePointPositionArray(ds, accessors); - this.scatterPlot.setPointPositions(ds, newPositions); + const projectionComponents = + (this.projection == null) ? null : this.projection.projectionComponents; + const newPositions = + this.generatePointPositionArray(ds, projectionComponents); + this.scatterPlot.setPointPositions(newPositions); } updateScatterPlotAttributes() { @@ -198,10 +242,10 @@ export class ProjectorScatterPlotAdapter { this.scatterPlot.render(); } - generatePointPositionArray(ds: DataSet, pointAccessors: PointAccessors3D): - Float32Array { + generatePointPositionArray( + ds: DataSet, projectionComponents: ProjectionComponents3D): Float32Array { if (ds == null) { - return new Float32Array(0); + return null; } const xScaler: d3.scale.Linear<number, number> = d3.scale.linear(); @@ -209,8 +253,12 @@ export class ProjectorScatterPlotAdapter { let zScaler: d3.scale.Linear<number, number> = null; { // Determine max and min of each axis of our data. - const xExtent = d3.extent(ds.points, (p, i) => pointAccessors[0](i)); - const yExtent = d3.extent(ds.points, (p, i) => pointAccessors[1](i)); + const xExtent = d3.extent( + ds.points, + (p, i) => ds.points[i].projections[projectionComponents[0]]); + const yExtent = d3.extent( + ds.points, + (p, i) => ds.points[i].projections[projectionComponents[1]]); const range = [-SCATTER_PLOT_CUBE_LENGTH / 2, SCATTER_PLOT_CUBE_LENGTH / 2]; @@ -218,8 +266,10 @@ export class ProjectorScatterPlotAdapter { xScaler.domain(xExtent).range(range); yScaler.domain(yExtent).range(range); - if (pointAccessors[2] != null) { - const zExtent = d3.extent(ds.points, (p, i) => pointAccessors[2](i)); + if (projectionComponents[2] != null) { + const zExtent = d3.extent( + ds.points, + (p, i) => ds.points[i].projections[projectionComponents[2]]); zScaler = d3.scale.linear(); zScaler.domain(zExtent).range(range); } @@ -229,15 +279,18 @@ export class ProjectorScatterPlotAdapter { let dst = 0; ds.points.forEach((d, i) => { - positions[dst++] = xScaler(pointAccessors[0](i)); - positions[dst++] = yScaler(pointAccessors[1](i)); + positions[dst++] = + xScaler(ds.points[i].projections[projectionComponents[0]]); + positions[dst++] = + yScaler(ds.points[i].projections[projectionComponents[1]]); positions[dst++] = 0.0; }); if (zScaler) { dst = 2; ds.points.forEach((d, i) => { - positions[dst] = zScaler(pointAccessors[2](i)); + positions[dst] = + zScaler(ds.points[i].projections[projectionComponents[2]]); dst += 3; }); } @@ -253,7 +306,11 @@ export class ProjectorScatterPlotAdapter { return null; } - const n = selectedPointIndices.length + neighborsOfFirstPoint.length + + const selectedPointCount = + (selectedPointIndices == null) ? 0 : selectedPointIndices.length; + const neighborCount = + (neighborsOfFirstPoint == null) ? 0 : neighborsOfFirstPoint.length; + const n = selectedPointCount + neighborCount + ((hoverPointIndex != null) ? 1 : 0); const visibleLabels = new Uint32Array(n); @@ -261,6 +318,7 @@ export class ProjectorScatterPlotAdapter { const opacityFlags = new Int8Array(n); const fillColors = new Uint8Array(n * 3); const strokeColors = new Uint8Array(n * 3); + const labelStrings: string[] = []; scale.fill(LABEL_SCALE_DEFAULT); opacityFlags.fill(1); @@ -268,6 +326,7 @@ export class ProjectorScatterPlotAdapter { let dst = 0; if (hoverPointIndex != null) { + labelStrings.push(this.labelPointAccessor(ds, hoverPointIndex)); visibleLabels[dst] = hoverPointIndex; scale[dst] = LABEL_SCALE_LARGE; opacityFlags[dst] = 0; @@ -282,11 +341,13 @@ export class ProjectorScatterPlotAdapter { // Selected points { - const n = selectedPointIndices.length; + const n = selectedPointCount; const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_SELECTED); const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_SELECTED); for (let i = 0; i < n; ++i) { - visibleLabels[dst] = selectedPointIndices[i]; + const labelIndex = selectedPointIndices[i]; + labelStrings.push(this.labelPointAccessor(ds, labelIndex)); + visibleLabels[dst] = labelIndex; scale[dst] = LABEL_SCALE_LARGE; opacityFlags[dst] = (n === 1) ? 0 : 1; packRgbIntoUint8Array( @@ -299,11 +360,13 @@ export class ProjectorScatterPlotAdapter { // Neighbors { - const n = neighborsOfFirstPoint.length; + const n = neighborCount; const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_NEIGHBOR); const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_NEIGHBOR); for (let i = 0; i < n; ++i) { - visibleLabels[dst] = neighborsOfFirstPoint[i].index; + const labelIndex = neighborsOfFirstPoint[i].index; + labelStrings.push(this.labelPointAccessor(ds, labelIndex)); + visibleLabels[dst] = labelIndex; packRgbIntoUint8Array( fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]); packRgbIntoUint8Array( @@ -313,8 +376,8 @@ export class ProjectorScatterPlotAdapter { } return new LabelRenderParams( - this.labelPointAccessor, visibleLabels, scale, opacityFlags, - LABEL_FONT_SIZE, fillColors, strokeColors); + visibleLabels, labelStrings, scale, opacityFlags, LABEL_FONT_SIZE, + fillColors, strokeColors); } generatePointScaleFactorArray( @@ -328,9 +391,14 @@ export class ProjectorScatterPlotAdapter { const scale = new Float32Array(ds.points.length); scale.fill(POINT_SCALE_DEFAULT); + const selectedPointCount = + (selectedPointIndices == null) ? 0 : selectedPointIndices.length; + const neighborCount = + (neighborsOfFirstPoint == null) ? 0 : neighborsOfFirstPoint.length; + // Scale up all selected points. { - const n = selectedPointIndices.length; + const n = selectedPointCount; for (let i = 0; i < n; ++i) { const p = selectedPointIndices[i]; scale[p] = POINT_SCALE_SELECTED; @@ -339,7 +407,7 @@ export class ProjectorScatterPlotAdapter { // Scale up the neighbor points. { - const n = neighborsOfFirstPoint.length; + const n = neighborCount; for (let i = 0; i < n; ++i) { const p = neighborsOfFirstPoint[i].index; scale[p] = POINT_SCALE_NEIGHBOR; @@ -355,7 +423,7 @@ export class ProjectorScatterPlotAdapter { } generateLineSegmentColorMap( - ds: DataSet, legendPointColorer: (index: number) => string): + ds: DataSet, legendPointColorer: (ds: DataSet, index: number) => string): {[trace: number]: Float32Array} { let traceColorArrayMap: {[trace: number]: Float32Array} = {}; if (ds == null) { @@ -370,10 +438,10 @@ export class ProjectorScatterPlotAdapter { if (legendPointColorer) { for (let j = 0; j < dataTrace.pointIndices.length - 1; j++) { - const c1 = - new THREE.Color(legendPointColorer(dataTrace.pointIndices[j])); + const c1 = new THREE.Color( + legendPointColorer(ds, dataTrace.pointIndices[j])); const c2 = new THREE.Color( - legendPointColorer(dataTrace.pointIndices[j + 1])); + legendPointColorer(ds, dataTrace.pointIndices[j + 1])); colors[colorIndex++] = c1.r; colors[colorIndex++] = c1.g; colors[colorIndex++] = c1.b; @@ -408,7 +476,9 @@ export class ProjectorScatterPlotAdapter { return new Float32Array(0); } const opacities = new Float32Array(ds.traces.length); - if (selectedPoints.length > 0) { + const selectedPointCount = + (selectedPoints == null) ? 0 : selectedPoints.length; + if (selectedPointCount > 0) { opacities.fill(TRACE_DESELECTED_OPACITY); const i = ds.points[selectedPoints[0]].traceIndex; opacities[i] = TRACE_SELECTED_OPACITY; @@ -425,7 +495,9 @@ export class ProjectorScatterPlotAdapter { } const widths = new Float32Array(ds.traces.length); widths.fill(TRACE_DEFAULT_LINEWIDTH); - if (selectedPoints.length > 0) { + const selectedPointCount = + (selectedPoints == null) ? 0 : selectedPoints.length; + if (selectedPointCount > 0) { const i = ds.points[selectedPoints[0]].traceIndex; widths[i] = TRACE_SELECTED_LINEWIDTH; } @@ -433,7 +505,7 @@ export class ProjectorScatterPlotAdapter { } generatePointColorArray( - ds: DataSet, legendPointColorer: (index: number) => string, + ds: DataSet, legendPointColorer: (ds: DataSet, index: number) => string, distFunc: DistanceFunction, selectedPointIndices: number[], neighborsOfFirstPoint: NearestEntry[], hoverPointIndex: number, label3dMode: boolean, spriteImageMode: boolean): Float32Array { @@ -441,6 +513,10 @@ export class ProjectorScatterPlotAdapter { return new Float32Array(0); } + const selectedPointCount = + (selectedPointIndices == null) ? 0 : selectedPointIndices.length; + const neighborCount = + (neighborsOfFirstPoint == null) ? 0 : neighborsOfFirstPoint.length; const colors = new Float32Array(ds.points.length * 3); let unselectedColor = POINT_COLOR_UNSELECTED; @@ -460,7 +536,7 @@ export class ProjectorScatterPlotAdapter { { const n = ds.points.length; let dst = 0; - if (selectedPointIndices.length > 0) { + if (selectedPointCount > 0) { const c = new THREE.Color(unselectedColor); for (let i = 0; i < n; ++i) { colors[dst++] = c.r; @@ -470,7 +546,7 @@ export class ProjectorScatterPlotAdapter { } else { if (legendPointColorer != null) { for (let i = 0; i < n; ++i) { - const c = new THREE.Color(legendPointColorer(i)); + const c = new THREE.Color(legendPointColorer(ds, i)); colors[dst++] = c.r; colors[dst++] = c.g; colors[dst++] = c.b; @@ -488,7 +564,7 @@ export class ProjectorScatterPlotAdapter { // Color the selected points. { - const n = selectedPointIndices.length; + const n = selectedPointCount; const c = new THREE.Color(POINT_COLOR_SELECTED); for (let i = 0; i < n; ++i) { let dst = selectedPointIndices[i] * 3; @@ -500,7 +576,7 @@ export class ProjectorScatterPlotAdapter { // Color the neighbors. { - const n = neighborsOfFirstPoint.length; + const n = neighborCount; let minDist = n > 0 ? neighborsOfFirstPoint[0].dist : 0; for (let i = 0; i < n; ++i) { const c = new THREE.Color( @@ -524,6 +600,19 @@ export class ProjectorScatterPlotAdapter { return colors; } + generate3DLabelsArray( + ds: DataSet, accessor: (ds: DataSet, i: number) => string) { + if ((ds == null) || (accessor == null)) { + return null; + } + let labels: string[] = []; + const n = ds.points.length; + for (let i = 0; i < n; ++i) { + labels.push(accessor(ds, i).toString()); + } + return labels; + } + private updateScatterPlotWithNewProjection(projection: Projection) { if (projection != null) { this.scatterPlot.setDimensions(projection.dimensionality); @@ -543,16 +632,32 @@ export class ProjectorScatterPlotAdapter { const scatterPlot = this.scatterPlot; scatterPlot.removeAllVisualizers(); this.labels3DVisualizer = null; + this.canvasLabelsVisualizer = null; + this.spriteVisualizer = null; + this.traceVisualizer = null; if (inLabels3DMode) { this.labels3DVisualizer = new ScatterPlotVisualizer3DLabels(); - this.labels3DVisualizer.setLabelAccessor(this.labelPointAccessor); - scatterPlot.addVisualizer(this.labels3DVisualizer); + this.labels3DVisualizer.setLabelStrings(this.generate3DLabelsArray( + this.projection.dataSet, this.labelPointAccessor)); } else { - scatterPlot.addVisualizer(new ScatterPlotVisualizerSprites()); - scatterPlot.addVisualizer( - new ScatterPlotVisualizerCanvasLabels(this.scatterPlotContainer)); + this.spriteVisualizer = new ScatterPlotVisualizerSprites(); + scatterPlot.addVisualizer(this.spriteVisualizer); + this.canvasLabelsVisualizer = + new ScatterPlotVisualizerCanvasLabels(this.scatterPlotContainer); + } + this.traceVisualizer = new ScatterPlotVisualizerTraces(); + const dataSet = (this.projection == null) ? null : this.projection.dataSet; + this.setDataSet(dataSet); + if (this.spriteVisualizer) { + scatterPlot.addVisualizer(this.spriteVisualizer); + } + if (this.labels3DVisualizer) { + scatterPlot.addVisualizer(this.labels3DVisualizer); + } + if (this.canvasLabelsVisualizer) { + scatterPlot.addVisualizer(this.canvasLabelsVisualizer); } - scatterPlot.addVisualizer(new ScatterPlotVisualizerTraces()); + scatterPlot.addVisualizer(this.traceVisualizer); } private getSpriteImageMode(): boolean { diff --git a/tensorflow/tensorboard/components/vz_projector/renderContext.ts b/tensorflow/tensorboard/components/vz_projector/renderContext.ts index 27c1310992..2e7e254596 100644 --- a/tensorflow/tensorboard/components/vz_projector/renderContext.ts +++ b/tensorflow/tensorboard/components/vz_projector/renderContext.ts @@ -19,10 +19,10 @@ limitations under the License. */ export class LabelRenderParams { constructor( - public labelAccessor: (index: number) => string, - public pointIndices: Float32Array, public scaleFactors: Float32Array, - public useSceneOpacityFlags: Int8Array, public defaultFontSize: number, - public fillColors: Uint8Array, public strokeColors: Uint8Array) {} + public pointIndices: Float32Array, public labelStrings: string[], + public scaleFactors: Float32Array, public useSceneOpacityFlags: Int8Array, + public defaultFontSize: number, public fillColors: Uint8Array, + public strokeColors: Uint8Array) {} } /** Details about the camera projection being used to render the scene. */ diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts index 3a30a74503..9d9b0b5aff 100644 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts +++ b/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {DataSet} from './data'; import {ProjectorEventContext} from './projectorEventContext'; import {CameraType, LabelRenderParams, RenderContext} from './renderContext'; import {BoundingBox, ScatterPlotRectangleSelector} from './scatterPlotRectangleSelector'; @@ -73,7 +72,6 @@ export class CameraDef { * array of visualizers and dispatches application events to them. */ export class ScatterPlot { - private dataSet: DataSet; private projectorEventContext: ProjectorEventContext; private containerNode: HTMLElement; @@ -104,7 +102,6 @@ export class ScatterPlot { private pointColors: Float32Array; private pointScaleFactors: Float32Array; private labels: LabelRenderParams; - private traceColors: {[trace: number]: Float32Array}; private traceOpacities: Float32Array; private traceWidths: Float32Array; @@ -337,9 +334,6 @@ export class ScatterPlot { * hoverlisteners (usually called from embedding.ts) */ private onMouseMove(e: MouseEvent) { - if (!this.dataSet) { - return; - } this.isDragSequence = this.mouseIsDown; // Depending if we're selecting or just navigating, handle accordingly. if (this.selecting && this.mouseIsDown) { @@ -390,6 +384,10 @@ export class ScatterPlot { */ private getPointIndicesFromPickingTexture(boundingBox: BoundingBox): number[] { + if (this.worldSpacePointPositions == null) { + return null; + } + const pointCount = this.worldSpacePointPositions.length / 3; const dpr = window.devicePixelRatio || 1; const x = Math.floor(boundingBox.x * dpr); const y = Math.floor(boundingBox.y * dpr); @@ -411,7 +409,7 @@ export class ScatterPlot { for (let i = 0; i < width * height; i++) { const id = (pixelBuffer[i * 4] << 16) | (pixelBuffer[i * 4 + 1] << 8) | pixelBuffer[i * 4 + 2]; - if (id !== 0xffffff && (id < this.dataSet.points.length)) { + if (id !== 0xffffff && (id < pointCount)) { pointIndicesSelection[id] = 1; } } @@ -436,12 +434,10 @@ export class ScatterPlot { this.nearestPoint = null; return; } - - let boundingBox: + const boundingBox: BoundingBox = {x: e.offsetX, y: e.offsetY, width: 1, height: 1}; - - let pointIndices = this.getPointIndicesFromPickingTexture(boundingBox); - this.nearestPoint = pointIndices[0]; + const pointIndices = this.getPointIndicesFromPickingTexture(boundingBox); + this.nearestPoint = (pointIndices != null) ? pointIndices[0] : null; } private getLayoutValues(): Point2D { @@ -560,10 +556,7 @@ export class ScatterPlot { visualizer.setScene(this.scene); } visualizer.onResize(this.width, this.height); - if (this.dataSet) { - visualizer.onPointPositionsChanged( - this.worldSpacePointPositions, this.dataSet); - } + visualizer.onPointPositionsChanged(this.worldSpacePointPositions); this.visualizers.push(visualizer); } @@ -574,19 +567,13 @@ export class ScatterPlot { } /** Update scatter plot with a new array of packed xyz point positions. */ - setPointPositions(dataSet: DataSet, worldSpacePointPositions: Float32Array) { - this.dataSet = dataSet; + setPointPositions(worldSpacePointPositions: Float32Array) { this.worldSpacePointPositions = worldSpacePointPositions; - this.visualizers.forEach(v => { - v.onPointPositionsChanged(worldSpacePointPositions, this.dataSet); - }); + this.visualizers.forEach( + v => v.onPointPositionsChanged(worldSpacePointPositions)); } render() { - if (this.dataSet == null) { - return; - } - { const lightPos = this.camera.position.clone(); lightPos.x += 1; @@ -598,9 +585,12 @@ export class ScatterPlot { CameraType.Perspective : CameraType.Orthographic; - const cameraSpacePointExtents: [number, number] = util.getNearFarPoints( - this.worldSpacePointPositions, this.camera.position, - this.orbitCameraControls.target); + let cameraSpacePointExtents: [number, number] = [0, 0]; + if (this.worldSpacePointPositions != null) { + cameraSpacePointExtents = util.getNearFarPoints( + this.worldSpacePointPositions, this.camera.position, + this.orbitCameraControls.target); + } const rc = new RenderContext( this.camera, cameraType, this.orbitCameraControls.target, this.width, @@ -612,9 +602,7 @@ export class ScatterPlot { // with colors that are actually point ids, so that sampling the texture at // the mouse's current x,y coordinates will reveal the data point that the // mouse is over. - this.visualizers.forEach(v => { - v.onPickingRender(rc); - }); + this.visualizers.forEach(v => v.onPickingRender(rc)); { const axes = this.remove3dAxisFromScene(); @@ -625,9 +613,7 @@ export class ScatterPlot { } // Render second pass to color buffer, to be displayed on the canvas. - this.visualizers.forEach(v => { - v.onRender(rc); - }); + this.visualizers.forEach(v => v.onRender(rc)); this.renderer.render(this.scene, this.camera); } @@ -723,9 +709,7 @@ export class ScatterPlot { this.pickingTexture.texture.minFilter = THREE.LinearFilter; } - this.visualizers.forEach(v => { - v.onResize(newW, newH); - }); + this.visualizers.forEach(v => v.onResize(newW, newH)); if (render) { this.render(); diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer.ts index 2d7f5cd640..b0974a2053 100644 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer.ts +++ b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer.ts @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {DataSet} from './data'; import {RenderContext} from './renderContext'; /** @@ -33,8 +32,7 @@ export interface ScatterPlotVisualizer { /** * Called when the positions of the scatter plot points have changed. */ - onPointPositionsChanged( - newWorldSpacePointPositions: Float32Array, dataSet: DataSet); + onPointPositionsChanged(newWorldSpacePointPositions: Float32Array); /** * Called immediately before the main scatter plot performs a picking * (selection) render. Set up render state for any geometry to use picking IDs diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer3DLabels.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer3DLabels.ts index 3811e10c57..ecd2e21403 100644 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer3DLabels.ts +++ b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer3DLabels.ts @@ -98,7 +98,7 @@ type GlyphTexture = { export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer { private dataSet: DataSet; private scene: THREE.Scene; - private labelAccessor: (index: number) => string; + private labelStrings: string[]; private geometry: THREE.BufferGeometry; private worldSpacePointPositions: Float32Array; private pickingColors: Float32Array; @@ -111,6 +111,10 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer { private labelVertexMap: number[][]; private glyphTexture: GlyphTexture; + setDataSet(ds: DataSet) { + this.dataSet = ds; + } + private createGlyphTexture(): GlyphTexture { let canvas = document.createElement('canvas'); canvas.width = MAX_CANVAS_DIMENSION; @@ -139,11 +143,11 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer { return {texture: tex, lengths: glyphLengths, offsets: glyphOffset}; } - private processLabelVerts() { + private processLabelVerts(pointCount: number) { let numTotalLetters = 0; this.labelVertexMap = []; - for (let i = 0; i < this.dataSet.points.length; i++) { - let label: string = this.labelAccessor(i).toString(); + for (let i = 0; i < pointCount; i++) { + const label = this.labelStrings[i]; let vertsArray: number[] = []; for (let j = 0; j < label.length; j++) { for (let k = 0; k < VERTICES_PER_GLYPH; k++) { @@ -156,13 +160,12 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer { this.totalVertexCount = numTotalLetters * VERTICES_PER_GLYPH; } - private createColorBuffers() { - let numPoints = this.dataSet.points.length; + private createColorBuffers(pointCount: number) { this.pickingColors = new Float32Array(this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY); this.renderColors = new Float32Array(this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY); - for (let i = 0; i < numPoints; i++) { + for (let i = 0; i < pointCount; i++) { let color = new THREE.Color(i); this.labelVertexMap[i].forEach((j) => { this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j] = color.r; @@ -175,7 +178,16 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer { } } - private createLabels(dataSet: DataSet) { + private createLabels() { + if ((this.labelStrings == null) || + (this.worldSpacePointPositions == null)) { + return; + } + const pointCount = + this.worldSpacePointPositions.length / XYZ_ELEMENTS_PER_ENTRY; + if (pointCount !== this.labelStrings.length) { + return; + } this.glyphTexture = this.createGlyphTexture(); this.uniforms = { @@ -190,8 +202,8 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer { fragmentShader: FRAGMENT_SHADER, }); - this.processLabelVerts(); - this.createColorBuffers(); + this.processLabelVerts(pointCount); + this.createColorBuffers(pointCount); let positionArray = new Float32Array(this.totalVertexCount * XYZ_ELEMENTS_PER_ENTRY); @@ -215,8 +227,8 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer { this.geometry.addAttribute('color', colors); let lettersSoFar = 0; - for (let i = 0; i < dataSet.points.length; i++) { - let label: string = this.labelAccessor(i).toString(); + for (let i = 0; i < pointCount; i++) { + const label = this.labelStrings[i]; let leftOffset = 0; // Determine length of word in pixels. for (let j = 0; j < label.length; j++) { @@ -262,8 +274,7 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer { } } - const n = dataSet.points.length; - for (let i = 0; i < n; i++) { + for (let i = 0; i < pointCount; i++) { const p = util.vector3FromPackedArray(this.worldSpacePointPositions, i); this.labelVertexMap[i].forEach((j) => { this.positions.setXYZ(j, p.x, p.y, p.z); @@ -276,7 +287,7 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer { } private colorLabels(pointColors: Float32Array) { - if (this.labelAccessor == null || this.geometry == null || + if (this.labelStrings == null || this.geometry == null || this.dataSet == null || pointColors == null) { return; } @@ -319,40 +330,43 @@ export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer { } } - setLabelAccessor(labelAccessor: (index: number) => string) { - this.labelAccessor = labelAccessor; - this.dispose(); - this.onPointPositionsChanged(this.worldSpacePointPositions, this.dataSet); - } - onPickingRender(rc: RenderContext) { + if (this.geometry == null) { + this.createLabels(); + } + if (this.geometry == null) { + return; + } this.material.uniforms.texture.value = this.glyphTexture.texture; this.material.uniforms.picking.value = true; - - let colors = this.geometry.getAttribute('color') as THREE.BufferAttribute; + const colors = this.geometry.getAttribute('color') as THREE.BufferAttribute; colors.array = this.pickingColors; colors.needsUpdate = true; } onRender(rc: RenderContext) { + if (this.geometry == null) { + this.createLabels(); + } + if (this.geometry == null) { + return; + } this.colorLabels(rc.pointColors); - this.material.uniforms.texture.value = this.glyphTexture.texture; this.material.uniforms.picking.value = false; - const colors = this.geometry.getAttribute('color') as THREE.BufferAttribute; colors.array = this.renderColors; colors.needsUpdate = true; } - onPointPositionsChanged(newPositions: Float32Array, dataSet: DataSet) { + onPointPositionsChanged(newPositions: Float32Array) { this.worldSpacePointPositions = newPositions; - this.dataSet = dataSet; this.dispose(); - if ((this.dataSet != null) && (this.labelAccessor != null) && - (this.worldSpacePointPositions != null)) { - this.createLabels(this.dataSet); - } + } + + setLabelStrings(labelStrings: string[]) { + this.labelStrings = labelStrings; + this.dispose(); } onResize(newWidth: number, newHeight: number) {} diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts index 959b077a4d..ef473eda6c 100644 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts +++ b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts @@ -29,6 +29,7 @@ const LABEL_FILL_WIDTH = 6; */ export class ScatterPlotVisualizerCanvasLabels implements ScatterPlotVisualizer { + private dataSet: DataSet; private worldSpacePointPositions: Float32Array; private gc: CanvasRenderingContext2D; private canvas: HTMLCanvasElement; @@ -41,6 +42,10 @@ export class ScatterPlotVisualizerCanvasLabels implements this.canvas.style.pointerEvents = 'none'; } + setDataSet(ds: DataSet) { + this.dataSet = ds; + } + private removeAllLabels() { const pixelWidth = this.canvas.width * window.devicePixelRatio; const pixelHeight = this.canvas.height * window.devicePixelRatio; @@ -49,13 +54,18 @@ export class ScatterPlotVisualizerCanvasLabels implements /** Render all of the non-overlapping visible labels to the canvas. */ private makeLabels(rc: RenderContext) { + if (this.dataSet == null) { + return; + } if ((rc.labels == null) || (rc.labels.pointIndices.length === 0)) { return; } + if (this.worldSpacePointPositions == null) { + return; + } const lrc = rc.labels; const sceneIs3D: boolean = (rc.cameraType === CameraType.Perspective); - const labelHeight = parseInt(this.gc.font, 10); const dpr = window.devicePixelRatio; @@ -87,9 +97,11 @@ export class ScatterPlotVisualizerCanvasLabels implements const n = Math.min(MAX_LABELS_ON_SCREEN, lrc.pointIndices.length); for (let i = 0; i < n; ++i) { - const index = lrc.pointIndices[i]; - const point = - util.vector3FromPackedArray(this.worldSpacePointPositions, index); + let point: THREE.Vector3; + { + const pi = lrc.pointIndices[i]; + point = util.vector3FromPackedArray(this.worldSpacePointPositions, pi); + } // discard points that are behind the camera camToPoint.copy(camPos).sub(point); @@ -112,7 +124,7 @@ export class ScatterPlotVisualizerCanvasLabels implements }; if (grid.insert(textBoundingBox, true)) { - const text = lrc.labelAccessor(index); + const text = lrc.labelStrings[i]; const fontSize = lrc.defaultFontSize * lrc.scaleFactors[i] * dpr; this.gc.font = fontSize + 'px roboto'; @@ -160,7 +172,7 @@ export class ScatterPlotVisualizerCanvasLabels implements this.gc = null; } - onPointPositionsChanged(newPositions: Float32Array, dataSet: DataSet) { + onPointPositionsChanged(newPositions: Float32Array) { this.worldSpacePointPositions = newPositions; this.removeAllLabels(); } diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts index db1fb691fa..1facddba1a 100644 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts +++ b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {DataSet} from './data'; import {CameraType, RenderContext} from './renderContext'; import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; import * as util from './util'; @@ -30,7 +29,7 @@ const XYZ_NUM_ELEMENTS = 3; const VERTEX_SHADER = ` // Index of the specific vertex (passed in as bufferAttribute), and the // variable that will be used to pass it to the fragment shader. - attribute float vertexIndex; + attribute float spriteIndex; attribute vec3 color; attribute float scaleFactor; @@ -39,14 +38,14 @@ const VERTEX_SHADER = ` uniform bool sizeAttenuation; uniform float pointSize; - uniform float imageWidth; - uniform float imageHeight; + uniform float spritesPerRow; + uniform float spritesPerColumn; void main() { // Pass index and color values to fragment shader. vColor = color; - xyIndex = vec2(mod(vertexIndex, imageWidth), - floor(vertexIndex / imageWidth)); + xyIndex = vec2(mod(spriteIndex, spritesPerRow), + floor(spriteIndex / spritesPerColumn)); // Transform current vertex by modelViewMatrix (model world position and // camera world position matrix). @@ -93,8 +92,8 @@ const FRAGMENT_SHADER = ` varying vec3 vColor; uniform sampler2D texture; - uniform float imageWidth; - uniform float imageHeight; + uniform float spritesPerRow; + uniform float spritesPerColumn; uniform bool isImage; ${THREE.ShaderChunk['common']} @@ -104,7 +103,8 @@ const FRAGMENT_SHADER = ` void main() { if (isImage) { // Coordinates of the vertex within the entire sprite image. - vec2 coords = (gl_PointCoord + xyIndex) / vec2(imageWidth, imageHeight); + vec2 coords = + (gl_PointCoord + xyIndex) / vec2(spritesPerRow, spritesPerColumn); gl_FragColor = vec4(vColor, 1.0) * texture2D(texture, coords); } else { bool inside = point_in_unit_circle(gl_PointCoord); @@ -140,11 +140,13 @@ const FRAGMENT_SHADER_PICKING = ` * Uses GL point sprites to render the dataset. */ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer { - private image: HTMLImageElement; - private scene: THREE.Scene; private fog: THREE.Fog; private texture: THREE.Texture = null; + private standinTextureForPoints: THREE.Texture; + private spritesPerRow: number; + private spritesPerColumn: number; + private spriteIndexBufferAttribute: THREE.BufferAttribute; private renderMaterial: THREE.ShaderMaterial; private pickingMaterial: THREE.ShaderMaterial; @@ -153,46 +155,47 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer { private pickingColors: Float32Array; private renderColors: Float32Array; - /** - * Create points, set their locations and actually instantiate the - * geometry. - */ - private createPointSprites( - scene: THREE.Scene, positions: Float32Array, dataSet: DataSet) { - const geometry = - this.createGeometry(positions.length / XYZ_NUM_ELEMENTS, dataSet); + constructor() { + this.standinTextureForPoints = + util.createTexture(document.createElement('canvas')); + this.renderMaterial = this.createRenderMaterial(false); + this.pickingMaterial = this.createPickingMaterial(false); + } - const haveImage = (this.image != null); - this.fog = new THREE.Fog(0xFFFFFF); // unused value, gets overwritten. + private createTextureFromSpriteAtlas( + spriteAtlas: HTMLImageElement, spriteDimensions: [number, number], + spriteIndices: Float32Array) { + this.texture = util.createTexture(spriteAtlas); + this.spritesPerRow = spriteAtlas.width / spriteDimensions[0]; + this.spritesPerColumn = spriteAtlas.height / spriteDimensions[1]; - { - const image = this.image || document.createElement('canvas'); - this.texture = util.createTexture(image); - } + this.spriteIndexBufferAttribute = + new THREE.BufferAttribute(spriteIndices, INDEX_NUM_ELEMENTS); - let imageDim = [1, 1]; - { - const spriteMetadata = dataSet.spriteAndMetadataInfo.spriteMetadata; - if (haveImage && spriteMetadata) { - imageDim[0] = this.image.width / spriteMetadata.singleImageDim[0]; - imageDim[1] = this.image.height / spriteMetadata.singleImageDim[1]; - } + if (this.points != null) { + (this.points.geometry as THREE.BufferGeometry) + .addAttribute('spriteIndex', this.spriteIndexBufferAttribute); } + } - const uniforms = { + private createUniforms(): any { + return { texture: {type: 't'}, - imageWidth: {type: 'f', value: imageDim[0]}, - imageHeight: {type: 'f', value: imageDim[1]}, + spritesPerRow: {type: 'f'}, + spritesPerColumn: {type: 'f'}, fogColor: {type: 'c'}, fogNear: {type: 'f'}, fogFar: {type: 'f'}, - isImage: {type: 'bool', value: haveImage}, + isImage: {type: 'bool'}, sizeAttenuation: {type: 'bool'}, pointSize: {type: 'f'} }; + } - this.renderMaterial = new THREE.ShaderMaterial({ - uniforms: THREE.UniformsUtils.clone(uniforms), + private createRenderMaterial(haveImage: boolean): THREE.ShaderMaterial { + const uniforms = this.createUniforms(); + return new THREE.ShaderMaterial({ + uniforms: uniforms, vertexShader: VERTEX_SHADER, fragmentShader: FRAGMENT_SHADER, transparent: !haveImage, @@ -201,9 +204,12 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer { fog: true, blending: THREE.MultiplyBlending, }); + } - this.pickingMaterial = new THREE.ShaderMaterial({ - uniforms: THREE.UniformsUtils.clone(uniforms), + private createPickingMaterial(haveImage: boolean): THREE.ShaderMaterial { + const uniforms = this.createUniforms(); + return new THREE.ShaderMaterial({ + uniforms: uniforms, vertexShader: VERTEX_SHADER, fragmentShader: FRAGMENT_SHADER_PICKING, transparent: true, @@ -212,17 +218,35 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer { fog: false, blending: THREE.NormalBlending, }); + } + + /** + * Create points, set their locations and actually instantiate the + * geometry. + */ + private createPointSprites(scene: THREE.Scene, positions: Float32Array) { + const pointCount = + (positions != null) ? (positions.length / XYZ_NUM_ELEMENTS) : 0; + const geometry = this.createGeometry(pointCount); + + this.fog = new THREE.Fog(0xFFFFFF); // unused value, gets overwritten. this.points = new THREE.Points(geometry, this.renderMaterial); this.points.frustumCulled = false; + if (this.spriteIndexBufferAttribute != null) { + (this.points.geometry as THREE.BufferGeometry) + .addAttribute('spriteIndex', this.spriteIndexBufferAttribute); + } scene.add(this.points); } private calculatePointSize(sceneIs3D: boolean): number { - if (this.image != null) { + if (this.texture != null) { return IMAGE_SIZE; } - const n = this.worldSpacePointPositions.length / XYZ_NUM_ELEMENTS; + const n = (this.worldSpacePointPositions != null) ? + (this.worldSpacePointPositions.length / XYZ_NUM_ELEMENTS) : + 1; const SCALE = 200; const LOG_BASE = 8; const DIVISOR = 1.5; @@ -234,8 +258,7 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer { /** * Set up buffer attributes to be used for the points/images. */ - private createGeometry(pointCount: number, dataSet: DataSet): - THREE.BufferGeometry { + private createGeometry(pointCount: number): THREE.BufferGeometry { const n = pointCount; // Fill pickingColors with each point's unique id as its color. @@ -250,13 +273,6 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer { } } - const spriteIndexes = - new THREE.BufferAttribute(new Float32Array(n), INDEX_NUM_ELEMENTS); - - for (let i = 0; i < n; i++) { - spriteIndexes.setX(i, dataSet.points[i].index); - } - const geometry = new THREE.BufferGeometry(); geometry.addAttribute( 'position', new THREE.BufferAttribute(null, XYZ_NUM_ELEMENTS)); @@ -264,7 +280,6 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer { 'color', new THREE.BufferAttribute(null, RGB_NUM_ELEMENTS)); geometry.addAttribute( 'scaleFactor', new THREE.BufferAttribute(null, INDEX_NUM_ELEMENTS)); - geometry.addAttribute('vertexIndex', spriteIndexes); return geometry; } @@ -286,55 +301,83 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer { } dispose() { - this.scene.remove(this.points); - this.points.geometry.dispose(); - if (this.renderMaterial.uniforms.texture.value) { - this.renderMaterial.uniforms.texture.value.dispose(); + this.disposeGeometry(); + this.disposeTextureAtlas(); + this.worldSpacePointPositions = null; + } + + private disposeGeometry() { + if (this.points != null) { + this.scene.remove(this.points); + this.points.geometry.dispose(); + this.points = null; + } + } + + private disposeTextureAtlas() { + if (this.texture != null) { + this.texture.dispose(); } - this.points = null; + this.texture = null; this.renderMaterial = null; this.pickingMaterial = null; - this.worldSpacePointPositions = null; - this.image = null; } setScene(scene: THREE.Scene) { this.scene = scene; } - onPointPositionsChanged(newPositions: Float32Array, dataSet: DataSet) { + setSpriteAtlas( + spriteImage: HTMLImageElement, spriteDimensions: [number, number], + spriteIndices: Uint8Array) { + this.disposeTextureAtlas(); + this.createTextureFromSpriteAtlas( + spriteImage, spriteDimensions, spriteIndices); + this.renderMaterial = this.createRenderMaterial(true); + this.pickingMaterial = this.createPickingMaterial(true); + } + + clearSpriteAtlas() { + this.disposeTextureAtlas(); + this.renderMaterial = this.createRenderMaterial(false); + this.pickingMaterial = this.createPickingMaterial(false); + } + + onPointPositionsChanged(newPositions: Float32Array) { + if ((newPositions == null) || (newPositions.length === 0)) { + this.disposeGeometry(); + return; + } + if (this.points != null) { - const notEnoughSpace = (this.pickingColors.length < newPositions.length); - const newImage = (dataSet != null) && - (this.image !== dataSet.spriteAndMetadataInfo.spriteImage); - if (notEnoughSpace || newImage) { - this.dispose(); + const notEnoughSpace = + (this.worldSpacePointPositions.length < newPositions.length); + if (notEnoughSpace) { + this.disposeGeometry(); } } - this.image = - (dataSet != null) ? dataSet.spriteAndMetadataInfo.spriteImage : null; this.worldSpacePointPositions = newPositions; if (this.points == null) { - this.createPointSprites(this.scene, newPositions, dataSet); + this.createPointSprites(this.scene, newPositions); } - if (newPositions) { - const positions = (this.points.geometry as THREE.BufferGeometry) - .getAttribute('position') as THREE.BufferAttribute; - positions.array = newPositions; - positions.needsUpdate = true; - } + const positions = (this.points.geometry as THREE.BufferGeometry) + .getAttribute('position') as THREE.BufferAttribute; + positions.array = newPositions; + positions.needsUpdate = true; } onPickingRender(rc: RenderContext) { - if (!this.points) { + if (this.points == null) { return; } const sceneIs3D: boolean = (rc.cameraType === CameraType.Perspective); + this.pickingMaterial.uniforms.spritesPerRow.value = this.spritesPerRow; + this.pickingMaterial.uniforms.spritesPerRow.value = this.spritesPerColumn; this.pickingMaterial.uniforms.sizeAttenuation.value = sceneIs3D; this.pickingMaterial.uniforms.pointSize.value = this.calculatePointSize(sceneIs3D); @@ -367,7 +410,11 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer { this.renderMaterial.uniforms.fogColor.value = this.scene.fog.color; this.renderMaterial.uniforms.fogNear.value = this.fog.near; this.renderMaterial.uniforms.fogFar.value = this.fog.far; - this.renderMaterial.uniforms.texture.value = this.texture; + this.renderMaterial.uniforms.spritesPerRow.value = this.spritesPerRow; + this.renderMaterial.uniforms.spritesPerColumn.value = this.spritesPerColumn; + this.renderMaterial.uniforms.isImage.value = (this.texture != null); + this.renderMaterial.uniforms.texture.value = + (this.texture != null) ? this.texture : this.standinTextureForPoints; this.renderMaterial.uniforms.sizeAttenuation.value = sceneIs3D; this.renderMaterial.uniforms.pointSize.value = this.calculatePointSize(sceneIs3D); diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerTraces.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerTraces.ts index ec71a93414..a1ff747ff3 100644 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerTraces.ts +++ b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerTraces.ts @@ -69,7 +69,7 @@ export class ScatterPlotVisualizerTraces implements ScatterPlotVisualizer { } dispose() { - if (!this.traces) { + if (this.traces == null) { return; } for (let i = 0; i < this.traces.length; i++) { @@ -85,32 +85,30 @@ export class ScatterPlotVisualizerTraces implements ScatterPlotVisualizer { this.scene = scene; } - onPointPositionsChanged(newPositions: Float32Array, dataSet: DataSet) { + setDataSet(dataSet: DataSet) { this.dataSet = dataSet; - if (dataSet == null) { + } + + onPointPositionsChanged(newPositions: Float32Array) { + if ((newPositions == null) || (this.traces != null)) { + this.dispose(); + } + if ((newPositions == null) || (this.dataSet == null)) { return; } + // Set up the position buffer arrays for each trace. + for (let i = 0; i < this.dataSet.traces.length; i++) { + let dataTrace = this.dataSet.traces[i]; + const vertexCount = 2 * (dataTrace.pointIndices.length - 1); - if ((this.traces == null) || - (this.traces.length !== dataSet.traces.length)) { - if (this.traces != null) { - this.dispose(); - } - // Set up the position buffer arrays for each trace. - for (let i = 0; i < this.dataSet.traces.length; i++) { - let dataTrace = this.dataSet.traces[i]; - const vertexCount = 2 * (dataTrace.pointIndices.length - 1); - - let traces = new Float32Array(vertexCount * XYZ_NUM_ELEMENTS); - this.tracePositionBuffer[i] = - new THREE.BufferAttribute(traces, XYZ_NUM_ELEMENTS); - - let colors = new Float32Array(vertexCount * RGB_NUM_ELEMENTS); - this.traceColorBuffer[i] = - new THREE.BufferAttribute(colors, RGB_NUM_ELEMENTS); - } - } + let traces = new Float32Array(vertexCount * XYZ_NUM_ELEMENTS); + this.tracePositionBuffer[i] = + new THREE.BufferAttribute(traces, XYZ_NUM_ELEMENTS); + let colors = new Float32Array(vertexCount * RGB_NUM_ELEMENTS); + this.traceColorBuffer[i] = + new THREE.BufferAttribute(colors, RGB_NUM_ELEMENTS); + } for (let i = 0; i < this.dataSet.traces.length; i++) { const dataTrace = this.dataSet.traces[i]; let src = 0; diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.ts index 0de423cc9e..308e1685d2 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.ts +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.ts @@ -64,11 +64,12 @@ export class BookmarkPanel extends BookmarkPanelPolymer { setSelectedTensor( run: string, tensorInfo: EmbeddingInfo, dataProvider: DataProvider) { + // Clear any existing bookmarks. + this.addStates(null); if (tensorInfo && tensorInfo.bookmarksPath) { - this.loadAllStates([]); // Get any bookmarks that may come when the projector starts up. dataProvider.getBookmarks(run, tensorInfo.tensorName, bookmarks => { - this.loadAllStates(bookmarks); + this.addStates(bookmarks); }); } } @@ -145,7 +146,7 @@ export class BookmarkPanel extends BookmarkPanelPolymer { // Verify the bookmarks match. if (this.savedStatesValid(savedStates)) { - this.loadAllStates(savedStates); + this.addStates(savedStates); this.loadSavedState(0); } else { logging.setWarningMessage( @@ -157,10 +158,14 @@ export class BookmarkPanel extends BookmarkPanelPolymer { }); } - loadAllStates(savedStates: State[]) { - for (let i = 0; i < savedStates.length; i++) { - savedStates[i].isSelected = false; - this.push('savedStates', savedStates[i] as any); + addStates(savedStates?: State[]) { + if (savedStates == null) { + this.savedStates = []; + } else { + for (let i = 0; i < savedStates.length; i++) { + savedStates[i].isSelected = false; + this.push('savedStates', savedStates[i] as any); + } } this.updateHasStates(); } @@ -168,34 +173,35 @@ export class BookmarkPanel extends BookmarkPanelPolymer { /** Deselects any selected state selection. */ clearStateSelection() { for (let i = 0; i < this.savedStates.length; i++) { - if (this.savedStates[i].isSelected) { - this.savedStates[i].isSelected = false; - this.notifyPath('savedStates.' + i + '.isSelected', false, false); - return; - } + this.setSelectionState(i, false); } } /** Handles a radio button click on a saved state. */ _radioButtonHandler(evt: Event) { - this.loadSavedState(this.getParentDataIndex(evt)); + const index = this.getParentDataIndex(evt); + this.loadSavedState(index); + this.setSelectionState(index, true); } loadSavedState(index: number) { for (let i = 0; i < this.savedStates.length; i++) { if (this.savedStates[i].isSelected) { - this.savedStates[i].isSelected = false; - this.notifyPath('savedStates.' + i + '.isSelected', false, false); + this.setSelectionState(i, false); } else if (index === i) { - this.savedStates[i].isSelected = true; - this.notifyPath('savedStates.' + i + '.isSelected', true, false); - + this.setSelectionState(i, true); this.ignoreNextProjectionEvent = true; this.projector.loadState(this.savedStates[i]); } } } + private setSelectionState(stateIndex: number, selected: boolean) { + this.savedStates[stateIndex].isSelected = selected; + const path = 'savedStates.' + stateIndex + '.isSelected'; + this.notifyPath(path, selected, false); + } + /** * Crawls up the DOM to find an ancestor with a data-index attribute. This is * used to match events to their bookmark index. diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts index 3b40ec27ce..32e9b0a724 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {DataSet, PCA_SAMPLE_DIM, PCA_SAMPLE_SIZE, Projection, ProjectionType, SpriteAndMetadataInfo, State, TSNE_SAMPLE_SIZE} from './data'; +import * as data from './data'; +import {DataSet, Projection, ProjectionType, SpriteAndMetadataInfo, State} from './data'; import * as vector from './vector'; import {Vector} from './vector'; import {Projector} from './vz-projector'; @@ -289,18 +290,17 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { this.dataSet = dataSet; this.originalDataSet = originalDataSet; this.dim = dim; - let perplexity = - Math.max(5, Math.ceil(Math.sqrt(dataSet.points.length) / 4)); + const pointCount = (dataSet == null) ? 0 : dataSet.points.length; + const perplexity = Math.max(5, Math.ceil(Math.sqrt(pointCount) / 4)); this.perplexitySlider.value = perplexity.toString(); this.updateTSNEPerplexityFromSliderChange(); this.clearCentroids(); this.dom.select('#tsne-sampling') - .style( - 'display', - dataSet.points.length > TSNE_SAMPLE_SIZE ? null : 'none'); - let wasSampled = - dataSet.dim[0] > PCA_SAMPLE_SIZE || dataSet.dim[1] > PCA_SAMPLE_DIM; + .style('display', pointCount > data.TSNE_SAMPLE_SIZE ? null : 'none'); + const wasSampled = + (dataSet == null) ? false : (dataSet.dim[0] > data.PCA_SAMPLE_DIM || + dataSet.dim[1] > data.PCA_SAMPLE_DIM); this.dom.select('#pca-sampling') .style('display', wasSampled ? null : 'none'); this.showTab('pca'); @@ -374,7 +374,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { return; } const accessors = - dataSet.getPointAccessors('tsne', [0, 1, this.tSNEis3d ? 2 : null]); + data.getProjectionComponents('tsne', [0, 1, this.tSNEis3d ? 2 : null]); const dimensionality = this.tSNEis3d ? 3 : 2; const projection = new Projection('tsne', accessors, dimensionality, dataSet); @@ -427,7 +427,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { } this.dataSet.projectPCA().then(() => { // Polymer properties are 1-based. - const accessors = this.dataSet.getPointAccessors( + const accessors = data.getProjectionComponents( 'pca', [this.pcaX, this.pcaY, this.pcaZ]); const dimensionality = this.pcaIs3d ? 3 : 2; @@ -459,7 +459,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { const yDir = vector.sub(this.centroids.yUp, this.centroids.yDown); this.dataSet.projectLinear(yDir, 'linear-y'); - const accessors = this.dataSet.getPointAccessors('custom', ['x', 'y']); + const accessors = data.getProjectionComponents('custom', ['x', 'y']); const projection = new Projection('custom', accessors, 2, this.dataSet); this.projector.setProjection(projection); } @@ -543,15 +543,15 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { } getPcaSampledDimText() { - return PCA_SAMPLE_DIM.toLocaleString(); + return data.PCA_SAMPLE_DIM.toLocaleString(); } getPcaSampleSizeText() { - return PCA_SAMPLE_SIZE.toLocaleString(); + return data.PCA_SAMPLE_SIZE.toLocaleString(); } getTsneSampleSizeText() { - return TSNE_SAMPLE_SIZE.toLocaleString(); + return data.TSNE_SAMPLE_SIZE.toLocaleString(); } } diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector.ts index 655a75b7f4..14ea58b24a 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector.ts @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ import {AnalyticsLogger} from './analyticsLogger'; +import * as data from './data'; import {ColorOption, ColumnStats, DataPoint, DataProto, DataSet, DistanceFunction, PointMetadata, Projection, SpriteAndMetadataInfo, State, stateGetAccessorDimensions} from './data'; import {DataProvider, EmbeddingInfo, ServingMode} from './data-provider'; import {DemoDataProvider} from './data-provider-demo'; @@ -68,6 +69,7 @@ export class Projector extends ProjectorPolymer implements private distanceMetricChangedListeners: DistanceMetricChangedListener[]; private originalDataSet: DataSet; + private dataSetBeforeFilter: DataSet; private dom: d3.Selection<any>; private projectorScatterPlotAdapter: ProjectorScatterPlotAdapter; private dim: number; @@ -125,10 +127,8 @@ export class Projector extends ProjectorPolymer implements setSelectedLabelOption(labelOption: string) { this.selectedLabelOption = labelOption; - let labelAccessor = (i: number): string => { - return this.dataSet.points[i] - .metadata[this.selectedLabelOption] as string; - }; + const labelAccessor = (ds: DataSet, i: number): string => + ds.points[i].metadata[this.selectedLabelOption] as string; this.metadataCard.setLabelOption(this.selectedLabelOption); this.projectorScatterPlotAdapter.setLabelPointAccessor(labelAccessor); this.projectorScatterPlotAdapter.render(); @@ -152,30 +152,41 @@ export class Projector extends ProjectorPolymer implements metadataFile?: string) { this.dataSetFilterIndices = null; this.originalDataSet = ds; - if (this.projectorScatterPlotAdapter == null || ds == null) { - return; + if (ds != null) { + this.normalizeData = + this.originalDataSet.dim[1] >= THRESHOLD_DIM_NORMALIZE; + spriteAndMetadata = spriteAndMetadata || {}; + if (spriteAndMetadata.pointsInfo == null) { + let [pointsInfo, stats] = this.makeDefaultPointsInfoAndStats(ds.points); + spriteAndMetadata.pointsInfo = pointsInfo; + spriteAndMetadata.stats = stats; + } + ds.mergeMetadata(spriteAndMetadata); } - this.normalizeData = this.originalDataSet.dim[1] >= THRESHOLD_DIM_NORMALIZE; - spriteAndMetadata = spriteAndMetadata || {}; - if (spriteAndMetadata.pointsInfo == null) { - let [pointsInfo, stats] = this.makeDefaultPointsInfoAndStats(ds.points); - spriteAndMetadata.pointsInfo = pointsInfo; - spriteAndMetadata.stats = stats; + if (this.projectorScatterPlotAdapter != null) { + if (ds == null) { + this.setProjection(null); + } + this.projectorScatterPlotAdapter.updateScatterPlotPositions(); + this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); + this.projectorScatterPlotAdapter.resize(); + this.projectorScatterPlotAdapter.render(); + } + if (ds != null) { + this.dataPanel.setNormalizeData(this.normalizeData); + this.setCurrentDataSet(ds.getSubset()); + this.inspectorPanel.datasetChanged(); + + this.inspectorPanel.metadataChanged(spriteAndMetadata); + this.projectionsPanel.metadataChanged(spriteAndMetadata); + this.dataPanel.metadataChanged(spriteAndMetadata, metadataFile); + // Set the container to a fixed height, otherwise in Colab the + // height can grow indefinitely. + let container = this.dom.select('#container'); + container.style('height', container.property('clientHeight') + 'px'); + } else { + this.setCurrentDataSet(null); } - ds.mergeMetadata(spriteAndMetadata); - this.dataPanel.setNormalizeData(this.normalizeData); - this.setCurrentDataSet(this.originalDataSet.getSubset()); - this.inspectorPanel.datasetChanged(); - - this.inspectorPanel.metadataChanged(spriteAndMetadata); - this.projectionsPanel.metadataChanged(spriteAndMetadata); - this.dataPanel.metadataChanged(spriteAndMetadata, metadataFile); - // Set the container to a fixed height, otherwise in Colab the - // height can grow indefinitely. - let container = this.dom.select('#container'); - container.style('height', container.property('clientHeight') + 'px'); - this.projectorScatterPlotAdapter.resize(); - this.projectorScatterPlotAdapter.render(); } setSelectedTensor(run: string, tensorInfo: EmbeddingInfo) { @@ -191,17 +202,26 @@ export class Projector extends ProjectorPolymer implements filterDataset(pointIndices: number[]) { const selectionSize = this.selectedPointIndices.length; + if (this.dataSetBeforeFilter == null) { + this.dataSetBeforeFilter = this.dataSet; + } this.setCurrentDataSet(this.dataSet.getSubset(pointIndices)); this.dataSetFilterIndices = pointIndices; + this.projectorScatterPlotAdapter.updateScatterPlotPositions(); + this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); this.adjustSelectionAndHover(d3.range(selectionSize)); } resetFilterDataset() { - let originalPointIndices = this.selectedPointIndices.map(localIndex => { - return this.dataSet.points[localIndex].index; - }); - this.setCurrentDataSet(this.originalDataSet.getSubset()); + const originalPointIndices = this.selectedPointIndices.map( + filteredIndex => this.dataSet.points[filteredIndex].index); + this.setCurrentDataSet(this.dataSetBeforeFilter); + if (this.projection != null) { + this.projection.dataSet = this.dataSetBeforeFilter; + } + this.dataSetBeforeFilter = null; this.projectorScatterPlotAdapter.updateScatterPlotPositions(); + this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); this.dataSetFilterIndices = []; this.adjustSelectionAndHover(originalPointIndices); } @@ -306,13 +326,12 @@ export class Projector extends ProjectorPolymer implements } private getLegendPointColorer(colorOption: ColorOption): - (index: number) => string { + (ds: DataSet, index: number) => string { if ((colorOption == null) || (colorOption.map == null)) { return null; } - const colorer = (i: number) => { - let value = - this.dataSet.points[i].metadata[this.selectedColorOption.name]; + const colorer = (ds: DataSet, i: number) => { + let value = ds.points[i].metadata[this.selectedColorOption.name]; if (value == null) { return POINT_COLOR_MISSING; } @@ -347,19 +366,19 @@ export class Projector extends ProjectorPolymer implements if (this.dataSet != null) { this.dataSet.stopTSNE(); } - this.dataSet = ds; - if (this.normalizeData) { - this.dataSet.normalize(); + if ((ds != null) && this.normalizeData) { + ds.normalize(); } - this.dim = this.dataSet.dim[1]; - this.dom.select('span.numDataPoints').text(this.dataSet.dim[0]); - this.dom.select('span.dim').text(this.dataSet.dim[1]); + this.dim = (ds == null) ? 0 : ds.dim[1]; + this.dom.select('span.numDataPoints').text((ds == null) ? '0' : ds.dim[0]); + this.dom.select('span.dim').text((ds == null) ? '0' : ds.dim[1]); - this.projection = null; + this.dataSet = ds; this.projectionsPanel.dataSetUpdated( this.dataSet, this.originalDataSet, this.dim); + this.projectorScatterPlotAdapter.setDataSet(this.dataSet); this.projectorScatterPlotAdapter.scatterPlot .setCameraParametersForNextCameraCreation(null, true); } @@ -494,7 +513,9 @@ export class Projector extends ProjectorPolymer implements this.setProjection(null); { this.projectionsPanel.disablePolymerChangesTriggerReprojection(); - this.resetFilterDataset(); + if (this.dataSetBeforeFilter != null) { + this.resetFilterDataset(); + } if (state.filteredPoints != null) { this.filterDataset(state.filteredPoints); } @@ -517,10 +538,11 @@ export class Projector extends ProjectorPolymer implements this.projectorScatterPlotAdapter.restoreUIFromBookmark(state); { const dimensions = stateGetAccessorDimensions(state); - const accessors = - this.dataSet.getPointAccessors(state.selectedProjection, dimensions); + const components = + data.getProjectionComponents(state.selectedProjection, dimensions); const projection = new Projection( - state.selectedProjection, accessors, dimensions.length, this.dataSet); + state.selectedProjection, components, dimensions.length, + this.dataSet); this.setProjection(projection); } this.notifySelectionChanged(state.selectedPoints); diff --git a/tensorflow/tensorboard/scripts/generate_testdata.py b/tensorflow/tensorboard/scripts/generate_testdata.py index 78739aa7cf..2d305a3b84 100644 --- a/tensorflow/tensorboard/scripts/generate_testdata.py +++ b/tensorflow/tensorboard/scripts/generate_testdata.py @@ -110,7 +110,7 @@ def WriteImageSeries(writer, tag, n_images=1): step = 0 session = tf.Session() p = tf.placeholder("uint8", (1, 4, 4, 3)) - s = tf.image_summary(tag, p) + s = tf.contrib.deprecated.image_summary(tag, p) for _ in xrange(n_images): im = np.random.random_integers(0, 255, (1, 4, 4, 3)) summ = session.run(s, feed_dict={p: im}) @@ -133,7 +133,7 @@ def WriteAudioSeries(writer, tag, n_audio=1): p = tf.placeholder("float32", (frequencies_per_run, duration_frames, num_channels)) - s = tf.audio_summary(tag, p, sample_rate) + s = tf.contrib.deprecated.audio_summary(tag, p, sample_rate) for _ in xrange(n_audio): # Generate a different frequency for each channel to show stereo works. @@ -158,7 +158,7 @@ def GenerateTestData(path): """Generates the test data directory.""" run1_path = os.path.join(path, "run1") os.makedirs(run1_path) - writer1 = tf.train.SummaryWriter(run1_path) + writer1 = tf.summary.FileWriter(run1_path) WriteScalarSeries(writer1, "foo/square", lambda x: x * x) WriteScalarSeries(writer1, "bar/square", lambda x: x * x) WriteScalarSeries(writer1, "foo/sin", math.sin) @@ -171,7 +171,7 @@ def GenerateTestData(path): run2_path = os.path.join(path, "run2") os.makedirs(run2_path) - writer2 = tf.train.SummaryWriter(run2_path) + writer2 = tf.summary.FileWriter(run2_path) WriteScalarSeries(writer2, "foo/square", lambda x: x * x * 2) WriteScalarSeries(writer2, "bar/square", lambda x: x * x * 3) WriteScalarSeries(writer2, "foo/cos", lambda x: math.cos(x) * 2) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 4951d9da81..502d698468 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -147,8 +147,9 @@ def if_not_windows(a): return select({ "//tensorflow:windows": [], "//conditions:default": a, - }) + }) +# LINT.IfChange def tf_copts(): return (["-DEIGEN_AVOID_STL_ARRAY", "-Iexternal/gemmlowp", @@ -179,6 +180,7 @@ def tf_opts_nortti_if_android(): "-DGOOGLE_PROTOBUF_NO_RTTI", "-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER", ]) +# LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt) # Given a list of "op_lib_names" (a list of files in the ops directory # without their .cc extensions), generate a library for that file. @@ -552,29 +554,6 @@ def tf_kernel_library(name, prefix=None, srcs=None, gpu_srcs=None, hdrs=None, deps = deps, **kwargs) -def tf_kernel_libraries(name, prefixes, deps=None, libs=None, **kwargs): - """Makes one target per prefix, and one target that includes them all. - - Args: - name: The name of the omnibus cc_library target that depends on each - generated tf_kernel_library target. - prefixes: A list of source file name prefixes used to generate individual - libraries. See the definition of tf_kernel_library for details. - deps: The dependencies list associated with each generated target. - libs: Additional tf_kernel_library targets that should be included in the - omnibus cc_library target but not as deps of individual libraries. - This can be used, for example, if a library that was previously - generated by this rule is refactored into a separate definition - in order to specify more or fewer deps for it. - - Other attributes are forwarded to each individual target but not to the - omnibus cc_library target. - """ - for p in prefixes: - tf_kernel_library(name=p, prefix=p, deps=deps, **kwargs) - native.cc_library(name=name, - deps=[":" + p for p in prefixes] + (libs or [])) - # Bazel rules for building swig files. def _py_wrap_cc_impl(ctx): srcs = ctx.files.srcs diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc index c3d9e865b1..c544829e5f 100644 --- a/tensorflow/tools/benchmark/benchmark_model.cc +++ b/tensorflow/tools/benchmark/benchmark_model.cc @@ -100,6 +100,11 @@ Status RunBenchmark(const std::vector<InputLayerInfo>& inputs, int_tensor = int_tensor.constant(0.0); break; } + case DT_UINT8: { + auto int_tensor = input_tensor.flat<uint8>(); + int_tensor = int_tensor.constant(0.0); + break; + } default: LOG(FATAL) << "Unsupported input type: " << input.data_type; } diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index e945df2c61..0d890f5684 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -291,6 +291,71 @@ do_buildifier(){ fi } +do_external_licenses_check(){ + echo "Running do_external_licenses_check" + echo "" + + EXTERNAL_LICENSES_CHECK_START_TIME=$(date +'%s') + + EXTERNAL_DEPENDENCIES_FILE="$(mktemp)_external_dependencies.log" + LICENSES_FILE="$(mktemp)_licenses.log" + MISSING_LICENSES_FILE="$(mktemp)_missing_licenses.log" + EXTRA_LICENSES_FILE="$(mktemp)_extra_licenses.log" + + echo "Getting external dependencies for //tensorflow/tools/pip_package:build_pip_package." + bazel query 'attr("licenses", "notice", deps(//tensorflow/tools/pip_package:build_pip_package))' --no_implicit_deps --no_host_deps --keep_going \ + | egrep -v "^//tensorflow" \ + | sed -e 's|:.*||' \ + | sort \ + | uniq 2>&1 \ + | tee ${EXTERNAL_DEPENDENCIES_FILE} + + echo + echo "Getting list of external licenses." + bazel query 'deps(//tensorflow/tools/pip_package:licenses)' --no_implicit_deps --no_host_deps --keep_going \ + | egrep -v "^//tensorflow" \ + | sed -e 's|:.*||' \ + | sort \ + | uniq 2>&1 \ + | tee ${LICENSES_FILE} + + echo + comm -1 -3 ${EXTERNAL_DEPENDENCIES_FILE} ${LICENSES_FILE} 2>&1 | tee ${EXTRA_LICENSES_FILE} + echo + comm -2 -3 ${EXTERNAL_DEPENDENCIES_FILE} ${LICENSES_FILE} 2>&1 | tee ${MISSING_LICENSES_FILE} + + EXTERNAL_LICENSES_CHECK_END_TIME=$(date +'%s') + + echo + echo "do_external_licenses_check took $((${EXTERNAL_LICENSES_CHECK_END_TIME} - ${EXTERNAL_LICENSES_CHECK_START_TIME})) s" + echo + + if [[ -s ${MISSING_LICENSES_FILE} ]] || [[ -s ${EXTRA_LICENSES_FILE} ]] ; then + echo "FAIL: pip package external dependencies vs licenses mismatch." + if [[ -s ${MISSING_LICENSES_FILE} ]] ; then + echo "Missing the licenses for the following external dependencies:" + cat ${MISSING_LICENSES_FILE} + fi + if [[ -s ${EXTRA_LICENSES_FILE} ]] ; then + echo "Please remove the licenses for the following external dependencies:" + cat ${EXTRA_LICENSES_FILE} + fi + rm -rf ${EXTERNAL_DEPENDENCIES_FILE} + rm -rf ${LICENSES_FILE} + rm -rf ${MISSING_LICENSES_FILE} + rm -rf ${EXTRA_LICENSES_FILE} + return 1 + else + echo "PASS: all external licenses included." + rm -rf ${EXTERNAL_DEPENDENCIES_FILE} + rm -rf ${LICENSES_FILE} + rm -rf ${MISSING_LICENSES_FILE} + rm -rf ${EXTRA_LICENSES_FILE} + return 0 + fi +} + + # Run bazel build --nobuild to test the validity of the BUILD files do_bazel_nobuild() { BUILD_TARGET="//tensorflow/..." @@ -311,8 +376,8 @@ do_bazel_nobuild() { } # Supply all sanity step commands and descriptions -SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_buildifier" "do_bazel_nobuild") -SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "buildifier check" "bazel nobuild") +SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_buildifier" "do_bazel_nobuild" "do_external_licenses_check") +SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "buildifier check" "bazel nobuild" "external dependencies licenses check") INCREMENTAL_FLAG="" diff --git a/tensorflow/tools/dist_test/python/census_widendeep.py b/tensorflow/tools/dist_test/python/census_widendeep.py index f5510c374a..309366e467 100644 --- a/tensorflow/tools/dist_test/python/census_widendeep.py +++ b/tensorflow/tools/dist_test/python/census_widendeep.py @@ -53,8 +53,8 @@ FLAGS = flags.FLAGS # Constants: Data download URLs -TRAIN_DATA_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data" -TEST_DATA_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test" +TRAIN_DATA_URL = "http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data" +TEST_DATA_URL = "http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test" # Define features for the model diff --git a/tensorflow/tools/dist_test/server/Dockerfile.test b/tensorflow/tools/dist_test/server/Dockerfile.test index 22438f3984..e2feb2227b 100644 --- a/tensorflow/tools/dist_test/server/Dockerfile.test +++ b/tensorflow/tools/dist_test/server/Dockerfile.test @@ -63,9 +63,9 @@ RUN curl -o /tmp/mnist-data/t10k-images-idx3-ubyte.gz \ # Download Census data for Wide & Deep test RUN mkdir -p /tmp/census-data RUN curl -o /tmp/census-data/adult.data \ - https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data + http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data RUN curl -o /tmp/census-data/adult.test \ - https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test + http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test # Container entry point ENTRYPOINT ["/var/tf-k8s/server/grpc_tensorflow_server_wrapper.sh"] diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index d6a6d83f9b..d0c813e84f 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -78,12 +78,36 @@ py_binary( deps = ["//tensorflow:tensorflow_py"], ) +filegroup( + name = "licenses", + data = [ + "//third_party/eigen3:LICENSE", + "//third_party/hadoop:LICENSE.txt", + "@boringssl//:LICENSE", + "@com_googlesource_code_re2//:LICENSE", + "@eigen_archive//:COPYING.MPL2", + "@farmhash_archive//:COPYING", + "@gemmlowp//:LICENSE", + "@gif_archive//:COPYING", + "@grpc//:LICENSE", + "@highwayhash//:LICENSE", + "@jpeg//:LICENSE.md", + "@local_config_sycl//sycl:LICENSE.text", + "@nanopb_git//:LICENSE.txt", + "@png_archive//:LICENSE", + "@protobuf//:LICENSE", + "@six_archive//:LICENSE", + "@zlib_archive//:zlib.h", + ], +) + sh_binary( name = "build_pip_package", srcs = ["build_pip_package.sh"], data = select({ "//tensorflow:windows": [":simple_console_for_windows"], "//conditions:default": [ + ":licenses", "MANIFEST.in", "README", "setup.py", diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 62e37f6ad4..9ca2ffc509 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -46,7 +46,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): name = "farmhash_archive", url = "http://github.com/google/farmhash/archive/92e897b282426729f4724d91a637596c7e2fe28f.zip", sha256 = "4c626d1f306bda2c6804ab955892f803f5245f4dcaecb4979dc08b091256da54", - strip_prefix = "farmhash-92e897b282426729f4724d91a637596c7e2fe28f/src", + strip_prefix = "farmhash-92e897b282426729f4724d91a637596c7e2fe28f", build_file = str(Label("//:farmhash.BUILD")), ) @@ -90,7 +90,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): name = "gif_archive", url = "http://cdimage.debian.org/mirror/xbmc.org/build-deps/sources/giflib-5.1.4.tar.gz", sha256 = "34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1", - strip_prefix = "giflib-5.1.4/lib", + strip_prefix = "giflib-5.1.4", build_file = str(Label("//:gif.BUILD")), ) @@ -248,3 +248,15 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): name = "zlib", actual = "@zlib_archive//:zlib", ) + + # Make junit-4.12 available as //external:junit + native.http_jar( + name = "junit_jar", + url = "https://github.com/junit-team/junit4/releases/download/r4.12/junit-4.12.jar", + sha256 = "59721f0805e223d84b90677887d9ff567dc534d7c502ca903c0c2b17f05c116a", + ) + + native.bind( + name = "junit", + actual = "@junit_jar//jar", + ) diff --git a/third_party/eigen3/BUILD b/third_party/eigen3/BUILD index f697866bde..c2abf78e95 100644 --- a/third_party/eigen3/BUILD +++ b/third_party/eigen3/BUILD @@ -9,6 +9,8 @@ licenses([ "notice", # Portions BSD ]) +exports_files(["LICENSE"]) + cc_library( name = "eigen3", hdrs = glob(["unsupported/Eigen/CXX11/src/FixedPoint/*.h"]) + [ diff --git a/third_party/hadoop/BUILD b/third_party/hadoop/BUILD index f25208c416..9e98154400 100644 --- a/third_party/hadoop/BUILD +++ b/third_party/hadoop/BUILD @@ -2,6 +2,8 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 +exports_files(["LICENSE.txt"]) + filegroup( name = "all_files", srcs = glob( diff --git a/third_party/hadoop/LICENSE.txt b/third_party/hadoop/LICENSE.txt new file mode 100644 index 0000000000..6ccfd09277 --- /dev/null +++ b/third_party/hadoop/LICENSE.txt @@ -0,0 +1,284 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +APACHE HADOOP SUBCOMPONENTS: + +The Apache Hadoop project contains subcomponents with separate copyright +notices and license terms. Your use of the source code for the these +subcomponents is subject to the terms and conditions of the following +licenses. + +For the org.apache.hadoop.util.bloom.* classes: + +/** + * + * Copyright (c) 2005, European Commission project OneLab under contract + * 034819 (http://www.one-lab.org) + * All rights reserved. + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * - Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * - Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the distribution. + * - Neither the name of the University Catholique de Louvain - UCL + * nor the names of its contributors may be used to endorse or + * promote products derived from this software without specific prior + * written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE + * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, + * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN + * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +For portions of the native implementation of slicing-by-8 CRC calculation +in src/main/native/src/org/apache/hadoop/util: + +/** + * Copyright 2008,2009,2010 Massachusetts Institute of Technology. + * All rights reserved. Use of this source code is governed by a + * BSD-style license that can be found in the LICENSE file. + */ + + For src/main/native/src/org/apache/hadoop/io/compress/lz4/lz4.c: + +/* + LZ4 - Fast LZ compression algorithm + Copyright (C) 2011, Yann Collet. + BSD License + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD index 4924a49b57..e1c20e82a7 100644 --- a/third_party/llvm/llvm.BUILD +++ b/third_party/llvm/llvm.BUILD @@ -13,6 +13,8 @@ load( "cmake_var_string", ) +package(default_visibility = ["@//tensorflow/compiler/xla:internal"]) + llvm_host_triple = "x86_64-unknown-linux_gnu" llvm_targets = [ @@ -26,6 +28,7 @@ llvm_targets = [ llvm_target_asm_parsers = [ "AArch64", "ARM", + "NVPTX", "PowerPC", "X86", ] @@ -1334,6 +1337,28 @@ cc_library( ) cc_library( + name = "objc_arc", + srcs = glob([ + "lib/Transforms/ObjCARC/*.c", + "lib/Transforms/ObjCARC/*.cpp", + "lib/Transforms/ObjCARC/*.inc", + "lib/Transforms/ObjCARC/*.h", + ]), + hdrs = glob([ + "include/llvm/Transforms/ObjCARC/*.h", + "include/llvm/Transforms/ObjCARC/*.def", + "include/llvm/Transforms/ObjCARC/*.inc", + ]), + deps = [ + ":analysis", + ":config", + ":core", + ":support", + ":transform_utils", + ], +) + +cc_library( name = "orc_jit", srcs = glob([ "lib/ExecutionEngine/Orc/*.c", diff --git a/third_party/pcre.BUILD b/third_party/pcre.BUILD index d9ef246672..68aadd1d40 100644 --- a/third_party/pcre.BUILD +++ b/third_party/pcre.BUILD @@ -1,6 +1,6 @@ licenses(["notice"]) # BSD -exports_files(["LICENSE"]) +exports_files(["COPYING"]) cc_library( name = "pcre", diff --git a/third_party/sycl/sycl/BUILD.tpl b/third_party/sycl/sycl/BUILD.tpl index 9e83b1994c..c66a9f007d 100755 --- a/third_party/sycl/sycl/BUILD.tpl +++ b/third_party/sycl/sycl/BUILD.tpl @@ -7,6 +7,8 @@ load("platform", "readlink_command") package(default_visibility = ["//visibility:public"]) +exports_files(["LICENSE.text"]) + config_setting( name = "using_sycl", values = { diff --git a/third_party/sycl/sycl/LICENSE.text.tpl b/third_party/sycl/sycl/LICENSE.text.tpl new file mode 100644 index 0000000000..0c2955c4d7 --- /dev/null +++ b/third_party/sycl/sycl/LICENSE.text.tpl @@ -0,0 +1,268 @@ + +--------------------------------------------------------------------- + +SOFTWARE LICENSE AGREEMENT + +--------------------------------------------------------------------- +--------------------------------------------------------------------- + +By downloading, installing, copying, or otherwise using the +ComputeCpp Community Edition software, including any associated +components, media, printed materials, and electronic documentation +("Software"), the user agrees to the following terms and conditions +of this Software License Agreement ("Agreement"). Please read the +terms of this Agreement carefully before beginning your download, as +pressing the "I AGREE" button at the end of this Agreement will +confirm your assent. If you do not agree to these terms, then +Codeplay Software Limited is unwilling to license the Software to +you; so please press the "CANCEL" button to cancel your download. + + 1. License. Codeplay Software Ltd., a company incorporated in + England and Wales with registered number 04567874 and having its + registered office at Regent House, 316 Beulah Hill, London, + United Kingdom, SE19 3HF ("Codeplay") hereby grants the user, + free of charge, a non-exclusive worldwide license to use and + replicate (but not modify) the Software for any use, whether + commercial or non-commercial, in accordance with this Agreement. + Codeplay reserves all rights to the Software that are not + expressly granted by this Agreement. + 2. Redistribution. The user may copy and redistribute unmodified + copies of only those components of the Software which are + specified below ("Redistributable Components"), in object code + form, as part of the user’s software applications or libraries + ("Applications"). The user acknowledges and agrees that it has no + right to modify the Redistributable Components in any way. Any + use of the Redistributable Components within the user’s + Applications will continue to be subject to the terms and + conditions of this Agreement, and the user must also distribute a + copy of this Agreement and reproduce and include all notices of + copyrights or other proprietary rights in the Software. The + user’s redistribution of the Redistributable Components will not + entitle it to any payment from Codeplay. The user may not + transfer any of its rights or obligations under this Agreement. + ++-------------------------------------------+ +|Redistributable Component|File Name | +|-------------------------+-----------------| +|Runtime (for Linux) |libComputeCpp.so | +|-------------------------+-----------------| +|Runtime (for Windows) |libComputeCpp.dll| ++-------------------------------------------+ + + 3. Restrictions. The user shall not: + + a. circumvent or bypass any technological protection measures in + or relating to the Software; + b. use the Software to perform any unauthorized transfer of + information or for any illegal purpose; + c. de-compile, decrypt, disassemble, hack, emulate, exploit or + reverse-engineer the Software (other than to the limited + extent permitted by law); + d. copy or redistribute any components of the Software that are + not listed in the table of Redistributable Components; + e. publish, rent, lease, sell, export, import, or lend the + Software; + f. represent in any way that it is selling the Software itself + or any license to use the Software, nor refer to Codeplay or + ComputeCpp within its marketing materials, without the + express prior written permission of Codeplay. + 4. Support. Codeplay does not provide any guarantees of support for + the Software to the user. Codeplay will use reasonable endeavours + to respond to users' support requests, for the most recent + release only, via the community support website at https:// + computecpp.codeplay.com. + 5. Intellectual Property. The Software is owned by Codeplay or its + licensors, and is protected by the copyright laws of the United + Kingdom and other countries and international treaty provisions. + Codeplay (and/or its licensors, as the case may be) retains all + copyrights, trade secrets and other proprietary rights in the + Software, including the rights to make and license the use of all + copies. To the extent that any patents owned by Codeplay or its + licensors relate to any component of the Software, the licence + granted to the user in accordance with this Agreement allows for + the lawful use of such patents but only for the purposes of this + Agreement and not further or otherwise. Therefore, the user may + make no copies of the Software, or the written materials that + accompany the Software, or reproduce it in any way, except as set + forth above. + 6. Terms. This Agreement is effective until terminated. Codeplay or + the user may terminate it immediately at any time. Any violation + of the terms of this Agreement by the user will result in + immediate termination by Codeplay. Upon termination, the user + must return or destroy the Software and accompanying materials + and notify Codeplay of its actions by email to info@codeplay.com. + 7. NO WARRANTIES. Codeplay expressly disclaims any warranty for the + Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF + ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE + AND NON-INFRINGEMENT. IN NO EVENT SHALL CODEPLAY BE LIABLE FOR + ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF + CONTRACT, DELICT OR TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. In particular, Codeplay provides no guarantees of + application performance on the target hardware. + 8. General. The invalidity of any portion or provision of this + Agreement shall not affect any other portions or provisions. This + Agreement shall be governed by the laws of Scotland. This + Agreement is the complete and exclusive agreement between the + user and Codeplay regarding the Software, and it supersedes any + prior agreement, oral or written, and any other communication + between the user and Codeplay relating to the subject matter of + the Agreement. Any amendment or modification of this Agreement + must be in writing and signed by both parties. If the user does + not agree to the terms of this Agreement, the user must not + install or use the Software. + 9. Third Party Licenses. The following licenses are for third-party + components included in the software. + + a. License for Clang/LLVM compiler technology components: + +============================================================================== + +LLVM Release License + +============================================================================== + +University of Illinois/NCSA + +Open Source License + +Copyright (c) 2007-2014 University of Illinois at Urbana-Champaign. + +All rights reserved. + +Developed by: + + LLVM Team + + University of Illinois at Urbana-Champaign + + http://llvm.org + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal with +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimers. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimers in the + documentation and/or other materials provided with the distribution. + + * Neither the names of the LLVM Team, University of Illinois at + Urbana-Champaign, nor the names of its contributors may be used to + endorse or promote products derived from this Software without specific + prior written permission. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE +SOFTWARE. + +============================================================================== + + b. License for OpenBSD regex components: + +$OpenBSD: COPYRIGHT,v 1.3 2003/06/02 20:18:36 millert Exp $ +Copyright 1992, 1993, 1994 Henry Spencer. All rights reserved. +This software is not subject to any license of the American Telephone +and Telegraph Company or of the Regents of the University of California. +Permission is granted to anyone to use this software for any purpose on +any computer system, and to alter it and redistribute it, subject +to the following restrictions: + +1. The author is not responsible for the consequences of use of this + software, no matter how awful, even if they arise from flaws in it. + +2. The origin of this software must not be misrepresented, either by + explicit claim or by omission. Since few users ever read sources, + credits must appear in the documentation. + +3. Altered versions must be plainly marked as such, and must not be + misrepresented as being the original software. Since few users + ever read sources, credits must appear in the documentation. + +4. This notice may not be removed or altered. + +=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= + +/*- + * Copyright (c) 1994 + * The Regents of the University of California. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the University nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + * + * @(#)COPYRIGHT8.1 (Berkeley) 3/16/94 + */ + + c. License for MD5 components: + +/* + * This code is derived from (original license follows): + * + * This is an OpenSSL-compatible implementation of the RSA Data Security, Inc. + * MD5 Message-Digest Algorithm (RFC 1321). + * + * Homepage: + * http://openwall.info/wiki/people/solar/software/public-domain-source-code/md5 + * + * Author: + * Alexander Peslyak, better known as Solar Designer <solar at openwall.com> + * + * This software was written by Alexander Peslyak in 2001. No copyright is + * claimed, and the software is hereby placed in the public domain. + * In case this attempt to disclaim copyright and place the software in the + * public domain is deemed null and void, then the software is + * Copyright (c) 2001 Alexander Peslyak and it is hereby released to the + * general public under the following terms: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted. + * + * There's ABSOLUTELY NO WARRANTY, express or implied. + * + * (This is a heavily cut-down "BSD license".) + * + * This differs from Colin Plumb's older public domain implementation in that + * no exactly 32-bit integer data type is required (any 32-bit or wider + * unsigned integer data type will do), there's no compile-time endianness + * configuration, and the function prototypes match OpenSSL's. No code from + * Colin Plumb's implementation has been reused; this comment merely compares + * the properties of the two independent implementations. + * + * The primary goals of this implementation are portability and ease of use. + * It is meant to be fast, but not as fast as possible. Some known + * optimizations are not included to reduce source code size and avoid + * compile-time configuration. + */ + + diff --git a/third_party/sycl/sycl_configure.bzl b/third_party/sycl/sycl_configure.bzl index 6102ed49c2..38bd7759de 100644 --- a/third_party/sycl/sycl_configure.bzl +++ b/third_party/sycl/sycl_configure.bzl @@ -135,6 +135,7 @@ def _create_dummy_repository(repository_ctx): # Set up BUILD file for sycl/. _file(repository_ctx, "sycl:build_defs.bzl") _tpl(repository_ctx, "sycl:BUILD") + _tpl(repository_ctx, "sycl:LICENSE.text") _tpl(repository_ctx, "sycl:platform.bzl") # Create dummy files for the SYCL toolkit since they are still required by |