aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sami Kama <skama@nvidia.com>2018-06-18 17:54:49 -0700
committerGravatar Sami Kama <skama@nvidia.com>2018-06-18 17:54:49 -0700
commit84a1f27d79f444cd865b6c46787bc650c6ff90ec (patch)
treea14e6142ad6c0677f38e04296f77ffa1b8df583d
parent78cef7962be702532cb1998b291c6624f803aa3f (diff)
Workaround Grappler funcdef optimization issue
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc4
-rw-r--r--tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc12
-rw-r--r--tensorflow/contrib/tensorrt/test/test_tftrt.py1
3 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index f19a8cd4bd..c17ef5fdab 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -479,7 +479,7 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
node_builder.Device(info.device);
}
if (VLOG_IS_ON(1)) {
- string ins(info.engine_name);
+ string ins=StrCat(info.engine_name," inputs= ");
for (const auto& ii : inputs) {
StrAppend(&ins, ii.node, ":", ii.index, " ");
}
@@ -623,6 +623,7 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary(
VLOG(7) << name << " Function_Def ";
VLOG(7) << native_segment->DebugString();
}
+ VLOG(1)<<"Adding funcdef to graphlib";
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdeflib));
return tensorflow::Status::OK();
}
@@ -813,6 +814,7 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
cudaSetDevice(old_cuda_device);
graph.ToGraphDef(params.output_graph_def);
for (auto tn : trt_nodes) delete tn;
+ VLOG(1)<<"Returning from conversion";
return tensorflow::Status::OK();
}
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
index 6d0fd7a44b..ec9dbfa13b 100644
--- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
@@ -191,6 +191,17 @@ tensorflow::Status TRTOptimizationPass::Optimize(
if (VLOG_IS_ON(1)) {
PrintDebugInfo(cluster, item);
}
+ // This is a hack to workaround optimizer issue. MetaOptimizer calls
+ // optimization passes on function objects as well, we should not modify
+ // generated funcdefs! This is fragile but we don't have any other option
+ // until framework fixes it.
+ if (item.id != "tf_graph") {
+ LOG(WARNING) << name_
+ << " is probably called on funcdef! This optimizer must *NOT* "
+ "be called on function objects.";
+ *optimized_graph = item.graph;
+ return tensorflow::Status::OK();
+ }
int max_dim = -1;
if (item.feed.size()) {
for (const auto& f : item.feed) {
@@ -235,6 +246,7 @@ tensorflow::Status TRTOptimizationPass::Optimize(
cp.max_cached_engines = max_cached_batches_;
auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(cp);
VLOG(2) << optimized_graph->DebugString();
+ VLOG(1) << "Returning from " << name_;
return status;
}
diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py
index 85f37aa899..12e84f7d3c 100644
--- a/tensorflow/contrib/tensorrt/test/test_tftrt.py
+++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py
@@ -236,6 +236,7 @@ def auto(multi_engine):
orig_graph = get_simple_graph_def() # use a frozen graph for inference
dummy_input = np.random.random_sample(inp_dims)
opt_config = rwpb2.RewriterConfig()
+ opt_config.meta_optimizer_iterations=opt_config.ONE
opt_config.optimizers.extend(["constfold", "layout"])
custom_op = opt_config.custom_optimizers.add()
custom_op.name = "TensorRTOptimizer"