caching working version
This commit is contained in:
parent
dd18ca858b
commit
5f6971510a
604
jupyter/stock.py
Normal file
604
jupyter/stock.py
Normal file
@ -0,0 +1,604 @@
|
||||
# %%
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
import gradio as gr
|
||||
import ollama
|
||||
import re
|
||||
import yfinance as yf
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
import requests
|
||||
import json
|
||||
from geopy.geocoders import Nominatim
|
||||
import time
|
||||
import pydle
|
||||
import asyncio
|
||||
import anyio
|
||||
import logging
|
||||
import queue
|
||||
|
||||
# %%
|
||||
system_message = """
|
||||
You are a helpful stock trading agent.
|
||||
You have real time access to current stock trading values and other information including the current date and time and current weather information.
|
||||
If you succeed in getting stock values or time information, do not mention your knowledge cutoff.
|
||||
Give short, courteous answers, no more than 1 sentence.
|
||||
You have access to several tools for obtaining real-time information.
|
||||
Always be accurate.
|
||||
If you don't know the answer, say so. Do not make up details.
|
||||
"""
|
||||
system_log = [{"role": "system", "content": system_message}]
|
||||
history = []
|
||||
tool_log = []
|
||||
|
||||
# %%
|
||||
client = ollama.Client(host="http://ollama:11434")
|
||||
#model = 'deepseek-r1:7b' # Does not support tools with template used by ollama
|
||||
model = 'llama3.2'
|
||||
|
||||
# %%
|
||||
def get_weather_by_location(city, state, country="USA"):
|
||||
"""
|
||||
Get weather information from weather.gov based on city, state, and country.
|
||||
|
||||
Args:
|
||||
city (str): City name
|
||||
state (str): State name or abbreviation
|
||||
country (str): Country name (defaults to "USA" as weather.gov is for US locations)
|
||||
|
||||
Returns:
|
||||
dict: Weather forecast information
|
||||
"""
|
||||
# Step 1: Get coordinates for the location using geocoding
|
||||
location = f"{city}, {state}, {country}"
|
||||
coordinates = get_coordinates(location)
|
||||
|
||||
if not coordinates:
|
||||
return {"error": f"Could not find coordinates for {location}"}
|
||||
|
||||
# Step 2: Get the forecast grid endpoint for the coordinates
|
||||
grid_endpoint = get_grid_endpoint(coordinates)
|
||||
|
||||
if not grid_endpoint:
|
||||
return {"error": f"Could not find weather grid for coordinates {coordinates}"}
|
||||
|
||||
# Step 3: Get the forecast data from the grid endpoint
|
||||
forecast = get_forecast(grid_endpoint)
|
||||
|
||||
return forecast
|
||||
|
||||
def get_coordinates(location):
|
||||
"""Convert a location string to latitude and longitude using Nominatim geocoder."""
|
||||
try:
|
||||
# Create a geocoder with a meaningful user agent
|
||||
geolocator = Nominatim(user_agent="weather_app_example")
|
||||
|
||||
# Get the location
|
||||
location_data = geolocator.geocode(location)
|
||||
|
||||
if location_data:
|
||||
return {
|
||||
"latitude": location_data.latitude,
|
||||
"longitude": location_data.longitude
|
||||
}
|
||||
else:
|
||||
print(f"Location not found: {location}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error getting coordinates: {e}")
|
||||
return None
|
||||
|
||||
def get_grid_endpoint(coordinates):
|
||||
"""Get the grid endpoint from weather.gov based on coordinates."""
|
||||
try:
|
||||
lat = coordinates["latitude"]
|
||||
lon = coordinates["longitude"]
|
||||
|
||||
# Define headers for the API request
|
||||
headers = {
|
||||
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
|
||||
"Accept": "application/geo+json"
|
||||
}
|
||||
|
||||
# Make the request to get the grid endpoint
|
||||
url = f"https://api.weather.gov/points/{lat},{lon}"
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return data["properties"]["forecast"]
|
||||
else:
|
||||
print(f"Error getting grid: {response.status_code} - {response.text}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error in get_grid_endpoint: {e}")
|
||||
return None
|
||||
|
||||
def get_forecast(grid_endpoint):
|
||||
"""Get the forecast data from the grid endpoint."""
|
||||
try:
|
||||
# Define headers for the API request
|
||||
headers = {
|
||||
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
|
||||
"Accept": "application/geo+json"
|
||||
}
|
||||
|
||||
# Make the request to get the forecast
|
||||
response = requests.get(grid_endpoint, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
|
||||
# Extract the relevant forecast information
|
||||
periods = data["properties"]["periods"]
|
||||
|
||||
# Process the forecast data into a simpler format
|
||||
forecast = {
|
||||
"location": data["properties"].get("relativeLocation", {}).get("properties", {}),
|
||||
"updated": data["properties"].get("updated", ""),
|
||||
"periods": []
|
||||
}
|
||||
|
||||
for period in periods:
|
||||
forecast["periods"].append({
|
||||
"name": period.get("name", ""),
|
||||
"temperature": period.get("temperature", ""),
|
||||
"temperatureUnit": period.get("temperatureUnit", ""),
|
||||
"windSpeed": period.get("windSpeed", ""),
|
||||
"windDirection": period.get("windDirection", ""),
|
||||
"shortForecast": period.get("shortForecast", ""),
|
||||
"detailedForecast": period.get("detailedForecast", "")
|
||||
})
|
||||
|
||||
return forecast
|
||||
else:
|
||||
print(f"Error getting forecast: {response.status_code} - {response.text}")
|
||||
return {"error": f"API Error: {response.status_code}"}
|
||||
except Exception as e:
|
||||
print(f"Error in get_forecast: {e}")
|
||||
return {"error": f"Exception: {str(e)}"}
|
||||
|
||||
# Example usage
|
||||
def do_weather():
|
||||
city = input("Enter city: ")
|
||||
state = input("Enter state: ")
|
||||
country = input("Enter country (default USA): ") or "USA"
|
||||
|
||||
print(f"Getting weather for {city}, {state}, {country}...")
|
||||
weather_data = get_weather_by_location(city, state, country)
|
||||
|
||||
if "error" in weather_data:
|
||||
print(f"Error: {weather_data['error']}")
|
||||
else:
|
||||
print("\nWeather Forecast:")
|
||||
print(f"Location: {weather_data.get('location', {}).get('city', city)}, {weather_data.get('location', {}).get('state', state)}")
|
||||
print(f"Last Updated: {weather_data.get('updated', 'N/A')}")
|
||||
print("\nForecast Periods:")
|
||||
|
||||
for period in weather_data.get("periods", []):
|
||||
print(f"\n{period['name']}:")
|
||||
print(f" Temperature: {period['temperature']}{period['temperatureUnit']}")
|
||||
print(f" Wind: {period['windSpeed']} {period['windDirection']}")
|
||||
print(f" Forecast: {period['shortForecast']}")
|
||||
print(f" Details: {period['detailedForecast']}")
|
||||
|
||||
# %%
|
||||
def get_ticker_price(ticker_symbols):
|
||||
"""
|
||||
Look up the current price of a stock using its ticker symbol.
|
||||
|
||||
Args:
|
||||
ticker_symbol (str): The stock ticker symbol (e.g., 'AAPL' for Apple)
|
||||
|
||||
Returns:
|
||||
dict: Current stock information including price
|
||||
"""
|
||||
results = []
|
||||
print(f"get_ticker_price('{ticker_symbols}')")
|
||||
for ticker_symbol in ticker_symbols.split(','):
|
||||
ticker_symbol = ticker_symbol.strip()
|
||||
if ticker_symbol == "":
|
||||
continue
|
||||
# Create a Ticker object
|
||||
try:
|
||||
ticker = yf.Ticker(ticker_symbol)
|
||||
|
||||
# Get the latest market data
|
||||
ticker_data = ticker.history(period="1d")
|
||||
|
||||
if ticker_data.empty:
|
||||
results.append({"error": f"No data found for ticker {ticker_symbol}"})
|
||||
continue
|
||||
|
||||
# Get the latest closing price
|
||||
latest_price = ticker_data['Close'].iloc[-1]
|
||||
|
||||
# Get some additional info
|
||||
info = ticker.info
|
||||
results.append({ 'symbol': ticker_symbol, 'price': latest_price })
|
||||
|
||||
except Exception as e:
|
||||
results.append({"error": f"Error fetching data for {ticker_symbol}: {str(e)}"})
|
||||
|
||||
return results[0] if len(results) == 1 else results
|
||||
#{
|
||||
# "symbol": ticker_symbol,
|
||||
# "price": latest_price,
|
||||
# "currency": info.get("currency", "Unknown"),
|
||||
# "company_name": info.get("shortName", "Unknown"),
|
||||
# "previous_close": info.get("previousClose", "Unknown"),
|
||||
# "market_cap": info.get("marketCap", "Unknown"),
|
||||
#}
|
||||
|
||||
|
||||
# %%
|
||||
def get_current_datetime(timezone="America/Los_Angeles"):
|
||||
"""
|
||||
Returns the current date and time in the specified timezone in ISO 8601 format.
|
||||
|
||||
Args:
|
||||
timezone (str): Timezone name (e.g., "UTC", "America/New_York", "Europe/London")
|
||||
Default is "America/Los_Angeles"
|
||||
|
||||
Returns:
|
||||
str: Current date and time with timezone in the format YYYY-MM-DDTHH:MM:SS+HH:MM
|
||||
"""
|
||||
try:
|
||||
if timezone == 'system' or timezone == '' or not timezone:
|
||||
timezone = 'America/Los_Angeles'
|
||||
# Get current UTC time (timezone-aware)
|
||||
local_tz = pytz.timezone("America/Los_Angeles")
|
||||
local_now = datetime.now(tz=local_tz)
|
||||
|
||||
# Convert to target timezone
|
||||
target_tz = pytz.timezone(timezone)
|
||||
target_time = local_now.astimezone(target_tz)
|
||||
|
||||
return target_time.isoformat()
|
||||
except Exception as e:
|
||||
return {'error': f"Invalid timezone {timezone}: {str(e)}"}
|
||||
|
||||
|
||||
# %%
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_ticker_price",
|
||||
"description": "Get the current stock price of one or more ticker symbols. Returns an array of objects with 'symbol' and 'price' fields. Call this whenever you need to know the latest value of stock ticker symbols, for example when a user asks 'How much is Intel trading at?' or 'What are the prices of AAPL and MSFT?'",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ticker": {
|
||||
"type": "string",
|
||||
"description": "The company stock ticker symbol. For multiple tickers, provide a comma-separated list (e.g., 'AAPL,MSFT,GOOGL').",
|
||||
},
|
||||
},
|
||||
"required": ["ticker"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
}, {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_datetime",
|
||||
"description": "Get the current date and time in a specified timezone",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
}, {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather_by_location",
|
||||
"description": "Get the current and future weather for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?'",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City to find the weather (e.g., 'Portland', 'Seattle')."
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "State to find the weather (e.g., 'OR', 'WA')."
|
||||
}
|
||||
},
|
||||
"required": [ "city", "state" ],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
# %%
|
||||
def handle_tool_calls(message):
|
||||
response = []
|
||||
tools_used = []
|
||||
for tool_call in message['tool_calls']:
|
||||
arguments = tool_call['function']['arguments']
|
||||
tool = tool_call['function']['name']
|
||||
if tool == 'get_ticker_price':
|
||||
ticker = arguments.get('ticker')
|
||||
if not ticker:
|
||||
ret = None
|
||||
else:
|
||||
ret = get_ticker_price(ticker)
|
||||
tools_used.append(tool)
|
||||
elif tool == 'get_current_datetime':
|
||||
ret = get_current_datetime(arguments.get('timezone'))
|
||||
tools_used.append(tool)
|
||||
elif tool == 'get_weather_by_location':
|
||||
ret = get_weather_by_location(arguments.get('city'), arguments.get('state'))
|
||||
tools_used.append(tool)
|
||||
else:
|
||||
ret = None
|
||||
response.append({
|
||||
"role": "tool",
|
||||
"content": str(ret),
|
||||
"name": tool_call['function']['name']
|
||||
})
|
||||
return response[0] if len(response) == 1 else response, tools_used
|
||||
|
||||
# %%
|
||||
irc_channel = None
|
||||
def chat(history, is_irc=False):
|
||||
global irc_bot, irc_channel
|
||||
message = history[-1]
|
||||
|
||||
logging.info("chat()")
|
||||
if irc_channel and not is_irc:
|
||||
asyncio.run(irc_bot.message(irc_channel, f"[console] {message['content']}"))
|
||||
|
||||
matches = re.match(r"^!([^\s]+)\s*(.*)$", message['content'])
|
||||
if message['role'] == 'user' and matches:
|
||||
command = matches.group(1)
|
||||
arguments = matches.group(2).strip()
|
||||
|
||||
# Handle special IRC commands
|
||||
if command == "test":
|
||||
response = f"test received: {arguments}"
|
||||
|
||||
elif command == "server":
|
||||
server = arguments.split(" ", 1)
|
||||
if len(server) == 1:
|
||||
server = "irc.libera.chat"
|
||||
else:
|
||||
server = server[1]
|
||||
|
||||
response = asyncio.run(irc_bot.connect(server, 6667, tls=False))
|
||||
|
||||
elif command == "join":
|
||||
channel = arguments.strip()
|
||||
if channel == "" or re.match(r"\s", channel):
|
||||
response = "Usage: !join CHANNEL"
|
||||
else:
|
||||
if not re.match(r"^#", channel):
|
||||
channel = f"#{channel}"
|
||||
if irc_channel and irc_channel != channel:
|
||||
asyncio.run(irc_bot.part(channel))
|
||||
if channel:
|
||||
asyncio.run(irc_bot.join(channel))
|
||||
irc_channel = channel
|
||||
response = f"Joined {channel}."
|
||||
|
||||
else:
|
||||
response = f"Unrecognized command: {command}"
|
||||
|
||||
airc_response = [{ "role": "user", "content": response, "metadata": { "title": f"airc: {command}"}}]
|
||||
|
||||
return history + airc_response
|
||||
|
||||
#return response, chat_history
|
||||
messages = system_log + history
|
||||
response = client.chat(model=model, messages=messages, tools=tools)
|
||||
tools_used = []
|
||||
if 'tool_calls' in response['message']:
|
||||
message = response['message']
|
||||
tool_result, tools_used = handle_tool_calls(message)
|
||||
messages.append(message)
|
||||
messages.append(tool_result)
|
||||
response = client.chat(model=model, messages=messages)
|
||||
tool_log.append({ 'call': message, 'tool_result': tool_result, 'tools_used': tools_used, 'response': response['message']['content']})
|
||||
|
||||
reply = response['message']['content']
|
||||
if len(tools_used):
|
||||
history += [{"role":"assistant", "content":reply, 'metadata': {"title": f"🛠️ Tool(s) used: {','.join(tools_used)}"}}]
|
||||
else:
|
||||
history += [{"role":"assistant", "content":reply}]
|
||||
|
||||
if irc_channel:
|
||||
asyncio.run(irc_bot.message(irc_channel, reply))
|
||||
|
||||
return history
|
||||
|
||||
|
||||
# %%
|
||||
# Queue to pass messages from IRC and CLI thread to Gradio
|
||||
message_queue = queue.Queue()
|
||||
|
||||
def check_message_queue(history=[]):
|
||||
if message_queue.empty():
|
||||
return gr.skip()
|
||||
|
||||
while not message_queue.empty():
|
||||
irc_message = message_queue.get_nowait()
|
||||
logging.info(f"Processing: {irc_message}")
|
||||
# Add the CLI message to chat history
|
||||
last_message = history[-1] if len(history) else None
|
||||
try:
|
||||
matches = re.match(r"^([^:]+)\s*:\s*(.*)$", irc_message['message'])
|
||||
if not matches:
|
||||
if last_message and last_message["role"] == "user":
|
||||
logging.info("Not a user directed message: modifying last USER context")
|
||||
last_message['content'] += irc_message['message']
|
||||
history[-1] = last_message
|
||||
else:
|
||||
logging.info("Not a user directed message: appending new USER context")
|
||||
history.append({
|
||||
"role": "user",
|
||||
"content": irc_message['message']
|
||||
})
|
||||
else:
|
||||
user = matches.group(1).strip()
|
||||
content = matches.group(2).strip()
|
||||
if user != irc_bot.nickname:
|
||||
if last_message and last_message["role"] == "user":
|
||||
logging.info(f"Not directed message to {irc_bot.nickname} ({user}): modifying last USER context")
|
||||
last_message['content'] += irc_message['message']
|
||||
history[-1] = last_message
|
||||
else:
|
||||
logging.info(f"Not directed message to {irc_bot.nickname} ({user}): appending a new USER context")
|
||||
history.append({
|
||||
"role": "user",
|
||||
"content": content
|
||||
})
|
||||
else:
|
||||
if last_message and last_message["role"] == "user":
|
||||
logging.info(f"Directed message to {irc_bot.nickname}: modifying last USER context")
|
||||
last_message['content'] += f"\n{irc_message['message']}"
|
||||
history[-1] = last_message
|
||||
else:
|
||||
logging.info(f"Directed message to {irc_bot.nickname}: appending a new USER context")
|
||||
history.append({
|
||||
"role": "user",
|
||||
"content": f"\n{irc_message['message']}"
|
||||
})
|
||||
history = chat(history, is_irc=True)
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.error({ "error": f"Unable to process message {irc_message}", "exception": e})
|
||||
|
||||
return history
|
||||
|
||||
# %%
|
||||
def create_ui():
|
||||
global chatbot, timer
|
||||
with gr.Blocks() as ui:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
chatbot = gr.Chatbot(
|
||||
# value=check_message_queue,
|
||||
# every=1,
|
||||
type="messages",
|
||||
height=600, # Set a fixed height
|
||||
container=True, # Enable scrolling
|
||||
elem_id="chatbot",
|
||||
)
|
||||
entry = gr.Textbox(
|
||||
label="Chat with our AI Assistant:",
|
||||
container=False,
|
||||
scale=1
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
chat_history = gr.JSON(
|
||||
system_log,
|
||||
label="Chat History",
|
||||
height=300 # Set height for JSON viewer
|
||||
)
|
||||
tool_history = gr.JSON(
|
||||
tool_log,
|
||||
label="Tool calls",
|
||||
height=300 # Set height for JSON viewer
|
||||
)
|
||||
with gr.Row():
|
||||
clear = gr.Button("Clear")
|
||||
timer = gr.Timer(1)
|
||||
|
||||
def do_entry(message, history):
|
||||
history += [{"role":"user", "content":message}]
|
||||
return "", history, system_log + history, tool_log # message, chat, logs
|
||||
|
||||
def do_clear():
|
||||
history = []
|
||||
tool_log = []
|
||||
return history, system_log, tool_log
|
||||
|
||||
def update_log(history):
|
||||
# This function updates the log after the chatbot responds
|
||||
return system_log + history, tool_log
|
||||
|
||||
entry.submit(
|
||||
do_entry,
|
||||
inputs=[entry, chatbot],
|
||||
outputs=[entry, chatbot, chat_history, tool_history]
|
||||
).then(
|
||||
chat, # Your chat function that processes responses
|
||||
inputs=chatbot,
|
||||
outputs=chatbot
|
||||
).then(
|
||||
update_log, # This new function updates the log after chatbot processing
|
||||
inputs=chatbot,
|
||||
outputs=[chat_history, tool_history]
|
||||
)
|
||||
|
||||
# timer.tick(check_message_queue, inputs=chatbot, outputs=chatbot).then(
|
||||
# update_log, # This new function updates the log after chatbot processing
|
||||
# inputs=chatbot,
|
||||
# outputs=[chat_history, tool_history]
|
||||
# )
|
||||
|
||||
clear.click(do_clear, inputs=None, outputs=[chatbot, chat_history, tool_history], queue=False)
|
||||
|
||||
return ui
|
||||
|
||||
# %%
|
||||
|
||||
# %%
|
||||
irc_bot = None # Placeholder for the bot instance
|
||||
|
||||
class DynamicIRCBot(pydle.Client):
|
||||
async def on_connect(self):
|
||||
await super().on_connect()
|
||||
print(f"CONNECT: {self.nickname}")
|
||||
if irc_channel:
|
||||
await self.join(irc_channel, self.nickname)
|
||||
|
||||
async def on_join(self, channel, user):
|
||||
await super().on_join(channel, user)
|
||||
print(f"JOIN: {user} => {channel}")
|
||||
|
||||
async def on_part(self, channel, user):
|
||||
await super().on_part(channel, user)
|
||||
print(f"PART: {channel} => {user}")
|
||||
|
||||
async def on_message(self, target, source, message):
|
||||
global chatbot, timer
|
||||
await super().on_message(target, source, message)
|
||||
print(f"MESSAGE: {source} => {target}: {message}")
|
||||
if source == self.nickname:
|
||||
return
|
||||
message_queue.put({'target': target, 'source': source, 'message': message})
|
||||
|
||||
# Main function to run everything
|
||||
async def main():
|
||||
setup_logging("DEBUG")
|
||||
|
||||
global irc_bot, irc_channel
|
||||
irc_bot = DynamicIRCBot("airc")
|
||||
irc_channel = "#airc-test"
|
||||
# await irc_bot.connect("irc.libera.chat", 6667, tls=False)
|
||||
await irc_bot.connect("miniircd", 6667, tls=False)
|
||||
|
||||
ui = create_ui()
|
||||
ui.launch(server_name="0.0.0.0", server_port=60673, prevent_thread_lock=True, pwa=True)
|
||||
|
||||
await irc_bot.handle_forever()
|
||||
|
||||
def setup_logging(level):
|
||||
numeric_level = getattr(logging, level.upper(), None)
|
||||
if not isinstance(numeric_level, int):
|
||||
raise ValueError(f"Invalid log level: {level}")
|
||||
|
||||
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logging.info(f"Logging is set to {level} level.")
|
||||
|
||||
# Run the main function using anyio
|
||||
asyncio.run(main())
|
Loading…
x
Reference in New Issue
Block a user