aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.gitmodules2
-rw-r--r--WORKSPACE4
-rw-r--r--eigen.BUILD2
-rw-r--r--tensorflow/core/framework/tensor_shape.cc4
-rw-r--r--tensorflow/core/graph/equal_graph_def_test.cc56
-rw-r--r--tensorflow/core/graph/graph_partition_test.cc36
-rw-r--r--tensorflow/core/kernels/cross_op.cc112
-rw-r--r--tensorflow/core/kernels/cross_op.h54
-rw-r--r--tensorflow/core/kernels/cross_op_gpu.cu.cc34
-rw-r--r--tensorflow/core/kernels/cross_op_test.cc102
-rw-r--r--tensorflow/core/kernels/transpose_op.cc6
-rw-r--r--tensorflow/core/ops/math_ops.cc19
-rw-r--r--tensorflow/core/ops/ops.pbtxt36
-rw-r--r--tensorflow/python/kernel_tests/cross_grad_test.py41
-rw-r--r--tensorflow/python/kernel_tests/seq2seq_test.py2
-rw-r--r--tensorflow/python/ops/array_ops.py4
-rw-r--r--tensorflow/python/ops/math_grad.py7
-rw-r--r--tensorflow/python/ops/math_ops.py2
-rw-r--r--third_party/eigen3/Eigen/Cholesky2
-rw-r--r--third_party/eigen3/Eigen/Core2
-rw-r--r--third_party/eigen3/Eigen/Eigenvalues3
-rw-r--r--third_party/eigen3/Eigen/LU2
-rw-r--r--third_party/eigen3/Eigen/QR2
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/Tensor3
24 files changed, 477 insertions, 60 deletions
diff --git a/.gitmodules b/.gitmodules
index 1b17ea57b9..0edca21239 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,3 +1,3 @@
[submodule "google/protobuf"]
path = google/protobuf
- url = https://github.com/google/protobuf.git
+ url = https://github.googlesource.com/google/protobuf.git
diff --git a/WORKSPACE b/WORKSPACE
index 0ca6ad8e58..a52fdf8345 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -21,8 +21,8 @@ new_http_archive(
new_http_archive(
name = "eigen_archive",
- url = "https://bitbucket.org/eigen/eigen/get/8cd7c2c.tar.gz",
- sha256 = "30b77010c49a28875c76f5941cab06d0e15c52dc193be9729def53b6ea1fdb57",
+ url = "https://bitbucket.org/eigen/eigen/get/726c779.tar.gz",
+ sha256 = "30e0c5d84cfefc6a0bf7ae1e682b22788b5b2e408e7db7d9ea2d2aa9f70a72a9",
build_file = "eigen.BUILD",
)
diff --git a/eigen.BUILD b/eigen.BUILD
index bb78603b7a..084689c6f4 100644
--- a/eigen.BUILD
+++ b/eigen.BUILD
@@ -1,6 +1,6 @@
package(default_visibility = ["//visibility:public"])
-archive_dir = "eigen-eigen-8cd7c2c6e9e1"
+archive_dir = "eigen-eigen-726c779797e8"
cc_library(
name = "eigen",
diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc
index 697e170729..a6d540bbe1 100644
--- a/tensorflow/core/framework/tensor_shape.cc
+++ b/tensorflow/core/framework/tensor_shape.cc
@@ -143,7 +143,7 @@ void TensorShape::AddDim(int64 size) {
// to allow REP32.
bool can_be_rep32 = (vals.size() <= 3);
if (can_be_rep32) {
- for (int i = 0; i < vals.size(); i++) {
+ for (size_t i = 0; i < vals.size(); i++) {
if (vals[i] >= kMaxRep32) {
can_be_rep32 = false;
break;
@@ -152,7 +152,7 @@ void TensorShape::AddDim(int64 size) {
}
if (can_be_rep32) {
set_tag(REP32);
- for (int d = 0; d < vals.size(); d++) {
+ for (size_t d = 0; d < vals.size(); d++) {
as32()->dims_[d] = static_cast<int32>(vals[d]);
}
} else {
diff --git a/tensorflow/core/graph/equal_graph_def_test.cc b/tensorflow/core/graph/equal_graph_def_test.cc
index e7cbb9853f..77ff0d9bce 100644
--- a/tensorflow/core/graph/equal_graph_def_test.cc
+++ b/tensorflow/core/graph/equal_graph_def_test.cc
@@ -27,7 +27,7 @@ namespace {
REGISTER_OP("Input").Output("o: float");
REGISTER_OP("Alternate").Output("o: float");
-REGISTER_OP("Cross").Input("a: float").Input("b: float").Output("o: float");
+REGISTER_OP("Combine").Input("a: float").Input("b: float").Output("o: float");
Node* Input(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("Input", opts);
@@ -37,9 +37,9 @@ Node* Alternate(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("Alternate", opts);
}
-Node* Cross(ops::NodeOut a, ops::NodeOut b,
- const GraphDefBuilder::Options& opts) {
- return ops::BinaryOp("Cross", a, b, opts);
+Node* Combine(ops::NodeOut a, ops::NodeOut b,
+ const GraphDefBuilder::Options& opts) {
+ return ops::BinaryOp("Combine", a, b, opts);
}
class EqualGraphDefTest : public ::testing::Test {
@@ -99,11 +99,11 @@ TEST_F(EqualGraphDefTest, ExtraNode) {
TEST_F(EqualGraphDefTest, NodeOrder) {
Node* a = Input(e_.opts().WithName("A"));
Node* b = Input(e_.opts().WithName("B"));
- Cross(a, b, e_.opts().WithName("C"));
+ Combine(a, b, e_.opts().WithName("C"));
b = Input(a_.opts().WithName("B"));
a = Input(a_.opts().WithName("A"));
- Cross(a, b, a_.opts().WithName("C"));
+ Combine(a, b, a_.opts().WithName("C"));
EXPECT_TRUE(Match()) << diff_;
}
@@ -141,11 +141,11 @@ TEST_F(EqualGraphDefTest, DeviceMismatch) {
TEST_F(EqualGraphDefTest, InputMismatch) {
Node* a = Input(e_.opts().WithName("A"));
Node* b = Input(e_.opts().WithName("B"));
- Cross(a, a, e_.opts().WithName("C"));
+ Combine(a, a, e_.opts().WithName("C"));
a = Input(a_.opts().WithName("A"));
b = Input(a_.opts().WithName("B"));
- Cross(b, b, a_.opts().WithName("C"));
+ Combine(b, b, a_.opts().WithName("C"));
EXPECT_FALSE(Match());
EXPECT_EQ("Node named 'C' has input 0 'B' that doesn't match expected 'A'",
diff_);
@@ -154,11 +154,11 @@ TEST_F(EqualGraphDefTest, InputMismatch) {
TEST_F(EqualGraphDefTest, InputOrderMismatch) {
Node* a = Input(e_.opts().WithName("A"));
Node* b = Input(e_.opts().WithName("B"));
- Cross(a, b, e_.opts().WithName("C"));
+ Combine(a, b, e_.opts().WithName("C"));
a = Input(a_.opts().WithName("A"));
b = Input(a_.opts().WithName("B"));
- Cross(b, a, a_.opts().WithName("C"));
+ Combine(b, a, a_.opts().WithName("C"));
EXPECT_FALSE(Match());
EXPECT_EQ("Node named 'C' has input 0 'B' that doesn't match expected 'A'",
diff_);
@@ -169,21 +169,21 @@ TEST_F(EqualGraphDefTest, ControlInputOrder) {
Node* b = Input(e_.opts().WithName("B"));
Node* c = Input(e_.opts().WithName("C"));
Node* d = Input(e_.opts().WithName("D"));
- Cross(a, a, e_.opts()
- .WithName("E")
- .WithControlInput(b)
- .WithControlInput(c)
- .WithControlInput(d));
+ Combine(a, a, e_.opts()
+ .WithName("E")
+ .WithControlInput(b)
+ .WithControlInput(c)
+ .WithControlInput(d));
a = Input(a_.opts().WithName("A"));
b = Input(a_.opts().WithName("B"));
c = Input(a_.opts().WithName("C"));
d = Input(a_.opts().WithName("D"));
- Cross(a, a, a_.opts()
- .WithName("E")
- .WithControlInput(c)
- .WithControlInput(d)
- .WithControlInput(b));
+ Combine(a, a, a_.opts()
+ .WithName("E")
+ .WithControlInput(c)
+ .WithControlInput(d)
+ .WithControlInput(b));
EXPECT_TRUE(Match()) << diff_;
}
@@ -192,13 +192,15 @@ TEST_F(EqualGraphDefTest, ControlInputMismatch) {
Node* b = Input(e_.opts().WithName("B"));
Node* c = Input(e_.opts().WithName("C"));
Node* d = Input(e_.opts().WithName("D"));
- Cross(a, a, e_.opts().WithName("E").WithControlInput(b).WithControlInput(c));
+ Combine(a, a,
+ e_.opts().WithName("E").WithControlInput(b).WithControlInput(c));
a = Input(a_.opts().WithName("A"));
b = Input(a_.opts().WithName("B"));
c = Input(a_.opts().WithName("C"));
d = Input(a_.opts().WithName("D"));
- Cross(a, a, a_.opts().WithName("E").WithControlInput(b).WithControlInput(d));
+ Combine(a, a,
+ a_.opts().WithName("E").WithControlInput(b).WithControlInput(d));
EXPECT_FALSE(Match());
EXPECT_EQ("Node named 'E' missing expected control input '^C'", diff_);
}
@@ -207,12 +209,13 @@ TEST_F(EqualGraphDefTest, ControlInputAdded) {
Node* a = Input(e_.opts().WithName("A"));
Node* b = Input(e_.opts().WithName("B"));
Node* c = Input(e_.opts().WithName("C"));
- Cross(a, a, e_.opts().WithName("D").WithControlInput(b));
+ Combine(a, a, e_.opts().WithName("D").WithControlInput(b));
a = Input(a_.opts().WithName("A"));
b = Input(a_.opts().WithName("B"));
c = Input(a_.opts().WithName("C"));
- Cross(a, a, a_.opts().WithName("D").WithControlInput(b).WithControlInput(c));
+ Combine(a, a,
+ a_.opts().WithName("D").WithControlInput(b).WithControlInput(c));
EXPECT_FALSE(Match());
EXPECT_EQ(
"Node named 'D' has inputs 'A, A, ^B, ^C' that don't match "
@@ -224,12 +227,13 @@ TEST_F(EqualGraphDefTest, ControlInputRemoved) {
Node* a = Input(e_.opts().WithName("A"));
Node* b = Input(e_.opts().WithName("B"));
Node* c = Input(e_.opts().WithName("C"));
- Cross(a, a, e_.opts().WithName("D").WithControlInput(b).WithControlInput(c));
+ Combine(a, a,
+ e_.opts().WithName("D").WithControlInput(b).WithControlInput(c));
a = Input(a_.opts().WithName("A"));
b = Input(a_.opts().WithName("B"));
c = Input(a_.opts().WithName("C"));
- Cross(a, a, a_.opts().WithName("D").WithControlInput(b));
+ Combine(a, a, a_.opts().WithName("D").WithControlInput(b));
EXPECT_FALSE(Match());
EXPECT_EQ(
"Node named 'D' has inputs 'A, A, ^B' that don't match "
diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc
index 4bbfc5fe5c..ecdf9ed74f 100644
--- a/tensorflow/core/graph/graph_partition_test.cc
+++ b/tensorflow/core/graph/graph_partition_test.cc
@@ -131,7 +131,7 @@ void CheckLoopConstruction(const GraphDef& graph_def) {
REGISTER_OP("Input").Output("o: float");
REGISTER_OP("BoolInput").Output("o: bool");
-REGISTER_OP("Cross").Input("a: float").Input("b: float").Output("o: float");
+REGISTER_OP("Combine").Input("a: float").Input("b: float").Output("o: float");
Node* Input(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("Input", opts);
@@ -141,9 +141,9 @@ Node* BoolInput(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("BoolInput", opts);
}
-Node* Cross(ops::NodeOut a, ops::NodeOut b,
- const GraphDefBuilder::Options& opts) {
- return ops::BinaryOp("Cross", a, b, opts);
+Node* Combine(ops::NodeOut a, ops::NodeOut b,
+ const GraphDefBuilder::Options& opts) {
+ return ops::BinaryOp("Combine", a, b, opts);
}
class GraphPartitionTest : public ::testing::Test {
@@ -188,13 +188,13 @@ class GraphPartitionTest : public ::testing::Test {
TEST_F(GraphPartitionTest, SingleDevice) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Node* a1 = Input(in_.opts().WithName("A1"));
- Cross(a1, a1, in_.opts().WithName("A2"));
+ Combine(a1, a1, in_.opts().WithName("A2"));
Partition(ToGraphDef(), &partitions_);
EXPECT_EQ(1, partitions_.size());
a1 = Input(a_opts_.WithName("A1"));
- Cross(a1, a1, a_opts_.WithName("A2"));
+ Combine(a1, a1, a_opts_.WithName("A2"));
ExpectMatchA();
}
@@ -202,7 +202,7 @@ TEST_F(GraphPartitionTest, CrossDeviceData) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Node* a1 = Input(in_.opts().WithName("A1"));
Node* b1 = Input(in_.opts().WithName("B1"));
- Cross(a1, b1, in_.opts().WithName("B2"));
+ Combine(a1, b1, in_.opts().WithName("B2"));
Partition(ToGraphDef(), &partitions_);
EXPECT_EQ(2, partitions_.size());
@@ -216,7 +216,7 @@ TEST_F(GraphPartitionTest, CrossDeviceData) {
b1 = Input(b_opts_.WithName("B1"));
Node* recv =
_Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_1"));
- Cross(recv, b1, b_opts_.WithName("B2"));
+ Combine(recv, b1, b_opts_.WithName("B2"));
ExpectMatchB();
}
@@ -224,7 +224,7 @@ TEST_F(GraphPartitionTest, CrossDeviceControl) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Node* a1 = Input(in_.opts().WithName("A1"));
Node* b1 = Input(in_.opts().WithName("B1"));
- Cross(b1, b1, in_.opts().WithName("B2").WithControlInput(a1));
+ Combine(b1, b1, in_.opts().WithName("B2").WithControlInput(a1));
Partition(ToGraphDef(), &partitions_);
EXPECT_EQ(2, partitions_.size());
@@ -240,7 +240,7 @@ TEST_F(GraphPartitionTest, CrossDeviceControl) {
_Recv(DT_FLOAT, "edge_3_A1", a, 82, b, b_opts_.WithName("A1/_2"));
Node* id = Identity(recv, b_opts_.WithName("A1/_3"));
b1 = Input(b_opts_.WithName("B1"));
- Cross(b1, b1, b_opts_.WithName("B2").WithControlInput(id));
+ Combine(b1, b1, b_opts_.WithName("B2").WithControlInput(id));
ExpectMatchB();
}
@@ -248,8 +248,8 @@ TEST_F(GraphPartitionTest, CrossDeviceData_MultiUse) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Node* a1 = Input(in_.opts().WithName("A1"));
Node* b1 = Input(in_.opts().WithName("B1"));
- Cross(a1, b1, in_.opts().WithName("B2"));
- Cross(a1, a1, in_.opts().WithName("B3"));
+ Combine(a1, b1, in_.opts().WithName("B2"));
+ Combine(a1, a1, in_.opts().WithName("B3"));
Partition(ToGraphDef(), &partitions_);
EXPECT_EQ(2, partitions_.size());
@@ -263,8 +263,8 @@ TEST_F(GraphPartitionTest, CrossDeviceData_MultiUse) {
Node* recv =
_Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_1"));
b1 = Input(b_opts_.WithName("B1"));
- Cross(recv, b1, b_opts_.WithName("B2"));
- Cross(recv, recv, b_opts_.WithName("B3"));
+ Combine(recv, b1, b_opts_.WithName("B2"));
+ Combine(recv, recv, b_opts_.WithName("B3"));
ExpectMatchB();
}
@@ -272,7 +272,7 @@ TEST_F(GraphPartitionTest, CrossDeviceControl_MultiUse) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Node* a1 = Input(in_.opts().WithName("A1"));
Node* b1 = Input(in_.opts().WithName("B1"));
- Cross(b1, b1, in_.opts().WithName("B2").WithControlInput(a1));
+ Combine(b1, b1, in_.opts().WithName("B2").WithControlInput(a1));
Input(in_.opts().WithName("B3").WithControlInput(a1));
Partition(ToGraphDef(), &partitions_);
@@ -289,7 +289,7 @@ TEST_F(GraphPartitionTest, CrossDeviceControl_MultiUse) {
_Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_2"));
Node* id = Identity(recv, b_opts_.WithName("A1/_3"));
b1 = Input(b_opts_.WithName("B1"));
- Cross(b1, b1, b_opts_.WithName("B2").WithControlInput(id));
+ Combine(b1, b1, b_opts_.WithName("B2").WithControlInput(id));
Input(b_opts_.WithName("B3").WithControlInput(id));
ExpectMatchB();
}
@@ -298,7 +298,7 @@ TEST_F(GraphPartitionTest, CrossDevice_DataControl) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Node* a1 = Input(in_.opts().WithName("A1"));
Node* b1 = Input(in_.opts().WithName("B1"));
- Cross(a1, b1, in_.opts().WithName("B2"));
+ Combine(a1, b1, in_.opts().WithName("B2"));
Input(in_.opts().WithName("B3").WithControlInput(a1));
Partition(ToGraphDef(), &partitions_);
@@ -320,7 +320,7 @@ TEST_F(GraphPartitionTest, CrossDevice_DataControl) {
Node* recv2 =
_Recv(DT_FLOAT, "edge_2_A1", a, 82, b, b_opts_.WithName("A1/_5"));
b1 = Input(b_opts_.WithName("B1"));
- Cross(recv2, b1, b_opts_.WithName("B2"));
+ Combine(recv2, b1, b_opts_.WithName("B2"));
Input(b_opts_.WithName("B3").WithControlInput(id1));
ExpectMatchB();
}
diff --git a/tensorflow/core/kernels/cross_op.cc b/tensorflow/core/kernels/cross_op.cc
new file mode 100644
index 0000000000..4564940d0f
--- /dev/null
+++ b/tensorflow/core/kernels/cross_op.cc
@@ -0,0 +1,112 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/math_ops.cc.
+#define EIGEN_USE_THREADS
+
+#include <algorithm>
+#include <cmath>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/cross_op.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename Type>
+class CrossOp : public OpKernel {
+ public:
+ explicit CrossOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& in0 = context->input(0);
+ const Tensor& in1 = context->input(1);
+ OP_REQUIRES(context, in0.shape() == in1.shape(),
+ errors::InvalidArgument("Both inputs must be of same shape: ",
+ in0.shape().DebugString(), " vs. ",
+ in1.shape().DebugString()));
+ OP_REQUIRES(context, in0.dims() >= 1,
+ errors::InvalidArgument("Input must be at least 1D",
+ in0.shape().DebugString()));
+
+ // Cross-products only really make sense for three and
+ // seven dimensions, and the latter is very obscure. If there is
+ // demand, we could perhaps allow 2D vectors where the last
+ // element is taken to be zero, but for now, we simply require
+ // that all are 3D.
+ auto inner_dim = in0.dim_size(in0.dims() - 1);
+ OP_REQUIRES(context, inner_dim == 3,
+ errors::FailedPrecondition(
+ "Cross-products are only defined for 3-element vectors."));
+
+ // Create the output Tensor with the same dimensions as the input Tensors.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, in0.shape(), &output));
+
+ // Make a canonical tensor, maintaining the last (3-vector) dimension,
+ // while flattening all others do give the functor easy to work with data.
+ typename TTypes<Type, 2>::ConstTensor in0_data =
+ in0.flat_inner_dims<Type>();
+ typename TTypes<Type, 2>::ConstTensor in1_data =
+ in1.flat_inner_dims<Type>();
+ typename TTypes<Type, 2>::Tensor output_data =
+ output->flat_inner_dims<Type>();
+
+ functor::Cross<Device, Type>()(context->eigen_device<Device>(), in0_data,
+ in1_data, output_data);
+ }
+};
+
+#define REGISTER_CPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Cross").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ CrossOp<CPUDevice, type>);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNEL);
+#undef REGISTER_CPU_KERNEL
+
+#if GOOGLE_CUDA
+// Forward declarations of the function specializations for GPU (to prevent
+// building the GPU versions here, they will be built compiling _gpu.cu.cc).
+namespace functor {
+#define DECLARE_GPU_KERNEL(type) \
+ template <> \
+ void Cross<GPUDevice, type>::operator()( \
+ const GPUDevice& d, TTypes<type, 2>::ConstTensor in0_data, \
+ TTypes<type, 2>::ConstTensor in1_data, \
+ TTypes<type, 2>::Tensor output_data); \
+ extern template struct Cross<GPUDevice, type>;
+TF_CALL_REAL_NUMBER_TYPES(DECLARE_GPU_KERNEL);
+#undef DECLARE_GPU_KERNEL
+} // namespace functor
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Cross").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ CrossOp<GPUDevice, type>);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_KERNEL);
+#undef REGISTER_GPU_KERNEL
+#endif
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cross_op.h b/tensorflow/core/kernels/cross_op.h
new file mode 100644
index 0000000000..bac01aeaa7
--- /dev/null
+++ b/tensorflow/core/kernels/cross_op.h
@@ -0,0 +1,54 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_COLORSPACE_OP_H_
+#define TENSORFLOW_KERNELS_COLORSPACE_OP_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+template <typename Device, typename Type>
+struct Cross {
+ void operator()(const Device &d,
+ typename TTypes<Type, 2>::ConstTensor in0_data,
+ typename TTypes<Type, 2>::ConstTensor in1_data,
+ typename TTypes<Type, 2>::Tensor output_data) {
+ auto s1 = output_data.template chip<1>(0);
+ auto s2 = output_data.template chip<1>(1);
+ auto s3 = output_data.template chip<1>(2);
+
+ auto u1 = in0_data.template chip<1>(0);
+ auto u2 = in0_data.template chip<1>(1);
+ auto u3 = in0_data.template chip<1>(2);
+
+ auto v1 = in1_data.template chip<1>(0);
+ auto v2 = in1_data.template chip<1>(1);
+ auto v3 = in1_data.template chip<1>(2);
+
+ s1.device(d) = u2 * v3 - u3 * v2;
+ s2.device(d) = u3 * v1 - u1 * v3;
+ s3.device(d) = u1 * v2 - u2 * v1;
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_COLORSPACE_OP_H_
diff --git a/tensorflow/core/kernels/cross_op_gpu.cu.cc b/tensorflow/core/kernels/cross_op_gpu.cu.cc
new file mode 100644
index 0000000000..a20ef8cf20
--- /dev/null
+++ b/tensorflow/core/kernels/cross_op_gpu.cu.cc
@@ -0,0 +1,34 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/cross_op.h"
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#define INSTANTIATE_GPU_KERNEL(type) \
+ template class functor::Cross<GPUDevice, type>;
+TF_CALL_REAL_NUMBER_TYPES(INSTANTIATE_GPU_KERNEL);
+#undef INSTANTIATE_GPU_KERNEL
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cross_op_test.cc b/tensorflow/core/kernels/cross_op_test.cc
new file mode 100644
index 0000000000..0b179e6d5d
--- /dev/null
+++ b/tensorflow/core/kernels/cross_op_test.cc
@@ -0,0 +1,102 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+class CrossOpTest : public OpsTestBase {
+ protected:
+ CrossOpTest() {
+ RequireDefaultOps();
+ TF_EXPECT_OK(NodeDefBuilder("cross_op", "Cross")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+ }
+};
+
+TEST_F(CrossOpTest, Zero) {
+ AddInputFromArray<float>(TensorShape({3}), {0, 0, 0});
+ AddInputFromArray<float>(TensorShape({3}), {0, 0, 0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
+ test::FillValues<float>(&expected, {0, 0, 0});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(CrossOpTest, RightHandRule) {
+ AddInputFromArray<float>(TensorShape({2, 3}), {1, 0, 0, /**/ 0, 1, 0});
+ AddInputFromArray<float>(TensorShape({2, 3}), {0, 1, 0, /**/ 1, 0, 0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
+ test::FillValues<float>(&expected, {{0, 0, 1, /**/ 0, 0, -1}});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(CrossOpTest, ArbitraryNonintegral) {
+ const float u1 = -0.669, u2 = -0.509, u3 = 0.125;
+ const float v1 = -0.477, v2 = 0.592, v3 = -0.110;
+ const float s1 = u2 * v3 - u3 * v2;
+ const float s2 = u3 * v1 - u1 * v3;
+ const float s3 = u1 * v2 - u2 * v1;
+
+ AddInputFromArray<float>(TensorShape({3}), {u1, u2, u3});
+ AddInputFromArray<float>(TensorShape({3}), {v1, v2, v3});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
+ test::FillValues<float>(&expected, {s1, s2, s3});
+ test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-6);
+}
+
+class CrossOpIntTest : public OpsTestBase {
+ protected:
+ CrossOpIntTest() {
+ RequireDefaultOps();
+ TF_EXPECT_OK(NodeDefBuilder("cross_int_op", "Cross")
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_INT32))
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+ }
+};
+
+TEST_F(CrossOpIntTest, RightHandRule) {
+ AddInputFromArray<int>(TensorShape({2, 3}), {2, 0, 0, /**/ 0, 2, 0});
+ AddInputFromArray<int>(TensorShape({2, 3}), {0, 2, 0, /**/ 2, 0, 0});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({2, 3}));
+ test::FillValues<int>(&expected, {{0, 0, 4, /**/ 0, 0, -4}});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index afab885099..e7294cb3e0 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -68,6 +68,12 @@ class InvertPermutationOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("InvertPermutation").Device(DEVICE_CPU),
InvertPermutationOp);
+REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y"),
+ InvertPermutationOp);
+
// output = TransposeOp(T<any> input, T<int32> perm) takes a tensor
// of type T and rank N, and a permutation of 0, 1, ..., N-1. It
// shuffles the dimensions of the input tensor according to permutation.
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 56c6c8e920..603005126d 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1180,4 +1180,23 @@ out: The inverse 2D Fourier Transform of `in`.
)doc");
+// --------------------------------------------------------------------------
+
+REGISTER_OP("Cross")
+ .Input("a: T")
+ .Input("b: T")
+ .Output("product: T")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Compute the pairwise cross product.
+
+`a` and `b` must be the same shape; they can either be simple 3-element vectors,
+or any shape where the innermost dimension is 3. In the latter case, each pair
+of corresponding 3-element vectors is cross-multiplied independently.
+
+a: A tensor containing 3-element vectors.
+b: Another tensor, of same type and shape as `a`.
+product: Pairwise cross product of the vectors in `a` and `b`.
+)doc");
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 1e32f3696d..85eb0965b3 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -1932,6 +1932,42 @@ op {
description: "This operation outputs \"ref\" after the update is done. This makes it\neasier to chain operations that need to use the updated value."
}
op {
+ name: "Cross"
+ input_arg {
+ name: "a"
+ description: "A tensor containing 3-element vectors."
+ type_attr: "T"
+ }
+ input_arg {
+ name: "b"
+ description: "Another tensor, of same type and shape as `a`."
+ type_attr: "T"
+ }
+ output_arg {
+ name: "product"
+ description: "Pairwise cross product of the vectors in `a` and `b`."
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_UINT16
+ }
+ }
+ }
+ summary: "Compute the pairwise cross product."
+ description: "`a` and `b` must be the same shape; they can either be simple 3-element vectors,\nor any shape where the innermost dimension is 3. In the latter case, each pair\nof corresponding 3-element vectors is cross-multiplied independently."
+}
+op {
name: "DecodeCSV"
input_arg {
name: "records"
diff --git a/tensorflow/python/kernel_tests/cross_grad_test.py b/tensorflow/python/kernel_tests/cross_grad_test.py
new file mode 100644
index 0000000000..9d29b35129
--- /dev/null
+++ b/tensorflow/python/kernel_tests/cross_grad_test.py
@@ -0,0 +1,41 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.nn_ops.Cross."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class CrossOpTest(tf.test.TestCase):
+
+ def testGradientRandomValues(self):
+ with self.test_session():
+ us = [2, 3]
+ u = tf.reshape([0.854, -0.616, 0.767, 0.725, -0.927, 0.159], shape=us)
+ v = tf.reshape([-0.522, 0.755, 0.407, -0.652, 0.241, 0.247], shape=us)
+ s = tf.cross(u, v)
+ jacob_u, jacob_v = tf.test.compute_gradient([u, v], [us, us], s, us)
+
+ self.assertAllClose(jacob_u[0], jacob_u[1], rtol=1e-3, atol=1e-3)
+ self.assertAllClose(jacob_v[0], jacob_v[1], rtol=1e-3, atol=1e-3)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/seq2seq_test.py b/tensorflow/python/kernel_tests/seq2seq_test.py
index cebc3244d8..b146390682 100644
--- a/tensorflow/python/kernel_tests/seq2seq_test.py
+++ b/tensorflow/python/kernel_tests/seq2seq_test.py
@@ -433,7 +433,7 @@ class Seq2SeqTest(tf.test.TestCase):
perplexities[bucket].append(math.exp(float(res[1])))
for bucket in range(len(buckets)):
if len(perplexities[bucket]) > 1: # Assert that perplexity went down.
- self.assertLess(perplexities[bucket][1], perplexities[bucket][0])
+ self.assertLess(perplexities[bucket][-1], perplexities[bucket][0])
if __name__ == "__main__":
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 39ef55ab2a..e9eeb2619b 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -560,8 +560,8 @@ def transpose(a, perm=None, name="transpose"):
"""
with ops.op_scope([a], name, "transpose") as name:
if perm is None:
- dims = gen_math_ops._range(0, gen_array_ops.rank(a), 1)
- perm = gen_array_ops.reverse(dims, [True])
+ rank = gen_array_ops.rank(a)
+ perm = (rank - 1) - gen_math_ops._range(0, rank, 1)
ret = gen_array_ops.transpose(a, perm, name=name)
# NOTE(mrry): Setting the shape explicitly because
# reverse is not handled by the shape function.
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 7a5b5e4c86..8b2779068b 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -567,3 +567,10 @@ def _FFT2DGrad(_, grad):
def _IFFT2DGrad(_, grad):
rsize = 1. / math_ops.cast(array_ops.size(grad), dtypes.float32)
return math_ops.fft2d(grad) * math_ops.complex(rsize, 0.)
+
+
+@ops.RegisterGradient("Cross")
+def _CrossGrad(op, grad):
+ u = op.inputs[0]
+ v = op.inputs[1]
+ return (math_ops.cross(v, grad), math_ops.cross(grad, u))
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index d878347863..90ed7ad876 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -25,6 +25,7 @@ operators to your graph.
@@truediv
@@floordiv
@@mod
+@@cross
## Basic Math Functions
@@ -1223,6 +1224,7 @@ ops.RegisterShape("Abs")(common_shapes.unchanged_shape)
ops.RegisterShape("Ceil")(common_shapes.unchanged_shape)
ops.RegisterShape("Conj")(common_shapes.unchanged_shape)
ops.RegisterShape("Cos")(common_shapes.unchanged_shape)
+ops.RegisterShape("Cross")(common_shapes.unchanged_shape)
ops.RegisterShape("Exp")(common_shapes.unchanged_shape)
ops.RegisterShape("Floor")(common_shapes.unchanged_shape)
ops.RegisterShape("Imag")(common_shapes.unchanged_shape)
diff --git a/third_party/eigen3/Eigen/Cholesky b/third_party/eigen3/Eigen/Cholesky
index ef31fc971b..671ec3d4a6 100644
--- a/third_party/eigen3/Eigen/Cholesky
+++ b/third_party/eigen3/Eigen/Cholesky
@@ -1 +1 @@
-#include "eigen-eigen-8cd7c2c6e9e1/Eigen/Cholesky" \ No newline at end of file
+#include "eigen-eigen-726c779797e8/Eigen/Cholesky"
diff --git a/third_party/eigen3/Eigen/Core b/third_party/eigen3/Eigen/Core
index a330b6166f..38f45037e6 100644
--- a/third_party/eigen3/Eigen/Core
+++ b/third_party/eigen3/Eigen/Core
@@ -1 +1 @@
-#include "eigen-eigen-8cd7c2c6e9e1/Eigen/Core"
+#include "eigen-eigen-726c779797e8/Eigen/Core"
diff --git a/third_party/eigen3/Eigen/Eigenvalues b/third_party/eigen3/Eigen/Eigenvalues
index 30158ba1ea..64f4200304 100644
--- a/third_party/eigen3/Eigen/Eigenvalues
+++ b/third_party/eigen3/Eigen/Eigenvalues
@@ -1 +1,2 @@
-#include "eigen-eigen-8cd7c2c6e9e1/Eigen/Eigenvalues"
+#include "eigen-eigen-726c779797e8/Eigen/Eigenvalues"
+
diff --git a/third_party/eigen3/Eigen/LU b/third_party/eigen3/Eigen/LU
index 5637771a51..ab9e6cb4c5 100644
--- a/third_party/eigen3/Eigen/LU
+++ b/third_party/eigen3/Eigen/LU
@@ -1 +1 @@
-#include "eigen-eigen-8cd7c2c6e9e1/Eigen/LU"
+#include "eigen-eigen-726c779797e8/Eigen/LU"
diff --git a/third_party/eigen3/Eigen/QR b/third_party/eigen3/Eigen/QR
index 360ba8e5e3..9ecf7be16d 100644
--- a/third_party/eigen3/Eigen/QR
+++ b/third_party/eigen3/Eigen/QR
@@ -1 +1 @@
-#include "eigen-eigen-8cd7c2c6e9e1/Eigen/QR"
+#include "eigen-eigen-726c779797e8/Eigen/QR"
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
index eb293afd04..a80816717b 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
@@ -1,2 +1 @@
-
-#include "eigen-eigen-8cd7c2c6e9e1/unsupported/Eigen/CXX11/Tensor"
+#include "eigen-eigen-726c779797e8/unsupported/Eigen/CXX11/Tensor"