Skip to content

Commit ca3de7c

Browse files
committed
Add test supervised solver for graph data and remove PinaBatch
1 parent 2c371ae commit ca3de7c

3 files changed

Lines changed: 181 additions & 25 deletions

File tree

pina/data/dataset.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -109,28 +109,6 @@ def input_points(self):
109109
return {k: v["input_points"] for k, v in self.conditions_dict.items()}
110110

111111

112-
class PinaBatch(Batch):
113-
"""
114-
Add extract function to torch_geometric Batch object
115-
"""
116-
117-
def __init__(self):
118-
119-
super().__init__(self)
120-
121-
def extract(self, labels):
122-
"""
123-
Perform extraction of labels on node features (x)
124-
125-
:param labels: Labels to extract
126-
:type labels: list[str] | tuple[str] | str
127-
:return: Batch object with extraction performed on x
128-
:rtype: PinaBatch
129-
"""
130-
self.x = self.x.extract(labels)
131-
return self
132-
133-
134112
class PinaGraphDataset(PinaDataset):
135113

136114
def __init__(
@@ -182,12 +160,10 @@ def _divide_batch(self, batch):
182160
"""
183161
to_return_dict = {}
184162
to_return_dict["input_points"] = batch
185-
if hasattr(batch, "y"):
186-
to_return_dict["output_points"] = batch.y
187163
return to_return_dict
188164

189165
def _base_create_graph_batch_from_list(self, data):
190-
batch = PinaBatch.from_data_list(data)
166+
batch = Batch.from_data_list(data)
191167
return batch
192168

193169
def _getitem_dummy(self, idx):

pina/graph.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,18 @@ def _preprocess_edge_index(edge_index, undirected):
162162
edge_index = to_undirected(edge_index)
163163
return edge_index
164164

165+
def extract(self, labels):
166+
"""
167+
Perform extraction of labels on node features (x)
168+
169+
:param labels: Labels to extract
170+
:type labels: list[str] | tuple[str] | str
171+
:return: Batch object with extraction performed on x
172+
:rtype: PinaBatch
173+
"""
174+
self.x = self.x.extract(labels)
175+
return self
176+
165177

166178
class GraphBuilder:
167179
"""

tests/test_solver/test_supervised_solver.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from pina.solver import SupervisedSolver
77
from pina.model import FeedForward
88
from pina.trainer import Trainer
9+
from pina.graph import KNNGraph
910
from torch._dynamo.eval_frame import OptimizedModule
11+
from torch_geometric.nn import GCNConv
1012

1113

1214
class LabelTensorProblem(AbstractProblem):
@@ -30,9 +32,72 @@ class TensorProblem(AbstractProblem):
3032
}
3133

3234

35+
x = torch.rand((100, 20, 5))
36+
pos = torch.rand((100, 20, 2))
37+
output_ = torch.rand((100, 20, 1))
38+
input_ = [
39+
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True, y=y_)
40+
for x_, pos_, y_ in zip(x, pos, output_)
41+
]
42+
43+
44+
class GraphProblem(AbstractProblem):
45+
output_variables = None
46+
conditions = {
47+
"data": Condition(
48+
graph=input_,
49+
)
50+
}
51+
52+
53+
x = LabelTensor(torch.rand((100, 20, 5)), ["a", "b", "c", "d", "e"])
54+
pos = LabelTensor(torch.rand((100, 20, 2)), ["x", "y"])
55+
output_ = LabelTensor(torch.rand((100, 20, 1)), ["u"])
56+
input_ = [
57+
KNNGraph(x=x[i], pos=pos[i], neighbours=3, edge_attr=True, y=output_[i])
58+
for i in range(len(x))
59+
]
60+
61+
62+
class GraphProblemLT(AbstractProblem):
63+
output_variables = ["u"]
64+
input_variables = ["a", "b", "c", "d", "e"]
65+
conditions = {
66+
"data": Condition(
67+
graph=input_,
68+
)
69+
}
70+
71+
3372
model = FeedForward(2, 1)
3473

3574

75+
class Model(torch.nn.Module):
76+
77+
def __init__(self, *args, **kwargs):
78+
super().__init__(*args, **kwargs)
79+
self.lift = torch.nn.Linear(5, 10)
80+
self.activation = torch.nn.Tanh()
81+
self.output = torch.nn.Linear(10, 1)
82+
83+
self.conv = GCNConv(10, 10)
84+
85+
def forward(self, batch):
86+
87+
x = batch.x
88+
edge_index = batch.edge_index
89+
for _ in range(1):
90+
y = self.lift(x)
91+
y = self.activation(y)
92+
y = self.conv(y, edge_index)
93+
y = self.activation(y)
94+
y = self.output(y)
95+
return y
96+
97+
98+
graph_model = Model()
99+
100+
36101
def test_constructor():
37102
SupervisedSolver(problem=TensorProblem(), model=model)
38103
SupervisedSolver(problem=LabelTensorProblem(), model=model)
@@ -64,6 +129,24 @@ def test_solver_train(use_lt, batch_size, compile):
64129
assert isinstance(solver.model, OptimizedModule)
65130

