From e2379a9bee71d6adc2cafd211674b66060c9df2a Mon Sep 17 00:00:00 2001 From: Ilya Lasy Date: Fri, 26 Feb 2021 11:24:54 +0300 Subject: [PATCH 1/6] compatibility with latest pytorch --- model/central.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/model/central.py b/model/central.py index 847aac3..af8d1da 100644 --- a/model/central.py +++ b/model/central.py @@ -134,18 +134,14 @@ class LeftMMFixed(torch.autograd.Function): Implementation of matrix multiplication of a Sparse Variable with a Dense Variable, returning a Dense one. This is added because there's no autograd for sparse yet. No gradient computed on the sparse weights. """ - - def __init__(self): - super(LeftMMFixed, self).__init__() - self.sparse_weights = None - - def forward(self, sparse_weights, x): - if self.sparse_weights is None: - self.sparse_weights = sparse_weights - return torch.mm(self.sparse_weights, x) - - def backward(self, grad_output): - sparse_weights = self.sparse_weights + @staticmethod + def forward(ctx, sparse_weights, x): + ctx.sparse_weights = sparse_weights + return torch.mm(ctx.sparse_weights, x) + + @staticmethod + def backward(ctx, grad_output): + sparse_weights = ctx.sparse_weights return None, torch.mm(sparse_weights.t(), grad_output) I = X._indices() @@ -156,6 +152,5 @@ def backward(self, grad_output): lookup = Y[I[0, :], I[2, :], :] X_I = torch.stack((I[0, :] * M + I[1, :], use_cuda(torch.arange(Z).type(torch.LongTensor))), 0) S = use_cuda(Variable(torch.sparse.FloatTensor(X_I, V, torch.Size([B * M, Z])), requires_grad=False)) - prod_op = LeftMMFixed() - prod = prod_op(S, lookup) + prod = LeftMMFixed.apply(S, lookup) return prod.view(B, M, K) From 73a86da1a66af6913f7ac1f19842a23bed29a541 Mon Sep 17 00:00:00 2001 From: Ilya Lasy Date: Fri, 26 Feb 2021 11:31:07 +0300 Subject: [PATCH 2/6] cleared files --- .DS_Store | Bin 10244 -> 0 bytes .gitignore | 36 +++++++++++++++++++ __pycache__/preprocession.cpython-37.pyc | Bin 7147 -> 0 bytes model/.DS_Store | Bin 6148 -> 0 bytes model/__pycache__/__init__.cpython-37.pyc | Bin 234 -> 0 bytes model/__pycache__/central.cpython-37.pyc | Bin 6153 -> 0 bytes model/__pycache__/conceptflow.cpython-37.pyc | Bin 10130 -> 0 bytes model/__pycache__/embedding.cpython-37.pyc | Bin 2038 -> 0 bytes model/__pycache__/model.cpython-37.pyc | Bin 9818 -> 0 bytes model/__pycache__/outer.cpython-37.pyc | Bin 1588 -> 0 bytes training_output/.DS_Store | Bin 6148 -> 0 bytes utils/.DS_Store | Bin 6148 -> 0 bytes utils/__pycache__/__init__.cpython-37.pyc | Bin 237 -> 0 bytes utils/__pycache__/utils.cpython-37.pyc | Bin 1852 -> 0 bytes 14 files changed, 36 insertions(+) delete mode 100644 .DS_Store create mode 100644 .gitignore delete mode 100644 __pycache__/preprocession.cpython-37.pyc delete mode 100644 model/.DS_Store delete mode 100644 model/__pycache__/__init__.cpython-37.pyc delete mode 100644 model/__pycache__/central.cpython-37.pyc delete mode 100644 model/__pycache__/conceptflow.cpython-37.pyc delete mode 100644 model/__pycache__/embedding.cpython-37.pyc delete mode 100644 model/__pycache__/model.cpython-37.pyc delete mode 100644 model/__pycache__/outer.cpython-37.pyc delete mode 100644 training_output/.DS_Store delete mode 100644 utils/.DS_Store delete mode 100644 utils/__pycache__/__init__.cpython-37.pyc delete mode 100644 utils/__pycache__/utils.cpython-37.pyc diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 07cc71b557fe1c1edd7be74fef97a8fb0bfcbeee..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10244 zcmeHMOKVd>6h70Y=C(c>q&{%t-3SU&o3^$KQDRNOg71hfthPxL(?F6NlN;OGR1dUki>k>EvZAb^n(iMQ!yV;%{Zk>i==5V2p9wm0tNwtfI;AIK>%kqCu^fe zO=%D?2p9yK2ylF`k(nkGZFp3M4ixeV02)EFOsL1W0%|LoP_*GuaFB$j3hJo}b;S^# zI@)a|mr%6fQBR$Ox_k&VvQRe^AxB65wiG9k@Te&b0tSIr1o-YAp?=C!fn5Fm-nr%N zJDpc@lF4k{ogw=0>EYMLZ9gx&Ca*HP`XkZFfBTeB`p#_?uH1aIvxMiZ~%w&pA z1`@o94`CC}f0K!os0We`mFX&Gn#YV6s6<7-PRAC=-P3d;l*I@xrTubr3w`enlhYB9 zqcS?Gk5vRd9=US-#cA+0-U%R#7&1ov9QHHFsd(z3AwSbZv-QR*%)Qx#x0f)|pGWX}4Y* zOII)Pw=-#XHlyn29jDl=r!yA{xw)xW-!Z4`rVHg{Ix9h2Dx#*Phm=og?r4uwk z$7q~RQHsX>ZbIy4do?%d!q#%&_PLe~wiVqS2$%d{T}K1M=Om(c0kCOY%^l+Bo~)SCTsks#)ToBYz|k=ziipSPO*d~hor5DBr6ipM?T4|m zeuof)0gad(?;0XZLDot#YPlG-ACYIiCGPNucQrSOLehbz8#+H8?TIFBL9U9p%`q}$ zJ@hg#yIxJ6b@Dh(){Ket87{m{8?BA@|LWFlmhE$O;EmKqLv#O?$ktY)xh2-A;oBzb znpI%ybjhQ#Kt-zgUA|)y-F?yALX-U4TwsI`#gU&s6tyM#&Asmtz|gpwY2#-b+b4qD zrWr%Hev?f@Rqr`q9_E^45HJWB1PlVd8i7u40Gr?c$8kqmpLBJsJ2N4jF zWGXp^39P)J1n0tU?Kbi>GAEX6cvPyOkk9d;@;M$~_&I(rGRqDy<%D`d(S}FSg7n+} Y4ET|!BkhLo`;~0&|K|Sx55I_h2M{iLm;e9( diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c0b679c --- /dev/null +++ b/.gitignore @@ -0,0 +1,36 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Jupyter Notebook +.ipynb_checkpoints + +# VS Code +.vscode/ + +# MAC +.DS_Store + + +data/* \ No newline at end of file diff --git a/__pycache__/preprocession.cpython-37.pyc b/__pycache__/preprocession.cpython-37.pyc deleted file mode 100644 index 4816465e1063693d593dc7a6b0bd41cd542b28ff..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7147 zcmbtYO>kVtaenju_WuWq#p0JB1QL`1Q=p1c$%%_4OB6}l5(6pvqiZdBy~I2$7g+2r z<~@)AYMu{*oI|RV&{aC6asYg@%D(vIlMg!j;PQDYhg4GKD&~|!PCO*}dfpEfAn7pn zZoT>G>FMe2>7JhH>!niOz_0zQsn)MA8^*s=C;3@uyn)Yu8wEEw>lsa^YqM$U+G<*; z&7QsOG@Wg?>27D58D?;c+n*cFEYEY7XFg}m9MAF`AbHL{FqR9CK+jlPwp3===DgeQ zsIj=w3Pg9<^IBc5s&{w0J>J^7+iLU8)^2Zl7cIc>BZ)D~(NwLNWw7HS9m3~Cp37Ig-7u48lz z(+KU*0rkB29oEUF>#T)r?^aNd?2z3sKK=G5X6S|)>GIr`DZYHqKr7!ccmcg%21Qk4QxK^?aGX!nb!KCood_iL@Y;=3xLa*ER7MA!y8T;cHk*zD7yVXx$Pkq;LBO%t`GVf-Hu`#igA?_?FTKb0hQN_#{{A_SX5S*?poUD;Nx~c-#Hn@XA|*{<=2|ZuJHat_+1Y6oYlo z_q&7s>Tq8b3DOo`i|7;6{t61is-jl&aDf+jiI;h$lL?)`;ZVY$mD!`ru@UBA zty!7V_AN&36qLt^a>oespjCSgS_PR0t%9cY+cUJriB{JLi=b7HX$`^Gy_By{V!kqb z0^ZozG9NN=8C*=tOze*l$0b>U21-XcK6T8r7RnN^sL!V-afWDUMOI`{maiEP*-n0x z8%R{$yDn`y3&OH2%i@-cwf>2)r;VdLv{LEUg2HAIdM=?Z^O-Nqup+CFu0qm*;!o}v z7y&FDuY%XvSYu<5a!wl3+A#GFyaq{EQ%M&`ci?fobLZM>%BK`G_BTm-$nE0yB{dK_0ifpA1<59+Vb{Q zMtJw(@7}?=l!!h(_s@`VT5IpF-E6u@Oof*&_BIQe7HwLddB2!oRTM^dF#PI_>k+M0m_};;J&x|SFgT&ef5=>U*Y6KV}P$=mcM8%HhHO&^wq;)<`5jS*4ot-&*pC{$rLqm`(6cggBW| z!;AzT&siQ5H7uxHe;@*nE34z-=Z@yMvxDD7c6*LUMeHKw_NQ_C~kQ0U7qX zfin9;?w96|=hbnx#|<8dgj+GxsAg74wx%rq*;Lh;lSTsu9-2y6)3`_y!!rVdF_agMkuM7Jhw;7b7beG~?Q%N#4SGU^&?1Qs1| z>?hh~R>WA1H85@?gNJeRb3?+74-uw#6pjo8sGOWKThIfaFd=@&SS& z4iIo4?G2OX1DhAVu)@L!+!=uvalDuCmGM>RsK+-pg6jge7}KZ$?nJZ$7gWt~d;~Ky zLbW~ur`1X^g^6dt8(4J`-xSY>(+{1!uSkavollR#nZ4UVCdkSej$nhpZt9@?8bZQs zI49?%se{o!NK?*XydbCL?1p*G=vQ&ot&Bntgf?-|Ce*4Jb52n^C(p&y{sq*|VZ8VR zwG;wszrnGLFbaLla8Qf(n*tK#nBS8@UXP5?i01iTI15bl;ncPxbc4a5gAG9zci! zt{mlK=np!?lR{rzUVO?po}{yv6YU&LxJU@vIQ zlFLa?{D&T2gghE&+UPLoU{PL>mvrlbycoAG$|a43xiQuf=6nN`YOt$IWTo;Fw=nw> z(S(*6E)N&9r*xPu* zCf|R;CT;QihxX2|`>&o@p4N;l{;1(|9G4gHPx6o0=Cncb5C*L8@8fkQ(mlMd*Fj&~ zehIekh*uvvJ2SAg86LwiO4jd)H5$=2Ix-gTfCKx-Xn%n?I6YhY$VB3^ghh51u4r#) zJjELv@f%?Me8d}8;9-vV13}>pE97CPya90S4J+`5Z^{+z4O#39`mr~x;9{DSE2%eJ zj=cdGBi?XXuEgHp?p%b%da=ePQ;q#VYYhEVW3JeE=(W3Kp&Mxc|+u};!) zt&>Ko6JX)gn#8G9l@_1Ukzb{c2k zdt|jo6?(+zx)|5hV2tE;{Rj?cwWmukdLN*EqP?@%pX1$k5!wm5ZLXtJiwSD)ju<(MTqU zi$MOfxWoJ39}bXZNI4WYKug@Df}Z3j_aAW!-66i1wLfdU@piA5WJ_9uzSp`p7(VCH zJI`2}au#|cNC|o)=0-Pyp68s=mP{J)rkJWuZ0K7a@-zt_NFiXW67H5wy=BVUZttP} zC$%D;!P@iMNOsWg?YDvlgXdI*{}--TCe35HSk~uUHF(OZl*de^Ttr82z)VO=haVvP z7ef$fN(kPtX-uEKu=X zDk!TezE8#XP;~I&_n+Ts&QY%OUMn&B79Kd>Ub6qP!PY23?;{7=c_LV!hyNM^Nen<;4&FC5J zAfwpvbZS++LeO+W#Z~H(3kfnMRn%g0F&GYZgXLAcrE%Wt=ko~Hcr~vNIJ8VJ_yYIR zPBR~ETnh=>WkWc&OkZDzd5ezQ2T<-TarluxMml;(|fuJ*8Mv~xNqpfY_}746wg!g`Ica{7qY zuYY6+qR`dS-Rh~`47P{OVypx04|e-O)AH|bDvPYvf+akuc3N*Tl`tllFj5>9Ulm6q{uxhH)wABc1V*bDeiAyr``4Pqmq0XW91p#EbkrDZPaHsRZB26Qjc`p=+Dvj~&07Ogz} zkt&HY_$GMB<4gV~qkk4}Qb>=maGt7ya^M5QeeJEH2#|7+xLd!coZX;{e3y34)bDiW lUuRY{l@WF83y{wJWUp-x_-@a8ou>LkGUA|aWzXer{VxHv_lN)h diff --git a/model/.DS_Store b/model/.DS_Store deleted file mode 100644 index 5ec44c2b12d7451c6f4669a5943947b7fc082334..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKy-veG47N*!3Pq?Jj5jhNF|$Nf8JKwhAZ>t3fkdbk>XwCf;1zg29soYurz(G} z2#_85zMSJPzWWm8azw<__2P_ZLPR6D;9w6;kI8lMjtydD7sz^#C9UWh$0?PaY!CcK z1?1U{^_eZ{iR#a9bD0xn@CVA|K1~*#BloSw+DB%F zN-KH>E7v!rRY^JZN{8SS;0Ut>^A5a%Ui!Fv)EeahIFFQ~UEq*8-uLn7y~P`$!Wy+Q zj8op`))@JmAXves|sja{wt-S=bQujo1j`Q(j}Xb6_14cCQKrYzgkv lR5b%r>YDwXVWyIx8Rx204dlR!zQ^_{U!V_t>TeIANGoZRP{Q8b-rnBs-rerr z-EK5Y1;4*En$D$*it;aNOny2lZ{th40G{Hpo?_=EvsJX!o)+r1&T?ESEZb#S*LtO} zVjHOIUa40NP1_7>b}g*i^{`<#!lvD1%KM5}_9|Z}Ud1nO%^hfV%VQ6ew($+-Qr6l` z7@xR2aMyc2>ZL&%^pbBFy4oetcocK*PPp!SUJz}FxjRuBq)(;@W03exXW+SF{*yZ& zed;{8|IWR8&f3TKZ{PVy%t_SRo%ahwhvu97w4#!{f{L%$%u{UjOQp*^=BZy8w&rP` z{)J*gFRzTUzcPHw^p3gg#f~SF!KAI!>2UJOwlx50+ zI#S09Y=yd`Wh$Val>kdw8L*sbfR#)KG%~hTMOj5@qBPUmRy|X;D)qjk`}wtRXmKCjZR^4eqW&Otl3?E#waviu{HSQGyJ=_B=>|MWMPohi0`7Ox zAdcJ~HahEbI(4XNUJwd%=Z5d}g2;DyTNSk+a=N~o4!EC)dJOi5Ez>(B6sCV;rx-+a zZ>BzBCYsyp&Zh4k_E9@^gWiOe+KJ!mP9{GiO39$_b57nu7>*M}LFzbyZ3}ZFikzPN z#ODbwp_iA5q~_E4bOs6(J$QIu z)FYf8$IlOusJzGHJ~S5fUXq56SG1~!Em3(8chk__6K0ZE9Phel=C(Veug$cD{*iz0 zW1(}uHxMNVwBZZW^*kqxy#Ykgl9Y?`z1;YQ(8(O4+=-)Zu))uhq85SY0PN;2Jdz_- zCYNNl-~!;1rWnBjtZCJ8Iz2Z@9Otjff8JRAIDySpJKPO60)O>AKiN*>{;D5Fz5eRE zan$kq>D^wuy9$B)-fAazD=Ymc{2b<_g-BKa6xL+rnLmRW>LP0a+hV4=z%JnLD*b6x z+duWJds>++Wk<#VuzL+(4d1)?l7C`Dlb=e}k(QOmY@~Z^3}5JJPt^l9G{2)yFZ#-v zxmEJiZ8f=@mVXC7G=di!vtbQ2oWW5wtBnb3dZi5IR&A_})z8#n9baSE%oAwT*vwkle0FR@!z_zv&t=C^I}I7t ztrl8~*=bhL*9@m#hH42{mYUI6w?b^$c1XFk=u@@Zvw5fYruE@T%MiSs4Q zdoEj=NRwTfa2~6$vXo=^Jh&`p&!^|T+EW!Rx$2N-X=IIy`4P|ab-ye%FM;zF#EieE z&q4DGMa=ly@M?N7Q@nb5DZ5HB<3Js@C8jZhc|l^DGnf}8W^M*^O=4Ozn3p7Geg?B5 zF~?>ws}gg32J>=uWq3VnL+0nR#u)1}x{N)5NK1QecmsN0$*yNNx@vAy&?=*=*-Ey2 zK>HJR%U0#sD>Qbj%C)*VwP8VHX3w8drgmLWhOcBdvsX~wEcjh1_?^pMmi%tc@Ou@~ zkc8*5bJ;7?RiTxqo<$`yX_w%{<Hj_UGJ_8PF$v8jL*YC ze4_$zau;n?B1C1|_xp%l6vDzF;>U>K1c8$TP7zonaE8EX0!skx7da(LocM{7lm&r- z$DtDq!u}JXr%w=mZc%UZodk}PLgpRLW1a_{%QU!6-H4J1O?k+b-dR+Voq_Ko5LXA$ zq-x%aqm5ioQJYa(l#{;86JO{L5W4eNR_c;b5S#jl0y{hv+Ikp@stjuV7_z)blUHeS zEg6K+8(~QE_xgxMl9@yFCAj{u&nX2IMjAWF2r%?z8cNYbR5tw7O;gTq5b7mDmBGV} zJfU^ml&=!{0)dMF?bkTvsKU(C37KcKFt5{#5l9g?DgFq9Fce;qNeq1c#7X@M>UQdWeqx_|OWs!~kf zp$wURRMK&hnaVX(4zm&R?ixk+1NL3Gt_q*W*4p~i=kT{MUX*a;;{5MD46q96P#!SE9({~_4SnFURUb8w@#vsy;{HlJ z%ng-tEV?7hY_ai`a-blyLvEMVkQK~jbI8i*o?9Jjba!?y!xXpaeq@D%UK;d!Pprol zJfQWdwY%waAFUv-OIL zy|on!nRXH+X+rpPH!i4>YtJwtMjza|xo2B&6G6vMu1T>zxpf^VxQT(EOl2VAPj6Z8 zk%RyCq-%R0PDwxVahr>dNGu(l@4W#U>e0Z68v@rHpj%?a#+*)~4dZXm6m3w|Sy?_61u;+XG|2`B% zS|;rG{1ATBrKn7Xgy=ljo2+7h_*Hbclf{_4?pV7)x@ozEwIBZF2p>RXs}qN~ zse=I|$Jhzot{-e{rbkw_L#Yq_s3fZ&$Cpq#lPX`~DFcr=Og`qI?%;_mcQYD43krE_ z6V1Z11^YIWu7)o6Gv?*56aAs%l_UR1Gv!XE$@X9WaqBg`ArmT?!9Kz00SX@!@jD-Y!>fR3p=9{1vEG>i)^ zFxK79HU+ysV_;!-Xz*k0KH<_f=+Pu{Lm!VOrc=Bzpk8;Joq^jcdMb|N#U00yE6&M; zyhnmJiid-a-|L~5b_RCNMqbFxLIGMLzCb0RsA%ZA+SHAb#7-0UHU3j{-ofXx7)QoVQFi6a34a^R_#FbYSp0ngcM1Fq zK%pA=P}j_ z$v0JImFH6HUt;KQ@g;N)mFnVgynrCpXtFwcf!)Se#_roc{Xh3!+u$Jv@`yl8fZkVf z>r;5=Gy&fsAdh5W{i)e0l2i7do1N9q8NV=e`BK)NFY*d`iQzv7k!a=@t%>~{&b%hGH{;FxVe2hg Pprj5!cB;$2v+92VRwsH% diff --git a/model/__pycache__/conceptflow.cpython-37.pyc b/model/__pycache__/conceptflow.cpython-37.pyc deleted file mode 100644 index e8b3893e0ada050db25dce904bed346ae0341e7a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10130 zcmb_iOK=-kcI_V;K;uIY1i_CeiJ~M*uw_vnOSU|cY{{1F*b?pWPct1)Iv~11fx?IE zZbp>ydOk&GCL5=eNmWwWq`=)I%Pg~)RVs^Fq-K#tsu$U0r7W_@DwR!g?(Igi0V-LA ztiJEP`|i8%)A!w9<8G;BNcekJE?c{wOVYnmWAvw?@*W=V#{jlu%U#KgO4(G_xi1Y?YO`cKi%~mcht^=rAGtTzSrAuY`fFn45#k({f__a5Ru#U9ILf!H^b?# z?|uG5>x+kX9z3wtzj}E0-sj;Iqt@?z7$aJ=qR}6n75#bl0UXJcZ3$;o{)N<*ZN*l< z&6_Hn#b7AQ(9YULyM!|L3rUmAJoY>l7TzE9Th6wBzdLw(zm?$` zB#Fu_9u<%G=h(b2`LdAqm7#JV$&%27WT`eVmL}w_49W~j9i@&ki!zHchcbsUk1~(a zKxy~|p$KV9-v_h`eo?4sWw=%JO9DKJo_G=8tpc9nK5;JjastzRiT}KN9&m8 zmKL|FXw86{xL`F~IXh$P-_fdTaSvhjIq69MEc)v)Hfzgv_B*W4nv1a+8$7nMHArgQ zGfzE?%>X;uyYRn`TjX&#A>u|b;dm?_Ln}1w{C9LtOCucKWP~|xOL%{U_4_W#1q4yG zJH3|*>R&cA)9CcKcYVw2Ja^2(wmbTGpdM|k*#ht<&ZUx{aNk`*EnDKVE9oH?Q73M6f)9?6}73Mph zrfIuNz6fD6?3tX+_KVZouR5;rmtPLI|X8k_xoyw))5J?m`@x}Hl? z>E;NKjJi5as`UG2_KQK=?=|;A!;31GeK}RXVyYj0{#97?oju>eGU)b(S=-@?sXyrS zou(U>n!ew+T2?n|mZPRK(H=~+{iNL*^m&4Ke5pIw#3o1GQ_*n@+!)cwgc?xoP{+7- zXVWyELU&kvB4H7V28ttUxpiz4?>KK+t!~rvEbFhO|GKsIm50~3)^eMj&5pD7f#W^# z2it2-uixEXOI_==9(LPZ!z6Uk+qmc#x1YJQU|GiFT>_BgvaHLxl9Q8vr3~B&xk9C~ zC@(8@c@gbJWl^ok&uc0E7f0{BMdC+)-^WwJ^FAK$U+~I=Y|9H$prBMxswh>IP#(61 z63WBQpoH?Ub(ByZb`~X+hn+(SLd!hnAdmsa| zm{8}AQ0J09#X2eq;7Q~EQ-%R$rF9Sy=Vacw~OVnf64^+FhAhoHc1Sf?w zEQuLW6ALgSn$Sg79(@v{&!ScnE2yoCmEat&M;gwG0i~*PPK_l+3i!-n2e=T?& z+~x#MG`J8`&Wj7LpnM~^C{71!d?pKajpgOiSnKjw>&jT`s=qj_U@wqdI&C`7U~zBN zKPG1OWytzs@TOfCi{i}%$*zeKm4Bq%lwPPiKYk=Jr>s~Uy*-xHh#%xX;dku_ziTnS za*E$SCj73Y_+8^UXoY-wyu0hDOD{CgX}kI@OJQS@E(u7CEgu7m0OIxE#4c$-hPqU`2nOQlC&k3**iyM>`p9WnMB6) zohylq-8nqN;7)RaOQJ?MJ-EB~b#enf4&L9po1EJ9-~*8t??b9Sz>S!HqP)=Ezj>kW z{7HWv_vG%V=a2?>1-*E)`p!==yFNT6?%H#Bhe~iSH7|4(=Dq)a&RZC9=Xta4%8Trd z&hy?M^@xN9^P-pM{m*#b`|-RV^8Lwre~kCBA})y&;*Pj1u8MnD|A*rKLH?V3@KNwP z!NS-qwo8c#NodI9xzv*lZ^(+z9rg|8YrMT5R$Jsxa_ zQ-gjNb3Pr6*w0RoMZIB_1`wI=cDB2iyJLql>2-}3Q=ecMT z%1a1&`hrss;+$3;| zz-<65%)Sy!p^6g@<7qfXsvvT}5>tXITjmTQCj`Qr5G8sMM{tmAGe?6-e8cH9Ur4$4 z#EdTs)x$y@gpjD2g*Y^!M*6PcBsLx2EFRvTj%`-P7$a;jtLg9tAs1pmb7l;Ou#S-9 zgr48@9kVpLBpiU4rO~>WO;~AJ$;-A_K&xrYA2hqWkh|#??(#?`BC?q=R6;j6N|#1O zbmr{D@yEf@G_=C_H0I10DxN0U1Me^_kG2$@LRg}ubvv8=9^}G96DvqJG09<5?u50; z)$j|lXoe{M_QT2u5lv}bv?Ws3B~Ktyq2r*AuZVQv&ennjKl z84E7P2oL(=Xi^q7=rq!n6!t;|cS!S+poj8CDBEF`@=@bG=E9K?qm75NW8?`snUSu! zncLo`+vFjQ;DI$o{T5K+bh>5H^26#FDq7g=818Tm-f!R8q&pWPwx-)UYF0*MBaC~I zEg_&4NbRl`XBN6=hG09_lEN)5qwSw62+Fqa|Exc$fIU8o7G-qkj|>r+369tQ z85Lg|%7`7wzd;l`q?(L)Cll!CQ3a*`bVN;AksWIA6f$`91H|o{h}rjSKkMg2mSQc$ z9$yA|V2!=^BCJ3hANO299G|7OKrC;|h}xg;*uoeh+CRiENC&bk?bwebkwXt6DYQO@ zFJw^aZ(j+DdnZH=^VE?AC=My1SlUNcVJ;f?w-{F(mMCe-3A4;Po8ITO=(-^F6sJWl z`fyz2W+)qtvP{JvDQmwI^DUi=HOM6{@UP;Rf zv2BrKK!4;pB*JnRt_d=`czp=+k}&%%XtoYicbw4_iE! zD5d&=N%iv_Vk*@S%uNaF$blp8012z|FptY&;3?pOl#zZ)w<{3Nbi0VFGL5H-XJ+po zAj~aM$%4zT?F?GPNfSsLOmNbe(*`){%&C%anggdfaGDqM%t>*PUq>tAG=Ip6)`hD& zAO{Ob;mio5ts;QX*V$8!bJ;HC4%s_4HS4SOPa0MCE@93S_>{m& z0u2ImRYMidmiv%uIvgcxyyB|5KWEYoKH!j3Z5oG;J7<*}i_s+=pwrzZutR_yNOnHi z$K)&gLv)0Cl-+ebYUJ5CZM%aFQ|GF0j{AMwRUh-%;>!%rKXgXfs7;+4aB&BPg^P!l*&07d_-eGE&9k+3w<7TXI`4@6~!8| z&KRdp_COKIuvC-WrXn>6v9I7V;IkBR#ua3XGeSXL2f5n=b(p6zm_|Q7l%XFgLqLs! z3JTR$HdXYi$h1^BH(ibCtGIU2NMuFhS`9pkxYYEaD#KZdS;k~UKEc$*^x^)X#AbzF zlkh>TE^<@`bL0KE-b{gO%M`&P@&{E?t)nwy9y5R=J~h-MUdL#5JSAHf$H+2|&ZhwD zTthopN>R)mr5GwusVe1&U?`N}c#3*{f*J~Yg8D>?dSQaPe1!UBih6N^8p?ix_i~DQ zCf<21+O?>r&t`&CwChu1Nu0u2uMii3{5i7Hr)~W$DLBo1QxeEB8p|eD5!d)M&P243 z#WeP8jLnTikFsoyj2CwfN*7q9M&ZG7^*BYo@5NE_W?Ba|# zE6$@li#%)tCGzLu^f8z$5xpC4;AOFyp!*(X5h`;6( zoX%uZj7kVOrXy6yHR2!Uqd1R($s$z~Vk0s$Y{lX_=YO8lO8=06AP|hNnPW$dFS+Z7vCdJ=zJmhqQnx2^JYU z7h_?07<&x!<un9N>ewV-#0;DI*Vj{Ve z9n7{Y&a#_1BveMVK~nRRn#J*?XXZhdR0oGu{4#_BPR}eNhnJ+dBJK8=CMOk#(G(oB z189KJVSO;FyImrlxIF7LmM#y4ywT<13xF%Y9v*x=z6$XLqDxg*7AHKpB)iv;?>7wk zSxLXybMUhg!?N&mhTSgJixzg=>_$EOg9i6!v>*!F+2M5QBD=JpuoVAvqUChEmUX!1 zhU}7$=(3YaKVd21@yY;5ovh@3BAw06<}T&#@XuQsvr%e*a{!-!#w-_Pm_cF?sRbQ0__i?m-B{{pE-YE(&bp)* iLj57{o|kjDi>O9>2y-{1^SMn+@hSkMa*E0SlJ-BSHtLxG diff --git a/model/__pycache__/embedding.cpython-37.pyc b/model/__pycache__/embedding.cpython-37.pyc deleted file mode 100644 index 2d0f57b3c5125e41c68778c3e240202265997bbe..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2038 zcmb7FOK%)S5bo}I?Cj3&`VkU45ky2h2Za3tcnD%lF1CeYCy0?yYck!oXPKQ{chAJJ zwVZ4zr%3#S^??JwfCJ)pAesXQPC0Plh=fd4&+a;LkPyA<>gwu8RabrA?CnmcMWFrp zUC4j%3Hco-)8>NmF?5}S5Jb?N#PvvH26L7>eK&S#ZS(qm>|5Q*y?!GOKzHM&@Zy&6 zMdOsj;T<9Z(L5!hDg9pi#ECm#X^9Xl;l$Dvbelx&r{IxnMbrfM6O|!W@kh#j>wpvk!>=9E}%zr zD{_oi4ph2t+_VyjX=j=z50fmPICLH+>U*LBxaY6;SMO=5^lGY-Y%i0mpGy68IT)_W zevuDX*9Ju@hvjGa;Bd7+5Hbh*11UsS?5zxsOi*db(M5=HL7pMf77duf%mG1;Yp3U(UAj&OtvxpW^Bm&6`C!!vX1MVXaH>r;>E^b#D=8IJl}5G5(; zh@3FY2qhynBHSt6o;RXBU(oAN(j~}G&jDx769-lsWiW1z=o$Tng4Ti_3OLns&e(|V zGC|K6thnQL*_rMz?cEM=A}*NMH~u=VEi3f=z%g$nv}*nmTI6)^YcU{eOZDZtfcV| zWwy6pnvRkORR&SII9Qp!Iz>ugLm*&ZgRYlB04azqFO+0VM+^$h zMyw_WQ@BID3LsWVLFquvK5c2(WYkZHLy9!erJk z57_<>(>h-al$ZhL-+8tJcm5a8EDwOt6OOs#plp_{afqdu@uw}2FO6QCcPv-jnk!I? z4!BTV6muYUqOL+vs0al@#DtbNHOkdhG+#rpglpYHc*8T`ei>H~y%y%`I%-%vbrZy- z@TPTfNWFvhNwQ~kw|?J%6^&N`G?`BkmdDqgOWBK%L}+3u)}Y(7Q%Kezr0QK*h?vEn zX&qEj9r02=Dlg&AV&gr$!(yEtUwz*13~dU7RBxhq3kBxO9#(IoW*_tdoG`NpzA2mn zo9k$<{xW#*)q;;aD~6_P|4vqtsvIEm6$WF1+E^5IDe$!>C<<=>&dkF7^}Ap4?afa% zHu%=P&D-mD&73uDt>2j#CLfngOAU+Uy4JH`FFp|__@HLcN9bB3E&>V!z#ns|%L0G# F%0Dxf;C282 diff --git a/model/__pycache__/model.cpython-37.pyc b/model/__pycache__/model.cpython-37.pyc deleted file mode 100644 index 838e5f19f435d2965e0bf9b8cd9a7da3a0e2a1ba..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 9818 zcmcIqO_1A0cE%q;68sqs$>IFU8jVJ>XvzAuC0mwcOR}W3Esg)JEY=RfAwY^F5}W}r z)<|h2m6e@6sY;bgk{c%{C&yGyNo7x|oVRjGf*f*4B~{w0-1d~Jl=HpD&w!+oTgbxe z*RNl{erR;Rp9XG}N}7P*kv}UK|MamS{2Nu`p917PT<*^SOu-cUf*zKlE}<^<<$>`-MSKFLFHFFAd6ic`&2Tpq|q!W?r8)HM8)Y zpjYn*X3;EtCzvHm-Kss6^|~p3E;P$8pog&06ocGXEvMJo>|4MyyI!yF0x93LY|m-+ zZ`tjkX*n-MFoRtDc-OO>h}e{Z(*2=h-WqILrrEQ*!OSh&>v@k8Lica}_8_WKmI5@_|M8WLYqXFvisZr8O85?>~rQ~*KL#GK@mfVHU_|Ktd9uK=+G zT{Hz<`i0ODP0^IT%jvQyn+m31p+|)>V`j}9O4ZDp8p`Z11VzwuxV1u1NNLojbwWQX zv$$v=_c`Xu6FiX#PbHcmJZ0d9%)~7PrGhepGJ{e@siMrH%%aSp%%RMq%zGM>n6Q<3 z3TS9vfyt;TyjJjv3^IwIWRcu04OihQNi2G$h*sd-C9k|SV@js{jP$7RniAD!QNt7D zQqXrROH=TS^sGfKM9s|O>5{go_Bz_6JJ0&uE^OV;Xi;NLk@)vaoU6$>8Z=TgYZ_@PgLz7fnUi zdiM6NXSls5mR@+!Gfm45E1JDKF#4^>Seklav*ooP7@eNudRW?po^5ojmbdFzZcrSe zds0(2doA|`Er+J$%%X6r0551=G!F$CcX!)z99pkI&M9ySdN9I9u%;V zZP(aEn;R4?+O1I~6VlXBsX*)6cBEjCbsrBlhke(fMd;)RkPZ%Qia@e$J^T5v;|*GS zf#!yYVV+Cj=XLqcXI}MboFk}wpvb5YI&Y*w2gjPEr(TWqCT9c zdr7@LwD|y0|5AV0#ZAV|Gp^OQ+OQfCp{WE4NS#1MzfP~KYmZ>Q3_c@45e5&oE3Dy} z<_)9WZ@I2vaHHZB>F(zN8o1ok0D@Q+RXHc(FNLrS$}3_;LMbnat5QQ;LVZbEQtIN9 zdTN`WjvrN+W2Fmb`Jv?e0Lly~Gbk%4E1vjN@Mrhd zScPoP`@%OBzsjm?mQ`4FQMe-P{Ly{ER!s0PWlR2A@@sybX&jsL8>|RS@#p+`f5D$) zxrn~#DPG1?M_H!qi=&)5yDyqmv-V7SD*B42FvYAtltG`{7eQG{xN}q7xui{+YLk*S zd8$p0+VU(j)s}$<(wGLziKP4#cRp#;rrNZmtuWPANZN{1ZAB(;$z}udzrv&lJH%xF z70emQR7V9?GaKFHo}Y@Kp%+^*8w&!pn2mketS<^3!X*DNylIKmSe-4x`YTLjS(anV z`+0vAEoBBt{t;j^>`0_n1#4m!J%p|5zWfx=VPC@2UGtBEs|I?J&F{BCGD^ICn0T~VMhH^5$6Is^%Bn4 z{nPBIe}>P=qFLwb3EyYieuhvi=_CO-<;*`9Lk-TzB-ajPrE~eyN zgWi&Y0l$mN&fBBxUftZ_s;)0ZXuS=$IUtw$ovf-_fFix5*|M8Z}_4Q8FcU>~rL z*j;vyZLn*jqx-^W9oMmafql$AVV|;h)3lC;EaiZob95-IU)%_^on6a$Y*2s_6cB+u zvYfC|><`uFM0KC&`dxm56 zgF0FfLPlZB(6)>R!|h;ZX!kMZqhTz5Zh|dr4XV_E*gU7V-N)EHGpMEYij`?h2#n}o z1X&86&ZApCQ~~Id3Qy5a-yPcBFD%;~I=7pmUUICR-5&1J=sGRap}go~)Kf!G=eoZ>y+ zwgNSj7|3nRpMo2uuG{^d7lyDrqFp25nJ1Prbe$EFwM>AbTW6iXF#^X4@YwfNU_r(j zY~z7R_I8j#NZqx9{I2cdDO-?plsNJTU}-jnHUj9dA;=#@;t_uyf zcY{o*yXp=5J!Au)IVT}X*TM|JYBWN$(=+wTm|(>6dNrMMAP+(|rq{-3 zNU(@F&ggnA&(cfrzTjDbUW)bNVuDKBNS?64*JL$K`L@;H#bTRnWBbMnA!MtKkrKY) zG+&w&^67IE(;sClvoHkX!|1g!QZ!6*5Bv^;a(qkSECeM|t>5d~1FQ=dRcMf|VvwY& z7rM^yQEXboUQnM@jX%w7l$Qmf7tF-Ng=w)DRANFnz5$jo4Gd488PkDtq_qGc%-njm zY3&8MIL*_GJfGuKNYP?C_>1wNeC5HWk+r166ehSwmX}t2AZ`Yt8C2;MXuL&VoazzZ zcrZ7{o{*CfX`}18?QPmx*z!deqA4#oKniBlHG|X-s$-;3vAHqY!5I90ZL3QgFhKT8 zTQ{y&Vz!tP=CO1_#*^6~+Zj5KT8>E;!u>UXK&kA@{}PVD@K=%GRwV^B4VWgch*fbV z`X`u7t%`VAUCLJWftCvDD>6wWNmVYpDyq`DR35KEi#6~=C-F%wos38G&?ZYyKRS#g zC&6>^8}KwZKf|dLax$D5jYNi1N)x#P;!m8ajc`UaLRKErK-IhqPPQ~00m|P?rr1IL zKFU%L&W$`0r&W@vJfp0t;1|cZr4+Z6;w$;(F}_TSPfqd8_?7XISpjEUDP67{)c9De zlxhbh)y~IYGpTk^t_aX02Z{U+aK2C;r#Ow)giq zby7 zo>GG~nDQ_SDZeyT)JT%Tkl{^8%5X`Wkfd@+l|<4!B+Wz80$bpcBqs++3z4LS14*P8 z;?5Z_LzyJc(2%W-kUe7^EF$BsF|8vb9?U-ZOKmm!#kMFmfR|T)7CL+K!lzcxom!<+ zs?+>%?wh37gpr)+H0A~2H&gw_P3J_v#7J<1_n2JnL02BD@xTl(3Y-~i2_B&7it**f zN6o5pohVHL_Xw;Kc#QxBW`PXn?|ej<3dc>APpOSgx}47md>FQEZ(Ft*$UVGppwLd& z4#G!%`q*6J5s}lUJ_7Qa^~P?kU=QQmA19L) zoJnSwgp&iD5bn#PJeB?|+VMsN?XYwV)*oE(1E$oK(JmvyuksVnYQ$fKwFrCS>@C9T zkWqv!%EDr2cy>I;DH)beC=E7yaDTA9a|~}O@V2DEa#Z^BcB%;%QQN!lBu%gdDzEzehLE3B<^}8#W6q4F_Or)(Afn1o8%u#aW718 z!_Q1`zmno!oZv1`aUV``FHLa6M@`6HO>x(vJFka#&8q3SnUECT^%1thj$p3WND9L# zEzUlVn(8INKg#8%Rv=%a(b|L-5yod@J>Ar`#rt`xLxUhfjfRZ@wP-gElqk1a=6R1Udvf z0KFbLv@mImB7A3;n&u~(Vk64khx1cnDTEf7vIRFzF?of@=&=h4DV#O7*&!*B_>}bd zjwF_nY=o@7UQViEcO(MQyQ;XaJ{ze?J||32lAGeBgpnr~L<%@Y@`HSsa8OoLBs5{x zLQlgTS%gzw;yfZv_Xs>D@MK(Pp2R3Tfr-v9_=!Od2O8mw$bBKS9pR-NjE%4bCxH!Fnxd4aC^J)(iYHToL4i74pae%mb(z6L@_Lom zGZd)vdX3lN#-e&14l9F0XbO-Cs{(6qjDpm8j^XtQ97)4dS$>2U_T*^cNE>Vc6fH$r z;Iil78sUr?Pf(!IBBwSYD!q&+XRySn^AVL^+!J*f4ebAn-c`S_C!;(1LRw5TGzqFD5IO&aAR+gCF)(7IVPIGH%~Hli~YU4x(ld>G?SkC-Pm&t2+En};lFxOw;n zz!u;^9>P4n6`~D9o2nr#O$2kvnNu@Q-yGPjfrW1lG{eBxG`oGmiw17D)el?vOLOP< zNE?bZ9ooPSDH4>TSL|)8-;ZCjlfuX~=>2;5%DoJ14wp+W;B#W}Snl_Qlesr?*ZD^e z%{iWHgh|B*;LywA$#QGg8}j33r0@99CM4Rna~M?z;U;MK*nxlUgl7j!i;tc7@oK9=0@All2j4;HvtQl%?5D=(_paZ)YixXZ@5Zgq^chatxOFEY@I)q1@9_!Y iil&|feIRdcajAV$)5QGlSAl(g9p)Bc8KqI#OM7LU#$`Y4mxFYG z(KEz6=HDRZbMIpBwVe(bK0`@x2fom0f_3ncikzMo9CWu)c_HsGm=ag_f35|al;@nW zyqf8~Nu_di^$rLc$w}5Qs`p<`p8rbDPQE`rCZ{h>o=l$Wy$(8^Jlg=7@wXlJMv-5E z!cmGDO073&iWz3sO_18mX3h;toj1rqsSEMCI{v9qTuiF0W?W?38_a6zXIA+O2!#}G z(Q9jE-6D)yt3||a0U?fZFT56AI1G=W zD)ey9DI<#J1u53ZL0xg8E^ETO(r;n^te&cp{-GneRtWnhv9%+4F%|p9$8d9iZ`p&$ zLPD~FN=e9H=-)@<7m^D(&IHY8IUoPPDl_R47;gUrA(f>&^B|7I`UrML9K=C*F@Pc1;LWdM6%cICIhErzop=t1j& z7J(&fqc-f$Tgx6;qP7Q?xU+IUfk0Q{c6$ zv8!s4&4KAfkt=OyRB1=iYA#0R+yMw}%cc~c7%gln`KsaknkNJC&>+z{RoR@#{F>_r zTb`B37X_CjXCGpuYRdG}5C6C0S39P8(m?Qns=TgDD0z{FyOv51b}H!_a9WF!vvjmQ z@g5p*a>?`AT%|EB7IV7U>Hao=&awF1WZH>Cladd2UQ)q6f{lCt0^u*LLp;PoE3$?- zvOc!1NAGpX=0M_^M^e#}6QUzRpy!&x=rJL$8d|I^UGIsnji>&mmor`z@UC`oBv^a~ zc45Nl9$yHGG%Xn5IP8#QUwjFB?Lt8pSMUHAdeAX>L>pDl1ZBd!a~-U~suCs>VVXe) zRYkf_qz-qL7w|bpU4IDE-XbuH1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T05Ukb<7ECrYoG%dkgXK7bi|+^IG!Yg~&RIt4bz}CmdTc8Kf(H(K}VQIeazOajm7?I91UT}kFyluzB!y@~9z_}fcc*PUO z@BICE*mJM{nA!P56qN!}Knh3!DIf)Yt$^2FT6dMGQ3^-_De$F$e;*p%aVVS;Rxc&ogja7s)%ENild85Z9_O{T(eLS=^Gs6wc7wvzp~i_XEn#h++gMp#Hllj)xaCfW_9o; zqc2Fm@di=JuurA^wZ{YRH1shy=UEV@Bek^p6gM#rG14B0KpGSJ0ElV_f_iibZ^AHU=sLy zM}9iZ;$+XCMPagc7)K*N$vz9?S9?WTCEZOfWi9a-3!({*jHE@fl(#S+86fIY;gzyT zyn!qxBsUhnW+X$}SNM;}H$Ua(;`>q}%x!F>kcPRn7{8MW>rOMwYl13EB@;CaGPSr6 zWLXONJ*&s$;;A~c_oXM~;t|IB30ZtBR6Yl}EsSeT*s#}Bbfy>7*d`Ml$E_oGTOGGL z>KjVQ#qTG|>VIv!B(DoY)ZwSW&$#tNiOiui-Qlz0v9wf;+|MiiMa94D>(b<2H1T;2 zM<^{XN&JZQ8?uh_1=;06z@(M&APIe`NAnr4!_28i{;Ntz14+*%ok|@7r-5R*rKTu* zNt!D3T2;qN?FLMV(@hXU4Z2P@=|kG3A80+TK@aE-Vxc&gq0? z8lyjJxweG#C1g}cFD%U^owbB&OUPdGtN%h5t;`VQ+*DO4`sZ!lKu0vT$xH2}o;UMW z-UjdFEBR_>iS`u*UA>}HE&VgAy;caad@Vz^*S1(=x+p4qgKo^%v0DAQA=YjvGsOCj zSR>znT@y7W=PmedTv5>|_N(6%JBCgdO|h{}fU##_L$oT4T?3nSJT9sh<1p!$aP2GNW#V)7kZ=oR*l$ZMN~Gy=?p;bf zz>8vz!x-__pq6HO=E1(`+|H8}?~96j%z4ZUJ1`5)&ElDC!frM%m*(rr*%;5mu-}!< z>8b0na~D~rvNiTbnag;b#PiH)D%rkUWa-pOuzx!lIOexrnC_QWgZowkx3U_zmDPZ& zK6fm{*oU&@1IIvw5-ArYRUS$@aBSQ*SEb{$)okkfNeLS>FP*wU#Qep7@7U_;JP4Wl zZb#}0rMn(JuAuNi6Rs9w3YWxvA`YmZ@F&~{t6A2`jdfz?a+h9MQ#cUq>pQvbQ2gF8Z LE%j@*Zm;|cwIj%> From 52479c8333e772ad2960d18330dec940a2a41a4c Mon Sep 17 00:00:00 2001 From: Ilya Lasy Date: Fri, 26 Feb 2021 11:34:01 +0300 Subject: [PATCH 3/6] create smaller dataset with fixed size in GB --- split_trainset.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 split_trainset.py diff --git a/split_trainset.py b/split_trainset.py new file mode 100644 index 0000000..ddf045b --- /dev/null +++ b/split_trainset.py @@ -0,0 +1,15 @@ +import os + +def to_gb(b): + return round(b/float(1<<30),1) + +max_gb = 1 + +with open(f'./data/trainset_{max_gb}gb.txt','w') as f1: + with open('./data/trainset.txt','r') as f2: + for line in f2: + f1.write(line) + if to_gb(os.fstat(f1.fileno()).st_size) >= max_gb: + break + + \ No newline at end of file From 242a48d29513fc5daac2ed1cc523fc84041915cf Mon Sep 17 00:00:00 2001 From: Ilya Lasy Date: Fri, 26 Feb 2021 14:18:49 +0300 Subject: [PATCH 4/6] switched to IterableDataset to avoid RAM consuption --- dataset.py | 11 ++ preprocession.py | 381 ++++++++++++++++++++++++----------------------- train.py | 20 ++- 3 files changed, 214 insertions(+), 198 deletions(-) create mode 100644 dataset.py diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..c678238 --- /dev/null +++ b/dataset.py @@ -0,0 +1,11 @@ +import json +from torch.utils.data import IterableDataset + +class ConceptFlowDataset(IterableDataset): + def __init__(self, txt_file, config): + self.root_dir = config.data_dir + self.txt_file = txt_file + + def __iter__(self): + f = open(f'{self.root_dir}/{self.txt_file}') + return map(json.loads, f) \ No newline at end of file diff --git a/preprocession.py b/preprocession.py index 8d65461..02c97ec 100644 --- a/preprocession.py +++ b/preprocession.py @@ -3,8 +3,10 @@ import json import torch from utils import padding, padding_triple_id, build_kb_adj_mat +from dataset import ConceptFlowDataset +from torch.utils.data import DataLoader + - def prepare_data(config): global csk_entities, csk_triples, kb_dict, dict_csk_entities, dict_csk_triples @@ -17,22 +19,9 @@ def prepare_data(config): kb_dict = d['dict_csk'] dict_csk_entities = d['dict_csk_entities'] dict_csk_triples = d['dict_csk_triples'] - - data_train, data_test = [], [] - - if config.is_train: - with open('%s/trainset.txt' % config.data_dir) as f: - for idx, line in enumerate(f): - if idx % 100000 == 0: print('read train file line %d' % idx) - data_train.append(json.loads(line)) + return raw_vocab - with open('%s/testset.txt' % config.data_dir) as f: - for line in f: - data_test.append(json.loads(line)) - - return raw_vocab, data_train, data_test - def build_vocab(path, raw_vocab, config, trans='transE'): print("Creating word vocabulary...") @@ -103,180 +92,198 @@ def build_vocab(path, raw_vocab, config, trans='transE'): return word2id, entity2id, vocab_list, embed, entity_list, entity_embed, relation_list, relation_embed, entity_relation_embed -def gen_batched_data(data, config, word2id, entity2id): - global csk_entities, csk_triples, kb_dict, dict_csk_entities, dict_csk_triples - encoder_len = max([len(item['post']) for item in data])+1 - - decoder_len = max([len(item['response']) for item in data])+1 - triple_num = max([len(item['all_triples_one_hop']) for item in data]) - entity_len = max([len(item['all_entities_one_hop']) + max(item['post_triples']) for item in data]) - only_two_entity_len = max([len(item['only_two']) for item in data]) - triple_num_one_two = max([len(item['one_two_triple']) for item in data]) - triple_len_one_two = max([len(tri) for item in data for tri in item['one_two_triple']]) - posts_id = np.full((len(data), encoder_len), 0, dtype=int) - responses_id = np.full((len(data), decoder_len), 0, dtype=int) - responses_length = [] - # posts_length = [] - local_entity_length = [] - only_two_entity_length = [] - local_entity = [] - only_two_entity = [] - kb_fact_rels = np.full((len(data), triple_num), 2, dtype=int) - kb_adj_mats = np.empty(len(data), dtype=object) - q2e_adj_mats = np.full((len(data), entity_len), 0, dtype=int) - match_entity_one_hop = np.full((len(data), decoder_len), -1, dtype=int) - match_entity_only_two = np.full((len(data), decoder_len), -1, dtype=int) - one_two_triples_id = [] - g2l_only_two_list = [] - # o2t_entity_index_list = [] - - next_id = 0 - for item in data: - # posts - for i, post_word in enumerate(padding(item['post'], encoder_len)): - if post_word in word2id: - posts_id[next_id, i] = word2id[post_word] - - else: - posts_id[next_id, i] = word2id['_UNK'] +def get_data(config, word2id, entity2id): + collate = _build_collate(config, word2id, entity2id) + + train_loader = None + if config.is_train: + train_dataset = ConceptFlowDataset('trainset.txt',config) + train_loader = DataLoader(train_dataset,batch_size=config.batch_size,collate_fn=collate) + + test_dataset = ConceptFlowDataset('testset.txt',config) + test_loader = DataLoader(test_dataset,batch_size=config.batch_size,collate_fn=collate) + + return train_loader,test_loader + +def _build_collate(config, word2id, entity2id): + + def gen_batched_data(data): + global csk_entities, csk_triples, kb_dict, dict_csk_entities, dict_csk_triples + + encoder_len = max([len(item['post']) for item in data])+1 + + decoder_len = max([len(item['response']) for item in data])+1 + triple_num = max([len(item['all_triples_one_hop']) for item in data]) + entity_len = max([len(item['all_entities_one_hop']) + max(item['post_triples']) for item in data]) + only_two_entity_len = max([len(item['only_two']) for item in data]) + triple_num_one_two = max([len(item['one_two_triple']) for item in data]) + triple_len_one_two = max([len(tri) for item in data for tri in item['one_two_triple']]) + posts_id = np.full((len(data), encoder_len), 0, dtype=int) + responses_id = np.full((len(data), decoder_len), 0, dtype=int) + responses_length = [] + # posts_length = [] + local_entity_length = [] + only_two_entity_length = [] + local_entity = [] + only_two_entity = [] + kb_fact_rels = np.full((len(data), triple_num), 2, dtype=int) + kb_adj_mats = np.empty(len(data), dtype=object) + q2e_adj_mats = np.full((len(data), entity_len), 0, dtype=int) + match_entity_one_hop = np.full((len(data), decoder_len), -1, dtype=int) + match_entity_only_two = np.full((len(data), decoder_len), -1, dtype=int) + one_two_triples_id = [] + g2l_only_two_list = [] + # o2t_entity_index_list = [] + + next_id = 0 + for item in data: + # posts + for i, post_word in enumerate(padding(item['post'], encoder_len)): + if post_word in word2id: + posts_id[next_id, i] = word2id[post_word] + + else: + posts_id[next_id, i] = word2id['_UNK'] + + # responses + for i, response_word in enumerate(padding(item['response'], decoder_len)): + if response_word in word2id: + responses_id[next_id, i] = word2id[response_word] + + else: + responses_id[next_id, i] = word2id['_UNK'] + + # responses_length + responses_length.append(len(item['response']) + 1) - # responses - for i, response_word in enumerate(padding(item['response'], decoder_len)): - if response_word in word2id: - responses_id[next_id, i] = word2id[response_word] + # local_entity + local_entity_tmp = [] + for i in range(len(item['post_triples'])): + if item['post_triples'][i] == 0: + continue + elif item['post'][i] not in entity2id: + continue + elif entity2id[item['post'][i]] in local_entity_tmp: + continue + else: + local_entity_tmp.append(entity2id[item['post'][i]]) + + for entity_index in item['all_entities_one_hop']: + if csk_entities[entity_index] not in entity2id: + continue + if entity2id[csk_entities[entity_index]] in local_entity_tmp: + continue + else: + local_entity_tmp.append(entity2id[csk_entities[entity_index]]) + local_entity_len_tmp = len(local_entity_tmp) + local_entity_tmp += [1] * (entity_len - len(local_entity_tmp)) + local_entity.append(local_entity_tmp) + + # kb_adj_mat and kb_fact_rel + g2l = dict() + for i in range(len(local_entity_tmp)): + g2l[local_entity_tmp[i]] = i + + entity2fact_e, entity2fact_f = [], [] + fact2entity_f, fact2entity_e = [], [] + + tmp_count = 0 + for i in range(len(item['all_triples_one_hop'])): + sbj = csk_triples[item['all_triples_one_hop'][i]].split()[0][:-1] + rel = csk_triples[item['all_triples_one_hop'][i]].split()[1][:-1] + obj = csk_triples[item['all_triples_one_hop'][i]].split()[2] + + if (sbj not in entity2id) or (obj not in entity2id): + continue + if (entity2id[sbj] not in g2l) or (entity2id[obj] not in g2l): + continue - else: - responses_id[next_id, i] = word2id['_UNK'] - - # responses_length - responses_length.append(len(item['response']) + 1) - - # local_entity - local_entity_tmp = [] - for i in range(len(item['post_triples'])): - if item['post_triples'][i] == 0: - continue - elif item['post'][i] not in entity2id: - continue - elif entity2id[item['post'][i]] in local_entity_tmp: - continue - else: - local_entity_tmp.append(entity2id[item['post'][i]]) - - for entity_index in item['all_entities_one_hop']: - if csk_entities[entity_index] not in entity2id: - continue - if entity2id[csk_entities[entity_index]] in local_entity_tmp: - continue - else: - local_entity_tmp.append(entity2id[csk_entities[entity_index]]) - local_entity_len_tmp = len(local_entity_tmp) - local_entity_tmp += [1] * (entity_len - len(local_entity_tmp)) - local_entity.append(local_entity_tmp) - - # kb_adj_mat and kb_fact_rel - g2l = dict() - for i in range(len(local_entity_tmp)): - g2l[local_entity_tmp[i]] = i - - entity2fact_e, entity2fact_f = [], [] - fact2entity_f, fact2entity_e = [], [] - - tmp_count = 0 - for i in range(len(item['all_triples_one_hop'])): - sbj = csk_triples[item['all_triples_one_hop'][i]].split()[0][:-1] - rel = csk_triples[item['all_triples_one_hop'][i]].split()[1][:-1] - obj = csk_triples[item['all_triples_one_hop'][i]].split()[2] - - if (sbj not in entity2id) or (obj not in entity2id): - continue - if (entity2id[sbj] not in g2l) or (entity2id[obj] not in g2l): - continue + entity2fact_e += [g2l[entity2id[sbj]]] + entity2fact_f += [tmp_count] + fact2entity_f += [tmp_count] + fact2entity_e += [g2l[entity2id[obj]]] + kb_fact_rels[next_id, tmp_count] = entity2id[rel] + tmp_count += 1 + + kb_adj_mats[next_id] = (np.array(entity2fact_f, dtype=int), np.array(entity2fact_e, dtype=int), np.array([1.0] * len(entity2fact_f))), (np.array(fact2entity_e, dtype=int), np.array(fact2entity_f, dtype=int), np.array([1.0] * len(fact2entity_e))) - entity2fact_e += [g2l[entity2id[sbj]]] - entity2fact_f += [tmp_count] - fact2entity_f += [tmp_count] - fact2entity_e += [g2l[entity2id[obj]]] - kb_fact_rels[next_id, tmp_count] = entity2id[rel] - tmp_count += 1 - - kb_adj_mats[next_id] = (np.array(entity2fact_f, dtype=int), np.array(entity2fact_e, dtype=int), np.array([1.0] * len(entity2fact_f))), (np.array(fact2entity_e, dtype=int), np.array(fact2entity_f, dtype=int), np.array([1.0] * len(fact2entity_e))) - - # q2e_adj_mat - for i in range(len(item['post_triples'])): - if item['post_triples'][i] == 0: - continue - elif item['post'][i] not in entity2id: - continue - else: - q2e_adj_mats[next_id, g2l[entity2id[item['post'][i]]]] = 1 - - # match_entity_one_hop - for i in range(len(item['match_response_index_one_hop'])): - if item['match_response_index_one_hop'][i] == -1: - continue - if csk_entities[item['match_response_index_one_hop'][i]] not in entity2id: - continue - if entity2id[csk_entities[item['match_response_index_one_hop'][i]]] not in g2l: - continue - else: - match_entity_one_hop[next_id, i] = g2l[entity2id[csk_entities[item['match_response_index_one_hop'][i]]]] - - # only_two_entity - only_two_entity_tmp = [] - for entity_index in item['only_two']: - if csk_entities[entity_index] not in entity2id: - continue - if entity2id[csk_entities[entity_index]] in only_two_entity_tmp: - continue - else: - only_two_entity_tmp.append(entity2id[csk_entities[entity_index]]) - only_two_entity_len_tmp = len(only_two_entity_tmp) - only_two_entity_tmp += [1] * (only_two_entity_len - len(only_two_entity_tmp)) - only_two_entity.append(only_two_entity_tmp) - - # match_entity_two_hop - g2l_only_two = dict() - for i in range(len(only_two_entity_tmp)): - g2l_only_two[only_two_entity_tmp[i]] = i - - for i in range(len(item['match_response_index_only_two'])): - if item['match_response_index_only_two'][i] == -1: - continue - if csk_entities[item['match_response_index_only_two'][i]] not in entity2id: - continue - else: - match_entity_only_two[next_id, i] = g2l_only_two[entity2id[csk_entities[item['match_response_index_only_two'][i]]]] - - # one_two_triple - one_two_triples_id.append(padding_triple_id(entity2id, [[csk_triples[x].split(', ') for x in triple] for triple in item['one_two_triple']], triple_num_one_two, triple_len_one_two)) + # q2e_adj_mat + for i in range(len(item['post_triples'])): + if item['post_triples'][i] == 0: + continue + elif item['post'][i] not in entity2id: + continue + else: + q2e_adj_mats[next_id, g2l[entity2id[item['post'][i]]]] = 1 + + # match_entity_one_hop + for i in range(len(item['match_response_index_one_hop'])): + if item['match_response_index_one_hop'][i] == -1: + continue + if csk_entities[item['match_response_index_one_hop'][i]] not in entity2id: + continue + if entity2id[csk_entities[item['match_response_index_one_hop'][i]]] not in g2l: + continue + else: + match_entity_one_hop[next_id, i] = g2l[entity2id[csk_entities[item['match_response_index_one_hop'][i]]]] + + # only_two_entity + only_two_entity_tmp = [] + for entity_index in item['only_two']: + if csk_entities[entity_index] not in entity2id: + continue + if entity2id[csk_entities[entity_index]] in only_two_entity_tmp: + continue + else: + only_two_entity_tmp.append(entity2id[csk_entities[entity_index]]) + only_two_entity_len_tmp = len(only_two_entity_tmp) + only_two_entity_tmp += [1] * (only_two_entity_len - len(only_two_entity_tmp)) + only_two_entity.append(only_two_entity_tmp) + + # match_entity_two_hop + g2l_only_two = dict() + for i in range(len(only_two_entity_tmp)): + g2l_only_two[only_two_entity_tmp[i]] = i + + for i in range(len(item['match_response_index_only_two'])): + if item['match_response_index_only_two'][i] == -1: + continue + if csk_entities[item['match_response_index_only_two'][i]] not in entity2id: + continue + else: + match_entity_only_two[next_id, i] = g2l_only_two[entity2id[csk_entities[item['match_response_index_only_two'][i]]]] + + # one_two_triple + one_two_triples_id.append(padding_triple_id(entity2id, [[csk_triples[x].split(', ') for x in triple] for triple in item['one_two_triple']], triple_num_one_two, triple_len_one_two)) + + ############################ g2l_only_two + g2l_only_two_list.append(g2l_only_two) + + # local_entity_length + local_entity_length.append(local_entity_len_tmp) + + # only_two_entity_length + only_two_entity_length.append(only_two_entity_len_tmp) + + next_id += 1 + + batched_data = {'query_text': np.array(posts_id), + 'answer_text': np.array(responses_id), + 'local_entity': np.array(local_entity), + 'responses_length': responses_length, + 'q2e_adj_mat': np.array(q2e_adj_mats), + 'kb_adj_mat': build_kb_adj_mat(kb_adj_mats, config.fact_dropout), + 'kb_fact_rel': np.array(kb_fact_rels), + 'match_entity_one_hop': np.array(match_entity_one_hop), + 'only_two_entity': np.array(only_two_entity), + 'match_entity_only_two': np.array(match_entity_only_two), + 'one_two_triples_id': np.array(one_two_triples_id), + 'word2id': word2id, + 'entity2id': entity2id, + 'local_entity_length': local_entity_length, + 'only_two_entity_length': only_two_entity_length} - ############################ g2l_only_two - g2l_only_two_list.append(g2l_only_two) - - # local_entity_length - local_entity_length.append(local_entity_len_tmp) - - # only_two_entity_length - only_two_entity_length.append(only_two_entity_len_tmp) - - next_id += 1 - - batched_data = {'query_text': np.array(posts_id), - 'answer_text': np.array(responses_id), - 'local_entity': np.array(local_entity), - 'responses_length': responses_length, - 'q2e_adj_mat': np.array(q2e_adj_mats), - 'kb_adj_mat': build_kb_adj_mat(kb_adj_mats, config.fact_dropout), - 'kb_fact_rel': np.array(kb_fact_rels), - 'match_entity_one_hop': np.array(match_entity_one_hop), - 'only_two_entity': np.array(only_two_entity), - 'match_entity_only_two': np.array(match_entity_only_two), - 'one_two_triples_id': np.array(one_two_triples_id), - 'word2id': word2id, - 'entity2id': entity2id, - 'local_entity_length': local_entity_length, - 'only_two_entity_length': only_two_entity_length} - - return batched_data + return batched_data + + return gen_batched_data diff --git a/train.py b/train.py index 92f3444..ea9aa43 100644 --- a/train.py +++ b/train.py @@ -2,7 +2,7 @@ import numpy as np import json from model import ConceptFlow, use_cuda -from preprocession import prepare_data, build_vocab, gen_batched_data +from preprocession import prepare_data, build_vocab, get_data import torch import warnings import yaml @@ -44,9 +44,7 @@ def list_all_member(self): print('%s = %s' % (name, value)) -def run(model, data_train, config, word2id, entity2id): - batched_data = gen_batched_data(data_train, config, word2id, entity2id) - +def run(model, batched_data): if model.is_inference == True: word_index, selector = model(batched_data) return word_index, selector @@ -67,10 +65,9 @@ def train(config, model, data_train, data_test, word2id, entity2id, model_optimi only_two_cut = use_cuda(torch.Tensor([0])) count = 0 - for iteration in range(len(data_train) // config.batch_size): + for iteration, batch in enumerate(data_train): decoder_loss, sentence_ppx, sentence_ppx_word, sentence_ppx_local, sentence_ppx_only_two, word_neg_num, local_neg_num, \ - only_two_neg_num = run(model, data_train[(iteration * config.batch_size):(iteration * \ - config.batch_size + config.batch_size)], config, word2id, entity2id) + only_two_neg_num = run(model, batch, config, word2id, entity2id) sentence_ppx_loss += torch.sum(sentence_ppx).data sentence_ppx_word_loss += torch.sum(sentence_ppx_word).data sentence_ppx_local_loss += torch.sum(sentence_ppx_local).data @@ -117,11 +114,10 @@ def evaluate(model, data_test, config, word2id, entity2id, epoch = 0, model_path id2word[word2id[key]] = key - for iteration in range(len(data_test) // config.batch_size): + for iteration, batch in enumerate(data_test): decoder_loss, sentence_ppx, sentence_ppx_word, sentence_ppx_local, sentence_ppx_only_two, word_neg_num, \ - local_neg_num, only_two_neg_num = run(model, data_test[(iteration * config.batch_size):(iteration * \ - config.batch_size + config.batch_size)], config, word2id, entity2id) + local_neg_num, only_two_neg_num = run(model, batch, config, word2id, entity2id) sentence_ppx_loss += torch.sum(sentence_ppx).data sentence_ppx_word_loss += torch.sum(sentence_ppx_word).data sentence_ppx_local_loss += torch.sum(sentence_ppx_local).data @@ -151,8 +147,10 @@ def evaluate(model, data_test, config, word2id, entity2id, epoch = 0, model_path def main(): config = Config('config.yml') config.list_all_member() - raw_vocab, data_train, data_test = prepare_data(config) + raw_vocab = prepare_data(config) word2id, entity2id, vocab, embed, entity_vocab, entity_embed, relation_vocab, relation_embed, entity_relation_embed = build_vocab(config.data_dir, raw_vocab, config = config) + data_train, data_test = get_data(config,word2id,entity2id) + model = use_cuda(ConceptFlow(config, embed, entity_relation_embed)) model_optimizer = torch.optim.Adam(model.parameters(), lr = config.lr_rate) From 50cedc4b1db1d0311ce98af6088609812bc71934 Mon Sep 17 00:00:00 2001 From: Ilya Lasy Date: Tue, 2 Mar 2021 14:55:07 +0300 Subject: [PATCH 5/6] fixed training bugs --- train.py | 337 +++++++++++++++++++++++++++---------------------------- 1 file changed, 168 insertions(+), 169 deletions(-) diff --git a/train.py b/train.py index ea9aa43..4a9c7f9 100644 --- a/train.py +++ b/train.py @@ -1,169 +1,168 @@ -#coding:utf-8 -import numpy as np -import json -from model import ConceptFlow, use_cuda -from preprocession import prepare_data, build_vocab, get_data -import torch -import warnings -import yaml -import os -warnings.filterwarnings('ignore') - -csk_triples, csk_entities, kb_dict = [], [], [] -dict_csk_entities, dict_csk_triples = {}, {} -class Config(): - def __init__(self, path): - self.config_path = path - self._get_config() - - def _get_config(self): - with open(self.config_path, "r") as setting: - config = yaml.load(setting) - self.is_train = config['is_train'] - self.test_model_path = config['test_model_path'] - self.embed_units = config['embed_units'] - self.symbols = config['symbols'] - self.units = config['units'] - self.layers = config['layers'] - self.batch_size = config['batch_size'] - self.data_dir = config['data_dir'] - self.num_epoch = config['num_epoch'] - self.lr_rate = config['lr_rate'] - self.lstm_dropout = config['lstm_dropout'] - self.linear_dropout = config['linear_dropout'] - self.max_gradient_norm = config['max_gradient_norm'] - self.trans_units = config['trans_units'] - self.gnn_layers = config['gnn_layers'] - self.fact_dropout = config['fact_dropout'] - self.fact_scale = config['fact_scale'] - self.pagerank_lambda = config['pagerank_lambda'] - self.result_dir_name = config['result_dir_name'] - - def list_all_member(self): - for name, value in vars(self).items(): - print('%s = %s' % (name, value)) - - -def run(model, batched_data): - if model.is_inference == True: - word_index, selector = model(batched_data) - return word_index, selector - else: - decoder_loss, sentence_ppx, sentence_ppx_word, sentence_ppx_local, sentence_ppx_only_two, word_neg_num, local_neg_num, only_two_neg_num = model(batched_data) - return decoder_loss, sentence_ppx, sentence_ppx_word, sentence_ppx_local, sentence_ppx_only_two, word_neg_num, local_neg_num, only_two_neg_num - -def train(config, model, data_train, data_test, word2id, entity2id, model_optimizer): - for epoch in range(config.num_epoch): - print ("epoch: ", epoch) - sentence_ppx_loss = 0 - sentence_ppx_word_loss = 0 - sentence_ppx_local_loss = 0 - sentence_ppx_only_two_loss = 0 - - word_cut = use_cuda(torch.Tensor([0])) - local_cut = use_cuda(torch.Tensor([0])) - only_two_cut = use_cuda(torch.Tensor([0])) - - count = 0 - for iteration, batch in enumerate(data_train): - decoder_loss, sentence_ppx, sentence_ppx_word, sentence_ppx_local, sentence_ppx_only_two, word_neg_num, local_neg_num, \ - only_two_neg_num = run(model, batch, config, word2id, entity2id) - sentence_ppx_loss += torch.sum(sentence_ppx).data - sentence_ppx_word_loss += torch.sum(sentence_ppx_word).data - sentence_ppx_local_loss += torch.sum(sentence_ppx_local).data - sentence_ppx_only_two_loss += torch.sum(sentence_ppx_only_two).data - - word_cut += word_neg_num - local_cut += local_neg_num - only_two_cut += only_two_neg_num - - model_optimizer.zero_grad() - decoder_loss.backward() - torch.nn.utils.clip_grad_norm(model.parameters(), config.max_gradient_norm) - model_optimizer.step() - - if count % 50 == 0: - print ("iteration:", iteration, "Loss:", decoder_loss.data) - count += 1 - - print ("perplexity for epoch", epoch + 1, ":", np.exp(sentence_ppx_loss.cpu() / len(data_train)), " ppx_word: ", \ - np.exp(sentence_ppx_word_loss.cpu() / (len(data_train) - int(word_cut))), " ppx_local: ", \ - np.exp(sentence_ppx_local_loss.cpu() / (len(data_train) - int(local_cut))), " ppx_only_two: ", \ - np.exp(sentence_ppx_only_two_loss.cpu() / (len(data_train) - int(only_two_cut)))) - - torch.save(model.state_dict(), config.result_dir_name + '/' + '_epoch_' + str(epoch + 1) + '.pkl') - ppx, ppx_word, ppx_local, ppx_only_two = evaluate(model, data_test, config, word2id, entity2id, epoch + 1) - ppx_f = open(config.result_dir_name + '/result.txt','a') - ppx_f.write("epoch " + str(epoch + 1) + " ppx: " + str(ppx) + " ppx_word: " + str(ppx_word) + " ppx_local: " + \ - str(ppx_local) + " ppx_only_two: " + str(ppx_only_two) + '\n') - ppx_f.close() - -def evaluate(model, data_test, config, word2id, entity2id, epoch = 0, model_path = None): - if model_path != None: - model.load_state_dict(torch.load(model_path)) - sentence_ppx_loss = 0 - sentence_ppx_word_loss = 0 - sentence_ppx_local_loss = 0 - sentence_ppx_only_two_loss = 0 - word_cut = use_cuda(torch.Tensor([0])) - local_cut = use_cuda(torch.Tensor([0])) - only_two_cut = use_cuda(torch.Tensor([0])) - count = 0 - id2word = dict() - for key in word2id.keys(): - id2word[word2id[key]] = key - - - for iteration, batch in enumerate(data_test): - - decoder_loss, sentence_ppx, sentence_ppx_word, sentence_ppx_local, sentence_ppx_only_two, word_neg_num, \ - local_neg_num, only_two_neg_num = run(model, batch, config, word2id, entity2id) - sentence_ppx_loss += torch.sum(sentence_ppx).data - sentence_ppx_word_loss += torch.sum(sentence_ppx_word).data - sentence_ppx_local_loss += torch.sum(sentence_ppx_local).data - sentence_ppx_only_two_loss += torch.sum(sentence_ppx_only_two).data - - word_cut += word_neg_num - local_cut += local_neg_num - only_two_cut += only_two_neg_num - - if count % 50 == 0: - print ("iteration for evaluate:", iteration, "Loss:", decoder_loss.data) - count += 1 - - model.is_inference = False - if model_path != None: - print(' perplexity on test set:', np.exp(sentence_ppx_loss.cpu() / len(data_test)), \ - np.exp(sentence_ppx_word_loss.cpu() / (len(data_test) - int(word_cut))), np.exp(sentence_ppx_local_loss.cpu() / (len(data_test) \ - - int(local_cut))), np.exp(sentence_ppx_only_two_loss.cpu() / (len(data_test) - int(only_two_cut)))) - exit() - print(' perplexity on test set:', np.exp(sentence_ppx_loss.cpu() / len(data_test)), np.exp(sentence_ppx_word_loss.cpu() / \ - (len(data_test) - int(word_cut))), np.exp(sentence_ppx_local_loss.cpu() / (len(data_test) - int(local_cut))), \ - np.exp(sentence_ppx_only_two_loss.cpu() / (len(data_test) - int(only_two_cut)))) - return np.exp(sentence_ppx_loss.cpu() / len(data_test)), np.exp(sentence_ppx_word_loss.cpu() / (len(data_test) - int(word_cut))), \ - np.exp(sentence_ppx_local_loss.cpu() / (len(data_test) - int(local_cut))), np.exp(sentence_ppx_only_two_loss.cpu() / \ - (len(data_test) - int(only_two_cut))) - -def main(): - config = Config('config.yml') - config.list_all_member() - raw_vocab = prepare_data(config) - word2id, entity2id, vocab, embed, entity_vocab, entity_embed, relation_vocab, relation_embed, entity_relation_embed = build_vocab(config.data_dir, raw_vocab, config = config) - data_train, data_test = get_data(config,word2id,entity2id) - - model = use_cuda(ConceptFlow(config, embed, entity_relation_embed)) - - model_optimizer = torch.optim.Adam(model.parameters(), lr = config.lr_rate) - - if not os.path.exists(config.result_dir_name): - os.mkdir(config.result_dir_name) - ppx_f = open(config.result_dir_name + '/result.txt','a') - for name, value in vars(config).items(): - ppx_f.write('%s = %s' % (name, value) + '\n') - - if config.is_train == False: - evaluate(model, data_test, config, word2id, entity2id, 0, model_path = config.test_model_path) - exit() - train(config, model, data_train, data_test, word2id, entity2id, model_optimizer) - -main() +#coding:utf-8 +import numpy as np +import json +from model import ConceptFlow, use_cuda +from preprocession import prepare_data, build_vocab, get_data +import torch +import warnings +import yaml +import os +warnings.filterwarnings('ignore') + +csk_triples, csk_entities, kb_dict = [], [], [] +dict_csk_entities, dict_csk_triples = {}, {} +class Config(): + def __init__(self, path): + self.config_path = path + self._get_config() + + def _get_config(self): + with open(self.config_path, "r") as setting: + config = yaml.load(setting) + self.is_train = config['is_train'] + self.test_model_path = config['test_model_path'] + self.embed_units = config['embed_units'] + self.symbols = config['symbols'] + self.units = config['units'] + self.layers = config['layers'] + self.batch_size = config['batch_size'] + self.data_dir = config['data_dir'] + self.num_epoch = config['num_epoch'] + self.lr_rate = config['lr_rate'] + self.lstm_dropout = config['lstm_dropout'] + self.linear_dropout = config['linear_dropout'] + self.max_gradient_norm = config['max_gradient_norm'] + self.trans_units = config['trans_units'] + self.gnn_layers = config['gnn_layers'] + self.fact_dropout = config['fact_dropout'] + self.fact_scale = config['fact_scale'] + self.pagerank_lambda = config['pagerank_lambda'] + self.result_dir_name = config['result_dir_name'] + + def list_all_member(self): + for name, value in vars(self).items(): + print('%s = %s' % (name, value)) + + +def run(model, batched_data): + if model.is_inference == True: + word_index, selector = model(batched_data) + return word_index, selector + else: + decoder_loss, sentence_ppx, sentence_ppx_word, sentence_ppx_local, sentence_ppx_only_two, word_neg_num, local_neg_num, only_two_neg_num = model(batched_data) + return decoder_loss, sentence_ppx, sentence_ppx_word, sentence_ppx_local, sentence_ppx_only_two, word_neg_num, local_neg_num, only_two_neg_num + +def train(config, model, data_train, data_test, word2id, entity2id, model_optimizer): + for epoch in range(config.num_epoch): + print ("epoch: ", epoch) + sentence_ppx_loss = 0 + sentence_ppx_word_loss = 0 + sentence_ppx_local_loss = 0 + sentence_ppx_only_two_loss = 0 + + word_cut = use_cuda(torch.Tensor([0])) + local_cut = use_cuda(torch.Tensor([0])) + only_two_cut = use_cuda(torch.Tensor([0])) + + data_len = 0 + for iteration, batch in enumerate(data_train): + decoder_loss, sentence_ppx, sentence_ppx_word, sentence_ppx_local, sentence_ppx_only_two, word_neg_num, local_neg_num, \ + only_two_neg_num = run(model, batch) + sentence_ppx_loss += torch.sum(sentence_ppx).data + sentence_ppx_word_loss += torch.sum(sentence_ppx_word).data + sentence_ppx_local_loss += torch.sum(sentence_ppx_local).data + sentence_ppx_only_two_loss += torch.sum(sentence_ppx_only_two).data + + word_cut += word_neg_num + local_cut += local_neg_num + only_two_cut += only_two_neg_num + + model_optimizer.zero_grad() + decoder_loss.backward() + torch.nn.utils.clip_grad_norm(model.parameters(), config.max_gradient_norm) + model_optimizer.step() + + if iteration % 50 == 0: + print ("iteration:", iteration, "Loss:", decoder_loss.data) + data_len += len(batch['query_text']) + + print ("perplexity for epoch", epoch + 1, ":", np.exp(sentence_ppx_loss.cpu() / data_len), " ppx_word: ", \ + np.exp(sentence_ppx_word_loss.cpu() / (data_len - int(word_cut))), " ppx_local: ", \ + np.exp(sentence_ppx_local_loss.cpu() / (data_len - int(local_cut))), " ppx_only_two: ", \ + np.exp(sentence_ppx_only_two_loss.cpu() / (data_len - int(only_two_cut)))) + + torch.save(model.state_dict(), config.result_dir_name + '/' + '_epoch_' + str(epoch + 1) + '.pkl') + ppx, ppx_word, ppx_local, ppx_only_two = evaluate(model, data_test, config, word2id, entity2id, epoch + 1) + ppx_f = open(config.result_dir_name + '/result.txt','a') + ppx_f.write("epoch " + str(epoch + 1) + " ppx: " + str(ppx) + " ppx_word: " + str(ppx_word) + " ppx_local: " + \ + str(ppx_local) + " ppx_only_two: " + str(ppx_only_two) + '\n') + ppx_f.close() + +def evaluate(model, data_test, config, word2id, entity2id, epoch = 0, model_path = None): + if model_path != None: + model.load_state_dict(torch.load(model_path)) + sentence_ppx_loss = 0 + sentence_ppx_word_loss = 0 + sentence_ppx_local_loss = 0 + sentence_ppx_only_two_loss = 0 + word_cut = use_cuda(torch.Tensor([0])) + local_cut = use_cuda(torch.Tensor([0])) + only_two_cut = use_cuda(torch.Tensor([0])) + id2word = dict() + for key in word2id.keys(): + id2word[word2id[key]] = key + + data_len = 0 + for iteration, batch in enumerate(data_test): + + decoder_loss, sentence_ppx, sentence_ppx_word, sentence_ppx_local, sentence_ppx_only_two, word_neg_num, \ + local_neg_num, only_two_neg_num = run(model, batch) + sentence_ppx_loss += torch.sum(sentence_ppx).data + sentence_ppx_word_loss += torch.sum(sentence_ppx_word).data + sentence_ppx_local_loss += torch.sum(sentence_ppx_local).data + sentence_ppx_only_two_loss += torch.sum(sentence_ppx_only_two).data + + word_cut += word_neg_num + local_cut += local_neg_num + only_two_cut += only_two_neg_num + + if iteration % 50 == 0: + print ("iteration for evaluate:", iteration, "Loss:", decoder_loss.data) + data_len += len(batch['query_text']) + + model.is_inference = False + if model_path != None: + print(' perplexity on test set:', np.exp(sentence_ppx_loss.cpu() / data_len), \ + np.exp(sentence_ppx_word_loss.cpu() / (data_len - int(word_cut))), np.exp(sentence_ppx_local_loss.cpu() / (data_len\ + - int(local_cut))), np.exp(sentence_ppx_only_two_loss.cpu() / (data_len - int(only_two_cut)))) + exit() + print(' perplexity on test set:', np.exp(sentence_ppx_loss.cpu() / data_len), np.exp(sentence_ppx_word_loss.cpu() / \ + (data_len- int(word_cut))), np.exp(sentence_ppx_local_loss.cpu() / (data_len - int(local_cut))), \ + np.exp(sentence_ppx_only_two_loss.cpu() / (data_len - int(only_two_cut)))) + return np.exp(sentence_ppx_loss.cpu() / data_len), np.exp(sentence_ppx_word_loss.cpu() / (data_len- int(word_cut))), \ + np.exp(sentence_ppx_local_loss.cpu() / (data_len - int(local_cut))), np.exp(sentence_ppx_only_two_loss.cpu() / \ + (data_len - int(only_two_cut))) + +def main(): + config = Config('config.yml') + config.list_all_member() + raw_vocab = prepare_data(config) + word2id, entity2id, vocab, embed, entity_vocab, entity_embed, relation_vocab, relation_embed, entity_relation_embed = build_vocab(config.data_dir, raw_vocab, config = config) + data_train, data_test = get_data(config,word2id,entity2id) + + model = use_cuda(ConceptFlow(config, embed, entity_relation_embed)) + + model_optimizer = torch.optim.Adam(model.parameters(), lr = config.lr_rate) + + if not os.path.exists(config.result_dir_name): + os.mkdir(config.result_dir_name) + ppx_f = open(config.result_dir_name + '/result.txt','a') + for name, value in vars(config).items(): + ppx_f.write('%s = %s' % (name, value) + '\n') + + if config.is_train == False: + evaluate(model, data_test, config, word2id, entity2id, 0, model_path = config.test_model_path) + exit() + train(config, model, data_train, data_test, word2id, entity2id, model_optimizer) + +main() From e234963d5475c7b95dc53046c6047312bb4cf570 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Oct 2021 23:08:08 +0000 Subject: [PATCH 6/6] Bump opencv-python from 4.1.1.26 to 4.2.0.32 Bumps [opencv-python](https://github.com/skvark/opencv-python) from 4.1.1.26 to 4.2.0.32. - [Release notes](https://github.com/skvark/opencv-python/releases) - [Commits](https://github.com/skvark/opencv-python/commits) --- updated-dependencies: - dependency-name: opencv-python dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 1819cd6..6376b8f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,7 +49,7 @@ networkx==2.4 nltk==3.4.5 numpy==1.18.1 nvidia-ml-py3==7.352.0 -opencv-python==4.1.1.26 +opencv-python==4.2.0.32 opt-einsum==3.1.0 periodictable==1.5.1 Pillow==6.2.1