aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-26 22:55:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-26 22:58:44 -0700
commit40dee372e3ee844c4746baa914c07b9c582a2ce7 (patch)
treebd39a01c0aad8a6cfc8e5d4205674b5a8892133d /tensorflow/core/grappler/costs
parent680c2f5d988fb1f3b725fb8f0a67d1926be8169b (diff)
Define OpContext and use it for OpLevelCostEstimator.
This CL does not add any functionality (except GraphDef's function library pointer is passed to OpContext), but we can later add additional fields to OpContext struct for extending VirtualCluster, Scheduler, Placer, and others. PiperOrigin-RevId: 170157235
Diffstat (limited to 'tensorflow/core/grappler/costs')
-rw-r--r--tensorflow/core/grappler/costs/BUILD12
-rw-r--r--tensorflow/core/grappler/costs/analytical_cost_estimator.cc11
-rw-r--r--tensorflow/core/grappler/costs/op_context.h39
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc47
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h23
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc99
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc22
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.h12
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc21
9 files changed, 173 insertions, 113 deletions
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index 678a37b5bc..1d0bd42372 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -195,12 +195,23 @@ tf_cc_test(
)
cc_library(
+ name = "op_context",
+ hdrs = ["op_context.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":op_performance_data_cc",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
name = "virtual_scheduler",
srcs = ["virtual_scheduler.cc"],
hdrs = ["virtual_scheduler.h"],
visibility = ["//visibility:public"],
deps = [
":graph_properties",
+ ":op_context",
":utils",
":virtual_placer",
"//tensorflow/core:framework",
@@ -256,6 +267,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":cost_estimator",
+ ":op_context",
":op_performance_data_cc",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
index 569efaf96d..91b6686971 100644
--- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
@@ -70,11 +70,10 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
Costs node_costs;
do {
- NodeInfo node_info = scheduler.GetCurrNodeInfo();
- auto& op_info = node_info.op_info;
- const string& op_name = node_info.name;
+ OpContext op_context = scheduler.GetCurrNode();
+ const string& op_name = op_context.name;
- node_costs = node_estimator_->PredictCosts(op_info);
+ node_costs = node_estimator_->PredictCosts(op_context);
if (node_costs.inaccurate) {
inaccurate_nodes.push_back(op_name);
}
@@ -87,14 +86,14 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
cost_node = cost_graph->add_node();
cost_node->set_name(op_name);
}
- cost_node->set_device(node_info.device_name);
+ cost_node->set_device(op_context.device_name);
cost_node->set_compute_cost(
node_costs.execution_time.asMicroSeconds().count());
cost_node->set_compute_time(
node_costs.compute_time.asMicroSeconds().count());
cost_node->set_memory_time(
node_costs.memory_time.asMicroSeconds().count());
- for (const auto& output : node_info.op_info.outputs()) {
+ for (const auto& output : op_context.op_info.outputs()) {
auto output_info = cost_node->add_output_info();
output_info->set_dtype(output.dtype());
auto shape = output_info->mutable_shape();
diff --git a/tensorflow/core/grappler/costs/op_context.h b/tensorflow/core/grappler/costs/op_context.h
new file mode 100644
index 0000000000..735a1e68ea
--- /dev/null
+++ b/tensorflow/core/grappler/costs/op_context.h
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// A structure to keep the context of op execution, including its shape,
+// execution context, and other relevant information.
+struct OpContext {
+ string name;
+ string device_name;
+ OpInfo op_info;
+ const FunctionDefLibrary* function_library; // Not owned.
+
+ OpContext() { function_library = nullptr; }
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index fbafed7c1f..b25def7612 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -142,10 +142,12 @@ int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
OpLevelCostEstimator::OpLevelCostEstimator() {
// Syntactic sugar to build and return a lambda that takes an OpInfo and
// returns a cost.
- typedef Costs (OpLevelCostEstimator::*CostImpl)(const OpInfo& op_feature)
+ typedef Costs (OpLevelCostEstimator::*CostImpl)(const OpContext& op_context)
const;
- auto wrap = [this](CostImpl impl) -> std::function<Costs(const OpInfo&)> {
- return [this, impl](const OpInfo& op) { return (this->*impl)(op); };
+ auto wrap = [this](CostImpl impl) -> std::function<Costs(const OpContext&)> {
+ return [this, impl](const OpContext& op_context) {
+ return (this->*impl)(op_context);
+ };
};
device_cost_impl_ = {
@@ -272,18 +274,19 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
compute_memory_overlap_ = false;
}
-Costs OpLevelCostEstimator::PredictCosts(const OpInfo& op_features) const {
+Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
auto it = device_cost_impl_.find(op_features.op());
if (it == device_cost_impl_.end()) {
if (elementwise_ops_.find(op_features.op()) != elementwise_ops_.end()) {
- return PredictCwiseOp(op_features);
+ return PredictCwiseOp(op_context);
}
VLOG(1) << "Missing implementation for op: " << op_features.op();
- return DummyExecutionTime(op_features);
+ return DummyExecutionTime(op_context);
}
- std::function<Costs(const OpInfo&)> estimator = it->second;
- Costs costs = estimator(op_features);
+ std::function<Costs(const OpContext&)> estimator = it->second;
+ Costs costs = estimator(op_context);
VLOG(1) << "Operation " << op_features.op() << " takes "
<< costs.execution_time.count() << " ns.";
return costs;
@@ -336,7 +339,8 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
return std::make_pair(gflops, bandwidth);
}
-Costs OpLevelCostEstimator::PredictCwiseOp(const OpInfo& op_features) const {
+Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
bool found_unknown_shapes = false;
// For unary or binary element-wise operations, op count is the element count
// of any input. We use the count for the largest input here to be more robust
@@ -369,9 +373,9 @@ Costs OpLevelCostEstimator::PredictCwiseOp(const OpInfo& op_features) const {
}
Costs OpLevelCostEstimator::DummyExecutionTime(
- const OpInfo& op_features) const {
+ const OpContext& op_context) const {
// Use CwiseOp time as an estimation
- auto costs = PredictCwiseOp(op_features);
+ auto costs = PredictCwiseOp(op_context);
costs.inaccurate = true;
return costs;
}
@@ -806,7 +810,8 @@ int64 OpLevelCostEstimator::CalculateOutputSize(
return total_output_size;
}
-Costs OpLevelCostEstimator::PredictConv2D(const OpInfo& op_features) const {
+Costs OpLevelCostEstimator::PredictConv2D(const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
bool found_unknown_shapes = false;
auto costs = PredictOpCountBasedCost(
CountConv2DOperations(op_features, &found_unknown_shapes), op_features);
@@ -815,7 +820,8 @@ Costs OpLevelCostEstimator::PredictConv2D(const OpInfo& op_features) const {
}
Costs OpLevelCostEstimator::PredictConv2DBackpropInput(
- const OpInfo& op_features) const {
+ const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
bool found_unknown_shapes = false;
auto costs =
PredictOpCountBasedCost(CountConv2DBackpropInputOperations(
@@ -826,7 +832,8 @@ Costs OpLevelCostEstimator::PredictConv2DBackpropInput(
}
Costs OpLevelCostEstimator::PredictConv2DBackpropFilter(
- const OpInfo& op_features) const {
+ const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
bool found_unknown_shapes = false;
auto costs =
PredictOpCountBasedCost(CountConv2DBackpropFilterOperations(
@@ -836,7 +843,8 @@ Costs OpLevelCostEstimator::PredictConv2DBackpropFilter(
return costs;
}
-Costs OpLevelCostEstimator::PredictMatMul(const OpInfo& op_features) const {
+Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
bool found_unknown_shapes = false;
auto costs = PredictOpCountBasedCost(
CountMatMulOperations(op_features, &found_unknown_shapes), op_features);
@@ -844,13 +852,15 @@ Costs OpLevelCostEstimator::PredictMatMul(const OpInfo& op_features) const {
return costs;
}
-Costs OpLevelCostEstimator::PredictNoOp(const OpInfo& op_features) const {
+Costs OpLevelCostEstimator::PredictNoOp(const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)";
return Costs::ZeroCosts();
}
Costs OpLevelCostEstimator::PredictBatchMatMul(
- const OpInfo& op_features) const {
+ const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
bool found_unknown_shapes = false;
Costs costs = PredictOpCountBasedCost(
CountBatchMatMulOperations(op_features, &found_unknown_shapes),
@@ -859,7 +869,8 @@ Costs OpLevelCostEstimator::PredictBatchMatMul(
return costs;
}
-Costs OpLevelCostEstimator::PredictMetadata(const OpInfo& op_features) const {
+Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
Costs costs;
costs.max_memory = CalculateOutputSize(op_features, &costs.inaccurate);
// Metadata operations are so cheap we assume they take the minimum amount of
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index b4302dc9e1..0e63299bcb 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include "tensorflow/core/grappler/costs/cost_estimator.h"
+#include "tensorflow/core/grappler/costs/op_context.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
#include "tensorflow/core/util/padding.h"
@@ -32,7 +33,7 @@ class OpLevelCostEstimator {
OpLevelCostEstimator();
virtual ~OpLevelCostEstimator() {}
- virtual Costs PredictCosts(const OpInfo& op_features) const;
+ virtual Costs PredictCosts(const OpContext& op_context) const;
protected:
// Returns an estimate of device performance (in billions of operations
@@ -43,7 +44,7 @@ class OpLevelCostEstimator {
// For operations for which we haven't yet built estimates, returns a dummy
// value based on input size.
- Costs DummyExecutionTime(const OpInfo& op_features) const;
+ Costs DummyExecutionTime(const OpContext& op_context) const;
// Naive cost estimate based on operations divided by device ops/sec.
Costs PredictOpCountBasedCost(double operations,
@@ -122,14 +123,14 @@ class OpLevelCostEstimator {
// Implementation of costs other than
// execution_time is optional, depending on the
// device.
- Costs PredictConv2D(const OpInfo& op_features) const;
- Costs PredictCwiseOp(const OpInfo& op_features) const;
- Costs PredictConv2DBackpropInput(const OpInfo& op_features) const;
- Costs PredictConv2DBackpropFilter(const OpInfo& op_features) const;
- Costs PredictMatMul(const OpInfo& op_features) const;
- Costs PredictNoOp(const OpInfo& op_features) const;
- Costs PredictBatchMatMul(const OpInfo& op_features) const;
- Costs PredictMetadata(const OpInfo& op_features) const;
+ Costs PredictConv2D(const OpContext& op_context) const;
+ Costs PredictCwiseOp(const OpContext& op_context) const;
+ Costs PredictConv2DBackpropInput(const OpContext& op_context) const;
+ Costs PredictConv2DBackpropFilter(const OpContext& op_context) const;
+ Costs PredictMatMul(const OpContext& op_context) const;
+ Costs PredictNoOp(const OpContext& op_context) const;
+ Costs PredictBatchMatMul(const OpContext& op_context) const;
+ Costs PredictMetadata(const OpContext& op_context) const;
// Utility function for safe division. Returns 0
// if rhs is 0 or negative.
@@ -148,7 +149,7 @@ class OpLevelCostEstimator {
protected:
std::map<string, int> elementwise_ops_;
- typedef std::function<Costs(const OpInfo& op_feature)> CostImpl;
+ typedef std::function<Costs(const OpContext& op_context)> CostImpl;
std::map<string, CostImpl> device_cost_impl_;
// If true, assume compute and memory overlap; hence, the op cost is max of
// compute_time and memory_time, insteaf of sum of those two.
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index 0cbfb10017..f19be4a0ee 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -24,7 +24,7 @@ namespace grappler {
namespace {
// Wrangles the minimum number of proto fields to set up a matrix.
-void DescribeMatrix(int rows, int columns, OpInfo *op_features) {
+void DescribeMatrix(int rows, int columns, OpInfo* op_features) {
auto input = op_features->add_inputs();
auto shape = input->mutable_shape();
auto shape_rows = shape->add_dim();
@@ -43,31 +43,31 @@ void SetCpuDevice(OpInfo* op_features) {
}
// Returns an OpInfo for MatMul with the minimum set of fields set up.
-OpInfo DescribeMatMul(int m, int n, int l, int k) {
- OpInfo op_features;
- SetCpuDevice(&op_features);
- op_features.set_op("MatMul");
+OpContext DescribeMatMul(int m, int n, int l, int k) {
+ OpContext op_context;
+ SetCpuDevice(&op_context.op_info);
+ op_context.op_info.set_op("MatMul");
- DescribeMatrix(m, l, &op_features);
- DescribeMatrix(k, n, &op_features);
- return op_features;
+ DescribeMatrix(m, l, &op_context.op_info);
+ DescribeMatrix(k, n, &op_context.op_info);
+ return op_context;
}
// Returns an OpInfo for MatMul with unknown input shapes.
-OpInfo DescribeMatMulUnknownShape() {
- OpInfo op_features;
- SetCpuDevice(&op_features);
- op_features.set_op("MatMul");
+OpContext DescribeMatMulUnknownShape() {
+ OpContext op_context;
+ SetCpuDevice(&op_context.op_info);
+ op_context.op_info.set_op("MatMul");
- auto input = op_features.add_inputs();
+ auto input = op_context.op_info.add_inputs();
auto shape = input->mutable_shape();
shape->set_unknown_rank(true);
- input = op_features.add_inputs();
+ input = op_context.op_info.add_inputs();
shape = input->mutable_shape();
shape->set_unknown_rank(true);
- return op_features;
+ return op_context;
}
// Wrangles the minimum number of proto fields to set up an input of
@@ -83,21 +83,21 @@ void DescribeArbitraryRankInput(const std::vector<int>& dims, DataType dtype,
}
// Returns an OpInfo for a BatchMatMul
-OpInfo DescribeBatchMatMul(const std::vector<int>& dims_a,
- const std::vector<int>& dims_b) {
- OpInfo op_features;
- SetCpuDevice(&op_features);
- op_features.set_op("BatchMatMul");
+OpContext DescribeBatchMatMul(const std::vector<int>& dims_a,
+ const std::vector<int>& dims_b) {
+ OpContext op_context;
+ SetCpuDevice(&op_context.op_info);
+ op_context.op_info.set_op("BatchMatMul");
- DescribeArbitraryRankInput(dims_a, DT_FLOAT, &op_features);
- DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_features);
- return op_features;
+ DescribeArbitraryRankInput(dims_a, DT_FLOAT, &op_context.op_info);
+ DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
+ return op_context;
}
// Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
// estimation purposes.
void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3,
- OpInfo *op_features) {
+ OpInfo* op_features) {
auto input = op_features->add_inputs();
auto shape = input->mutable_shape();
shape->add_dim()->set_size(dim0);
@@ -108,26 +108,26 @@ void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3,
}
// Returns an OpInfo for Conv2D with the minimum set of fields set up.
-OpInfo DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2, int kx,
- int ky, int oz) {
- OpInfo op_features;
- SetCpuDevice(&op_features);
- op_features.set_op("Conv2D");
+OpContext DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2,
+ int kx, int ky, int oz) {
+ OpContext op_context;
+ SetCpuDevice(&op_context.op_info);
+ op_context.op_info.set_op("Conv2D");
- DescribeTensor4D(batch, ix, iy, iz1, &op_features);
- DescribeTensor4D(kx, ky, iz2, oz, &op_features);
- return op_features;
+ DescribeTensor4D(batch, ix, iy, iz1, &op_context.op_info);
+ DescribeTensor4D(kx, ky, iz2, oz, &op_context.op_info);
+ return op_context;
}
-OpInfo DescribeOp(const string& op, int size1, int size2) {
- OpInfo op_features;
- SetCpuDevice(&op_features);
- op_features.set_op(op);
+OpContext DescribeOp(const string& op, int size1, int size2) {
+ OpContext op_context;
+ SetCpuDevice(&op_context.op_info);
+ op_context.op_info.set_op(op);
- DescribeTensor4D(size1, 1, 1, 1, &op_features);
- DescribeTensor4D(2 * size1, size2, 1, 1, &op_features);
+ DescribeTensor4D(size1, 1, 1, 1, &op_context.op_info);
+ DescribeTensor4D(2 * size1, size2, 1, 1, &op_context.op_info);
- auto output = op_features.add_outputs();
+ auto output = op_context.op_info.add_outputs();
auto shape = output->mutable_shape();
shape->add_dim()->set_size(2 * size1);
shape->add_dim()->set_size(size2);
@@ -135,15 +135,15 @@ OpInfo DescribeOp(const string& op, int size1, int size2) {
shape->add_dim()->set_size(1);
output->set_dtype(DT_FLOAT);
- SetCpuDevice(&op_features);
- return op_features;
+ SetCpuDevice(&op_context.op_info);
+ return op_context;
}
} // namespace
class OpLevelCostEstimatorTest : public ::testing::Test {
protected:
- Costs PredictCosts(const OpInfo& op_features) const {
- return estimator_.PredictCosts(op_features);
+ Costs PredictCosts(const OpContext& op_context) const {
+ return estimator_.PredictCosts(op_context);
}
int64 CountMatMulOperations(const OpInfo& op_features,
@@ -228,20 +228,21 @@ TEST_F(OpLevelCostEstimatorTest, BatchMatMul) {
bool matmul_inaccurate = false;
bool batch_matmul_inaccurate = false;
EXPECT_EQ(
- CountMatMulOperations(DescribeMatMul(2, 2, 4, 4), &matmul_inaccurate),
- CountBatchMatMulOperations(DescribeBatchMatMul({2, 4}, {4, 2}),
+ CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
+ &matmul_inaccurate),
+ CountBatchMatMulOperations(DescribeBatchMatMul({2, 4}, {4, 2}).op_info,
&batch_matmul_inaccurate));
EXPECT_EQ(matmul_inaccurate, batch_matmul_inaccurate);
- EXPECT_EQ(10 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4),
+ EXPECT_EQ(10 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
&matmul_inaccurate),
CountBatchMatMulOperations(
- DescribeBatchMatMul({10, 2, 4}, {-1, 10, 4, 2}),
+ DescribeBatchMatMul({10, 2, 4}, {-1, 10, 4, 2}).op_info,
&batch_matmul_inaccurate));
EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
- EXPECT_EQ(20 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4),
+ EXPECT_EQ(20 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
&matmul_inaccurate),
CountBatchMatMulOperations(
- DescribeBatchMatMul({2, 10, 2, 4}, {-1, 10, 4, 2}),
+ DescribeBatchMatMul({2, 10, 2, 4}, {-1, 10, 4, 2}).op_info,
&batch_matmul_inaccurate));
EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
}
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 16c434b0ad..4294c9e954 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -377,7 +377,7 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
return std::make_pair(send, recv);
}
-NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
+OpContext VirtualScheduler::GetCurrNode() const {
const NodeDef* node = ready_nodes_->GetCurrNode();
// Get the device from the placer.
@@ -389,12 +389,12 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
device.set_type(kChannelDevice);
}
- // Construct NodeInfo.
- NodeInfo node_info;
+ // Construct OpContext.
+ OpContext op_context;
const auto& node_state = node_map_.at(node);
- node_info.name = node->name();
- node_info.device_name = node_state.device_name;
- auto& op_info = node_info.op_info;
+ op_context.name = node->name();
+ op_context.device_name = node_state.device_name;
+ auto& op_info = op_context.op_info;
op_info.set_op(node->op());
*op_info.mutable_attr() = node->attr();
for (auto& input : node_state.input_properties) {
@@ -404,7 +404,11 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
*op_info.add_outputs() = output;
}
op_info.mutable_device()->Swap(&device);
- return node_info;
+
+ if (grappler_item_->graph.has_library()) {
+ op_context.function_library = &grappler_item_->graph.library();
+ }
+ return op_context;
}
NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
@@ -497,8 +501,8 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
const auto& op_name = node->op();
// Also keep track of op counts and times per op (with their shapes).
- NodeInfo node_info = GetCurrNodeInfo();
- string node_description = GetOpDescription(node_info.op_info);
+ OpContext op_context = GetCurrNode();
+ string node_description = GetOpDescription(op_context.op_info);
op_counts_[node_description] += 1;
op_costs_[node_description] =
node_costs.execution_time.asMicroSeconds().count();
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h
index 0bbd2fd2eb..767b91677f 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.h
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/grappler/costs/cost_estimator.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/costs/op_context.h"
#include "tensorflow/core/grappler/costs/virtual_placer.h"
#include "tensorflow/core/grappler/grappler_item.h"
@@ -250,15 +251,6 @@ class FirstReadyManager : public ReadyNodeManager {
const std::unordered_map<const NodeDef*, NodeState>* node_state_;
};
-// A wrapper struct to OpInfo proto.
-// TODO(dyoon): once we extend OpInfo or implement a better interface, and then
-// delete this wrapper struct.
-struct NodeInfo {
- OpInfo op_info;
- string name;
- string device_name;
-};
-
// The virtual scheduler emulates execution of nodes in a graph, considering
// dependencies, device, etc.
class VirtualScheduler {
@@ -270,7 +262,7 @@ class VirtualScheduler {
// graph_properties_.
Status Init();
- NodeInfo GetCurrNodeInfo() const;
+ OpContext GetCurrNode() const;
// Returns true if there is any node to be scheduled.
bool MarkCurrNodeExecuted(const Costs& node_costs);
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index cea00b04f2..64fb626422 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -719,12 +719,12 @@ versions {
}
// Returns cost based on op.
- Costs SimplePredictCosts(const NodeInfo& info) const {
+ Costs SimplePredictCosts(const OpContext& op_context) const {
Costs c;
int64 exec_cost = 0;
- if (info.op_info.op() == "MatMul") {
+ if (op_context.op_info.op() == "MatMul") {
exec_cost = 2000000000;
- } else if (info.op_info.op() == "RandomUniform") {
+ } else if (op_context.op_info.op() == "RandomUniform") {
exec_cost = 1000000000;
} else {
exec_cost = 1000;
@@ -735,18 +735,19 @@ versions {
// Call this after init scheduler_. Scheduler stops after executing
// target_node.
- std::unordered_map<string, NodeInfo> RunScheduler(const string& target_node) {
+ std::unordered_map<string, OpContext> RunScheduler(
+ const string& target_node) {
Costs zero_costs = Costs::ZeroCosts();
- std::unordered_map<string, NodeInfo> ops_executed;
+ std::unordered_map<string, OpContext> ops_executed;
bool more_nodes = true;
do {
- NodeInfo node_info = scheduler_->GetCurrNodeInfo();
- ops_executed[node_info.name] = node_info;
+ OpContext op_context = scheduler_->GetCurrNode();
+ ops_executed[op_context.name] = op_context;
- Costs node_costs = SimplePredictCosts(node_info);
+ Costs node_costs = SimplePredictCosts(op_context);
// Check scheduling order.
- auto it = dependency_.find(node_info.name);
+ auto it = dependency_.find(op_context.name);
if (it != dependency_.end()) {
for (const auto& preceding_node : it->second) {
EXPECT_GT(ops_executed.count(preceding_node), 0);
@@ -754,7 +755,7 @@ versions {
}
more_nodes = scheduler_->MarkCurrNodeExecuted(node_costs);
- if (node_info.name == target_node) {
+ if (op_context.name == target_node) {
// Scheduler has the state after executing the target node.
break;
}