diff options
Diffstat (limited to 'tensorflow/core/kernels/constant_op.cc')
-rw-r--r-- | tensorflow/core/kernels/constant_op.cc | 249 |
1 files changed, 249 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc new file mode 100644 index 0000000000..281bafd3df --- /dev/null +++ b/tensorflow/core/kernels/constant_op.cc @@ -0,0 +1,249 @@ +// See docs in ../ops/array_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/constant_op.h" + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/public/tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +ConstantOp::ConstantOp(OpKernelConstruction* ctx) + : OpKernel(ctx), tensor_(ctx->output_type(0)) { + const TensorProto* proto = nullptr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto)); + OP_REQUIRES_OK(ctx, ctx->device()->MakeTensorFromProto( + *proto, AllocatorAttributes(), &tensor_)); + OP_REQUIRES( + ctx, ctx->output_type(0) == tensor_.dtype(), + errors::InvalidArgument("Type mismatch between value (", + DataTypeString(tensor_.dtype()), ") and dtype (", + DataTypeString(ctx->output_type(0)), ")")); +} + +void ConstantOp::Compute(OpKernelContext* ctx) { ctx->set_output(0, tensor_); } + +ConstantOp::~ConstantOp() {} + +REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_CPU), ConstantOp); + +#if GOOGLE_CUDA +#define REGISTER_KERNEL(D, TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Const").Device(DEVICE_##D).TypeConstraint<TYPE>("dtype"), \ + ConstantOp); +REGISTER_KERNEL(GPU, float); +REGISTER_KERNEL(GPU, double); +REGISTER_KERNEL(GPU, uint8); +REGISTER_KERNEL(GPU, int8); +REGISTER_KERNEL(GPU, int16); +REGISTER_KERNEL(GPU, int64); +REGISTER_KERNEL(GPU, complex64); +REGISTER_KERNEL(GPU, bool); +// Currently we do not support string constants on GPU +#undef REGISTER_KERNEL +#endif + +// HostConstantOp differs from ConstantOp in that its output is always +// in host memory. +class HostConstantOp : public OpKernel { + public: + explicit HostConstantOp(OpKernelConstruction* ctx) + : OpKernel(ctx), tensor_(ctx->output_type(0)) { + const TensorProto* proto = nullptr; + AllocatorAttributes alloc_attr; + alloc_attr.set_on_host(true); + OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto)); + OP_REQUIRES_OK( + ctx, ctx->device()->MakeTensorFromProto(*proto, alloc_attr, &tensor_)); + OP_REQUIRES( + ctx, ctx->output_type(0) == tensor_.dtype(), + errors::InvalidArgument( + "Type mismatch between value (", DataTypeString(tensor_.dtype()), + ") and dtype (", DataTypeString(ctx->output_type(0)), ")")); + } + + void Compute(OpKernelContext* ctx) override { ctx->set_output(0, tensor_); } + + bool IsExpensive() override { return false; } + + ~HostConstantOp() override {} + + private: + Tensor tensor_; + TF_DISALLOW_COPY_AND_ASSIGN(HostConstantOp); +}; + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Const") + .Device(DEVICE_GPU) + .HostMemory("output") + .TypeConstraint<int32>("dtype"), + HostConstantOp); + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +// Partial specialization of FillFunctor<Device=CPUDevice, T>. +template <typename T> +struct FillFunctor<CPUDevice, T> { + void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstScalar in) { + out.device(d) = out.constant(in()); + } +}; + +// Partial specialization of SetZeroFunctor<Device=CPUDevice, T>. +template <typename T> +struct SetZeroFunctor<CPUDevice, T> { + void operator()(const CPUDevice& d, typename TTypes<T>::Flat out) { + out.device(d) = out.constant(0); + } +}; + +#define DEFINE_SETZERO_CPU(T) template struct SetZeroFunctor<CPUDevice, T> +DEFINE_SETZERO_CPU(float); +DEFINE_SETZERO_CPU(double); +DEFINE_SETZERO_CPU(int32); +DEFINE_SETZERO_CPU(complex64); +#undef DEFINE_SETZERO_CPU + +} // end namespace functor + +template <typename Device, typename T> +class FillOp : public OpKernel { + public: + explicit FillOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& Tdims = context->input(0); + OP_REQUIRES(context, TensorShapeUtils::IsLegacyVector(Tdims.shape()), + errors::InvalidArgument("dims must be a vector of int32.")); + const Tensor& Tvalue = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(Tvalue.shape()), + errors::InvalidArgument("value must be a scalar.")); + auto dims = Tdims.flat<int32>(); + for (int i = 0; i < dims.size(); i++) { + OP_REQUIRES(context, dims(i) >= 0, + errors::InvalidArgument("dims[", i, "] = ", dims(i), + " must be nonnegative.")); + } + Tensor* out = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output( + 0, TensorShapeUtils::MakeShape( + reinterpret_cast<const int32*>(dims.data()), dims.size()), + &out)); + functor::FillFunctor<Device, T> functor; + functor(context->eigen_device<Device>(), out->flat<T>(), + Tvalue.scalar<T>()); + } +}; + +#define REGISTER_KERNEL(D, TYPE) \ + REGISTER_KERNEL_BUILDER(Name("Fill") \ + .Device(DEVICE_##D) \ + .TypeConstraint<TYPE>("T") \ + .HostMemory("dims"), \ + FillOp<D##Device, TYPE>); + +#define REGISTER_CPU_KERNEL(TYPE) REGISTER_KERNEL(CPU, TYPE) +TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL); +#undef REGISTER_CPU_KERNEL + +#if GOOGLE_CUDA +REGISTER_KERNEL(GPU, float); +REGISTER_KERNEL(GPU, double); +REGISTER_KERNEL(GPU, uint8); +REGISTER_KERNEL(GPU, int8); +REGISTER_KERNEL(GPU, int16); +REGISTER_KERNEL(GPU, int64); +// Currently we do not support filling strings and complex64 on GPU + +#endif // GOOGLE_CUDA + +#undef REGISTER_KERNEL + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Fill") + .Device(DEVICE_GPU) + .TypeConstraint<int32>("T") + .HostMemory("dims") + .HostMemory("value") + .HostMemory("output"), + FillOp<CPUDevice, int32>); + +template <typename Device, typename T> +class ZerosLikeOp : public OpKernel { + public: + explicit ZerosLikeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& input = ctx->input(0); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &out)); + Tensor zero(DataTypeToEnum<T>::value, {1}); + zero.scalar<T>().setZero(); + const Tensor& zero_cref = zero; + functor::FillFunctor<Device, T> functor; + functor(ctx->eigen_device<Device>(), out->flat<T>(), zero_cref.scalar<T>()); + } +}; + +#define REGISTER_KERNEL(type, dev) \ + REGISTER_KERNEL_BUILDER( \ + Name("ZerosLike").Device(DEVICE_##dev).TypeConstraint<type>("T"), \ + ZerosLikeOp<dev##Device, type>) + +#define REGISTER_CPU(type) REGISTER_KERNEL(type, CPU) +TF_CALL_ALL_TYPES(REGISTER_CPU); +#undef REGISTER_CPU + +#if GOOGLE_CUDA +REGISTER_KERNEL(float, GPU); +REGISTER_KERNEL(double, GPU); +#endif // GOOGLE_CUDA + +#undef REGISTER_KERNEL + +class PlaceholderOp : public OpKernel { + public: + explicit PlaceholderOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &expected_shape_)); + } + + void Compute(OpKernelContext* ctx) override { + if (expected_shape_.dims() > 0) { + OP_REQUIRES(ctx, false, + errors::InvalidArgument( + "You must feed a value for placeholder tensor '", name(), + "' with dtype ", DataTypeString(output_type(0)), + " and shape ", expected_shape_.DebugString())); + } else { + OP_REQUIRES(ctx, false, + errors::InvalidArgument( + "You must feed a value for placeholder tensor '", name(), + "' with dtype ", DataTypeString(output_type(0)))); + } + } + + private: + TensorShape expected_shape_; +}; + +REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE_CPU), PlaceholderOp); + +} // namespace tensorflow |