Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/codegate/pipeline/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def _record_to_db(self):
await self._db_recorder.record_context(self._input_context)

async def process_stream(
self, stream: AsyncIterator[ModelResponse]
self, stream: AsyncIterator[ModelResponse], cleanup_sensitive: bool = True
) -> AsyncIterator[ModelResponse]:
"""
Process a stream through all pipeline steps
Expand Down Expand Up @@ -182,7 +182,7 @@ async def process_stream(
self._context.buffer.clear()

# Cleanup sensitive data through the input context
if self._input_context and self._input_context.sensitive:
if cleanup_sensitive and self._input_context and self._input_context.sensitive:
self._input_context.sensitive.secure_cleanup()


Expand Down
64 changes: 45 additions & 19 deletions src/codegate/providers/copilot/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,10 +705,15 @@ def __init__(self, proxy: CopilotProvider):
self.stream_queue: Optional[asyncio.Queue] = None
self.processing_task: Optional[asyncio.Task] = None

self.finish_stream = False

# For debugging only
# self.data_sent = []

def connection_made(self, transport: asyncio.Transport) -> None:
"""Handle successful connection to target"""
self.transport = transport
logger.debug(f"Target transport peer: {transport.get_extra_info('peername')}")
logger.debug(f"Connection established to target: {transport.get_extra_info('peername')}")
self.proxy.target_transport = transport

def _ensure_output_processor(self) -> None:
Expand Down Expand Up @@ -737,7 +742,7 @@ async def _process_stream(self):
try:

async def stream_iterator():
while True:
while not self.stream_queue.empty():
incoming_record = await self.stream_queue.get()

record_content = incoming_record.get("content", {})
Expand All @@ -750,6 +755,9 @@ async def stream_iterator():
else:
content = choice.get("delta", {}).get("content")

if choice.get("finish_reason", None) == "stop":
self.finish_stream = True

streaming_choices.append(
StreamingChoices(
finish_reason=choice.get("finish_reason", None),
Expand All @@ -771,22 +779,18 @@ async def stream_iterator():
)
yield mr

async for record in self.output_pipeline_instance.process_stream(stream_iterator()):
async for record in self.output_pipeline_instance.process_stream(
stream_iterator(), cleanup_sensitive=False
):
chunk = record.model_dump_json(exclude_none=True, exclude_unset=True)
sse_data = f"data: {chunk}\n\n".encode("utf-8")
chunk_size = hex(len(sse_data))[2:] + "\r\n"
self._proxy_transport_write(chunk_size.encode())
self._proxy_transport_write(sse_data)
self._proxy_transport_write(b"\r\n")

sse_data = b"data: [DONE]\n\n"
# Add chunk size for DONE message too
chunk_size = hex(len(sse_data))[2:] + "\r\n"
self._proxy_transport_write(chunk_size.encode())
self._proxy_transport_write(sse_data)
self._proxy_transport_write(b"\r\n")
# Now send the final zero chunk
self._proxy_transport_write(b"0\r\n\r\n")
if self.finish_stream:
self.finish_data()

except asyncio.CancelledError:
logger.debug("Stream processing cancelled")
Expand All @@ -795,12 +799,37 @@ async def stream_iterator():
logger.error(f"Error processing stream: {e}")
finally:
# Clean up
self.stream_queue = None
if self.processing_task and not self.processing_task.done():
self.processing_task.cancel()
if self.proxy.context_tracking and self.proxy.context_tracking.sensitive:
self.proxy.context_tracking.sensitive.secure_cleanup()

def finish_data(self):
logger.debug("Finishing data stream")
sse_data = b"data: [DONE]\n\n"
# Add chunk size for DONE message too
chunk_size = hex(len(sse_data))[2:] + "\r\n"
self._proxy_transport_write(chunk_size.encode())
self._proxy_transport_write(sse_data)
self._proxy_transport_write(b"\r\n")
# Now send the final zero chunk
self._proxy_transport_write(b"0\r\n\r\n")

# For debugging only
# print("===========START DATA SENT====================")
# for data in self.data_sent:
# print(data)
# self.data_sent = []
# print("===========START DATA SENT====================")

self.finish_stream = False
self.headers_sent = False

def _process_chunk(self, chunk: bytes):
# For debugging only
# print("===========START DATA RECVD====================")
# print(chunk)
# print("===========END DATA RECVD======================")

records = self.sse_processor.process_chunk(chunk)

for record in records:
Expand All @@ -812,13 +841,12 @@ def _process_chunk(self, chunk: bytes):
self.stream_queue.put_nowait(record)

def _proxy_transport_write(self, data: bytes):
# For debugging only
# self.data_sent.append(data)
if not self.proxy.transport or self.proxy.transport.is_closing():
logger.error("Proxy transport not available")
return
self.proxy.transport.write(data)
# print("DEBUG =================================")
# print(data)
# print("DEBUG =================================")

def data_received(self, data: bytes) -> None:
"""Handle data received from target"""
Expand Down Expand Up @@ -848,15 +876,13 @@ def data_received(self, data: bytes) -> None:
logger.debug(f"Headers sent: {headers}")

data = data[header_end + 4 :]
# print("DEBUG =================================")
# print(data)
# print("DEBUG =================================")

self._process_chunk(data)

def connection_lost(self, exc: Optional[Exception]) -> None:
"""Handle connection loss to target"""

logger.debug("Lost connection to target")
if (
not self.proxy._closing
and self.proxy.transport
Expand Down
Loading