Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 68 additions & 96 deletions src/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,132 +4,104 @@


class LLM:
def __init__(self, transcript_text=None, target_fields=None, json=None):
if json is None:
json = {}
self._transcript_text = transcript_text # str
self._target_fields = target_fields # List, contains the template field.
self._json = json # dictionary
def __init__(self, transcript_text=None, target_fields=None, json_data=None):
if json_data is None:
json_data = {}

self._transcript_text = transcript_text
self._target_fields = target_fields
self._json = json_data

def type_check_all(self):
if type(self._transcript_text) is not str:
raise TypeError(
f"ERROR in LLM() attributes ->\
Transcript must be text. Input:\n\ttranscript_text: {self._transcript_text}"
f"Transcript must be text. Received: {self._transcript_text}"
)
elif type(self._target_fields) is not list:

if type(self._target_fields) is not dict:
raise TypeError(
f"ERROR in LLM() attributes ->\
Target fields must be a list. Input:\n\ttarget_fields: {self._target_fields}"
f"Target fields must be a dictionary. Received: {self._target_fields}"
)

def build_prompt(self, current_field):
"""
This method is in charge of the prompt engineering. It creates a specific prompt for each target field.
@params: current_field -> represents the current element of the json that is being prompted.
"""
prompt = f"""
SYSTEM PROMPT:
You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings.
You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return
only a single string containing the identified value for the JSON field.
If the field name is plural, and you identify more than one possible value in the text, return both separated by a ";".
If you don't identify the value in the provided text, return "-1".
---
DATA:
Target JSON field to find in text: {current_field}

TEXT: {self._transcript_text}
"""
def build_structured_prompt(self):
fields = list(self._target_fields.keys())
schema = "\n".join(fields)

prompt = f"""
You are an AI system that extracts structured information from incident reports.

Return ONLY valid JSON matching the fields below.

Fields:
{schema}

Text:
{self._transcript_text}
"""

return prompt

def main_loop(self):
# self.type_check_all()
for field in self._target_fields.keys():
prompt = self.build_prompt(field)
# print(prompt)
# ollama_url = "http://localhost:11434/api/generate"
ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/")
ollama_url = f"{ollama_host}/api/generate"

payload = {
"model": "mistral",
"prompt": prompt,
"stream": False, # don't really know why --> look into this later.
}

try:
response = requests.post(ollama_url, json=payload)
response.raise_for_status()
except requests.exceptions.ConnectionError:
raise ConnectionError(
f"Could not connect to Ollama at {ollama_url}. "
"Please ensure Ollama is running and accessible."
)
except requests.exceptions.HTTPError as e:
raise RuntimeError(f"Ollama returned an error: {e}")

# parse response
json_data = response.json()
parsed_response = json_data["response"]
# print(parsed_response)
self.add_response_to_json(field, parsed_response)
self.type_check_all()
return self.structured_extraction()

print("----------------------------------")
print("\t[LOG] Resulting JSON created from the input text:")
print(json.dumps(self._json, indent=2))
print("--------- extracted data ---------")
def structured_extraction(self):
print("[LLM] Running structured extraction")

return self
prompt = self.build_structured_prompt()

def add_response_to_json(self, field, value):
"""
this method adds the following value under the specified field,
or under a new field if the field doesn't exist, to the json dict
"""
value = value.strip().replace('"', "")
parsed_value = None
ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/")
url = f"{ollama_host}/api/generate"

if value != "-1":
parsed_value = value
payload = {
"model": "mistral",
"prompt": prompt,
"stream": False
}

if ";" in value:
parsed_value = self.handle_plural_values(value)
try:
response = requests.post(url, json=payload, timeout=120)
response.raise_for_status()
except requests.exceptions.ConnectionError:
raise RuntimeError(
f"Could not connect to Ollama at {url}. Ensure Ollama is running."
)
except requests.exceptions.Timeout:
raise RuntimeError("LLM request timed out.")

result = response.json()["response"].strip()

if field in self._json.keys():
self._json[field].append(parsed_value)
else:
self._json[field] = parsed_value
try:
parsed = json.loads(result)
except json.JSONDecodeError:
raise RuntimeError(
f"LLM returned invalid JSON:\n{result}"
)

return
self._json = parsed

print("----------------------------------")
print("[LOG] Resulting JSON created from the input text:")
print(json.dumps(self._json, indent=2))
print("--------- extracted data ---------")

return self

def handle_plural_values(self, plural_value):
"""
This method handles plural values.
Takes in strings of the form 'value1; value2; value3; ...; valueN'
returns a list with the respective values -> [value1, value2, value3, ..., valueN]
"""
if ";" not in plural_value:
raise ValueError(
f"Value is not plural, doesn't have ; separator, Value: {plural_value}"
f"Value does not contain ';' separator: {plural_value}"
)

print(
f"\t[LOG]: Formating plural values for JSON, [For input {plural_value}]..."
f"[LOG] Formatting plural values for JSON (input: {plural_value})"
)
values = plural_value.split(";")

# Remove trailing leading whitespace
for i in range(len(values)):
current = i + 1
if current < len(values):
clean_value = values[current].lstrip()
values[current] = clean_value
values = [v.strip() for v in plural_value.split(";")]

print(f"\t[LOG]: Resulting formatted list of values: {values}")
print(f"[LOG] Resulting formatted list: {values}")

return values

def get_data(self):
return self._json
return self._json