Skip to content

Commit 2b01757

Browse files
authored
Merge pull request #142 from hspark1212/develop
Develop
2 parents 8f53b5e + c064e25 commit 2b01757

File tree

7 files changed

+28
-12
lines changed

7 files changed

+28
-12
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
<p align="center">
44
<a href="https://hspark1212.github.io/MOFTransformer/">
5-
<img alt="Docs" src="https://img.shields.io/badge/Docs-v2.1.1-brightgreen.svg?style=plastic">
5+
<img alt="Docs" src="https://img.shields.io/badge/Docs-v2.1.2-brightgreen.svg?style=plastic">
66
</a>
77
<a href="https://pypi.org/project/moftransformer/">
8-
<img alt="PypI" src="https://img.shields.io/badge/PyPI-v2.1.1-blue.svg?style=plastic&logo=PyPI">
8+
<img alt="PypI" src="https://img.shields.io/badge/PyPI-v2.1.2-blue.svg?style=plastic&logo=PyPI">
99
</a>
1010
<a href="https://doi.org/10.6084/m9.figshare.21155506.v2">
1111
<img alt="Figshare" src="https://img.shields.io/badge/Figshare-v2-blue.svg?style=plastic&logo=figshare">

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
project = "MOFTransformer"
1010
copyright = "2022, Yeonghun Kang, Hyunsoo Park"
1111
author = "Yeonghun Kang, Hyunsoo Park"
12-
release = "2.1.1"
12+
release = "2.1.2"
1313

1414
# -- General configuration ---------------------------------------------------
1515
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

moftransformer/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# MOFTransformer version 2.1.1
1+
# MOFTransformer version 2.1.2
22
import os
33

4-
__version__ = "2.1.1"
4+
__version__ = "2.1.2"
55
__root_dir__ = os.path.dirname(__file__)
66

77
from moftransformer import visualize, utils, modules, libs, gadgets, datamodules, assets

moftransformer/modules/module.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,10 +330,13 @@ def on_predict_start(self):
330330
def predict_step(self, batch, batch_idx, dataloader_idx=0):
331331
output = self(batch)
332332

333-
softmax = torch.nn.Softmax(dim=1)
334333
if 'classification_logits' in output:
335-
output['classification_logits'] = softmax(output['classification_logits'])
336-
output['classification_logits_index'] = torch.argmax(output['classification_logits'], dim=1)
334+
if self.hparams.config['n_classes'] == 2:
335+
output['classification_logits_index'] = torch.round(output['classification_logits']).to(torch.int)
336+
else:
337+
softmax = torch.nn.Softmax(dim=1)
338+
output['classification_logits'] = softmax(output['classification_logits'])
339+
output['classification_logits_index'] = torch.argmax(output['classification_logits'], dim=1)
337340

338341
output = {
339342
k: (v.cpu().tolist() if torch.is_tensor(v) else v)

moftransformer/utils/install_griday.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,18 @@ def _make_griday():
6565
print(
6666
"=== Successfully download ======================================================="
6767
)
68-
if Path(GRIDAY_PATH).exists():
68+
if not Path(GRIDAY_PATH).exists():
69+
raise InstallationError(f"GRIDAY is not installed. Please try again.")
70+
71+
print(
72+
"=== Check GIRDAY ================================================================"
73+
)
74+
ps = subprocess.run([str(GRIDAY_PATH)], cwd=dir_griday, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
75+
if ps.stderr == b'./make_egrid spacing atom_type force_field input_cssr egrid_stem\n':
6976
print(f"GRIDAY is installed to {dir_griday}")
7077
else:
71-
raise InstallationError(f"GRIDAY is not installed. Please try again.")
78+
print (ps.stdout, ps.stderr)
79+
print(f'GRIDAY does not installed correctly. Please uninstall griday and re-install.')
7280

7381

7482
def install_griday(install_make=False):

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
# MOFTransformer version 2.0.1
1+
# MOFTransformer version 2.1.2
22
wget
33
torch<2.0.0
44
pytorch-lightning==1.7.0 # pytorch-lightning>=1.7.0
5-
torchmetrics>=0.6.0
5+
torchmetrics<1.0.0, >=0.6.0
66
transformers>=4.12.5
77
timm>=0.4.12
88
sacred>=0.8.2

updates.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Update
22

3+
## Version 2.1.2
4+
- Check GRIDAY is vailed or not when you run `install_griday`
5+
- `torchmetrics < 1.0.0` in requirements
6+
- Fix bugs in predict when loss is `classification` and n_classes are 2.
7+
38
## Version 2.1.1
49
- Fixed a bug when the structure name of raw_[downstream].json contains a cif during prepare_data.
510
- Changed an error that occurred when there were multiple devices in an interactive environment to a warning, automatically converting the configuration to a single device.

0 commit comments

Comments
 (0)