diff --git a/requirements.txt b/requirements.txt index eaa6c81..6dac9b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ flask commonforms fastapi uvicorn -pydantic +pydantic>=2.0.0 sqlmodel pytest httpx diff --git a/src/llm.py b/src/llm.py index 70937f9..9f003f7 100644 --- a/src/llm.py +++ b/src/llm.py @@ -1,6 +1,7 @@ import json import os import requests +from pydantic import create_model, Field class LLM: @@ -17,64 +18,100 @@ def type_check_all(self): f"ERROR in LLM() attributes ->\ Transcript must be text. Input:\n\ttranscript_text: {self._transcript_text}" ) - elif type(self._target_fields) is not list: + # Updated to handle both list and dict based on earlier usage + if not isinstance(self._target_fields, (list, dict)): raise TypeError( f"ERROR in LLM() attributes ->\ - Target fields must be a list. Input:\n\ttarget_fields: {self._target_fields}" + Target fields must be a list or dict. Input:\n\ttarget_fields: {self._target_fields}" ) - def build_prompt(self, current_field): - """ - This method is in charge of the prompt engineering. It creates a specific prompt for each target field. - @params: current_field -> represents the current element of the json that is being prompted. - """ + def _get_fields_iterable(self): + """Helper to safely iterate over fields whether they are passed as a list or a dict.""" + if isinstance(self._target_fields, dict): + return list(self._target_fields.keys()) + return self._target_fields + + def build_schema(self): + """Dynamically generates a Pydantic schema based on target fields.""" + field_definitions = {} + for field in self._get_fields_iterable(): + clean_name = field.replace(" ", "_").replace("'", "").replace("-", "_").replace("/", "_") + if clean_name and clean_name[0].isdigit(): + clean_name = "f_" + clean_name + + description = f"Extract the value for '{field}'. If not found, return an empty string. If multiple values, separate by ';'." + field_definitions[clean_name] = (str, Field(default="", description=description)) + + DynamicModel = create_model('FormExtraction', **field_definitions) + return DynamicModel.model_json_schema() + + def map_schema_to_json(self, extracted_data): + """Maps the strictly typed LLM output back to the original PDF field names.""" + for original_field in self._get_fields_iterable(): + clean_name = original_field.replace(" ", "_").replace("'", "").replace("-", "_").replace("/", "_") + if clean_name and clean_name[0].isdigit(): + clean_name = "f_" + clean_name + + if clean_name in extracted_data: + raw_value = extracted_data[clean_name] + + if not raw_value or raw_value == "-1": + continue + + if ";" in raw_value: + values = [v.strip() for v in raw_value.split(";") if v.strip()] + self._json[original_field] = values + else: + self._json[original_field] = raw_value + + def main_loop(self): + self.type_check_all() + + fields_list = self._get_fields_iterable() + print(f"\t[LOG] Extracting {len(fields_list)} fields using Pydantic structured output...") + + schema = self.build_schema() + prompt = f""" SYSTEM PROMPT: You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings. - You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return - only a single string containing the identified value for the JSON field. - If the field name is plural, and you identify more than one possible value in the text, return both separated by a ";". - If you don't identify the value in the provided text, return "-1". + You will receive the transcription and must extract the values for all fields defined in the JSON schema. + Return ONLY a valid JSON object matching the provided schema. No markdown, no extra text. + If you don't identify the value in the provided text, leave it as an empty string. --- - DATA: - Target JSON field to find in text: {current_field} - TEXT: {self._transcript_text} """ - return prompt - - def main_loop(self): - # self.type_check_all() - for field in self._target_fields.keys(): - prompt = self.build_prompt(field) - # print(prompt) - # ollama_url = "http://localhost:11434/api/generate" - ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") - ollama_url = f"{ollama_host}/api/generate" - - payload = { - "model": "mistral", - "prompt": prompt, - "stream": False, # don't really know why --> look into this later. - } - - try: - response = requests.post(ollama_url, json=payload) - response.raise_for_status() - except requests.exceptions.ConnectionError: - raise ConnectionError( - f"Could not connect to Ollama at {ollama_url}. " - "Please ensure Ollama is running and accessible." - ) - except requests.exceptions.HTTPError as e: - raise RuntimeError(f"Ollama returned an error: {e}") - - # parse response - json_data = response.json() - parsed_response = json_data["response"] - # print(parsed_response) - self.add_response_to_json(field, parsed_response) + ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") + ollama_url = f"{ollama_host}/api/generate" + + payload = { + "model": "mistral", + "prompt": prompt, + "format": schema, + "stream": False, + } + + try: + response = requests.post(ollama_url, json=payload) + response.raise_for_status() + except requests.exceptions.ConnectionError: + raise ConnectionError( + f"Could not connect to Ollama at {ollama_url}. " + "Please ensure Ollama is running and accessible." + ) + except requests.exceptions.HTTPError as e: + raise RuntimeError(f"Ollama returned an error: {e}") + + # parse response + json_data = response.json() + try: + parsed_response = json.loads(json_data["response"]) + except json.JSONDecodeError: + print("\t[ERROR] LLM did not return valid JSON. Defaulting to empty extraction.") + parsed_response = {} + + self.map_schema_to_json(parsed_response) print("----------------------------------") print("\t[LOG] Resulting JSON created from the input text:") @@ -83,53 +120,5 @@ def main_loop(self): return self - def add_response_to_json(self, field, value): - """ - this method adds the following value under the specified field, - or under a new field if the field doesn't exist, to the json dict - """ - value = value.strip().replace('"', "") - parsed_value = None - - if value != "-1": - parsed_value = value - - if ";" in value: - parsed_value = self.handle_plural_values(value) - - if field in self._json.keys(): - self._json[field].append(parsed_value) - else: - self._json[field] = parsed_value - - return - - def handle_plural_values(self, plural_value): - """ - This method handles plural values. - Takes in strings of the form 'value1; value2; value3; ...; valueN' - returns a list with the respective values -> [value1, value2, value3, ..., valueN] - """ - if ";" not in plural_value: - raise ValueError( - f"Value is not plural, doesn't have ; separator, Value: {plural_value}" - ) - - print( - f"\t[LOG]: Formating plural values for JSON, [For input {plural_value}]..." - ) - values = plural_value.split(";") - - # Remove trailing leading whitespace - for i in range(len(values)): - current = i + 1 - if current < len(values): - clean_value = values[current].lstrip() - values[current] = clean_value - - print(f"\t[LOG]: Resulting formatted list of values: {values}") - - return values - def get_data(self): - return self._json + return self._json \ No newline at end of file