LLM_TIMEOUT = 600 from utils import logger from pydantic import BaseModel, Field, ValidationError # type: ignore from pydantic_core import PydanticSerializationError # type: ignore from typing import List from typing import AsyncGenerator, Dict, Optional # %% # Imports [standard] # Standard library modules (no try-except needed) import argparse import asyncio import json import logging import os import re import uuid import subprocess import re import math import warnings from typing import Any from datetime import datetime import inspect from uuid import uuid4 import time import traceback def try_import(module_name, pip_name=None): try: __import__(module_name) except ImportError: print(f"Module '{module_name}' not found. Install it using:") print(f" pip install {pip_name or module_name}") # Third-party modules with import checks try_import("ollama") try_import("requests") try_import("fastapi") try_import("uvicorn") try_import("numpy") try_import("umap") try_import("sklearn") try_import("prometheus_client") try_import("prometheus_fastapi_instrumentator") import ollama from contextlib import asynccontextmanager from fastapi import FastAPI, Request # type: ignore from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse # type: ignore from fastapi.middleware.cors import CORSMiddleware # type: ignore import uvicorn # type: ignore import numpy as np # type: ignore # Prometheus from prometheus_client import Summary # type: ignore from prometheus_fastapi_instrumentator import Instrumentator # type: ignore from prometheus_client import CollectorRegistry, Counter # type: ignore from utils import ( rag as Rag, ChromaDBGetResponse, tools as Tools, Context, Conversation, Message, Agent, Metrics, Tunables, defines, check_serializable, logger, ) rags : List[ChromaDBGetResponse] = [ ChromaDBGetResponse( name="JPK", enabled=True, description="Expert data about James Ketrenos, including work history, personal hobbies, and projects.", ), # { "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." }, ] class Query(BaseModel): prompt: str tunables: Tunables = Field(default_factory=Tunables) agent_options: Dict[str, Any] = Field(default={}) REQUEST_TIME = Summary("request_processing_seconds", "Time spent processing request") def get_installed_ram(): try: with open("/proc/meminfo", "r") as f: meminfo = f.read() match = re.search(r"MemTotal:\s+(\d+)", meminfo) if match: return f"{math.floor(int(match.group(1)) / 1000**2)}GB" # Convert KB to GB except Exception as e: return f"Error retrieving RAM: {e}" def get_graphics_cards(): gpus = [] try: # Run the ze-monitor utility result = subprocess.run( ["ze-monitor"], capture_output=True, text=True, check=True ) # Clean up the output (remove leading/trailing whitespace and newlines) output = result.stdout.strip() for index in range(len(output.splitlines())): result = subprocess.run( ["ze-monitor", "--device", f"{index+1}", "--info"], capture_output=True, text=True, check=True, ) gpu_info = result.stdout.strip().splitlines() gpu = { "discrete": True, # Assume it's discrete initially "name": None, "memory": None, } gpus.append(gpu) for line in gpu_info: match = re.match(r"^Device: [^(]*\((.*)\)", line) if match: gpu["name"] = match.group(1) continue match = re.match(r"^\s*Memory: (.*)", line) if match: gpu["memory"] = match.group(1) continue match = re.match(r"^.*Is integrated with host: Yes.*", line) if match: gpu["discrete"] = False continue return gpus except Exception as e: return f"Error retrieving GPU info: {e}" def get_cpu_info(): try: with open("/proc/cpuinfo", "r") as f: cpuinfo = f.read() model_match = re.search(r"model name\s+:\s+(.+)", cpuinfo) cores_match = re.findall(r"processor\s+:\s+\d+", cpuinfo) if model_match and cores_match: return f"{model_match.group(1)} with {len(cores_match)} cores" except Exception as e: return f"Error retrieving CPU info: {e}" def system_info(model): return { "System RAM": get_installed_ram(), "Graphics Card": get_graphics_cards(), "CPU": get_cpu_info(), "LLM Model": model, "Embedding Model": defines.embedding_model, "Context length": defines.max_context, } # %% # Defaults OLLAMA_API_URL = defines.ollama_api_url MODEL_NAME = defines.model LOG_LEVEL = "info" WEB_HOST = "0.0.0.0" WEB_PORT = 8911 DEFAULT_HISTORY_LENGTH = 5 # %% # Globals model = None web_server = None # %% # Cmd line overrides def parse_args(): parser = argparse.ArgumentParser(description="AI is Really Cool") parser.add_argument( "--ollama-server", type=str, default=OLLAMA_API_URL, help=f"Ollama API endpoint. default={OLLAMA_API_URL}", ) parser.add_argument( "--ollama-model", type=str, default=MODEL_NAME, help=f"LLM model to use. default={MODEL_NAME}", ) parser.add_argument( "--web-host", type=str, default=WEB_HOST, help=f"Host to launch Flask web server. default={WEB_HOST} only if --web-disable not specified.", ) parser.add_argument( "--web-port", type=str, default=WEB_PORT, help=f"Port to launch Flask web server. default={WEB_PORT} only if --web-disable not specified.", ) parser.add_argument( "--level", type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], default=LOG_LEVEL, help=f"Set the logging level. default={LOG_LEVEL}", ) return parser.parse_args() # %% # %% # %% def is_valid_uuid(value: str) -> bool: try: uuid_obj = uuid.UUID(value, version=4) return str(uuid_obj) == value except (ValueError, TypeError): return False # %% class WebServer: @asynccontextmanager async def lifespan(self, app: FastAPI): # Start the file watcher self.observer, self.file_watcher = Rag.start_file_watcher( llm=self.llm, watch_directory=defines.doc_dir, recreate=False, # Don't recreate if exists ) logger.info( f"API started with {self.file_watcher.collection.count()} documents in the collection" ) yield if self.observer: self.observer.stop() self.observer.join() logger.info("File watcher stopped") def __init__(self, llm, model=MODEL_NAME): self.app = FastAPI(lifespan=self.lifespan) self.prometheus_collector = CollectorRegistry() self.metrics = Metrics(prometheus_collector=self.prometheus_collector) # Keep the Instrumentator instance alive self.instrumentator = Instrumentator(registry=self.prometheus_collector) # Instrument the FastAPI app self.instrumentator.instrument(self.app) # Expose the /metrics endpoint self.instrumentator.expose(self.app, endpoint="/metrics") self.contexts = {} self.llm = llm self.model = model self.processing = False self.file_watcher = None self.observer = None self.ssl_enabled = os.path.exists(defines.key_path) and os.path.exists( defines.cert_path ) if self.ssl_enabled: allow_origins = ["https://battle-linux.ketrenos.com:3000"] else: allow_origins = ["http://battle-linux.ketrenos.com:3000"] logger.info(f"Allowed origins: {allow_origins}") self.app.add_middleware( CORSMiddleware, allow_origins=allow_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) self.setup_routes() def setup_routes(self): @self.app.get("/") async def root(): context = self.create_context() logger.info(f"Redirecting non-context to {context.id}") return RedirectResponse(url=f"/{context.id}", status_code=307) # return JSONResponse({"redirect": f"/{context.id}"}) @self.app.get("/api/umap/entry/{doc_id}/{context_id}") async def get_umap(doc_id: str, context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") try: if not self.file_watcher: raise Exception("File watcher not initialized") context = self.upsert_context(context_id) if not context: return JSONResponse( {"error": f"Invalid context: {context_id}"}, status_code=400 ) collection = self.file_watcher.umap_collection if not collection: return JSONResponse( {"error": "No UMAP collection found"}, status_code=404 ) if not collection.get("metadatas", None): return JSONResponse(f"Document id {doc_id} not found.", 404) for index, id in enumerate(collection.get("ids", [])): if id == doc_id: metadata = collection.get("metadatas", [])[index].copy() content = self.file_watcher.prepare_metadata(metadata) return JSONResponse(content) return JSONResponse(f"Document id {doc_id} not found.", 404) except Exception as e: logger.error(f"get_umap error: {str(e)}") logger.error(traceback.format_exc()) return JSONResponse({"error": str(e)}, 500) @self.app.put("/api/umap/{context_id}") async def put_umap(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") try: if not self.file_watcher: raise Exception("File watcher not initialized") context = self.upsert_context(context_id) if not context: return JSONResponse( {"error": f"Invalid context: {context_id}"}, status_code=400 ) data = await request.json() dimensions = data.get("dimensions", 2) collection = self.file_watcher.umap_collection if not collection: return JSONResponse( {"error": "No UMAP collection found"}, status_code=404 ) if dimensions == 2: logger.info("Returning 2D UMAP") umap_embedding = self.file_watcher.umap_embedding_2d else: logger.info("Returning 3D UMAP") umap_embedding = self.file_watcher.umap_embedding_3d if len(umap_embedding) == 0: return JSONResponse( {"error": "No UMAP embedding found"}, status_code=404 ) result = { "ids": collection.get("ids", []), "metadatas": collection.get("metadatas", []), "documents": collection.get("documents", []), "embeddings": umap_embedding.tolist(), "size": self.file_watcher.collection.count() } return JSONResponse(result) except Exception as e: logger.error(f"put_umap error: {str(e)}") logger.error(traceback.format_exc()) return JSONResponse({"error": str(e)}, 500) @self.app.put("/api/similarity/{context_id}") async def put_similarity(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") if not self.file_watcher: raise Exception("File watcher not initialized") if not is_valid_uuid(context_id): logger.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) try: data = await request.json() query = data.get("query", "") threshold = data.get("threshold", defines.default_rag_threshold) results = data.get("results", defines.default_rag_top_k) except: query = "" threshold = defines.default_rag_threshold results = defines.default_rag_top_k if not query: return JSONResponse( {"error": "No query provided for similarity search"}, status_code=400, ) try: chroma_results = self.file_watcher.find_similar( query=query, top_k=results, threshold=threshold ) if not chroma_results: return JSONResponse({"error": "No results found"}, status_code=404) chroma_embedding = np.array( chroma_results["query_embedding"] ).flatten() # Ensure correct shape logger.info(f"Chroma embedding shape: {chroma_embedding.shape}") umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[ 0 ].tolist() logger.info( f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}" ) # Debug output umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[ 0 ].tolist() logger.info( f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}" ) # Debug output return JSONResponse({ "distances": chroma_results["distances"], "ids": chroma_results["ids"], "metadatas": chroma_results["metadatas"], "query": query, "umap_embedding_2d": umap_2d, "umap_embedding_3d": umap_3d, "size": self.file_watcher.collection.count() }) except Exception as e: logger.error(e) logging.error(traceback.format_exc()) return JSONResponse({"error": str(e)}, 500) @self.app.put("/api/reset/{context_id}/{agent_type}") async def put_reset(context_id: str, agent_type: str, request: Request): logger.info(f"{request.method} {request.url.path}") if not is_valid_uuid(context_id): logger.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = self.upsert_context(context_id) agent = context.get_agent(agent_type) if not agent: response = { "history": [] } return JSONResponse(response) data = await request.json() try: response = {} for reset_operation in data["reset"]: match reset_operation: case "system_prompt": logger.info(f"Resetting {reset_operation}") case "rags": logger.info(f"Resetting {reset_operation}") context.rags = [ r.model_copy() for r in rags] response["rags"] = [ r.model_dump(mode="json") for r in context.rags ] case "tools": logger.info(f"Resetting {reset_operation}") context.tools = Tools.enabled_tools(Tools.tools) response["tools"] = context.tools case "history": reset_map = { "job_description": ( "job_description", "resume", "fact_check", ), "resume": ("job_description", "resume", "fact_check"), "fact_check": ( "job_description", "resume", "fact_check", ), "chat": ("chat",), } resets = reset_map.get(agent_type, ()) for mode in resets: tmp = context.get_agent(mode) if not tmp: logger.info( f"Agent {mode} not found for {context_id}" ) continue logger.info(f"Resetting {reset_operation} for {mode}") tmp.conversation.reset() response["history"] = [] response["context_used"] = agent.context_tokens case "message_history_length": logger.info(f"Resetting {reset_operation}") context.message_history_length = DEFAULT_HISTORY_LENGTH response["message_history_length"] = DEFAULT_HISTORY_LENGTH if not response: return JSONResponse( {"error": "Usage: { reset: rags|tools|history|system_prompt}"} ) else: self.save_context(context_id) return JSONResponse(response) except Exception as e: logger.error(f"Error in reset: {e}") logger.error(traceback.format_exc()) return JSONResponse( {"error": "Usage: { reset: rags|tools|history|system_prompt}"} ) @self.app.put("/api/tunables/{context_id}") async def put_tunables(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") try: context = self.upsert_context(context_id) data = await request.json() agent = context.get_agent("chat") if not agent: logger.info("chat agent does not exist on this context!") return JSONResponse( {"error": f"chat is not recognized", "context": context.id}, status_code=404, ) for k in data.keys(): match k: case "tools": # { "tools": [{ "tool": tool?.name, "enabled": tool.enabled }] } tools: list[dict[str, Any]] = data[k] if not tools: return JSONResponse( { "status": "error", "message": "Tools can not be empty.", } ) for tool in tools: for context_tool in context.tools: if context_tool["function"]["name"] == tool["name"]: context_tool["enabled"] = tool["enabled"] self.save_context(context_id) return JSONResponse( { "tools": [ { **t["function"], "enabled": t["enabled"], } for t in context.tools ] } ) case "rags": # { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] } rag_configs: list[dict[str, Any]] = data[k] if not rag_configs: return JSONResponse( { "status": "error", "message": "RAGs can not be empty.", } ) for config in rag_configs: for context_rag in context.rags: if context_rag.name == config["name"]: context_rag.enabled = config["enabled"] self.save_context(context_id) return JSONResponse({"rags": [ r.model_dump(mode="json") for r in context.rags]}) case "system_prompt": system_prompt = data[k].strip() if not system_prompt: return JSONResponse( { "status": "error", "message": "System prompt can not be empty.", } ) agent.system_prompt = system_prompt self.save_context(context_id) return JSONResponse({"system_prompt": system_prompt}) case "message_history_length": value = max(0, int(data[k])) context.message_history_length = value self.save_context(context_id) return JSONResponse({"message_history_length": value}) case _: return JSONResponse( {"error": f"Unrecognized tunable {k}"}, status_code=404 ) except Exception as e: logger.error(f"Error in put_tunables: {e}") return JSONResponse({"error": str(e)}, status_code=500) @self.app.get("/api/tunables/{context_id}") async def get_tunables(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") context = self.upsert_context(context_id) agent = context.get_agent("chat") if not agent: logger.info("chat agent does not exist on this context!") return JSONResponse( {"error": f"chat is not recognized", "context": context.id}, status_code=404, ) return JSONResponse( { "system_prompt": agent.system_prompt, "message_history_length": context.message_history_length, "rags": [ r.model_dump(mode="json") for r in context.rags ], "tools": [ { **t["function"], "enabled": t["enabled"], } for t in context.tools ], } ) @self.app.get("/api/system-info/{context_id}") async def get_system_info(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") return JSONResponse(system_info(self.model)) @self.app.post("/api/{agent_type}/{context_id}") async def post_agent_endpoint( agent_type: str, context_id: str, request: Request ): logger.info(f"{request.method} {request.url.path}") try: context = self.upsert_context(context_id) except Exception as e: error = { "error": f"Unable to create or access context {context_id}: {e}" } logger.info(error) return JSONResponse(error, status_code=404) try: data = await request.json() query: Query = Query(**data) except Exception as e: error = {"error": f"Attempt to parse request: {e}"} logger.info(error) return JSONResponse(error, status_code=400) try: agent = context.get_or_create_agent(agent_type, **query.agent_options) except Exception as e: error = { "error": f"Attempt to create agent type: {agent_type} failed: {e}" } logger.info(error) return JSONResponse(error, status_code=404) try: async def flush_generator(): logger.info(f"{agent.agent_type} - {inspect.stack()[0].function}") try: start_time = time.perf_counter() async for message in self.generate_response( context=context, agent=agent, prompt=query.prompt, tunables=query.tunables, ): if message.status != "done" and message.status != "partial": if message.status == "streaming": result = { "status": "streaming", "chunk": message.chunk, "remaining_time": LLM_TIMEOUT - (time.perf_counter() - start_time), } else: start_time = time.perf_counter() result = { "status": message.status, "response": message.response, "remaining_time": LLM_TIMEOUT, } else: logger.info(f"Providing {message.status} response.") try: result = message.model_dump( by_alias=True, mode="json" ) except Exception as e: result = {"status": "error", "response": str(e)} yield json.dumps(result) + "\n" return # Convert to JSON and add newline result = json.dumps(result) + "\n" message.network_packets += 1 message.network_bytes += len(result) yield result if await request.is_disconnected(): logger.info("Disconnect detected. Aborting generation.") context.processing = False # Save context on completion or error message.prompt = query.prompt message.status = "error" message.response = ( "Client disconnected during generation." ) agent.conversation.add(message) self.save_context(context_id) return current_time = time.perf_counter() if current_time - start_time > LLM_TIMEOUT: message.status = "error" message.response = f"Processing time ({LLM_TIMEOUT}s) exceeded for single LLM inference (likely due to LLM getting stuck.) You will need to retry your query." message.partial_response = message.response logger.info(message.response + " Ending session") result = message.model_dump(by_alias=True, mode="json") result = json.dumps(result) + "\n" yield result if message.status == "error": context.processing = False return # Allow the event loop to process the write await asyncio.sleep(0) except Exception as e: context.processing = False logger.error(f"Error in generate_response: {e}") logger.error(traceback.format_exc()) yield json.dumps({"status": "error", "response": str(e)}) + "\n" finally: # Save context on completion or error self.save_context(context_id) # Return StreamingResponse with appropriate headers return StreamingResponse( flush_generator(), media_type="application/json", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", # Prevents Nginx buffering if you're using it }, ) except Exception as e: context.processing = False logger.error(f"Error in post_chat_endpoint: {e}") return JSONResponse({"error": str(e)}, status_code=500) @self.app.post("/api/context") async def create_context(): try: context = self.create_context() logger.info(f"Generated new agent as {context.id}") return JSONResponse({"id": context.id}) except Exception as e: logger.error(f"get_history error: {str(e)}") logger.error(traceback.format_exc()) return JSONResponse({"error": str(e)}, status_code=404) @self.app.get("/api/history/{context_id}/{agent_type}") async def get_history(context_id: str, agent_type: str, request: Request): logger.info(f"{request.method} {request.url.path}") try: context = self.upsert_context(context_id) agent = context.get_agent(agent_type) if not agent: logger.info( f"Agent {agent_type} not found. Returning empty history." ) return JSONResponse({"messages": []}) logger.info( f"History for {agent_type} contains {len(agent.conversation)} entries." ) return agent.conversation except Exception as e: logger.error(f"get_history error: {str(e)}") import traceback logger.error(traceback.format_exc()) return JSONResponse({"error": str(e)}, status_code=404) @self.app.get("/api/tools/{context_id}") async def get_tools(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") context = self.upsert_context(context_id) return JSONResponse(context.tools) @self.app.put("/api/tools/{context_id}") async def put_tools(context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") if not is_valid_uuid(context_id): logger.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = self.upsert_context(context_id) try: data = await request.json() modify = data["tool"] enabled = data["enabled"] for tool in context.tools: if modify == tool["function"]["name"]: tool["enabled"] = enabled self.save_context(context_id) return JSONResponse(context.tools) return JSONResponse( {"status": f"{modify} not found in tools."}, status_code=404 ) except: return JSONResponse({"status": "error"}, 405) @self.app.get("/api/context-status/{context_id}/{agent_type}") async def get_context_status(context_id, agent_type: str, request: Request): logger.info(f"{request.method} {request.url.path}") if not is_valid_uuid(context_id): logger.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = self.upsert_context(context_id) agent = context.get_agent(agent_type) if not agent: return JSONResponse( {"context_used": 0, "max_context": defines.max_context} ) return JSONResponse( { "context_used": agent.context_tokens, "max_context": defines.max_context, } ) @self.app.get("/api/health") async def health_check(): return JSONResponse({"status": "healthy"}) @self.app.get("/{path:path}") async def serve_static(path: str, request: Request): full_path = os.path.join(defines.static_content, path) if os.path.exists(full_path) and os.path.isfile(full_path): logger.info(f"Serve static request for {full_path}") return FileResponse(full_path) logger.info(f"Serve index.html for {path}") return FileResponse(os.path.join(defines.static_content, "index.html")) def save_context(self, context_id): """ Serialize a Python dictionary to a file in the agents directory. Args: data: Dictionary containing the agent data context_id: UUID string for the context. If it doesn't exist, it is created Returns: The context_id used for the file """ context = self.upsert_context(context_id) # Create agents directory if it doesn't exist if not os.path.exists(defines.context_dir): os.makedirs(defines.context_dir) # Create the full file path file_path = os.path.join(defines.context_dir, context_id) # Serialize the data to JSON and write to file try: # Check for non-serializable fields before dumping serialization_errors = check_serializable(context) if serialization_errors: for error in serialization_errors: logger.error(error) raise ValueError("Found non-serializable fields in the model") # Dump the model prior to opening file in case there is # a validation error so it doesn't delete the current # context session json_data = context.model_dump_json(by_alias=True) with open(file_path, "w") as f: f.write(json_data) except ValidationError as e: logger.error(e) logger.error(traceback.format_exc()) for error in e.errors(): print(f"Field: {error['loc'][0]}, Error: {error['msg']}") except PydanticSerializationError as e: logger.error(e) logger.error(traceback.format_exc()) logger.error(f"Serialization error: {str(e)}") # Inspect the model to identify problematic fields for field_name, value in context.__dict__.items(): if isinstance(value, np.ndarray): logger.error(f"Field '{field_name}' contains non-serializable type: {type(value)}") except Exception as e: logger.error(traceback.format_exc()) logger.error(e) return context_id def load_or_create_context(self, context_id) -> Context: """ Load a context from a file in the context directory or create a new one if it doesn't exist. Args: context_id: UUID string for the context. Returns: A Context object with the specified ID and default settings. """ if not self.file_watcher: raise Exception("File watcher not initialized") file_path = os.path.join(defines.context_dir, context_id) # Check if the file exists if not os.path.exists(file_path): logger.info(f"Context file {file_path} not found. Creating new context.") self.contexts[context_id] = self.create_context(context_id) else: # Read and deserialize the data with open(file_path, "r") as f: content = f.read() logger.info( f"Loading context from {file_path}, content length: {len(content)}" ) import json try: # Try parsing as JSON first to ensure valid JSON json_data = json.loads(content) logger.info("JSON parsed successfully, attempting model validation") # Validate from JSON (no prometheus_collector or file_watcher) context = Context.model_validate(json_data) # Set excluded fields context.file_watcher = self.file_watcher context.prometheus_collector = self.prometheus_collector # Now set context on agents manually agent_types = [agent.agent_type for agent in context.agents] if len(agent_types) != len(set(agent_types)): raise ValueError( "Context cannot contain multiple agents of the same agent_type" ) for agent in context.agents: agent.set_context(context) self.contexts[context_id] = context logger.info(f"Successfully loaded context {context_id}") except Exception as e: logger.error(f"Error validating context: {str(e)}") logger.error(traceback.format_exc()) # Fallback to creating a new context self.contexts[context_id] = Context( id=context_id, file_watcher=self.file_watcher, prometheus_collector=self.prometheus_collector, ) return self.contexts[context_id] def create_context(self, context_id=None) -> Context: """ Create a new context with a unique ID and default settings. Args: context_id: Optional UUID string for the context. If not provided, a new UUID is generated. Returns: A Context object with the specified ID and default settings. """ if not self.file_watcher: raise Exception("File watcher not initialized") if not context_id: context_id = str(uuid4()) logger.info(f"Creating new context with ID: {context_id}") context = Context( id=context_id, file_watcher=self.file_watcher, prometheus_collector=self.prometheus_collector, ) if os.path.exists(defines.resume_doc): context.user_resume = open(defines.resume_doc, "r").read() context.get_or_create_agent(agent_type="chat") # system_prompt=system_message) # context.add_agent(Resume(system_prompt = system_generate_resume)) # context.add_agent(JobDescription(system_prompt = system_job_description)) # context.add_agent(FactCheck(system_prompt = system_fact_check)) context.tools = Tools.enabled_tools(Tools.tools) context.rags = [ r.model_copy() for r in rags ] logger.info(f"{context.id} created and added to contexts.") self.contexts[context.id] = context self.save_context(context.id) return context def upsert_context(self, context_id=None) -> Context: """ Upsert a context based on the provided context_id. Args: context_id: UUID string for the context. If it doesn't exist, a new context is created. Returns: A Context object with the specified ID and default settings. """ if not context_id: logger.warning("No context ID provided. Creating a new context.") return self.create_context() if context_id in self.contexts: return self.contexts[context_id] logger.info(f"Context {context_id} is not yet loaded.") return self.load_or_create_context(context_id) @REQUEST_TIME.time() async def generate_response( self, context: Context, agent: Agent, prompt: str, tunables: Tunables | None ) -> AsyncGenerator[Message, None]: if not self.file_watcher: raise Exception("File watcher not initialized") agent_type = agent.get_agent_type() logger.info(f"generate_response: type - {agent_type}") # Merge tunables to take agent defaults and override with user supplied settings agent_tunables = agent.tunables.model_dump() if agent.tunables else {} user_tunables = tunables.model_dump() if tunables else {} merged_tunables = {**agent_tunables, **user_tunables} message = Message(prompt=prompt, tunables=Tunables(**merged_tunables)) async for message in agent.prepare_message(message): # logger.info(f"{agent_type}.prepare_message: {value.status} - {value.response}") if message.status == "error": yield message return if message.status != "done": yield message async for message in agent.process_message(self.llm, self.model, message): if message.status != "done": yield message if message.status == "error": return logger.info( f"{agent_type}.process_message: {message.status} {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}" ) message.status = "done" yield message return def run(self, host="0.0.0.0", port=WEB_PORT, **kwargs): try: if self.ssl_enabled: logger.info(f"Starting web server at https://{host}:{port}") uvicorn.run( self.app, host=host, port=port, log_config=None, ssl_keyfile=defines.key_path, ssl_certfile=defines.cert_path, ) else: logger.info(f"Starting web server at http://{host}:{port}") uvicorn.run(self.app, host=host, port=port, log_config=None) except KeyboardInterrupt: if self.observer: self.observer.stop() if self.observer: self.observer.join() # %% # Main function to run everything def main(): global model # Parse command-line arguments args = parse_args() # Setup logging based on the provided level logger.setLevel(args.level.upper()) warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn.*") warnings.filterwarnings("ignore", category=UserWarning, module="umap.*") llm = ollama.Client(host=args.ollama_server) # type: ignore model = args.ollama_model web_server = WebServer(llm, model) web_server.run(host=args.web_host, port=args.web_port, use_reloader=False) main()