diff options
author | 2018-06-18 17:54:49 -0700 | |
---|---|---|
committer | 2018-06-18 17:54:49 -0700 | |
commit | 84a1f27d79f444cd865b6c46787bc650c6ff90ec (patch) | |
tree | a14e6142ad6c0677f38e04296f77ffa1b8df583d | |
parent | 78cef7962be702532cb1998b291c6624f803aa3f (diff) |
Workaround Grappler funcdef optimization issue
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_graph.cc | 4 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc | 12 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/test/test_tftrt.py | 1 |
3 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index f19a8cd4bd..c17ef5fdab 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -479,7 +479,7 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, node_builder.Device(info.device); } if (VLOG_IS_ON(1)) { - string ins(info.engine_name); + string ins=StrCat(info.engine_name," inputs= "); for (const auto& ii : inputs) { StrAppend(&ins, ii.node, ":", ii.index, " "); } @@ -623,6 +623,7 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( VLOG(7) << name << " Function_Def "; VLOG(7) << native_segment->DebugString(); } + VLOG(1)<<"Adding funcdef to graphlib"; TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdeflib)); return tensorflow::Status::OK(); } @@ -813,6 +814,7 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { cudaSetDevice(old_cuda_device); graph.ToGraphDef(params.output_graph_def); for (auto tn : trt_nodes) delete tn; + VLOG(1)<<"Returning from conversion"; return tensorflow::Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index 6d0fd7a44b..ec9dbfa13b 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -191,6 +191,17 @@ tensorflow::Status TRTOptimizationPass::Optimize( if (VLOG_IS_ON(1)) { PrintDebugInfo(cluster, item); } + // This is a hack to workaround optimizer issue. MetaOptimizer calls + // optimization passes on function objects as well, we should not modify + // generated funcdefs! This is fragile but we don't have any other option + // until framework fixes it. + if (item.id != "tf_graph") { + LOG(WARNING) << name_ + << " is probably called on funcdef! This optimizer must *NOT* " + "be called on function objects."; + *optimized_graph = item.graph; + return tensorflow::Status::OK(); + } int max_dim = -1; if (item.feed.size()) { for (const auto& f : item.feed) { @@ -235,6 +246,7 @@ tensorflow::Status TRTOptimizationPass::Optimize( cp.max_cached_engines = max_cached_batches_; auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(cp); VLOG(2) << optimized_graph->DebugString(); + VLOG(1) << "Returning from " << name_; return status; } diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py index 85f37aa899..12e84f7d3c 100644 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py @@ -236,6 +236,7 @@ def auto(multi_engine): orig_graph = get_simple_graph_def() # use a frozen graph for inference dummy_input = np.random.random_sample(inp_dims) opt_config = rwpb2.RewriterConfig() + opt_config.meta_optimizer_iterations=opt_config.ONE opt_config.optimizers.extend(["constfold", "layout"]) custom_op = opt_config.custom_optimizers.add() custom_op.name = "TensorRTOptimizer" |