aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
m---------google/protobuf0
-rw-r--r--tensorflow/core/BUILD45
-rw-r--r--tensorflow/core/framework/allocator.h6
-rw-r--r--tensorflow/core/framework/tensor.cc4
-rw-r--r--tensorflow/core/kernels/reshape_op.h19
-rw-r--r--tensorflow/core/kernels/tensor_array.h2
-rw-r--r--tensorflow/core/kernels/topk_op.cc1
-rw-r--r--tensorflow/core/public/version.h3
-rw-r--r--tensorflow/models/image/cifar10/cifar10.py14
-rw-r--r--tensorflow/models/image/mnist/convolutional.py20
-rw-r--r--tensorflow/python/ops/image_ops.py34
-rw-r--r--tensorflow/python/ops/image_ops_test.py52
-rw-r--r--tensorflow/python/ops/nn_ops.py7
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.cc9
-rw-r--r--tensorflow/tensorflow.bzl2
-rw-r--r--third_party/eigen3/Eigen/Cholesky2
-rw-r--r--third_party/eigen3/Eigen/Core2
-rw-r--r--third_party/eigen3/Eigen/Eigenvalues2
-rw-r--r--third_party/eigen3/Eigen/LU2
-rw-r--r--third_party/eigen3/Eigen/QR2
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/Tensor3
-rw-r--r--third_party/gpus/crosstool/CROSSTOOL1
22 files changed, 156 insertions, 76 deletions
diff --git a/google/protobuf b/google/protobuf
-Subproject bd8a476510d17d3841ff2509fbd67b7f4b543c1
+Subproject 0906f5d18a2548024b511eadcbb4cfc0ca56cd6
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index c5dfeb36b7..257280559b 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -29,7 +29,7 @@
# ":all_kernels" - The cpu-specific kernels, plus ":gpu_kernels" if
# built with Cuda
# ":tensorflow_opensource" - The complete open-source package, including
-# ":kernels", ":core", and a Session implementation.
+# ":all_kernels", ":core", and a Session implementation.
# ":tensorflow" - "tensorflow_opensource" plus some Google-internal libraries.
# ":testlib" - TensorFlow-specific test support, e.g. utilities for testing
# kernels.
@@ -457,19 +457,6 @@ tf_cuda_library(
],
)
-# DEPRECATED: Use either ":all_kernels" and/or ":kernel_lib" instead.
-# We need to get rid of this library before we can make the kernels
-# directory its own package, due to name conflicts.
-tf_cuda_library(
- name = "kernels",
- hdrs = glob(["kernels/**/*.h"]),
- deprecation = "use ':all_kernels' or ':kernel_lib' instead",
- visibility = ["//visibility:public"],
- deps = [
- ":kernel_lib",
- ],
-)
-
tf_cuda_library(
name = "tensorflow_opensource",
copts = tf_copts(),
@@ -1094,8 +1081,9 @@ tf_cc_tests(
# TODO(opensource): fix
"common_runtime/gpu/*_test.cc",
# Run by tests below
- "common_runtime/gpu/gpu_region_allocator_test.cc",
+ "common_runtime/gpu/gpu_allocator_retry_test.cc",
"common_runtime/gpu/gpu_bfc_allocator_test.cc",
+ "common_runtime/gpu/gpu_region_allocator_test.cc",
],
),
deps = [
@@ -1129,6 +1117,10 @@ tf_cc_tests(
"user_ops/**/*_test.cc",
"common_runtime/gpu/*_test.cc",
],
+ exclude = [
+ # Run by tests below
+ "common_runtime/gpu/gpu_allocator_retry_test.cc",
+ ],
),
deps = [
":all_kernels",
@@ -1149,6 +1141,29 @@ tf_cc_tests(
],
)
+tf_cc_tests(
+ linkstatic = tf_kernel_tests_linkstatic(),
+ tags = tf_cuda_tests_tags() + ["nomac"],
+ tests = ["common_runtime/gpu/gpu_allocator_retry_test.cc"],
+ deps = [
+ ":all_kernels",
+ ":core_cpu",
+ ":core_cpu_internal",
+ ":direct_session",
+ ":framework",
+ ":framework_internal",
+ ":gpu_runtime",
+ ":kernel_lib",
+ ":lib",
+ ":lib_internal",
+ ":protos_all_cc",
+ ":test",
+ ":test_main",
+ ":testlib",
+ "//tensorflow/cc:cc_ops",
+ ],
+)
+
# Test data
filegroup(
name = "image_testdata",
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index f98b7bb32b..d49d0d6005 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -108,6 +108,12 @@ class Allocator {
// TracksAlloctionSizes is overridden to return true.
virtual bool TracksAllocationSizes() { return false; }
+ // Returns true if this allocator requires tensors with 0 elements
+ // to allocate buffers. This is false for most allocators, but may
+ // be used by special-case allocators that want to track tensor
+ // usage.
+ virtual bool ShouldAllocateEmptyTensors() { return false; }
+
// Returns the user-requested size of the data allocated at
// 'ptr'. Note that the actual buffer allocated might be larger
// than requested, but this function returns the size requested by
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 8526d26148..8eee2126d9 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -381,7 +381,7 @@ void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) {
Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
: type_(type), shape_(shape), buf_(nullptr) {
CHECK_NOTNULL(a);
- if (shape_.num_elements() > 0) {
+ if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
}
}
@@ -390,7 +390,7 @@ Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
const AllocationAttributes& allocation_attr)
: type_(type), shape_(shape), buf_(nullptr) {
CHECK_NOTNULL(a);
- if (shape_.num_elements() > 0) {
+ if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
}
}
diff --git a/tensorflow/core/kernels/reshape_op.h b/tensorflow/core/kernels/reshape_op.h
index 577ed057f6..27ae3561e4 100644
--- a/tensorflow/core/kernels/reshape_op.h
+++ b/tensorflow/core/kernels/reshape_op.h
@@ -66,19 +66,22 @@ class ReshapeOp : public OpKernel {
if (unknown_index != -1) {
OP_REQUIRES(
context, product > 0,
- errors::InvalidArgument("cannot infer the missing input size for "
- "an empty tensor unless all specified "
+ errors::InvalidArgument("Reshape cannot infer the missing input size "
+ "for an empty tensor unless all specified "
"input sizes are non-zero"));
const int32 missing = input.NumElements() / product;
- OP_REQUIRES(context, product * missing == input.NumElements(),
- errors::InvalidArgument("Input has ", input.NumElements(),
- " values, which isn't divisible by ",
- product));
+ OP_REQUIRES(
+ context, product * missing == input.NumElements(),
+ errors::InvalidArgument(
+ "Input to reshape is a tensor with ", input.NumElements(),
+ " values, but the requested shape requires a multiple of ",
+ product));
shape.set_dim(unknown_index, missing);
}
OP_REQUIRES(context, shape.num_elements() == input.NumElements(),
- errors::InvalidArgument("Input has ", input.NumElements(),
- " values, which isn't the same as ",
+ errors::InvalidArgument("Input to reshape is a tensor with ",
+ input.NumElements(),
+ " values, but the requested shape has ",
shape.num_elements()));
// Actually produce the reshaped output.
diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h
index 82026c3cb4..7ef04bee94 100644
--- a/tensorflow/core/kernels/tensor_array.h
+++ b/tensorflow/core/kernels/tensor_array.h
@@ -122,7 +122,7 @@ class TensorArray : public ResourceBase {
mutex_lock l(mu_);
values->clear();
values->resize(tensors_.size());
- for (int32 i = 0; i < tensors_.size(); ++i) {
+ for (std::size_t i = 0; i < tensors_.size(); ++i) {
TF_RETURN_IF_ERROR(LockedRead(i, &(*values)[i]));
}
return Status::OK();
diff --git a/tensorflow/core/kernels/topk_op.cc b/tensorflow/core/kernels/topk_op.cc
index 27c00e912c..d5ea5e15a8 100644
--- a/tensorflow/core/kernels/topk_op.cc
+++ b/tensorflow/core/kernels/topk_op.cc
@@ -33,6 +33,7 @@ class TopK : public OpKernel {
explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_));
if (num_inputs() < 2) { // k is an attr (TopK).
+ OP_DEPRECATED(context, 7, "Use TopKV2 instead");
OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
} else { // k is an input (TopKV2), so we won't know it until Compute.
k_ = -1;
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 5d0ea8b750..9ec1aa9b7f 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -60,8 +60,9 @@ limitations under the License.
// 111635679, 7jan2016).
// 5. Graphs are wholly-validated during Session::Create() (7jan2016).
// 6. TensorFlow is scalar strict within Google (27jan2016).
+// 7. Remove TopK in favor of TopKV2 (5feb2016).
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 6
+#define TF_GRAPH_DEF_VERSION 7
#endif // TENSORFLOW_CORE_PUBLIC_VERSION_H_
diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py
index e736fcb467..7493b092fe 100644
--- a/tensorflow/models/image/cifar10/cifar10.py
+++ b/tensorflow/models/image/cifar10/cifar10.py
@@ -265,18 +265,10 @@ def loss(logits, labels):
Returns:
Loss tensor of type float.
"""
- # Reshape the labels into a dense Tensor of
- # shape [batch_size, NUM_CLASSES].
- sparse_labels = tf.reshape(labels, [FLAGS.batch_size, 1])
- indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1])
- concated = tf.concat(1, [indices, sparse_labels])
- dense_labels = tf.sparse_to_dense(concated,
- [FLAGS.batch_size, NUM_CLASSES],
- 1.0, 0.0)
-
# Calculate the average cross entropy loss across the batch.
- cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
- logits, dense_labels, name='cross_entropy_per_example')
+ labels = tf.cast(labels, tf.int64)
+ cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ logits, labels, name='cross_entropy_per_example')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
tf.add_to_collection('losses', cross_entropy_mean)
diff --git a/tensorflow/models/image/mnist/convolutional.py b/tensorflow/models/image/mnist/convolutional.py
index df0ca22063..edceb2a1ec 100644
--- a/tensorflow/models/image/mnist/convolutional.py
+++ b/tensorflow/models/image/mnist/convolutional.py
@@ -81,14 +81,13 @@ def extract_data(filename, num_images):
def extract_labels(filename, num_images):
- """Extract the labels into a 1-hot matrix [image index, label index]."""
+ """Extract the labels into a vector of int64 label IDs."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
bytestream.read(8)
buf = bytestream.read(1 * num_images)
- labels = numpy.frombuffer(buf, dtype=numpy.uint8)
- # Convert to dense 1-hot representation.
- return (numpy.arange(NUM_LABELS) == labels[:, None]).astype(numpy.float32)
+ labels = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.int64)
+ return labels
def fake_data(num_images):
@@ -96,19 +95,19 @@ def fake_data(num_images):
data = numpy.ndarray(
shape=(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS),
dtype=numpy.float32)
- labels = numpy.zeros(shape=(num_images, NUM_LABELS), dtype=numpy.float32)
+ labels = numpy.zeros(shape=(num_images,), dtype=numpy.int64)
for image in xrange(num_images):
label = image % 2
data[image, :, :, 0] = label - 0.5
- labels[image, label] = 1.0
+ labels[image] = label
return data, labels
def error_rate(predictions, labels):
- """Return the error rate based on dense predictions and 1-hot labels."""
+ """Return the error rate based on dense predictions and sparse labels."""
return 100.0 - (
100.0 *
- numpy.sum(numpy.argmax(predictions, 1) == numpy.argmax(labels, 1)) /
+ numpy.sum(numpy.argmax(predictions, 1) == labels) /
predictions.shape[0])
@@ -146,8 +145,7 @@ def main(argv=None): # pylint: disable=unused-argument
train_data_node = tf.placeholder(
tf.float32,
shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
- train_labels_node = tf.placeholder(tf.float32,
- shape=(BATCH_SIZE, NUM_LABELS))
+ train_labels_node = tf.placeholder(tf.int64, shape=(BATCH_SIZE,))
eval_data = tf.placeholder(
tf.float32,
shape=(EVAL_BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
@@ -222,7 +220,7 @@ def main(argv=None): # pylint: disable=unused-argument
# Training computation: logits + cross-entropy loss.
logits = model(train_data_node, True)
- loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
+ loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits, train_labels_node))
# L2 regularization for the fully connected parameters.
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index b65ecf9aeb..2eeef95d99 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -563,30 +563,53 @@ def resize_images(images,
_, height, width, depth = _ImageDimensions(images)
- if width == new_width and height == new_height:
+ # Handle tensor-valued sizes as well as Python integers.
+ try:
+ new_width = ops.convert_to_tensor(new_width, dtypes.int32,
+ name='new_width')
+ new_width.get_shape().assert_has_rank(0)
+ except (TypeError, ValueError):
+ raise ValueError('new_width must be a scalar integer')
+ try:
+ new_height = ops.convert_to_tensor(new_height, dtypes.int32,
+ name='new_height')
+ new_height.get_shape().assert_has_rank(0)
+ except (TypeError, ValueError):
+ raise ValueError('new_height must be a scalar integer')
+
+ new_width_const = tensor_util.constant_value(new_width)
+ new_height_const = tensor_util.constant_value(new_height)
+
+ if width == new_width_const and height == new_height_const:
if not is_batch:
images = array_ops.squeeze(images, squeeze_dims=[0])
return images
+ new_size = array_ops.pack([new_height, new_width])
+
if method == ResizeMethod.BILINEAR:
images = gen_image_ops.resize_bilinear(images,
- [new_height, new_width],
+ new_size,
align_corners=align_corners)
elif method == ResizeMethod.NEAREST_NEIGHBOR:
images = gen_image_ops.resize_nearest_neighbor(images,
- [new_height, new_width],
+ new_size,
align_corners=align_corners)
elif method == ResizeMethod.BICUBIC:
images = gen_image_ops.resize_bicubic(images,
- [new_height, new_width],
+ new_size,
align_corners=align_corners)
elif method == ResizeMethod.AREA:
images = gen_image_ops.resize_area(images,
- [new_height, new_width],
+ new_size,
align_corners=align_corners)
else:
raise ValueError('Resize method is not implemented.')
+ # NOTE(mrry): The shape functions for the resize ops cannot unpack
+ # the packed values in `new_size`, so set the shape here.
+ images.set_shape([None, new_height_const, new_width_const, None])
+
if not is_batch:
images = array_ops.squeeze(images, squeeze_dims=[0])
return images
@@ -779,6 +802,7 @@ ops.RegisterShape('AdjustContrastv2')(
def _ResizeShape(op):
"""Shape function for the resize_bilinear and resize_nearest_neighbor ops."""
input_shape = op.inputs[0].get_shape().with_rank(4)
+ unused_size_shape = op.inputs[1].get_shape().merge_with([2])
size = tensor_util.constant_value(op.inputs[1])
if size is not None:
height = size[0]
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index f44fe344e6..d09004556c 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -654,6 +654,58 @@ class ResizeImagesTest(test_util.TensorFlowTestCase):
newshape = yshape.eval()
self.assertAllEqual(single_shape, newshape)
+ def testTensorArguments(self):
+ img_shape = [1, 6, 4, 1]
+ single_shape = [6, 4, 1]
+ # This test is also conducted with int8, so 127 is the maximum
+ # value that can be used.
+ data = [127, 127, 64, 64,
+ 127, 127, 64, 64,
+ 64, 64, 127, 127,
+ 64, 64, 127, 127,
+ 50, 50, 100, 100,
+ 50, 50, 100, 100]
+ target_height = array_ops.placeholder(dtypes.int32)
+ target_width = array_ops.placeholder(dtypes.int32)
+
+ img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
+
+ for opt in self.OPTIONS:
+ with self.test_session() as sess:
+ image = constant_op.constant(img_np, shape=img_shape)
+ y = image_ops.resize_images(image, target_height, target_width, opt)
+ yshape = array_ops.shape(y)
+ resized, newshape = sess.run([y, yshape], {target_height: 6,
+ target_width: 4})
+ self.assertAllEqual(img_shape, newshape)
+ self.assertAllClose(resized, img_np, atol=1e-5)
+
+ # Resizing with a single image must leave the shape unchanged also.
+ with self.test_session():
+ img_single = img_np.reshape(single_shape)
+ image = constant_op.constant(img_single, shape=single_shape)
+ y = image_ops.resize_images(image, target_height, target_width,
+ self.OPTIONS[0])
+ yshape = array_ops.shape(y)
+ newshape = yshape.eval(feed_dict={target_height: 6, target_width: 4})
+ self.assertAllEqual(single_shape, newshape)
+
+ # Incorrect shape.
+ with self.assertRaises(ValueError):
+ _ = image_ops.resize_images(
+ image, [12, 32], 4, image_ops.ResizeMethod.BILINEAR)
+ with self.assertRaises(ValueError):
+ _ = image_ops.resize_images(
+ image, 6, [12, 32], image_ops.ResizeMethod.BILINEAR)
+
+ # Incorrect dtypes.
+ with self.assertRaises(ValueError):
+ _ = image_ops.resize_images(
+ image, 6.0, 4, image_ops.ResizeMethod.BILINEAR)
+ with self.assertRaises(ValueError):
+ _ = image_ops.resize_images(
+ image, 6, 4.0, image_ops.ResizeMethod.BILINEAR)
+
def testResizeDown(self):
# This test is also conducted with int8, so 127 is the maximum
# value that can be used.
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index b01d6d0cf9..06e625e695 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -614,12 +614,7 @@ def top_k(input, k=1, sorted=True, name=None):
values: The `k` largest elements along each last dimensional slice.
indices: The indices of `values` within the last dimension of `input`.
"""
- # TODO(irving): Always use v2 once the GraphDef mechanism is unstuck.
- if isinstance(k, ops.Tensor):
- op = gen_nn_ops._top_kv2
- else:
- op = gen_nn_ops._top_k
- return op(input, k=k, sorted=sorted, name=name)
+ return gen_nn_ops._top_kv2(input, k=k, sorted=sorted, name=name)
# pylint: enable=invalid-name
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
index 2565078bb2..f98eec3b33 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
@@ -208,15 +208,6 @@ static string GetBinaryDir(bool strip_exe) {
return exe_path;
}
-// Returns the location of the runfiles directory.
-// This is the directory which "bazel run" sets as the current working directory
-// before the program starts.
-// N.B. This doesn't have to be running under "bazel run" in order to get the
-// appropriate runfiles directory.
-static string GetRunfilesDir() {
- return port::StrCat(GetBinaryDir(false), ".runfiles");
-}
-
bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
KernelBase *kernel) {
CUDAKernel *cuda_kernel = AsCUDAKernel(kernel);
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 21ed3a8369..c9bf1ac35a 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -260,7 +260,7 @@ def _py_wrap_cc_impl(ctx):
ctx.action(executable=ctx.executable.swig_binary,
arguments=args,
mnemonic="PythonSwig",
- inputs=list(set([src]) + cc_includes + ctx.files.swig_includes +
+ inputs=sorted(set([src]) + cc_includes + ctx.files.swig_includes +
ctx.attr.swig_deps.files),
outputs=outputs,
progress_message="SWIGing {input}".format(input=src.path))
diff --git a/third_party/eigen3/Eigen/Cholesky b/third_party/eigen3/Eigen/Cholesky
index afc87201eb..ef31fc971b 100644
--- a/third_party/eigen3/Eigen/Cholesky
+++ b/third_party/eigen3/Eigen/Cholesky
@@ -1 +1 @@
-#include "external/eigen_archive/eigen-eigen-8cd7c2c6e9e1/Eigen/Cholesky"
+#include "eigen-eigen-8cd7c2c6e9e1/Eigen/Cholesky" \ No newline at end of file
diff --git a/third_party/eigen3/Eigen/Core b/third_party/eigen3/Eigen/Core
index 6527c9f80a..a330b6166f 100644
--- a/third_party/eigen3/Eigen/Core
+++ b/third_party/eigen3/Eigen/Core
@@ -1 +1 @@
-#include "external/eigen_archive/eigen-eigen-8cd7c2c6e9e1/Eigen/Core"
+#include "eigen-eigen-8cd7c2c6e9e1/Eigen/Core"
diff --git a/third_party/eigen3/Eigen/Eigenvalues b/third_party/eigen3/Eigen/Eigenvalues
index d8c61044c5..30158ba1ea 100644
--- a/third_party/eigen3/Eigen/Eigenvalues
+++ b/third_party/eigen3/Eigen/Eigenvalues
@@ -1 +1 @@
-#include "external/eigen_archive/eigen-eigen-8cd7c2c6e9e1/Eigen/Eigenvalues"
+#include "eigen-eigen-8cd7c2c6e9e1/Eigen/Eigenvalues"
diff --git a/third_party/eigen3/Eigen/LU b/third_party/eigen3/Eigen/LU
index a290471c8d..5637771a51 100644
--- a/third_party/eigen3/Eigen/LU
+++ b/third_party/eigen3/Eigen/LU
@@ -1 +1 @@
-#include "external/eigen_archive/eigen-eigen-8cd7c2c6e9e1/Eigen/LU"
+#include "eigen-eigen-8cd7c2c6e9e1/Eigen/LU"
diff --git a/third_party/eigen3/Eigen/QR b/third_party/eigen3/Eigen/QR
index c67724defb..360ba8e5e3 100644
--- a/third_party/eigen3/Eigen/QR
+++ b/third_party/eigen3/Eigen/QR
@@ -1 +1 @@
-#include "external/eigen_archive/eigen-eigen-8cd7c2c6e9e1/Eigen/QR"
+#include "eigen-eigen-8cd7c2c6e9e1/Eigen/QR"
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
index 6f97c57e43..eb293afd04 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
@@ -1 +1,2 @@
-#include "external/eigen_archive/eigen-eigen-8cd7c2c6e9e1/unsupported/Eigen/CXX11/Tensor"
+
+#include "eigen-eigen-8cd7c2c6e9e1/unsupported/Eigen/CXX11/Tensor"
diff --git a/third_party/gpus/crosstool/CROSSTOOL b/third_party/gpus/crosstool/CROSSTOOL
index 629dc32f30..dfde7cd216 100644
--- a/third_party/gpus/crosstool/CROSSTOOL
+++ b/third_party/gpus/crosstool/CROSSTOOL
@@ -144,4 +144,5 @@ toolchain {
compiler_flag: "-fdata-sections"
linker_flag: "-Wl,--gc-sections"
}
+ linking_mode_flags { mode: DYNAMIC }
}