diff --git a/src/llm.py b/src/llm.py index 70937f9..4b5f7d2 100644 --- a/src/llm.py +++ b/src/llm.py @@ -4,132 +4,104 @@ class LLM: - def __init__(self, transcript_text=None, target_fields=None, json=None): - if json is None: - json = {} - self._transcript_text = transcript_text # str - self._target_fields = target_fields # List, contains the template field. - self._json = json # dictionary + def __init__(self, transcript_text=None, target_fields=None, json_data=None): + if json_data is None: + json_data = {} + + self._transcript_text = transcript_text + self._target_fields = target_fields + self._json = json_data def type_check_all(self): if type(self._transcript_text) is not str: raise TypeError( - f"ERROR in LLM() attributes ->\ - Transcript must be text. Input:\n\ttranscript_text: {self._transcript_text}" + f"Transcript must be text. Received: {self._transcript_text}" ) - elif type(self._target_fields) is not list: + + if type(self._target_fields) is not dict: raise TypeError( - f"ERROR in LLM() attributes ->\ - Target fields must be a list. Input:\n\ttarget_fields: {self._target_fields}" + f"Target fields must be a dictionary. Received: {self._target_fields}" ) - def build_prompt(self, current_field): - """ - This method is in charge of the prompt engineering. It creates a specific prompt for each target field. - @params: current_field -> represents the current element of the json that is being prompted. - """ - prompt = f""" - SYSTEM PROMPT: - You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings. - You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return - only a single string containing the identified value for the JSON field. - If the field name is plural, and you identify more than one possible value in the text, return both separated by a ";". - If you don't identify the value in the provided text, return "-1". - --- - DATA: - Target JSON field to find in text: {current_field} - - TEXT: {self._transcript_text} - """ + def build_structured_prompt(self): + fields = list(self._target_fields.keys()) + schema = "\n".join(fields) + + prompt = f""" +You are an AI system that extracts structured information from incident reports. + +Return ONLY valid JSON matching the fields below. + +Fields: +{schema} + +Text: +{self._transcript_text} +""" return prompt def main_loop(self): - # self.type_check_all() - for field in self._target_fields.keys(): - prompt = self.build_prompt(field) - # print(prompt) - # ollama_url = "http://localhost:11434/api/generate" - ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") - ollama_url = f"{ollama_host}/api/generate" - - payload = { - "model": "mistral", - "prompt": prompt, - "stream": False, # don't really know why --> look into this later. - } - - try: - response = requests.post(ollama_url, json=payload) - response.raise_for_status() - except requests.exceptions.ConnectionError: - raise ConnectionError( - f"Could not connect to Ollama at {ollama_url}. " - "Please ensure Ollama is running and accessible." - ) - except requests.exceptions.HTTPError as e: - raise RuntimeError(f"Ollama returned an error: {e}") - - # parse response - json_data = response.json() - parsed_response = json_data["response"] - # print(parsed_response) - self.add_response_to_json(field, parsed_response) + self.type_check_all() + return self.structured_extraction() - print("----------------------------------") - print("\t[LOG] Resulting JSON created from the input text:") - print(json.dumps(self._json, indent=2)) - print("--------- extracted data ---------") + def structured_extraction(self): + print("[LLM] Running structured extraction") - return self + prompt = self.build_structured_prompt() - def add_response_to_json(self, field, value): - """ - this method adds the following value under the specified field, - or under a new field if the field doesn't exist, to the json dict - """ - value = value.strip().replace('"', "") - parsed_value = None + ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") + url = f"{ollama_host}/api/generate" - if value != "-1": - parsed_value = value + payload = { + "model": "mistral", + "prompt": prompt, + "stream": False + } - if ";" in value: - parsed_value = self.handle_plural_values(value) + try: + response = requests.post(url, json=payload, timeout=120) + response.raise_for_status() + except requests.exceptions.ConnectionError: + raise RuntimeError( + f"Could not connect to Ollama at {url}. Ensure Ollama is running." + ) + except requests.exceptions.Timeout: + raise RuntimeError("LLM request timed out.") + + result = response.json()["response"].strip() - if field in self._json.keys(): - self._json[field].append(parsed_value) - else: - self._json[field] = parsed_value + try: + parsed = json.loads(result) + except json.JSONDecodeError: + raise RuntimeError( + f"LLM returned invalid JSON:\n{result}" + ) - return + self._json = parsed + + print("----------------------------------") + print("[LOG] Resulting JSON created from the input text:") + print(json.dumps(self._json, indent=2)) + print("--------- extracted data ---------") + + return self def handle_plural_values(self, plural_value): - """ - This method handles plural values. - Takes in strings of the form 'value1; value2; value3; ...; valueN' - returns a list with the respective values -> [value1, value2, value3, ..., valueN] - """ if ";" not in plural_value: raise ValueError( - f"Value is not plural, doesn't have ; separator, Value: {plural_value}" + f"Value does not contain ';' separator: {plural_value}" ) print( - f"\t[LOG]: Formating plural values for JSON, [For input {plural_value}]..." + f"[LOG] Formatting plural values for JSON (input: {plural_value})" ) - values = plural_value.split(";") - # Remove trailing leading whitespace - for i in range(len(values)): - current = i + 1 - if current < len(values): - clean_value = values[current].lstrip() - values[current] = clean_value + values = [v.strip() for v in plural_value.split(";")] - print(f"\t[LOG]: Resulting formatted list of values: {values}") + print(f"[LOG] Resulting formatted list: {values}") return values def get_data(self): - return self._json + return self._json \ No newline at end of file