caching working version

This commit is contained in:
James Ketr 2025-03-18 12:59:28 -07:00
parent dd18ca858b
commit 5f6971510a

604
jupyter/stock.py Normal file
View File

@ -0,0 +1,604 @@
# %%
import os
import json
from dotenv import load_dotenv
from openai import OpenAI
import gradio as gr
import ollama
import re
import yfinance as yf
from datetime import datetime
import pytz
import requests
import json
from geopy.geocoders import Nominatim
import time
import pydle
import asyncio
import anyio
import logging
import queue
# %%
system_message = """
You are a helpful stock trading agent.
You have real time access to current stock trading values and other information including the current date and time and current weather information.
If you succeed in getting stock values or time information, do not mention your knowledge cutoff.
Give short, courteous answers, no more than 1 sentence.
You have access to several tools for obtaining real-time information.
Always be accurate.
If you don't know the answer, say so. Do not make up details.
"""
system_log = [{"role": "system", "content": system_message}]
history = []
tool_log = []
# %%
client = ollama.Client(host="http://ollama:11434")
#model = 'deepseek-r1:7b' # Does not support tools with template used by ollama
model = 'llama3.2'
# %%
def get_weather_by_location(city, state, country="USA"):
"""
Get weather information from weather.gov based on city, state, and country.
Args:
city (str): City name
state (str): State name or abbreviation
country (str): Country name (defaults to "USA" as weather.gov is for US locations)
Returns:
dict: Weather forecast information
"""
# Step 1: Get coordinates for the location using geocoding
location = f"{city}, {state}, {country}"
coordinates = get_coordinates(location)
if not coordinates:
return {"error": f"Could not find coordinates for {location}"}
# Step 2: Get the forecast grid endpoint for the coordinates
grid_endpoint = get_grid_endpoint(coordinates)
if not grid_endpoint:
return {"error": f"Could not find weather grid for coordinates {coordinates}"}
# Step 3: Get the forecast data from the grid endpoint
forecast = get_forecast(grid_endpoint)
return forecast
def get_coordinates(location):
"""Convert a location string to latitude and longitude using Nominatim geocoder."""
try:
# Create a geocoder with a meaningful user agent
geolocator = Nominatim(user_agent="weather_app_example")
# Get the location
location_data = geolocator.geocode(location)
if location_data:
return {
"latitude": location_data.latitude,
"longitude": location_data.longitude
}
else:
print(f"Location not found: {location}")
return None
except Exception as e:
print(f"Error getting coordinates: {e}")
return None
def get_grid_endpoint(coordinates):
"""Get the grid endpoint from weather.gov based on coordinates."""
try:
lat = coordinates["latitude"]
lon = coordinates["longitude"]
# Define headers for the API request
headers = {
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
"Accept": "application/geo+json"
}
# Make the request to get the grid endpoint
url = f"https://api.weather.gov/points/{lat},{lon}"
response = requests.get(url, headers=headers)
if response.status_code == 200:
data = response.json()
return data["properties"]["forecast"]
else:
print(f"Error getting grid: {response.status_code} - {response.text}")
return None
except Exception as e:
print(f"Error in get_grid_endpoint: {e}")
return None
def get_forecast(grid_endpoint):
"""Get the forecast data from the grid endpoint."""
try:
# Define headers for the API request
headers = {
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
"Accept": "application/geo+json"
}
# Make the request to get the forecast
response = requests.get(grid_endpoint, headers=headers)
if response.status_code == 200:
data = response.json()
# Extract the relevant forecast information
periods = data["properties"]["periods"]
# Process the forecast data into a simpler format
forecast = {
"location": data["properties"].get("relativeLocation", {}).get("properties", {}),
"updated": data["properties"].get("updated", ""),
"periods": []
}
for period in periods:
forecast["periods"].append({
"name": period.get("name", ""),
"temperature": period.get("temperature", ""),
"temperatureUnit": period.get("temperatureUnit", ""),
"windSpeed": period.get("windSpeed", ""),
"windDirection": period.get("windDirection", ""),
"shortForecast": period.get("shortForecast", ""),
"detailedForecast": period.get("detailedForecast", "")
})
return forecast
else:
print(f"Error getting forecast: {response.status_code} - {response.text}")
return {"error": f"API Error: {response.status_code}"}
except Exception as e:
print(f"Error in get_forecast: {e}")
return {"error": f"Exception: {str(e)}"}
# Example usage
def do_weather():
city = input("Enter city: ")
state = input("Enter state: ")
country = input("Enter country (default USA): ") or "USA"
print(f"Getting weather for {city}, {state}, {country}...")
weather_data = get_weather_by_location(city, state, country)
if "error" in weather_data:
print(f"Error: {weather_data['error']}")
else:
print("\nWeather Forecast:")
print(f"Location: {weather_data.get('location', {}).get('city', city)}, {weather_data.get('location', {}).get('state', state)}")
print(f"Last Updated: {weather_data.get('updated', 'N/A')}")
print("\nForecast Periods:")
for period in weather_data.get("periods", []):
print(f"\n{period['name']}:")
print(f" Temperature: {period['temperature']}{period['temperatureUnit']}")
print(f" Wind: {period['windSpeed']} {period['windDirection']}")
print(f" Forecast: {period['shortForecast']}")
print(f" Details: {period['detailedForecast']}")
# %%
def get_ticker_price(ticker_symbols):
"""
Look up the current price of a stock using its ticker symbol.
Args:
ticker_symbol (str): The stock ticker symbol (e.g., 'AAPL' for Apple)
Returns:
dict: Current stock information including price
"""
results = []
print(f"get_ticker_price('{ticker_symbols}')")
for ticker_symbol in ticker_symbols.split(','):
ticker_symbol = ticker_symbol.strip()
if ticker_symbol == "":
continue
# Create a Ticker object
try:
ticker = yf.Ticker(ticker_symbol)
# Get the latest market data
ticker_data = ticker.history(period="1d")
if ticker_data.empty:
results.append({"error": f"No data found for ticker {ticker_symbol}"})
continue
# Get the latest closing price
latest_price = ticker_data['Close'].iloc[-1]
# Get some additional info
info = ticker.info
results.append({ 'symbol': ticker_symbol, 'price': latest_price })
except Exception as e:
results.append({"error": f"Error fetching data for {ticker_symbol}: {str(e)}"})
return results[0] if len(results) == 1 else results
#{
# "symbol": ticker_symbol,
# "price": latest_price,
# "currency": info.get("currency", "Unknown"),
# "company_name": info.get("shortName", "Unknown"),
# "previous_close": info.get("previousClose", "Unknown"),
# "market_cap": info.get("marketCap", "Unknown"),
#}
# %%
def get_current_datetime(timezone="America/Los_Angeles"):
"""
Returns the current date and time in the specified timezone in ISO 8601 format.
Args:
timezone (str): Timezone name (e.g., "UTC", "America/New_York", "Europe/London")
Default is "America/Los_Angeles"
Returns:
str: Current date and time with timezone in the format YYYY-MM-DDTHH:MM:SS+HH:MM
"""
try:
if timezone == 'system' or timezone == '' or not timezone:
timezone = 'America/Los_Angeles'
# Get current UTC time (timezone-aware)
local_tz = pytz.timezone("America/Los_Angeles")
local_now = datetime.now(tz=local_tz)
# Convert to target timezone
target_tz = pytz.timezone(timezone)
target_time = local_now.astimezone(target_tz)
return target_time.isoformat()
except Exception as e:
return {'error': f"Invalid timezone {timezone}: {str(e)}"}
# %%
tools = [{
"type": "function",
"function": {
"name": "get_ticker_price",
"description": "Get the current stock price of one or more ticker symbols. Returns an array of objects with 'symbol' and 'price' fields. Call this whenever you need to know the latest value of stock ticker symbols, for example when a user asks 'How much is Intel trading at?' or 'What are the prices of AAPL and MSFT?'",
"parameters": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "The company stock ticker symbol. For multiple tickers, provide a comma-separated list (e.g., 'AAPL,MSFT,GOOGL').",
},
},
"required": ["ticker"],
"additionalProperties": False
}
}
}, {
"type": "function",
"function": {
"name": "get_current_datetime",
"description": "Get the current date and time in a specified timezone",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'."
}
},
"required": []
}
}
}, {
"type": "function",
"function": {
"name": "get_weather_by_location",
"description": "Get the current and future weather for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?'",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City to find the weather (e.g., 'Portland', 'Seattle')."
},
"state": {
"type": "string",
"description": "State to find the weather (e.g., 'OR', 'WA')."
}
},
"required": [ "city", "state" ],
"additionalProperties": False
}
}
}]
# %%
def handle_tool_calls(message):
response = []
tools_used = []
for tool_call in message['tool_calls']:
arguments = tool_call['function']['arguments']
tool = tool_call['function']['name']
if tool == 'get_ticker_price':
ticker = arguments.get('ticker')
if not ticker:
ret = None
else:
ret = get_ticker_price(ticker)
tools_used.append(tool)
elif tool == 'get_current_datetime':
ret = get_current_datetime(arguments.get('timezone'))
tools_used.append(tool)
elif tool == 'get_weather_by_location':
ret = get_weather_by_location(arguments.get('city'), arguments.get('state'))
tools_used.append(tool)
else:
ret = None
response.append({
"role": "tool",
"content": str(ret),
"name": tool_call['function']['name']
})
return response[0] if len(response) == 1 else response, tools_used
# %%
irc_channel = None
def chat(history, is_irc=False):
global irc_bot, irc_channel
message = history[-1]
logging.info("chat()")
if irc_channel and not is_irc:
asyncio.run(irc_bot.message(irc_channel, f"[console] {message['content']}"))
matches = re.match(r"^!([^\s]+)\s*(.*)$", message['content'])
if message['role'] == 'user' and matches:
command = matches.group(1)
arguments = matches.group(2).strip()
# Handle special IRC commands
if command == "test":
response = f"test received: {arguments}"
elif command == "server":
server = arguments.split(" ", 1)
if len(server) == 1:
server = "irc.libera.chat"
else:
server = server[1]
response = asyncio.run(irc_bot.connect(server, 6667, tls=False))
elif command == "join":
channel = arguments.strip()
if channel == "" or re.match(r"\s", channel):
response = "Usage: !join CHANNEL"
else:
if not re.match(r"^#", channel):
channel = f"#{channel}"
if irc_channel and irc_channel != channel:
asyncio.run(irc_bot.part(channel))
if channel:
asyncio.run(irc_bot.join(channel))
irc_channel = channel
response = f"Joined {channel}."
else:
response = f"Unrecognized command: {command}"
airc_response = [{ "role": "user", "content": response, "metadata": { "title": f"airc: {command}"}}]
return history + airc_response
#return response, chat_history
messages = system_log + history
response = client.chat(model=model, messages=messages, tools=tools)
tools_used = []
if 'tool_calls' in response['message']:
message = response['message']
tool_result, tools_used = handle_tool_calls(message)
messages.append(message)
messages.append(tool_result)
response = client.chat(model=model, messages=messages)
tool_log.append({ 'call': message, 'tool_result': tool_result, 'tools_used': tools_used, 'response': response['message']['content']})
reply = response['message']['content']
if len(tools_used):
history += [{"role":"assistant", "content":reply, 'metadata': {"title": f"🛠️ Tool(s) used: {','.join(tools_used)}"}}]
else:
history += [{"role":"assistant", "content":reply}]
if irc_channel:
asyncio.run(irc_bot.message(irc_channel, reply))
return history
# %%
# Queue to pass messages from IRC and CLI thread to Gradio
message_queue = queue.Queue()
def check_message_queue(history=[]):
if message_queue.empty():
return gr.skip()
while not message_queue.empty():
irc_message = message_queue.get_nowait()
logging.info(f"Processing: {irc_message}")
# Add the CLI message to chat history
last_message = history[-1] if len(history) else None
try:
matches = re.match(r"^([^:]+)\s*:\s*(.*)$", irc_message['message'])
if not matches:
if last_message and last_message["role"] == "user":
logging.info("Not a user directed message: modifying last USER context")
last_message['content'] += irc_message['message']
history[-1] = last_message
else:
logging.info("Not a user directed message: appending new USER context")
history.append({
"role": "user",
"content": irc_message['message']
})
else:
user = matches.group(1).strip()
content = matches.group(2).strip()
if user != irc_bot.nickname:
if last_message and last_message["role"] == "user":
logging.info(f"Not directed message to {irc_bot.nickname} ({user}): modifying last USER context")
last_message['content'] += irc_message['message']
history[-1] = last_message
else:
logging.info(f"Not directed message to {irc_bot.nickname} ({user}): appending a new USER context")
history.append({
"role": "user",
"content": content
})
else:
if last_message and last_message["role"] == "user":
logging.info(f"Directed message to {irc_bot.nickname}: modifying last USER context")
last_message['content'] += f"\n{irc_message['message']}"
history[-1] = last_message
else:
logging.info(f"Directed message to {irc_bot.nickname}: appending a new USER context")
history.append({
"role": "user",
"content": f"\n{irc_message['message']}"
})
history = chat(history, is_irc=True)
pass
except Exception as e:
logging.error({ "error": f"Unable to process message {irc_message}", "exception": e})
return history
# %%
def create_ui():
global chatbot, timer
with gr.Blocks() as ui:
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
# value=check_message_queue,
# every=1,
type="messages",
height=600, # Set a fixed height
container=True, # Enable scrolling
elem_id="chatbot",
)
entry = gr.Textbox(
label="Chat with our AI Assistant:",
container=False,
scale=1
)
with gr.Column(scale=1):
chat_history = gr.JSON(
system_log,
label="Chat History",
height=300 # Set height for JSON viewer
)
tool_history = gr.JSON(
tool_log,
label="Tool calls",
height=300 # Set height for JSON viewer
)
with gr.Row():
clear = gr.Button("Clear")
timer = gr.Timer(1)
def do_entry(message, history):
history += [{"role":"user", "content":message}]
return "", history, system_log + history, tool_log # message, chat, logs
def do_clear():
history = []
tool_log = []
return history, system_log, tool_log
def update_log(history):
# This function updates the log after the chatbot responds
return system_log + history, tool_log
entry.submit(
do_entry,
inputs=[entry, chatbot],
outputs=[entry, chatbot, chat_history, tool_history]
).then(
chat, # Your chat function that processes responses
inputs=chatbot,
outputs=chatbot
).then(
update_log, # This new function updates the log after chatbot processing
inputs=chatbot,
outputs=[chat_history, tool_history]
)
# timer.tick(check_message_queue, inputs=chatbot, outputs=chatbot).then(
# update_log, # This new function updates the log after chatbot processing
# inputs=chatbot,
# outputs=[chat_history, tool_history]
# )
clear.click(do_clear, inputs=None, outputs=[chatbot, chat_history, tool_history], queue=False)
return ui
# %%
# %%
irc_bot = None # Placeholder for the bot instance
class DynamicIRCBot(pydle.Client):
async def on_connect(self):
await super().on_connect()
print(f"CONNECT: {self.nickname}")
if irc_channel:
await self.join(irc_channel, self.nickname)
async def on_join(self, channel, user):
await super().on_join(channel, user)
print(f"JOIN: {user} => {channel}")
async def on_part(self, channel, user):
await super().on_part(channel, user)
print(f"PART: {channel} => {user}")
async def on_message(self, target, source, message):
global chatbot, timer
await super().on_message(target, source, message)
print(f"MESSAGE: {source} => {target}: {message}")
if source == self.nickname:
return
message_queue.put({'target': target, 'source': source, 'message': message})
# Main function to run everything
async def main():
setup_logging("DEBUG")
global irc_bot, irc_channel
irc_bot = DynamicIRCBot("airc")
irc_channel = "#airc-test"
# await irc_bot.connect("irc.libera.chat", 6667, tls=False)
await irc_bot.connect("miniircd", 6667, tls=False)
ui = create_ui()
ui.launch(server_name="0.0.0.0", server_port=60673, prevent_thread_lock=True, pwa=True)
await irc_bot.handle_forever()
def setup_logging(level):
numeric_level = getattr(logging, level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f"Invalid log level: {level}")
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(message)s')
logging.info(f"Logging is set to {level} level.")
# Run the main function using anyio
asyncio.run(main())