diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-12 18:12:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-12 18:16:13 -0700 |
commit | 9f10f60fbd9fefaf225c1985014010b6b2f738c1 (patch) | |
tree | 1054fd871c9b7703d74e9c6d017543ee30e907b3 /tensorflow/core/debug | |
parent | 995f5f4f40721e177f94cf954060619111e5cc4c (diff) |
Minor modernizations, mostly more <memory>
PiperOrigin-RevId: 158793461
Diffstat (limited to 'tensorflow/core/debug')
-rw-r--r-- | tensorflow/core/debug/debug_gateway_test.cc | 20 | ||||
-rw-r--r-- | tensorflow/core/debug/grpc_session_debug_test.cc | 22 |
2 files changed, 22 insertions, 20 deletions
diff --git a/tensorflow/core/debug/debug_gateway_test.cc b/tensorflow/core/debug/debug_gateway_test.cc index adbb1b2116..f25d91a3c2 100644 --- a/tensorflow/core/debug/debug_gateway_test.cc +++ b/tensorflow/core/debug/debug_gateway_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include <algorithm> #include <cstdlib> +#include <memory> #include <unordered_map> #include "tensorflow/core/debug/debug_graph_utils.h" @@ -29,14 +30,15 @@ limitations under the License. namespace tensorflow { namespace { -DirectSession* CreateSession() { +std::unique_ptr<DirectSession> CreateSession() { SessionOptions options; // Turn off graph optimizer so we can observe intermediate node states. options.config.mutable_graph_options() ->mutable_optimizer_options() ->set_opt_level(OptimizerOptions_Level_L0); - return dynamic_cast<DirectSession*>(NewSession(options)); + return std::unique_ptr<DirectSession>( + dynamic_cast<DirectSession*>(NewSession(options))); } class SessionDebugMinusAXTest : public ::testing::Test { @@ -85,7 +87,7 @@ class SessionDebugMinusAXTest : public ::testing::Test { TEST_F(SessionDebugMinusAXTest, RunSimpleNetwork) { Initialize({3, 2, -1, 0}); - std::unique_ptr<DirectSession> session(CreateSession()); + auto session = CreateSession(); ASSERT_TRUE(session != nullptr); DebugGateway debug_gateway(session.get()); @@ -220,7 +222,7 @@ TEST_F(SessionDebugMinusAXTest, RunSimpleNetwork) { TEST_F(SessionDebugMinusAXTest, RunSimpleNetworkWithTwoDebugNodesInserted) { // Tensor contains one count of NaN Initialize({3, std::numeric_limits<float>::quiet_NaN(), -1, 0}); - std::unique_ptr<DirectSession> session(CreateSession()); + auto session = CreateSession(); ASSERT_TRUE(session != nullptr); DebugGateway debug_gateway(session.get()); @@ -350,7 +352,7 @@ TEST_F(SessionDebugMinusAXTest, // Test concurrent Run() calls on a graph with different debug watches. Initialize({3, 2, -1, 0}); - std::unique_ptr<DirectSession> session(CreateSession()); + auto session = CreateSession(); ASSERT_TRUE(session != nullptr); TF_ASSERT_OK(session->Create(def_)); @@ -537,7 +539,7 @@ class SessionDebugOutputSlotWithoutOngoingEdgeTest : public ::testing::Test { TEST_F(SessionDebugOutputSlotWithoutOngoingEdgeTest, WatchSlotWithoutOutgoingEdge) { Initialize(); - std::unique_ptr<DirectSession> session(CreateSession()); + auto session = CreateSession(); ASSERT_TRUE(session != nullptr); DebugGateway debug_gateway(session.get()); @@ -662,7 +664,7 @@ class SessionDebugVariableTest : public ::testing::Test { TEST_F(SessionDebugVariableTest, WatchUninitializedVariableWithDebugOps) { Initialize(); - std::unique_ptr<DirectSession> session(CreateSession()); + auto session = CreateSession(); ASSERT_TRUE(session != nullptr); DebugGateway debug_gateway(session.get()); @@ -741,7 +743,7 @@ TEST_F(SessionDebugVariableTest, WatchUninitializedVariableWithDebugOps) { TEST_F(SessionDebugVariableTest, VariableAssignWithDebugOps) { // Tensor contains one count of NaN Initialize(); - std::unique_ptr<DirectSession> session(CreateSession()); + auto session = CreateSession(); ASSERT_TRUE(session != nullptr); DebugGateway debug_gateway(session.get()); @@ -917,7 +919,7 @@ class SessionDebugGPUSwitchTest : public ::testing::Test { // Test for debug-watching tensors marked as HOST_MEMORY on GPU. TEST_F(SessionDebugGPUSwitchTest, RunSwitchWithHostMemoryDebugOp) { Initialize(); - std::unique_ptr<DirectSession> session(CreateSession()); + auto session = CreateSession(); ASSERT_TRUE(session != nullptr); DebugGateway debug_gateway(session.get()); diff --git a/tensorflow/core/debug/grpc_session_debug_test.cc b/tensorflow/core/debug/grpc_session_debug_test.cc index 120c093c9b..3827596a67 100644 --- a/tensorflow/core/debug/grpc_session_debug_test.cc +++ b/tensorflow/core/debug/grpc_session_debug_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/grpc_session.h" +#include <memory> + #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/debug/debug_io_utils.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h" @@ -37,8 +39,9 @@ limitations under the License. #include "tensorflow/core/util/port.h" namespace tensorflow { +namespace { -static SessionOptions Devices(int num_cpus, int num_gpus) { +SessionOptions Devices(int num_cpus, int num_gpus) { SessionOptions result; (*result.config.mutable_device_count())["CPU"] = num_cpus; (*result.config.mutable_device_count())["GPU"] = num_gpus; @@ -67,13 +70,13 @@ void CreateGraphDef(GraphDef* graph_def, string node_names[3]) { // Asserts that "val" is a single float tensor. The only float is // "expected_val". -static void IsSingleFloatValue(const Tensor& val, float expected_val) { +void IsSingleFloatValue(const Tensor& val, float expected_val) { ASSERT_EQ(val.dtype(), DT_FLOAT); ASSERT_EQ(val.NumElements(), 1); ASSERT_EQ(val.flat<float>()(0), expected_val); } -static SessionOptions Options(const string& target, int placement_period) { +SessionOptions Options(const string& target, int placement_period) { SessionOptions options; // NOTE(mrry): GrpcSession requires a grpc:// scheme prefix in the target // string. @@ -85,8 +88,8 @@ static SessionOptions Options(const string& target, int placement_period) { return options; } -static Session* NewRemote(const SessionOptions& options) { - return CHECK_NOTNULL(NewSession(options)); +std::unique_ptr<Session> NewRemote(const SessionOptions& options) { + return std::unique_ptr<Session>(CHECK_NOTNULL(NewSession(options))); } class GrpcSessionDebugTest : public ::testing::Test { @@ -149,9 +152,7 @@ TEST_F(GrpcSessionDebugTest, FileDebugURL) { std::unique_ptr<test::TestCluster> cluster; TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster)); - std::unique_ptr<Session> session( - NewRemote(Options(cluster->targets()[0], 1))); - ASSERT_TRUE(session != nullptr); + auto session = NewRemote(Options(cluster->targets()[0], 1)); TF_CHECK_OK(session->Create(graph)); // Iteration 0: No watch. @@ -220,9 +221,7 @@ void SetDevice(GraphDef* graph, const string& name, const string& dev) { TEST_F(GrpcSessionDebugTest, MultiDevices_String) { std::unique_ptr<test::TestCluster> cluster; TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 1), 2, &cluster)); - std::unique_ptr<Session> session( - NewRemote(Options(cluster->targets()[0], 1000))); - ASSERT_TRUE(session != nullptr); + auto session = NewRemote(Options(cluster->targets()[0], 1000)); // b = a Graph graph(OpRegistry::Global()); @@ -289,4 +288,5 @@ TEST_F(GrpcSessionDebugTest, MultiDevices_String) { } } +} // namespace } // namespace tensorflow |