aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/text
diff options
context:
space:
mode:
authorGravatar Wei Ho <weiho@google.com>2017-05-31 17:07:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-31 17:11:03 -0700
commit458f94c128fa5f72085be9a2489765615e1951a7 (patch)
tree4309a23782c21194e58fb76c557e8547667ad1e7 /tensorflow/contrib/text
parentfaac0331c2701fa8c2d669089e531e785be8fddd (diff)
Open-source skip-gram ops
PiperOrigin-RevId: 157655970
Diffstat (limited to 'tensorflow/contrib/text')
-rw-r--r--tensorflow/contrib/text/BUILD119
-rw-r--r--tensorflow/contrib/text/__init__.py30
-rw-r--r--tensorflow/contrib/text/kernels/skip_gram_kernels.cc139
-rw-r--r--tensorflow/contrib/text/ops/skip_gram_ops.cc54
-rw-r--r--tensorflow/contrib/text/python/ops/__init__.py22
-rw-r--r--tensorflow/contrib/text/python/ops/skip_gram_ops.py428
-rw-r--r--tensorflow/contrib/text/python/ops/skip_gram_ops_test.py571
7 files changed, 1363 insertions, 0 deletions
diff --git a/tensorflow/contrib/text/BUILD b/tensorflow/contrib/text/BUILD
new file mode 100644
index 0000000000..ff69c4e2cb
--- /dev/null
+++ b/tensorflow/contrib/text/BUILD
@@ -0,0 +1,119 @@
+# Description:
+# contains parts of TensorFlow that are experimental or unstable and which
+# are not supported.
+
+package(default_visibility = [
+ "//learning/brain:__subpackages__",
+ "//tensorflow:__subpackages__",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_custom_op_library",
+ "tf_custom_op_py_library",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+)
+
+tf_custom_op_py_library(
+ name = "text_py",
+ srcs = [
+ "__init__.py",
+ "python/ops/__init__.py",
+ "python/ops/skip_gram_ops.py",
+ ],
+ dso = [
+ ":python/ops/_skip_gram_ops.so",
+ ],
+ kernels = [
+ ":all_kernels",
+ ":all_ops",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gen_skip_gram_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_kernel_library(
+ name = "skip_gram_kernels",
+ srcs = ["kernels/skip_gram_kernels.cc"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "all_kernels",
+ deps = [":skip_gram_kernels"],
+)
+
+tf_custom_op_library(
+ name = "python/ops/_skip_gram_ops.so",
+ srcs = [
+ "kernels/skip_gram_kernels.cc",
+ "ops/skip_gram_ops.cc",
+ ],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["skip_gram_ops"],
+)
+
+cc_library(
+ name = "all_ops",
+ deps = [":skip_gram_ops_op_lib"],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_skip_gram_ops",
+ out = "python/ops/gen_skip_gram_ops.py",
+ deps = [":skip_gram_ops_op_lib"],
+)
+
+py_test(
+ name = "skip_gram_ops_test",
+ size = "medium",
+ srcs = ["python/ops/skip_gram_ops_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":text_py",
+ "//tensorflow/contrib/lookup:lookup_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:lookup_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:training",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+)
diff --git a/tensorflow/contrib/text/__init__.py b/tensorflow/contrib/text/__init__.py
new file mode 100644
index 0000000000..35e6623189
--- /dev/null
+++ b/tensorflow/contrib/text/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Text-processing ops.
+
+@@skip_gram_sample
+@@skip_gram_sample_with_text_vocab
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.text.python.ops import *
+# pylint: enable=unused-import,wildcard-import
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/text/kernels/skip_gram_kernels.cc b/tensorflow/contrib/text/kernels/skip_gram_kernels.cc
new file mode 100644
index 0000000000..3cd0b5f72b
--- /dev/null
+++ b/tensorflow/contrib/text/kernels/skip_gram_kernels.cc
@@ -0,0 +1,139 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/util/guarded_philox_random.h"
+
+namespace tensorflow {
+
+template <typename T>
+class SkipGramGenerateCandidatesOp : public OpKernel {
+ public:
+ explicit SkipGramGenerateCandidatesOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, generator_.Init(context));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(context, context->input("input_tensor", &input_tensor));
+ const auto input = input_tensor->flat<T>();
+
+ const Tensor* min_skips_tensor;
+ OP_REQUIRES_OK(context, context->input("min_skips", &min_skips_tensor));
+ const int min_skips = *(min_skips_tensor->scalar<int>().data());
+ const Tensor* max_skips_tensor;
+ OP_REQUIRES_OK(context, context->input("max_skips", &max_skips_tensor));
+ const int max_skips = *(max_skips_tensor->scalar<int>().data());
+
+ OP_REQUIRES(
+ context, min_skips >= 0 && max_skips >= 0,
+ errors::InvalidArgument("Both min_skips and max_skips must be >= 0."));
+ OP_REQUIRES(context, min_skips <= max_skips,
+ errors::InvalidArgument("min_skips must be <= max_skips."));
+
+ const Tensor* start_tensor;
+ OP_REQUIRES_OK(context, context->input("start", &start_tensor));
+ const int start = *(start_tensor->scalar<int>().data());
+ const Tensor* limit_tensor;
+ OP_REQUIRES_OK(context, context->input("limit", &limit_tensor));
+ const int limit = *(limit_tensor->scalar<int>().data());
+ const int end =
+ limit < 0 ? input.size()
+ : std::min(start + limit, static_cast<int>(input.size()));
+
+ const Tensor* emit_self_tensor;
+ OP_REQUIRES_OK(context,
+ context->input("emit_self_as_target", &emit_self_tensor));
+ const bool emit_self_as_target = *(emit_self_tensor->scalar<bool>().data());
+
+ std::vector<T> tokens;
+ std::vector<T> labels;
+
+ // Reserve the number of random numbers we will use - we use one for each
+ // token between start and end.
+ random::PhiloxRandom local_gen =
+ generator_.ReserveSamples32(end - start + 1);
+ random::SimplePhilox rng(&local_gen);
+
+ // For each token in the sentence, pick a random skip, then generates
+ // (token, label) pairs for all labels whose distances from the token are
+ // within the range [-skip, skip].
+ for (int i = start; i < end; ++i) {
+ const int skips = min_skips + rng.Uniform(max_skips - min_skips + 1);
+ for (int j = -skips; j <= skips; ++j) {
+ if ((i + j < start) || (i + j >= end) ||
+ (j == 0 && !emit_self_as_target)) {
+ continue;
+ }
+ tokens.push_back(input(i));
+ labels.push_back(input(i + j));
+ }
+ }
+
+ Tensor* tokens_output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ "tokens", TensorShape({static_cast<int>(tokens.size())}),
+ &tokens_output));
+ Tensor* labels_output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ "labels", TensorShape({static_cast<int>(labels.size())}),
+ &labels_output));
+ OP_REQUIRES(
+ context, tokens_output->IsSameSize(*labels_output),
+ errors::Internal(strings::StrCat(
+ "Mismatch between tokens_output shape of ",
+ tokens_output->shape().DebugString(),
+ " and labels_output shape of ",
+ labels_output->shape().DebugString(),
+ ". This should never happen - contact ami-team@ if it does.")));
+
+ // Copies results to output tensors.
+ for (int i = 0; i < tokens.size(); ++i) {
+ tokens_output->vec<T>()(i) = tokens[i];
+ labels_output->vec<T>()(i) = labels[i];
+ }
+ }
+
+ private:
+ GuardedPhiloxRandom generator_;
+};
+
+#define REGISTER_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("SkipGramGenerateCandidates") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T"), \
+ SkipGramGenerateCandidatesOp<type>)
+
+REGISTER_KERNEL(string);
+REGISTER_KERNEL(int64);
+REGISTER_KERNEL(int32);
+REGISTER_KERNEL(int16);
+// TODO(weiho): Add other types if the need arises.
+
+#undef REGISTER_KERNEL
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/text/ops/skip_gram_ops.cc b/tensorflow/contrib/text/ops/skip_gram_ops.cc
new file mode 100644
index 0000000000..9a7a20d81a
--- /dev/null
+++ b/tensorflow/contrib/text/ops/skip_gram_ops.cc
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+REGISTER_OP("SkipGramGenerateCandidates")
+ .Input("input_tensor: T")
+ .Input("min_skips: int32")
+ .Input("max_skips: int32")
+ .Input("start: int32")
+ .Input("limit: int32")
+ .Input("emit_self_as_target: bool")
+ .Output("tokens: T")
+ .Output("labels: T")
+ .Attr("T: type")
+ // The seed attributes are needed by GuardedPhiloxRandom
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .SetIsStateful()
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // input_tensor must be of rank-1.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
+ // All other args must be scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+
+ // Due to possible randomness in selecting skips, we only know that the
+ // outputs will be of rank-1, but not their sizes.
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ c->set_output(1, c->Vector(c->UnknownDim()));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Generates skip-gram token and label paired Tensors from the input tensor.
+See docs for the public-facing skip_gram_sample() Python op for more details.
+)doc");
+} // namespace tensorflow
diff --git a/tensorflow/contrib/text/python/ops/__init__.py b/tensorflow/contrib/text/python/ops/__init__.py
new file mode 100644
index 0000000000..bb47266dd2
--- /dev/null
+++ b/tensorflow/contrib/text/python/ops/__init__.py
@@ -0,0 +1,22 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Various contrib ops related to text-processing."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.text.python.ops.skip_gram_ops import skip_gram_sample
+from tensorflow.contrib.text.python.ops.skip_gram_ops import skip_gram_sample_with_text_vocab
diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops.py b/tensorflow/contrib/text/python/ops/skip_gram_ops.py
new file mode 100644
index 0000000000..410ee517e0
--- /dev/null
+++ b/tensorflow/contrib/text/python/ops/skip_gram_ops.py
@@ -0,0 +1,428 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Skip-gram sampling ops from https://arxiv.org/abs/1301.3781."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import csv
+
+from tensorflow.contrib import lookup
+from tensorflow.contrib.text.python.ops import gen_skip_gram_ops
+from tensorflow.contrib.util import loader
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.training import input as input_ops
+
+_checkpoint_ops_so = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_skip_gram_ops.so"))
+
+ops.NotDifferentiable("SkipGramGenerateCandidates")
+
+
+def skip_gram_sample(input_tensor,
+ min_skips=1,
+ max_skips=5,
+ start=0,
+ limit=-1,
+ emit_self_as_target=False,
+ vocab_freq_table=None,
+ vocab_min_count=None,
+ vocab_subsampling=None,
+ corpus_size=None,
+ batch_size=None,
+ batch_capacity=None,
+ seed=None,
+ name=None):
+ """Generates skip-gram token and label paired Tensors from the input tensor.
+
+ Generates skip-gram `("token", "label")` pairs using each element in the
+ rank-1 `input_tensor` as a token. The window size used for each token will be
+ randomly selected from the range specified by `[min_skips, max_skips]`,
+ inclusive. See https://arxiv.org/abs/1301.3781 for more details about
+ skip-gram.
+
+ For example, given `input_tensor = ["the", "quick", "brown", "fox", "jumps"]`,
+ `min_skips = 1`, `max_skips = 2`, `emit_self_as_target = False`, the output
+ `(tokens, labels)` pairs for the token "quick" will be randomly selected from
+ either `(tokens=["quick", "quick"], labels=["the", "brown"])` for 1 skip, or
+ `(tokens=["quick", "quick", "quick"], labels=["the", "brown", "fox"])` for 2
+ skips.
+
+ If `emit_self_as_target = True`, each token will also be emitted as a label
+ for itself. From the previous example, the output will be either
+ `(tokens=["quick", "quick", "quick"], labels=["the", "quick", "brown"])` for 1
+ skip, or `(tokens=["quick", "quick", "quick", "quick"], labels=["the",
+ "quick", "brown", "fox"])` for 2 skips.
+
+ The same process is repeated for each element of `input_tensor` and
+ concatenated together into the two output rank-1 `Tensors` (one for all the
+ tokens, another for all the labels).
+
+ If `vocab_freq_table` is specified, tokens in `input_tensor` that are not
+ present in the vocabulary are discarded. Tokens whose frequency counts are
+ below `vocab_min_count` are also discarded. Tokens whose frequency proportions
+ in the corpus exceed `vocab_subsampling` may be randomly down-sampled. See
+ Eq. 5 in http://arxiv.org/abs/1310.4546 for more details about subsampling.
+
+ Due to the random window sizes used for each token, the lengths of the outputs
+ are non-deterministic, unless `batch_size` is specified to batch the outputs
+ to always return `Tensors` of length `batch_size`.
+
+ Args:
+ input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates.
+ min_skips: `int` or scalar `Tensor` specifying the minimum window size to
+ randomly use for each token. Must be >= 0 and <= `max_skips`. If
+ `min_skips` and `max_skips` are both 0, the only label outputted will be
+ the token itself when `emit_self_as_target = True` - or no output
+ otherwise.
+ max_skips: `int` or scalar `Tensor` specifying the maximum window size to
+ randomly use for each token. Must be >= 0.
+ start: `int` or scalar `Tensor` specifying the position in
+ `input_tensor` from which to start generating skip-gram candidates.
+ limit: `int` or scalar `Tensor` specifying the maximum number of
+ elements in `input_tensor` to use in generating skip-gram candidates. -1
+ means to use the rest of the `Tensor` after `start`.
+ emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit
+ each token as a label for itself.
+ vocab_freq_table: (Optional) A lookup table (subclass of
+ `lookup.InitializableLookupTableBase`) that maps tokens to their raw
+ frequency counts. If specified, any token in `input_tensor` that is not
+ found in `vocab_freq_table` will be filtered out before generating
+ skip-gram candidates. While this will typically map to integer raw
+ frequency counts, it could also map to float frequency proportions.
+ `vocab_min_count` and `corpus_size` should be in the same units as this.
+ vocab_min_count: (Optional) `int`, `float`, or scalar `Tensor` specifying
+ minimum frequency threshold (from `vocab_freq_table`) for a token to be
+ kept in `input_tensor`. If this is specified, `vocab_freq_table` must also
+ be specified - and they should both be in the same units.
+ vocab_subsampling: (Optional) `float` specifying frequency proportion
+ threshold for tokens from `input_tensor`. Tokens that occur more
+ frequently (based on the ratio of the token's `vocab_freq_table` value to
+ the `corpus_size`) will be randomly down-sampled. Reasonable starting
+ values may be around 1e-3 or 1e-5. If this is specified, both
+ `vocab_freq_table` and `corpus_size` must also be specified. See Eq. 5
+ in http://arxiv.org/abs/1310.4546 for more details.
+ corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the
+ total number of tokens in the corpus (e.g., sum of all the frequency
+ counts of `vocab_freq_table`). Used with `vocab_subsampling` for
+ down-sampling frequently occurring tokens. If this is specified,
+ `vocab_freq_table` and `vocab_subsampling` must also be specified.
+ batch_size: (Optional) `int` specifying batch size of returned `Tensors`.
+ batch_capacity: (Optional) `int` specifying batch capacity for the queue
+ used for batching returned `Tensors`. Only has an effect if
+ `batch_size` > 0. Defaults to 100 * `batch_size` if not specified.
+ seed: (Optional) `int` used to create a random seed for window size and
+ subsampling. See `set_random_seed` docs for behavior.
+ name: (Optional) A `string` name or a name scope for the operations.
+
+ Returns:
+ A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of
+ rank-1 and has the same type as `input_tensor`. The `Tensors` will be of
+ length `batch_size`; if `batch_size` is not specified, they will be of
+ random length, though they will be in sync with each other as long as they
+ are evaluated together.
+
+ Raises:
+ ValueError: If `vocab_freq_table` is not provided, but `vocab_min_count`,
+ `vocab_subsampling`, or `corpus_size` is specified. If `vocab_subsampling`
+ and `corpus_size` are not both present or both absent.
+ """
+
+ if vocab_freq_table is None and (vocab_min_count is not None or
+ vocab_subsampling is not None or
+ corpus_size is not None):
+ raise ValueError(
+ "vocab_freq_table is not provided, but vocab_min_count={}, "
+ "vocab_subsampling={}, or corpus_size={} is not None. These settings "
+ "are useless without a vocab_freq_table.".format(
+ vocab_min_count, vocab_subsampling, corpus_size))
+
+ if (vocab_subsampling is None) != (corpus_size is None):
+ raise ValueError(
+ "vocab_subsampling is {} while corpus_size is {} - both must be "
+ "provided in order for subsampling to work.".format(
+ vocab_subsampling, corpus_size))
+
+ with ops.name_scope(
+ name,
+ "skip_gram_sample",
+ values=[input_tensor, min_skips, max_skips, start, limit]):
+
+ input_tensor = _filter_input(
+ input_tensor=input_tensor,
+ vocab_freq_table=vocab_freq_table,
+ vocab_min_count=vocab_min_count,
+ vocab_subsampling=vocab_subsampling,
+ corpus_size=corpus_size,
+ seed=seed)
+
+ seed1, seed2 = random_seed.get_seed(seed)
+ tokens, labels = gen_skip_gram_ops.skip_gram_generate_candidates(
+ input_tensor=input_tensor,
+ min_skips=min_skips,
+ max_skips=max_skips,
+ start=start,
+ limit=limit,
+ emit_self_as_target=emit_self_as_target,
+ # Note that seed here should be seed1! This is due to
+ # GuardedPhiloxRandom's hard-coded attributes of "seed" and "seed2".
+ seed=seed1,
+ seed2=seed2)
+
+ # TODO(weiho): If the need arises, add support for sparse input_tensor that
+ # figures out sentence boundaries, then calls
+ # skip_gram_generate_candidates() on each sentence.
+
+ # Batches the (tokens, labels) outputs so that they will be of deterministic
+ # batch_size, to facilitate feeding them into the rest of the network.
+ if batch_size is not None and batch_size > 0:
+ batch_capacity = (batch_capacity
+ if (batch_capacity is not None and batch_capacity > 0)
+ else 100 * batch_size)
+ return input_ops.batch(
+ [tokens, labels],
+ batch_size,
+ capacity=batch_capacity,
+ enqueue_many=True)
+
+ return tokens, labels
+
+
+def skip_gram_sample_with_text_vocab(input_tensor,
+ vocab_freq_file,
+ vocab_token_index=0,
+ vocab_token_dtype=dtypes.string,
+ vocab_freq_index=1,
+ vocab_freq_dtype=dtypes.float64,
+ vocab_delimiter=",",
+ vocab_min_count=0,
+ vocab_subsampling=None,
+ min_skips=1,
+ max_skips=5,
+ start=0,
+ limit=-1,
+ emit_self_as_target=False,
+ batch_size=None,
+ batch_capacity=None,
+ seed=None,
+ name=None):
+ """Skip-gram sampling with a text vocabulary file.
+
+ Wrapper around `skip_gram_sample()` for use with a text vocabulary file. The
+ vocabulary file is expected to be a plain-text file, with lines of
+ `vocab_delimiter`-separated columns. The `vocab_token_index` column should
+ contain the vocabulary term, while the `vocab_freq_index` column should
+ contain the number of times that term occurs in the corpus. For example, with
+ a text vocabulary file of:
+
+ ```
+ bonjour,fr,42
+ hello,en,777
+ hola,es,99
+ ```
+
+ You should set `vocab_delimiter=","`, `vocab_token_index=0`, and
+ `vocab_freq_index=2`.
+
+ See `skip_gram_sample()` documentation for more details about the skip-gram
+ sampling process.
+
+ Args:
+ input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates.
+ vocab_freq_file: `string` specifying full file path to the text vocab file.
+ vocab_token_index: `int` specifying which column in the text vocab file
+ contains the tokens.
+ vocab_token_dtype: `DType` specifying the format of the tokens in the text
+ vocab file.
+ vocab_freq_index: `int` specifying which column in the text vocab file
+ contains the frequency counts of the tokens.
+ vocab_freq_dtype: `DType` specifying the format of the frequency counts in
+ the text vocab file.
+ vocab_delimiter: `string` specifying the delimiter used in the text vocab
+ file.
+ vocab_min_count: `int`, `float`, or scalar `Tensor` specifying
+ minimum frequency threshold (from `vocab_freq_file`) for a token to be
+ kept in `input_tensor`. This should correspond with `vocab_freq_dtype`.
+ vocab_subsampling: (Optional) `float` specifying frequency proportion
+ threshold for tokens from `input_tensor`. Tokens that occur more
+ frequently will be randomly down-sampled. Reasonable starting values may
+ be around 1e-3 or 1e-5. See Eq. 5 in http://arxiv.org/abs/1310.4546 for
+ more details.
+ min_skips: `int` or scalar `Tensor` specifying the minimum window size to
+ randomly use for each token. Must be >= 0 and <= `max_skips`. If
+ `min_skips` and `max_skips` are both 0, the only label outputted will be
+ the token itself.
+ max_skips: `int` or scalar `Tensor` specifying the maximum window size to
+ randomly use for each token. Must be >= 0.
+ start: `int` or scalar `Tensor` specifying the position in `input_tensor`
+ from which to start generating skip-gram candidates.
+ limit: `int` or scalar `Tensor` specifying the maximum number of elements in
+ `input_tensor` to use in generating skip-gram candidates. -1 means to use
+ the rest of the `Tensor` after `start`.
+ emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit
+ each token as a label for itself.
+ batch_size: (Optional) `int` specifying batch size of returned `Tensors`.
+ batch_capacity: (Optional) `int` specifying batch capacity for the queue
+ used for batching returned `Tensors`. Only has an effect if
+ `batch_size` > 0. Defaults to 100 * `batch_size` if not specified.
+ seed: (Optional) `int` used to create a random seed for window size and
+ subsampling. See
+ [`set_random_seed`](../../g3doc/python/constant_op.md#set_random_seed)
+ for behavior.
+ name: (Optional) A `string` name or a name scope for the operations.
+
+ Returns:
+ A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of
+ rank-1 and has the same type as `input_tensor`. The `Tensors` will be of
+ length `batch_size`; if `batch_size` is not specified, they will be of
+ random length, though they will be in sync with each other as long as they
+ are evaluated together.
+
+ Raises:
+ ValueError: If `vocab_token_index` or `vocab_freq_index` is less than 0 or
+ exceeds the number of columns in `vocab_freq_file`. If `vocab_token_index`
+ and `vocab_freq_index` are both set to the same column. If any token in
+ `vocab_freq_file` has a negative frequency.
+ """
+
+ if vocab_token_index < 0 or vocab_freq_index < 0:
+ raise ValueError(
+ "vocab_token_index={} and vocab_freq_index={} must both be >= 0.".
+ format(vocab_token_index, vocab_freq_index))
+ if vocab_token_index == vocab_freq_index:
+ raise ValueError(
+ "vocab_token_index and vocab_freq_index should be different, but are "
+ "both {}.".format(vocab_token_index))
+
+ # Iterates through the vocab file and calculates the number of vocab terms as
+ # well as the total corpus size (by summing the frequency counts of all the
+ # vocab terms).
+ corpus_size = 0.0
+ vocab_size = 0
+ with gfile.GFile(vocab_freq_file, mode="r") as f:
+ reader = csv.reader(f, delimiter=vocab_delimiter)
+ for row in reader:
+ if vocab_token_index >= len(row) or vocab_freq_index >= len(row):
+ raise ValueError(
+ "Row in vocab file only has {} columns, so vocab_token_index={} or "
+ "vocab_freq_index={} is out of bounds. Row content: {}".format(
+ len(row), vocab_token_index, vocab_freq_index, row))
+ vocab_size += 1
+ freq = vocab_freq_dtype.as_numpy_dtype(row[vocab_freq_index])
+ if freq < 0:
+ raise ValueError(
+ "Row in vocab file has negative frequency of {}. Row content: {}".
+ format(freq, row))
+ # Note: tokens whose frequencies are below vocab_min_count will still
+ # contribute to the total corpus size used for vocab subsampling.
+ corpus_size += freq
+
+ vocab_freq_table = lookup.HashTable(
+ lookup.TextFileInitializer(
+ filename=vocab_freq_file,
+ key_dtype=vocab_token_dtype,
+ key_index=vocab_token_index,
+ value_dtype=vocab_freq_dtype,
+ value_index=vocab_freq_index,
+ vocab_size=vocab_size,
+ delimiter=vocab_delimiter),
+ # For vocab terms not in vocab file, use a default value of -1.
+ default_value=-1)
+
+ return skip_gram_sample(
+ input_tensor,
+ min_skips=min_skips,
+ max_skips=max_skips,
+ start=start,
+ limit=limit,
+ emit_self_as_target=emit_self_as_target,
+ vocab_freq_table=vocab_freq_table,
+ vocab_min_count=vocab_min_count,
+ vocab_subsampling=vocab_subsampling,
+ # corpus_size is not used unless vocab_subsampling is specified.
+ corpus_size=None if vocab_subsampling is None else corpus_size,
+ batch_size=batch_size,
+ batch_capacity=batch_capacity,
+ seed=seed,
+ name=name)
+
+
+def _filter_input(input_tensor, vocab_freq_table, vocab_min_count,
+ vocab_subsampling, corpus_size, seed):
+ """Filters input tensor based on vocab freq, threshold, and subsampling."""
+ if vocab_freq_table is None:
+ return input_tensor
+
+ if not isinstance(vocab_freq_table, lookup.InitializableLookupTableBase):
+ raise ValueError(
+ "vocab_freq_table must be a subclass of "
+ "InitializableLookupTableBase (such as HashTable) instead of type "
+ "{}.".format(type(vocab_freq_table)))
+
+ with ops.name_scope(
+ "filter_vocab", values=[vocab_freq_table, input_tensor, vocab_min_count]):
+ freq = vocab_freq_table.lookup(input_tensor)
+ # Filters out elements in input_tensor that are not found in
+ # vocab_freq_table (table returns a default value of -1 specified above when
+ # an element is not found).
+ mask = math_ops.not_equal(freq, vocab_freq_table.default_value)
+
+ # Filters out elements whose vocab frequencies are less than the threshold.
+ if vocab_min_count is not None:
+ cast_threshold = math_ops.cast(vocab_min_count, freq.dtype)
+ mask = math_ops.logical_and(mask,
+ math_ops.greater_equal(freq, cast_threshold))
+
+ input_tensor = array_ops.boolean_mask(input_tensor, mask)
+ freq = array_ops.boolean_mask(freq, mask)
+
+ if not vocab_subsampling:
+ return input_tensor
+
+ if vocab_subsampling < 0 or vocab_subsampling > 1:
+ raise ValueError(
+ "Invalid vocab_subsampling={} - it should be within range [0, 1].".
+ format(vocab_subsampling))
+
+ # Subsamples the input tokens based on vocabulary frequency and
+ # vocab_subsampling threshold (ie randomly discard commonly appearing
+ # tokens).
+ with ops.name_scope(
+ "subsample_vocab", values=[input_tensor, freq, vocab_subsampling]):
+ corpus_size = math_ops.cast(corpus_size, dtypes.float64)
+ freq = math_ops.cast(freq, dtypes.float64)
+ vocab_subsampling = math_ops.cast(vocab_subsampling, dtypes.float64)
+
+ # From tensorflow_models/tutorials/embedding/word2vec_kernels.cc, which is
+ # suppose to correlate with Eq. 5 in http://arxiv.org/abs/1310.4546.
+ keep_prob = ((math_ops.sqrt(freq /
+ (vocab_subsampling * corpus_size)) + 1.0) *
+ (vocab_subsampling * corpus_size / freq))
+ random_prob = random_ops.random_uniform(
+ array_ops.shape(freq),
+ minval=0,
+ maxval=1,
+ dtype=dtypes.float64,
+ seed=seed)
+
+ mask = math_ops.less_equal(random_prob, keep_prob)
+ return array_ops.boolean_mask(input_tensor, mask)
diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
new file mode 100644
index 0000000000..d989942f73
--- /dev/null
+++ b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
@@ -0,0 +1,571 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Skip-gram sampling ops tests."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import csv
+import os
+
+from tensorflow.contrib import lookup
+from tensorflow.contrib import text
+from tensorflow.contrib.text.python.ops import skip_gram_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import coordinator
+from tensorflow.python.training import queue_runner_impl
+
+
+class SkipGramOpsTest(test.TestCase):
+
+ def _split_tokens_labels(self, output):
+ tokens = [x[0] for x in output]
+ labels = [x[1] for x in output]
+ return tokens, labels
+
+ def test_skip_gram_sample_skips_2(self):
+ """Tests skip-gram with min_skips = max_skips = 2."""
+ input_tensor = constant_op.constant(
+ [b"the", b"quick", b"brown", b"fox", b"jumps"])
+ tokens, labels = text.skip_gram_sample(
+ input_tensor, min_skips=2, max_skips=2)
+ expected_tokens, expected_labels = self._split_tokens_labels([
+ (b"the", b"quick"),
+ (b"the", b"brown"),
+ (b"quick", b"the"),
+ (b"quick", b"brown"),
+ (b"quick", b"fox"),
+ (b"brown", b"the"),
+ (b"brown", b"quick"),
+ (b"brown", b"fox"),
+ (b"brown", b"jumps"),
+ (b"fox", b"quick"),
+ (b"fox", b"brown"),
+ (b"fox", b"jumps"),
+ (b"jumps", b"brown"),
+ (b"jumps", b"fox"),
+ ])
+ with self.test_session():
+ self.assertAllEqual(expected_tokens, tokens.eval())
+ self.assertAllEqual(expected_labels, labels.eval())
+
+ def test_skip_gram_sample_emit_self(self):
+ """Tests skip-gram with emit_self_as_target = True."""
+ input_tensor = constant_op.constant(
+ [b"the", b"quick", b"brown", b"fox", b"jumps"])
+ tokens, labels = text.skip_gram_sample(
+ input_tensor, min_skips=2, max_skips=2, emit_self_as_target=True)
+ expected_tokens, expected_labels = self._split_tokens_labels([
+ (b"the", b"the"),
+ (b"the", b"quick"),
+ (b"the", b"brown"),
+ (b"quick", b"the"),
+ (b"quick", b"quick"),
+ (b"quick", b"brown"),
+ (b"quick", b"fox"),
+ (b"brown", b"the"),
+ (b"brown", b"quick"),
+ (b"brown", b"brown"),
+ (b"brown", b"fox"),
+ (b"brown", b"jumps"),
+ (b"fox", b"quick"),
+ (b"fox", b"brown"),
+ (b"fox", b"fox"),
+ (b"fox", b"jumps"),
+ (b"jumps", b"brown"),
+ (b"jumps", b"fox"),
+ (b"jumps", b"jumps"),
+ ])
+ with self.test_session():
+ self.assertAllEqual(expected_tokens, tokens.eval())
+ self.assertAllEqual(expected_labels, labels.eval())
+
+ def test_skip_gram_sample_skips_0(self):
+ """Tests skip-gram with min_skips = max_skips = 0."""
+ input_tensor = constant_op.constant([b"the", b"quick", b"brown"])
+
+ # If emit_self_as_target is False (default), output will be empty.
+ tokens, labels = text.skip_gram_sample(
+ input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False)
+ with self.test_session():
+ self.assertEqual(0, tokens.eval().size)
+ self.assertEqual(0, labels.eval().size)
+
+ # If emit_self_as_target is True, each token will be its own label.
+ tokens, labels = text.skip_gram_sample(
+ input_tensor, min_skips=0, max_skips=0, emit_self_as_target=True)
+ expected_tokens, expected_labels = self._split_tokens_labels([
+ (b"the", b"the"),
+ (b"quick", b"quick"),
+ (b"brown", b"brown"),
+ ])
+ with self.test_session():
+ self.assertAllEqual(expected_tokens, tokens.eval())
+ self.assertAllEqual(expected_labels, labels.eval())
+
+ def test_skip_gram_sample_skips_exceed_length(self):
+ """Tests skip-gram when min/max_skips exceed length of input."""
+ input_tensor = constant_op.constant([b"the", b"quick", b"brown"])
+ tokens, labels = text.skip_gram_sample(
+ input_tensor, min_skips=100, max_skips=100)
+ expected_tokens, expected_labels = self._split_tokens_labels([
+ (b"the", b"quick"),
+ (b"the", b"brown"),
+ (b"quick", b"the"),
+ (b"quick", b"brown"),
+ (b"brown", b"the"),
+ (b"brown", b"quick"),
+ ])
+ with self.test_session():
+ self.assertAllEqual(expected_tokens, tokens.eval())
+ self.assertAllEqual(expected_labels, labels.eval())
+
+ def test_skip_gram_sample_start_limit(self):
+ """Tests skip-gram over a limited portion of the input."""
+ input_tensor = constant_op.constant(
+ [b"foo", b"the", b"quick", b"brown", b"bar"])
+ tokens, labels = text.skip_gram_sample(
+ input_tensor, min_skips=1, max_skips=1, start=1, limit=3)
+ expected_tokens, expected_labels = self._split_tokens_labels([
+ (b"the", b"quick"),
+ (b"quick", b"the"),
+ (b"quick", b"brown"),
+ (b"brown", b"quick"),
+ ])
+ with self.test_session():
+ self.assertAllEqual(expected_tokens, tokens.eval())
+ self.assertAllEqual(expected_labels, labels.eval())
+
+ def test_skip_gram_sample_limit_exceeds(self):
+ """Tests skip-gram when limit exceeds the length of the input."""
+ input_tensor = constant_op.constant([b"foo", b"the", b"quick", b"brown"])
+ tokens, labels = text.skip_gram_sample(
+ input_tensor, min_skips=1, max_skips=1, start=1, limit=100)
+ expected_tokens, expected_labels = self._split_tokens_labels([
+ (b"the", b"quick"),
+ (b"quick", b"the"),
+ (b"quick", b"brown"),
+ (b"brown", b"quick"),
+ ])
+ with self.test_session():
+ self.assertAllEqual(expected_tokens, tokens.eval())
+ self.assertAllEqual(expected_labels, labels.eval())
+
+ def test_skip_gram_sample_random_skips(self):
+ """Tests skip-gram with min_skips != max_skips, with random output."""
+ # The number of outputs is non-deterministic in this case, so set random
+ # seed to help ensure the outputs remain constant for this test case.
+ random_seed.set_random_seed(42)
+
+ input_tensor = constant_op.constant(
+ [b"the", b"quick", b"brown", b"fox", b"jumps", b"over"])
+ tokens, labels = text.skip_gram_sample(
+ input_tensor, min_skips=1, max_skips=2, seed=9)
+ expected_tokens, expected_labels = self._split_tokens_labels([
+ (b"the", b"quick"),
+ (b"the", b"brown"),
+ (b"quick", b"the"),
+ (b"quick", b"brown"),
+ (b"quick", b"fox"),
+ (b"brown", b"the"),
+ (b"brown", b"quick"),
+ (b"brown", b"fox"),
+ (b"brown", b"jumps"),
+ (b"fox", b"brown"),
+ (b"fox", b"jumps"),
+ (b"jumps", b"fox"),
+ (b"jumps", b"over"),
+ (b"over", b"fox"),
+ (b"over", b"jumps"),
+ ])
+ with self.test_session() as sess:
+ tokens_eval, labels_eval = sess.run([tokens, labels])
+ self.assertAllEqual(expected_tokens, tokens_eval)
+ self.assertAllEqual(expected_labels, labels_eval)
+
+ def test_skip_gram_sample_random_skips_default_seed(self):
+ """Tests outputs are still random when no op-level seed is specified."""
+ # This is needed since tests set a graph-level seed by default. We want to
+ # explicitly avoid setting both graph-level seed and op-level seed, to
+ # simulate behavior under non-test settings when the user doesn't provide a
+ # seed to us. This results in random_seed.get_seed() returning None for both
+ # seeds, forcing the C++ kernel to execute its default seed logic.
+ random_seed.set_random_seed(None)
+
+ # Uses an input tensor with 10 words, with possible skip ranges in [1,
+ # 5]. Thus, the probability that two random samplings would result in the
+ # same outputs is 1/5^10 ~ 1e-7 (aka the probability of this test being
+ # flaky).
+ input_tensor = constant_op.constant([str(x) for x in range(10)])
+
+ # Do not provide an op-level seed here!
+ tokens_1, labels_1 = text.skip_gram_sample(
+ input_tensor, min_skips=1, max_skips=5)
+ tokens_2, labels_2 = text.skip_gram_sample(
+ input_tensor, min_skips=1, max_skips=5)
+
+ with self.test_session() as sess:
+ tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run(
+ [tokens_1, labels_1, tokens_2, labels_2])
+
+ if len(tokens_1_eval) == len(tokens_2_eval):
+ self.assertNotEqual(tokens_1_eval.tolist(), tokens_2_eval.tolist())
+ if len(labels_1_eval) == len(labels_2_eval):
+ self.assertNotEqual(labels_1_eval.tolist(), labels_2_eval.tolist())
+
+ def test_skip_gram_sample_batch(self):
+ """Tests skip-gram with batching."""
+ input_tensor = constant_op.constant([b"the", b"quick", b"brown", b"fox"])
+ tokens, labels = text.skip_gram_sample(
+ input_tensor, min_skips=1, max_skips=1, batch_size=3)
+ expected_tokens, expected_labels = self._split_tokens_labels([
+ (b"the", b"quick"),
+ (b"quick", b"the"),
+ (b"quick", b"brown"),
+ (b"brown", b"quick"),
+ (b"brown", b"fox"),
+ (b"fox", b"brown"),
+ ])
+ with self.test_session() as sess:
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
+
+ tokens_eval, labels_eval = sess.run([tokens, labels])
+ self.assertAllEqual(expected_tokens[:3], tokens_eval)
+ self.assertAllEqual(expected_labels[:3], labels_eval)
+ tokens_eval, labels_eval = sess.run([tokens, labels])
+ self.assertAllEqual(expected_tokens[3:6], tokens_eval)
+ self.assertAllEqual(expected_labels[3:6], labels_eval)
+
+ coord.request_stop()
+ coord.join(threads)
+
+ def test_skip_gram_sample_non_string_input(self):
+ """Tests skip-gram with non-string input."""
+ input_tensor = constant_op.constant([1, 2, 3], dtype=dtypes.int16)
+ tokens, labels = text.skip_gram_sample(
+ input_tensor, min_skips=1, max_skips=1)
+ expected_tokens, expected_labels = self._split_tokens_labels([
+ (1, 2),
+ (2, 1),
+ (2, 3),
+ (3, 2),
+ ])
+ with self.test_session():
+ self.assertAllEqual(expected_tokens, tokens.eval())
+ self.assertAllEqual(expected_labels, labels.eval())
+
+ def test_skip_gram_sample_errors(self):
+ """Tests various errors raised by skip_gram_sample()."""
+ input_tensor = constant_op.constant([b"the", b"quick", b"brown"])
+
+ invalid_skips = (
+ # min_skips and max_skips must be >= 0.
+ (-1, 2),
+ (1, -2),
+ # min_skips must be <= max_skips.
+ (2, 1))
+ for min_skips, max_skips in invalid_skips:
+ tokens, labels = text.skip_gram_sample(
+ input_tensor, min_skips=min_skips, max_skips=max_skips)
+ with self.test_session() as sess, self.assertRaises(
+ errors.InvalidArgumentError):
+ sess.run([tokens, labels])
+
+ # input_tensor must be of rank 1.
+ with self.assertRaises(ValueError):
+ invalid_tensor = constant_op.constant([[b"the"], [b"quick"], [b"brown"]])
+ text.skip_gram_sample(invalid_tensor)
+
+ # vocab_freq_table must be provided if vocab_min_count, vocab_subsampling,
+ # or corpus_size is specified.
+ dummy_input = constant_op.constant([""])
+ with self.assertRaises(ValueError):
+ text.skip_gram_sample(
+ dummy_input, vocab_freq_table=None, vocab_min_count=1)
+ with self.assertRaises(ValueError):
+ text.skip_gram_sample(
+ dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5)
+ with self.assertRaises(ValueError):
+ text.skip_gram_sample(dummy_input, vocab_freq_table=None, corpus_size=100)
+ with self.assertRaises(ValueError):
+ text.skip_gram_sample(
+ dummy_input,
+ vocab_freq_table=None,
+ vocab_subsampling=1e-5,
+ corpus_size=100)
+
+ # vocab_subsampling and corpus_size must both be present or absent.
+ dummy_table = lookup.HashTable(
+ lookup.KeyValueTensorInitializer([b"foo"], [10]), -1)
+ with self.assertRaises(ValueError):
+ text.skip_gram_sample(
+ dummy_input,
+ vocab_freq_table=dummy_table,
+ vocab_subsampling=None,
+ corpus_size=100)
+ with self.assertRaises(ValueError):
+ text.skip_gram_sample(
+ dummy_input,
+ vocab_freq_table=dummy_table,
+ vocab_subsampling=1e-5,
+ corpus_size=None)
+
+ def test_filter_input_filter_vocab(self):
+ """Tests input filtering based on vocab frequency table and thresholds."""
+ input_tensor = constant_op.constant(
+ [b"the", b"answer", b"to", b"life", b"and", b"universe"])
+ keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"])
+ values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
+ vocab_freq_table = lookup.HashTable(
+ lookup.KeyValueTensorInitializer(keys, values), -1)
+
+ with self.test_session():
+ vocab_freq_table.init.run()
+
+ # No vocab_freq_table specified - output should be the same as input.
+ no_table_output = skip_gram_ops._filter_input(
+ input_tensor=input_tensor,
+ vocab_freq_table=None,
+ vocab_min_count=None,
+ vocab_subsampling=None,
+ corpus_size=None,
+ seed=None)
+ self.assertAllEqual(input_tensor.eval(), no_table_output.eval())
+
+ # vocab_freq_table specified, but no vocab_min_count - output should have
+ # filtered out tokens not in the table (b"answer").
+ table_output = skip_gram_ops._filter_input(
+ input_tensor=input_tensor,
+ vocab_freq_table=vocab_freq_table,
+ vocab_min_count=None,
+ vocab_subsampling=None,
+ corpus_size=None,
+ seed=None)
+ self.assertAllEqual([b"the", b"to", b"life", b"and", b"universe"],
+ table_output.eval())
+
+ # vocab_freq_table and vocab_min_count specified - output should have
+ # filtered out tokens whose frequencies are below the threshold
+ # (b"and": 0, b"life": 1).
+ threshold_output = skip_gram_ops._filter_input(
+ input_tensor=input_tensor,
+ vocab_freq_table=vocab_freq_table,
+ vocab_min_count=2,
+ vocab_subsampling=None,
+ corpus_size=None,
+ seed=None)
+ self.assertAllEqual([b"the", b"to", b"universe"], threshold_output.eval())
+
+ def test_filter_input_subsample_vocab(self):
+ """Tests input filtering based on vocab subsampling."""
+ # The outputs are non-deterministic, so set random seed to help ensure that
+ # the outputs remain constant for testing.
+ random_seed.set_random_seed(42)
+
+ input_tensor = constant_op.constant([
+ # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57.
+ b"the",
+ b"answer", # Not in vocab. (Always discarded)
+ b"to", # keep_prob = 0.75.
+ b"life", # keep_prob > 1. (Always kept)
+ b"and", # keep_prob = 0.48.
+ b"universe" # Below vocab threshold of 3. (Always discarded)
+ ])
+ keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"])
+ values = constant_op.constant([40, 8, 30, 20, 2], dtypes.int64)
+ vocab_freq_table = lookup.HashTable(
+ lookup.KeyValueTensorInitializer(keys, values), -1)
+
+ with self.test_session():
+ vocab_freq_table.init.run()
+ output = skip_gram_ops._filter_input(
+ input_tensor=input_tensor,
+ vocab_freq_table=vocab_freq_table,
+ vocab_min_count=3,
+ vocab_subsampling=0.05,
+ corpus_size=math_ops.reduce_sum(values),
+ seed=9)
+ self.assertAllEqual([b"the", b"to", b"life", b"and"], output.eval())
+
+ def _make_text_vocab_freq_file(self):
+ filepath = os.path.join(test.get_temp_dir(), "vocab_freq.txt")
+ with open(filepath, "w") as f:
+ writer = csv.writer(f)
+ writer.writerows([
+ ["and", 40],
+ ["life", 8],
+ ["the", 30],
+ ["to", 20],
+ ["universe", 2],
+ ])
+ return filepath
+
+ def _make_text_vocab_float_file(self):
+ filepath = os.path.join(test.get_temp_dir(), "vocab_freq_float.txt")
+ with open(filepath, "w") as f:
+ writer = csv.writer(f)
+ writer.writerows([
+ ["and", 0.4],
+ ["life", 0.08],
+ ["the", 0.3],
+ ["to", 0.2],
+ ["universe", 0.02],
+ ])
+ return filepath
+
+ def test_skip_gram_sample_with_text_vocab_filter_vocab(self):
+ """Tests skip-gram sampling with text vocab and freq threshold filtering."""
+ input_tensor = constant_op.constant([
+ b"the",
+ b"answer", # Will be filtered before candidate generation.
+ b"to",
+ b"life",
+ b"and",
+ b"universe" # Will be filtered before candidate generation.
+ ])
+
+ # b"answer" is not in vocab file, and b"universe"'s frequency is below
+ # threshold of 3.
+ vocab_freq_file = self._make_text_vocab_freq_file()
+
+ tokens, labels = text.skip_gram_sample_with_text_vocab(
+ input_tensor=input_tensor,
+ vocab_freq_file=vocab_freq_file,
+ vocab_token_index=0,
+ vocab_freq_index=1,
+ vocab_min_count=3,
+ min_skips=1,
+ max_skips=1)
+
+ expected_tokens, expected_labels = self._split_tokens_labels([
+ (b"the", b"to"),
+ (b"to", b"the"),
+ (b"to", b"life"),
+ (b"life", b"to"),
+ (b"life", b"and"),
+ (b"and", b"life"),
+ ])
+ with self.test_session():
+ lookup_ops.tables_initializer().run()
+ self.assertAllEqual(expected_tokens, tokens.eval())
+ self.assertAllEqual(expected_labels, labels.eval())
+
+ def _text_vocab_subsample_vocab_helper(self, vocab_freq_file, vocab_min_count,
+ vocab_freq_dtype):
+ # The outputs are non-deterministic, so set random seed to help ensure that
+ # the outputs remain constant for testing.
+ random_seed.set_random_seed(42)
+
+ input_tensor = constant_op.constant([
+ # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57.
+ b"the",
+ b"answer", # Not in vocab. (Always discarded)
+ b"to", # keep_prob = 0.75.
+ b"life", # keep_prob > 1. (Always kept)
+ b"and", # keep_prob = 0.48.
+ b"universe" # Below vocab threshold of 3. (Always discarded)
+ ])
+ # keep_prob calculated from vocab file with relative frequencies of:
+ # and: 40
+ # life: 8
+ # the: 30
+ # to: 20
+ # universe: 2
+
+ tokens, labels = text.skip_gram_sample_with_text_vocab(
+ input_tensor=input_tensor,
+ vocab_freq_file=vocab_freq_file,
+ vocab_token_index=0,
+ vocab_freq_index=1,
+ vocab_freq_dtype=vocab_freq_dtype,
+ vocab_min_count=vocab_min_count,
+ vocab_subsampling=0.05,
+ min_skips=1,
+ max_skips=1,
+ seed=123)
+
+ expected_tokens, expected_labels = self._split_tokens_labels([
+ (b"the", b"to"),
+ (b"to", b"the"),
+ (b"to", b"life"),
+ (b"life", b"to"),
+ ])
+ with self.test_session() as sess:
+ lookup_ops.tables_initializer().run()
+ tokens_eval, labels_eval = sess.run([tokens, labels])
+ self.assertAllEqual(expected_tokens, tokens_eval)
+ self.assertAllEqual(expected_labels, labels_eval)
+
+ def test_skip_gram_sample_with_text_vocab_subsample_vocab(self):
+ """Tests skip-gram sampling with text vocab and vocab subsampling."""
+ # Vocab file frequencies
+ # and: 40
+ # life: 8
+ # the: 30
+ # to: 20
+ # universe: 2
+ self._text_vocab_subsample_vocab_helper(
+ vocab_freq_file=self._make_text_vocab_freq_file(),
+ vocab_min_count=3,
+ vocab_freq_dtype=dtypes.int64)
+
+ def test_skip_gram_sample_with_text_vocab_subsample_vocab_float(self):
+ """Tests skip-gram sampling with text vocab and subsampling with floats."""
+ # Vocab file frequencies
+ # and: 0.4
+ # life: 0.08
+ # the: 0.3
+ # to: 0.2
+ # universe: 0.02
+ self._text_vocab_subsample_vocab_helper(
+ vocab_freq_file=self._make_text_vocab_float_file(),
+ vocab_min_count=0.03,
+ vocab_freq_dtype=dtypes.float32)
+
+ def test_skip_gram_sample_with_text_vocab_errors(self):
+ """Tests various errors raised by skip_gram_sample_with_text_vocab()."""
+ dummy_input = constant_op.constant([""])
+ vocab_freq_file = self._make_text_vocab_freq_file()
+
+ invalid_indices = (
+ # vocab_token_index can't be negative.
+ (-1, 0),
+ # vocab_freq_index can't be negative.
+ (0, -1),
+ # vocab_token_index can't be equal to vocab_freq_index.
+ (0, 0),
+ (1, 1),
+ # vocab_freq_file only has two columns.
+ (0, 2),
+ (2, 0))
+
+ for vocab_token_index, vocab_freq_index in invalid_indices:
+ with self.assertRaises(ValueError):
+ text.skip_gram_sample_with_text_vocab(
+ input_tensor=dummy_input,
+ vocab_freq_file=vocab_freq_file,
+ vocab_token_index=vocab_token_index,
+ vocab_freq_index=vocab_freq_index)
+
+
+if __name__ == "__main__":
+ test.main()