aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/libsvm/ops
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-22 12:42:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-22 12:46:28 -0800
commite4532d20973c4c00854492362665317551661c18 (patch)
tree398527e29bd30d39237adb4785be5069fdb646fa /tensorflow/contrib/libsvm/ops
parent673641c2d6a27fa97ee05453d671853731a4c602 (diff)
Merge changes from github.
PiperOrigin-RevId: 179953488
Diffstat (limited to 'tensorflow/contrib/libsvm/ops')
-rw-r--r--tensorflow/contrib/libsvm/ops/libsvm_ops.cc58
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/contrib/libsvm/ops/libsvm_ops.cc b/tensorflow/contrib/libsvm/ops/libsvm_ops.cc
new file mode 100644
index 0000000000..dec946189e
--- /dev/null
+++ b/tensorflow/contrib/libsvm/ops/libsvm_ops.cc
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using shape_inference::InferenceContext;
+
+REGISTER_OP("DecodeLibsvm")
+ .Input("input: string")
+ .Output("label: label_dtype")
+ .Output("feature_indices: int64")
+ .Output("feature_values: dtype")
+ .Output("feature_shape: int64")
+ .Attr("dtype: {float, double, int32, int64} = DT_FLOAT")
+ .Attr("label_dtype: {float, double, int32, int64} = DT_INT64")
+ .Attr("num_features: int >= 1")
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->input(0));
+
+ c->set_output(1, c->Matrix(InferenceContext::kUnknownDim,
+ InferenceContext::kUnknownDim));
+ c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
+ c->set_output(3, c->Vector(InferenceContext::kUnknownDim));
+
+ return Status::OK();
+ })
+
+ .Doc(R"doc(
+Convert LibSVM input to tensors. The output consists of
+a label and a feature tensor. The shape of the label tensor
+is the same as input and the shape of the feature tensor is
+`[input_shape, num_features]`.
+
+input: Each string is a record in the LibSVM.
+label: A tensor of the same shape as input.
+feature_indices: A 2-D int64 tensor of dense_shape [N, ndims].
+feature_values: A 1-D tensor of any type and dense_shape [N].
+feature_shape: A 1-D int64 tensor of dense_shape [ndims].
+num_features: The number of features.
+)doc");
+
+} // namespace tensorflow