aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-06 19:38:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-06 19:43:05 -0700
commit7d0f6385f8e7637e155ef9c340c19aded365a6ff (patch)
tree843a815c37f285ca49edef2e11c2c8a258f5af21 /tensorflow/compiler/xla/service/dfs_hlo_visitor.h
parent24101b35f3baebbfff3d8057ac223b325bc415ce (diff)
[BatchNorm] Skeleton code to implement BatchNormGrad
This CL sets up all the boilerplate code needed to implement BatchNormGrad. None of the backends bas been implemented yet. RELNOTES: n/a PiperOrigin-RevId: 161161713
Diffstat (limited to 'tensorflow/compiler/xla/service/dfs_hlo_visitor.h')
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h2
1 files changed, 2 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index fcc4f85f01..0deedcf184 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -227,6 +227,8 @@ class DfsHloVisitor {
virtual Status HandleBatchNormTraining(HloInstruction* batchNormTraining) = 0;
+ virtual Status HandleBatchNormGrad(HloInstruction* batchNormGrad) = 0;
+
// Invoked to inform the visitor that the traversal has completed, and that
// the root was "root".
virtual Status FinishVisit(HloInstruction* root) = 0;