aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--configure.py11
-rw-r--r--tensorflow/c/c_api_internal.h2
-rw-r--r--tensorflow/c/eager/c_api.cc9
-rw-r--r--tensorflow/c/eager/c_api.h2
-rw-r--r--tensorflow/c/eager/c_api_test.cc9
-rw-r--r--tensorflow/cc/framework/gradients.cc36
-rw-r--r--tensorflow/cc/framework/gradients_test.cc67
-rw-r--r--tensorflow/cc/framework/testutil.cc14
-rw-r--r--tensorflow/cc/framework/testutil.h12
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt7
-rw-r--r--tensorflow/contrib/cmake/external/boringssl.cmake2
-rw-r--r--tensorflow/contrib/cmake/external/snappy.cmake50
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py35
-rw-r--r--tensorflow/contrib/data/python/ops/dataset_ops.py6
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py4
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py12
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/framework/allocator.cc10
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc60
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc12
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass_test.cc8
-rw-r--r--tensorflow/core/kernels/BUILD17
-rw-r--r--tensorflow/core/kernels/cwise_op_sub.cc5
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc8
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc27
-rw-r--r--tensorflow/core/kernels/mkl_reshape_op.cc68
-rw-r--r--tensorflow/core/kernels/parse_tensor_op.cc28
-rw-r--r--tensorflow/core/kernels/parse_tensor_test.cc198
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.cc128
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.h22
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc121
-rw-r--r--tensorflow/core/ops/math_ops.cc2
-rw-r--r--tensorflow/core/ops/nn_ops.cc4
-rw-r--r--tensorflow/core/ops/ops.pbtxt19
-rw-r--r--tensorflow/core/ops/parsing_ops.cc13
-rw-r--r--tensorflow/core/platform/default/logging.cc60
-rw-r--r--tensorflow/core/platform/default/logging.h46
-rw-r--r--tensorflow/core/profiler/g3doc/command_line.md2
-rw-r--r--tensorflow/docs_src/community/welcome.md1
-rw-r--r--tensorflow/docs_src/get_started/estimator.md8
-rw-r--r--tensorflow/docs_src/get_started/index.md2
-rw-r--r--tensorflow/docs_src/get_started/input_fn.md2
-rw-r--r--tensorflow/docs_src/get_started/leftnav_files1
-rw-r--r--tensorflow/java/BUILD51
-rw-r--r--tensorflow/java/src/gen/cc/op_gen_main.cc84
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.cc66
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.h51
-rw-r--r--tensorflow/java/src/gen/gen_ops.bzl59
-rw-r--r--tensorflow/python/eager/python_eager_op_gen.cc23
-rw-r--r--tensorflow/python/eager/python_eager_op_gen.h5
-rw-r--r--tensorflow/python/feature_column/feature_column.py3
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py30
-rw-r--r--tensorflow/python/framework/python_op_gen_main.cc41
-rw-r--r--tensorflow/python/framework/tensor_util.py3
-rw-r--r--tensorflow/python/framework/tensor_util_test.py11
-rw-r--r--tensorflow/python/kernel_tests/BUILD4
-rw-r--r--tensorflow/python/kernel_tests/segment_reduction_ops_test.py178
-rw-r--r--tensorflow/python/ops/io_ops.py1
-rw-r--r--tensorflow/python/ops/parsing_ops.py1
-rw-r--r--tensorflow/python/profiler/model_analyzer.py4
-rw-r--r--tensorflow/python/tools/import_pb_to_tensorboard.py2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt4
-rw-r--r--tensorflow/workspace.bzl8
-rw-r--r--third_party/boringssl/add_boringssl_s390x.patch4
65 files changed, 1521 insertions, 268 deletions
diff --git a/configure.py b/configure.py
index 186fdc9ddc..ef5051d275 100644
--- a/configure.py
+++ b/configure.py
@@ -685,10 +685,13 @@ def set_tf_cunn_version(environ_cp):
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
cudnn_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
cudnn_path_from_ldconfig = re.search('.*libcudnn.so .* => (.*)',
- cudnn_path_from_ldconfig).group(1)
- if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)):
- cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig)
- break
+ cudnn_path_from_ldconfig)
+ if cudnn_path_from_ldconfig:
+ cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1)
+ if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig,
+ tf_cudnn_version)):
+ cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig)
+ break
# Reset and Retry
print(
diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h
index 6e44a72e2b..68c324f2b9 100644
--- a/tensorflow/c/c_api_internal.h
+++ b/tensorflow/c/c_api_internal.h
@@ -146,6 +146,8 @@ class TensorCApi {
}
};
+Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
+
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out);
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 01e251a1ac..e70539ceef 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -151,10 +151,11 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
return TF_SessionListDevices(ctx->session, status);
}
-TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t) {
- return new TFE_TensorHandle(
- tensorflow::TensorCApi::MakeTensor(t->dtype, t->shape, t->buffer),
- nullptr);
+TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
+ tensorflow::Tensor tensor;
+ status->status = tensorflow::TF_TensorToTensor(t, &tensor);
+ if (!status->status.ok()) return nullptr;
+ return new TFE_TensorHandle(tensor, nullptr);
}
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; }
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 476c9288f8..88a0dd343f 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -43,7 +43,7 @@ extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
// placed in memory of different devices or remote address spaces.
typedef struct TFE_TensorHandle TFE_TensorHandle;
-extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t);
+extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status);
extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h);
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 6f5c21c947..72e0fe8a15 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -34,8 +34,11 @@ TFE_TensorHandle* TestMatrixTensorHandle() {
TF_Tensor* t = TF_AllocateTensor(
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
- TFE_TensorHandle* th = TFE_NewTensorHandle(t);
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
return th;
}
@@ -383,7 +386,9 @@ TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
- value_handle(TFE_NewTensorHandle(t.get()), TFE_DeleteTensorHandle);
+ value_handle(TFE_NewTensorHandle(t.get(), status),
+ TFE_DeleteTensorHandle);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op, value_handle.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc
index 66a943410e..82469261e5 100644
--- a/tensorflow/cc/framework/gradients.cc
+++ b/tensorflow/cc/framework/gradients.cc
@@ -78,6 +78,10 @@ class SymbolicGradientBuilder {
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs);
+ // Returns a list mapping whether each node in the graph is reachable
+ // from outputs_. Keyed by node id.
+ std::vector<bool> GetReachableNodes();
+
const Scope& scope_;
const ops::GradOpRegistry* registry_;
const std::vector<Output>& outputs_;
@@ -143,11 +147,36 @@ Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad,
return Status::OK();
}
+std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
+ std::vector<bool> reachable_nodes(scope_.graph()->num_node_ids(), false);
+ std::deque<Node*> queue;
+ for (const Output& out : outputs_) {
+ if (!reachable_nodes[out.node()->id()]) {
+ queue.push_back(out.node());
+ reachable_nodes[out.node()->id()] = true;
+ }
+ }
+
+ while (!queue.empty()) {
+ Node* n = queue.front();
+ queue.pop_front();
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ queue.push_back(e->src());
+ reachable_nodes[e->src()->id()] = true;
+ }
+ }
+ return reachable_nodes;
+}
+
Status SymbolicGradientBuilder::Initialize() {
if (outputs_.size() != grad_inputs_.size()) {
return errors::InvalidArgument(
"Must specify a gradient input for each output.");
}
+ std::vector<bool> reachable_nodes = GetReachableNodes();
+ // TODO(theflofly) Check that inputs_ are reachable from
+ // outputs_ using reachable_nodes
grad_outputs_->clear();
grad_outputs_->resize(inputs_.size());
// Populate `output_nodes_` from node ids in `outputs_`.
@@ -188,12 +217,15 @@ Status SymbolicGradientBuilder::Initialize() {
if (output_nodes_.find(n->id()) == output_nodes_.end()) {
// Internal node: continue BFS along connected outputs.
for (const Edge* e : n->out_edges()) {
- if (e->IsControlEdge()) continue;
- ++num_expected_backprops;
+ // If a node is not reachable from outputs_,
+ // we don't expect it to receive a backpropagated gradient.
+ // It will not be counted in num_expected_backprops.
+ if (e->IsControlEdge() || !reachable_nodes[e->dst()->id()]) continue;
if (visited.find(e->dst()) == visited.end()) {
queue.push_back(e->dst());
visited.insert(e->dst());
}
+ ++num_expected_backprops;
}
} else {
// Output node: stop BFS and update `num_expected_backprops` for
diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc
index 24af7d567b..032ab93623 100644
--- a/tensorflow/cc/framework/gradients_test.cc
+++ b/tensorflow/cc/framework/gradients_test.cc
@@ -364,6 +364,73 @@ TEST_F(GradientsTest, MultipleNodeOutputGrads) {
test::AsTensor<int>({60, 61, 62, 63, 66, 66, 66, 67}, {4, 2}));
}
+TEST_F(GradientsTest, UnreachableEdgeGradOneOutput) {
+ auto x = Variable(scope_test_, {2, 3}, DT_DOUBLE);
+ auto x_const = Const(scope_test_, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
+ auto x_assign = Assign(scope_test_, x, x_const);
+
+ auto y = Variable(scope_test_, {3, 1}, DT_DOUBLE);
+ auto y_const = Const(scope_test_, {{1.0}, {2.0}, {3.0}});
+ auto y_assign = Assign(scope_test_, y, y_const);
+
+ auto m1 = MatMul(scope_test_, x, y);
+
+ auto z = Variable(scope_test_, {1, 3}, DT_DOUBLE);
+ auto z_const = Const(scope_test_, {{9.0, 10.0, 11.0}});
+ auto z_assign = Assign(scope_test_, z, z_const);
+
+ auto m2 = MatMul(scope_test_, y, z);
+
+ auto dm1 = Const(scope_test_, {{0.5}, {0.5}});
+
+ std::vector<Output> grad_outputs;
+ TF_ASSERT_OK(
+ AddSymbolicGradients(scope_test_, {m1}, {y}, {dm1}, &grad_outputs));
+
+ std::vector<Tensor> outputs;
+ test::GetTensors(scope_test_, {x_assign, y_assign, z_assign},
+ {grad_outputs[0]}, &outputs);
+ // dz/dy = xT * dm1
+ test::ExpectTensorNear<double>(
+ outputs[0], test::AsTensor<double>({2.5, 3.5, 4.5}, {3, 1}), 1e-5);
+}
+
+TEST_F(GradientsTest, UnreachableEdgeGradTwoOutputs) {
+ auto x = Variable(scope_test_, {2, 3}, DT_DOUBLE);
+ auto x_const = Const(scope_test_, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
+ auto x_assign = Assign(scope_test_, x, x_const);
+
+ auto y = Variable(scope_test_, {3, 1}, DT_DOUBLE);
+ auto y_const = Const(scope_test_, {{1.0}, {2.0}, {3.0}});
+ auto y_assign = Assign(scope_test_, y, y_const);
+
+ auto m1 = MatMul(scope_test_, x, y);
+
+ auto z = Variable(scope_test_, {1, 3}, DT_DOUBLE);
+ auto z_const = Const(scope_test_, {{9.0, 10.0, 11.0}});
+ auto z_assign = Assign(scope_test_, z, z_const);
+
+ auto m2 = MatMul(scope_test_, y, z);
+
+ auto dm1 = Const(scope_test_, {{0.5}, {0.5}});
+ auto dm2 =
+ Const(scope_test_, {{0.5, 0.5, 0.5}, {0.6, 0.7, 0.8}, {0.6, 0.7, 0.9}});
+
+ std::vector<Output> grad_outputs;
+ TF_ASSERT_OK(AddSymbolicGradients(scope_test_, {m1, m2}, {y}, {dm1, dm2},
+ &grad_outputs));
+
+ std::vector<Tensor> outputs;
+ test::GetTensors(scope_test_, {x_assign, y_assign, z_assign},
+ {grad_outputs[0]}, &outputs);
+
+ // the gradients from m1 and m2 will be summed to compute the gradient
+ // w.r.t y
+ // dz/dy = xT * dm1 + dm2 * zT
+ test::ExpectTensorNear<double>(
+ outputs[0], test::AsTensor<double>({17.5, 24.7, 26.8}, {3, 1}), 1e-5);
+}
+
// StopGradientSingleOutputMultiEdgeTest tests combinations of valid and
// 'NoGradient' (induced by StopGradient op) returned along multiple edges from
// a single nodes output.
diff --git a/tensorflow/cc/framework/testutil.cc b/tensorflow/cc/framework/testutil.cc
index ca78f31db5..57d573e3c5 100644
--- a/tensorflow/cc/framework/testutil.cc
+++ b/tensorflow/cc/framework/testutil.cc
@@ -36,5 +36,19 @@ void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
*out = outputs[0];
}
+void GetTensors(const Scope& scope, const std::vector<Output>& assign_vars,
+ const OutputList& tensors, std::vector<Tensor>* out) {
+ ClientSession session(scope);
+ TF_CHECK_OK(session.Run(assign_vars, nullptr));
+ TF_CHECK_OK(session.Run(tensors, out));
+}
+
+void GetTensor(const Scope& scope, const std::vector<Output>& assign_vars,
+ Output tensor, Tensor* out) {
+ std::vector<Tensor> outputs;
+ GetTensors(scope, assign_vars, {std::move(tensor)}, &outputs);
+ *out = outputs[0];
+}
+
} // end namespace test
} // end namespace tensorflow
diff --git a/tensorflow/cc/framework/testutil.h b/tensorflow/cc/framework/testutil.h
index d027ad3744..a3e19870ec 100644
--- a/tensorflow/cc/framework/testutil.h
+++ b/tensorflow/cc/framework/testutil.h
@@ -26,9 +26,21 @@ namespace test {
void GetTensors(const Scope& scope, OutputList tensors,
std::vector<Tensor>* out);
+// Computes the outputs listed in 'tensors', returns the tensors in 'out'.
+// assign_vars are extra outputs that should be run
+// e.g. to assign values to variables.
+void GetTensors(const Scope& scope, const std::vector<Output>& assign_vars,
+ const OutputList& tensors, std::vector<Tensor>* out);
+
/// Computes the output 'tensor', returning the resulting tensor in 'out'.
void GetTensor(const Scope& scope, Output tensor, Tensor* out);
+// Computes the output 'tensor', returning the resulting tensor in 'out'.
+// assign_vars are extra outputs that should be run
+// e.g. to assign values to variables.
+void GetTensor(const Scope& scope, const std::vector<Output>& assign_vars,
+ Output tensor, Tensor* out);
+
} // namespace test
} // namespace tensorflow
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index 422df3063e..c249a28556 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -33,6 +33,7 @@ option(tensorflow_BUILD_MORE_PYTHON_TESTS "Build more python unit tests for cont
option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF)
option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON)
option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions")
+option(tensorflow_ENABLE_SNAPPY_SUPPORT "Enable SNAPPY compression support" ON)
if (NOT WIN32)
# Threads: defines CMAKE_THREAD_LIBS_INIT and adds -pthread compile option
@@ -204,6 +205,12 @@ if(tensorflow_ENABLE_JEMALLOC_SUPPORT)
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES jemalloc)
include_directories(${jemalloc_INCLUDE_DIRS})
endif()
+if(tensorflow_ENABLE_SNAPPY_SUPPORT)
+ include(snappy)
+ list(APPEND tensorflow_EXTERNAL_LIBRARIES ${snappy_STATIC_LIBRARIES})
+ list(APPEND tensorflow_EXTERNAL_DEPENDENCIES snappy)
+ include_directories(${snappy_INCLUDE_DIR})
+endif()
if(WIN32)
list(APPEND tensorflow_EXTERNAL_LIBRARIES wsock32 ws2_32 shlwapi)
endif()
diff --git a/tensorflow/contrib/cmake/external/boringssl.cmake b/tensorflow/contrib/cmake/external/boringssl.cmake
index 04a9664701..dc27eadaca 100644
--- a/tensorflow/contrib/cmake/external/boringssl.cmake
+++ b/tensorflow/contrib/cmake/external/boringssl.cmake
@@ -17,7 +17,7 @@ include (ExternalProject)
set(boringssl_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src/boringssl/include)
#set(boringssl_EXTRA_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src)
set(boringssl_URL https://boringssl.googlesource.com/boringssl)
-set(boringssl_TAG 17cf2cb1d226b0ba2401304242df7ddd3b6f1ff2)
+set(boringssl_TAG ee7aa02)
set(boringssl_BUILD ${CMAKE_BINARY_DIR}/boringssl/src/boringssl-build)
#set(boringssl_LIBRARIES ${boringssl_BUILD}/obj/so/libboringssl.so)
set(boringssl_STATIC_LIBRARIES
diff --git a/tensorflow/contrib/cmake/external/snappy.cmake b/tensorflow/contrib/cmake/external/snappy.cmake
new file mode 100644
index 0000000000..a35d8654fb
--- /dev/null
+++ b/tensorflow/contrib/cmake/external/snappy.cmake
@@ -0,0 +1,50 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+include (ExternalProject)
+
+set(snappy_URL https://github.com/google/snappy.git)
+set(snappy_TAG "55924d11095df25ab25c405fadfe93d0a46f82eb")
+set(snappy_BUILD ${CMAKE_CURRENT_BINARY_DIR}/snappy/src/snappy)
+set(snappy_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/snappy/src/snappy)
+
+if(WIN32)
+ set(snappy_STATIC_LIBRARIES ${snappy_BUILD}/$(Configuration)/snappy.lib)
+else()
+ set(snappy_STATIC_LIBRARIES ${snappy_BUILD}/libsnappy.a)
+endif()
+
+set(snappy_HEADERS
+ "${snappy_INCLUDE_DIR}/snappy.h"
+)
+
+ExternalProject_Add(snappy
+ PREFIX snappy
+ GIT_REPOSITORY ${snappy_URL}
+ GIT_TAG ${snappy_TAG}
+ DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
+ BUILD_IN_SOURCE 1
+ INSTALL_COMMAND ""
+ LOG_DOWNLOAD ON
+ LOG_CONFIGURE ON
+ LOG_BUILD ON
+ CMAKE_CACHE_ARGS
+ -DCMAKE_BUILD_TYPE:STRING=Release
+ -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
+ -DSNAPPY_BUILD_TESTS:BOOL=OFF
+ -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
+)
+
+# actually enables snappy in the source code
+add_definitions(-DSNAPPY) \ No newline at end of file
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index 97b4ec44fc..7240fc7422 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -16,6 +16,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from collections import namedtuple
import os
import threading
@@ -481,6 +482,40 @@ class MapDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testMapNamedtuple(self, count=10):
+ # construct dataset of tuples
+ labels = dataset_ops.Dataset.range(count)
+ images = labels.map(lambda l: -l)
+ dataset_tuple = dataset_ops.Dataset.zip((labels, images))
+
+ # convert dataset of tuples to dataset of namedtuples
+ example = namedtuple("Example", ["label", "image"])
+ dataset_namedtuple = dataset_tuple.map(example)
+
+ def preprocess_tuple(label, image):
+ image = 2 * image
+ return label, image
+
+ def preprocess_namedtuple(example):
+ return example._replace(image=2 * example.image)
+
+ # preprocess both datasets
+ dataset_tuple = dataset_tuple.map(preprocess_tuple)
+ dataset_namedtuple = dataset_namedtuple.map(preprocess_namedtuple)
+
+ next_tuple = dataset_tuple.make_one_shot_iterator().get_next()
+ next_namedtuple = dataset_namedtuple.make_one_shot_iterator().get_next()
+
+ # make sure both datasets contain the same data
+ with self.test_session() as sess:
+ for i in range(count):
+ tuple_, namedtuple_ = sess.run([next_tuple, next_namedtuple])
+ self.assertEqual(tuple_, namedtuple_)
+ self.assertEqual(tuple_, (i, -2 * i))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_namedtuple)
+
def testUseStepContainerInMap(self):
row = np.arange(6)
iterator = (
diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py
index abf7bcb384..0ee9acfc97 100644
--- a/tensorflow/contrib/data/python/ops/dataset_ops.py
+++ b/tensorflow/contrib/data/python/ops/dataset_ops.py
@@ -1921,7 +1921,7 @@ class DenseToSparseBatchDataset(Dataset):
def _should_unpack_args(args):
"""Returns `True` if `args` should be `*args` when passed to a callable."""
- return nest.is_sequence(args) and not isinstance(args, dict)
+ return type(args) is tuple # pylint: disable=unidiomatic-typecheck
class _ResourceDataset(Dataset):
@@ -2104,7 +2104,7 @@ class InterleaveDataset(Dataset):
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
- if nest.is_sequence(nested_args):
+ if _should_unpack_args(nested_args):
dataset = map_func(*nested_args)
else:
dataset = map_func(nested_args)
@@ -2413,7 +2413,7 @@ def rejection_resample(dataset,
shapes and types defined by `dataset.output_shapes` and
`dataset.output_types`) to a scalar `tf.int32` tensor. Values should
be in `[0, num_classes)`.
- target_dist: A floating point type tensor, shaped `[num_classes].
+ target_dist: A floating point type tensor, shaped `[num_classes]`.
initial_dist: (Optional.) A floating point type tensor, shaped
`[num_classes]`. If not provided, the true class distribution is
estimated live in a streaming fashion.
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
index da1cd72a6f..699cf45a73 100644
--- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
@@ -150,7 +150,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution):
`N - 1` dimensions index into a batch of independent distributions and
the last dimension represents a vector of probabilities for each
class. Only one of `logits` or `probs` should be passed in.
- dtype: The type of the event samples (default: int32).
+ dtype: The type of the event samples (default: float32).
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
@@ -388,7 +388,7 @@ class RelaxedOneHotCategorical(
dimensions index into a batch of independent distributions and the last
dimension represents a vector of probabilities for each class. Only one
of `logits` or `probs` should be passed in.
- dtype: The type of the event samples (default: int32).
+ dtype: The type of the event samples (default: float32).
validate_args: Unused in this distribution.
allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
exception if a statistic (e.g. mean/mode/etc...) is undefined for any
diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py
index 7eb410b4c7..33db93b970 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers.py
@@ -156,9 +156,9 @@ def optimize_loss(loss,
loss = ops.convert_to_tensor(loss)
contrib_framework.assert_scalar(loss)
if global_step is None:
- global_step = contrib_framework.get_global_step()
+ global_step = train.get_global_step()
else:
- contrib_framework.assert_global_step(global_step)
+ train.assert_global_step(global_step)
with vs.variable_scope(name, "OptimizeLoss", [loss, global_step]):
# Update ops take UPDATE_OPS collection if not provided.
if update_ops is None:
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index c31d5d2d47..861db1f89e 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -24,7 +24,6 @@ import six
from tensorflow.contrib import framework as framework_lib
from tensorflow.contrib import layers as layers_lib
-from tensorflow.contrib import lookup as lookup_lib
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
@@ -35,6 +34,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import nn
@@ -1070,9 +1070,8 @@ class _MultiClassHead(_SingleHead):
labels_tensor = _to_labels_tensor(labels, self._label_name)
_check_no_sparse_tensor(labels_tensor)
if self._label_keys:
- table = lookup_lib.string_to_index_table_from_tensor(
- mapping=self._label_keys,
- name="label_id_lookup")
+ table = lookup_ops.index_table_from_tensor(
+ self._label_keys, name="label_id_lookup")
return {
"labels": labels_tensor,
"label_ids": table.lookup(labels_tensor),
@@ -1106,9 +1105,8 @@ class _MultiClassHead(_SingleHead):
class_ids = math_ops.argmax(
logits, 1, name=prediction_key.PredictionKey.CLASSES)
if self._label_keys:
- table = lookup_lib.index_to_string_table_from_tensor(
- mapping=self._label_keys,
- name="class_string_lookup")
+ table = lookup_ops.index_to_string_table_from_tensor(
+ self._label_keys, name="class_string_lookup")
classes = table.lookup(class_ids)
else:
classes = class_ids
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 1915a89af7..9319928307 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2156,8 +2156,6 @@ tf_cc_tests(
"platform/port_test.cc",
"platform/profile_utils/cpu_utils_test.cc",
"platform/subprocess_test.cc",
- "platform/vmodule_benchmark_test.cc",
- "platform/vmodule_test.cc",
],
deps = [
":lib",
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc
index e7092f549b..f5dadf76da 100644
--- a/tensorflow/core/framework/allocator.cc
+++ b/tensorflow/core/framework/allocator.cc
@@ -117,16 +117,6 @@ class CPUAllocator : public Allocator {
TF_DISALLOW_COPY_AND_ASSIGN(CPUAllocator);
};
-namespace {
-Allocator* MakeCpuAllocator() {
- Allocator* allocator = new CPUAllocator;
- if (cpu_allocator_collect_full_stats || LogMemory::IsEnabled()) {
- allocator = new TrackingAllocator(allocator, true);
- }
- return allocator;
-}
-} // namespace
-
Allocator* cpu_allocator() {
static Allocator* cpu_alloc = AllocatorRegistry::Global()->GetAllocator();
if (cpu_allocator_collect_full_stats && !cpu_alloc->TracksAllocationSizes()) {
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 2f9ceaa3bd..cf5d6e8baa 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -1099,6 +1099,44 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
CHECK_NOTNULL(workspace_tensors);
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+ // TODO(nhasabni): Temporary solution to connect filter input of
+ // BackpropInput with the converted filter from Conv2D.
+ bool do_connect_conv2d_backprop_input_filter = false;
+ Node* conv2d_node = nullptr;
+ // Filter node is 2nd input (slot index 1) of Conv2D.
+ int kConv2DFilterInputSlotIdx = 1;
+ int kConv2DBackpropInputFilterInputSlotIdx = 1;
+ int kConv2DFilterOutputSlotIdx = 1;
+ if (old_node->type_string() == csinfo_.conv2d_grad_input) {
+ // We need to find Conv2D node from Conv2DBackpropInput.
+ // For that let's first find filter node that is 2nd input (slot 1)
+ // of BackpropInput.
+ Node* filter_node = nullptr;
+ old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx, &filter_node);
+ CHECK_NOTNULL(filter_node);
+
+ // Now check which nodes receive from filter_node. Filter feeds as
+ // 2nd input (slot 1) of _MklConv2D and _MklConv2DWithBias.
+ for (const Edge* e : filter_node->out_edges()) {
+ if (e->dst()->type_string() == csinfo_.mkl_conv2d &&
+ e->dst_input() == kConv2DFilterInputSlotIdx
+ /* filter is 2nd input of Conv2D and _MklConv2D. */) {
+ if (conv2d_node != nullptr) {
+ VLOG(1) << "MklLayoutRewritePass: unusual case of same filter"
+ << " feeding multiple Conv2D nodes: "
+ << filter_node->DebugString();
+ // We will not connect filter input of Conv2DBackpropInput
+ // to be safe here.
+ do_connect_conv2d_backprop_input_filter = false;
+ break;
+ } else {
+ conv2d_node = e->dst();
+ do_connect_conv2d_backprop_input_filter = true;
+ }
+ }
+ }
+ }
+
// Number of input slots to original op
// Input slots are represented by .Input() calls in REGISTER_OP.
int old_node_input_slots = old_node->op_def().input_arg_size();
@@ -1122,7 +1160,13 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
nb->Input(new_node_inputs);
nn_slot_idx++;
} else {
- nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
+ // Special case for connecting filter input of Conv2DBackpropInput
+ if (do_connect_conv2d_backprop_input_filter &&
+ iidx == kConv2DBackpropInputFilterInputSlotIdx) {
+ nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx);
+ } else {
+ nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
+ }
iidx++;
nn_slot_idx++;
}
@@ -1157,9 +1201,17 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
} else {
Node* mkl_node = nullptr;
int mkl_node_output_slot = 0;
- GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
- old_node_inputs[iidx].second,
- &mkl_node, &mkl_node_output_slot);
+ // Special case for connecting filter input of Conv2DBackpropInput
+ if (do_connect_conv2d_backprop_input_filter &&
+ iidx == kConv2DBackpropInputFilterInputSlotIdx) {
+ GetNodeProducingMklTensor(g, old_node, conv2d_node,
+ kConv2DFilterOutputSlotIdx, &mkl_node,
+ &mkl_node_output_slot);
+ } else {
+ GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
+ old_node_inputs[iidx].second, &mkl_node,
+ &mkl_node_output_slot);
+ }
nb->Input(mkl_node, mkl_node_output_slot);
iidx++;
nn_slot_idx++;
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index 482e339802..bd1d74368e 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -788,7 +788,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->C:1;C->D:1;C->E;"
- "C:1->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
+ "C:2->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
}
// Conv2D with INT32 which is not supported by Mkl
@@ -917,7 +917,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;"
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
- "DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;"
+ "DMT/_4->H:3;E->H:1;E:2->H:4;F->H:2;F:2->H:5;G->H;"
"G:control->DMT/_4:control;H->I:1");
}
@@ -953,7 +953,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
"H(_MklConcat);I(Mul)|A->E;A->I;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
- "DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:1->H:4;F->H:2;"
+ "DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:2->H:4;F->H:2;"
"G->H;G:control->DMT/_2:control;G:control->DMT/_3:control;H->I:1");
}
@@ -1023,8 +1023,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
"A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;"
"C:control->DMT/_0:control;C:control->DMT/_1:control;"
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
- "DMT/_4->H:5;E->H;E:1->H:3;E:control->DMT/_4:control;F->H:1;"
- "F:1->H:4;G->H:2;H->I:1");
+ "DMT/_4->H:5;E->H;E:2->H:3;E:control->DMT/_4:control;F->H:1;"
+ "F:2->H:4;G->H:2;H->I:1");
}
// ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it
@@ -1060,7 +1060,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
"H(_MklConcatV2);I(Mul)|A->E;A->I;A:control->DMT/_0:control;"
"A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
- "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;"
+ "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:2->H:3;"
"E:control->DMT/_2:control;E:control->DMT/_3:control;F->H:1;"
"G->H:2;H->I:1");
}
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
index 90bef11164..b01818f746 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
@@ -173,13 +173,13 @@ TEST_F(MklToTfConversionPass, Positive) {
EXPECT_EQ(DoRunMklToTfConversionPass(),
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
"Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;"
- "C:1->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3");
+ "C:2->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3");
}
}
// MklConv2D followed by MklToTf op followed by Non-Mkl layer.
// C=MklConv2D(A,M,B,N); D=MklToTf(C:0, C:1) F=Sub(D,E) (for interleaved)
-// C=MklConv2D(A,B,M,N); D=MklToTf(C:0, C:1) F=Sub(D,E) (for contiguous)
+// C=MklConv2D(A,B,M,N); D=MklToTf(C:0, C:2) F=Sub(D,E) (for contiguous)
// MklToTf node should not be inserted again.
TEST_F(MklToTfConversionPass, Negative_DoubleInsert) {
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
@@ -226,7 +226,7 @@ TEST_F(MklToTfConversionPass, Negative_DoubleInsert) {
"node { name: 'D' op: '_MklToTf'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['C:0', 'C:1']}"
+ " input: ['C:0', 'C:2']}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
@@ -234,7 +234,7 @@ TEST_F(MklToTfConversionPass, Negative_DoubleInsert) {
EXPECT_EQ(DoRunMklToTfConversionPass(),
"A(Input);B(Input);C(_MklConv2D);D(_MklToTf);E(Input);"
"F(Sub);M(_MklInput);N(_MklInput)|"
- "A->C;B->C:1;C->D;C:1->D:1;D->F;E->F:1;M->C:2;N->C:3");
+ "A->C;B->C:1;C->D;C:2->D:1;D->F;E->F:1;M->C:2;N->C:3");
}
}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 082101ce11..8dd8900f28 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2590,7 +2590,9 @@ tf_kernel_library(
tf_kernel_library(
name = "segment_reduction_ops",
prefix = "segment_reduction_ops",
- deps = MATH_DEPS,
+ deps = MATH_DEPS + if_cuda([
+ ":cuda_solvers",
+ ]),
)
tf_kernel_library(
@@ -3344,6 +3346,19 @@ tf_kernel_library(
deps = PARSING_DEPS,
)
+tf_cc_test(
+ name = "parse_tensor_test",
+ srcs = ["parse_tensor_test.cc"],
+ deps = [
+ ":ops_testutil",
+ ":parse_tensor_op",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
tf_kernel_library(
name = "string_to_number_op",
prefix = "string_to_number_op",
diff --git a/tensorflow/core/kernels/cwise_op_sub.cc b/tensorflow/core/kernels/cwise_op_sub.cc
index eb173c7040..6adaecba04 100644
--- a/tensorflow/core/kernels/cwise_op_sub.cc
+++ b/tensorflow/core/kernels/cwise_op_sub.cc
@@ -18,7 +18,10 @@ limitations under the License.
namespace tensorflow {
REGISTER7(BinaryOp, CPU, "Sub", functor::sub, float, Eigen::half, double, int32,
int64, complex64, complex128);
-#if defined(__ANDROID_TYPES_SLIM__)
+#if !defined(__ANDROID_TYPES_SLIM__)
+// Sub op for int8, uint8, int16, uint16
+REGISTER4(BinaryOp, CPU, "Sub", functor::sub, int8, uint8, int16, uint16);
+#else
// We only register the first type when we have multi-argument calls in the
// case where we're trying to reduce executable size, but it turns out that the
// int32 version of this op is needed, so explicitly include it.
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index ef7338e0e0..00884d0981 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -97,8 +97,12 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
errors::InvalidArgument(
"Conv2DCustomBackpropInput: size must be 4-dim"));
- MklSizesToTFSizes(context, data_format, mkl_context.filter_shape,
- &filter_shape);
+ const int64* filter_sizes =
+ (const int64*)mkl_context.filter_shape.GetSizes();
+ const int64 filter_dims = mkl_context.filter_shape.GetDimension();
+
+ OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
+ filter_sizes, filter_dims, &filter_shape));
} else {
filter_shape = filter.shape();
}
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 203e694631..5dfce5d5c6 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -265,6 +265,28 @@ class MklConv2DOp : public OpKernel {
sizeof(T));
AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape,
mkl_output_mkl_shape);
+ // Filter output to be used in the backprop_input
+ TensorShape mkl_filter_output_tf_shape;
+ MklShape mkl_filter_output_mkl_shape;
+ mkl_filter_output_mkl_shape.SetMklTensor(true);
+ mkl_filter_output_mkl_shape.SetMklLayout(mkl_context.prim_fwd,
+ dnnResourceFilter);
+
+ size_t filter_sizes[4] = {filter.dim_size(0), filter.dim_size(1),
+ filter.dim_size(2), filter.dim_size(3)};
+ mkl_filter_output_mkl_shape.SetTfLayout(filter.dims(), filter_sizes,
+ mkl_context.filter_strides);
+
+ mkl_filter_output_mkl_shape.SetTfDimOrder(mkl_context.filter_dims,
+ data_format_);
+ mkl_filter_output_tf_shape.AddDim(
+ dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
+ mkl_filter_output_mkl_shape.GetMklLayout())) /
+ sizeof(T));
+ AllocateOutputSetMklShape(context, 1, &mkl_context.output_filter,
+ mkl_filter_output_tf_shape,
+ mkl_filter_output_mkl_shape);
+
mkl_context.conv_res[dnnResourceDst] =
static_cast<void*>(output->flat<T>().data());
@@ -303,6 +325,7 @@ class MklConv2DOp : public OpKernel {
dnnPrimitive_t prim_fwd;
void* conv_res[dnnResourceNumber];
dnnLayout_t lt_filter, lt_bias, lt_input;
+ Tensor* output_filter = nullptr;
// Create MKL dnnLayout_t objects for tensors coming into the layer
void MklCreateInputLayouts(OpKernelContext* context) {
@@ -383,8 +406,8 @@ class MklConv2DOp : public OpKernel {
CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_filter, lt_filter,
mkl_lt_internal_filter),
E_SUCCESS);
- AllocTmpBuffer(context, mkl_tmp_filter_buf_tensor,
- mkl_lt_internal_filter, &mkl_buf_convert_filter);
+ mkl_buf_convert_filter = const_cast<void*>(
+ static_cast<const void*>(output_filter->flat<T>().data()));
CHECK_EQ(
dnnConversionExecute_F32(mkl_prim_convert_filter, mkl_buf_filter,
mkl_buf_convert_filter),
diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc
index b3763f17bc..5e98582475 100644
--- a/tensorflow/core/kernels/mkl_reshape_op.cc
+++ b/tensorflow/core/kernels/mkl_reshape_op.cc
@@ -43,30 +43,26 @@ class MklReshapeOp : public OpKernel {
OP_REQUIRES(context, IsLegacyVector(sizes.shape()),
errors::InvalidArgument("sizes input must be 1-D, not shape ",
sizes.shape().DebugString()));
- const int64 num_dims = sizes.NumElements();
// Compute the output shape. Determine product of specified
// dimensions, and find the index of the unspecified one.
TensorShape shape;
int64 product = 1;
int unknown_index = -1;
- auto vec_size = sizes.flat<int32>();
- for (int d = 0; d < num_dims; ++d) {
- const int32 size = vec_size(d);
- if (size == -1) {
- OP_REQUIRES(
- context, unknown_index == -1,
- errors::InvalidArgument("only one input size may be -1, not both ",
- unknown_index, " and ", d));
- unknown_index = d;
- shape.AddDim(1);
- } else {
- OP_REQUIRES(context, size >= 0,
- errors::InvalidArgument(
- "size ", d, " must be non-negative, not ", size));
- shape.AddDim(size);
- product *= size;
- }
+ switch (sizes.dtype()) {
+ case DT_INT32:
+ OP_REQUIRES_OK(context, ValidateSizes<int32>(sizes, &product,
+ &unknown_index, &shape));
+ break;
+ case DT_INT64:
+ OP_REQUIRES_OK(context, ValidateSizes<int64>(sizes, &product,
+ &unknown_index, &shape));
+ break;
+ default:
+ context->CtxFailure(errors::InvalidArgument(
+ "desired shape must be a DT_INT32 or DT_INT64 vector, not a ",
+ DataTypeString(sizes.dtype())));
+ return;
}
if (unknown_index != -1) {
OP_REQUIRES(
@@ -132,6 +128,35 @@ class MklReshapeOp : public OpKernel {
CopyTfTensorInToOutWithShape(context, 0, 0, shape);
}
}
+
+ private:
+ template <typename Tshape>
+ Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index,
+ TensorShape* shape) {
+ *product = 1;
+ *unknown_index = -1;
+ const int64 num_dims = sizes.NumElements();
+ auto Svec = sizes.flat<Tshape>();
+ for (int d = 0; d < num_dims; ++d) {
+ const Tshape size = Svec(d);
+ if (size == -1) {
+ if (*unknown_index != -1) {
+ return errors::InvalidArgument(
+ "Only one input size may be -1, not both ", *unknown_index,
+ " and ", d);
+ }
+ *unknown_index = d;
+ shape->AddDim(1);
+ } else if (size < 0) {
+ return errors::InvalidArgument("Size ", d,
+ " must be non-negative, not ", size);
+ } else {
+ shape->AddDim(size);
+ (*product) *= size;
+ }
+ }
+ return Status::OK();
+ }
};
#define REGISTER_MKL_CPU(T) \
@@ -141,6 +166,13 @@ class MklReshapeOp : public OpKernel {
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tshape") \
.Label(mkl_op_registry::kMklOpLabel), \
+ MklReshapeOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklReshape") \
+ .Device(DEVICE_CPU) \
+ .HostMemory("shape") \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int64>("Tshape") \
+ .Label(mkl_op_registry::kMklOpLabel), \
MklReshapeOp<CPUDevice, T>);
TF_CALL_float(REGISTER_MKL_CPU);
#undef REGISTER_MKL_CPU
diff --git a/tensorflow/core/kernels/parse_tensor_op.cc b/tensorflow/core/kernels/parse_tensor_op.cc
index 79199ff5c3..8e175fe8d4 100644
--- a/tensorflow/core/kernels/parse_tensor_op.cc
+++ b/tensorflow/core/kernels/parse_tensor_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
// See docs in ../ops/parsing_ops.cc.
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -65,4 +66,31 @@ class ParseTensorOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("ParseTensor").Device(DEVICE_CPU), ParseTensorOp);
+template <typename T>
+class SerializeTensorOp : public OpKernel {
+ public:
+ using OpKernel::OpKernel;
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor = context->input(0);
+ TensorProto proto;
+ if (tensor.dtype() == DT_STRING) {
+ tensor.AsProtoField(&proto);
+ } else {
+ tensor.AsProtoTensorContent(&proto);
+ }
+ Tensor* proto_string = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &proto_string));
+ CHECK(proto.SerializeToString(&proto_string->scalar<string>()()));
+ }
+};
+
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SerializeTensor").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ SerializeTensorOp<T>);
+TF_CALL_ALL_TYPES(REGISTER)
+#undef REGISTER
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/parse_tensor_test.cc b/tensorflow/core/kernels/parse_tensor_test.cc
new file mode 100644
index 0000000000..4a5fc07935
--- /dev/null
+++ b/tensorflow/core/kernels/parse_tensor_test.cc
@@ -0,0 +1,198 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+
+namespace tensorflow {
+namespace {
+
+class SerializeTensorOpTest : public OpsTestBase {
+ protected:
+ template <typename T>
+ void MakeOp(const TensorShape& input_shape, std::function<T(int)> functor) {
+ TF_ASSERT_OK(NodeDefBuilder("myop", "SerializeTensor")
+ .Input(FakeInput(DataTypeToEnum<T>::value))
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ AddInput<T>(input_shape, functor);
+ }
+ void ParseSerializedWithNodeDef(const NodeDef& parse_node_def,
+ Tensor* serialized, Tensor* parse_output) {
+ std::unique_ptr<Device> device(
+ DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ inputs.push_back({nullptr, serialized});
+ Status status;
+ std::unique_ptr<OpKernel> op(CreateOpKernel(DEVICE_CPU, device.get(),
+ cpu_allocator(), parse_node_def,
+ TF_GRAPH_DEF_VERSION, &status));
+ TF_EXPECT_OK(status);
+ OpKernelContext::Params params;
+ params.device = device.get();
+ params.inputs = &inputs;
+ params.frame_iter = FrameAndIter(0, 0);
+ params.op_kernel = op.get();
+ std::vector<AllocatorAttributes> attrs;
+ test::SetOutputAttrs(&params, &attrs);
+ OpKernelContext ctx(&params);
+ op->Compute(&ctx);
+ TF_EXPECT_OK(status);
+ *parse_output = *ctx.mutable_output(0);
+ }
+ template <typename T>
+ void ParseSerializedOutput(Tensor* serialized, Tensor* parse_output) {
+ NodeDef parse;
+ TF_ASSERT_OK(NodeDefBuilder("parse", "ParseTensor")
+ .Input(FakeInput(DT_STRING))
+ .Attr("out_type", DataTypeToEnum<T>::value)
+ .Finalize(&parse));
+ ParseSerializedWithNodeDef(parse, serialized, parse_output);
+ }
+};
+
+TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_half) {
+ MakeOp<Eigen::half>(TensorShape({10}), [](int x) -> Eigen::half {
+ return static_cast<Eigen::half>(x / 10.);
+ });
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor parse_output;
+ ParseSerializedOutput<Eigen::half>(GetOutput(0), &parse_output);
+ test::ExpectTensorEqual<Eigen::half>(parse_output, GetInput(0));
+}
+
+TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_float) {
+ MakeOp<float>(TensorShape({1, 10}),
+ [](int x) -> float { return static_cast<float>(x / 10.); });
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor parse_output;
+ ParseSerializedOutput<float>(GetOutput(0), &parse_output);
+ test::ExpectTensorEqual<float>(parse_output, GetInput(0));
+}
+
+TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_double) {
+ MakeOp<double>(TensorShape({5, 5}),
+ [](int x) -> double { return static_cast<double>(x / 10.); });
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor parse_output;
+ ParseSerializedOutput<double>(GetOutput(0), &parse_output);
+ test::ExpectTensorEqual<double>(parse_output, GetInput(0));
+}
+
+TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int64) {
+ MakeOp<int64>(TensorShape({2, 3, 4}),
+ [](int x) -> int64 { return static_cast<int64>(x - 10); });
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor parse_output;
+ ParseSerializedOutput<int64>(GetOutput(0), &parse_output);
+ test::ExpectTensorEqual<int64>(parse_output, GetInput(0));
+}
+
+TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int32) {
+ MakeOp<int32>(TensorShape({4, 2}),
+ [](int x) -> int32 { return static_cast<int32>(x + 7); });
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor parse_output;
+ ParseSerializedOutput<int32>(GetOutput(0), &parse_output);
+ test::ExpectTensorEqual<int32>(parse_output, GetInput(0));
+}
+
+TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int16) {
+ MakeOp<int16>(TensorShape({8}),
+ [](int x) -> int16 { return static_cast<int16>(x + 18); });
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor parse_output;
+ ParseSerializedOutput<int16>(GetOutput(0), &parse_output);
+ test::ExpectTensorEqual<int16>(parse_output, GetInput(0));
+}
+
+TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int8) {
+ MakeOp<int8>(TensorShape({2}),
+ [](int x) -> int8 { return static_cast<int8>(x + 8); });
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor parse_output;
+ ParseSerializedOutput<int8>(GetOutput(0), &parse_output);
+ test::ExpectTensorEqual<int8>(parse_output, GetInput(0));
+}
+
+TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_uint16) {
+ MakeOp<uint16>(TensorShape({1, 3}),
+ [](int x) -> uint16 { return static_cast<uint16>(x + 2); });
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor parse_output;
+ ParseSerializedOutput<uint16>(GetOutput(0), &parse_output);
+ test::ExpectTensorEqual<uint16>(parse_output, GetInput(0));
+}
+
+TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_uint8) {
+ MakeOp<uint8>(TensorShape({2, 1, 1}),
+ [](int x) -> uint8 { return static_cast<uint8>(x + 1); });
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor parse_output;
+ ParseSerializedOutput<uint8>(GetOutput(0), &parse_output);
+ test::ExpectTensorEqual<uint8>(parse_output, GetInput(0));
+}
+
+TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_complex64) {
+ MakeOp<complex64>(TensorShape({}), [](int x) -> complex64 {
+ return complex64{static_cast<float>(x / 8.), static_cast<float>(x / 2.)};
+ });
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor parse_output;
+ ParseSerializedOutput<complex64>(GetOutput(0), &parse_output);
+ test::ExpectTensorEqual<complex64>(parse_output, GetInput(0));
+}
+
+TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_complex128) {
+ MakeOp<complex128>(TensorShape({3}), [](int x) -> complex128 {
+ return complex128{x / 3., x / 2.};
+ });
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor parse_output;
+ ParseSerializedOutput<complex128>(GetOutput(0), &parse_output);
+ test::ExpectTensorEqual<complex128>(parse_output, GetInput(0));
+}
+
+TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_bool) {
+ MakeOp<bool>(TensorShape({1}),
+ [](int x) -> bool { return static_cast<bool>(x % 2); });
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor parse_output;
+ ParseSerializedOutput<bool>(GetOutput(0), &parse_output);
+ test::ExpectTensorEqual<bool>(parse_output, GetInput(0));
+}
+
+TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_string) {
+ MakeOp<string>(TensorShape({10}),
+ [](int x) -> string { return std::to_string(x / 10.); });
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor parse_output;
+ ParseSerializedOutput<string>(GetOutput(0), &parse_output);
+ test::ExpectTensorEqual<string>(parse_output, GetInput(0));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc
index 9cdbe89457..5624d5cd1b 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops.cc
@@ -16,6 +16,9 @@ limitations under the License.
// See docs in ../ops/math_ops.cc.
#define EIGEN_USE_THREADS
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif // GOOGLE_CUDA
#include "tensorflow/core/kernels/segment_reduction_ops.h"
#include <vector>
@@ -32,6 +35,14 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/util.h"
+#if GOOGLE_CUDA
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
+#include "tensorflow/core/kernels/cuda_solvers.h"
+#include "tensorflow/core/platform/cuda.h"
+
+using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
+#endif // GOOGLE_CUDA
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -183,6 +194,106 @@ class SegmentReductionOp : public OpKernel {
}
};
+#ifdef GOOGLE_CUDA
+// SegmentSumGPUOp is a segment sum operator implemented for GPU only.
+// TODO: This implementation of SegmentSumGPUOp is sometimes slower than
+// its unsorted counterpart (mostly when problem size is small).
+// This is due to the following two main reasons and a cost-effective way
+// to resolve these problems is desirable.
+// 1. Sorted segment sum requires a memory transfer from device to host in
+// order to know the size of the output dimension whereas unsorted segment
+// sum receives the size of the output dimension as an input parameter.
+// 2. Sorted segment sum is essentially a tiled version of unsorted segment
+// sum and therefore such optimization comes at an inherent cost. However
+// such cost may not be justified when the problem size is small. When to
+// use the tiled version or the untiled version depends on many factors
+// including data alignments, ratio of calculation to memory traffic and
+// obviously, the problem sizes.
+template <class T, class Index>
+class SegmentSumGPUOp : public AsyncOpKernel {
+ public:
+ explicit SegmentSumGPUOp(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
+ const Tensor& input = context->input(0);
+ const Tensor& segment_ids = context->input(1);
+
+ OP_REQUIRES_ASYNC(
+ context, TensorShapeUtils::IsVector(segment_ids.shape()),
+ errors::InvalidArgument("segment_ids should be a vector."), done);
+
+ const int64 num_indices = segment_ids.NumElements();
+ OP_REQUIRES_ASYNC(
+ context, num_indices == input.dim_size(0),
+ errors::InvalidArgument(
+ "segment_ids should be the same size as dimension 0 of"
+ " input."),
+ done);
+
+ if (num_indices == 0) {
+ TensorShape output_shape = input.shape();
+ output_shape.set_dim(0, 0);
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(
+ context, context->allocate_output(0, output_shape, &output), done);
+ done();
+ return;
+ }
+
+ perftools::gputools::DeviceMemoryBase output_rows_device(
+ (void*)(segment_ids.template flat<Index>().data() + (num_indices - 1)));
+ ScratchSpace<Index> output_rows_host(context, 1, /* on_host */ true);
+
+ auto stream = context->op_device_context()->stream();
+ OP_REQUIRES_ASYNC(
+ context,
+ stream
+ ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device,
+ sizeof(Index))
+ .ok(),
+ errors::Internal(
+ "SegmentSumGPUOp: failed to copy output_rows from device"),
+ done);
+
+ functor::SegmentSumFunctor<T, Index> functor_;
+ auto create_and_check_output = [context, output_rows_host, &input,
+ &segment_ids, &functor_, done]() {
+ // Ensure that within the callback, the proper GPU settings are
+ // configured.
+ auto stream = context->op_device_context()->stream();
+ ScopedActivateExecutorContext scoped_activation{stream->parent()};
+
+ Index output_rows = *output_rows_host.data();
+ output_rows++;
+ OP_REQUIRES_ASYNC(context, output_rows > 0,
+ errors::InvalidArgument("segment ids must be >= 0"),
+ done);
+
+ TensorShape output_shape = input.shape();
+ output_shape.set_dim(0, output_rows);
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(
+ context, context->allocate_output(0, output_shape, &output), done);
+
+ auto output_flat = output->flat_outer_dims<T>();
+ auto data_ptr = input.template flat<T>().data();
+ auto segment_flat = segment_ids.flat<Index>();
+ functor_(context, context->eigen_device<GPUDevice>(), output_rows,
+ segment_ids.shape(), segment_flat, input.NumElements(), data_ptr,
+ output_flat);
+
+ done();
+ };
+
+ context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
+ stream, create_and_check_output);
+ }
+};
+#endif // GOOGLE_CUDA
+
#define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type, \
default_value) \
REGISTER_KERNEL_BUILDER( \
@@ -227,6 +338,23 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
#undef REGISTER_REAL_CPU_KERNELS_ALL
#undef REGISTER_COMPLEX_CPU_KERNELS_ALL
+#if GOOGLE_CUDA
+#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
+ REGISTER_KERNEL_BUILDER(Name("SegmentSum") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ SegmentSumGPUOp<type, index_type>)
+
+#define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
+ REGISTER_GPU_SORTED_KERNELS(type, int32); \
+ REGISTER_GPU_SORTED_KERNELS(type, int64);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL);
+#undef REGISTER_GPU_SORTED_KERNELS
+#undef REGISTER_GPU_SORTED_KERNELS_ALL
+#endif // GOOGLE_CUDA
+
namespace functor {
// UnsortedSegmentSumFunctor implementation for CPUDevice.
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index ee09c213b7..412c1d601d 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -26,6 +26,28 @@ namespace tensorflow {
class OpKernelContext;
namespace functor {
+
+#ifdef GOOGLE_CUDA
+typedef Eigen::GpuDevice GPUDevice;
+// Functor for SegmentSumGPUOp.
+// 'output_rows': the number of output segments (unique segment ids in
+// 'segment_ids').
+// 'segment_ids_shape': shape of 'segment_ids' tensor.
+// 'segment_ids': unsorted map from input to output segment ids at which to
+// perform segment sum operation.
+// 'data_size': size of input data tensor.
+// 'data': input data tensor.
+// 'output': output reshaped to {output_rows, output.size/output_rows}
+template <typename T, typename Index>
+struct SegmentSumFunctor {
+ void operator()(OpKernelContext* ctx, const GPUDevice& d,
+ const Index output_rows, const TensorShape& segment_ids_shape,
+ typename TTypes<Index>::ConstFlat segment_ids,
+ const Index data_size, const T* data,
+ typename TTypes<T, 2>::Tensor output);
+};
+#endif
+
// BaseFunctor for definition of UnsorteSegmentReductionOp
// for usage without templates.
template <typename Device, typename T, typename Index>
diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
index b132b1e8f8..159fada621 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
@@ -54,6 +54,77 @@ __device__ __forceinline__ void AccumulateInto(
CudaAtomicAdd(dest_scalar + 1, value.imag());
}
+// SortedSegmentSumFunctor kernel reduces input data just as
+// UnsortedSegmentSumCustomKernel does except that input data
+// is partitioned along the outer reduction dimension. This is
+// because consecutive rows (elements in a row share the same
+// outer dimension index) in the flattened 2D input data likely
+// belong to the same segment in sorted segment sum operation.
+// Therefore such partitioning strategy has two advantages over
+// the UnsortedSegmentSumFunctor kernel:
+// 1. Each thread reduces across multiple rows before writing
+// answers to the global memory, we can therefore
+// write reduction results to global memory less often.
+// 2. We may know that the current thread is the only contributor
+// to an output element because of the increasing nature of segment
+// ids. In such cases, we do not need to use atomic operations
+// to write results to global memory.
+// In the flattened view of input data (with only outer and inner
+// dimension), every thread processes a strip of input data of
+// size OuterDimTileSize x 1. This strip runs across multiple
+// rows of input data and all reduction elements share one inner
+// dimension index.
+template <typename T, typename Index, int OuterDimTileSize>
+__global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size,
+ const Index inner_dim_size,
+ const Index output_outer_dim_size,
+ const Index* segment_ids,
+ const T* input, T* output,
+ const Index total_stripe_count) {
+ CUDA_1D_KERNEL_LOOP(stripe_index, total_stripe_count) {
+ const Index segment_offset = stripe_index % inner_dim_size;
+ const Index input_outer_dim_index_base =
+ stripe_index / inner_dim_size * Index(OuterDimTileSize);
+
+ T sum = T(0);
+ Index first_segment_id = segment_ids[input_outer_dim_index_base];
+ Index last_output_segment_id = output_outer_dim_size;
+
+ const Index actual_stripe_height =
+ min(Index(OuterDimTileSize),
+ input_outer_dim_size - input_outer_dim_index_base);
+ for (Index j = 0; j < actual_stripe_height; j++) {
+ Index current_output_segment_id =
+ segment_ids[input_outer_dim_index_base + j];
+ // Decide whether to write result to global memory.
+ // Result is only written to global memory if we move
+ // to another segment. Otherwise we can keep accumulating
+ // locally.
+ if (current_output_segment_id > last_output_segment_id) {
+ const Index output_index =
+ last_output_segment_id * inner_dim_size + segment_offset;
+ // decide whether to write result to global memory using atomic
+ // operations
+ if (last_output_segment_id == first_segment_id) {
+ AccumulateInto<T>(output + output_index, sum);
+ } else {
+ *(output + output_index) = sum;
+ }
+ sum = T(0);
+ }
+ sum += ldg(input + (input_outer_dim_index_base + j) * inner_dim_size +
+ segment_offset);
+ last_output_segment_id = current_output_segment_id;
+ }
+ // For the last result in a strip, always write using atomic operations
+ // due to possible race conditions with threads computing
+ // the following strip.
+ const Index output_index =
+ last_output_segment_id * inner_dim_size + segment_offset;
+ AccumulateInto<T>(output + output_index, sum);
+ }
+}
+
// UnsortedSegmentSumFunctor kernel processes 'input_total_size' elements.
// Each element is mapped from input to output by a combination of its
// 'segment_ids' mapping and 'inner_dim_size'.
@@ -80,6 +151,47 @@ __global__ void UnsortedSegmentSumCustomKernel(
namespace functor {
+template <typename T, typename Index>
+void SegmentSumFunctor<T, Index>::operator()(
+ OpKernelContext* ctx, const GPUDevice& d, const Index output_rows,
+ const TensorShape& segment_ids_shape,
+ typename TTypes<Index>::ConstFlat segment_ids, const Index data_size,
+ const T* data, typename TTypes<T, 2>::Tensor output) {
+ if (output.size() == 0) {
+ return;
+ }
+ // Set 'output' to zeros.
+ CudaLaunchConfig config = GetCudaLaunchConfig(output.size(), d);
+ SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ output.size(), output.data());
+ if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
+ return;
+ }
+
+ // Launch kernel to compute sorted segment sum.
+ // Notes:
+ // *) 'input_total_size' is the total number of elements to process.
+ // *) 'segment_ids.shape' is a prefix of data's shape.
+ // *) 'input_outer_dim_size' is the total number of segments to process.
+ const Index input_total_size = data_size;
+ const Index input_outer_dim_size = segment_ids.dimension(0);
+ const Index input_inner_dim_size = input_total_size / input_outer_dim_size;
+
+ const int OuterDimTileSize = 8;
+
+ const Index input_outer_dim_num_stripe =
+ Eigen::divup(input_outer_dim_size, Index(OuterDimTileSize));
+
+ const Index total_stripe_count =
+ input_inner_dim_size * input_outer_dim_num_stripe;
+
+ config = GetCudaLaunchConfig(total_stripe_count, d);
+ SortedSegmentSumCustomKernel<T, Index, OuterDimTileSize>
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ input_outer_dim_size, input_inner_dim_size, output_rows,
+ segment_ids.data(), data, output.data(), total_stripe_count);
+};
+
// UnsortedSegmentSumFunctor implementation for GPUDevice.
template <typename T, typename Index>
struct UnsortedSegmentSumFunctor<GPUDevice, T, Index>: UnsortedSegmentBaseFunctor<GPUDevice, T, Index> {
@@ -117,6 +229,15 @@ struct UnsortedSegmentSumFunctor<GPUDevice, T, Index>: UnsortedSegmentBaseFuncto
}
};
+#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index) \
+ template struct SegmentSumFunctor<T, Index>
+
+#define DEFINE_SORTED_GPU_SPECS(T) \
+ DEFINE_SORTED_GPU_SPECS_INDEX(T, int32); \
+ DEFINE_SORTED_GPU_SPECS_INDEX(T, int64);
+
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS);
+
#define DEFINE_GPU_SPECS_INDEX(T, Index) \
template struct UnsortedSegmentSumFunctor<GPUDevice, T, Index>
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 6ff05bd2a6..6eb05874aa 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -499,7 +499,7 @@ Returns x + y element-wise.
)doc");
REGISTER_OP("Sub")
- .BINARY_FEWER()
+ .BINARY_MORE()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
Returns x - y element-wise.
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 1ab1f1a736..8a2d5e8c05 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -2791,7 +2791,9 @@ REGISTER_OP("_MklConv2D")
.Input("mkl_input: uint8")
.Input("mkl_filter: uint8")
.Output("output: T")
+ .Output("filter_output: T")
.Output("mkl_output: uint8")
+ .Output("mkl_filter_output: uint8")
.Attr("T: {half, float, double}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
@@ -2813,7 +2815,9 @@ REGISTER_OP("_MklConv2DWithBias")
.Input("mkl_filter: uint8")
.Input("mkl_bias: uint8")
.Output("output: T")
+ .Output("filter_output: T")
.Output("mkl_output: uint8")
+ .Output("mkl_filter_output: uint8")
.Attr("T: {half, float, double}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 3a28ce3767..35c31c6cb8 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -15805,6 +15805,25 @@ op {
summary: "Transforms a serialized tensorflow.TensorProto proto into a Tensor."
}
op {
+ name: "SerializeTensor"
+ input_arg {
+ name: "tensor"
+ description: "A Tensor of type `T`."
+ type: "T"
+ }
+ output_arg {
+ name: "serialized"
+ description: "A serialized TensorProto proto of the input tensor."
+ type_attr: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "The type of the input tensor."
+ }
+ summary: "Transforms a Tensor into a serialized TensorProto proto."
+}
+op {
name: "Placeholder"
output_arg {
name: "output"
diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc
index 2e605fdffc..1f7ebe91cf 100644
--- a/tensorflow/core/ops/parsing_ops.cc
+++ b/tensorflow/core/ops/parsing_ops.cc
@@ -292,6 +292,19 @@ out_type: The type of the serialized tensor. The provided type must match the
output: A Tensor of type `out_type`.
)doc");
+REGISTER_OP("SerializeTensor")
+ .Input("tensor: T")
+ .Output("serialized: string")
+ .Attr("T: type")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Transforms a Tensor into a serialized TensorProto proto.
+
+tensor: A Tensor of type `T`.
+T: The type of the input tensor.
+serialized: A serialized TensorProto proto of the input tensor.
+)doc");
+
REGISTER_OP("DecodeJSONExample")
.Input("json_examples: string")
.Output("binary_examples: string")
diff --git a/tensorflow/core/platform/default/logging.cc b/tensorflow/core/platform/default/logging.cc
index ac0988e704..7127db3929 100644
--- a/tensorflow/core/platform/default/logging.cc
+++ b/tensorflow/core/platform/default/logging.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/default/logging.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/env_time.h"
#include "tensorflow/core/platform/macros.h"
@@ -25,12 +24,8 @@ limitations under the License.
#endif
#include <stdlib.h>
-#include <string.h>
#include <time.h>
-#include <string>
-#include <unordered_map>
-
namespace tensorflow {
namespace internal {
@@ -129,48 +124,6 @@ int64 MinVLogLevelFromEnv() {
return LogLevelStrToInt(tf_env_var_val);
}
-using VmoduleMap = std::unordered_map<StringPiece, int, StringPiece::Hasher>;
-
-// Returns a mapping from module name to VLOG level, derived from the
-// TF_CPP_VMOUDLE environment variable; ownership is transferred to the caller.
-VmoduleMap* VmoduleRecordsFromEnv() {
- // The value of the env var is supposed to be of the form:
- // "foo=1,bar=2,baz=3"
- const char* tf_env_var_val = getenv("TF_CPP_VMODULE");
- auto* result = new VmoduleMap();
- if (tf_env_var_val == nullptr) return result;
- while (true) {
- const char* eq = strchr(tf_env_var_val, '=');
- if (eq == nullptr) break;
- const char* after_eq = eq + 1;
-
- // Comma either points at the next comma delimiter, or at a null terminator.
- // We check that the integer we parse ends at this delimiter.
- const char* comma = strchr(after_eq, ',');
- const char* new_tf_env_var_val;
- if (comma == nullptr) {
- comma = strchr(after_eq, '\0');
- new_tf_env_var_val = comma;
- } else {
- new_tf_env_var_val = comma + 1;
- }
-
- char* endptr = nullptr;
- int level = strtol(after_eq, &endptr, 10);
- if (endptr != comma) {
- fprintf(stderr,
- "warning: could not parse integer in vmodule specification in "
- "\"%s\".\n",
- after_eq);
- break;
- }
- StringPiece module(tf_env_var_val, eq - tf_env_var_val);
- tf_env_var_val = new_tf_env_var_val;
- (*result)[module] = level;
- }
- return result;
-}
-
} // namespace
LogMessage::~LogMessage() {
@@ -184,19 +137,6 @@ int64 LogMessage::MinVLogLevel() {
return min_vlog_level;
}
-bool LogMessage::VmoduleActivated(const char* fname, int lvl) {
- static VmoduleMap* vmodule_records = VmoduleRecordsFromEnv();
- const char* last_slash = strrchr(fname, '/');
- const char* module_start = last_slash == nullptr ? fname : last_slash + 1;
- const char* dot_after = strchr(module_start, '.');
- const char* module_limit =
- dot_after == nullptr ? strchr(fname, '\0') : dot_after;
- StringPiece module(module_start, module_limit - module_start);
- auto it = vmodule_records->find(module);
- if (it == vmodule_records->end()) return false;
- return it->second >= lvl;
-}
-
LogMessageFatal::LogMessageFatal(const char* file, int line)
: LogMessage(file, line, FATAL) {}
LogMessageFatal::~LogMessageFatal() {
diff --git a/tensorflow/core/platform/default/logging.h b/tensorflow/core/platform/default/logging.h
index c8c9b2da11..d5f7350cdd 100644
--- a/tensorflow/core/platform/default/logging.h
+++ b/tensorflow/core/platform/default/logging.h
@@ -46,16 +46,6 @@ class LogMessage : public std::basic_ostringstream<char> {
// but VLOG(3) will not. Defaults to 0.
static int64 MinVLogLevel();
- // Returns whether VLOG level lvl is activated for the file fname.
- //
- // E.g. if the environment variable TF_CPP_VMODULE contains foo=3 and fname is
- // foo.cc and lvl is <= 3, this will return true.
- //
- // It is expected that the result of this query will be cached in the VLOG-ing
- // call site to avoid repeated lookups. This routine performs a hash-map
- // access against the VLOG-ing specification provided by the env var.
- static bool VmoduleActivated(const char* fname, int lvl);
-
protected:
void GenerateLogMessage();
@@ -86,38 +76,18 @@ class LogMessageFatal : public LogMessage {
#define LOG(severity) _TF_LOG_##severity
-#if defined(IS_MOBILE_PLATFORM)
-
+#ifdef IS_MOBILE_PLATFORM
// Turn VLOG off when under mobile devices for considerations of binary size.
-#define _VLOG_IS_ON(lvl, file) ((lvl) <= 0)
-
-#elif defined(PLATFORM_WINDOWS)
-
-// TODO(b/64279502) The _VLOG_IS_ON definition below appears to cause MSVC to
-// fatal error, so we fall back to the vmodule-less implementation for now.
-#define _VLOG_IS_ON(lvl, file) \
- ((lvl) <= ::tensorflow::internal::LogMessage::MinVLogLevel())
-
+#define VLOG_IS_ON(lvl) ((lvl) <= 0)
#else
-
-// Otherwise, set TF_CPP_MIN_VLOG_LEVEL environment to update minimum log level
-// of VLOG, or TF_CPP_VMODULE to set the minimum log level for individual
-// translation units.
-#define _VLOG_IS_ON(lvl, file) \
- (([](int level, const char* fname) { \
- if (level <= ::tensorflow::internal::LogMessage::MinVLogLevel()) \
- return true; \
- static bool vmodule_activated = \
- ::tensorflow::internal::LogMessage::VmoduleActivated(fname, level); \
- return vmodule_activated; \
- })(lvl, file))
-
+// Otherwise, Set TF_CPP_MIN_VLOG_LEVEL environment to update minimum log level
+// of VLOG
+#define VLOG_IS_ON(lvl) \
+ ((lvl) <= ::tensorflow::internal::LogMessage::MinVLogLevel())
#endif
-#define VLOG_IS_ON(lvl) _VLOG_IS_ON(lvl, __FILE__)
-
-#define VLOG(lvl) \
- if (TF_PREDICT_FALSE(_VLOG_IS_ON(lvl, __FILE__))) \
+#define VLOG(lvl) \
+ if (TF_PREDICT_FALSE(VLOG_IS_ON(lvl))) \
::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::INFO)
// CHECK dies with a fatal error if condition is not true. It is *not*
diff --git a/tensorflow/core/profiler/g3doc/command_line.md b/tensorflow/core/profiler/g3doc/command_line.md
index e2839a682f..fb4207c784 100644
--- a/tensorflow/core/profiler/g3doc/command_line.md
+++ b/tensorflow/core/profiler/g3doc/command_line.md
@@ -57,7 +57,7 @@ Note: this feature is not well maintained now.
```shell
# Build the tool.
-bazel build --config opt third_party/tensorflow/core/profiler/...
+bazel build --config opt tensorflow/core/profiler:profiler
# Help information, including detail 'option' instructions.
bazel-bin/tensorflow/core/profiler/profiler help
diff --git a/tensorflow/docs_src/community/welcome.md b/tensorflow/docs_src/community/welcome.md
index 194649a304..4991783a53 100644
--- a/tensorflow/docs_src/community/welcome.md
+++ b/tensorflow/docs_src/community/welcome.md
@@ -37,6 +37,7 @@ Asia:
* [TensorFlow Korea (TF-KR) User Group](https://www.facebook.com/groups/TensorFlowKR/) _(Korean language)_
* [TensorFlow User Group Tokyo](https://tfug-tokyo.connpass.com/) _(Japanese Language)_
* [Soleil Data Dojo](https://soleildatadojo.connpass.com/) _(Japanese language)_
+* [TensorFlow User Group Utsunomiya](https://tfug-utsunomiya.connpass.com/)
Europe:
diff --git a/tensorflow/docs_src/get_started/estimator.md b/tensorflow/docs_src/get_started/estimator.md
index a55454f8af..4f3a438d17 100644
--- a/tensorflow/docs_src/get_started/estimator.md
+++ b/tensorflow/docs_src/get_started/estimator.md
@@ -273,9 +273,7 @@ Then, the code creates a `DNNClassifier` model using the following arguments:
containing 10, 20, and 10 neurons, respectively.
* `n_classes=3`. Three target classes, representing the three Iris species.
* `model_dir=/tmp/iris_model`. The directory in which TensorFlow will save
- checkpoint data during model training. For more on logging and monitoring
- with TensorFlow, see
- @{$monitors$Logging and Monitoring Basics with tf.estimator}.
+ checkpoint data and TensorBoard summaries during model training.
## Describe the training input pipeline {#train-input}
@@ -315,9 +313,7 @@ classifier.train(input_fn=train_input_fn, steps=1000)
However, if you're looking to track the model while it trains, you'll likely
want to instead use a TensorFlow @{tf.train.SessionRunHook$`SessionRunHook`}
-to perform logging operations. See the tutorial
-@{$monitors$Logging and Monitoring Basics with tf.estimator}
-for more on this topic.
+to perform logging operations.
## Evaluate Model Accuracy {#evaluate-accuracy}
diff --git a/tensorflow/docs_src/get_started/index.md b/tensorflow/docs_src/get_started/index.md
index 3e700daa30..003fac1a28 100644
--- a/tensorflow/docs_src/get_started/index.md
+++ b/tensorflow/docs_src/get_started/index.md
@@ -24,8 +24,6 @@ To learn about the high-level API, read the following guides:
API.
* @{$get_started/input_fn$Building Input Functions},
which takes you into a somewhat more sophisticated use of this API.
- * @{$get_started/monitors$Logging and Monitoring Basics with tf.contrib.learn},
- which explains how to audit the progress of model training.
TensorBoard is a utility to visualize different aspects of machine learning.
The following guides explain how to use TensorBoard:
diff --git a/tensorflow/docs_src/get_started/input_fn.md b/tensorflow/docs_src/get_started/input_fn.md
index 422f45c586..7706c07b1d 100644
--- a/tensorflow/docs_src/get_started/input_fn.md
+++ b/tensorflow/docs_src/get_started/input_fn.md
@@ -249,7 +249,7 @@ here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/input_fn/bos
### Importing the Housing Data
-To start, set up your imports (including `pandas` and `tensorflow`) and @{$monitors#enabling-logging-with-tensorflow$set logging verbosity} to
+To start, set up your imports (including `pandas` and `tensorflow`) and set logging verbosity to
`INFO` for more detailed log output:
```python
diff --git a/tensorflow/docs_src/get_started/leftnav_files b/tensorflow/docs_src/get_started/leftnav_files
index b656033f7e..bb67eaddda 100644
--- a/tensorflow/docs_src/get_started/leftnav_files
+++ b/tensorflow/docs_src/get_started/leftnav_files
@@ -5,7 +5,6 @@ mnist/pros.md
mnist/mechanics.md
estimator.md
input_fn.md
-monitors.md
summaries_and_tensorboard.md
graph_viz.md
tensorboard_histograms.md
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index 64b3767735..4680e3ba16 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -5,7 +5,9 @@ package(default_visibility = ["//visibility:private"])
licenses(["notice"]) # Apache 2.0
-load("build_defs", "JAVACOPTS")
+load(":build_defs.bzl", "JAVACOPTS")
+load(":src/gen/gen_ops.bzl", "tf_java_op_gen_srcjar")
+load("//tensorflow:tensorflow.bzl", "tf_copts")
java_library(
name = "tensorflow",
@@ -34,12 +36,57 @@ filegroup(
filegroup(
name = "java_op_sources",
- srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]),
+ srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [
+ ":java_op_gen_sources",
+ ],
visibility = [
"//tensorflow/java:__pkg__",
],
)
+tf_java_op_gen_srcjar(
+ name = "java_op_gen_sources",
+ gen_base_package = "org.tensorflow.op",
+ gen_tool = "java_op_gen_tool",
+ ops_libs = [
+ "array_ops",
+ "candidate_sampling_ops",
+ "control_flow_ops",
+ "data_flow_ops",
+ "image_ops",
+ "io_ops",
+ "linalg_ops",
+ "logging_ops",
+ "math_ops",
+ "nn_ops",
+ "no_op",
+ "parsing_ops",
+ "random_ops",
+ "sparse_ops",
+ "state_ops",
+ "string_ops",
+ "training_ops",
+ "user_ops",
+ ],
+)
+
+# Build the gen tool as a library, as it will be linked to a core/ops binary
+# file before making it an executable. See tf_java_op_gen_srcjar().
+cc_library(
+ name = "java_op_gen_tool",
+ srcs = glob([
+ "src/gen/cc/*.h",
+ "src/gen/cc/*.cc",
+ ]),
+ copts = tf_copts(),
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
java_library(
name = "testutil",
testonly = 1,
diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc
new file mode 100644
index 0000000000..a7c66dda89
--- /dev/null
+++ b/tensorflow/java/src/gen/cc/op_gen_main.cc
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/util/command_line_flags.h"
+#include "tensorflow/java/src/gen/cc/op_generator.h"
+
+namespace tensorflow {
+namespace op_gen {
+
+const char kUsageHeader[] =
+ "\n\nGenerator of operation wrappers in Java.\n\n"
+ "This executable generates wrappers for all registered operations it has "
+ "been compiled with. A wrapper exposes an intuitive and strongly-typed\n"
+ "interface for building its underlying operation and linking it into a "
+ "graph.\n\n"
+ "Operation wrappers are generated under the path specified by the "
+ "'--output_dir' argument. This path can be absolute or relative to the\n"
+ "current working directory and will be created if it does not exists.\n\n"
+ "The '--lib_name' argument is used to classify the set of operations. If "
+ "the chosen name contains more than one word, it must be provided in \n"
+ "snake_case. This value is declined into other meaningful names, such as "
+ "the group and package of the generated operations. For example,\n"
+ "'--lib_name=my_lib' generates the operations under the "
+ "'org.tensorflow.op.mylib' package and add them to the 'myLib()' operator\n"
+ "group.\n\n"
+ "Note that the operator group assigned to the generated wrappers is just "
+ "an annotation tag at this stage. Operations will not be available "
+ "through\n"
+ "the 'org.tensorflow.op.Ops' API as a group until the generated classes "
+ "are compiled using an appropriate annotation processor.\n\n"
+ "Finally, the '--base_package' overrides the default parent package "
+ "under which the generated subpackage and classes are to be located.\n\n";
+
+} // namespace op_gen
+} // namespace tensorflow
+
+int main(int argc, char* argv[]) {
+ tensorflow::string lib_name;
+ tensorflow::string output_dir;
+ tensorflow::string base_package = "org.tensorflow.op";
+ std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("output_dir", &output_dir,
+ "Root directory into which output files are generated"),
+ tensorflow::Flag(
+ "lib_name", &lib_name,
+ "A name, in snake_case, used to classify this set of operations"),
+ tensorflow::Flag(
+ "base_package", &base_package,
+ "Package parent to the generated subpackage and classes")};
+ tensorflow::string usage = tensorflow::op_gen::kUsageHeader;
+ usage += tensorflow::Flags::Usage(argv[0], flag_list);
+ bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
+ QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage;
+
+ tensorflow::OpGenerator generator;
+ tensorflow::OpList ops;
+ tensorflow::OpRegistry::Global()->Export(true, &ops);
+ tensorflow::Status status =
+ generator.Run(ops, lib_name, base_package, output_dir);
+ TF_QCHECK_OK(status);
+
+ return 0;
+}
diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc
new file mode 100644
index 0000000000..df130c32e6
--- /dev/null
+++ b/tensorflow/java/src/gen/cc/op_generator.cc
@@ -0,0 +1,66 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/java/src/gen/cc/op_generator.h"
+
+namespace tensorflow {
+namespace {
+
+string CamelCase(const string& str, char delimiter, bool upper) {
+ string result;
+ bool cap = upper;
+ for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
+ const char c = *it;
+ if (c == delimiter) {
+ cap = true;
+ } else if (cap) {
+ result += toupper(c);
+ cap = false;
+ } else {
+ result += c;
+ }
+ }
+ return result;
+}
+
+} // namespace
+
+OpGenerator::OpGenerator() : env(Env::Default()) {}
+
+OpGenerator::~OpGenerator() {}
+
+Status OpGenerator::Run(const OpList& ops, const string& lib_name,
+ const string& base_package, const string& output_dir) {
+ const string package =
+ base_package + '.' + str_util::StringReplace(lib_name, "_", "", true);
+ const string package_path =
+ output_dir + '/' + str_util::StringReplace(package, ".", "/", true);
+ const string group = CamelCase(lib_name, '_', false);
+
+ if (!env->FileExists(package_path).ok()) {
+ TF_CHECK_OK(env->RecursivelyCreateDir(package_path));
+ }
+
+ LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations";
+ // TODO(karllessard) generate wrappers from list of ops
+
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h
new file mode 100644
index 0000000000..eec1082b51
--- /dev/null
+++ b/tensorflow/java/src/gen/cc/op_generator.h
@@ -0,0 +1,51 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_
+#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+
+/// \brief A generator of Java operation wrappers.
+///
+/// Such generator is normally ran only once per executable, outputting
+/// wrappers for the all registered operations it has been compiled with.
+/// Nonetheless, it is designed to support multiple runs, giving a different
+/// list of operations on each cycle.
+class OpGenerator {
+ public:
+ OpGenerator();
+ virtual ~OpGenerator();
+
+ /// \brief Generates wrappers for the given list of 'ops'.
+ ///
+ /// Output files are generated in <output_dir>/<base_package>/<lib_package>,
+ /// where 'lib_package' is derived from 'lib_name'.
+ Status Run(const OpList& ops, const string& lib_name,
+ const string& base_package, const string& output_dir);
+
+ private:
+ Env* env;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_
diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl
new file mode 100644
index 0000000000..e3710c49d0
--- /dev/null
+++ b/tensorflow/java/src/gen/gen_ops.bzl
@@ -0,0 +1,59 @@
+# -*- Python -*-
+
+load("//tensorflow:tensorflow.bzl", "tf_copts")
+
+# Given a list of "ops_libs" (a list of files in the core/ops directory
+# without their .cc extensions), generate Java wrapper code for all operations
+# found in the ops files.
+# Then, combine all those source files into a single archive (.srcjar).
+#
+# For example:
+# tf_java_op_gen_srcjar("gen_sources", "gen_tool", "my.package", [ "array_ops", "math_ops" ])
+#
+# will create a genrule named "gen_sources" that first generate source files:
+# ops/src/main/java/my/package/array/*.java
+# ops/src/main/java/my/package/math/*.java
+#
+# and then archive those source files in:
+# ops/gen_sources.srcjar
+#
+def tf_java_op_gen_srcjar(name,
+ gen_tool,
+ gen_base_package,
+ ops_libs=[],
+ ops_libs_pkg="//tensorflow/core",
+ out_dir="ops/",
+ out_src_dir="src/main/java/",
+ visibility=["//tensorflow/java:__pkg__"]):
+
+ gen_tools = []
+ gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files
+
+ # Construct an op generator binary for each ops library.
+ for ops_lib in ops_libs:
+ gen_lib = ops_lib[:ops_lib.rfind("_")]
+ out_gen_tool = out_dir + ops_lib + "_gen_tool"
+
+ native.cc_binary(
+ name=out_gen_tool,
+ copts=tf_copts(),
+ linkopts=["-lm"],
+ linkstatic=1, # Faster to link this one-time-use binary dynamically
+ deps=[gen_tool, ops_libs_pkg + ":" + ops_lib + "_op_lib"])
+
+ gen_tools += [":" + out_gen_tool]
+ gen_cmds += ["$(location :" + out_gen_tool + ")" +
+ " --output_dir=$(@D)/" + out_src_dir +
+ " --lib_name=" + gen_lib +
+ " --base_package=" + gen_base_package]
+
+ # Generate a source archive containing generated code for these ops.
+ gen_srcjar = out_dir + name + ".srcjar"
+ gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) ."]
+ gen_tools += ["@local_jdk//:jar"]
+
+ native.genrule(
+ name=name,
+ outs=[gen_srcjar],
+ tools=gen_tools,
+ cmd="&&".join(gen_cmds))
diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc
index c46a3d8db3..a526856794 100644
--- a/tensorflow/python/eager/python_eager_op_gen.cc
+++ b/tensorflow/python/eager/python_eager_op_gen.cc
@@ -659,14 +659,25 @@ void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) {
string GetEagerPythonOps(const OpList& ops,
const std::vector<string>& hidden_ops,
- bool require_shapes) {
+ bool require_shapes,
+ const string& source_file_name = "") {
string result;
// Header
// TODO(josh11b): Mention the library for which wrappers are being generated.
- strings::StrAppend(&result, R"("""Python wrappers for TensorFlow ops.
+ strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
This file is MACHINE GENERATED! Do not edit.
-"""
+)");
+
+ // Mention the original source file so someone tracing back through generated
+ // Python code will know where to look next.
+ if (!source_file_name.empty()) {
+ strings::StrAppend(&result, "Original C++ source file: ");
+ strings::StrAppend(&result, source_file_name);
+ strings::StrAppend(&result, "\n");
+ }
+
+ strings::StrAppend(&result, R"("""
import collections as _collections
@@ -747,8 +758,10 @@ from tensorflow.python.framework import op_def_library as _op_def_library
void PrintEagerPythonOps(const OpList& ops,
const std::vector<string>& hidden_ops,
- bool require_shapes) {
- printf("%s", GetEagerPythonOps(ops, hidden_ops, require_shapes).c_str());
+ bool require_shapes, const string& source_file_name) {
+ printf("%s",
+ GetEagerPythonOps(ops, hidden_ops, require_shapes, source_file_name)
+ .c_str());
}
string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) {
diff --git a/tensorflow/python/eager/python_eager_op_gen.h b/tensorflow/python/eager/python_eager_op_gen.h
index 9a7ed28cf9..250623850f 100644
--- a/tensorflow/python/eager/python_eager_op_gen.h
+++ b/tensorflow/python/eager/python_eager_op_gen.h
@@ -24,9 +24,12 @@ namespace tensorflow {
// hidden_ops should be a list of Op names that should get a leading _
// in the output. Prints the output to stdout.
+// Optional fourth argument is the name of the original C++ source file
+// where the ops' REGISTER_OP() calls reside.
void PrintEagerPythonOps(const OpList& ops,
const std::vector<string>& hidden_ops,
- bool require_shapes);
+ bool require_shapes,
+ const string& source_file_name = "");
// Get the python wrappers for a list of ops in a OpList.
// `op_list_buf` should be a pointer to a buffer containing
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index a8434d0c99..965b35bc4c 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -2474,6 +2474,9 @@ class _IndicatorColumn(_DenseColumn,
sp_ids=id_tensor,
sp_values=weight_tensor,
vocab_size=int(self._variable_shape[-1]))
+ # Remove (?, -1) index
+ weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0],
+ weighted_column.dense_shape)
return sparse_ops.sparse_tensor_to_dense(weighted_column)
dense_id_tensor = sparse_ops.sparse_tensor_to_dense(
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 3057776391..626879f76a 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -3213,13 +3213,39 @@ class IndicatorColumnTest(test.TestCase):
weights = fc.weighted_categorical_column(ids, 'weights')
indicator = fc.indicator_column(weights)
features = {
- 'ids': constant_op.constant(['c', 'b', 'a'], shape=(1, 3)),
- 'weights': constant_op.constant([2., 4., 6.], shape=(1, 3))
+ 'ids': constant_op.constant([['c', 'b', 'a']]),
+ 'weights': constant_op.constant([[2., 4., 6.]])
}
indicator_tensor = _transform_features(features, [indicator])[indicator]
with _initialized_session():
self.assertAllEqual([[6., 4., 2.]], indicator_tensor.eval())
+ def test_transform_with_missing_value_in_weighted_column(self):
+ # Github issue 12583
+ ids = fc.categorical_column_with_vocabulary_list(
+ key='ids', vocabulary_list=('a', 'b', 'c'))
+ weights = fc.weighted_categorical_column(ids, 'weights')
+ indicator = fc.indicator_column(weights)
+ features = {
+ 'ids': constant_op.constant([['c', 'b', 'unknown']]),
+ 'weights': constant_op.constant([[2., 4., 6.]])
+ }
+ indicator_tensor = _transform_features(features, [indicator])[indicator]
+ with _initialized_session():
+ self.assertAllEqual([[0., 4., 2.]], indicator_tensor.eval())
+
+ def test_transform_with_missing_value_in_categorical_column(self):
+ # Github issue 12583
+ ids = fc.categorical_column_with_vocabulary_list(
+ key='ids', vocabulary_list=('a', 'b', 'c'))
+ indicator = fc.indicator_column(ids)
+ features = {
+ 'ids': constant_op.constant([['c', 'b', 'unknown']]),
+ }
+ indicator_tensor = _transform_features(features, [indicator])[indicator]
+ with _initialized_session():
+ self.assertAllEqual([[0., 1., 1.]], indicator_tensor.eval())
+
def test_linear_model(self):
animal = fc.indicator_column(
fc.categorical_column_with_identity('animal', num_buckets=4))
diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc
index 8366542288..f681daa7e4 100644
--- a/tensorflow/python/framework/python_op_gen_main.cc
+++ b/tensorflow/python/framework/python_op_gen_main.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
@@ -80,7 +81,29 @@ Status ParseOpListCommandLine(const char* arg, std::vector<string>* op_list) {
return Status::OK();
}
-void PrintAllPythonOps(const std::vector<string>& op_list, bool require_shapes,
+// Use the name of the current executable to infer the C++ source file
+// where the REGISTER_OP() call for the operator can be found.
+// Returns the name of the file.
+// Returns an empty string if the current executable's name does not
+// follow a known pattern.
+string InferSourceFileName(const char* argv_zero) {
+ StringPiece command_str = io::Basename(argv_zero);
+
+ // For built-in ops, the Bazel build creates a separate executable
+ // with the name gen_<op type>_ops_py_wrappers_cc containing the
+ // operators defined in <op type>_ops.cc
+ const char* kExecPrefix = "gen_";
+ const char* kExecSuffix = "_py_wrappers_cc";
+ if (command_str.Consume(kExecPrefix) && command_str.ends_with(kExecSuffix)) {
+ command_str.remove_suffix(strlen(kExecSuffix));
+ return strings::StrCat(command_str, ".cc");
+ } else {
+ return string("");
+ }
+}
+
+void PrintAllPythonOps(const std::vector<string>& op_list,
+ const string& source_file_name, bool require_shapes,
bool op_list_is_whitelist) {
OpList ops;
OpRegistry::Global()->Export(false, &ops);
@@ -93,9 +116,9 @@ void PrintAllPythonOps(const std::vector<string>& op_list, bool require_shapes,
*pruned_ops.mutable_op()->Add() = op_def;
}
}
- PrintEagerPythonOps(pruned_ops, {}, require_shapes);
+ PrintEagerPythonOps(pruned_ops, {}, require_shapes, source_file_name);
} else {
- PrintEagerPythonOps(ops, op_list, require_shapes);
+ PrintEagerPythonOps(ops, op_list, require_shapes, source_file_name);
}
}
@@ -105,20 +128,26 @@ void PrintAllPythonOps(const std::vector<string>& op_list, bool require_shapes,
int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
+ tensorflow::string source_file_name =
+ tensorflow::InferSourceFileName(argv[0]);
+
// Usage:
// gen_main [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1]
if (argc == 2) {
- tensorflow::PrintAllPythonOps({}, {}, tensorflow::string(argv[1]) == "1");
+ tensorflow::PrintAllPythonOps({}, source_file_name,
+ tensorflow::string(argv[1]) == "1",
+ false /* op_list_is_whitelist */);
} else if (argc == 3) {
std::vector<tensorflow::string> hidden_ops;
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &hidden_ops));
- tensorflow::PrintAllPythonOps(hidden_ops,
+ tensorflow::PrintAllPythonOps(hidden_ops, source_file_name,
tensorflow::string(argv[2]) == "1",
false /* op_list_is_whitelist */);
} else if (argc == 4) {
std::vector<tensorflow::string> op_list;
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &op_list));
- tensorflow::PrintAllPythonOps(op_list, tensorflow::string(argv[2]) == "1",
+ tensorflow::PrintAllPythonOps(op_list, source_file_name,
+ tensorflow::string(argv[2]) == "1",
tensorflow::string(argv[3]) == "1");
} else {
return -1;
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index eea3d28a7e..745428e530 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -236,7 +236,8 @@ def _FilterTuple(v):
def _FilterInt(v):
if isinstance(v, (list, tuple)):
return _FirstNotNone([_FilterInt(x) for x in v])
- return None if isinstance(v, compat.integral_types) else _NotNone(v)
+ return None if isinstance(v, (compat.integral_types,
+ tensor_shape.Dimension)) else _NotNone(v)
def _FilterFloat(v):
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index 2760f98a6b..f66af3adc6 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -314,6 +314,17 @@ class TensorUtilTest(test.TestCase):
shape=[3, 4],
dtype=dtype)))
+ def testIntMixedWithDimension(self):
+ # Github issue: 11974
+ dtype = dtypes.int32
+ nptype = np.int32
+ t = tensor_util.make_tensor_proto(
+ [10, tensor_shape.Dimension(20), 30], dtype=dtype)
+ self.assertEquals(dtype, t.dtype)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(nptype, a.dtype)
+ self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
+
def testLong(self):
t = tensor_util.make_tensor_proto(10, dtype=dtypes.int64)
self.assertProtoEquals("""
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 43827b0d10..d9c5f3bce9 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -684,13 +684,15 @@ cuda_py_test(
tf_py_test(
name = "segment_reduction_ops_test",
- size = "small",
+ size = "medium",
srcs = ["segment_reduction_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
+ "//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:variables",
"//tensorflow/python:nn_grad",
],
)
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index 33269c9123..bf20f5d1a9 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -18,13 +18,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import itertools
+
import numpy as np
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import math_ops
-import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -107,19 +111,19 @@ class SegmentReductionOpTest(SegmentReductionHelper):
curr_ops_list = complex_ops_list
else:
curr_ops_list = ops_list
-
- with self.test_session(use_gpu=False):
- tf_x, np_x = self._input(shape, dtype=dtype)
- for np_op1, np_op2, tf_op in curr_ops_list:
- np_ans = self._segmentReduce(indices, np_x, np_op1, np_op2)
- s = tf_op(data=tf_x, segment_ids=indices)
- tf_ans = s.eval()
- self.assertAllClose(np_ans, tf_ans)
- # NOTE(mrry): The static shape inference that computes
- # `tf_ans.shape` can only infer that sizes from dimension 1
- # onwards, because the size of dimension 0 is data-dependent
- # and may therefore vary dynamically.
- self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:])
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ tf_x, np_x = self._input(shape, dtype=dtype)
+ for np_op1, np_op2, tf_op in curr_ops_list:
+ np_ans = self._segmentReduce(indices, np_x, np_op1, np_op2)
+ s = tf_op(data=tf_x, segment_ids=indices)
+ tf_ans = s.eval()
+ self.assertAllClose(np_ans, tf_ans)
+ # NOTE(mrry): The static shape inference that computes
+ # `tf_ans.shape` can only infer that sizes from dimension 1
+ # onwards, because the size of dimension 0 is data-dependent
+ # and may therefore vary dynamically.
+ self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:])
def testSegmentIdsShape(self):
shape = [4, 4]
@@ -130,41 +134,45 @@ class SegmentReductionOpTest(SegmentReductionHelper):
def testSegmentIdsSize(self):
shape = [4, 4]
- with self.test_session():
- tf_x, _ = self._input(shape)
- indices = [0, 1]
- s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
- with self.assertRaisesOpError("segment_ids should be the same size"):
- s.eval()
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ tf_x, _ = self._input(shape)
+ indices = [0, 1]
+ s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+ with self.assertRaisesOpError("segment_ids should be the same size"):
+ s.eval()
def testSegmentIdsValid(self):
# This is a baseline for the following SegmentIdsInvalid* tests.
shape = [4, 4]
- with self.test_session():
- tf_x, _ = self._input(shape)
- indices = [0, 0, 0, 1]
- result = math_ops.segment_sum(data=tf_x, segment_ids=indices).eval()
- self.assertAllEqual([[15, 18, 21, 24], [13, 14, 15, 16]], result)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
+ indices = [0, 0, 0, 1]
+ result = math_ops.segment_sum(data=tf_x, segment_ids=indices).eval()
+ self.assertAllEqual([[15, 18, 21, 24], [13, 14, 15, 16]], result)
def testSegmentIdsGreaterThanZero(self):
shape = [4, 4]
- with self.test_session():
- tf_x, np_x = self._input(shape)
- indices = [1, 1, 2, 2]
- np_ans = self._segmentReduce(indices, np_x, np.add)
- s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
- tf_ans = s.eval()
- self.assertAllClose(np_ans, tf_ans)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32)
+ indices = [1, 1, 2, 2]
+ np_ans = self._segmentReduce(indices, np_x, np.add)
+ s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+ tf_ans = s.eval()
+ self.assertAllClose(np_ans, tf_ans)
def testSegmentIdsHole(self):
shape = [4, 4]
- with self.test_session():
- tf_x, np_x = self._input(shape)
- indices = [0, 0, 3, 3]
- np_ans = self._segmentReduce(indices, np_x, np.add)
- s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
- tf_ans = s.eval()
- self.assertAllClose(np_ans, tf_ans)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32)
+ indices = [0, 0, 3, 3]
+ np_ans = self._segmentReduce(indices, np_x, np.add)
+ s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+ tf_ans = s.eval()
+ self.assertAllClose(np_ans, tf_ans)
def testSegmentIdsInvalid1(self):
shape = [4, 4]
@@ -199,21 +207,23 @@ class SegmentReductionOpTest(SegmentReductionHelper):
def testSegmentIdsInvalid4(self):
shape = [4, 4]
- with self.test_session():
- tf_x, _ = self._input(shape)
- indices = [0, 0, 0, -1]
- s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
- with self.assertRaisesOpError("segment ids must be >= 0"):
- s.eval()
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
+ indices = [0, 0, 0, -1]
+ s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+ with self.assertRaisesOpError("segment ids must be >= 0"):
+ s.eval()
def testSegmentIdsInvalid5(self):
shape = [4, 4]
- with self.test_session():
- tf_x, _ = self._input(shape)
- indices = [0, 0, 0, -2]
- s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
- with self.assertRaisesOpError("segment ids must be >= 0"):
- s.eval()
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
+ indices = [0, 0, 0, -2]
+ s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+ with self.assertRaisesOpError("segment ids must be >= 0"):
+ s.eval()
def testGradient(self):
shape = [4, 4]
@@ -340,8 +350,8 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
shape = indices.shape + (num_cols,)
with self.test_session(use_gpu=True):
tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64)
- s = math_ops.unsorted_segment_max(data=tf_x, segment_ids=indices,
- num_segments=num_segments)
+ s = math_ops.unsorted_segment_max(
+ data=tf_x, segment_ids=indices, num_segments=num_segments)
jacob_t, jacob_n = gradient_checker.compute_gradient(
tf_x,
shape,
@@ -636,5 +646,67 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
s.eval()
+class SegmentReductionOpBenchmark(test.Benchmark):
+ outer_dim_options = [2**x for x in range(9, 14, 2)]
+ ratio_options = [2**x for x in range(1, 6, 2)]
+ inner_dim_options = [2**x for x in range(9, 14, 2)]
+ # randomly generated sizes with less alignments
+ inner_dim_options += [
+ 1120, 1215, 1856, 1302, 1329, 1531, 1313, 1672, 1851, 1584
+ ]
+ dtype_options = [np.float32, np.float64]
+ options = (outer_dim_options, ratio_options, inner_dim_options, dtype_options)
+ # pylint: disable=g-long-lambda
+ op_functors = [lambda vc, vs, seg_ids:
+ ("sorted", math_ops.segment_sum(vc, vs)),
+ lambda vc, vs, seg_ids:
+ ("unsorted",
+ math_ops.unsorted_segment_sum(vc, vs, seg_ids[-1]+1))]
+ # pylint: enable=g-long-lambda
+ repeat = 10
+
+ def _npTypeToStr(self, t):
+ if t == np.float32:
+ return "fp32"
+ if t == np.float64:
+ return "fp64"
+
+ def _runGraph(self, op_functor, outer_dim, ratio, inner_dim, dtype):
+ output_outer_dim = int(outer_dim / ratio)
+ const = np.random.randint(5, size=(outer_dim, inner_dim))
+ seg_ids = np.sort(np.random.randint(output_outer_dim, size=outer_dim))
+ vs = variables.Variable(seg_ids.astype(np.int32))
+ with ops.device("/gpu:0"):
+ vc = variables.Variable(const.astype(dtype))
+ name, op = op_functor(vc, vs, seg_ids)
+ with session.Session() as sess:
+ variables.global_variables_initializer().run()
+ r = self.run_op_benchmark(
+ sess,
+ op,
+ min_iters=self.repeat,
+ name="_".join(
+ map(str,
+ [name, outer_dim, ratio, inner_dim,
+ self._npTypeToStr(dtype)])))
+ return name, r["wall_time"]
+
+ def benchmarkSegmentSumGPU(self):
+ if not test.is_gpu_available(cuda_only=True):
+ return
+ for outer_dim, ratio, inner_dim, dtype in itertools.product(*self.options):
+ op_functor = self.op_functors[0]
+ with ops.Graph().as_default():
+ self._runGraph(op_functor, outer_dim, ratio, inner_dim, dtype)
+
+ def benchmarkUnsortedSegmentSumGPU(self):
+ if not test.is_gpu_available(cuda_only=True):
+ return
+ for outer_dim, ratio, inner_dim, dtype in itertools.product(*self.options):
+ op_functor = self.op_functors[1]
+ with ops.Graph().as_default():
+ self._runGraph(op_functor, outer_dim, ratio, inner_dim, dtype)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index 5cd5d7ba2f..bd879ac423 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -37,6 +37,7 @@ See the @{$python/io_ops} guide.
@@parse_example
@@parse_single_example
@@parse_tensor
+@@serialize_tensor
@@decode_json_example
@@QueueBase
@@FIFOQueue
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index 803e0e7a1e..c5fd15bae4 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -40,6 +40,7 @@ from tensorflow.python.platform import tf_logging
ops.NotDifferentiable("DecodeRaw")
ops.NotDifferentiable("ParseTensor")
+ops.NotDifferentiable("SerializeTensor")
ops.NotDifferentiable("StringToNumber")
diff --git a/tensorflow/python/profiler/model_analyzer.py b/tensorflow/python/profiler/model_analyzer.py
index 5345949664..a1fe47982f 100644
--- a/tensorflow/python/profiler/model_analyzer.py
+++ b/tensorflow/python/profiler/model_analyzer.py
@@ -117,7 +117,7 @@ class Profiler(object):
```python
Typical use case:
# Currently we are only allowed to create 1 profiler per process.
- profiler = Profile(sess.graph)
+ profiler = Profiler(sess.graph)
for i in xrange(total_steps):
if i % 10000 == 0:
@@ -174,7 +174,7 @@ class Profiler(object):
"""Add statistics of a step.
Args:
- step: A step uint64 used to identify the RunMetadata. Must be different
+ step: int, A step used to identify the RunMetadata. Must be different
across different AddStep() calls.
run_meta: RunMetadata proto that contains statistics of a session run.
"""
diff --git a/tensorflow/python/tools/import_pb_to_tensorboard.py b/tensorflow/python/tools/import_pb_to_tensorboard.py
index a8712fc37e..00de044505 100644
--- a/tensorflow/python/tools/import_pb_to_tensorboard.py
+++ b/tensorflow/python/tools/import_pb_to_tensorboard.py
@@ -51,7 +51,7 @@ def import_to_tensorboard(model_dir, log_dir):
pb_visual_writer = summary.FileWriter(log_dir)
pb_visual_writer.add_graph(sess.graph)
print("Model Imported. Visualize by running: "
- "> tensorboard --logdir={}".format(log_dir))
+ "tensorboard --logdir={}".format(log_dir))
def main(unused_args):
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index ca867dbe3c..8935bcda3d 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -1689,6 +1689,10 @@ tf_module {
argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "serialize_tensor"
+ argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "set_random_seed"
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index e5342cba77..ef342fe127 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -573,11 +573,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
patched_http_archive(
name = "boringssl",
urls = [
- "http://mirror.bazel.build/github.com/google/boringssl/archive/bbcaa15b0647816b9a1a9b9e0d209cd6712f0105.tar.gz",
- "https://github.com/google/boringssl/archive/bbcaa15b0647816b9a1a9b9e0d209cd6712f0105.tar.gz", # 2016-07-11
+ "http://mirror.bazel.build/github.com/google/boringssl/archive/e3860009a091cd1bd2bc189cdbc3c6d095abde84.tar.gz",
+ "https://github.com/google/boringssl/archive/e3860009a091cd1bd2bc189cdbc3c6d095abde84.tar.gz", # 2017-07-07
],
- sha256 = "025264d6e9a7ad371f2f66d17a28b6627de0c9592dc2eb54afd062f68f1f9aa3",
- strip_prefix = "boringssl-bbcaa15b0647816b9a1a9b9e0d209cd6712f0105",
+ sha256 = "02f5950f93c4fd3691771c07c9d04cf2999ab01383ff99da345249e93b0fcfb2",
+ strip_prefix = "boringssl-e3860009a091cd1bd2bc189cdbc3c6d095abde84",
# Add patch to boringssl code to support s390x
patch_file = str(Label("//third_party/boringssl:add_boringssl_s390x.patch")),
)
diff --git a/third_party/boringssl/add_boringssl_s390x.patch b/third_party/boringssl/add_boringssl_s390x.patch
index 9a34a59a1d..8b42d10e68 100644
--- a/third_party/boringssl/add_boringssl_s390x.patch
+++ b/third_party/boringssl/add_boringssl_s390x.patch
@@ -3,9 +3,9 @@ index 7a3adfb..88012ad 100644
--- a/src/include/openssl/base.h
+++ b/src/include/openssl/base.h
@@ -94,6 +94,8 @@ extern "C" {
- #elif defined(__pnacl__)
- #define OPENSSL_32_BIT
#define OPENSSL_PNACL
+ #elif defined(__myriad2__)
+ #define OPENSSL_32_BIT
+#elif defined(__s390x__)
+#define OPENSSL_64_BIT
#else