2525
2626'''
2727
28- import pathlib
29- import pickle as pkl
30- import re
28+ import os
29+
30+ import torch
31+ from transformers import AutoModelForTokenClassification , AutoTokenizer
3132
3233import harmony
33- from harmony .parsing .util .feature_extraction import convert_text_to_features
3434from harmony .parsing .util .tika_wrapper import parse_pdf_to_plain_text
3535from harmony .schemas .requests .text import RawFile , Instrument
3636
37- model_containing_folder = pathlib .Path (__file__ ).parent .resolve ()
37+ # Disable tokenizer parallelism
38+ os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
39+
40+
41+ def group_token_spans_by_class (tokens , classes ,
42+ tokenizer = AutoTokenizer .from_pretrained ("harmonydata/debertaV2_pdfparser" )) -> dict :
43+ """
44+ Given a list of tokens, and a list of predicted classes
45+ for each token, create a dictionary to hold each
46+ span of tokens.
47+ Example:
48+ > group_token_spans_by_classes(['▁how', '▁are', '▁you', '?', '▁1'],
49+ ['question', 'question', 'question', 'question', 'answer'],
50+ bert_tokenizer)
51+ > {"question":["How are you?"], "answer":["1"]}
52+ Notice that some tokens begin with ▁ (ASCII 9601) instead of _ (ASCII 95)
53+ :param tokens: List of tokens
54+ :type tokens: List[str]
55+ :param classes: List of predicted classes
56+ :type classes: List[str]
57+ :param tokenizer: Tokenizer (defaulted to harmonydata/debertaV2_pdfparser)
58+ :return: Dictionary of each span relative to its class
59+ """
60+ grouped_spans = {"answer" : [], "question" : [], "other" : []}
61+ span = []
62+ prev_cls = None
63+
64+ for token , cls in zip (tokens , classes ):
65+ if cls != prev_cls and span :
66+ grouped_spans [prev_cls ].append (tokenizer .convert_tokens_to_string (span ))
67+ span = []
68+ span .append (token )
69+ prev_cls = cls
70+ # Add final token and class to respective key in dictionary
71+ if span :
72+ grouped_spans [prev_cls ].append (tokenizer .convert_tokens_to_string (span ))
73+
74+ return grouped_spans
3875
39- with open (f"{ model_containing_folder } /20240719_pdf_question_extraction_sklearn_crf_model.pkl" , "rb" ) as f :
40- crf_text_model = pkl .load (f )
4176
42- # Predict method is taken from the training repo. Use the training repo as the master copy of the predict method.
43- # All training code is in https://github.com/harmonydata/pdf-questionnaire-extraction
4477def predict (test_text ):
45- token_texts , token_start_char_indices , token_end_char_indices , token_properties = convert_text_to_features (
46- test_text )
47-
48- X = []
49- X .append (token_properties )
50-
51- y_pred = crf_text_model .predict (X )
78+ # Load fine-tuned huggingface model and tokenizer
79+ model = AutoModelForTokenClassification .from_pretrained ("harmonydata/debertaV2_pdfparser" )
80+ tokenizer = AutoTokenizer .from_pretrained ("harmonydata/debertaV2_pdfparser" )
5281
53- questions_from_text = []
82+ # Tokenize input text
83+ tokenized_texts = tokenizer (test_text , return_tensors = "pt" )
5484
55- tokens_already_used = set ()
85+ # Inference with tokenized input text
86+ with torch .no_grad ():
87+ logits = model (** tokenized_texts ).logits
5688
57- last_token_category = "O"
89+ # Retrieve predicted class for each token
90+ predictions = torch .argmax (logits , dim = 2 )
91+ predicted_token_class = [model .config .id2label [t .item ()] for t in predictions [0 ]]
5892
59- for idx in range (len (X [0 ])):
93+ # Get input IDs (tensor) and convert to list
94+ input_ids = tokenized_texts ["input_ids" ][0 ].tolist ()
95+ # Convert input IDs to tokens
96+ decoded_tokenized_texts = tokenizer .convert_ids_to_tokens (input_ids )
6097
61- if y_pred [0 ][idx ] != "O" and idx not in tokens_already_used :
62- if last_token_category == "O" or y_pred [0 ][idx ] == "B" :
63- start_idx = token_start_char_indices [idx ]
64- end_idx = len (test_text )
65- for j in range (idx + 1 , len (X [0 ])):
66- if y_pred [0 ][j ] == "O" or y_pred [0 ][j ] == "B" :
67- end_idx = token_end_char_indices [j - 1 ]
68- break
69- tokens_already_used .add (j )
98+ # Remove leading [CLS] and trailing [SEP] tokens from decoded
99+ # tokens, and the list of predictions
100+ predicted_token_class = predicted_token_class [1 :- 1 ]
101+ decoded_tokenized_texts = decoded_tokenized_texts [1 :- 1 ]
70102
71- question_text = test_text [start_idx :end_idx ]
72- question_text = re .sub (r'\s+' , ' ' , question_text )
73- question_text = question_text .strip ()
74- questions_from_text .append (question_text )
103+ grouped_tokens = group_token_spans_by_class (decoded_tokenized_texts , predicted_token_class , tokenizer )
75104
76- last_token_category = y_pred [0 ][idx ]
77-
78- return questions_from_text
105+ return grouped_tokens
79106
80107
81108def convert_pdf_to_instruments (file : RawFile ) -> Instrument :
@@ -87,8 +114,12 @@ def convert_pdf_to_instruments(file: RawFile) -> Instrument:
87114 if not file .text_content :
88115 file .text_content = parse_pdf_to_plain_text (file .content ) # call Tika to convert the PDF to plain text
89116
90- questions_from_text = predict (file .text_content )
117+ # Run prediction script to return questions and answers from file text content
118+ questions_answers_from_text = predict (file .text_content )
119+
120+ questions_from_text = questions_answers_from_text ["question" ]
121+ answers_from_text = questions_answers_from_text ["answer" ]
91122
92- instrument = harmony .create_instrument_from_list (questions_from_text , instrument_name = file .file_name ,
123+ instrument = harmony .create_instrument_from_list (questions_from_text , answers_from_text , instrument_name = file .file_name ,
93124 file_name = file .file_name )
94125 return [instrument ]
0 commit comments