Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ flask
commonforms
fastapi
uvicorn
pydantic
pydantic>=2.0.0
sqlmodel
pytest
httpx
Expand Down
181 changes: 85 additions & 96 deletions src/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import requests
from pydantic import create_model, Field


class LLM:
Expand All @@ -17,64 +18,100 @@ def type_check_all(self):
f"ERROR in LLM() attributes ->\
Transcript must be text. Input:\n\ttranscript_text: {self._transcript_text}"
)
elif type(self._target_fields) is not list:
# Updated to handle both list and dict based on earlier usage
if not isinstance(self._target_fields, (list, dict)):
raise TypeError(
f"ERROR in LLM() attributes ->\
Target fields must be a list. Input:\n\ttarget_fields: {self._target_fields}"
Target fields must be a list or dict. Input:\n\ttarget_fields: {self._target_fields}"
)

def build_prompt(self, current_field):
"""
This method is in charge of the prompt engineering. It creates a specific prompt for each target field.
@params: current_field -> represents the current element of the json that is being prompted.
"""
def _get_fields_iterable(self):
"""Helper to safely iterate over fields whether they are passed as a list or a dict."""
if isinstance(self._target_fields, dict):
return list(self._target_fields.keys())
return self._target_fields

def build_schema(self):
"""Dynamically generates a Pydantic schema based on target fields."""
field_definitions = {}
for field in self._get_fields_iterable():
clean_name = field.replace(" ", "_").replace("'", "").replace("-", "_").replace("/", "_")
if clean_name and clean_name[0].isdigit():
clean_name = "f_" + clean_name

description = f"Extract the value for '{field}'. If not found, return an empty string. If multiple values, separate by ';'."
field_definitions[clean_name] = (str, Field(default="", description=description))

DynamicModel = create_model('FormExtraction', **field_definitions)
return DynamicModel.model_json_schema()

def map_schema_to_json(self, extracted_data):
"""Maps the strictly typed LLM output back to the original PDF field names."""
for original_field in self._get_fields_iterable():
clean_name = original_field.replace(" ", "_").replace("'", "").replace("-", "_").replace("/", "_")
if clean_name and clean_name[0].isdigit():
clean_name = "f_" + clean_name

if clean_name in extracted_data:
raw_value = extracted_data[clean_name]

if not raw_value or raw_value == "-1":
continue

if ";" in raw_value:
values = [v.strip() for v in raw_value.split(";") if v.strip()]
self._json[original_field] = values
else:
self._json[original_field] = raw_value

def main_loop(self):
self.type_check_all()

fields_list = self._get_fields_iterable()
print(f"\t[LOG] Extracting {len(fields_list)} fields using Pydantic structured output...")

schema = self.build_schema()

prompt = f"""
SYSTEM PROMPT:
You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings.
You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return
only a single string containing the identified value for the JSON field.
If the field name is plural, and you identify more than one possible value in the text, return both separated by a ";".
If you don't identify the value in the provided text, return "-1".
You will receive the transcription and must extract the values for all fields defined in the JSON schema.
Return ONLY a valid JSON object matching the provided schema. No markdown, no extra text.
If you don't identify the value in the provided text, leave it as an empty string.
---
DATA:
Target JSON field to find in text: {current_field}

TEXT: {self._transcript_text}
"""

return prompt

def main_loop(self):
# self.type_check_all()
for field in self._target_fields.keys():
prompt = self.build_prompt(field)
# print(prompt)
# ollama_url = "http://localhost:11434/api/generate"
ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/")
ollama_url = f"{ollama_host}/api/generate"

payload = {
"model": "mistral",
"prompt": prompt,
"stream": False, # don't really know why --> look into this later.
}

try:
response = requests.post(ollama_url, json=payload)
response.raise_for_status()
except requests.exceptions.ConnectionError:
raise ConnectionError(
f"Could not connect to Ollama at {ollama_url}. "
"Please ensure Ollama is running and accessible."
)
except requests.exceptions.HTTPError as e:
raise RuntimeError(f"Ollama returned an error: {e}")

# parse response
json_data = response.json()
parsed_response = json_data["response"]
# print(parsed_response)
self.add_response_to_json(field, parsed_response)
ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/")
ollama_url = f"{ollama_host}/api/generate"

payload = {
"model": "mistral",
"prompt": prompt,
"format": schema,
"stream": False,
}

try:
response = requests.post(ollama_url, json=payload)
response.raise_for_status()
except requests.exceptions.ConnectionError:
raise ConnectionError(
f"Could not connect to Ollama at {ollama_url}. "
"Please ensure Ollama is running and accessible."
)
except requests.exceptions.HTTPError as e:
raise RuntimeError(f"Ollama returned an error: {e}")

# parse response
json_data = response.json()
try:
parsed_response = json.loads(json_data["response"])
except json.JSONDecodeError:
print("\t[ERROR] LLM did not return valid JSON. Defaulting to empty extraction.")
parsed_response = {}

self.map_schema_to_json(parsed_response)

print("----------------------------------")
print("\t[LOG] Resulting JSON created from the input text:")
Expand All @@ -83,53 +120,5 @@ def main_loop(self):

return self

def add_response_to_json(self, field, value):
"""
this method adds the following value under the specified field,
or under a new field if the field doesn't exist, to the json dict
"""
value = value.strip().replace('"', "")
parsed_value = None

if value != "-1":
parsed_value = value

if ";" in value:
parsed_value = self.handle_plural_values(value)

if field in self._json.keys():
self._json[field].append(parsed_value)
else:
self._json[field] = parsed_value

return

def handle_plural_values(self, plural_value):
"""
This method handles plural values.
Takes in strings of the form 'value1; value2; value3; ...; valueN'
returns a list with the respective values -> [value1, value2, value3, ..., valueN]
"""
if ";" not in plural_value:
raise ValueError(
f"Value is not plural, doesn't have ; separator, Value: {plural_value}"
)

print(
f"\t[LOG]: Formating plural values for JSON, [For input {plural_value}]..."
)
values = plural_value.split(";")

# Remove trailing leading whitespace
for i in range(len(values)):
current = i + 1
if current < len(values):
clean_value = values[current].lstrip()
values[current] = clean_value

print(f"\t[LOG]: Resulting formatted list of values: {values}")

return values

def get_data(self):
return self._json
return self._json