aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kernel_methods
diff options
context:
space:
mode:
authorGravatar Petros Mol <pmol@google.com>2017-03-25 16:52:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-25 18:13:42 -0700
commitb04954acb444f7e799ff97c65aedb4d6bc06d298 (patch)
tree0cb6bb654f66f7ccde0e6b67266c3126d48168c4 /tensorflow/contrib/kernel_methods
parent396b6bd1af7bd5a9295b13f30c5ed34e7de42daa (diff)
This CL adds:
- an approximate kernel mapper for the RBF kernel (based on Random Fourier Features) - a (canned) tf.learn kernel-based classifier. Change: 151237967
Diffstat (limited to 'tensorflow/contrib/kernel_methods')
-rw-r--r--tensorflow/contrib/kernel_methods/BUILD76
-rw-r--r--tensorflow/contrib/kernel_methods/__init__.py29
-rw-r--r--tensorflow/contrib/kernel_methods/python/kernel_estimators.py335
-rw-r--r--tensorflow/contrib/kernel_methods/python/kernel_estimators_test.py269
-rw-r--r--tensorflow/contrib/kernel_methods/python/mappers/dense_kernel_mapper.py59
-rw-r--r--tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py157
-rw-r--r--tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py166
7 files changed, 1091 insertions, 0 deletions
diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD
new file mode 100644
index 0000000000..b37cbc119f
--- /dev/null
+++ b/tensorflow/contrib/kernel_methods/BUILD
@@ -0,0 +1,76 @@
+# Description:
+# Contains kernel methods for TensorFlow.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+ name = "kernel_methods",
+ srcs = [
+ "__init__.py",
+ "python/kernel_estimators.py",
+ "python/mappers/random_fourier_features.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dense_kernel_mapper_py",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/contrib/learn",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "dense_kernel_mapper_py",
+ srcs = ["python/mappers/dense_kernel_mapper.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "random_fourier_features_test",
+ srcs = ["python/mappers/random_fourier_features_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dense_kernel_mapper_py",
+ ":kernel_methods",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:nn",
+ "//tensorflow/python:ops",
+ ],
+)
+
+py_test(
+ name = "kernel_estimators_test",
+ srcs = ["python/kernel_estimators_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":kernel_methods",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/kernel_methods/__init__.py b/tensorflow/contrib/kernel_methods/__init__.py
new file mode 100644
index 0000000000..1a3a0ab77a
--- /dev/null
+++ b/tensorflow/contrib/kernel_methods/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops for evaluation metrics and summary statistics.
+
+@@KernelLinearClassifier
+@@RandomFourierFeatureMapper
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kernel_methods.python.kernel_estimators import KernelLinearClassifier
+from tensorflow.contrib.kernel_methods.python.mappers import dense_kernel_mapper
+from tensorflow.contrib.kernel_methods.python.mappers.random_fourier_features import RandomFourierFeatureMapper
+
+from tensorflow.python.util.all_util import remove_undocumented
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/kernel_methods/python/kernel_estimators.py b/tensorflow/contrib/kernel_methods/python/kernel_estimators.py
new file mode 100644
index 0000000000..8037082487
--- /dev/null
+++ b/tensorflow/contrib/kernel_methods/python/kernel_estimators.py
@@ -0,0 +1,335 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Estimators that combine explicit kernel mappings with linear models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.contrib import layers
+from tensorflow.contrib.kernel_methods.python.mappers import dense_kernel_mapper as dkm
+from tensorflow.contrib.learn.python.learn.estimators import estimator
+from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
+from tensorflow.contrib.learn.python.learn.estimators import linear
+from tensorflow.contrib.learn.python.learn.estimators import prediction_key
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import tf_logging as logging
+
+_FEATURE_COLUMNS = "feature_columns"
+_KERNEL_MAPPERS = "kernel_mappers"
+_OPTIMIZER = "optimizer"
+
+
+def _check_valid_kernel_mappers(kernel_mappers):
+ """Checks that the input kernel_mappers are valid."""
+ if kernel_mappers is None:
+ return True
+ for kernel_mappers_list in six.itervalues(kernel_mappers):
+ for kernel_mapper in kernel_mappers_list:
+ if not isinstance(kernel_mapper, dkm.DenseKernelMapper):
+ return False
+ return True
+
+
+def _check_valid_head(head):
+ """Returns true if the provided head is supported."""
+ if head is None:
+ return False
+ # pylint: disable=protected-access
+ return isinstance(head, head_lib._BinaryLogisticHead) or isinstance(
+ head, head_lib._MultiClassHead)
+ # pylint: enable=protected-access
+
+
+def _update_features_and_columns(features, feature_columns,
+ kernel_mappers_dict):
+ """Updates features and feature_columns based on provided kernel mappers.
+
+ Currently supports the update of RealValuedColumns only.
+
+ Args:
+ features: Initial features dict. The key is a `string` (feature column name)
+ and the value is a tensor.
+ feature_columns: Initial iterable containing all the feature columns to be
+ consumed (possibly after being updated) by the model. All items should be
+ instances of classes derived from `FeatureColumn`.
+ kernel_mappers_dict: A dict from feature column (type: _FeatureColumn) to
+ objects inheriting from KernelMapper class.
+
+ Returns:
+ updated features and feature_columns based on provided kernel_mappers_dict.
+ """
+ if kernel_mappers_dict is None:
+ return features, feature_columns
+
+ # First construct new columns and features affected by kernel_mappers_dict.
+ mapped_features = dict()
+ mapped_columns = set()
+ for feature_column in kernel_mappers_dict:
+ column_name = feature_column.name
+ # Currently only mappings over RealValuedColumns are supported.
+ if not isinstance(feature_column, layers.feature_column._RealValuedColumn): # pylint: disable=protected-access
+ logging.warning(
+ "Updates are currently supported on RealValuedColumns only. Metadata "
+ "for FeatureColumn {} will not be updated.".format(column_name))
+ continue
+ mapped_column_name = column_name + "_MAPPED"
+ # Construct new feature columns based on provided kernel_mappers.
+ column_kernel_mappers = kernel_mappers_dict[feature_column]
+ new_dim = sum([mapper.output_dim for mapper in column_kernel_mappers])
+ mapped_columns.add(
+ layers.feature_column.real_valued_column(mapped_column_name, new_dim))
+
+ # Get mapped features by concatenating mapped tensors (one mapped tensor
+ # per kernel mappers from the list of kernel mappers corresponding to each
+ # feature column).
+ output_tensors = []
+ for kernel_mapper in column_kernel_mappers:
+ output_tensors.append(kernel_mapper.map(features[column_name]))
+ tensor = array_ops.concat(output_tensors, 1)
+ mapped_features[mapped_column_name] = tensor
+
+ # Finally update features dict and feature_columns.
+ features = features.copy()
+ features.update(mapped_features)
+ feature_columns = set(feature_columns)
+ feature_columns.update(mapped_columns)
+
+ return features, feature_columns
+
+
+def _kernel_model_fn(features, labels, mode, params, config=None):
+ """model_fn for the Estimator using kernel methods.
+
+ Args:
+ features: `Tensor` or dict of `Tensor` (depends on data passed to `fit`).
+ labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
+ dtype `int32` or `int64` in the range `[0, n_classes)`.
+ mode: Defines whether this is training, evaluation or prediction. See
+ `ModeKeys`.
+ params: A dict of hyperparameters.
+ The following hyperparameters are expected:
+ * head: A `Head` instance.
+ * feature_columns: An iterable containing all the feature columns used by
+ the model.
+ * optimizer: string, `Optimizer` object, or callable that defines the
+ optimizer to use for training. If `None`, will use a FTRL optimizer.
+ * kernel_mappers: Dictionary of kernel mappers to be applied to the input
+ features before training.
+ config: `RunConfig` object to configure the runtime settings.
+
+ Returns:
+ A `ModelFnOps` instance.
+
+ Raises:
+ ValueError: If mode is not any of the `ModeKeys`.
+ """
+ feature_columns = params[_FEATURE_COLUMNS]
+ kernel_mappers = params[_KERNEL_MAPPERS]
+
+ updated_features, updated_columns = _update_features_and_columns(
+ features, feature_columns, kernel_mappers)
+ params[_FEATURE_COLUMNS] = updated_columns
+
+ return linear._linear_model_fn( # pylint: disable=protected-access
+ updated_features, labels, mode, params, config)
+
+
+class _KernelEstimator(estimator.Estimator):
+ """Generic kernel-based linear estimator."""
+
+ def __init__(self,
+ feature_columns=None,
+ model_dir=None,
+ weight_column_name=None,
+ head=None,
+ optimizer=None,
+ kernel_mappers=None,
+ config=None):
+ """Constructs a `_KernelEstimator` object."""
+ if not feature_columns and not kernel_mappers:
+ raise ValueError(
+ "You should set at least one of feature_columns, kernel_mappers.")
+ if not _check_valid_kernel_mappers(kernel_mappers):
+ raise ValueError("Invalid kernel mappers.")
+
+ if not _check_valid_head(head):
+ raise ValueError(
+ "head type: {} is not supported. Supported head types: "
+ "_BinaryLogisticHead, _MultiClassHead.".format(type(head)))
+
+ params = {
+ "head": head,
+ _FEATURE_COLUMNS: feature_columns or [],
+ _OPTIMIZER: optimizer,
+ _KERNEL_MAPPERS: kernel_mappers
+ }
+ super(_KernelEstimator, self).__init__(
+ model_fn=_kernel_model_fn,
+ model_dir=model_dir,
+ config=config,
+ params=params)
+
+
+class KernelLinearClassifier(_KernelEstimator):
+ """Linear classifier using kernel methods as feature preprocessing.
+
+ It trains a linear model after possibly mapping initial input features into
+ a mapped space using explicit kernel mappings. Due to the kernel mappings,
+ training a linear classifier in the mapped (output) space can detect
+ non-linearities in the input space.
+
+ The user can provide a list of kernel mappers to be applied to all or a subset
+ of existing feature_columns. This way, the user can effectively provide 2
+ types of feature columns:
+ - those passed as elements of feature_columns in the classifier's constructor
+ - those appearing as a key of the kernel_mappers dict.
+ If a column appears in feature_columns only, no mapping is applied to it. If
+ it appears as a key in kernel_mappers, the corresponding kernel mappers are
+ applied to it. Note that it is possible that a column appears in both places.
+ Currently kernel_mappers are supported for _RealValuedColumns only.
+
+ Example usage:
+ ```
+ real_column_a = real_valued_column(name='real_column_a',...)
+ sparse_column_b = sparse_column_with_hash_bucket(...)
+ kernel_mappers = {real_column_a : [RandomFourierFeatureMapper(...)]}
+ optimizer = ...
+
+ # real_column_a is used as a feature in both its initial and its transformed
+ # (mapped) form. sparse_column_b is not affected by kernel mappers.
+ kernel_classifier = KernelLinearClassifier(
+ feature_columns=[real_column_a, sparse_column_b],
+ model_dir=...,
+ optimizer=optimizer,
+ kernel_mappers=kernel_mappers)
+
+ # real_column_a is used as a feature in its transformed (mapped) form only.
+ # sparse_column_b is not affected by kernel mappers.
+ kernel_classifier = KernelLinearClassifier(
+ feature_columns=[sparse_column_b],
+ model_dir=...,
+ optimizer=optimizer,
+ kernel_mappers=kernel_mappers)
+
+ # Input builders
+ def train_input_fn: # returns x, y
+ ...
+ def eval_input_fn: # returns x, y
+ ...
+
+ kernel_classifier.fit(input_fn=train_input_fn)
+ kernel_classifier.evaluate(input_fn=eval_input_fn)
+ kernel_classifier.predict(...)
+ ```
+
+ Input of `fit` and `evaluate` should have following features,
+ otherwise there will be a `KeyError`:
+ * if `weight_column_name` is not `None`, a feature with
+ `key=weight_column_name` whose value is a `Tensor`.
+ * for each `column` in `feature_columns`:
+ - if `column` is a `SparseColumn`, a feature with `key=column.name`
+ whose `value` is a `SparseTensor`.
+ - if `column` is a `WeightedSparseColumn`, two features: the first with
+ `key` the id column name, the second with `key` the weight column name.
+ Both features' `value` must be a `SparseTensor`.
+ - if `column` is a `RealValuedColumn`, a feature with `key=column.name`
+ whose `value` is a `Tensor`.
+ """
+
+ def __init__(self,
+ feature_columns=None,
+ model_dir=None,
+ n_classes=2,
+ weight_column_name=None,
+ optimizer=None,
+ kernel_mappers=None,
+ config=None):
+ """Construct a `KernelLinearClassifier` estimator object.
+
+ Args:
+ feature_columns: An iterable containing all the feature columns used by
+ the model. All items in the set should be instances of classes derived
+ from `FeatureColumn`.
+ model_dir: Directory to save model parameters, graph etc. This can also be
+ used to load checkpoints from the directory into an estimator to
+ continue training a previously saved model.
+ n_classes: number of label classes. Default is binary classification.
+ Note that class labels are integers representing the class index (i.e.
+ values from 0 to n_classes-1). For arbitrary label values (e.g. string
+ labels), convert to class indices first.
+ weight_column_name: A string defining feature column name representing
+ weights. It is used to down weight or boost examples during training. It
+ will be multiplied by the loss of the example.
+ optimizer: The optimizer used to train the model. If specified, it should
+ be an instance of `tf.Optimizer`. If `None`, the Ftrl optimizer is used
+ by default.
+ kernel_mappers: Dictionary of kernel mappers to be applied to the input
+ features before training a (linear) model. Keys are feature columns and
+ values are lists of mappers to be applied to the corresponding feature
+ column. Currently only _RealValuedColumns are supported and therefore
+ all mappers should conform to the `DenseKernelMapper` interface (see
+ ./mappers/dense_kernel_mapper.py).
+ config: `RunConfig` object to configure the runtime settings.
+
+ Returns:
+ A `KernelLinearClassifier` estimator.
+
+ Raises:
+ ValueError: if n_classes < 2.
+ ValueError: if neither feature_columns nor kernel_mappers are provided.
+ ValueError: if mappers provided as kernel_mappers values are invalid.
+ """
+ super(KernelLinearClassifier, self).__init__(
+ feature_columns=feature_columns,
+ model_dir=model_dir,
+ weight_column_name=weight_column_name,
+ head=head_lib.multi_class_head(
+ n_classes=n_classes, weight_column_name=weight_column_name),
+ kernel_mappers=kernel_mappers,
+ config=config)
+
+ def predict_classes(self, input_fn=None):
+ """Runs inference to determine the predicted class per instance.
+
+ Args:
+ input_fn: The input function providing features.
+
+ Returns:
+ A generator of predicted classes for the features provided by input_fn.
+ Each predicted class is represented by its class index (i.e. integer from
+ 0 to n_classes-1)
+ """
+ key = prediction_key.PredictionKey.CLASSES
+ predictions = super(KernelLinearClassifier, self).predict(
+ input_fn=input_fn, outputs=[key])
+ return (pred[key] for pred in predictions)
+
+ def predict_proba(self, input_fn=None):
+ """Runs inference to determine the class probability predictions.
+
+ Args:
+ input_fn: The input function providing features.
+
+ Returns:
+ A generator of predicted class probabilities for the features provided by
+ input_fn.
+ """
+ key = prediction_key.PredictionKey.PROBABILITIES
+ predictions = super(KernelLinearClassifier, self).predict(
+ input_fn=input_fn, outputs=[key])
+ return (pred[key] for pred in predictions)
diff --git a/tensorflow/contrib/kernel_methods/python/kernel_estimators_test.py b/tensorflow/contrib/kernel_methods/python/kernel_estimators_test.py
new file mode 100644
index 0000000000..a461ba8134
--- /dev/null
+++ b/tensorflow/contrib/kernel_methods/python/kernel_estimators_test.py
@@ -0,0 +1,269 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for kernel_estimators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib import layers
+from tensorflow.contrib.kernel_methods.python import kernel_estimators
+from tensorflow.contrib.kernel_methods.python.mappers.random_fourier_features import RandomFourierFeatureMapper
+from tensorflow.contrib.learn.python.learn.estimators import test_data
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework.test_util import TensorFlowTestCase
+from tensorflow.python.platform import googletest
+
+
+def _linearly_separable_binary_input_fn():
+ """Returns linearly-separable data points (binary classification)."""
+ return {
+ 'feature1': constant_op.constant([[0.0], [1.0], [3.0]]),
+ 'feature2': constant_op.constant([[1.0], [-1.2], [1.0]]),
+ }, constant_op.constant([[1], [0], [1]])
+
+
+def _linearly_inseparable_binary_input_fn():
+ """Returns non-linearly-separable data points (binary classification)."""
+ return {
+ 'multi_dim_feature':
+ constant_op.constant([[1.0, 1.0], [1.0, -1.0], [-1.0, -1.0],
+ [-1.0, 1.0]]),
+ }, constant_op.constant([[1], [0], [1], [0]])
+
+
+class KernelLinearClassifierTest(TensorFlowTestCase):
+
+ def testNoFeatureColumnsOrKernelMappers(self):
+ """Tests that at least one of feature columns or kernels is provided."""
+ with self.assertRaises(ValueError):
+ _ = kernel_estimators.KernelLinearClassifier()
+
+ def testInvalidKernelMapper(self):
+ """ValueError raised when the kernel mappers provided have invalid type."""
+
+ class DummyKernelMapper(object):
+
+ def __init__(self):
+ pass
+
+ feature = layers.real_valued_column('feature')
+ kernel_mappers = {feature: [DummyKernelMapper()]}
+ with self.assertRaises(ValueError):
+ _ = kernel_estimators.KernelLinearClassifier(
+ feature_columns=[feature], kernel_mappers=kernel_mappers)
+
+ def testInvalidNumberOfClasses(self):
+ """ValueError raised when the kernel mappers provided have invalid type."""
+
+ feature = layers.real_valued_column('feature')
+ with self.assertRaises(ValueError):
+ _ = kernel_estimators.KernelLinearClassifier(
+ feature_columns=[feature], n_classes=1)
+
+ def testLinearlySeparableBinaryDataNoKernels(self):
+ """Tests classifier w/o kernels (log. regression) for lin-separable data."""
+
+ feature1 = layers.real_valued_column('feature1')
+ feature2 = layers.real_valued_column('feature2')
+
+ logreg_classifier = kernel_estimators.KernelLinearClassifier(
+ feature_columns=[feature1, feature2])
+ logreg_classifier.fit(
+ input_fn=_linearly_separable_binary_input_fn, steps=100)
+
+ metrics = logreg_classifier.evaluate(
+ input_fn=_linearly_separable_binary_input_fn, steps=1)
+ # Since the data is linearly separable, the classifier should have small
+ # loss and perfect accuracy.
+ self.assertLess(metrics['loss'], 0.1)
+ self.assertEqual(metrics['accuracy'], 1.0)
+
+ # As a result, it should assign higher probability to class 1 for the 1st
+ # and 3rd example and higher probability to class 0 for the second example.
+ logreg_prob_predictions = list(
+ logreg_classifier.predict_proba(input_fn=
+ _linearly_separable_binary_input_fn))
+ self.assertGreater(logreg_prob_predictions[0][1], 0.5)
+ self.assertGreater(logreg_prob_predictions[1][0], 0.5)
+ self.assertGreater(logreg_prob_predictions[2][1], 0.5)
+
+ def testLinearlyInseparableBinaryDataWithAndWithoutKernels(self):
+ """Tests classifier w/ and w/o kernels on non-linearly-separable data."""
+ multi_dim_feature = layers.real_valued_column(
+ 'multi_dim_feature', dimension=2)
+
+ # Data points are non-linearly separable so there will be at least one
+ # mis-classified sample (accuracy < 0.8). In fact, the loss is minimized for
+ # w1=w2=0.0, in which case each example incurs a loss of ln(2). The overall
+ # (average) loss should then be ln(2) and the logits should be approximately
+ # 0.0 for each sample.
+ logreg_classifier = kernel_estimators.KernelLinearClassifier(
+ feature_columns=[multi_dim_feature])
+ logreg_classifier.fit(
+ input_fn=_linearly_inseparable_binary_input_fn, steps=50)
+ logreg_metrics = logreg_classifier.evaluate(
+ input_fn=_linearly_inseparable_binary_input_fn, steps=1)
+ logreg_loss = logreg_metrics['loss']
+ logreg_accuracy = logreg_metrics['accuracy']
+ logreg_predictions = logreg_classifier.predict(
+ input_fn=_linearly_inseparable_binary_input_fn, as_iterable=False)
+ self.assertAlmostEqual(logreg_loss, np.log(2), places=3)
+ self.assertLess(logreg_accuracy, 0.8)
+ self.assertAllClose(logreg_predictions['logits'], [[0.0], [0.0], [0.0],
+ [0.0]])
+
+ # Using kernel mappers allows to discover non-linearities in data. Mapping
+ # the data to a higher dimensional feature space using approx RBF kernels,
+ # substantially reduces the loss and leads to perfect classification
+ # accuracy.
+ kernel_mappers = {
+ multi_dim_feature: [RandomFourierFeatureMapper(2, 30, 0.6, 1, 'rffm')]
+ }
+ kernelized_logreg_classifier = kernel_estimators.KernelLinearClassifier(
+ feature_columns=[], kernel_mappers=kernel_mappers)
+ kernelized_logreg_classifier.fit(
+ input_fn=_linearly_inseparable_binary_input_fn, steps=50)
+ kernelized_logreg_metrics = kernelized_logreg_classifier.evaluate(
+ input_fn=_linearly_inseparable_binary_input_fn, steps=1)
+ kernelized_logreg_loss = kernelized_logreg_metrics['loss']
+ kernelized_logreg_accuracy = kernelized_logreg_metrics['accuracy']
+ self.assertLess(kernelized_logreg_loss, 0.2)
+ self.assertEqual(kernelized_logreg_accuracy, 1.0)
+
+ def testVariablesWithAndWithoutKernels(self):
+ """Tests variables w/ and w/o kernel."""
+ multi_dim_feature = layers.real_valued_column(
+ 'multi_dim_feature', dimension=2)
+
+ linear_classifier = kernel_estimators.KernelLinearClassifier(
+ feature_columns=[multi_dim_feature])
+ linear_classifier.fit(
+ input_fn=_linearly_inseparable_binary_input_fn, steps=50)
+ linear_variables = linear_classifier.get_variable_names()
+ self.assertIn('linear/multi_dim_feature/weight', linear_variables)
+ self.assertIn('linear/bias_weight', linear_variables)
+ linear_weights = linear_classifier.get_variable_value(
+ 'linear/multi_dim_feature/weight')
+ linear_bias = linear_classifier.get_variable_value('linear/bias_weight')
+
+ kernel_mappers = {
+ multi_dim_feature: [RandomFourierFeatureMapper(2, 30, 0.6, 1, 'rffm')]
+ }
+ kernel_linear_classifier = kernel_estimators.KernelLinearClassifier(
+ feature_columns=[], kernel_mappers=kernel_mappers)
+ kernel_linear_classifier.fit(
+ input_fn=_linearly_inseparable_binary_input_fn, steps=50)
+ kernel_linear_variables = kernel_linear_classifier.get_variable_names()
+ self.assertIn('linear/multi_dim_feature_MAPPED/weight',
+ kernel_linear_variables)
+ self.assertIn('linear/bias_weight', kernel_linear_variables)
+ kernel_linear_weights = kernel_linear_classifier.get_variable_value(
+ 'linear/multi_dim_feature_MAPPED/weight')
+ kernel_linear_bias = kernel_linear_classifier.get_variable_value(
+ 'linear/bias_weight')
+
+ # The feature column used for linear classification (no kernels) has
+ # dimension 2 so the model will learn a 2-dimension weights vector (and a
+ # scalar for the bias). In the kernelized model, the features are mapped to
+ # a 30-dimensional feature space and so the weights variable will also have
+ # dimension 30.
+ self.assertEqual(2, len(linear_weights))
+ self.assertEqual(1, len(linear_bias))
+ self.assertEqual(30, len(kernel_linear_weights))
+ self.assertEqual(1, len(kernel_linear_bias))
+
+ def testClassifierWithAndWithoutKernelsNoRealValuedColumns(self):
+ """Tests kernels have no effect for non-real valued columns ."""
+
+ def input_fn():
+ return {
+ 'price':
+ constant_op.constant([[0.4], [0.6], [0.3]]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 3], [2, 1]],
+ dense_shape=[3, 5]),
+ }, constant_op.constant([[1], [0], [1]])
+
+ price = layers.real_valued_column('price')
+ country = layers.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+
+ linear_classifier = kernel_estimators.KernelLinearClassifier(
+ feature_columns=[price, country])
+ linear_classifier.fit(input_fn=input_fn, steps=100)
+ linear_metrics = linear_classifier.evaluate(input_fn=input_fn, steps=1)
+ linear_loss = linear_metrics['loss']
+ linear_accuracy = linear_metrics['accuracy']
+
+ kernel_mappers = {
+ country: [RandomFourierFeatureMapper(2, 30, 0.6, 1, 'rffm')]
+ }
+
+ kernel_linear_classifier = kernel_estimators.KernelLinearClassifier(
+ feature_columns=[price, country], kernel_mappers=kernel_mappers)
+ kernel_linear_classifier.fit(input_fn=input_fn, steps=100)
+ kernel_linear_metrics = kernel_linear_classifier.evaluate(
+ input_fn=input_fn, steps=1)
+ kernel_linear_loss = kernel_linear_metrics['loss']
+ kernel_linear_accuracy = kernel_linear_metrics['accuracy']
+
+ # The kernel mapping is applied to a non-real-valued feature column and so
+ # it should have no effect on the model. The loss and accuracy of the
+ # "kernelized" model should match the loss and accuracy of the initial model
+ # (without kernels).
+ self.assertAlmostEqual(linear_loss, kernel_linear_loss, delta=0.01)
+ self.assertAlmostEqual(linear_accuracy, kernel_linear_accuracy, delta=0.01)
+
+ def testMulticlassDataWithAndWithoutKernels(self):
+ """Tests classifier w/ and w/o kernels on multiclass data."""
+ feature_column = layers.real_valued_column('feature', dimension=4)
+
+ # Metrics for linear classifier (no kernels).
+ linear_classifier = kernel_estimators.KernelLinearClassifier(
+ feature_columns=[feature_column], n_classes=3)
+ linear_classifier.fit(input_fn=test_data.iris_input_multiclass_fn, steps=50)
+ linear_metrics = linear_classifier.evaluate(
+ input_fn=test_data.iris_input_multiclass_fn, steps=1)
+ linear_loss = linear_metrics['loss']
+ linear_accuracy = linear_metrics['accuracy']
+
+ # Using kernel mappers allows to discover non-linearities in data (via RBF
+ # kernel approximation), reduces loss and increases accuracy.
+ kernel_mappers = {
+ feature_column: [
+ RandomFourierFeatureMapper(
+ input_dim=4, output_dim=50, stddev=1.0, name='rffm')
+ ]
+ }
+ kernel_linear_classifier = kernel_estimators.KernelLinearClassifier(
+ feature_columns=[], n_classes=3, kernel_mappers=kernel_mappers)
+ kernel_linear_classifier.fit(
+ input_fn=test_data.iris_input_multiclass_fn, steps=50)
+ kernel_linear_metrics = kernel_linear_classifier.evaluate(
+ input_fn=test_data.iris_input_multiclass_fn, steps=1)
+ kernel_linear_loss = kernel_linear_metrics['loss']
+ kernel_linear_accuracy = kernel_linear_metrics['accuracy']
+ self.assertLess(kernel_linear_loss, linear_loss)
+ self.assertGreater(kernel_linear_accuracy, linear_accuracy)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/kernel_methods/python/mappers/dense_kernel_mapper.py b/tensorflow/contrib/kernel_methods/python/mappers/dense_kernel_mapper.py
new file mode 100644
index 0000000000..db38b47152
--- /dev/null
+++ b/tensorflow/contrib/kernel_methods/python/mappers/dense_kernel_mapper.py
@@ -0,0 +1,59 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""API class for dense (approximate) kernel mappers.
+
+See ./random_fourier_features.py for a concrete instantiation of this class.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+import six
+
+
+class InvalidShapeError(Exception):
+ """Exception thrown when a tensor's shape deviates from an expected shape."""
+
+
+@six.add_metaclass(abc.ABCMeta)
+class DenseKernelMapper(object):
+ """Abstract class for a kernel mapper that maps dense inputs to dense outputs.
+
+ This class is abstract. Users should not create instances of this class.
+ """
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractmethod
+ def map(self, input_tensor):
+ """Main Dense-Tensor-In-Dense-Tensor-Out (DTIDTO) map method.
+
+ Should be implemented by subclasses.
+ Args:
+ input_tensor: The dense input tensor to be mapped using the (approximate)
+ kernel mapper.
+ """
+ raise NotImplementedError('map is not implemented for {}.'.format(self))
+
+ @abc.abstractproperty
+ def name(self):
+ """Returns the name of the kernel mapper."""
+ pass
+
+ @abc.abstractproperty
+ def output_dim(self):
+ """Returns the output dimension of the mapping."""
+ pass
diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py
new file mode 100644
index 0000000000..270a243970
--- /dev/null
+++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py
@@ -0,0 +1,157 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Approximate kernel mapper for RBF kernel based on Random Fourier Features."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import numpy as np
+
+from tensorflow.contrib.kernel_methods.python.mappers import dense_kernel_mapper as dkm
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import math_ops
+
+
+# TODO(sibyl-vie3Poto,felixyu): add an option to control whether the parameters in the
+# kernel map are trainable.
+class RandomFourierFeatureMapper(dkm.DenseKernelMapper):
+ """Class that implements Random Fourier Feature Mapping.
+
+ The RFFM mapping is used to approximate the Gaussian (RBF) kernel:
+ exp(-||x-y||_2^2 / (2 * sigma^2))
+
+ The implementation of RFFM is based on the following paper:
+ "Random Features for Large-Scale Kernel Machines" by Ali Rahimi and Ben Recht.
+ (link: https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf)
+
+ The mapping uses a matrix Omega in R^{d x D} and a bias vector b in R^D where
+ d is the input dimension (number of dense input features) and D is the output
+ dimension (i.e., dimension of the feature space the input is mapped to). Each
+ entry of Omega is sampled i.i.d. from a (scaled) Gaussian distribution and
+ each entry of the bias vector is sampled i.i.d. and uniformly from [0, 2*pi].
+
+ For a single input feature vector x in R^d, its RFFM is defined as:
+ sqrt(2/D) * cos(x * Omega + b)
+ where cos is the element-wise cosine function and x, b are represented as row
+ vectors. The aforementioned paper shows that the linear kernel of RFFM-mapped
+ vectors approximates the Gaussian kernel of the initial vectors.
+
+ """
+
+ def __init__(self, input_dim, output_dim, stddev=1.0, seed=1, name=None):
+ """Constructs a RandomFourierFeatureMapper instance.
+
+ Args:
+ input_dim: The dimension (number of features) of the tensors to be mapped.
+ output_dim: The output dimension of the mapping.
+ stddev: The standard deviation of the Gaussian kernel to be approximated.
+ The error of the classifier trained using this approximation is very
+ sensitive to this parameter.
+ seed: An integer used to initialize the parameters (Omega and bias) of the
+ mapper. For repeatable sequences across different invocations of the
+ mapper object (for instance, to ensure consistent mapping both at
+ training and eval/inference if these happen in different invocations),
+ set this to the same integer.
+ name: name for the mapper object.
+ """
+ # TODO(sibyl-vie3Poto): Maybe infer input_dim and/or output_dim (if not explicitly
+ # provided). input_dim can be inferred lazily, the first time map is called.
+ # output_dim can be inferred from input_dim using heuristics on the error of
+ # the approximation (and, by extension, the error of the classification
+ # based on the approximation).
+ self._input_dim = input_dim
+ self._output_dim = output_dim
+ self._stddev = stddev
+ self._seed = seed
+ self._name = name
+
+ @property
+ def name(self):
+ """Returns a name for the RandomFourierFeatureMapper instance.
+
+ If the name provided in the constructor is None, then the object's unique id
+ is returned.
+
+ Returns:
+ A name for the RandomFourierFeatureMapper instance.
+ """
+ return self._name or str(id(self))
+
+ @property
+ def input_dim(self):
+ return self._input_dim
+
+ @property
+ def output_dim(self):
+ return self._output_dim
+
+ def map(self, input_tensor):
+ """Maps each row of input_tensor using random Fourier features.
+
+ Args:
+ input_tensor: tensor containing input features. It's shape is
+ [batch_size, self._input_dim].
+
+ Returns:
+ A tensor of shape [batch_size, self._output_dim] containing RFFM-mapped
+ features.
+
+ Raises:
+ InvalidShapeError: if the shape of the input_tensor is inconsistent with
+ expected input dimension.
+ """
+ input_tensor_shape = input_tensor.get_shape()
+ if len(input_tensor_shape) != 2:
+ raise dkm.InvalidShapeError(
+ 'The shape of the tensor should be 2. Got %d instead.' %
+ len(input_tensor_shape))
+
+ features_dim = input_tensor_shape[1]
+ if features_dim != self._input_dim:
+ raise dkm.InvalidShapeError(
+ 'Invalid dimension: expected %d input features, got %d instead.' %
+ (self._input_dim, features_dim))
+
+ # Add ops that compute (deterministically) omega_matrix and bias based on
+ # the provided seed.
+ # TODO(sibyl-vie3Poto): Storing the mapper's parameters (omega_matrix and bias) as
+ # constants incurs no RPC calls to the parameter server during distributed
+ # training. However, if the parameters grow too large (for instance if they
+ # don't fit into memory or if they blow up the size of the GraphDef proto),
+ # stroring them as constants is no longer an option. In this case, we should
+ # have a heuristic to choose out of one of the following alternatives:
+ # a) store them as variables (in the parameter server)
+ # b) store them as worker local variables
+ # c) generating on the fly the omega matrix at each step
+ np.random.seed(self._seed)
+ omega_matrix_shape = [self._input_dim, self._output_dim]
+ bias_shape = [self._output_dim]
+
+ omega_matrix = constant_op.constant(
+ np.random.normal(
+ scale=1.0 / self._stddev, size=omega_matrix_shape),
+ dtype=dtypes.float32)
+ bias = constant_op.constant(
+ np.random.uniform(
+ low=0.0, high=2 * np.pi, size=bias_shape),
+ dtype=dtypes.float32)
+
+ x_omega_plus_bias = math_ops.add(
+ math_ops.matmul(input_tensor, omega_matrix), bias)
+ return math.sqrt(2.0 / self._output_dim) * math_ops.cos(x_omega_plus_bias)
diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
new file mode 100644
index 0000000000..200d00b663
--- /dev/null
+++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
@@ -0,0 +1,166 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for RandomFourierFeatureMapper."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.contrib.kernel_methods.python.mappers import dense_kernel_mapper
+from tensorflow.contrib.kernel_methods.python.mappers.random_fourier_features import RandomFourierFeatureMapper
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework.test_util import TensorFlowTestCase
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import googletest
+
+
+def _inner_product(x, y):
+ """Inner product between tensors x and y.
+
+ The input tensors are assumed to be in ROW representation, that is, the method
+ returns x * y^T.
+
+ Args:
+ x: input tensor in row format
+ y: input tensor in row format
+
+ Returns:
+ the inner product of x, y
+ """
+ return math_ops.matmul(x, y, transpose_b=True)
+
+
+def _compute_exact_rbf_kernel(x, y, stddev):
+ """Computes exact RBF kernel given input tensors x and y and stddev."""
+ diff = math_ops.subtract(x, y)
+ diff_squared_norm = _inner_product(diff, diff)
+ return math_ops.exp(-diff_squared_norm / (2 * stddev * stddev))
+
+
+class RandomFourierFeatureMapperTest(TensorFlowTestCase):
+
+ def testInvalidInputShape(self):
+ x = constant_op.constant([[2.0, 1.0]])
+
+ with self.test_session():
+ rffm = RandomFourierFeatureMapper(3, 10)
+ with self.assertRaisesWithPredicateMatch(
+ dense_kernel_mapper.InvalidShapeError,
+ r'Invalid dimension: expected 3 input features, got 2 instead.'):
+ rffm.map(x)
+
+ def testMappedShape(self):
+ x1 = constant_op.constant([[2.0, 1.0, 0.0]])
+ x2 = constant_op.constant([[1.0, -1.0, 2.0], [-1.0, 10.0, 1.0],
+ [4.0, -2.0, -1.0]])
+
+ with self.test_session():
+ rffm = RandomFourierFeatureMapper(3, 10, 1.0)
+ mapped_x1 = rffm.map(x1)
+ mapped_x2 = rffm.map(x2)
+ self.assertEqual([1, 10], mapped_x1.get_shape())
+ self.assertEqual([3, 10], mapped_x2.get_shape())
+
+ def testSameOmegaReused(self):
+ x = constant_op.constant([[2.0, 1.0, 0.0]])
+
+ with self.test_session():
+ rffm = RandomFourierFeatureMapper(3, 100)
+ mapped_x = rffm.map(x)
+ mapped_x_copy = rffm.map(x)
+ # Two different evaluations of tensors output by map on the same input
+ # are identical because the same paramaters are used for the mappings.
+ self.assertAllClose(mapped_x.eval(), mapped_x_copy.eval(), atol=0.001)
+
+ def testTwoMapperObjects(self):
+ x = constant_op.constant([[2.0, 1.0, 0.0]])
+ y = constant_op.constant([[1.0, -1.0, 2.0]])
+ stddev = 3.0
+
+ with self.test_session():
+ # The mapped dimension is fairly small, so the kernel approximation is
+ # very rough.
+ rffm1 = RandomFourierFeatureMapper(3, 100, stddev)
+ rffm2 = RandomFourierFeatureMapper(3, 100, stddev)
+ mapped_x1 = rffm1.map(x)
+ mapped_y1 = rffm1.map(y)
+ mapped_x2 = rffm2.map(x)
+ mapped_y2 = rffm2.map(y)
+
+ approx_kernel_value1 = _inner_product(mapped_x1, mapped_y1)
+ approx_kernel_value2 = _inner_product(mapped_x2, mapped_y2)
+ self.assertAllClose(
+ approx_kernel_value1.eval(), approx_kernel_value2.eval(), atol=0.01)
+
+ def testBadKernelApproximation(self):
+ x = constant_op.constant([[2.0, 1.0, 0.0]])
+ y = constant_op.constant([[1.0, -1.0, 2.0]])
+ stddev = 3.0
+
+ with self.test_session():
+ # The mapped dimension is fairly small, so the kernel approximation is
+ # very rough.
+ rffm = RandomFourierFeatureMapper(3, 100, stddev, seed=0)
+ mapped_x = rffm.map(x)
+ mapped_y = rffm.map(y)
+ exact_kernel_value = _compute_exact_rbf_kernel(x, y, stddev)
+ approx_kernel_value = _inner_product(mapped_x, mapped_y)
+ self.assertAllClose(
+ exact_kernel_value.eval(), approx_kernel_value.eval(), atol=0.2)
+
+ def testGoodKernelApproximationAmortized(self):
+ # Parameters.
+ num_points = 20
+ input_dim = 5
+ mapped_dim = 5000
+ stddev = 5.0
+
+ # TODO(sibyl-vie3Poto): Reduce test's running time before moving to third_party. One
+ # possible way to speed the test up is to compute both the approximate and
+ # the exact kernel matrix directly using matrix operations instead of
+ # computing the values for each pair of points separately.
+ points_shape = [1, input_dim]
+ points = [
+ random_ops.random_uniform(shape=points_shape, maxval=1.0)
+ for _ in xrange(num_points)
+ ]
+
+ normalized_points = [nn.l2_normalize(point, dim=1) for point in points]
+ total_absolute_error = 0.0
+ with self.test_session():
+ rffm = RandomFourierFeatureMapper(input_dim, mapped_dim, stddev, seed=0)
+ # Cache mappings so that they are not computed multiple times.
+ cached_mappings = dict((point, rffm.map(point))
+ for point in normalized_points)
+ for x in normalized_points:
+ mapped_x = cached_mappings[x]
+ for y in normalized_points:
+ mapped_y = cached_mappings[y]
+ exact_kernel_value = _compute_exact_rbf_kernel(x, y, stddev)
+ approx_kernel_value = _inner_product(mapped_x, mapped_y)
+ abs_error = math_ops.abs(exact_kernel_value - approx_kernel_value)
+ total_absolute_error += abs_error
+ self.assertAllClose(
+ [[0.0]],
+ total_absolute_error.eval() / (num_points * num_points),
+ atol=0.02)
+
+
+if __name__ == '__main__':
+ googletest.main()