aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
authorGravatar Xiaoqiang Zheng <zhengxq@google.com>2016-10-28 10:29:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-28 11:38:26 -0700
commite2d51a87f0727f8537b46048d8241aeebb6e48d6 (patch)
tree64c075f59bae00706a009e5d1ed15aaff6adc6ff /tensorflow/core/kernels
parentf80ef2d696456c970956f47e7d5aa88bc7ccbdce (diff)
Merge changes from github.
Change: 137532946
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_select.cu.cc24
-rw-r--r--tensorflow/core/kernels/cwise_op_select.cc35
-rw-r--r--tensorflow/core/kernels/cwise_ops.h8
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc4
-rw-r--r--tensorflow/core/kernels/matrix_triangular_solve_op.cc162
-rw-r--r--tensorflow/core/kernels/range_sampler.h2
6 files changed, 234 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc
index d71fdac497..a54dbdfc24 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc
@@ -16,6 +16,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow {
namespace functor {
@@ -32,6 +33,28 @@ struct SelectFunctor<GPUDevice, T> {
};
template <typename T>
+struct SelectScalarFunctor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<bool>::ConstScalar cond,
+ typename TTypes<T>::ConstFlat then_flat,
+ typename TTypes<T>::ConstFlat else_flat) {
+
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ Eigen::array<int, 1> rank1{1};
+#else
+ Eigen::IndexList<Eigen::type2index<1>> rank1;
+#endif
+ const int size = then_flat.dimension(0);
+ Eigen::array<int, 1> broadcast_dims{size};
+
+ To32Bit(out).device(d) = cond.reshape(rank1)
+ .broadcast(broadcast_dims)
+ .select(then_flat, else_flat);
+
+ }
+};
+
+template <typename T>
struct BatchSelectFunctor<GPUDevice, T> {
void operator()(const GPUDevice& d,
typename TTypes<T>::Matrix output_flat_outer_dims,
@@ -68,6 +91,7 @@ struct BatchSelectFunctor<GPUDevice, T> {
#define SELECT_FUNCTOR(T) \
template struct SelectFunctor<GPUDevice, T>; \
+ template struct SelectScalarFunctor<GPUDevice, T>; \
template struct BatchSelectFunctor<GPUDevice, T>;
SELECT_FUNCTOR(Eigen::half);
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
index fbfde88e61..8160fb74c2 100644
--- a/tensorflow/core/kernels/cwise_op_select.cc
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -41,6 +41,11 @@ class SelectOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->input("t", &then));
OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
+ if (TensorShapeUtils::IsScalar(cond->shape())){
+ ComputeScalar(ctx, cond, then, else_);
+ return;
+ }
+
bool broadcasting = (TensorShapeUtils::IsVector(cond->shape()) &&
!TensorShapeUtils::IsVector(then->shape()));
@@ -108,6 +113,25 @@ class SelectOp : public OpKernel {
}
}
+ void ComputeScalar(OpKernelContext* ctx, const Tensor* cond,
+ const Tensor* then, const Tensor* else_) {
+ OP_REQUIRES(
+ ctx, then->shape().IsSameSize(else_->shape()),
+ errors::InvalidArgument(
+ "'then' and 'else' must have the same size. but received: ",
+ then->shape().DebugString(), " vs. ",
+ else_->shape().DebugString()));
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
+
+ if (output->NumElements() > 0) {
+ functor::SelectScalarFunctor<Device, T> func;
+ TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
+ func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
+ then->flat<T>(), else_->flat<T>());
+ }
+ }
private:
TF_DISALLOW_COPY_AND_ASSIGN(SelectOp);
};
@@ -152,6 +176,17 @@ struct SelectFunctor<CPUDevice, T> {
}
};
+// CPU Specializations of Select functors with scalar
+template <typename T>
+struct SelectScalarFunctor<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
+ TTypes<bool>::ConstScalar cond,
+ typename TTypes<T>::ConstFlat then_flat,
+ typename TTypes<T>::ConstFlat else_flat) {
+ out.device(d) = cond() ? then_flat : else_flat;
+ }
+};
+
template <typename T>
struct BatchSelectFunctor<CPUDevice, T> {
void operator()(const CPUDevice& d,
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 2a77376a42..572a729b34 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -720,6 +720,14 @@ struct SelectFunctor {
};
template <typename Device, typename T>
+struct SelectScalarFunctor {
+ void operator()(const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<bool>::ConstScalar cond,
+ typename TTypes<T>::ConstFlat then_flat,
+ typename TTypes<T>::ConstFlat else_flat);
+};
+
+template <typename Device, typename T>
struct BatchSelectFunctor {
void operator()(const Device& d,
typename TTypes<T>::Matrix output_flat_outer_dims,
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 0acf82c9de..b256d24517 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -21,7 +21,11 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
+#if !defined(_MSC_VER)
#define UNROLL _Pragma("unroll")
+#else
+#define UNROLL
+#endif
namespace tensorflow {
diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op.cc b/tensorflow/core/kernels/matrix_triangular_solve_op.cc
index 09f75f2d5f..5f30a95108 100644
--- a/tensorflow/core/kernels/matrix_triangular_solve_op.cc
+++ b/tensorflow/core/kernels/matrix_triangular_solve_op.cc
@@ -25,8 +25,25 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
+#if GOOGLE_CUDA
+#include "tensorflow/core/platform/stream_executor.h"
+#endif // GOOGLE_CUDA
+
namespace tensorflow {
+#if GOOGLE_CUDA
+namespace {
+template <typename Scalar>
+perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory(
+ const Scalar* cuda_memory) {
+ perftools::gputools::DeviceMemoryBase wrapped(
+ const_cast<Scalar*>(cuda_memory));
+ perftools::gputools::DeviceMemory<Scalar> typed(wrapped);
+ return typed;
+}
+} // namespace
+#endif // GOOGLE_CUDA
+
template <class Scalar>
class MatrixTriangularSolveOp : public LinearAlgebraOp<Scalar> {
public:
@@ -60,7 +77,9 @@ class MatrixTriangularSolveOp : public LinearAlgebraOp<Scalar> {
int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
- double cost = rows * rows * num_rhss;
+ double cost = rows * rows * num_rhss *
+ (Eigen::TensorOpCost::AddCost<Scalar>() +
+ Eigen::TensorOpCost::MulCost<Scalar>());
return cost >= static_cast<double>(kint64max) ? kint64max
: static_cast<int64>(cost);
}
@@ -103,6 +122,121 @@ class MatrixTriangularSolveOp : public LinearAlgebraOp<Scalar> {
TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOp);
};
+
+#ifdef GOOGLE_CUDA
+template <class Scalar>
+class MatrixTriangularSolveOpGPU : public LinearAlgebraOp<Scalar> {
+ public:
+ typedef LinearAlgebraOp<Scalar> Base;
+
+ explicit MatrixTriangularSolveOpGPU(OpKernelConstruction* context)
+ : Base(context), lower_(true), adjoint_(false) {
+ OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
+ OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
+ }
+
+ using TensorShapes = typename Base::TensorShapes;
+ using Matrix = typename Base::Matrix;
+ using MatrixMap = typename Base::MatrixMap;
+ using MatrixMaps = typename Base::MatrixMaps;
+ using ConstMatrixMap = typename Base::ConstMatrixMap;
+ using ConstMatrixMaps = typename Base::ConstMatrixMaps;
+
+ virtual void ValidateInputMatrixShapes(
+ OpKernelContext* context,
+ const TensorShapes& input_matrix_shapes) const final {
+ Base::ValidateSquareSolver(context, input_matrix_shapes);
+ }
+
+ TensorShapes GetOutputMatrixShapes(
+ const TensorShapes& input_matrix_shapes) const final {
+ return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
+ input_matrix_shapes[1].dim_size(1)})});
+ }
+
+ int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
+ double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
+ double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
+ double cost = rows * rows * num_rhss *
+ (Eigen::TensorOpCost::AddCost<Scalar>() +
+ Eigen::TensorOpCost::MulCost<Scalar>());
+ return cost >= static_cast<double>(kint64max) ? kint64max
+ : static_cast<int64>(cost);
+ }
+
+ void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
+ MatrixMaps* outputs) final {
+ const ConstMatrixMap& matrix = inputs[0];
+ const ConstMatrixMap& rhs = inputs[1];
+ MatrixMap& output = outputs->at(0);
+
+ if (matrix.rows() == 0 || rhs.cols() == 0) {
+ // To be consistent with the MatrixInverse op, we define the solution for
+ // an empty set of equation as the empty matrix.
+ return;
+ }
+
+ auto matrix_ptr = AsDeviceMemory(matrix.data());
+ auto rhs_ptr = AsDeviceMemory(rhs.data());
+ auto out_ptr = AsDeviceMemory(output.data());
+
+ auto* stream = context->op_device_context()->stream();
+ uint64 rhs_elems = rhs.rows() * rhs.cols();
+ bool copy_status =
+ stream->ThenMemcpyD2D(&out_ptr, rhs_ptr, sizeof(Scalar) * rhs_elems)
+ .ok();
+ if (!copy_status) {
+ context->SetStatus(
+ errors::Internal("Failed to copy rhs into output before solve"));
+ }
+
+ // Cublas does
+ // output = matrix \ rhs
+ // where matrix, rhs and output are assumed to be in column major.
+ // We want the output to be in row-major, so we can compute
+ // output' = rhs' / matrix' (' stands for transpose)
+ // Upper/lower needs to be swapped for this.
+
+ perftools::gputools::blas::UpperLower upper_lower_matrix;
+ perftools::gputools::blas::Transpose transpose_matrix;
+ if (lower_) {
+ upper_lower_matrix = perftools::gputools::blas::UpperLower::kUpper;
+ } else {
+ upper_lower_matrix = perftools::gputools::blas::UpperLower::kLower;
+ }
+ if (adjoint_) {
+ transpose_matrix = perftools::gputools::blas::Transpose::kTranspose;
+ } else {
+ transpose_matrix = perftools::gputools::blas::Transpose::kNoTranspose;
+ }
+ uint64 leading_dim_matrix = matrix.cols();
+ uint64 leading_dim_output = output.cols();
+ uint64 colmajor_rows = output.cols();
+ uint64 colmajor_cols = output.rows();
+ bool blas_launch_status =
+ stream
+ ->ThenBlasTrsm(perftools::gputools::blas::Side::kRight /*side*/,
+ upper_lower_matrix /*uplo*/,
+ transpose_matrix /*trans*/,
+ perftools::gputools::blas::Diagonal::kNonUnit /*diag*/,
+ colmajor_rows /*m*/, colmajor_cols /*n*/,
+ Scalar(1.0) /*alpha*/,
+ matrix_ptr, leading_dim_matrix /*lda*/,
+ &out_ptr, leading_dim_output /*ldb*/)
+ .ok();
+ if (!blas_launch_status) {
+ context->SetStatus(errors::Internal("Blas TRSM launch failed"));
+ }
+ }
+
+ private:
+ bool lower_;
+ bool adjoint_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOpGPU);
+};
+#endif // GOOGLE_CUDA
+
REGISTER_LINALG_OP("MatrixTriangularSolve", (MatrixTriangularSolveOp<float>),
float);
REGISTER_LINALG_OP("MatrixTriangularSolve", (MatrixTriangularSolveOp<double>),
@@ -112,4 +246,30 @@ REGISTER_LINALG_OP("BatchMatrixTriangularSolve",
REGISTER_LINALG_OP("BatchMatrixTriangularSolve",
(MatrixTriangularSolveOp<double>), double);
+#ifdef GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(
+ Name("MatrixTriangularSolve")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T"),
+ MatrixTriangularSolveOpGPU<float>);
+
+REGISTER_KERNEL_BUILDER(
+ Name("MatrixTriangularSolve")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<double>("T"),
+ MatrixTriangularSolveOpGPU<double>);
+
+REGISTER_KERNEL_BUILDER(
+ Name("BatchMatrixTriangularSolve")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T"),
+ MatrixTriangularSolveOpGPU<float>);
+
+REGISTER_KERNEL_BUILDER(
+ Name("BatchMatrixTriangularSolve")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<double>("T"),
+ MatrixTriangularSolveOpGPU<double>);
+#endif //GOOGLE_CUDA
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/range_sampler.h b/tensorflow/core/kernels/range_sampler.h
index 1372975dc9..3010666598 100644
--- a/tensorflow/core/kernels/range_sampler.h
+++ b/tensorflow/core/kernels/range_sampler.h
@@ -115,10 +115,12 @@ class AllSampler : public RangeSampler {
int64 Sample(random::SimplePhilox* rnd) const override {
LOG(FATAL) << "Should not be called";
+ return 0;
}
float Probability(int64 value) const override {
LOG(FATAL) << "Should not be called";
+ return 0;
}
void SampleBatchGetExpectedCountAvoid(