LLM_TIMEOUT = 600 from utils import logger from pydantic import BaseModel, Field, ValidationError # type: ignore from pydantic_core import PydanticSerializationError # type: ignore from typing import List from typing import AsyncGenerator, Dict, Optional # %% # Imports [standard] # Standard library modules (no try-except needed) import argparse import asyncio import json import logging import os import re import uuid import subprocess import re import math import warnings # from typing import Any import inspect import time import traceback def try_import(module_name, pip_name=None): try: __import__(module_name) except ImportError: print(f"Module '{module_name}' not found. Install it using:") print(f" pip install {pip_name or module_name}") # Third-party modules with import checks try_import("ollama") try_import("requests") try_import("fastapi") try_import("uvicorn") try_import("numpy") try_import("umap") try_import("sklearn") try_import("prometheus_client") try_import("prometheus_fastapi_instrumentator") import ollama from contextlib import asynccontextmanager from fastapi import FastAPI, Request, HTTPException # type: ignore from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse # type: ignore from fastapi.middleware.cors import CORSMiddleware # type: ignore import uvicorn # type: ignore import numpy as np # type: ignore # Prometheus from prometheus_client import Summary # type: ignore from prometheus_fastapi_instrumentator import Instrumentator # type: ignore from prometheus_client import CollectorRegistry, Counter # type: ignore from utils import ( rag as Rag, RagEntry, tools as Tools, Context, Conversation, Message, Agent, Metrics, Tunables, defines, User, check_serializable, logger, ) class Query(BaseModel): prompt: str tunables: Tunables = Field(default_factory=Tunables) agent_options: Dict[str, str | int | float | Dict] = Field(default={}, exclude=True) REQUEST_TIME = Summary("request_processing_seconds", "Time spent processing request") def get_installed_ram(): try: with open("/proc/meminfo", "r") as f: meminfo = f.read() match = re.search(r"MemTotal:\s+(\d+)", meminfo) if match: return f"{math.floor(int(match.group(1)) / 1000**2)}GB" # Convert KB to GB except Exception as e: return f"Error retrieving RAM: {e}" def get_graphics_cards(): gpus = [] try: # Run the ze-monitor utility result = subprocess.run( ["ze-monitor"], capture_output=True, text=True, check=True ) # Clean up the output (remove leading/trailing whitespace and newlines) output = result.stdout.strip() for index in range(len(output.splitlines())): result = subprocess.run( ["ze-monitor", "--device", f"{index+1}", "--info"], capture_output=True, text=True, check=True, ) gpu_info = result.stdout.strip().splitlines() gpu = { "discrete": True, # Assume it's discrete initially "name": None, "memory": None, } gpus.append(gpu) for line in gpu_info: match = re.match(r"^Device: [^(]*\((.*)\)", line) if match: gpu["name"] = match.group(1) continue match = re.match(r"^\s*Memory: (.*)", line) if match: gpu["memory"] = match.group(1) continue match = re.match(r"^.*Is integrated with host: Yes.*", line) if match: gpu["discrete"] = False continue return gpus except Exception as e: return f"Error retrieving GPU info: {e}" def get_cpu_info(): try: with open("/proc/cpuinfo", "r") as f: cpuinfo = f.read() model_match = re.search(r"model name\s+:\s+(.+)", cpuinfo) cores_match = re.findall(r"processor\s+:\s+\d+", cpuinfo) if model_match and cores_match: return f"{model_match.group(1)} with {len(cores_match)} cores" except Exception as e: return f"Error retrieving CPU info: {e}" def system_info(model): return { "System RAM": get_installed_ram(), "Graphics Card": get_graphics_cards(), "CPU": get_cpu_info(), "LLM Model": model, "Embedding Model": defines.embedding_model, "Context length": defines.max_context, } # %% # Defaults OLLAMA_API_URL = defines.ollama_api_url MODEL_NAME = defines.model LOG_LEVEL = "info" WEB_HOST = "0.0.0.0" WEB_PORT = 8911 DEFAULT_HISTORY_LENGTH = 5 # %% # Cmd line overrides def parse_args(): parser = argparse.ArgumentParser(description="AI is Really Cool") parser.add_argument( "--ollama-server", type=str, default=OLLAMA_API_URL, help=f"Ollama API endpoint. default={OLLAMA_API_URL}", ) parser.add_argument( "--ollama-model", type=str, default=MODEL_NAME, help=f"LLM model to use. default={MODEL_NAME}", ) parser.add_argument( "--web-host", type=str, default=WEB_HOST, help=f"Host to launch Flask web server. default={WEB_HOST} only if --web-disable not specified.", ) parser.add_argument( "--web-port", type=str, default=WEB_PORT, help=f"Port to launch Flask web server. default={WEB_PORT} only if --web-disable not specified.", ) parser.add_argument( "--level", type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], default=LOG_LEVEL, help=f"Set the logging level. default={LOG_LEVEL}", ) return parser.parse_args() # %% # %% # %% def is_valid_uuid(value: str) -> bool: try: uuid_obj = uuid.UUID(value, version=4) return str(uuid_obj) == value except (ValueError, TypeError): return False # %% class WebServer: @asynccontextmanager async def lifespan(self, app: FastAPI): yield for user in self.users: if user.observer: user.observer.stop() user.observer.join() logger.info("File watcher stopped") def __init__(self, llm, model=MODEL_NAME): self.app = FastAPI(lifespan=self.lifespan) self.prometheus_collector = CollectorRegistry() self.metrics = Metrics(prometheus_collector=self.prometheus_collector) # Keep the Instrumentator instance alive self.instrumentator = Instrumentator(registry=self.prometheus_collector) # Instrument the FastAPI app self.instrumentator.instrument(self.app) # Expose the /metrics endpoint self.instrumentator.expose(self.app, endpoint="/metrics") self.llm = llm self.model = model self.processing = False self.users = [] self.contexts = {} self.ssl_enabled = os.path.exists(defines.key_path) and os.path.exists( defines.cert_path ) if self.ssl_enabled: allow_origins = ["https://battle-linux.ketrenos.com:3000", "https://backstory-beta.ketrenos.com"] else: allow_origins = ["http://battle-linux.ketrenos.com:3000", "http://backstory-beta.ketrenos.com"] logger.info(f"Allowed origins: {allow_origins}") self.app.add_middleware( CORSMiddleware, allow_origins=allow_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) self.setup_routes() def setup_routes(self): # @self.app.get("/") # async def root(): # context = self.create_context(username=defines.default_username) # logger.info(f"Redirecting non-context to {context.id}") # return RedirectResponse(url=f"/{context.id}", status_code=307) # # return JSONResponse({"redirect": f"/{context.id}"}) @self.app.get("/api/umap/entry/{doc_id}/{context_id}") async def get_umap(doc_id: str, context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") try: context = await self.upsert_context(context_id) if not context: return JSONResponse( {"error": f"Invalid context: {context_id}"}, status_code=400 ) user = context.user collection = user.umap_collection if not collection: return JSONResponse( {"error": "No UMAP collection found"}, status_code=404 ) if not collection.get("metadatas", None): return JSONResponse(f"Document id {doc_id} not found.", 404) for index, id in enumerate(collection.get("ids", [])): if id == doc_id: metadata = collection.get("metadatas", [])[index].copy() content = user.file_watcher.prepare_metadata(metadata) return JSONResponse(content) return JSONResponse(f"Document id {doc_id} not found.", 404) except Exception as e: logger.error(f"get_umap error: {str(e)}") logger.error(traceback.format_exc()) return JSONResponse({"error": str(e)}, 500) @self.app.put("/api/umap/{context_id}") async def put_umap(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") try: context = await self.upsert_context(context_id) if not context: return JSONResponse( {"error": f"Invalid context: {context_id}"}, status_code=400 ) user = context.user data = await request.json() dimensions = data.get("dimensions", 2) collection = user.file_watcher.umap_collection if not collection: return JSONResponse( {"error": "No UMAP collection found"}, status_code=404 ) if dimensions == 2: logger.info("Returning 2D UMAP") umap_embedding = user.file_watcher.umap_embedding_2d else: logger.info("Returning 3D UMAP") umap_embedding = user.file_watcher.umap_embedding_3d if len(umap_embedding) == 0: return JSONResponse( {"error": "No UMAP embedding found"}, status_code=404 ) result = { "ids": collection.get("ids", []), "metadatas": collection.get("metadatas", []), "documents": collection.get("documents", []), "embeddings": umap_embedding.tolist(), "size": user.file_watcher.collection.count() } return JSONResponse(result) except Exception as e: logger.error(traceback.format_exc()) logger.error(f"put_umap error: {str(e)}") return JSONResponse({"error": str(e)}, 500) @self.app.put("/api/similarity/{context_id}") async def put_similarity(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") context = await self.upsert_context(context_id) user = context.user try: data = await request.json() query = data.get("query", "") threshold = data.get("threshold", defines.default_rag_threshold) results = data.get("results", defines.default_rag_top_k) except: query = "" threshold = defines.default_rag_threshold results = defines.default_rag_top_k if not query: return JSONResponse( {"error": "No query provided for similarity search"}, status_code=400, ) try: chroma_results = user.file_watcher.find_similar( query=query, top_k=results, threshold=threshold ) if not chroma_results: return JSONResponse({"error": "No results found"}, status_code=404) chroma_embedding = np.array( chroma_results["query_embedding"] ).flatten() # Ensure correct shape logger.info(f"Chroma embedding shape: {chroma_embedding.shape}") umap_2d = user.file_watcher.umap_model_2d.transform([chroma_embedding])[ 0 ].tolist() logger.info( f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}" ) # Debug output umap_3d = user.file_watcher.umap_model_3d.transform([chroma_embedding])[ 0 ].tolist() logger.info( f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}" ) # Debug output return JSONResponse({ "distances": chroma_results["distances"], "ids": chroma_results["ids"], "metadatas": chroma_results["metadatas"], "query": query, "umap_embedding_2d": umap_2d, "umap_embedding_3d": umap_3d, "size": user.file_watcher.collection.count() }) except Exception as e: logger.error(e) logging.error(traceback.format_exc()) return JSONResponse({"error": str(e)}, 500) @self.app.put("/api/reset/{context_id}/{agent_type}") async def put_reset(context_id: str, agent_type: str, request: Request): logger.info(f"{request.method} {request.url.path}") if not is_valid_uuid(context_id): logger.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = await self.upsert_context(context_id) agent = context.get_agent(agent_type) if not agent: response = { "history": [] } return JSONResponse(response) data = await request.json() try: response = {} for reset_operation in data["reset"]: match reset_operation: case "system_prompt": logger.info(f"Resetting {reset_operation}") case "rags": logger.info(f"Resetting {reset_operation}") context.rags = [ r.model_copy() for r in context.user.rags] response["rags"] = [ r.model_dump(mode="json") for r in context.rags ] case "tools": logger.info(f"Resetting {reset_operation}") context.tools = Tools.all_tools() response["tools"] = Tools.llm_tools(context.tools) case "history": reset_map = { "job_description": ( "job_description", "resume", "fact_check", ), "resume": ("job_description", "resume", "fact_check"), "fact_check": ( "job_description", "resume", "fact_check", ), "chat": ("chat",), } resets = reset_map.get(agent_type, ()) for mode in resets: tmp = context.get_agent(mode) if not tmp: logger.info( f"Agent {mode} not found for {context_id}" ) continue logger.info(f"Resetting {reset_operation} for {mode}") if mode != "chat": logger.info(f"Removing agent {tmp.agent_type} from {context_id}") context.remove_agent(tmp) tmp.conversation.reset() response["history"] = [] response["context_used"] = agent.context_tokens if not response: return JSONResponse( {"error": "Usage: { reset: rags|tools|history|system_prompt}"} ) else: await self.save_context(context_id) return JSONResponse(response) except Exception as e: logger.error(f"Error in reset: {e}") logger.error(traceback.format_exc()) return JSONResponse( {"error": "Usage: { reset: rags|tools|history|system_prompt}"} ) @self.app.put("/api/tunables/{context_id}") async def put_tunables(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") try: context = await self.upsert_context(context_id) data = await request.json() agent = context.get_agent("chat") if not agent: logger.info("chat agent does not exist on this context!") return JSONResponse( {"error": f"chat is not recognized", "context": context.id}, status_code=404, ) for k in data.keys(): match k: case "tools": from typing import Any # { "tools": [{ "tool": tool.name, "enabled": tool.enabled }] } tools: List[Dict[str, Any]] = data[k] if not tools: return JSONResponse( { "status": "error", "message": "Tools can not be empty.", } ) for tool in tools: for context_tool in context.tools: if context_tool.tool.function.name == tool["name"]: context_tool.enabled = tool.get("enabled", True) await self.save_context(context_id) return JSONResponse({ "tools": [{ **t.function.model_dump(mode='json'), "enabled": t.enabled, } for t in context.tools] }) case "rags": from typing import Any # { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] } rag_configs: list[dict[str, Any]] = data[k] if not rag_configs: return JSONResponse( { "status": "error", "message": "RAGs can not be empty.", } ) for config in rag_configs: for context_rag in context.rags: if context_rag.name == config["name"]: context_rag.enabled = config["enabled"] await self.save_context(context_id) return JSONResponse({"rags": [ r.model_dump(mode="json") for r in context.rags]}) case "system_prompt": system_prompt = data[k].strip() if not system_prompt: return JSONResponse( { "status": "error", "message": "System prompt can not be empty.", } ) agent.system_prompt = system_prompt await self.save_context(context_id) return JSONResponse({"system_prompt": system_prompt}) case _: return JSONResponse( {"error": f"Unrecognized tunable {k}"}, status_code=404 ) except Exception as e: logger.error(f"Error in put_tunables: {e}") return JSONResponse({"error": str(e)}, status_code=500) @self.app.get("/api/user/{context_id}") async def get_user(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") context = await self.upsert_context(context_id) user = context.user user_data = { "username": user.username, "first_name": user.first_name, "last_name": user.last_name, "full_name": user.full_name, "description": user.description, "contact_info": user.contact_info, "rag_content_size": user.rag_content_size, "has_profile": user.has_profile, "questions": [ q.model_dump(mode='json') for q in user.user_questions], } return JSONResponse(user_data) @self.app.get("/api/tunables/{context_id}") async def get_tunables(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") context = await self.upsert_context(context_id) agent = context.get_agent("chat") if not agent: logger.info("chat agent does not exist on this context!") return JSONResponse( {"error": f"chat is not recognized", "context": context.id}, status_code=404, ) return JSONResponse( { "system_prompt": agent.system_prompt, "rags": [ r.model_dump(mode="json") for r in context.rags ], "tools": [{ **t.function.model_dump(mode='json'), "enabled": t.enabled, } for t in context.tools], } ) @self.app.get("/api/system-info/{context_id}") async def get_system_info(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") return JSONResponse(system_info(self.model)) @self.app.post("/api/{agent_type}/{context_id}") async def post_agent_endpoint( agent_type: str, context_id: str, request: Request ): logger.info(f"{request.method} {request.url.path}") try: context = await self.upsert_context(context_id) except Exception as e: error = { "error": f"Unable to create or access context {context_id}: {e}" } logger.info(error) return JSONResponse(error, status_code=404) try: data = await request.json() query: Query = Query(**data) except Exception as e: error = {"error": f"Attempt to parse request: {e}"} logger.info(error) return JSONResponse(error, status_code=400) try: agent = context.get_or_create_agent(agent_type, **query.agent_options) except Exception as e: error = { "error": f"Attempt to create agent type: {agent_type} failed: {e}" } logger.error(traceback.format_exc()) logger.info(error) return JSONResponse(error, status_code=404) try: async def flush_generator(): logger.info(f"{agent.agent_type} - {inspect.stack()[0].function}") try: start_time = time.perf_counter() async for message in self.generate_response( context=context, agent=agent, prompt=query.prompt, tunables=query.tunables, ): if message.status != "done" and message.status != "partial": if message.status == "streaming": result = { "status": "streaming", "chunk": message.chunk, "remaining_time": LLM_TIMEOUT - (time.perf_counter() - start_time), } else: start_time = time.perf_counter() result = { "status": message.status, "response": message.response, "remaining_time": LLM_TIMEOUT, } else: logger.info(f"Providing {message.status} response.") try: result = message.model_dump( by_alias=True, mode="json" ) except Exception as e: result = {"status": "error", "response": str(e)} yield json.dumps(result) + "\n" break # Convert to JSON and add newline message.network_packets += 1 message.network_bytes += len(result) try: disconnected = await request.is_disconnected() except Exception as e: logger.warning(f"Disconnection check failed: {e}") disconnected = True if disconnected: logger.info("Disconnect detected. Continuing generation to store in cache.") disconnected = True if not disconnected: yield json.dumps(result) + "\n" current_time = time.perf_counter() if current_time - start_time > LLM_TIMEOUT: message.status = "error" message.response = f"Processing time ({LLM_TIMEOUT}s) exceeded for single LLM inference (likely due to LLM getting stuck.) You will need to retry your query." message.partial_response = message.response logger.info(message.response + " Ending session") result = message.model_dump(by_alias=True, mode="json") if not disconnected: yield json.dumps(result) + "\n" if message.status == "error": context.processing = False break # Allow the event loop to process the write await asyncio.sleep(0) except Exception as e: context.processing = False logger.error(traceback.format_exc()) logger.error(f"Error in generate_response: {e}") yield json.dumps({"status": "error", "response": str(e)}) + "\n" finally: # Save context on completion or error await self.save_context(context_id) logger.info("Flush generator completed normally.") # Return StreamingResponse with appropriate headers return StreamingResponse( flush_generator(), media_type="application/json", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", # Prevents Nginx buffering if you're using it }, ) except Exception as e: context.processing = False logger.error(f"Error in post_chat_endpoint: {e}") return JSONResponse({"error": str(e)}, status_code=500) @self.app.post("/api/create-session") async def create_session(request: Request): logger.info(f"{request.method} {request.url.path}") context = await self.create_context(username=defines.default_username) return JSONResponse({"id": context.id}) @self.app.get("/api/join-session/{context_id}") async def join_session(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") context = await self.load_context(context_id=context_id) if not context: return JSONResponse({"error": f"{context_id} does not exist."}, 404) return JSONResponse({"id": context.id}) @self.app.get("/api/u/{context_id}") async def get_users(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") try: context = await self.load_context(context_id) if not context: return JSONResponse({"error": f"Context {context_id} not found."}, status_code=404) users = [User.sanitize(u) for u in User.get_users()] return JSONResponse(users) except Exception as e: logger.error(traceback.format_exc()) logger.error(f"get_users error: {str(e)}") return JSONResponse({ "error": "Unable to parse users"}, 500) @self.app.get("/api/u/{username}/profile/{context_id}") async def get_user_profile(username: str, context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") try: if not User.exists(username): return JSONResponse({"error": f"User {username} not found."}, status_code=404) context = await self.load_context(context_id) if not context: return JSONResponse({"error": f"Context {context_id} not found."}, status_code=404) profile_path = os.path.join(defines.user_dir, username, f"profile.png") if not os.path.exists(profile_path): return JSONResponse({ "error": "User {username} does not have a profile picture"}, status_code=404) return FileResponse(profile_path) except Exception as e: return JSONResponse({ "error": "Unable to load user {username}"}, 500) @self.app.post("/api/u/{username}/{context_id}") async def post_user(username: str, context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") try: if not User.exists(username): return JSONResponse({"error": f"User {username} not found."}, status_code=404) context = await self.load_context(context_id) if not context: return JSONResponse({"error": f"Context {context_id} not found."}, status_code=404) matching_user = next((user for user in self.users if user.username == username), None) if matching_user: user = matching_user else: user = User(username=username, llm=self.llm) await user.initialize(prometheus_collector=self.prometheus_collector) self.users.append(user) reset_map = ( "chat", "job_description", "resume", "fact_check", ) for mode in reset_map: tmp = context.get_agent(mode) if not tmp: continue logger.info(f"User change: Resetting history for {mode}") if mode != "chat": context.remove_agent(tmp) tmp.conversation.reset() context.user = user user_data = { "username": user.username, "first_name": user.first_name, "last_name": user.last_name, "full_name": user.full_name, "description": user.description, "contact_info": user.contact_info, "rag_content_size": user.rag_content_size, "has_profile": user.has_profile, "questions": [ q.model_dump(mode='json') for q in user.user_questions], } await self.save_context(context_id) return JSONResponse(user_data) except Exception as e: return JSONResponse({ "error": "Unable to load user {username}"}, 500) @self.app.post("/api/context/u/{username}") async def create_user_context(username: str, request: Request): logger.info(f"{request.method} {request.url.path}") try: if not User.exists(username): return JSONResponse({"error": f"User {username} not found."}, status_code=404) context = await self.create_context(username=username) logger.info(f"Generated new context {context.id} for {username}") return JSONResponse({"id": context.id}) except Exception as e: logger.error(f"create_user_context error: {str(e)}") logger.error(traceback.format_exc()) return JSONResponse({"error": f"User {username} not found."}, status_code=404) @self.app.post("/api/context") async def create_context(request: Request): logger.info(f"{request.method} {request.url.path}") return self.app.create_user_context(defines.default_username, request) @self.app.get("/api/history/{context_id}/{agent_type}") async def get_history(context_id: str, agent_type: str, request: Request): logger.info(f"{request.method} {request.url.path}") try: context = await self.upsert_context(context_id) agent = context.get_agent(agent_type) if not agent: logger.info( f"Agent {agent_type} not found. Returning empty history." ) return JSONResponse({"messages": []}) logger.info( f"History for {agent_type} contains {len(agent.conversation)} entries." ) return agent.conversation except Exception as e: logger.error(traceback.format_exc()) logger.error(f"get_history error: {str(e)}") return JSONResponse({"error": str(e)}, status_code=404) @self.app.get("/api/tools/{context_id}") async def get_tools(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") context = await self.upsert_context(context_id) return JSONResponse(context.tools) @self.app.put("/api/tools/{context_id}") async def put_tools(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") if not is_valid_uuid(context_id): logger.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = await self.upsert_context(context_id) try: data = await request.json() modify = data["tool"] enabled = data["enabled"] for tool in context.tools: if modify == tool.function.name: tool.enabled = enabled await self.save_context(context_id) return JSONResponse(context.tools) return JSONResponse( {"status": f"{modify} not found in tools."}, status_code=404 ) except: return JSONResponse({"status": "error"}, 405) @self.app.get("/api/context-status/{context_id}/{agent_type}") async def get_context_status(context_id, agent_type: str, request: Request): logger.info(f"{request.method} {request.url.path}") if not is_valid_uuid(context_id): logger.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = await self.upsert_context(context_id) agent = context.get_agent(agent_type) if not agent: return JSONResponse( {"context_used": 0, "max_context": defines.max_context} ) return JSONResponse( { "context_used": agent.context_tokens, "max_context": defines.max_context, } ) @self.app.get("/api/health") async def health_check(): return JSONResponse({"status": "healthy"}) @self.app.get("/{path:path}") async def serve_static(path: str, request: Request): full_path = os.path.join(defines.static_content, path) # Check if the original path exists if os.path.exists(full_path) and os.path.isfile(full_path): logger.info(f"Serve static request for {full_path}") return FileResponse(full_path) # Check if the path matches /{filename}.png # if path.startswith("/") and path.endswith(".png"): # filename = path[1:-4] # Remove leading '/' and trailing '.png' # alt_path = f"/opt/backstory/users/{filename}/{filename}.png" # # Check if the alternative path exists # if os.path.exists(alt_path) and os.path.isfile(alt_path): # logger.info(f"Serve static request for alternative path {alt_path}") # return FileResponse(alt_path) # If neither path exists, return 404 logger.info(f"File not found for path {full_path} -- returning index.html") return FileResponse(os.path.join(defines.static_content, "index.html")) async def save_context(self, context_id): """ Serialize a Python dictionary to a file in the agents directory. Args: data: Dictionary containing the agent data context_id: UUID string for the context. If it doesn't exist, it is created Returns: The context_id used for the file """ context = await self.upsert_context(context_id) # Create agents directory if it doesn't exist if not os.path.exists(defines.context_dir): os.makedirs(defines.context_dir) # Create the full file path file_path = os.path.join(defines.context_dir, context_id) # Serialize the data to JSON and write to file try: # Check for non-serializable fields before dumping serialization_errors = check_serializable(context) if serialization_errors: for error in serialization_errors: logger.error(error) raise ValueError("Found non-serializable fields in the model") # Dump the model prior to opening file in case there is # a validation error so it doesn't delete the current # context session json_data = context.model_dump_json(by_alias=True) with open(file_path, "w") as f: f.write(json_data) except ValidationError as e: logger.error(e) logger.error(traceback.format_exc()) for error in e.errors(): print(f"Field: {error['loc'][0]}, Error: {error['msg']}") except PydanticSerializationError as e: logger.error(e) logger.error(traceback.format_exc()) logger.error(f"Serialization error: {str(e)}") # Inspect the model to identify problematic fields for field_name, value in context.__dict__.items(): if isinstance(value, np.ndarray): logger.error(f"Field '{field_name}' contains non-serializable type: {type(value)}") except Exception as e: logger.error(traceback.format_exc()) logger.error(e) return context_id async def load_context(self, context_id: str) -> Context | None: """ Load a context from a file in the context directory or create a new one if it doesn't exist. Args: context_id: UUID string for the context. Returns: A Context object with the specified ID and default settings. """ file_path = os.path.join(defines.context_dir, context_id) # Check if the file exists if not os.path.exists(file_path): return None # Read and deserialize the data with open(file_path, "r") as f: content = f.read() logger.info( f"Loading context from {file_path}, content length: {len(content)}" ) json_data = {} try: # Try parsing as JSON first to ensure valid JSON json_data = json.loads(content) logger.info("JSON parsed successfully, attempting model validation") context = Context.model_validate(json_data) username = context.username if not User.exists(username): raise ValueError(f"Attempt to load context {context.id} with invalid user {username}") matching_user = next((user for user in self.users if user.username == username), None) if matching_user: user = matching_user else: user = User(username=username, llm=self.llm) await user.initialize(prometheus_collector=self.prometheus_collector) self.users.append(user) context.user = user # Now set context on agents manually agent_types = [agent.agent_type for agent in context.agents] if len(agent_types) != len(set(agent_types)): raise ValueError( "Context cannot contain multiple agents of the same agent_type" ) for agent in context.agents: agent.set_context(context) self.contexts[context_id] = context logger.info(f"Successfully loaded context {context_id}") except ValidationError as e: logger.error(e) logger.error(traceback.format_exc()) for error in e.errors(): print(f"Field: {error['loc'][0]}, Error: {error['msg']}") except Exception as e: logger.error(f"Error validating context: {str(e)}") logger.error(traceback.format_exc()) for key in json_data: logger.info(f"{key} = {type(json_data[key])} {str(json_data[key])[:60] if json_data[key] else "None"}") logger.info("*" * 50) return None return self.contexts[context_id] async def load_or_create_context(self, context_id: str) -> Context: """ Load a context from a file in the context directory or create a new one if it doesn't exist. Args: context_id: UUID string for the context. Returns: A Context object with the specified ID and default settings. """ context = await self.load_context(context_id) if context: return context logger.info(f"Context not found. Creating new instance of context {context_id}.") self.contexts[context_id] = await self.create_context(username=defines.default_username, context_id=context_id) return self.contexts[context_id] async def create_context(self, username: str, context_id=None) -> Context: """ Create a new context with a unique ID and default settings. Args: context_id: Optional UUID string for the context. If not provided, a new UUID is generated. Returns: A Context object with the specified ID and default settings. """ if not User.exists(username): raise ValueError(f"{username} does not exist.") # If username matching_user = next((user for user in self.users if user.username == username), None) if matching_user: user = matching_user logger.info(f"Found matching user: {user.username}") else: user = User(username=username, llm=self.llm) await user.initialize(prometheus_collector=self.prometheus_collector) logger.info(f"Created new instance of user: {user.username}") self.users.append(user) logger.info(f"Creating context {context_id or "new"} for user: {user.username}") try: if context_id: context = Context( id=context_id, user=user, rags=[ rag.model_copy() for rag in user.rags ], tools=Tools.all_tools() ) else: context = Context( user=user, rags=[ rag.model_copy() for rag in user.rags ], tools=Tools.all_tools() ) except ValidationError as e: logger.error(e) logger.error(traceback.format_exc()) for error in e.errors(): print(f"Field: {error['loc'][0]}, Error: {error['msg']}") exit(1) logger.info(f"New context created with ID: {context.id}") if os.path.exists(defines.resume_doc): context.user_resume = open(defines.resume_doc, "r").read() context.get_or_create_agent(agent_type="chat") # system_prompt=system_message) # context.add_agent(Resume(system_prompt = system_generate_resume)) # context.add_agent(JobDescription(system_prompt = system_job_description)) # context.add_agent(FactCheck(system_prompt = system_fact_check)) logger.info(f"{context.id} created and added to contexts.") self.contexts[context.id] = context await self.save_context(context.id) return context async def upsert_context(self, context_id=None) -> Context: """ Upsert a context based on the provided context_id. Args: context_id: UUID string for the context. If it doesn't exist, a new context is created. Returns: A Context object with the specified ID and default settings. """ if not context_id: logger.warning("No context ID provided. Creating a new context.") return await self.create_context(username=defines.default_username) if context_id in self.contexts: return self.contexts[context_id] logger.info(f"Context {context_id} is not yet loaded.") return await self.load_or_create_context(context_id=context_id) @REQUEST_TIME.time() async def generate_response( self, context: Context, agent: Agent, prompt: str, tunables: Tunables | None ) -> AsyncGenerator[Message, None]: agent_type = agent.get_agent_type() logger.info(f"generate_response: type - {agent_type}") # Merge tunables to take agent defaults and override with user supplied settings agent_tunables = agent.tunables.model_dump() if agent.tunables else {} user_tunables = tunables.model_dump() if tunables else {} merged_tunables = {**agent_tunables, **user_tunables} message = Message(prompt=prompt, tunables=Tunables(**merged_tunables)) async for message in agent.prepare_message(message): # logger.info(f"{agent_type}.prepare_message: {value.status} - {value.response}") if message.status == "error": yield message return if message.status != "done": yield message async for message in agent.process_message(self.llm, self.model, message): if message.status != "done": yield message if message.status == "error": return logger.info( f"{agent_type}.process_message: {message.status} {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}" ) message.status = "done" yield message return def run(self, host="0.0.0.0", port=WEB_PORT, **kwargs): try: if self.ssl_enabled: logger.info(f"Starting web server at https://{host}:{port}") uvicorn.run( self.app, host=host, port=port, log_config=None, ssl_keyfile=defines.key_path, ssl_certfile=defines.cert_path, ) else: logger.info(f"Starting web server at http://{host}:{port}") uvicorn.run(self.app, host=host, port=port, log_config=None) except KeyboardInterrupt: for user in self.users: if user.observer: user.observer.stop() for user in self.users: if user.observer: user.observer.join() # %% # Main function to run everything def main(): # Parse command-line arguments args = parse_args() # Setup logging based on the provided level logger.setLevel(args.level.upper()) warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn.*") warnings.filterwarnings("ignore", category=UserWarning, module="umap.*") llm = ollama.Client(host=args.ollama_server) # type: ignore web_server = WebServer(llm, args.ollama_model) web_server.run(host=args.web_host, port=args.web_port, use_reloader=False) main()