diff --git a/src/backend.py b/src/backend.py index 6e45f24..2dfcc06 100644 --- a/src/backend.py +++ b/src/backend.py @@ -1,5 +1,6 @@ import json import os +import sys import requests from json_manager import JsonManager from input_manager import InputManager @@ -26,48 +27,64 @@ def type_check_all(self): def build_prompt(self, current_field): - """ - This method is in charge of the prompt engineering. It creates a specific prompt for each target field. + """ + This method is in charge of the prompt engineering. It creates a specific prompt for each target field. @params: current_field -> represents the current element of the json that is being prompted. """ - prompt = f""" + prompt = f""" SYSTEM PROMPT: - You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings. - You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return - only a single string containing the identified value for the JSON field. - If the field name is plural, and you identify more than one possible value in the text, return both separated by a ";". - If you don't identify the value in the provided text, return "-1". + You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings. + You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. + Return ONLY valid JSON with this exact format: {{"value": "", "confidence": <0.0_to_1.0>}} + The "confidence" field is a number from 0.0 to 1.0 indicating how confident you are in the extracted value. + If the field name is plural, and you identify more than one possible value in the text, separate them with ";" inside the "value" string. + If you don't identify the value in the provided text, return: {{"value": "-1", "confidence": 0.0}} --- DATA: Target JSON field to find in text: {current_field} - + TEXT: {self.__transcript_text} """ return prompt + def parse_llm_response(self, raw_response): + """ + Parses the LLM JSON response and extracts value and confidence. + Returns (value, confidence) tuple. + Falls back to (raw_text, 0.0) on parse failure. + """ + try: + data = json.loads(raw_response.strip()) + value = str(data.get("value", raw_response.strip())) + confidence = float(data.get("confidence", 0.0)) + confidence = max(0.0, min(1.0, confidence)) + return value, confidence + except (json.JSONDecodeError, ValueError, TypeError): + print(f"\t[WARNING] Failed to parse LLM response as JSON: {raw_response[:100]}", file=sys.stderr) + return raw_response.strip(), 0.0 + def main_loop(self): #FUTURE -> Refactor this to its own class for field in self.__target_fields: prompt = self.build_prompt(field) - # print(prompt) - # ollama_url = "http://localhost:11434/api/generate" ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") ollama_url = f"{ollama_host}/api/generate" payload = { "model": "mistral", "prompt": prompt, - "stream": False # don't really know why --> look into this later. + "stream": False, + "format": "json" } response = requests.post(ollama_url, json=payload) # parse response json_data = response.json() - parsed_response = json_data['response'] - # print(parsed_response) - self.add_response_to_json(field, parsed_response) - + raw_response = json_data['response'] + value, confidence = self.parse_llm_response(raw_response) + self.add_response_to_json(field, value, confidence) + print("----------------------------------") print("\t[LOG] Resulting JSON created from the input text:") print(json.dumps(self.__json, indent=2)) @@ -75,28 +92,35 @@ def main_loop(self): #FUTURE -> Refactor this to its own class return None - def add_response_to_json(self, field, value): - """ - this method adds the following value under the specified field, - or under a new field if the field doesn't exist, to the json dict + def add_response_to_json(self, field, value, confidence): + """ + this method adds the following value under the specified field, + or under a new field if the field doesn't exist, to the json dict. + Stores each field as {"value": parsed_value, "confidence": confidence}. """ value = value.strip().replace('"', '') parsed_value = None - plural = False - + if value != "-1": - parsed_value = value - + parsed_value = value + if ";" in value: parsed_value = self.handle_plural_values(value) - plural = True + if confidence < 0.5: + print(f"\t[WARNING] Low confidence ({confidence}) for field '{field}'", file=sys.stderr) + + entry = {"value": parsed_value, "confidence": confidence} if field in self.__json.keys(): - self.__json[field].append(parsed_value) - else: - self.__json[field] = parsed_value - + existing = self.__json[field] + if isinstance(existing, list): + existing.append(entry) + else: + self.__json[field] = [existing, entry] + else: + self.__json[field] = entry + return def handle_plural_values(self, plural_value): @@ -126,6 +150,21 @@ def handle_plural_values(self, plural_value): def get_data(self): return self.__json + def get_confidence_report(self): + """Returns a {field: confidence} dict for easy inspection by callers.""" + report = {} + for field, entry in self.__json.items(): + if isinstance(entry, dict) and "confidence" in entry: + report[field] = entry["confidence"] + elif isinstance(entry, list): + report[field] = [ + item["confidence"] if isinstance(item, dict) and "confidence" in item else 0.0 + for item in entry + ] + else: + report[field] = 0.0 + return report + class Fill(): def __init__(self): pass @@ -142,7 +181,10 @@ def fill_form(user_input: str, definitions: list, pdf_form: str): t2j = textToJSON(user_input, definitions) textbox_answers = t2j.get_data() # This is a dictionary - answers_list = list(textbox_answers.values()) + answers_list = [ + entry["value"] if isinstance(entry, dict) and "value" in entry else entry + for entry in textbox_answers.values() + ] # Read PDF pdf = PdfReader(pdf_form) diff --git a/src/test/test_confidence.py b/src/test/test_confidence.py new file mode 100644 index 0000000..c9467f3 --- /dev/null +++ b/src/test/test_confidence.py @@ -0,0 +1,151 @@ +""" +Unit tests for confidence score extraction in textToJSON. +Tests parse_llm_response(), add_response_to_json(), and get_confidence_report(). +""" +import json +import sys +import os +from unittest.mock import patch, MagicMock +import pytest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from backend import textToJSON + + +def make_t2j(fields=None): + """Create a textToJSON instance without triggering main_loop (which calls Ollama).""" + with patch.object(textToJSON, "main_loop", return_value=None): + return textToJSON("dummy text", fields or ["field1"], json={}) + + +class TestParseLlmResponse: + def test_valid_json(self): + t2j = make_t2j() + raw = '{"value": "John Doe", "confidence": 0.92}' + value, confidence = t2j.parse_llm_response(raw) + assert value == "John Doe" + assert confidence == 0.92 + + def test_not_found_response(self): + t2j = make_t2j() + raw = '{"value": "-1", "confidence": 0.0}' + value, confidence = t2j.parse_llm_response(raw) + assert value == "-1" + assert confidence == 0.0 + + def test_plural_values(self): + t2j = make_t2j() + raw = '{"value": "Alice; Bob; Charlie", "confidence": 0.85}' + value, confidence = t2j.parse_llm_response(raw) + assert value == "Alice; Bob; Charlie" + assert confidence == 0.85 + + def test_confidence_clamped_above_one(self): + t2j = make_t2j() + raw = '{"value": "test", "confidence": 1.5}' + value, confidence = t2j.parse_llm_response(raw) + assert value == "test" + assert confidence == 1.0 + + def test_confidence_clamped_below_zero(self): + t2j = make_t2j() + raw = '{"value": "test", "confidence": -0.3}' + value, confidence = t2j.parse_llm_response(raw) + assert value == "test" + assert confidence == 0.0 + + def test_malformed_json_fallback(self): + t2j = make_t2j() + raw = "just some plain text" + value, confidence = t2j.parse_llm_response(raw) + assert value == "just some plain text" + assert confidence == 0.0 + + def test_malformed_json_warns_to_stderr(self, capsys): + t2j = make_t2j() + t2j.parse_llm_response("not valid json") + captured = capsys.readouterr() + assert "[WARNING]" in captured.err + assert "Failed to parse" in captured.err + + def test_missing_value_key(self): + t2j = make_t2j() + raw = '{"confidence": 0.8}' + value, confidence = t2j.parse_llm_response(raw) + # Falls back to the raw string representation of missing key + assert confidence == 0.8 + + def test_missing_confidence_key(self): + t2j = make_t2j() + raw = '{"value": "hello"}' + value, confidence = t2j.parse_llm_response(raw) + assert value == "hello" + assert confidence == 0.0 + + def test_whitespace_padding(self): + t2j = make_t2j() + raw = ' {"value": "trimmed", "confidence": 0.7} ' + value, confidence = t2j.parse_llm_response(raw) + assert value == "trimmed" + assert confidence == 0.7 + + +class TestAddResponseToJson: + def test_stores_value_and_confidence(self): + t2j = make_t2j() + t2j.add_response_to_json("name", "John Doe", 0.92) + data = t2j.get_data() + assert data["name"] == {"value": "John Doe", "confidence": 0.92} + + def test_not_found_stores_none(self): + t2j = make_t2j() + t2j.add_response_to_json("missing_field", "-1", 0.0) + data = t2j.get_data() + assert data["missing_field"] == {"value": None, "confidence": 0.0} + + def test_plural_values_stored_as_list(self): + t2j = make_t2j() + t2j.add_response_to_json("items", "apple; banana; cherry", 0.85) + data = t2j.get_data() + assert data["items"]["value"] == ["apple", "banana", "cherry"] + assert data["items"]["confidence"] == 0.85 + + def test_low_confidence_warning(self, capsys): + t2j = make_t2j() + t2j.add_response_to_json("uncertain", "maybe", 0.3) + captured = capsys.readouterr() + assert "[WARNING]" in captured.err + assert "0.3" in captured.err + + def test_high_confidence_no_warning(self, capsys): + t2j = make_t2j() + t2j.add_response_to_json("certain", "definitely", 0.95) + captured = capsys.readouterr() + assert "[WARNING]" not in captured.err + + def test_duplicate_field_becomes_list(self): + t2j = make_t2j() + t2j.add_response_to_json("name", "John", 0.9) + t2j.add_response_to_json("name", "Jane", 0.8) + data = t2j.get_data() + assert isinstance(data["name"], list) + assert len(data["name"]) == 2 + assert data["name"][0] == {"value": "John", "confidence": 0.9} + assert data["name"][1] == {"value": "Jane", "confidence": 0.8} + + +class TestGetConfidenceReport: + def test_single_fields(self): + t2j = make_t2j() + t2j.add_response_to_json("name", "John", 0.92) + t2j.add_response_to_json("phone", "555-1234", 0.88) + report = t2j.get_confidence_report() + assert report == {"name": 0.92, "phone": 0.88} + + def test_list_field_reports_all_confidences(self): + t2j = make_t2j() + t2j.add_response_to_json("name", "John", 0.9) + t2j.add_response_to_json("name", "Jane", 0.7) + report = t2j.get_confidence_report() + assert report["name"] == [0.9, 0.7]