aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2018-03-29 10:50:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-29 10:52:55 -0700
commit63dffd5a3bc4e94e74cb140cbf7a68e0e5644ad6 (patch)
tree087e334cd7e71d44cdcd1ce5ebab483768fcf474 /tensorflow/core/kernels
parent9fbb5b3b8fef1caa2ee2ca4a0f8dde900d1f2aa5 (diff)
Automated g4 rollback of changelist 190858242
PiperOrigin-RevId: 190953197
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/mkl_fused_batch_norm_op.cc2
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.h7
-rw-r--r--tensorflow/core/kernels/snapshot_op.cc30
-rw-r--r--tensorflow/core/kernels/snapshot_op.h26
-rw-r--r--tensorflow/core/kernels/snapshot_op_gpu.cu.cc9
-rw-r--r--tensorflow/core/kernels/xent_op.cc65
-rw-r--r--tensorflow/core/kernels/xent_op.h35
-rw-r--r--tensorflow/core/kernels/xent_op_gpu.cu.cc9
8 files changed, 124 insertions, 59 deletions
diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
index 333a6570dc..62aafa7930 100644
--- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
@@ -933,7 +933,7 @@ class MklFusedBatchNormOp : public OpKernel {
bool is_training_;
T* mean_values_;
T* variance_values_;
- size_t depth_; // batch normalization is done for per channel.
+ int depth_; // batch normalization is done for per channel.
void ExtractParams(OpKernelContext* context) {
const Tensor& input = MklGetInput(context, 0);
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index 4abfbfb1a6..7badc00572 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -23,6 +23,13 @@ limitations under the License.
// non-GPU targets. This only breaks in clang, because it's more strict for
// template code and CudaAtomicMax is used in template context.
+// This file requires the following include because it uses CudaAtomicMax:
+// #include "tensorflow/core/util/cuda_kernel_helper.h"
+
+// Unfortunately we can't add the #include, since it breaks compilation for
+// non-GPU targets. This only breaks in clang, because it's more strict for
+// template code and CudaAtomicMax is used in template context.
+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
diff --git a/tensorflow/core/kernels/snapshot_op.cc b/tensorflow/core/kernels/snapshot_op.cc
index 50157d5d48..fe04dcf72e 100644
--- a/tensorflow/core/kernels/snapshot_op.cc
+++ b/tensorflow/core/kernels/snapshot_op.cc
@@ -22,6 +22,26 @@ limitations under the License.
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename Scalar>
+class SnapshotOp : public OpKernel {
+ public:
+ explicit SnapshotOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ Tensor* output = nullptr;
+ // Try to use buffer forwarding to avoid an explicit copy.
+ OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
+ {0}, 0, input.shape(), &output));
+ if (!output->SharesBufferWith(input)) {
+ functor::Snapshot<Device, Scalar> functor;
+ functor(context->eigen_device<Device>(), input.flat<Scalar>(),
+ output->flat<Scalar>());
+ }
+ }
+};
#define REGISTER_KERNEL(TYPE) \
REGISTER_KERNEL_BUILDER( \
@@ -31,6 +51,16 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
TF_CALL_POD_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL
+#if GOOGLE_CUDA
+#define REGISTER_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Snapshot").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
+ SnapshotOp<GPUDevice, TYPE>);
+
+TF_CALL_POD_TYPES(REGISTER_KERNEL);
+#undef REGISTER_KERNEL
+#endif
+
#if TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SyclDevice;
#define REGISTER_SYCL_KERNEL(TYPE) \
diff --git a/tensorflow/core/kernels/snapshot_op.h b/tensorflow/core/kernels/snapshot_op.h
index b94834f159..a18065d42b 100644
--- a/tensorflow/core/kernels/snapshot_op.h
+++ b/tensorflow/core/kernels/snapshot_op.h
@@ -26,29 +26,19 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
+namespace functor {
+// Functor used by SnapshotOp.
template <typename Device, typename Scalar>
-class SnapshotOp : public OpKernel {
- public:
- explicit SnapshotOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- const Tensor& input = context->input(0);
- Tensor* output = nullptr;
- // Try to use buffer forwarding to avoid an explicit copy.
- OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
- {0}, 0, input.shape(), &output));
- if (!output->SharesBufferWith(input)) {
- // We had to allocate a new buffer since the refcount on the input was
- // greater than 1. Copy the input to the new buffer.
- const Device& device = context->eigen_device<Device>();
- device.memcpy(output->template flat<Scalar>().data(),
- input.template flat<Scalar>().data(),
- input.NumElements() * sizeof(Scalar));
- }
+struct Snapshot {
+ void operator()(const Device& device,
+ typename TTypes<Scalar>::ConstTensor input,
+ typename TTypes<Scalar>::Tensor output) {
+ device.memcpy(output.data(), input.data(), input.size() * sizeof(Scalar));
}
};
+} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_SNAPSHOT_OP_H_
diff --git a/tensorflow/core/kernels/snapshot_op_gpu.cu.cc b/tensorflow/core/kernels/snapshot_op_gpu.cu.cc
index 52070be838..e4e3bd5220 100644
--- a/tensorflow/core/kernels/snapshot_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/snapshot_op_gpu.cu.cc
@@ -24,13 +24,10 @@ limitations under the License.
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
-#define REGISTER_KERNEL(TYPE) \
- REGISTER_KERNEL_BUILDER( \
- Name("Snapshot").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
- SnapshotOp<GPUDevice, TYPE>);
+// Definition of the GPU implementations declared in softsign_op.cc.
+#define DEFINE_GPU_KERNELS(T) template struct functor::Snapshot<GPUDevice, T>;
-TF_CALL_POD_TYPES(REGISTER_KERNEL);
-#undef REGISTER_KERNEL
+TF_CALL_POD_TYPES(DEFINE_GPU_KERNELS);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc
index a6a71fdfaf..9a3612bd72 100644
--- a/tensorflow/core/kernels/xent_op.cc
+++ b/tensorflow/core/kernels/xent_op.cc
@@ -17,12 +17,14 @@ limitations under the License.
#define EIGEN_USE_THREADS
-#include "tensorflow/core/kernels/xent_op.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/xent_op.h"
+#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
@@ -41,37 +43,56 @@ class SoftmaxXentWithLogitsOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& logits_in = context->input(0);
const Tensor& labels_in = context->input(1);
- OP_REQUIRES(context, logits_in.IsSameSize(labels_in),
- errors::InvalidArgument(
- "logits and labels must be same size: logits_size=",
- logits_in.shape().DebugString(),
- " labels_size=", labels_in.shape().DebugString()));
- OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()),
- errors::InvalidArgument("logits must be 2-dimensional"));
- // As we already tested that both inputs have the same shape no need to
- // check that "labels" is a matrix too.
+
+ TensorShape shape_in = logits_in.shape();
+
+ BCast bcast(BCast::FromShape(logits_in.shape()),
+ BCast::FromShape(labels_in.shape()));
+ if (!logits_in.IsSameSize(labels_in)) {
+ OP_REQUIRES(context, bcast.IsValid(),
+ errors::InvalidArgument(
+ "logits and labels must be broadcastable: logits_size=",
+ logits_in.shape().DebugString(),
+ " labels_size=", labels_in.shape().DebugString()));
+ shape_in = BCast::ToShape(bcast.output_shape());
+ }
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(shape_in),
+ errors::InvalidArgument("logits and labels must be beither "
+ "2-dimensional, or roadcasted to "
+ "2-dimensional"));
// loss is 1-D (one per example), and size is batch_size.
Tensor scratch;
OP_REQUIRES_OK(
context, context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({logits_in.dim_size(0), 1}),
+ TensorShape({shape_in.dim_size(0), 1}),
&scratch));
Tensor* loss_out = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(
- 0, TensorShape({logits_in.dim_size(0)}), &loss_out));
+ 0, TensorShape({shape_in.dim_size(0)}), &loss_out));
Tensor* back_out = nullptr;
// Try to reuse the logits_in buffer for the backprop output.
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
- {0}, 1, logits_in.shape(), &back_out));
- if (logits_in.dim_size(0) > 0) {
+ {0}, 1, shape_in, &back_out));
+ if (shape_in.dim_size(0) > 0) {
functor::XentFunctor<Device, T> functor;
- functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
- labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(),
- back_out->matrix<T>());
+ if (logits_in.IsSameSize(labels_in)) {
+ functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(),
+ Eigen::array<Eigen::DenseIndex, 2>{1, 1},
+ Eigen::array<Eigen::DenseIndex, 2>{1, 1}, logits_in.matrix<T>(),
+ labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(),
+ back_out->matrix<T>());
+ } else {
+ functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(),
+ BCast::ToIndexArray<2>(bcast.x_bcast()),
+ BCast::ToIndexArray<2>(bcast.y_bcast()),
+ logits_in.template shaped<T, 2>(bcast.x_reshape()),
+ labels_in.template shaped<T, 2>(bcast.y_reshape()),
+ scratch.matrix<T>(), loss_out->vec<T>(), back_out->matrix<T>());
+ }
}
}
};
@@ -81,13 +102,17 @@ class SoftmaxXentWithLogitsOp : public OpKernel {
namespace functor {
template <typename Device, typename T>
struct XentFunctorBase {
- void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits,
+ void operator()(const Device& d,
+ const Eigen::DSizes<Eigen::DenseIndex, 2>& shape,
+ const Eigen::array<Eigen::DenseIndex, 2>& logits_bcast,
+ const Eigen::array<Eigen::DenseIndex, 2>& labels_bcast,
+ typename TTypes<T>::ConstMatrix logits,
typename TTypes<T>::ConstMatrix labels,
typename TTypes<T>::Matrix scratch,
typename TTypes<T>::Vec loss,
typename TTypes<T>::Matrix backprop) {
- XentEigenImpl<Device, T>::Compute(d, logits, labels, scratch, loss,
- backprop);
+ XentEigenImpl<Device, T>::Compute(d, shape, logits_bcast, labels_bcast,
+ logits, labels, scratch, loss, backprop);
}
};
diff --git a/tensorflow/core/kernels/xent_op.h b/tensorflow/core/kernels/xent_op.h
index e689fca7ff..87be17fca9 100644
--- a/tensorflow/core/kernels/xent_op.h
+++ b/tensorflow/core/kernels/xent_op.h
@@ -18,6 +18,7 @@ limitations under the License.
// Functor definition for XentOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
@@ -33,7 +34,11 @@ struct XentFunctor {
// scratch: temporary tensor, dims: batch_size, 1
// loss: output tensor for the loss, dims: batch_size.
// backprop: output tensor for the backprop, dims: batch_size, num_classes.
- void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits,
+ void operator()(const Device &d,
+ const Eigen::DSizes<Eigen::DenseIndex, 2> &shape,
+ const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast,
+ const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast,
+ typename TTypes<T>::ConstMatrix logits,
typename TTypes<T>::ConstMatrix labels,
typename TTypes<T>::Matrix scratch,
typename TTypes<T>::Vec loss,
@@ -45,7 +50,11 @@ struct XentFunctor {
// specializations for both device types.
template <typename Device, typename T>
struct XentEigenImpl {
- static void Compute(const Device& d, typename TTypes<T>::ConstMatrix logits,
+ static void Compute(const Device &d,
+ const Eigen::DSizes<Eigen::DenseIndex, 2> &shape,
+ const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast,
+ const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast,
+ typename TTypes<T>::ConstMatrix logits,
typename TTypes<T>::ConstMatrix labels,
typename TTypes<T>::Matrix scratch,
typename TTypes<T>::Vec loss,
@@ -57,8 +66,8 @@ struct XentEigenImpl {
const int kBatchDim = 0;
const int kClassDim = 1;
- const int batch_size = logits.dimension(kBatchDim);
- const int num_classes = logits.dimension(kClassDim);
+ const int batch_size = shape[kBatchDim];
+ const int num_classes = shape[kClassDim];
// These arrays are used to reduce along the class dimension, and broadcast
// the resulting value to all classes.
@@ -84,10 +93,12 @@ struct XentEigenImpl {
#endif
// max_logits along classes.
- scratch.reshape(batch_only).device(d) = logits.maximum(along_class);
+ scratch.reshape(batch_only).device(d) =
+ logits.broadcast(logits_bcast).maximum(along_class);
// logits - max_logits.
- backprop.device(d) = logits - scratch.broadcast(one_by_class);
+ backprop.device(d) =
+ logits.broadcast(logits_bcast) - scratch.broadcast(one_by_class);
// sum(exp(logits - max_logits)) along classes.
scratch.reshape(batch_only).device(d) = backprop.exp().sum(along_class);
@@ -99,15 +110,15 @@ struct XentEigenImpl {
// sum(-labels *
// ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
// along classes
- loss.device(d) =
- (labels * (scratch.log().eval().broadcast(one_by_class) - backprop))
- .eval()
- .sum(along_class);
+ loss.device(d) = (labels.broadcast(labels_bcast) *
+ (scratch.log().eval().broadcast(one_by_class) - backprop))
+ .eval()
+ .sum(along_class);
// backprop: prob - labels, where
// prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
- backprop.device(d) =
- (backprop.exp() / scratch.broadcast(one_by_class)) - labels;
+ backprop.device(d) = (backprop.exp() / scratch.broadcast(one_by_class)) -
+ labels.broadcast(labels_bcast);
}
};
diff --git a/tensorflow/core/kernels/xent_op_gpu.cu.cc b/tensorflow/core/kernels/xent_op_gpu.cu.cc
index 05ee7da490..2c0c0b3a02 100644
--- a/tensorflow/core/kernels/xent_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/xent_op_gpu.cu.cc
@@ -31,12 +31,17 @@ typedef Eigen::GpuDevice GPUDevice;
namespace functor {
template <typename T>
struct XentFunctor<GPUDevice, T> {
- void operator()(const GPUDevice& d, typename TTypes<T>::ConstMatrix logits,
+ void operator()(const GPUDevice &d,
+ const Eigen::DSizes<Eigen::DenseIndex, 2> &shape,
+ const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast,
+ const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast,
+ typename TTypes<T>::ConstMatrix logits,
typename TTypes<T>::ConstMatrix labels,
typename TTypes<T>::Matrix scratch,
typename TTypes<T>::Vec loss,
typename TTypes<T>::Matrix backprop) {
- XentEigenImpl<GPUDevice, T>::Compute(d, logits, labels, scratch, loss,
+ XentEigenImpl<GPUDevice, T>::Compute(d, shape, logits_bcast, labels_bcast,
+ logits, labels, scratch, loss,
backprop);
}
};