aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/tensorrt_test.cc
diff options
context:
space:
mode:
authorGravatar Guangda Lai <laigd@google.com>2018-04-26 13:12:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-26 13:14:59 -0700
commit5f06514bff4061b839ee71847a299adbef9e7e03 (patch)
tree5875824ef64cf4c72c09a1e38223312ca2559cbc /tensorflow/contrib/tensorrt/tensorrt_test.cc
parent38244c353a7b91563b27c816105165833f5bb462 (diff)
Fix build by adding op_lib dependencies to trt_engine_op_loader, and remove
unnecessary dependency from the tf_gen_op_libs. PiperOrigin-RevId: 194442728
Diffstat (limited to 'tensorflow/contrib/tensorrt/tensorrt_test.cc')
-rw-r--r--tensorflow/contrib/tensorrt/tensorrt_test.cc8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/contrib/tensorrt/tensorrt_test.cc
index e11522ea5b..3712a9a6fe 100644
--- a/tensorflow/contrib/tensorrt/tensorrt_test.cc
+++ b/tensorflow/contrib/tensorrt/tensorrt_test.cc
@@ -95,9 +95,9 @@ nvinfer1::IHostMemory* CreateNetwork() {
}
// Executes the network.
-void Execute(nvinfer1::IExecutionContext& context, const float* input,
+void Execute(nvinfer1::IExecutionContext* context, const float* input,
float* output) {
- const nvinfer1::ICudaEngine& engine = context.getEngine();
+ const nvinfer1::ICudaEngine& engine = context->getEngine();
// We have two bindings: input and output.
ASSERT_EQ(engine.getNbBindings(), 2);
@@ -118,7 +118,7 @@ void Execute(nvinfer1::IExecutionContext& context, const float* input,
// could be removed.
ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float),
cudaMemcpyHostToDevice, stream));
- context.enqueue(1, buffers, stream, nullptr);
+ context->enqueue(1, buffers, stream, nullptr);
ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float),
cudaMemcpyDeviceToHost, stream));
cudaStreamSynchronize(stream);
@@ -143,7 +143,7 @@ TEST(TensorrtTest, BasicFunctions) {
// Execute the network.
float input = 1234;
float output;
- Execute(*context, &input, &output);
+ Execute(context, &input, &output);
EXPECT_EQ(output, input * 2 + 3);
// Destroy the engine.