aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/libsvm
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-22 12:42:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-22 12:46:28 -0800
commite4532d20973c4c00854492362665317551661c18 (patch)
tree398527e29bd30d39237adb4785be5069fdb646fa /tensorflow/contrib/libsvm
parent673641c2d6a27fa97ee05453d671853731a4c602 (diff)
Merge changes from github.
PiperOrigin-RevId: 179953488
Diffstat (limited to 'tensorflow/contrib/libsvm')
-rw-r--r--tensorflow/contrib/libsvm/BUILD102
-rw-r--r--tensorflow/contrib/libsvm/__init__.py32
-rw-r--r--tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc178
-rw-r--r--tensorflow/contrib/libsvm/ops/libsvm_ops.cc58
-rw-r--r--tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py71
-rw-r--r--tensorflow/contrib/libsvm/python/ops/libsvm_ops.py50
6 files changed, 491 insertions, 0 deletions
diff --git a/tensorflow/contrib/libsvm/BUILD b/tensorflow/contrib/libsvm/BUILD
new file mode 100644
index 0000000000..df96402a4f
--- /dev/null
+++ b/tensorflow/contrib/libsvm/BUILD
@@ -0,0 +1,102 @@
+package(
+ default_visibility = ["//visibility:private"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+tf_custom_op_library(
+ name = "python/ops/_libsvm_ops.so",
+ srcs = [
+ "kernels/decode_libsvm_op.cc",
+ "ops/libsvm_ops.cc",
+ ],
+ deps = [
+ "//tensorflow/core/kernels:bounds_check_lib",
+ ],
+)
+
+tf_kernel_library(
+ name = "libsvm_kernels",
+ srcs = ["kernels/decode_libsvm_op.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/kernels:bounds_check_lib",
+ ],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["libsvm_ops"],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "libsvm_ops",
+ deps = [":libsvm_ops_op_lib"],
+)
+
+tf_custom_op_py_library(
+ name = "libsvm",
+ srcs = [
+ "__init__.py",
+ "python/ops/libsvm_ops.py",
+ ],
+ dso = [
+ ":python/ops/_libsvm_ops.so",
+ ],
+ kernels = [
+ ":libsvm_kernels",
+ ":libsvm_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":libsvm_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_py_test(
+ name = "decode_libsvm_op_test",
+ srcs = ["python/kernel_tests/decode_libsvm_op_test.py"],
+ additional_deps = [
+ ":libsvm",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/libsvm/__init__.py b/tensorflow/contrib/libsvm/__init__.py
new file mode 100644
index 0000000000..a875863caa
--- /dev/null
+++ b/tensorflow/contrib/libsvm/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Libsvm decoder.
+
+@@decode_libsvm
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.libsvm.python.ops.libsvm_ops import decode_libsvm
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ "decode_libsvm",
+]
+
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc b/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc
new file mode 100644
index 0000000000..27ce55d568
--- /dev/null
+++ b/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc
@@ -0,0 +1,178 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+namespace {
+template <typename T>
+bool ConvertHelper(const string& s, T* value);
+}
+
+template <typename T, typename Tlabel>
+class DecodeLibsvmOp : public OpKernel {
+ public:
+ explicit DecodeLibsvmOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("num_features", &num_features_));
+ OP_REQUIRES(ctx, (num_features_ >= 1),
+ errors::InvalidArgument("Invalid number of features \"",
+ num_features_, "\""));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
+ const auto& input_flat = input_tensor->flat<string>();
+
+ Tensor* label_tensor;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(0, input_tensor->shape(), &label_tensor));
+ auto label = label_tensor->flat<Tlabel>();
+
+ std::vector<T> out_values;
+ std::vector<std::pair<int64, int64>> out_indices;
+ for (int i = 0; i < input_flat.size(); ++i) {
+ std::vector<string> entries =
+ str_util::Split(input_flat(i), " ", str_util::SkipEmpty());
+ OP_REQUIRES(ctx, !entries.empty(),
+ errors::InvalidArgument("No entries found for input[", i,
+ "]: \"", input_flat(i), "\""));
+ Tlabel label_value;
+ OP_REQUIRES(
+ ctx, ConvertHelper<Tlabel>(entries[0], &label_value),
+ errors::InvalidArgument("Label format incorrect: ", entries[0]));
+ label(i) = label_value;
+ for (int j = 1; j < entries.size(); j++) {
+ std::vector<string> pair = str_util::Split(entries[j], ":");
+ OP_REQUIRES(
+ ctx, (pair.size() == 2),
+ errors::InvalidArgument("Invalid feature \"", entries[j], "\""));
+ int64 feature_index;
+ OP_REQUIRES(
+ ctx, strings::safe_strto64(pair[0].c_str(), &feature_index),
+ errors::InvalidArgument("Feature format incorrect: ", entries[j]));
+ OP_REQUIRES(ctx, (feature_index >= 0),
+ errors::InvalidArgument(
+ "Feature index should be >= 0, got ", feature_index));
+ T feature_value;
+ OP_REQUIRES(
+ ctx, ConvertHelper<T>(pair[1], &feature_value),
+ errors::InvalidArgument("Feature format incorrect: ", entries[j]));
+ out_values.emplace_back(feature_value);
+ out_indices.emplace_back(std::pair<int64, int64>(i, feature_index));
+ }
+ }
+
+ Tensor* indices_tensor;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(
+ 1,
+ TensorShape({static_cast<int64>(out_indices.size()),
+ input_tensor->shape().dims() + 1}),
+ &indices_tensor));
+ auto indices = indices_tensor->matrix<int64>();
+ // Translate flat index to shaped index like np.unravel_index
+ // Calculate factors for each dimension
+ std::vector<int64> factors(input_tensor->shape().dims());
+ factors[input_tensor->shape().dims() - 1] = 1;
+ for (int j = input_tensor->shape().dims() - 2; j >= 0; j--) {
+ factors[j] = factors[j + 1] * input_tensor->shape().dim_size(j + 1);
+ }
+ for (int i = 0; i < out_indices.size(); i++) {
+ indices(i, 0) = out_indices[i].first;
+ int64 value = out_indices[i].first;
+ for (int j = 0; j < input_tensor->shape().dims(); j++) {
+ indices(i, j) = value / factors[j];
+ value = value % factors[j];
+ }
+ indices(i, input_tensor->shape().dims()) = out_indices[i].second;
+ }
+
+ Tensor* values_tensor;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(
+ 2, TensorShape({static_cast<int64>(out_values.size())}),
+ &values_tensor));
+ auto values = values_tensor->vec<T>();
+ std::copy_n(out_values.begin(), out_values.size(), &values(0));
+
+ Tensor* shape_tensor;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(
+ 3, TensorShape({input_tensor->shape().dims() + 1}),
+ &shape_tensor));
+ auto shape = shape_tensor->flat<int64>();
+ for (int i = 0; i < input_tensor->shape().dims(); i++) {
+ shape(i) = input_tensor->shape().dim_size(i);
+ }
+ shape(input_tensor->shape().dims()) = num_features_;
+ }
+
+ private:
+ int64 num_features_;
+};
+
+namespace {
+template <>
+bool ConvertHelper<float>(const string& s, float* value) {
+ return strings::safe_strtof(s.c_str(), value);
+}
+template <>
+bool ConvertHelper<double>(const string& s, double* value) {
+ return strings::safe_strtod(s.c_str(), value);
+}
+template <>
+bool ConvertHelper<int32>(const string& s, int32* value) {
+ return strings::safe_strto32(s.c_str(), value);
+}
+template <>
+bool ConvertHelper<int64>(const string& s, int64* value) {
+ return strings::safe_strto64(s.c_str(), value);
+}
+} // namespace
+
+#define REGISTER_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("dtype") \
+ .TypeConstraint<int32>("label_dtype"), \
+ DecodeLibsvmOp<type, int32>); \
+ REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("dtype") \
+ .TypeConstraint<int64>("label_dtype"), \
+ DecodeLibsvmOp<type, int64>); \
+ REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("dtype") \
+ .TypeConstraint<float>("label_dtype"), \
+ DecodeLibsvmOp<type, float>); \
+ REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("dtype") \
+ .TypeConstraint<double>("label_dtype"), \
+ DecodeLibsvmOp<type, double>);
+
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+REGISTER_KERNEL(int32);
+REGISTER_KERNEL(int64);
+#undef REGISTER_KERNEL
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/libsvm/ops/libsvm_ops.cc b/tensorflow/contrib/libsvm/ops/libsvm_ops.cc
new file mode 100644
index 0000000000..dec946189e
--- /dev/null
+++ b/tensorflow/contrib/libsvm/ops/libsvm_ops.cc
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using shape_inference::InferenceContext;
+
+REGISTER_OP("DecodeLibsvm")
+ .Input("input: string")
+ .Output("label: label_dtype")
+ .Output("feature_indices: int64")
+ .Output("feature_values: dtype")
+ .Output("feature_shape: int64")
+ .Attr("dtype: {float, double, int32, int64} = DT_FLOAT")
+ .Attr("label_dtype: {float, double, int32, int64} = DT_INT64")
+ .Attr("num_features: int >= 1")
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->input(0));
+
+ c->set_output(1, c->Matrix(InferenceContext::kUnknownDim,
+ InferenceContext::kUnknownDim));
+ c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
+ c->set_output(3, c->Vector(InferenceContext::kUnknownDim));
+
+ return Status::OK();
+ })
+
+ .Doc(R"doc(
+Convert LibSVM input to tensors. The output consists of
+a label and a feature tensor. The shape of the label tensor
+is the same as input and the shape of the feature tensor is
+`[input_shape, num_features]`.
+
+input: Each string is a record in the LibSVM.
+label: A tensor of the same shape as input.
+feature_indices: A 2-D int64 tensor of dense_shape [N, ndims].
+feature_values: A 1-D tensor of any type and dense_shape [N].
+feature_shape: A 1-D int64 tensor of dense_shape [ndims].
+num_features: The number of features.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py
new file mode 100644
index 0000000000..423dcce8de
--- /dev/null
+++ b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py
@@ -0,0 +1,71 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for DecodeLibsvm op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.libsvm.python.ops import libsvm_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import test
+
+
+class DecodeLibsvmOpTest(test.TestCase):
+
+ def testBasic(self):
+ with self.test_session() as sess:
+ content = [
+ "1 1:3.4 2:0.5 4:0.231", "1 2:2.5 3:inf 5:0.503",
+ "2 3:2.5 2:nan 1:0.105"
+ ]
+ sparse_features, labels = libsvm_ops.decode_libsvm(
+ content, num_features=6)
+ features = sparse_ops.sparse_tensor_to_dense(
+ sparse_features, validate_indices=False)
+
+ self.assertAllEqual(labels.get_shape().as_list(), [3])
+
+ features, labels = sess.run([features, labels])
+ self.assertAllEqual(labels, [1, 1, 2])
+ self.assertAllClose(
+ features, [[0, 3.4, 0.5, 0, 0.231, 0], [0, 0, 2.5, np.inf, 0, 0.503],
+ [0, 0.105, np.nan, 2.5, 0, 0]])
+
+ def testNDimension(self):
+ with self.test_session() as sess:
+ content = [["1 1:3.4 2:0.5 4:0.231", "1 1:3.4 2:0.5 4:0.231"],
+ ["1 2:2.5 3:inf 5:0.503", "1 2:2.5 3:inf 5:0.503"],
+ ["2 3:2.5 2:nan 1:0.105", "2 3:2.5 2:nan 1:0.105"]]
+ sparse_features, labels = libsvm_ops.decode_libsvm(
+ content, num_features=6, label_dtype=dtypes.float64)
+ features = sparse_ops.sparse_tensor_to_dense(
+ sparse_features, validate_indices=False)
+
+ self.assertAllEqual(labels.get_shape().as_list(), [3, 2])
+
+ features, labels = sess.run([features, labels])
+ self.assertAllEqual(labels, [[1, 1], [1, 1], [2, 2]])
+ self.assertAllClose(
+ features, [[[0, 3.4, 0.5, 0, 0.231, 0], [0, 3.4, 0.5, 0, 0.231, 0]], [
+ [0, 0, 2.5, np.inf, 0, 0.503], [0, 0, 2.5, np.inf, 0, 0.503]
+ ], [[0, 0.105, np.nan, 2.5, 0, 0], [0, 0.105, np.nan, 2.5, 0, 0]]])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/libsvm/python/ops/libsvm_ops.py b/tensorflow/contrib/libsvm/python/ops/libsvm_ops.py
new file mode 100644
index 0000000000..b302250563
--- /dev/null
+++ b/tensorflow/contrib/libsvm/python/ops/libsvm_ops.py
@@ -0,0 +1,50 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Libsvm decoder."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.libsvm.ops import gen_libsvm_ops
+from tensorflow.contrib.util import loader
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.platform import resource_loader
+
+
+_libsvm_ops_so = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_libsvm_ops.so"))
+
+
+def decode_libsvm(content, num_features, dtype=None, label_dtype=None):
+ """Convert Libsvm records to a tensor of label and a tensor of feature.
+
+ Args:
+ content: A `Tensor` of type `string`. Each string is a record/row in
+ the Libsvm format.
+ num_features: The number of features.
+ dtype: The type of the output feature tensor. Default to tf.float32.
+ label_dtype: The type of the output label tensor. Default to tf.int64.
+
+ Returns:
+ features: A `SparseTensor` of the shape `[input_shape, num_features]`.
+ labels: A `Tensor` of the same shape as content.
+ """
+ labels, indices, values, shape = gen_libsvm_ops.decode_libsvm(
+ content, num_features, dtype=dtype, label_dtype=label_dtype)
+ return sparse_tensor.SparseTensor(indices, values, shape), labels
+
+
+ops.NotDifferentiable("DecodeLibSVM")