diff options
-rw-r--r-- | tensorflow/core/kernels/string_to_hash_bucket_op.cc | 23 | ||||
-rw-r--r-- | tensorflow/core/kernels/string_to_hash_bucket_op.h | 66 | ||||
-rw-r--r-- | tensorflow/core/ops/string_ops.cc | 16 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py | 24 | ||||
-rw-r--r-- | tensorflow/python/ops/string_ops.py | 4 |
5 files changed, 120 insertions, 13 deletions
diff --git a/tensorflow/core/kernels/string_to_hash_bucket_op.cc b/tensorflow/core/kernels/string_to_hash_bucket_op.cc index 539648a676..3a2429d4cd 100644 --- a/tensorflow/core/kernels/string_to_hash_bucket_op.cc +++ b/tensorflow/core/kernels/string_to_hash_bucket_op.cc @@ -13,20 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include <string> +#include "tensorflow/core/kernels/string_to_hash_bucket_op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/fingerprint.h" namespace tensorflow { -class StringToHashBucketOp : public OpKernel { +// Deprecated class. It also uses `string_tensor` as Op argument instead of +// `input`. +class LegacyStringToHashBuckeOp : public OpKernel { public: - explicit StringToHashBucketOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + explicit LegacyStringToHashBuckeOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_)); } @@ -55,10 +54,14 @@ class StringToHashBucketOp : public OpKernel { private: int64 num_buckets_; - TF_DISALLOW_COPY_AND_ASSIGN(StringToHashBucketOp); + TF_DISALLOW_COPY_AND_ASSIGN(LegacyStringToHashBuckeOp); }; +// StringToHashBucket is deprecated in favor of StringToHashBucketStable. REGISTER_KERNEL_BUILDER(Name("StringToHashBucket").Device(DEVICE_CPU), - StringToHashBucketOp); + LegacyStringToHashBuckeOp); + +REGISTER_KERNEL_BUILDER(Name("StringToHashBucketFast").Device(DEVICE_CPU), + StringToHashBucketOp<Fingerprint64>); } // namespace tensorflow diff --git a/tensorflow/core/kernels/string_to_hash_bucket_op.h b/tensorflow/core/kernels/string_to_hash_bucket_op.h new file mode 100644 index 0000000000..9c6c0a89e4 --- /dev/null +++ b/tensorflow/core/kernels/string_to_hash_bucket_op.h @@ -0,0 +1,66 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_OP_H_ + +#include <string> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +template <uint64 hash(const string&)> +class StringToHashBucketOp : public OpKernel { + public: + explicit StringToHashBucketOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(context, context->input("input", &input_tensor)); + const auto& input_flat = input_tensor->flat<string>(); + + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("output", input_tensor->shape(), + &output_tensor)); + auto output_flat = output_tensor->flat<int64>(); + + typedef decltype(input_flat.size()) Index; + for (Index i = 0; i < input_flat.size(); ++i) { + const uint64 input_hash = hash(input_flat(i)); + const uint64 bucket_id = input_hash % num_buckets_; + // The number of buckets is always in the positive range of int64 so is + // the resulting bucket_id. Casting the bucket_id from uint64 to int64 is + // safe. + output_flat(i) = static_cast<int64>(bucket_id); + } + } + + private: + int64 num_buckets_; + + TF_DISALLOW_COPY_AND_ASSIGN(StringToHashBucketOp); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_OP_H_ diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index 93a239c5c4..1a274f1e68 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -17,10 +17,26 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("StringToHashBucketFast") + .Input("input: string") + .Output("output: int64") + .Attr("num_buckets: int >= 1") + .Doc(R"doc( +Converts each string in the input Tensor to its hash mod by a number of buckets. + +The hash function is deterministic on the content of the string within the +process and will never change. However, it is not suitable for cryptography. + +input: The strings to assing a hash bucket. +num_buckets: The number of buckets. +output: A Tensor of the same shape as the input `string_tensor`. +)doc"); + REGISTER_OP("StringToHashBucket") .Input("string_tensor: string") .Output("output: int64") .Attr("num_buckets: int >= 1") + .Deprecated(10, "Use tf.string_to_hash_bucket_fast()") .Doc(R"doc( Converts each string in the input Tensor to its hash mod by a number of buckets. diff --git a/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py b/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py index 17219f3483..379edbfbb0 100644 --- a/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py +++ b/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py @@ -23,7 +23,27 @@ import tensorflow as tf class StringToHashBucketOpTest(tf.test.TestCase): - def testStringToOneHashBucket(self): + def testStringToOneHashBucketFast(self): + with self.test_session(): + input_string = tf.placeholder(tf.string) + output = tf.string_to_hash_bucket_fast(input_string, 1) + result = output.eval(feed_dict={input_string: ['a', 'b', 'c']}) + + self.assertAllEqual([0, 0, 0], result) + + def testStringToHashBucketsFast(self): + with self.test_session(): + input_string = tf.placeholder(tf.string) + output = tf.string_to_hash_bucket_fast(input_string, 10) + result = output.eval(feed_dict={input_string: ['a', 'b', 'c', 'd']}) + + # Fingerprint64('a') -> 12917804110809363939 -> mod 10 -> 9 + # Fingerprint64('b') -> 11795596070477164822 -> mod 10 -> 2 + # Fingerprint64('c') -> 11430444447143000872 -> mod 10 -> 2 + # Fingerprint64('d') -> 4470636696479570465 -> mod 10 -> 5 + self.assertAllEqual([9, 2, 2, 5], result) + + def testStringToOneHashBucketLegacyHash(self): with self.test_session(): input_string = tf.placeholder(tf.string) output = tf.string_to_hash_bucket(input_string, 1) @@ -33,7 +53,7 @@ class StringToHashBucketOpTest(tf.test.TestCase): self.assertAllEqual([0, 0, 0], result) - def testStringToHashBuckets(self): + def testStringToHashBucketsLegacyHash(self): with self.test_session(): input_string = tf.placeholder(tf.string) output = tf.string_to_hash_bucket(input_string, 10) diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index d082377bd4..1cd38af3b8 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -18,6 +18,7 @@ String hashing ops take a string input tensor and map each element to an integer. +@@string_to_hash_bucket_fast @@string_to_hash_bucket ## Joining @@ -47,9 +48,11 @@ from tensorflow.python.ops.gen_string_ops import * # pylint: enable=wildcard-import ops.NoGradient("StringToHashBucket") +ops.NoGradient("StringToHashBucketFast") ops.NoGradient("ReduceJoin") ops.RegisterShape("StringToHashBucket")(common_shapes.unchanged_shape) +ops.RegisterShape("StringToHashBucketFast")(common_shapes.unchanged_shape) @ops.RegisterShape("ReduceJoin") @@ -94,4 +97,3 @@ def _ReduceJoinShape(op): returned_dims.append(dim) return [tensor_shape.TensorShape(returned_dims)] - |