Skip to content

Commit d8511c5

Browse files
authored
Merge pull request #108 from montygole/llm_predict
Replace the PDF parsing code with LLM model prediction
2 parents a82ec7b + 41691be commit d8511c5

14 files changed

Lines changed: 158 additions & 104 deletions

pyproject.toml

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -42,35 +42,23 @@ classifiers=[
4242
# core dependencies of harmony
4343
# this set should be kept minimal!
4444
dependencies = [
45-
"pydantic>=2.11.3; python_version <= '3.13.3'",
46-
"pandas>=2.2.3; python_version <= '3.13.3'",
47-
"tika>=3.1.0; python_version <= '3.13.3'",
48-
"lxml>=5.4.0; python_version <= '3.13.3'",
49-
"langdetect>=1.0.9; python_version <= '3.13.3'",
50-
"XlsxWriter>=3.2.3; python_version <= '3.13.3'",
51-
"openpyxl>=3.1.5; python_version <= '3.13.3'",
52-
"wget>=3.2; python_version <= '3.13.3'",
53-
"sentence-transformers>=4.1.0; python_version <= '3.13.3'",
54-
"numpy>=2.0.2; python_version <= '3.13.3'",
55-
"sklearn-crfsuite>=0.5.0; python_version <= '3.13.3'",
56-
"scikit-learn>=1.6.1; python_version <= '3.13.3'",
57-
"scipy>=1.13.1; python_version <= '3.13.3'",
58-
"huggingface-hub>=0.30.2; python_version <= '3.13.3'",
59-
"pydantic==2.8.2; python_version <= '3.13'",
60-
"pandas==2.2.2; python_version <= '3.13'",
61-
"tika==2.6.0; python_version <= '3.13'",
62-
"lxml==5.3.0; python_version <= '3.13'",
63-
"langdetect==1.0.9; python_version <= '3.13'",
64-
"XlsxWriter==3.0.9; python_version <= '3.13'",
65-
"openpyxl==3.1.2; python_version <= '3.13'",
66-
"wget==3.2; python_version <= '3.13'",
67-
"sentence-transformers==3.4.1; python_version <= '3.13'",
45+
"pydantic>=2.11.3; python_version <= '3.13'",
46+
"pandas>=2.2.3; python_version <= '3.13'",
47+
"tika>=3.1.0; python_version <= '3.13'",
48+
"lxml>=5.4.0; python_version <= '3.13'",
49+
"langdetect>=1.0.9; python_version <= '3.13'",
50+
"XlsxWriter>=3.2.3; python_version <= '3.13'",
51+
"openpyxl>=3.1.5; python_version <= '3.13'",
52+
"wget>=3.2; python_version <= '3.13'",
53+
"sentence-transformers>=4.1.0; python_version <= '3.13'",
6854
"numpy==1.26.4; python_version <= '3.13'",
69-
"sklearn-crfsuite==0.5.0; python_version <= '3.13'",
70-
"scikit-learn; python_version <= '3.13'",
71-
"scipy==1.14.1; python_version <= '3.13'",
72-
"huggingface-hub==0.29.3; python_version <= '3.13'",
73-
"fpdf==1.7.2; python_version <= '3.13'",
55+
"sklearn-crfsuite>=0.5.0; python_version <= '3.13'",
56+
"scikit-learn>=1.6.1; python_version <= '3.13'",
57+
"scipy>=1.13.1; python_version <= '3.13'",
58+
"huggingface-hub>=0.30.2; python_version <= '3.13'",
59+
"torch==2.2.2; python_version <= '3.13'",
60+
"transformers==4.50.3; python_version <= '3.13'",
61+
"fpdf2~=2.8.2; python_version <= '3.13'",
7462
]
7563

7664
[project.optional-dependencies]

requirements.txt

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,13 @@ XlsxWriter>=3.2.3
77
openpyxl>=3.1.5
88
wget>=3.2
99
sentence-transformers>=4.1.0
10-
numpy>=2.0.2
10+
numpy==1.26.4
1111
sklearn-crfsuite>=0.5.0
1212
scikit-learn>=1.6.1
1313
scipy>=1.13.1
1414
huggingface-hub>=0.30.2
15-
pydantic==2.8.2
16-
pandas==2.2.2
17-
tika==2.6.0
18-
lxml==5.3.0
19-
langdetect==1.0.9
20-
XlsxWriter==3.0.9
21-
openpyxl==3.1.2
22-
wget==3.2
23-
sentence-transformers==3.4.1
24-
numpy==1.26.4
2515
sklearn-crfsuite==0.5.0
26-
scikit-learn==1.5.0
2716
scipy==1.14.1
28-
huggingface-hub==0.29.3
17+
torch==2.2.2
18+
transformers==4.50.3
2919
fpdf2~=2.8.2

src/harmony/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .parsing.text_parser import convert_text_to_instruments
4040
from .parsing.excel_parser import convert_excel_to_instruments
4141
from .parsing.pdf_parser import convert_pdf_to_instruments
42+
from .parsing.pdf_parser import group_token_spans_by_class
4243
from .parsing.wrapper_all_parsers import convert_files_to_instruments
4344
from .parsing import *
4445
from .util.file_helper import load_instruments_from_local_file

src/harmony/parsing/pdf_parser.py

Lines changed: 69 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -25,57 +25,84 @@
2525
2626
'''
2727

28-
import pathlib
29-
import pickle as pkl
30-
import re
28+
import os
29+
30+
import torch
31+
from transformers import AutoModelForTokenClassification, AutoTokenizer
3132

3233
import harmony
33-
from harmony.parsing.util.feature_extraction import convert_text_to_features
3434
from harmony.parsing.util.tika_wrapper import parse_pdf_to_plain_text
3535
from harmony.schemas.requests.text import RawFile, Instrument
3636

37-
model_containing_folder = pathlib.Path(__file__).parent.resolve()
37+
# Disable tokenizer parallelism
38+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
39+
40+
41+
def group_token_spans_by_class(tokens, classes,
42+
tokenizer=AutoTokenizer.from_pretrained("harmonydata/debertaV2_pdfparser")) -> dict:
43+
"""
44+
Given a list of tokens, and a list of predicted classes
45+
for each token, create a dictionary to hold each
46+
span of tokens.
47+
Example:
48+
> group_token_spans_by_classes(['▁how', '▁are', '▁you', '?', '▁1'],
49+
['question', 'question', 'question', 'question', 'answer'],
50+
bert_tokenizer)
51+
> {"question":["How are you?"], "answer":["1"]}
52+
Notice that some tokens begin with ▁ (ASCII 9601) instead of _ (ASCII 95)
53+
:param tokens: List of tokens
54+
:type tokens: List[str]
55+
:param classes: List of predicted classes
56+
:type classes: List[str]
57+
:param tokenizer: Tokenizer (defaulted to harmonydata/debertaV2_pdfparser)
58+
:return: Dictionary of each span relative to its class
59+
"""
60+
grouped_spans = {"answer": [], "question": [], "other": []}
61+
span = []
62+
prev_cls = None
63+
64+
for token, cls in zip(tokens, classes):
65+
if cls != prev_cls and span:
66+
grouped_spans[prev_cls].append(tokenizer.convert_tokens_to_string(span))
67+
span = []
68+
span.append(token)
69+
prev_cls = cls
70+
# Add final token and class to respective key in dictionary
71+
if span:
72+
grouped_spans[prev_cls].append(tokenizer.convert_tokens_to_string(span))
73+
74+
return grouped_spans
3875

39-
with open(f"{model_containing_folder}/20240719_pdf_question_extraction_sklearn_crf_model.pkl", "rb") as f:
40-
crf_text_model = pkl.load(f)
4176

42-
# Predict method is taken from the training repo. Use the training repo as the master copy of the predict method.
43-
# All training code is in https://github.com/harmonydata/pdf-questionnaire-extraction
4477
def predict(test_text):
45-
token_texts, token_start_char_indices, token_end_char_indices, token_properties = convert_text_to_features(
46-
test_text)
47-
48-
X = []
49-
X.append(token_properties)
50-
51-
y_pred = crf_text_model.predict(X)
78+
# Load fine-tuned huggingface model and tokenizer
79+
model = AutoModelForTokenClassification.from_pretrained("harmonydata/debertaV2_pdfparser")
80+
tokenizer = AutoTokenizer.from_pretrained("harmonydata/debertaV2_pdfparser")
5281

53-
questions_from_text = []
82+
# Tokenize input text
83+
tokenized_texts = tokenizer(test_text, return_tensors="pt")
5484

55-
tokens_already_used = set()
85+
# Inference with tokenized input text
86+
with torch.no_grad():
87+
logits = model(**tokenized_texts).logits
5688

57-
last_token_category = "O"
89+
# Retrieve predicted class for each token
90+
predictions = torch.argmax(logits, dim=2)
91+
predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]]
5892

59-
for idx in range(len(X[0])):
93+
# Get input IDs (tensor) and convert to list
94+
input_ids = tokenized_texts["input_ids"][0].tolist()
95+
# Convert input IDs to tokens
96+
decoded_tokenized_texts = tokenizer.convert_ids_to_tokens(input_ids)
6097

61-
if y_pred[0][idx] != "O" and idx not in tokens_already_used:
62-
if last_token_category == "O" or y_pred[0][idx] == "B":
63-
start_idx = token_start_char_indices[idx]
64-
end_idx = len(test_text)
65-
for j in range(idx + 1, len(X[0])):
66-
if y_pred[0][j] == "O" or y_pred[0][j] == "B":
67-
end_idx = token_end_char_indices[j - 1]
68-
break
69-
tokens_already_used.add(j)
98+
# Remove leading [CLS] and trailing [SEP] tokens from decoded
99+
# tokens, and the list of predictions
100+
predicted_token_class = predicted_token_class[1:-1]
101+
decoded_tokenized_texts = decoded_tokenized_texts[1:-1]
70102

71-
question_text = test_text[start_idx:end_idx]
72-
question_text = re.sub(r'\s+', ' ', question_text)
73-
question_text = question_text.strip()
74-
questions_from_text.append(question_text)
103+
grouped_tokens = group_token_spans_by_class(decoded_tokenized_texts, predicted_token_class, tokenizer)
75104

76-
last_token_category = y_pred[0][idx]
77-
78-
return questions_from_text
105+
return grouped_tokens
79106

80107

81108
def convert_pdf_to_instruments(file: RawFile) -> Instrument:
@@ -87,8 +114,12 @@ def convert_pdf_to_instruments(file: RawFile) -> Instrument:
87114
if not file.text_content:
88115
file.text_content = parse_pdf_to_plain_text(file.content) # call Tika to convert the PDF to plain text
89116

90-
questions_from_text = predict(file.text_content)
117+
# Run prediction script to return questions and answers from file text content
118+
questions_answers_from_text = predict(file.text_content)
119+
120+
questions_from_text = questions_answers_from_text["question"]
121+
answers_from_text = questions_answers_from_text["answer"]
91122

92-
instrument = harmony.create_instrument_from_list(questions_from_text, instrument_name=file.file_name,
123+
instrument = harmony.create_instrument_from_list(questions_from_text, answers_from_text, instrument_name=file.file_name,
93124
file_name=file.file_name)
94125
return [instrument]

src/harmony/util/instrument_helper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,21 @@
3232
from harmony.schemas.requests.text import Instrument, Question
3333

3434

35-
def create_instrument_from_list(question_texts: list, question_numbers: list = None,
35+
def create_instrument_from_list(question_texts: list, answer_texts: list, question_numbers: list = None,
36+
answer_numbers: list = None,
3637
instrument_name: str = "My instrument",
3738
file_name="My file") -> Instrument:
3839
"""
3940
Read a list of strings and create an Instrument object.
4041
:return: Single Instrument.
4142
"""
42-
4343
questions = []
4444
for ctr, question_text in enumerate(question_texts):
4545
if question_numbers is not None:
4646
question_no = question_numbers[ctr]
4747
else:
4848
question_no = str(ctr + 1)
49-
questions.append(Question(question_text=question_text, question_no=question_no))
49+
questions.append(Question(question_text=question_text, question_no=question_no, options=answer_texts))
5050

5151
return Instrument(questions=questions, instrument_name=instrument_name, instrument_id=uuid.uuid4().hex,
5252
file_name=file_name, file_id=uuid.uuid4().hex)

tests/test_convert_pdf.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from harmony import convert_pdf_to_instruments
3434
from harmony.schemas.requests.text import RawFile
3535
from harmony import download_models
36+
from harmony import group_token_spans_by_class
3637

3738
pdf_gad_7_2_questions = RawFile.model_validate({
3839
"file_id": "d39f31718513413fbfc620c6b6135d0c",
@@ -53,5 +54,26 @@ def test_two_questions(self):
5354
self.assertEqual(2, len(convert_pdf_to_instruments(pdf_gad_7_2_questions)[0].questions))
5455

5556

57+
class TestTokenGroupingByClass(unittest.TestCase):
58+
def test_multiple_questions_answers_others(self):
59+
input_classes = ["question", "question", "question", "question",
60+
"answer",
61+
"other", "other", "other", "other", "other", "other",
62+
"question", "question", "question", "question",
63+
"answer",
64+
"other", "other", "other", "other", "other", "other"]
65+
input_tokens = ['▁How', '▁are', '▁you', '?',
66+
"▁5",
67+
".", "▁lore", "m", "▁ipsum", "▁dolor", ".",
68+
"▁How", "▁are", "▁you", "?",
69+
"▁8",
70+
".", "▁lore", "m", "▁ipsum", "▁dolor", "."]
71+
expected_output = {"question": ["How are you?", "How are you?"],
72+
"answer": ["5", "8"],
73+
"other": [". lorem ipsum dolor.", ". lorem ipsum dolor."]}
74+
output = group_token_spans_by_class(input_tokens, input_classes)
75+
self.assertDictEqual(expected_output, output)
76+
77+
5678
if __name__ == '__main__':
5779
unittest.main()

tests/test_create_instrument_from_list.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,25 @@
3636
class TestCreateInstrument(unittest.TestCase):
3737

3838
def test_single_instrument_simple(self):
39-
instrument = create_instrument_from_list(["question A", "question B"])
39+
instrument = create_instrument_from_list(["question A", "question B"], [])
4040
self.assertEqual(2, len(instrument.questions))
4141

4242
def test_single_instrument_simple_2(self):
43-
instrument = create_instrument_from_list(["question A", "question B", "question C"], instrument_name="potato")
43+
instrument = create_instrument_from_list(["question A", "question B", "question C"], [],
44+
instrument_name="potato")
4445
self.assertEqual(3, len(instrument.questions))
4546
self.assertEqual("potato", instrument.instrument_name)
4647

48+
def test_single_instrument_with_answers(self):
49+
instrument = create_instrument_from_list(["question A", "question B", "question C"], ["Never", "Rarely", "Less than 2 times a week", "Everyday"],
50+
instrument_name="potato")
51+
self.assertEqual(3, len(instrument.questions))
52+
self.assertEqual(4, len(instrument.questions[0].options))
53+
self.assertEqual(4, len(instrument.questions[1].options))
54+
self.assertEqual(4, len(instrument.questions[2].options))
55+
self.assertEqual("potato", instrument.instrument_name)
4756
def test_single_instrument_send_to_web(self):
48-
instrument = create_instrument_from_list(["question A", "question B"])
57+
instrument = create_instrument_from_list(["question A", "question B"], [])
4958
web_url = import_instrument_into_harmony_web(instrument)
5059
self.assertIn("harmonydata.ac.uk", web_url)
5160

tests/test_crosswalk.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
class TestGenerateCrosswalkTable(unittest.TestCase):
4242
def setUp(self):
4343
# Sample data
44-
self.instruments_dummy = [create_instrument_from_list(["potato", "tomato", "radish"], instrument_name="veg")]
44+
self.instruments_dummy = [
45+
create_instrument_from_list(["potato", "tomato", "radish"], [], instrument_name="veg")]
4546

4647
self.similarity = np.array([
4748
[1.0, 0.7, 0.9],
@@ -51,6 +52,7 @@ def setUp(self):
5152

5253
self.instruments = [create_instrument_from_list(
5354
["Feeling nervous, anxious, or on edge", "Not being able to stop or control worrying"],
55+
[],
5456
instrument_name="GAD-7")]
5557

5658
self.threshold = 0.6
@@ -101,9 +103,10 @@ def test_generate_crosswalk_table_real(self):
101103

102104
def test_crosswalk_two_instruments_allow_many_to_one_matches(self):
103105

104-
instrument_1 = create_instrument_from_list(["I felt fearful."])
106+
instrument_1 = create_instrument_from_list(["I felt fearful."], [])
105107
instrument_2 = create_instrument_from_list(
106-
["Feeling afraid, as if something awful might happen", "Feeling nervous, anxious, or on edge"])
108+
["Feeling afraid, as if something awful might happen", "Feeling nervous, anxious, or on edge"],
109+
[])
107110
instruments = [instrument_1, instrument_2]
108111

109112
match_response = match_instruments(instruments)
@@ -114,9 +117,10 @@ def test_crosswalk_two_instruments_allow_many_to_one_matches(self):
114117

115118
def test_crosswalk_two_instruments_enforce_one_to_one_matches(self):
116119

117-
instrument_1 = create_instrument_from_list(["I felt fearful."])
120+
instrument_1 = create_instrument_from_list(["I felt fearful."], [])
118121
instrument_2 = create_instrument_from_list(
119-
["Feeling afraid, as if something awful might happen", "Feeling nervous, anxious, or on edge"])
122+
["Feeling afraid, as if something awful might happen", "Feeling nervous, anxious, or on edge"],
123+
[])
120124
instruments = [instrument_1, instrument_2]
121125

122126
match_response = match_instruments(instruments)

tests/test_deterministic_clustering.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,9 @@
3030

3131
sys.path.append("../src")
3232

33-
from harmony import match_instruments, create_instrument_from_list, find_clusters_deterministic
34-
from harmony.schemas.requests.text import Instrument, Question
33+
from harmony import create_instrument_from_list, find_clusters_deterministic
3534
import numpy as np
3635

37-
3836
if __name__ == '__main__':
3937
unittest.main()
4038

@@ -43,17 +41,21 @@ class TestDeterministicClustering(unittest.TestCase):
4341

4442
def test_two_questions_one_cluster(self):
4543
questions = create_instrument_from_list(
46-
["Feeling nervous, anxious, or on edge", "Not being able to stop or control worrying"]).questions
44+
["Feeling nervous, anxious, or on edge", "Not being able to stop or control worrying"],
45+
[]).questions
4746
item_to_item_similarity_matrix = np.eye(2) / 2 + np.ones((2, 2)) / 2
4847
clusters = find_clusters_deterministic(questions, item_to_item_similarity_matrix)
4948
self.assertEqual(1, len(clusters))
5049

5150
def test_three_questions_one_cluster(self):
5251
questions = create_instrument_from_list(
53-
["Feeling nervous, anxious, or on edge", "Not being able to stop or control worrying", "Worrying too much about different things"]).questions
52+
["Feeling nervous, anxious, or on edge", "Not being able to stop or control worrying",
53+
"Worrying too much about different things"],
54+
[]).questions
5455
item_to_item_similarity_matrix = np.eye(3) / 2 + np.ones((3, 3)) / 2
5556
clusters = find_clusters_deterministic(questions, item_to_item_similarity_matrix)
5657
self.assertEqual(1, len(clusters))
5758

59+
5860
if __name__ == '__main__':
5961
unittest.main()

0 commit comments

Comments
 (0)