Save and restore contexts

This commit is contained in:
James Ketr 2025-03-31 15:26:41 -07:00
parent 0938cc5b55
commit 2d5f6a2797
2 changed files with 59 additions and 4 deletions

2
src/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
sessions

View File

@ -333,7 +333,6 @@ def is_valid_uuid(value):
except (ValueError, TypeError): except (ValueError, TypeError):
return False return False
def default_tools(tools): def default_tools(tools):
return [{**tool, "enabled": True} for tool in tools] return [{**tool, "enabled": True} for tool in tools]
@ -394,7 +393,8 @@ class WebServer:
if not response: if not response:
return JSONResponse({ "error": "Usage: { reset: rag|tools|history|system-prompt}"}) return JSONResponse({ "error": "Usage: { reset: rag|tools|history|system-prompt}"})
else: else:
return JSONResponse(response); self.save_context(context_id)
return JSONResponse(response)
except: except:
return JSONResponse({ "error": "Usage: { reset: rag|tools|history|system-prompt}"}) return JSONResponse({ "error": "Usage: { reset: rag|tools|history|system-prompt}"})
@ -411,6 +411,7 @@ class WebServer:
if not system_prompt: if not system_prompt:
return JSONResponse({ "status": "error", "message": "System prompt can not be empty." }) return JSONResponse({ "status": "error", "message": "System prompt can not be empty." })
context["system"] = [{"role": "system", "content": system_prompt}] context["system"] = [{"role": "system", "content": system_prompt}]
self.save_context(context_id)
return JSONResponse({ "system-prompt": system_prompt }) return JSONResponse({ "system-prompt": system_prompt })
@self.app.get('/api/system-prompt/{context_id}') @self.app.get('/api/system-prompt/{context_id}')
@ -429,6 +430,8 @@ class WebServer:
async for message in self.chat(context=context, content=data['content']): async for message in self.chat(context=context, content=data['content']):
# Convert to JSON and add newline # Convert to JSON and add newline
yield json.dumps(message) + "\n" yield json.dumps(message) + "\n"
# Save the history as its generated
self.save_context(context_id)
# Explicitly flush after each yield # Explicitly flush after each yield
await asyncio.sleep(0) # Allow the event loop to process the write await asyncio.sleep(0) # Allow the event loop to process the write
@ -472,6 +475,7 @@ class WebServer:
for tool in context["tools"]: for tool in context["tools"]:
if modify == tool["function"]["name"]: if modify == tool["function"]["name"]:
tool["enabled"] = enabled tool["enabled"] = enabled
self.save_context(context_id)
return JSONResponse(context["tools"]) return JSONResponse(context["tools"])
return JSONResponse({ "status": f"{modify} not found in tools." }), 404 return JSONResponse({ "status": f"{modify} not found in tools." }), 404
except: except:
@ -495,6 +499,55 @@ class WebServer:
self.logging.info(f"Serve index.html for {path}") self.logging.info(f"Serve index.html for {path}")
return FileResponse('/opt/airc/src/ketr-chat/build/index.html') return FileResponse('/opt/airc/src/ketr-chat/build/index.html')
def save_context(self, session_id):
"""
Serialize a Python dictionary to a file in the sessions directory.
Args:
data: Dictionary containing the session data
session_id: UUID string for the context. If it doesn't exist, it is created
Returns:
The session_id used for the file
"""
context = self.upsert_context(session_id)
# Create sessions directory if it doesn't exist
if not os.path.exists("sessions"):
os.makedirs("sessions")
# Create the full file path
file_path = os.path.join("sessions", session_id)
# Serialize the data to JSON and write to file
with open(file_path, 'w') as f:
json.dump(context, f)
return session_id
def load_context(self, session_id):
"""
Load a serialized Python dictionary from a file in the sessions directory.
Args:
session_id: UUID string for the filename
Returns:
The deserialized dictionary, or a new context if it doesn't exist on disk.
"""
file_path = os.path.join("sessions", session_id)
# Check if the file exists
if not os.path.exists(file_path):
return self.create_context(session_id)
# Read and deserialize the data
with open(file_path, 'r') as f:
self.contexts[session_id] = json.load(f)
return self.contexts[session_id]
def create_context(self, context_id = None): def create_context(self, context_id = None):
if not context_id: if not context_id:
context_id = str(uuid.uuid4()) context_id = str(uuid.uuid4())
@ -517,7 +570,7 @@ class WebServer:
logging.info(f"Context {context_id} found.") logging.info(f"Context {context_id} found.")
return self.contexts[context_id] return self.contexts[context_id]
logging.info(f"Context {context_id} not found. Creating new context.") logging.info(f"Context {context_id} not found. Creating new context.")
return self.create_context(context_id) return self.load_context(context_id)
async def chat(self, context, content): async def chat(self, context, content):
content = content.strip() content = content.strip()