66131

132+
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
133+
@pytest.mark.parametrize("use_lt", [True, False])
134+
def test_solver_train_graph(batch_size, use_lt):
135+
problem = GraphProblemLT() if use_lt else GraphProblem()
136+
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
137+
trainer = Trainer(
138+
solver=solver,
139+
max_epochs=2,
140+
accelerator="cpu",
141+
batch_size=batch_size,
142+
train_size=1.0,
143+
test_size=0.0,
144+
val_size=0.0,
145+
)
146+
147+
trainer.train()
148+
149+
67150
@pytest.mark.parametrize("use_lt", [True, False])
68151
@pytest.mark.parametrize("compile", [True, False])
69152
def test_solver_validation(use_lt, compile):
@@ -84,6 +167,24 @@ def test_solver_validation(use_lt, compile):
84167
assert isinstance(solver.model, OptimizedModule)
85168

86169

170+
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
171+
@pytest.mark.parametrize("use_lt", [True, False])
172+
def test_solver_val_graph(batch_size, use_lt):
173+
problem = GraphProblemLT() if use_lt else GraphProblem()
174+
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
175+
trainer = Trainer(
176+
solver=solver,
177+
max_epochs=2,
178+
accelerator="cpu",
179+
batch_size=batch_size,
180+
train_size=0.9,
181+
val_size=0.1,
182+
test_size=0.0,
183+
)
184+
185+
trainer.train()
186+
187+
87188
@pytest.mark.parametrize("use_lt", [True, False])
88189
@pytest.mark.parametrize("compile", [True, False])
89190
def test_solver_test(use_lt, compile):
@@ -104,6 +205,24 @@ def test_solver_test(use_lt, compile):
104205
assert isinstance(solver.model, OptimizedModule)
105206

106207

208+
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
209+
@pytest.mark.parametrize("use_lt", [True, False])
210+
def test_solver_test_graph(batch_size, use_lt):
211+
problem = GraphProblemLT() if use_lt else GraphProblem()
212+
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
213+
trainer = Trainer(
214+
solver=solver,
215+
max_epochs=2,
216+
accelerator="cpu",
217+
batch_size=batch_size,
218+
train_size=0.8,
219+
val_size=0.1,
220+
test_size=0.1,
221+
)
222+
223+
trainer.test()
224+
225+
107226
def test_train_load_restore():
108227
dir = "tests/test_solver/tmp/"
109228
problem = LabelTensorProblem()
@@ -145,3 +264,52 @@ def test_train_load_restore():
145264
import shutil
146265

147266
shutil.rmtree("tests/test_solver/tmp")
267+
268+
269+
def test_train_load_restore_graph():
270+
dir = "tests/test_solver/tmp/"
271+
problem = GraphProblemLT()
272+
solver = SupervisedSolver(problem=problem, model=graph_model)
273+
trainer = Trainer(
274+
solver=solver,
275+
max_epochs=5,
276+
accelerator="cpu",
277+
batch_size=None,
278+
train_size=0.9,
279+
test_size=0.1,
280+
val_size=0.0,
281+
default_root_dir=dir,
282+
)
283+
trainer.train()
284+
285+
# restore
286+
new_trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
287+
new_trainer.train(
288+
ckpt_path=f"{dir}/lightning_logs/version_0/checkpoints/"
289+
+ "epoch=4-step=5.ckpt"
290+
)
291+
292+
# loading
293+
new_solver = SupervisedSolver.load_from_checkpoint(
294+
f"{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt",
295+
problem=problem,
296+
model=graph_model,
297+
)
298+
299+
test_pts = KNNGraph(
300+
x=LabelTensor(torch.rand(20, 5), ["a", "b", "c", "d", "e"]),
301+
pos=LabelTensor(torch.rand(20, 2), ["x", "y"]),
302+
neighbours=3,
303+
edge_attr=True,
304+
)
305+
306+
assert new_solver.forward(test_pts).shape == (20, 1)
307+
assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape
308+
torch.testing.assert_close(
309+
new_solver.forward(test_pts), solver.forward(test_pts)
310+
)
311+
312+
# rm directories
313+
import shutil
314+
315+
shutil.rmtree("tests/test_solver/tmp")

0 commit comments

Comments
 (0)