aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-30 11:09:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-30 12:39:03 -0800
commit979f139dd94e2bf5fba4794536715973b55373c1 (patch)
tree23557ec1a9b3feb3b16789ad0aa49e9f0ff17441 /tensorflow
parenta694f0ca2682f53f89a75707ad1f6c2ddffeacde (diff)
Add py2tf to contrib_py.
PiperOrigin-RevId: 183860192
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/BUILD4
-rw-r--r--tensorflow/compiler/xla/executable_run_options.cc17
-rw-r--r--tensorflow/compiler/xla/executable_run_options.h7
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc8
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_text.cc9
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops_test.py8
6 files changed, 35 insertions, 18 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index c22fd37129..38e39afdc0 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -444,6 +444,10 @@ cc_library(
srcs = ["executable_run_options.cc"],
hdrs = ["executable_run_options.h"],
visibility = ["//visibility:public"],
+ deps = [
+ ":types",
+ "//tensorflow/core:lib",
+ ],
)
cc_library(
diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc
index 392ad9010a..f8bb8e52c7 100644
--- a/tensorflow/compiler/xla/executable_run_options.cc
+++ b/tensorflow/compiler/xla/executable_run_options.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
namespace xla {
ExecutableRunOptions& ExecutableRunOptions::set_device_ordinal(
@@ -87,4 +89,19 @@ const DeviceAssignment* ExecutableRunOptions::device_assignment() const {
return device_assignment_;
}
+string ExecutableRunOptions::ToString() const {
+ return tensorflow::strings::Printf(
+ "ExecutableRunOptions{allocator=%p, device_ordinal=%d, "
+ "device_assignment=%p, stream=%p, inter_op_thread_pool=%p, "
+ "intra_op_thread_pool=%p, execution_profile=%p}",
+ allocator_, device_ordinal_, device_assignment_, stream_,
+ inter_op_thread_pool_, intra_op_thread_pool_, execution_profile_);
+}
+
+std::ostream& operator<<(std::ostream& out,
+ const ExecutableRunOptions& options) {
+ out << options.ToString();
+ return out;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h
index d4fcbf0493..c7a20bb33c 100644
--- a/tensorflow/compiler/xla/executable_run_options.h
+++ b/tensorflow/compiler/xla/executable_run_options.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
#define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
+#include "tensorflow/compiler/xla/types.h"
+
// Intentionally forward declared so that ExecutableRunOptions can be linked
// into an XLA-compiled binary without having to link all of the pointed-to
// objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't
@@ -84,6 +86,8 @@ class ExecutableRunOptions {
DeviceAssignment* device_assignment);
const DeviceAssignment* device_assignment() const;
+ string ToString() const;
+
private:
DeviceMemoryAllocator* allocator_ = nullptr;
int device_ordinal_ = -1;
@@ -94,6 +98,9 @@ class ExecutableRunOptions {
ExecutionProfile* execution_profile_ = nullptr;
};
+std::ostream& operator<<(std::ostream& out,
+ const ExecutableRunOptions& options);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
index b82f1c81c8..4ad356d045 100644
--- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
@@ -85,12 +85,10 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
for (int i = 0; i < program_shape->parameters_size(); ++i) {
layouts.push_back(&program_shape->parameters(i));
}
- ExecutableBuildOptions build_options;
- build_options.set_device_ordinal(0);
- build_options.set_result_layout(program_shape->result());
StatusOr<std::unique_ptr<Executable>> executable =
- local_service->CompileExecutable(computation.handle(), layouts,
- build_options);
+ local_service->CompileExecutable(
+ computation.handle(), layouts, &program_shape->result(),
+ /*device_ordinal=*/0, /*device_allocator=*/nullptr);
const HloModule& module = executable.ValueOrDie()->module();
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
index 05c0fdf97d..5ebb75a31c 100644
--- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
@@ -60,13 +60,10 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
for (int i = 0; i < program_shape->parameters_size(); ++i) {
layouts.push_back(&program_shape->parameters(i));
}
-
- ExecutableBuildOptions build_options;
- build_options.set_device_ordinal(0);
- build_options.set_result_layout(program_shape->result());
StatusOr<std::unique_ptr<Executable>> executable =
- local_service->CompileExecutable(computation.handle(), layouts,
- build_options);
+ local_service->CompileExecutable(
+ computation.handle(), layouts, &program_shape->result(),
+ /*device_ordinal=*/0, /*device_allocator=*/nullptr);
const HloModule& module = executable.ValueOrDie()->module();
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
index 1417772e04..9d0f95e6f3 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
@@ -27,7 +27,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
-from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -275,7 +274,6 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
-@test_util.with_c_api
class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testNoneWeightRaisesValueError(self):
@@ -473,11 +471,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
labels = constant_op.constant([[0, 1], [2, 3]])
weights = constant_op.constant([1.2, 3.4, 5.6, 7.8])
- if ops._USE_C_API:
- error_type = ValueError
- else:
- error_type = errors_impl.InvalidArgumentError
- with self.assertRaises(error_type):
+ with self.assertRaises(errors_impl.InvalidArgumentError):
loss_ops.sparse_softmax_cross_entropy(
logits, labels, weights=weights).eval()