aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-02-26 14:00:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-26 14:03:54 -0800
commit7765066e6a686c7d6b1bed44248fafaa859db4eb (patch)
treeae5ed756c88ec6862f655306bad8dda2b98fdb71
parentba2cc572f99b09ddd6a60e0557059cb1da51b356 (diff)
TFTS: Switch to using core feature columns
This fixes some shape issues that came up when using the tf.contrib.layers parsing functions. Adds a string -> embedding column API example to the LSTM example. PiperOrigin-RevId: 187076400
-rw-r--r--tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv200
-rw-r--r--tensorflow/contrib/timeseries/examples/known_anomaly.py8
-rw-r--r--tensorflow/contrib/timeseries/examples/lstm.py26
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators.py53
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/model.py38
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py10
6 files changed, 177 insertions, 158 deletions
diff --git a/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv b/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv
index b49a0662c2..9b15b4f0b2 100644
--- a/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv
+++ b/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv
@@ -1,100 +1,100 @@
-0,0.926906299771,1.99107237682,2.56546245685,3.07914768197,4.04839057867,1.,0.
-1,0.108010001864,1.41645361423,2.1686839775,2.94963962176,4.1263503303,1.,0.
-2,-0.800567600028,1.0172132907,1.96434754116,2.99885333086,4.04300485864,1.,0.
-3,0.0607042871898,0.719540073421,1.9765012584,2.89265588817,4.0951014426,1.,0.
-4,0.933712200629,0.28052120776,1.41018552514,2.69232603996,4.06481164223,1.,0.
-5,-0.171730652974,0.260054421028,1.48770816369,2.62199129293,4.44572807842,1.,0.
-6,-1.00180162933,0.333045158863,1.50006392277,2.88888309683,4.24755865606,1.,0.
-7,0.0580061875336,0.688929398826,1.56543458772,2.99840358953,4.52726873347,1.,0.
-8,0.764139447412,1.24704875327,1.77649279698,3.13578593851,4.63238922951,1.,0.
-9,-0.230331874785,1.47903998963,2.03547545751,3.20624030377,4.77980005228,1.,0.
-10,-1.03846045211,2.01133000781,2.31977503972,3.67951536251,5.09716775897,1.,0.
-11,0.188643592253,2.23285349038,2.68338482249,3.49817168611,5.24928239634,1.,0.
-12,0.91207302309,2.24244446841,2.71362604985,3.96332587625,5.37802271594,1.,0.
-13,-0.296588665881,2.02594634141,3.07733910479,3.99698324956,5.56365901394,1.,0.
-14,-0.959961476551,1.45078629833,3.18996420137,4.3763059609,5.65356015609,1.,0.
-15,0.46313530679,1.01141441548,3.4980215948,4.20224896882,5.88842247449,1.,0.
-16,0.929354125798,0.626635305936,3.70508262244,4.51791573544,5.73945973251,1.,0.
-17,-0.519110731957,0.269249223148,3.39866823332,4.46802003061,5.82768174382,1.,0.
-18,-0.924330981367,0.349602834684,3.21762413294,4.72803587499,5.94918925767,1.,0.
-19,0.253239387885,0.345158023497,3.11071425333,4.79311566935,5.9489259713,1.,0.
-20,0.637408390225,0.698996675371,3.25232492145,4.73814732384,5.9612010251,1.,0.
-21,-0.407396859412,1.17456342803,2.49526823723,4.59323415742,5.82501686811,1.,0.
-22,-0.967485452118,1.66655933642,2.47284606244,4.58316034754,5.88721406681,1.,0.
-23,0.474480867904,1.95018556323,2.0228950072,4.48651142819,5.8255943735,1.,0.
-24,1.04309652155,2.23519892356,1.91924131572,4.19094661783,5.87457348436,1.,0.
-25,-0.517861513772,2.12501967336,1.70266619979,4.05280882887,5.72160912899,1.,0.
-26,-0.945301585146,1.65464653549,1.81567174251,3.92309850635,5.58270493814,1.,0.
-27,0.501153868974,1.40600764889,1.53991387719,3.72853247942,5.60169001727,1.,0.
-28,0.972859524418,1.00344321868,1.5175642828,3.64092376655,5.10567722582,1.,0.
-29,-0.70553406135,0.465306263885,1.7038540803,3.33236870312,5.09182481555,1.,0.
-30,-0.946093634916,0.294539309453,1.88052827037,2.93011492669,4.97354922696,1.,0.
-31,0.47922123231,0.308465865031,2.03445883031,2.90772899045,4.86241793548,1.,0.
-32,0.754030014252,0.549752241167,2.46115815089,2.95063349534,4.71834614627,1.,0.
-33,-0.64875949826,0.894615488148,2.5922463381,2.81269864022,4.43480095104,1.,0.
-34,-0.757829951086,1.39123914261,2.69258079904,2.61834837315,4.36580046156,1.,0.
-35,0.565653301088,1.72360022693,2.97794913834,2.80403840334,4.27327248459,1.,0.
-36,0.867440092372,2.21100730052,3.38648090792,2.84057515729,4.12210169576,1.,0.
-37,-0.894567758095,2.17549105818,3.45532493329,2.90446025717,4.00251740584,1.,0.
-38,-0.715442356893,2.15105389965,3.52041791902,3.03650393392,4.12809249577,1.,0.
-39,0.80671703672,1.81504564517,3.60463324866,3.00747789871,3.98440762467,1.,0.
-40,0.527014790142,1.31803513865,3.43842186337,3.3332594663,4.03232406566,1.,0.
-41,-0.795936862129,0.847809114454,3.09875133548,3.52863155938,3.94883924909,1.,0.
-42,-0.610245806946,0.425530441018,2.92581949152,3.77238736123,4.27287245021,1.,0.
-43,0.611662279431,0.178432049837,2.48128214822,3.73212087883,4.17319013831,1.,0.
-44,0.650866553108,0.220341648392,2.41694642022,4.2609098519,4.27271645905,1.,0.
-45,-0.774156982023,0.632667602331,2.05474356052,4.32889204886,4.18029723271,1.,0.
-46,-0.714058448409,0.924562377599,1.75706135146,4.52492718422,4.3972678094,1.,0.
-47,0.889627293379,1.46207968841,1.78299357672,4.64466731095,4.56317887554,1.,0.
-48,0.520140662861,1.8996333843,1.41377633823,4.48899091177,4.78805049769,1.,0.
-49,-1.03816935616,2.08997002059,1.51218375351,4.84167764204,4.93026048606,1.,0.
-50,-0.40772951362,2.30878972136,1.44144415128,4.76854460997,5.01538444629,1.,0.
-51,0.792730684781,1.91367048509,1.58887384677,4.71739397335,5.25690012199,1.,0.
-52,0.371311881576,1.67565079528,1.81688563053,4.60353107555,5.44265822961,1.,0.
-53,-0.814398070371,1.13374634126,1.80328814859,4.72264252878,5.52674761122,1.,0.
-54,-0.469017949323,0.601244136627,2.29690896736,4.49859178859,5.54126153454,1.,0.
-55,0.871044371426,0.407597593794,2.7499112487,4.19060637761,5.57693767301,1.,0.
-56,0.523764933017,0.247705192709,3.09002071379,4.02095509006,5.80510362182,1.,0.
-57,-0.881326403531,0.31513103164,3.11358205718,3.96079100808,5.81000652365,1.,0.
-58,-0.357928025339,0.486163915865,3.17884556771,3.72634990659,5.85693642011,1.,0.
-59,0.853038779822,1.04218094475,3.45835384454,3.36703969978,5.9585988449,1.,0.
-60,0.435311516013,1.59715085283,3.63313338588,3.11276729421,5.93643818229,1.,0.
-61,-1.02703719138,1.92205832542,3.47606111735,3.06247155999,6.02106646259,1.,0.
-62,-0.246661325557,2.14653802542,3.29446326567,2.89936259181,5.67531541272,1.,0.
-63,1.02554736569,2.25943737733,3.07031591528,2.78176218013,5.78206328989,1.,0.
-64,0.337814475969,2.07589147224,2.80356226089,2.55888206331,5.7094075496,1.,0.
-65,-1.12023369929,1.25333011618,2.56497288445,2.77361359194,5.50799418376,1.,0.
-66,-0.178980246554,1.11937139901,2.51598681313,2.91438309151,5.47469577206,1.,0.
-67,0.97550951531,0.60553823137,2.11657741073,2.88081098981,5.37034999502,1.,0.
-68,0.136653357206,0.365828836075,1.97386033165,3.13217903204,5.07254490219,1.,0.
-69,-1.05607596951,0.153152115069,1.52110743825,3.01308794192,5.08902539125,1.,0.
-70,-0.13095280331,0.337113974483,1.52703079853,3.16687131599,4.86649398514,1.,0.
-71,1.07081057754,0.714247566736,1.53761382634,3.45151989484,4.75892309166,1.,0.
-72,0.0153410376082,1.24631231847,1.61690939161,3.85481994498,4.35683752832,1.,0.
-73,-0.912801257303,1.60791309476,1.8729264524,4.03037260012,4.36072588913,1.,0.
-74,-0.0894895640338,2.02535207407,1.93484909619,4.09557485132,4.35327025188,1.,0.
-75,0.978646999652,2.20085086625,2.09003440427,4.27542353033,4.1805058388,1.,0.
-76,-0.113312642876,2.2444100761,2.50789248839,4.4151861502,4.03267168136,1.,0.
-77,-1.00215099149,1.84305628445,2.61691237246,4.45425147595,3.81203553766,1.,0.
-78,-0.0183234614205,1.49573923116,2.99308471214,4.71134960112,4.0273804959,1.,0.
-79,1.0823738177,1.12211589848,3.27079386925,4.94288270502,4.01851068083,1.,0.
-80,0.124370187893,0.616474412808,3.4284236674,4.76942168327,3.9749536483,1.,0.
-81,-0.929423379352,0.290977090976,3.34131726136,4.78590392707,4.10190661656,1.,0.
-82,0.23766302648,0.155302052254,3.49779513794,4.64605656795,4.15571321107,1.,0.
-83,1.03531486192,0.359702776204,3.4880725919,4.48167586667,4.21134561991,1.,0.
-84,-0.261234571382,0.713877760378,3.42756426614,4.426443869,4.25208300527,1.,0.
-85,-1.03572442277,1.25001113691,2.96908341113,4.25500915322,4.25723010649,1.,0.
-86,0.380034261243,1.70543355622,2.73605932518,4.16703432307,4.63700400788,1.,0.
-87,1.03734873488,1.97544410562,2.55586572141,3.84976673263,4.55282864289,1.,0.
-88,-0.177344253372,2.22614526325,2.09565864891,3.77378097953,4.82577400298,1.,0.
-89,-0.976821526892,2.18385079177,1.78522284118,3.67768223554,5.06302440873,1.,0.
-90,0.264820472091,1.86981946157,1.50048403865,3.43619796921,5.05651761669,1.,0.
-91,1.05642344868,1.47568646076,1.51347671977,3.20898518885,5.50149047462,1.,0.
-92,-0.311607433358,1.04226467636,1.52089650905,3.02291865417,5.4889046232,1.,0.
-93,-0.724285777937,0.553052311957,1.48573560173,2.7365973598,5.72549174225,1.,0.
-94,0.519859192905,0.226520626591,1.61543723167,2.84102086852,5.69330622288,1.,0.
-95,1.0323195039,0.260873217055,1.81913034804,2.83951143848,5.90325028086,1.,0.
-96,-0.53285682538,0.387695521405,1.70935609313,2.57977050631,5.79579213161,1.,0.
-97,-0.975127997215,0.920948771589,2.51292643636,2.71004616612,5.87016469227,1.,0.
-98,0.540246804099,1.36445470181,2.61949412896,2.98482553485,6.02447664937,1.,0.
-99,0.987764008058,1.85581989607,2.84685706149,2.94760204892,6.0212151724,1.,0.
+0,0.926906299771,1.99107237682,2.56546245685,3.07914768197,4.04839057867,1.,0.,strkeya
+1,0.108010001864,1.41645361423,2.1686839775,2.94963962176,4.1263503303,1.,0.,strkeyb
+2,-0.800567600028,1.0172132907,1.96434754116,2.99885333086,4.04300485864,1.,0.,strkey
+3,0.0607042871898,0.719540073421,1.9765012584,2.89265588817,4.0951014426,1.,0.,strkey
+4,0.933712200629,0.28052120776,1.41018552514,2.69232603996,4.06481164223,1.,0.,strkey
+5,-0.171730652974,0.260054421028,1.48770816369,2.62199129293,4.44572807842,1.,0.,strkey
+6,-1.00180162933,0.333045158863,1.50006392277,2.88888309683,4.24755865606,1.,0.,strkey
+7,0.0580061875336,0.688929398826,1.56543458772,2.99840358953,4.52726873347,1.,0.,strkey
+8,0.764139447412,1.24704875327,1.77649279698,3.13578593851,4.63238922951,1.,0.,strkey
+9,-0.230331874785,1.47903998963,2.03547545751,3.20624030377,4.77980005228,1.,0.,strkey
+10,-1.03846045211,2.01133000781,2.31977503972,3.67951536251,5.09716775897,1.,0.,strkeyc
+11,0.188643592253,2.23285349038,2.68338482249,3.49817168611,5.24928239634,1.,0.,strkey
+12,0.91207302309,2.24244446841,2.71362604985,3.96332587625,5.37802271594,1.,0.,strkey
+13,-0.296588665881,2.02594634141,3.07733910479,3.99698324956,5.56365901394,1.,0.,strkey
+14,-0.959961476551,1.45078629833,3.18996420137,4.3763059609,5.65356015609,1.,0.,strkey
+15,0.46313530679,1.01141441548,3.4980215948,4.20224896882,5.88842247449,1.,0.,strkey
+16,0.929354125798,0.626635305936,3.70508262244,4.51791573544,5.73945973251,1.,0.,strkey
+17,-0.519110731957,0.269249223148,3.39866823332,4.46802003061,5.82768174382,1.,0.,strkey
+18,-0.924330981367,0.349602834684,3.21762413294,4.72803587499,5.94918925767,1.,0.,strkey
+19,0.253239387885,0.345158023497,3.11071425333,4.79311566935,5.9489259713,1.,0.,strkey
+20,0.637408390225,0.698996675371,3.25232492145,4.73814732384,5.9612010251,1.,0.,strkey
+21,-0.407396859412,1.17456342803,2.49526823723,4.59323415742,5.82501686811,1.,0.,strkey
+22,-0.967485452118,1.66655933642,2.47284606244,4.58316034754,5.88721406681,1.,0.,strkey
+23,0.474480867904,1.95018556323,2.0228950072,4.48651142819,5.8255943735,1.,0.,strkey
+24,1.04309652155,2.23519892356,1.91924131572,4.19094661783,5.87457348436,1.,0.,strkey
+25,-0.517861513772,2.12501967336,1.70266619979,4.05280882887,5.72160912899,1.,0.,strkey
+26,-0.945301585146,1.65464653549,1.81567174251,3.92309850635,5.58270493814,1.,0.,strkey
+27,0.501153868974,1.40600764889,1.53991387719,3.72853247942,5.60169001727,1.,0.,strkey
+28,0.972859524418,1.00344321868,1.5175642828,3.64092376655,5.10567722582,1.,0.,strkey
+29,-0.70553406135,0.465306263885,1.7038540803,3.33236870312,5.09182481555,1.,0.,strkey
+30,-0.946093634916,0.294539309453,1.88052827037,2.93011492669,4.97354922696,1.,0.,strkey
+31,0.47922123231,0.308465865031,2.03445883031,2.90772899045,4.86241793548,1.,0.,strkey
+32,0.754030014252,0.549752241167,2.46115815089,2.95063349534,4.71834614627,1.,0.,strkey
+33,-0.64875949826,0.894615488148,2.5922463381,2.81269864022,4.43480095104,1.,0.,strkey
+34,-0.757829951086,1.39123914261,2.69258079904,2.61834837315,4.36580046156,1.,0.,strkey
+35,0.565653301088,1.72360022693,2.97794913834,2.80403840334,4.27327248459,1.,0.,strkey
+36,0.867440092372,2.21100730052,3.38648090792,2.84057515729,4.12210169576,1.,0.,strkey
+37,-0.894567758095,2.17549105818,3.45532493329,2.90446025717,4.00251740584,1.,0.,strkeyd
+38,-0.715442356893,2.15105389965,3.52041791902,3.03650393392,4.12809249577,1.,0.,strkey
+39,0.80671703672,1.81504564517,3.60463324866,3.00747789871,3.98440762467,1.,0.,strkey
+40,0.527014790142,1.31803513865,3.43842186337,3.3332594663,4.03232406566,1.,0.,strkey
+41,-0.795936862129,0.847809114454,3.09875133548,3.52863155938,3.94883924909,1.,0.,strkey
+42,-0.610245806946,0.425530441018,2.92581949152,3.77238736123,4.27287245021,1.,0.,strkey
+43,0.611662279431,0.178432049837,2.48128214822,3.73212087883,4.17319013831,1.,0.,strkey
+44,0.650866553108,0.220341648392,2.41694642022,4.2609098519,4.27271645905,1.,0.,strkey
+45,-0.774156982023,0.632667602331,2.05474356052,4.32889204886,4.18029723271,1.,0.,strkey
+46,-0.714058448409,0.924562377599,1.75706135146,4.52492718422,4.3972678094,1.,0.,strkey
+47,0.889627293379,1.46207968841,1.78299357672,4.64466731095,4.56317887554,1.,0.,strkey
+48,0.520140662861,1.8996333843,1.41377633823,4.48899091177,4.78805049769,1.,0.,strkey
+49,-1.03816935616,2.08997002059,1.51218375351,4.84167764204,4.93026048606,1.,0.,strkey
+50,-0.40772951362,2.30878972136,1.44144415128,4.76854460997,5.01538444629,1.,0.,strkey
+51,0.792730684781,1.91367048509,1.58887384677,4.71739397335,5.25690012199,1.,0.,strkey
+52,0.371311881576,1.67565079528,1.81688563053,4.60353107555,5.44265822961,1.,0.,strkey
+53,-0.814398070371,1.13374634126,1.80328814859,4.72264252878,5.52674761122,1.,0.,strkey
+54,-0.469017949323,0.601244136627,2.29690896736,4.49859178859,5.54126153454,1.,0.,strkey
+55,0.871044371426,0.407597593794,2.7499112487,4.19060637761,5.57693767301,1.,0.,strkey
+56,0.523764933017,0.247705192709,3.09002071379,4.02095509006,5.80510362182,1.,0.,strkey
+57,-0.881326403531,0.31513103164,3.11358205718,3.96079100808,5.81000652365,1.,0.,strkey
+58,-0.357928025339,0.486163915865,3.17884556771,3.72634990659,5.85693642011,1.,0.,strkey
+59,0.853038779822,1.04218094475,3.45835384454,3.36703969978,5.9585988449,1.,0.,strkey
+60,0.435311516013,1.59715085283,3.63313338588,3.11276729421,5.93643818229,1.,0.,strkey
+61,-1.02703719138,1.92205832542,3.47606111735,3.06247155999,6.02106646259,1.,0.,strkey
+62,-0.246661325557,2.14653802542,3.29446326567,2.89936259181,5.67531541272,1.,0.,strkey
+63,1.02554736569,2.25943737733,3.07031591528,2.78176218013,5.78206328989,1.,0.,strkey
+64,0.337814475969,2.07589147224,2.80356226089,2.55888206331,5.7094075496,1.,0.,strkey
+65,-1.12023369929,1.25333011618,2.56497288445,2.77361359194,5.50799418376,1.,0.,strkey
+66,-0.178980246554,1.11937139901,2.51598681313,2.91438309151,5.47469577206,1.,0.,strkey
+67,0.97550951531,0.60553823137,2.11657741073,2.88081098981,5.37034999502,1.,0.,strkey
+68,0.136653357206,0.365828836075,1.97386033165,3.13217903204,5.07254490219,1.,0.,strkey
+69,-1.05607596951,0.153152115069,1.52110743825,3.01308794192,5.08902539125,1.,0.,strkey
+70,-0.13095280331,0.337113974483,1.52703079853,3.16687131599,4.86649398514,1.,0.,strkey
+71,1.07081057754,0.714247566736,1.53761382634,3.45151989484,4.75892309166,1.,0.,strkey
+72,0.0153410376082,1.24631231847,1.61690939161,3.85481994498,4.35683752832,1.,0.,strkey
+73,-0.912801257303,1.60791309476,1.8729264524,4.03037260012,4.36072588913,1.,0.,strkey
+74,-0.0894895640338,2.02535207407,1.93484909619,4.09557485132,4.35327025188,1.,0.,strkey
+75,0.978646999652,2.20085086625,2.09003440427,4.27542353033,4.1805058388,1.,0.,strkey
+76,-0.113312642876,2.2444100761,2.50789248839,4.4151861502,4.03267168136,1.,0.,strkey
+77,-1.00215099149,1.84305628445,2.61691237246,4.45425147595,3.81203553766,1.,0.,strkey
+78,-0.0183234614205,1.49573923116,2.99308471214,4.71134960112,4.0273804959,1.,0.,strkey
+79,1.0823738177,1.12211589848,3.27079386925,4.94288270502,4.01851068083,1.,0.,strkey
+80,0.124370187893,0.616474412808,3.4284236674,4.76942168327,3.9749536483,1.,0.,strkey
+81,-0.929423379352,0.290977090976,3.34131726136,4.78590392707,4.10190661656,1.,0.,strkey
+82,0.23766302648,0.155302052254,3.49779513794,4.64605656795,4.15571321107,1.,0.,strkey
+83,1.03531486192,0.359702776204,3.4880725919,4.48167586667,4.21134561991,1.,0.,strkey
+84,-0.261234571382,0.713877760378,3.42756426614,4.426443869,4.25208300527,1.,0.,strkey
+85,-1.03572442277,1.25001113691,2.96908341113,4.25500915322,4.25723010649,1.,0.,strkey
+86,0.380034261243,1.70543355622,2.73605932518,4.16703432307,4.63700400788,1.,0.,strkey
+87,1.03734873488,1.97544410562,2.55586572141,3.84976673263,4.55282864289,1.,0.,strkey
+88,-0.177344253372,2.22614526325,2.09565864891,3.77378097953,4.82577400298,1.,0.,strkey
+89,-0.976821526892,2.18385079177,1.78522284118,3.67768223554,5.06302440873,1.,0.,strkey
+90,0.264820472091,1.86981946157,1.50048403865,3.43619796921,5.05651761669,1.,0.,strkey
+91,1.05642344868,1.47568646076,1.51347671977,3.20898518885,5.50149047462,1.,0.,strkey
+92,-0.311607433358,1.04226467636,1.52089650905,3.02291865417,5.4889046232,1.,0.,strkey
+93,-0.724285777937,0.553052311957,1.48573560173,2.7365973598,5.72549174225,1.,0.,strkey
+94,0.519859192905,0.226520626591,1.61543723167,2.84102086852,5.69330622288,1.,0.,strkey
+95,1.0323195039,0.260873217055,1.81913034804,2.83951143848,5.90325028086,1.,0.,strkey
+96,-0.53285682538,0.387695521405,1.70935609313,2.57977050631,5.79579213161,1.,0.,strkey
+97,-0.975127997215,0.920948771589,2.51292643636,2.71004616612,5.87016469227,1.,0.,strkey
+98,0.540246804099,1.36445470181,2.61949412896,2.98482553485,6.02447664937,1.,0.,strkey
+99,0.987764008058,1.85581989607,2.84685706149,2.94760204892,6.0212151724,1.,0.,strkey
diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly.py b/tensorflow/contrib/timeseries/examples/known_anomaly.py
index 7659dd308a..c08c0b0acb 100644
--- a/tensorflow/contrib/timeseries/examples/known_anomaly.py
+++ b/tensorflow/contrib/timeseries/examples/known_anomaly.py
@@ -46,12 +46,12 @@ def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300):
# Indicate the format of our exogenous feature, in this case a string
# representing a boolean value.
- string_feature = tf.contrib.layers.sparse_column_with_keys(
- column_name="is_changepoint", keys=["no", "yes"])
+ string_feature = tf.feature_column.categorical_column_with_vocabulary_list(
+ key="is_changepoint", vocabulary_list=["no", "yes"])
# Specify the way this feature is presented to the model, here using a one-hot
# encoding.
- one_hot_feature = tf.contrib.layers.one_hot_column(
- sparse_id_column=string_feature)
+ one_hot_feature = tf.feature_column.indicator_column(
+ categorical_column=string_feature)
estimator = tf.contrib.timeseries.StructuralEnsembleRegressor(
periodicities=12,
diff --git a/tensorflow/contrib/timeseries/examples/lstm.py b/tensorflow/contrib/timeseries/examples/lstm.py
index f37cafcc50..2eee878196 100644
--- a/tensorflow/contrib/timeseries/examples/lstm.py
+++ b/tensorflow/contrib/timeseries/examples/lstm.py
@@ -59,10 +59,10 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
num_units: The number of units in the model's LSTMCell.
num_features: The dimensionality of the time series (features per
timestep).
- exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
- objects representing features which are inputs to the model but are
- not predicted by it. These must then be present for training,
- evaluation, and prediction.
+ exogenous_feature_columns: A list of `tf.feature_column`s representing
+ features which are inputs to the model but are not predicted by
+ it. These must then be present for training, evaluation, and
+ prediction.
dtype: The floating point data type to use.
"""
super(_LSTMModel, self).__init__(
@@ -189,12 +189,16 @@ def train_and_predict(
export_directory=None):
"""Train and predict using a custom time series model."""
# Construct an Estimator from our LSTM model.
+ categorical_column = tf.feature_column.categorical_column_with_hash_bucket(
+ key="categorical_exogenous_feature", hash_bucket_size=16)
exogenous_feature_columns = [
# Exogenous features are not part of the loss, but can inform
# predictions. In this example the features have no extra information, but
# are included as an API example.
- tf.contrib.layers.real_valued_column(
- "2d_exogenous_feature", dimension=2)]
+ tf.feature_column.numeric_column(
+ "2d_exogenous_feature", shape=(2,)),
+ tf.feature_column.embedding_column(
+ categorical_column=categorical_column, dimension=10)]
estimator = ts_estimators.TimeSeriesRegressor(
model=_LSTMModel(num_features=5, num_units=128,
exogenous_feature_columns=exogenous_feature_columns),
@@ -205,7 +209,11 @@ def train_and_predict(
csv_file_name,
column_names=((tf.contrib.timeseries.TrainEvalFeatures.TIMES,)
+ (tf.contrib.timeseries.TrainEvalFeatures.VALUES,) * 5
- + ("2d_exogenous_feature",) * 2))
+ + ("2d_exogenous_feature",) * 2
+ + ("categorical_exogenous_feature",)),
+ # Data types other than for `times` need to be specified if they aren't
+ # float32. In this case one of our exogenous features has string dtype.
+ column_dtypes=((tf.int64,) + (tf.float32,) * 7 + (tf.string,)))
train_input_fn = tf.contrib.timeseries.RandomWindowInputFn(
reader, batch_size=4, window_size=32)
estimator.train(input_fn=train_input_fn, steps=training_steps)
@@ -215,7 +223,9 @@ def train_and_predict(
predict_exogenous_features = {
"2d_exogenous_feature": numpy.concatenate(
[numpy.ones([1, 100, 1]), numpy.zeros([1, 100, 1])],
- axis=-1)}
+ axis=-1),
+ "categorical_exogenous_feature": numpy.array(
+ ["strkey"] * 100)[None, :, None]}
(predictions,) = tuple(estimator.predict(
input_fn=tf.contrib.timeseries.predict_continuation_input_fn(
evaluation, steps=100,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
index f8355f366f..8d13343e82 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.layers.python.layers import feature_column
-
from tensorflow.contrib.timeseries.python.timeseries import ar_model
from tensorflow.contrib.timeseries.python.timeseries import feature_keys
from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib
@@ -31,10 +29,12 @@ from tensorflow.contrib.timeseries.python.timeseries.state_space_models.filterin
from tensorflow.python.estimator import estimator_lib
from tensorflow.python.estimator.export import export_lib
+from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import parsing_ops
from tensorflow.python.training import training as train
@@ -117,22 +117,29 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
dtype=self._model.dtype),
shape=(default_batch_size, default_series_length,
self._model.num_features)))
- with ops.Graph().as_default():
- # Default placeholders have only an unknown batch dimension. Make them
- # in a separate graph, then splice in the series length to the shapes
- # and re-create them in the outer graph.
- exogenous_feature_shapes = {
- key: (value.get_shape(), value.dtype) for key, value
- in feature_column.make_place_holder_tensors_for_base_features(
- self._model.exogenous_feature_columns).items()}
- for feature_key, (batch_only_feature_shape, value_dtype) in (
- exogenous_feature_shapes.items()):
- batch_only_feature_shape = batch_only_feature_shape.with_rank_at_least(
- 1).as_list()
- feature_shape = ([default_batch_size, default_series_length]
- + batch_only_feature_shape[1:])
- placeholders[feature_key] = array_ops.placeholder(
- dtype=value_dtype, name=feature_key, shape=feature_shape)
+ if self._model.exogenous_feature_columns:
+ with ops.Graph().as_default():
+ # Default placeholders have only an unknown batch dimension. Make them
+ # in a separate graph, then splice in the series length to the shapes
+ # and re-create them in the outer graph.
+ parsed_features = (
+ feature_column.make_parse_example_spec(
+ self._model.exogenous_feature_columns))
+ placeholder_features = parsing_ops.parse_example(
+ serialized=array_ops.placeholder(
+ shape=[None], dtype=dtypes.string),
+ features=parsed_features)
+ exogenous_feature_shapes = {
+ key: (value.get_shape(), value.dtype) for key, value
+ in placeholder_features.items()}
+ for feature_key, (batch_only_feature_shape, value_dtype) in (
+ exogenous_feature_shapes.items()):
+ batch_only_feature_shape = (
+ batch_only_feature_shape.with_rank_at_least(1).as_list())
+ feature_shape = ([default_batch_size, default_series_length]
+ + batch_only_feature_shape[1:])
+ placeholders[feature_key] = array_ops.placeholder(
+ dtype=value_dtype, name=feature_key, shape=feature_shape)
# Models may not know the shape of their state without creating some
# variables/ops. Avoid polluting the default graph by making a new one. We
# use only static metadata from the returned Tensors.
@@ -333,11 +340,11 @@ class StructuralEnsembleRegressor(StateSpaceRegressor):
determine the model size. Learning autoregressive coefficients
typically requires more steps and a smaller step size than other
components.
- exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
- objects (for example tf.contrib.layers.embedding_column) corresponding
- to exogenous features which provide extra information to the model but
- are not part of the series to be predicted. Passed to
- tf.contrib.layers.input_from_feature_columns.
+ exogenous_feature_columns: A list of `tf.feature_column`s (for example
+ `tf.feature_column.embedding_column`) corresponding to exogenous
+ features which provide extra information to the model but are not part
+ of the series to be predicted. Passed to
+ `tf.feature_column.input_layer`.
exogenous_update_condition: A function taking two Tensor arguments,
`times` (shape [batch size]) and `features` (a dictionary mapping
exogenous feature keys to Tensors with shapes [batch size, ...]), and
diff --git a/tensorflow/contrib/timeseries/python/timeseries/model.py b/tensorflow/contrib/timeseries/python/timeseries/model.py
index bac7d1ebf5..7644764a74 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/model.py
@@ -21,18 +21,17 @@ from __future__ import print_function
import abc
import collections
-from tensorflow.contrib import layers
-from tensorflow.contrib.layers import feature_column
-
from tensorflow.contrib.timeseries.python.timeseries import math_utils
from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures
from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures
+from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
@@ -66,11 +65,11 @@ class TimeSeriesModel(object):
Args:
num_features: Number of features for the time series
- exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
- objects (for example tf.contrib.layers.embedding_column) corresponding
- to exogenous features which provide extra information to the model but
- are not part of the series to be predicted. Passed to
- tf.contrib.layers.input_from_feature_columns.
+ exogenous_feature_columns: A list of `tf.feature_column`s (for example
+ `tf.feature_column.embedding_column`) corresponding to exogenous
+ features which provide extra information to the model but are not
+ part of the series to be predicted. Passed to
+ `tf.feature_column.input_layer`.
dtype: The floating point datatype to use.
"""
if exogenous_feature_columns:
@@ -86,7 +85,7 @@ class TimeSeriesModel(object):
@property
def exogenous_feature_columns(self):
- """`FeatureColumn` objects for features which are not predicted."""
+ """`tf.feature_colum`s for features which are not predicted."""
return self._exogenous_feature_columns
# TODO(allenl): Move more of the generic machinery for generating and
@@ -265,11 +264,14 @@ class TimeSeriesModel(object):
if not self._exogenous_feature_columns:
return (0,)
with ops.Graph().as_default():
- placeholder_features = (
- feature_column.make_place_holder_tensors_for_base_features(
+ parsed_features = (
+ feature_column.make_parse_example_spec(
self._exogenous_feature_columns))
- embedded = layers.input_from_feature_columns(
- columns_to_tensors=placeholder_features,
+ placeholder_features = parsing_ops.parse_example(
+ serialized=array_ops.placeholder(shape=[None], dtype=dtypes.string),
+ features=parsed_features)
+ embedded = feature_column.input_layer(
+ features=placeholder_features,
feature_columns=self._exogenous_feature_columns)
return embedded.get_shape().as_list()[1:]
@@ -308,13 +310,13 @@ class TimeSeriesModel(object):
# Avoid shape warnings when embedding "scalar" exogenous features (those
# with only batch and window dimensions); input_from_feature_columns
# expects input ranks to match the embedded rank.
- if tensor.get_shape().ndims == 1:
+ if tensor.get_shape().ndims == 1 and tensor.dtype != dtypes.string:
exogenous_features_single_batch_dimension[name] = tensor[:, None]
else:
exogenous_features_single_batch_dimension[name] = tensor
embedded_exogenous_features_single_batch_dimension = (
- layers.input_from_feature_columns(
- columns_to_tensors=exogenous_features_single_batch_dimension,
+ feature_column.input_layer(
+ features=exogenous_features_single_batch_dimension,
feature_columns=self._exogenous_feature_columns,
trainable=True))
exogenous_regressors = array_ops.reshape(
@@ -381,8 +383,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
may use _scale_back_data or _scale_back_variance to return predictions
to the input scale.
dtype: The floating point datatype to use.
- exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
- objects. See `TimeSeriesModel`.
+ exogenous_feature_columns: A list of `tf.feature_column`s objects. See
+ `TimeSeriesModel`.
exogenous_update_condition: A function taking two Tensor arguments `times`
(shape [batch size]) and `features` (a dictionary mapping exogenous
feature keys to Tensors with shapes [batch size, ...]) and returning a
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py
index 6257002647..951c6546d5 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py
@@ -112,11 +112,11 @@ class StateSpaceModelConfiguration(
exogenous_noise_decreases: If True, exogenous regressors can "set" model
state, decreasing uncertainty. If both this parameter and
exogenous_noise_increases are False, exogenous regressors are ignored.
- exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
- objects (for example tf.contrib.layers.embedding_column) corresponding
- to exogenous features which provide extra information to the model but
- are not part of the series to be predicted. Passed to
- tf.contrib.layers.input_from_feature_columns.
+ exogenous_feature_columns: A list of `tf.feature_column`s (for example
+ `tf.feature_column.embedding_column`) corresponding to exogenous
+ features which provide extra information to the model but are not part
+ of the series to be predicted. Passed to
+ `tf.feature_column.input_layer`.
exogenous_update_condition: A function taking two Tensor arguments `times`
(shape [batch size]) and `features` (a dictionary mapping exogenous
feature keys to Tensors with shapes [batch size, ...]) and returning a