aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/interpreter
diff options
context:
space:
mode:
authorGravatar Kay Zhu <kayzhu@google.com>2017-12-15 18:11:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-15 18:15:09 -0800
commit713d45278491d792c525344de6038a61ebcb2136 (patch)
tree4c4d4063b82d415a448e0ff4646fbf6f88636876 /tensorflow/compiler/xla/service/interpreter
parent9648f8040a559f6cf9bbe0501ba96f2b2c2864b1 (diff)
[XLA] Support Map in HloEvaluator, enable Interpreter to run
xla/tests:map_map_test which tests this change. Additionally: - templatize Evaluate* methods to specialize on both std::unique_ptr<Literal> and const Literal* type of input literal arguments. - add ResetVisitState to DfsHloVisitor such that a visitor instance can traverse the same HLO graph more than once. PiperOrigin-RevId: 179263540
Diffstat (limited to 'tensorflow/compiler/xla/service/interpreter')
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc7
1 files changed, 3 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 9183a1d1bf..293cc2007e 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -98,12 +98,10 @@ StatusOr<se::DeviceMemoryBase> InterpreterExecutable::ExecuteOnStream(
// Create the arguments as an vector of XLA literals
std::vector<std::unique_ptr<Literal>> arg_literals;
- std::vector<Literal*> arg_literals_ptrs;
for (int64 p = 0; p < computation->num_parameters(); ++p) {
// Create the input literal for the parameter
HloInstruction* param = computation->parameter_instruction(p);
arg_literals.emplace_back(Literal::CreateFromShape(param->shape()));
- arg_literals_ptrs.push_back(arg_literals.back().get());
// Copy in the data from the stream_executor buffers
void* buffer = arg_literals.back()->MutableInternalData();
@@ -113,8 +111,9 @@ StatusOr<se::DeviceMemoryBase> InterpreterExecutable::ExecuteOnStream(
// Execute the graph using the HloEvaluator.
HloEvaluator evaluator;
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> output,
- evaluator.Evaluate(*computation, arg_literals_ptrs));
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Literal> output,
+ evaluator.Evaluate<std::unique_ptr<Literal>>(*computation, arg_literals));
// Copy the result into the return buffer
perftools::gputools::StreamExecutor* executor(stream->parent());