aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-04-23 21:19:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-23 21:21:38 -0700
commit22f3a97b8b089202f60bb0c7697feb0c8e0713cc (patch)
treed16f95826e4be15bbb3b0f22bed0ca25d3eb5897 /tensorflow/tools/graph_transforms
parent24b7c9a800ab5086d45a7d83ebcd6218424dc9e3 (diff)
Merge changes from github.
PiperOrigin-RevId: 194031845
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/transform_graph.cc70
1 files changed, 62 insertions, 8 deletions
diff --git a/tensorflow/tools/graph_transforms/transform_graph.cc b/tensorflow/tools/graph_transforms/transform_graph.cc
index 28387c2b48..8ce8f5e24b 100644
--- a/tensorflow/tools/graph_transforms/transform_graph.cc
+++ b/tensorflow/tools/graph_transforms/transform_graph.cc
@@ -24,6 +24,9 @@ limitations under the License.
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/file_utils.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
+#if !defined(PLATFORM_WINDOWS)
+#include <pwd.h>
+#endif
namespace tensorflow {
namespace graph_transforms {
@@ -130,16 +133,64 @@ Status ParseTransformParameters(const string& transforms_string,
return Status::OK();
}
+std::string ExpandPath(const std::string& path_string) {
+#if defined(PLATFORM_WINDOWS)
+ return path_string;
+#else
+ if (path_string.empty() || path_string[0] != '~') {
+ return path_string;
+ }
+
+ const char* home = NULL;
+ std::string::size_type prefix = path_string.find_first_of('/');
+ if (path_string.length() == 1 || prefix == 1) {
+ // The value of $HOME, e.g., ~/foo
+ home = getenv("HOME");
+ if (!home) {
+ // If HOME is not available, get uid
+ struct passwd* pw = getpwuid(getuid());
+ if (pw) {
+ home = pw->pw_dir;
+ }
+ }
+ } else {
+ // The value of ~user, e.g., ~user/foo
+ std::string user(path_string, 1, (prefix == std::string::npos)
+ ? std::string::npos
+ : prefix - 1);
+ struct passwd* pw = getpwnam(user.c_str());
+ if (pw) {
+ home = pw->pw_dir;
+ }
+ }
+
+ if (!home) {
+ return path_string;
+ }
+
+ string path(home);
+ if (prefix == std::string::npos) {
+ return path;
+ }
+
+ if (path.length() == 0 || path[path.length() - 1] != '/') {
+ path += '/';
+ }
+ path += path_string.substr(prefix + 1);
+ return path;
+#endif
+}
+
int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
- string in_graph = "";
- string out_graph = "";
+ string in_graph_string = "";
+ string out_graph_string = "";
string inputs_string = "";
string outputs_string = "";
string transforms_string = "";
bool output_as_text = false;
std::vector<Flag> flag_list = {
- Flag("in_graph", &in_graph, "input graph file name"),
- Flag("out_graph", &out_graph, "output graph file name"),
+ Flag("in_graph", &in_graph_string, "input graph file name"),
+ Flag("out_graph", &out_graph_string, "output graph file name"),
Flag("inputs", &inputs_string, "inputs"),
Flag("outputs", &outputs_string, "outputs"),
Flag("transforms", &transforms_string, "list of transforms"),
@@ -166,11 +217,11 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
LOG(ERROR) << "Unknown argument " << argv[1] << ".\n" << usage;
return -1;
}
- if (in_graph.empty()) {
+ if (in_graph_string.empty()) {
LOG(ERROR) << "in_graph graph can't be empty.\n" << usage;
return -1;
}
- if (out_graph.empty()) {
+ if (out_graph_string.empty()) {
LOG(ERROR) << "out_graph graph can't be empty.\n" << usage;
return -1;
}
@@ -179,6 +230,9 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
return -1;
}
+ string in_graph = ExpandPath(in_graph_string);
+ string out_graph = ExpandPath(out_graph_string);
+
std::vector<string> inputs = str_util::Split(inputs_string, ',');
std::vector<string> outputs = str_util::Split(outputs_string, ',');
TransformParameters transform_params;
@@ -197,7 +251,7 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
GraphDef graph_def;
Status load_status = LoadTextOrBinaryGraphFile(in_graph, &graph_def);
if (!load_status.ok()) {
- LOG(ERROR) << "Loading graph '" << in_graph << "' failed with "
+ LOG(ERROR) << "Loading graph '" << in_graph_string << "' failed with "
<< load_status.error_message();
LOG(ERROR) << usage;
return -1;
@@ -219,7 +273,7 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def);
}
if (!save_status.ok()) {
- LOG(ERROR) << "Saving graph '" << out_graph << "' failed with "
+ LOG(ERROR) << "Saving graph '" << out_graph_string << "' failed with "
<< save_status.error_message();
return -1;
}