aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/string_to_number_op.cc
blob: 8d23a4fdf8411248b70161d354c499f5221ae87e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
// See docs in ../ops/parse_ops.cc.

#include <errno.h>
#include <string>

#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/public/status.h"
#include "tensorflow/core/public/tensor.h"

namespace tensorflow {

static constexpr char kErrorMessage[] =
    "StringToNumberOp could not correctly convert string: ";

template <typename OutputType>
class StringToNumberOp : public OpKernel {
 public:
  using OpKernel::OpKernel;

  void Compute(OpKernelContext* context) override {
    // This is not a deep copy of the input tensor; they will share the same
    // underlying storage.
    const Tensor* input_tensor;
    OP_REQUIRES_OK(context, context->input("string_tensor", &input_tensor));
    const auto& input_flat = input_tensor->flat<string>();

    Tensor* output_tensor = nullptr;
    OP_REQUIRES_OK(context,
                   context->allocate_output("output", input_tensor->shape(),
                                            &output_tensor));
    auto output_flat = output_tensor->flat<OutputType>();

    for (int i = 0; i < input_flat.size(); ++i) {
      const char* s = input_flat(i).data();
      Convert(s, &output_flat(i), context);
    }
  }

 private:
  void Convert(const char* s, OutputType* output_data,
               OpKernelContext* context);
};

template <>
void StringToNumberOp<float>::Convert(const char* s, float* output_data,
                                      OpKernelContext* context) {
  OP_REQUIRES(context, strings::safe_strtof(s, output_data),
              errors::InvalidArgument(kErrorMessage, s));
}

template <>
void StringToNumberOp<int32>::Convert(const char* s, int32* output_data,
                                      OpKernelContext* context) {
  OP_REQUIRES(context, strings::safe_strto32(s, output_data),
              errors::InvalidArgument(kErrorMessage, s));
}

// Registers the currently supported output types.
#define REGISTER(type)                                           \
  REGISTER_KERNEL_BUILDER(Name("StringToNumber")                 \
                              .Device(DEVICE_CPU)                \
                              .TypeConstraint<type>("out_type"), \
                          StringToNumberOp<type>)
REGISTER(float);
REGISTER(int32);
#undef REGISTER

}  // namespace tensorflow