66from pina .solver import SupervisedSolver
77from pina .model import FeedForward
88from pina .trainer import Trainer
9+ from pina .graph import KNNGraph
910from torch ._dynamo .eval_frame import OptimizedModule
11+ from torch_geometric .nn import GCNConv
1012
1113
1214class LabelTensorProblem (AbstractProblem ):
@@ -30,9 +32,72 @@ class TensorProblem(AbstractProblem):
3032 }
3133
3234
35+ x = torch .rand ((100 , 20 , 5 ))
36+ pos = torch .rand ((100 , 20 , 2 ))
37+ output_ = torch .rand ((100 , 20 , 1 ))
38+ input_ = [
39+ KNNGraph (x = x_ , pos = pos_ , neighbours = 3 , edge_attr = True , y = y_ )
40+ for x_ , pos_ , y_ in zip (x , pos , output_ )
41+ ]
42+
43+
44+ class GraphProblem (AbstractProblem ):
45+ output_variables = None
46+ conditions = {
47+ "data" : Condition (
48+ graph = input_ ,
49+ )
50+ }
51+
52+
53+ x = LabelTensor (torch .rand ((100 , 20 , 5 )), ["a" , "b" , "c" , "d" , "e" ])
54+ pos = LabelTensor (torch .rand ((100 , 20 , 2 )), ["x" , "y" ])
55+ output_ = LabelTensor (torch .rand ((100 , 20 , 1 )), ["u" ])
56+ input_ = [
57+ KNNGraph (x = x [i ], pos = pos [i ], neighbours = 3 , edge_attr = True , y = output_ [i ])
58+ for i in range (len (x ))
59+ ]
60+
61+
62+ class GraphProblemLT (AbstractProblem ):
63+ output_variables = ["u" ]
64+ input_variables = ["a" , "b" , "c" , "d" , "e" ]
65+ conditions = {
66+ "data" : Condition (
67+ graph = input_ ,
68+ )
69+ }
70+
71+
3372model = FeedForward (2 , 1 )
3473
3574
75+ class Model (torch .nn .Module ):
76+
77+ def __init__ (self , * args , ** kwargs ):
78+ super ().__init__ (* args , ** kwargs )
79+ self .lift = torch .nn .Linear (5 , 10 )
80+ self .activation = torch .nn .Tanh ()
81+ self .output = torch .nn .Linear (10 , 1 )
82+
83+ self .conv = GCNConv (10 , 10 )
84+
85+ def forward (self , batch ):
86+
87+ x = batch .x
88+ edge_index = batch .edge_index
89+ for _ in range (1 ):
90+ y = self .lift (x )
91+ y = self .activation (y )
92+ y = self .conv (y , edge_index )
93+ y = self .activation (y )
94+ y = self .output (y )
95+ return y
96+
97+
98+ graph_model = Model ()
99+
100+
36101def test_constructor ():
37102 SupervisedSolver (problem = TensorProblem (), model = model )
38103 SupervisedSolver (problem = LabelTensorProblem (), model = model )
@@ -64,6 +129,24 @@ def test_solver_train(use_lt, batch_size, compile):
64129 assert isinstance (solver .model , OptimizedModule )
65130
66131
132+ @pytest .mark .parametrize ("batch_size" , [None , 1 , 5 , 20 ])
133+ @pytest .mark .parametrize ("use_lt" , [True , False ])
134+ def test_solver_train_graph (batch_size , use_lt ):
135+ problem = GraphProblemLT () if use_lt else GraphProblem ()
136+ solver = SupervisedSolver (problem = problem , model = graph_model , use_lt = use_lt )
137+ trainer = Trainer (
138+ solver = solver ,
139+ max_epochs = 2 ,
140+ accelerator = "cpu" ,
141+ batch_size = batch_size ,
142+ train_size = 1.0 ,
143+ test_size = 0.0 ,
144+ val_size = 0.0 ,
145+ )
146+
147+ trainer .train ()
148+
149+
67150@pytest .mark .parametrize ("use_lt" , [True , False ])
68151@pytest .mark .parametrize ("compile" , [True , False ])
69152def test_solver_validation (use_lt , compile ):
@@ -84,6 +167,24 @@ def test_solver_validation(use_lt, compile):
84167 assert isinstance (solver .model , OptimizedModule )
85168
86169
170+ @pytest .mark .parametrize ("batch_size" , [None , 1 , 5 , 20 ])
171+ @pytest .mark .parametrize ("use_lt" , [True , False ])
172+ def test_solver_val_graph (batch_size , use_lt ):
173+ problem = GraphProblemLT () if use_lt else GraphProblem ()
174+ solver = SupervisedSolver (problem = problem , model = graph_model , use_lt = use_lt )
175+ trainer = Trainer (
176+ solver = solver ,
177+ max_epochs = 2 ,
178+ accelerator = "cpu" ,
179+ batch_size = batch_size ,
180+ train_size = 0.9 ,
181+ val_size = 0.1 ,
182+ test_size = 0.0 ,
183+ )
184+
185+ trainer .train ()
186+
187+
87188@pytest .mark .parametrize ("use_lt" , [True , False ])
88189@pytest .mark .parametrize ("compile" , [True , False ])
89190def test_solver_test (use_lt , compile ):
@@ -104,6 +205,24 @@ def test_solver_test(use_lt, compile):
104205 assert isinstance (solver .model , OptimizedModule )
105206
106207
208+ @pytest .mark .parametrize ("batch_size" , [None , 1 , 5 , 20 ])
209+ @pytest .mark .parametrize ("use_lt" , [True , False ])
210+ def test_solver_test_graph (batch_size , use_lt ):
211+ problem = GraphProblemLT () if use_lt else GraphProblem ()
212+ solver = SupervisedSolver (problem = problem , model = graph_model , use_lt = use_lt )
213+ trainer = Trainer (
214+ solver = solver ,
215+ max_epochs = 2 ,
216+ accelerator = "cpu" ,
217+ batch_size = batch_size ,
218+ train_size = 0.8 ,
219+ val_size = 0.1 ,
220+ test_size = 0.1 ,
221+ )
222+
223+ trainer .test ()
224+
225+
107226def test_train_load_restore ():
108227 dir = "tests/test_solver/tmp/"
109228 problem = LabelTensorProblem ()
@@ -145,3 +264,52 @@ def test_train_load_restore():
145264 import shutil
146265
147266 shutil .rmtree ("tests/test_solver/tmp" )
267+
268+
269+ def test_train_load_restore_graph ():
270+ dir = "tests/test_solver/tmp/"
271+ problem = GraphProblemLT ()
272+ solver = SupervisedSolver (problem = problem , model = graph_model )
273+ trainer = Trainer (
274+ solver = solver ,
275+ max_epochs = 5 ,
276+ accelerator = "cpu" ,
277+ batch_size = None ,
278+ train_size = 0.9 ,
279+ test_size = 0.1 ,
280+ val_size = 0.0 ,
281+ default_root_dir = dir ,
282+ )
283+ trainer .train ()
284+
285+ # restore
286+ new_trainer = Trainer (solver = solver , max_epochs = 5 , accelerator = "cpu" )
287+ new_trainer .train (
288+ ckpt_path = f"{ dir } /lightning_logs/version_0/checkpoints/"
289+ + "epoch=4-step=5.ckpt"
290+ )
291+
292+ # loading
293+ new_solver = SupervisedSolver .load_from_checkpoint (
294+ f"{ dir } /lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt" ,
295+ problem = problem ,
296+ model = graph_model ,
297+ )
298+
299+ test_pts = KNNGraph (
300+ x = LabelTensor (torch .rand (20 , 5 ), ["a" , "b" , "c" , "d" , "e" ]),
301+ pos = LabelTensor (torch .rand (20 , 2 ), ["x" , "y" ]),
302+ neighbours = 3 ,
303+ edge_attr = True ,
304+ )
305+
306+ assert new_solver .forward (test_pts ).shape == (20 , 1 )
307+ assert new_solver .forward (test_pts ).shape == solver .forward (test_pts ).shape
308+ torch .testing .assert_close (
309+ new_solver .forward (test_pts ), solver .forward (test_pts )
310+ )
311+
312+ # rm directories
313+ import shutil
314+
315+ shutil .rmtree ("tests/test_solver/tmp" )
0 commit comments