Updating search space#156
Updating search space#156ravinkohli merged 4 commits intorefactor_development_regularization_cocktailsfrom
Conversation
|
@franchuterivera @ravinkohli this is the same pr as #154. It contains additional bug fixes though. We can probably even delete that one. |
ravinkohli
left a comment
There was a problem hiding this comment.
Thanks for the changes. I think it can be merged
There was a problem hiding this comment.
Hey @ArlindKadra, although it may be outside the scope of this PR, could you also add the changes for fixing the bug in predicting on a GPU? As it is just a simple update I think it's fine to have it in this PR, allowing us to run the experiments on a GPU directly after merging this PR. It just involves pushing the network to the device(i.e, GPU) before predicting with it. I have pasted the changes I made locally with this comment. Also, we are looking for ways to run the tests also on a GPU so we are sure that the code can run on a GPU as well.
def _predict(self, network: torch.nn.Module, loader: torch.utils.data.DataLoader) -> torch.Tensor:
network.to(self.device)
network.float()
network.eval()
# Batch prediction
Y_batch_preds = list()
for i, (X_batch, Y_batch) in enumerate(loader):
# Predict on batch
X_batch = X_batch.float().to(self.device)
Y_batch_pred = network(X_batch)
if self.final_activation is not None:
Y_batch_pred = self.final_activation(Y_batch_pred)
Y_batch_preds.append(Y_batch_pred.detach().cpu())
return torch.cat(Y_batch_preds, 0)
@ravinkohli Sure thing, I added that, however, is there a reason why we are actually converting the |
yes I don't think we need to do network.float(). I added it because I was getting an error for a PR but I tried running the tests without it and they passed. So its your choice, we can remove it and if we get any errors, then we know why and we can fix them right away or you can leave it as is. |
ravinkohli
left a comment
There was a problem hiding this comment.
I think it can be merged once all the tests pass
* Updating search space * fix typo * Bug fix * Fixing buggy implementation of predict when using gpu
* Updating search space * fix typo * Bug fix * Fixing buggy implementation of predict when using gpu
* Updating search space * fix typo * Bug fix * Fixing buggy implementation of predict when using gpu
* Updating search space * fix typo * Bug fix * Fixing buggy implementation of predict when using gpu
* Updating search space * fix typo * Bug fix * Fixing buggy implementation of predict when using gpu bug fixes fixing code style checks bug fix for use_pynisher in the base pipeline bug fix
* Updating search space * fix typo * Bug fix * Fixing buggy implementation of predict when using gpu bug fixes fixing code style checks bug fix for use_pynisher in the base pipeline bug fix
* Updating search space * fix typo * Bug fix * Fixing buggy implementation of predict when using gpu bug fixes fixing code style checks bug fix for use_pynisher in the base pipeline bug fix
* Updating search space * fix typo * Bug fix * Fixing buggy implementation of predict when using gpu bug fixes fixing code style checks bug fix for use_pynisher in the base pipeline bug fix
* Updating search space * fix typo * Bug fix * Fixing buggy implementation of predict when using gpu bug fixes fixing code style checks bug fix for use_pynisher in the base pipeline bug fix
* Updating search space * fix typo * Bug fix * Fixing buggy implementation of predict when using gpu bug fixes fixing code style checks bug fix for use_pynisher in the base pipeline bug fix
* Updating search space * fix typo * Bug fix * Fixing buggy implementation of predict when using gpu bug fixes fixing code style checks bug fix for use_pynisher in the base pipeline bug fix
* Updating search space * fix typo * Bug fix * Fixing buggy implementation of predict when using gpu bug fixes fixing code style checks bug fix for use_pynisher in the base pipeline bug fix
* Updating search space * fix typo * Bug fix * Fixing buggy implementation of predict when using gpu bug fixes fixing code style checks bug fix for use_pynisher in the base pipeline bug fix
Updating the search space the same as for the `refactor_development' branch in #154