Tools are working and messages are rate limited
This commit is contained in:
parent
6829decfbb
commit
786444ce3c
@ -252,7 +252,8 @@ RUN pip3 install 'bigdl-core-xe-all>=2.6.0b'
|
||||
# NOTE: IPEX includes the oneAPI components... not sure if they still need to be installed separately with a oneAPI env
|
||||
RUN pip install einops diffusers # Required for IPEX optimize(), which is required to convert from Params4bit
|
||||
|
||||
RUN pip install yfinance pyzt geopy
|
||||
# Install packages needed for stock.py
|
||||
RUN pip install yfinance pyzt geopy PyHyphen
|
||||
|
||||
SHELL [ "/bin/bash", "-c" ]
|
||||
|
||||
|
179
jupyter/stock.py
179
jupyter/stock.py
@ -11,6 +11,7 @@ import queue
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
import textwrap
|
||||
|
||||
def try_import(module_name, pip_name=None):
|
||||
try:
|
||||
@ -29,6 +30,7 @@ try_import('requests')
|
||||
try_import('yfinance', 'yfinance')
|
||||
try_import('dotenv', 'python-dotenv')
|
||||
try_import('geopy', 'geopy')
|
||||
try_import('hyphen', 'PyHyphen')
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from geopy.geocoders import Nominatim
|
||||
@ -39,6 +41,7 @@ import pydle
|
||||
import pytz
|
||||
import requests
|
||||
import yfinance as yf
|
||||
from hyphen import hyphenator
|
||||
|
||||
# Local defined imports
|
||||
from tools import (
|
||||
@ -119,6 +122,88 @@ def setup_logging(level):
|
||||
|
||||
logging.info(f"Logging is set to {level} level.")
|
||||
|
||||
# %%
|
||||
def split_paragraph_with_hyphenation(text, line_length=80, language='en_US'):
|
||||
"""
|
||||
Split a paragraph into multiple lines with proper hyphenation.
|
||||
|
||||
Args:
|
||||
text (str): The text to split.
|
||||
line_length (int): The maximum length of each line.
|
||||
language (str): The language code for hyphenation rules.
|
||||
|
||||
Returns:
|
||||
[str]: The text split into multiple lines with proper hyphenation.
|
||||
"""
|
||||
# Initialize the hyphenator for the specified language
|
||||
h = hyphenator.Hyphenator(language)
|
||||
|
||||
# First attempt: try to wrap without hyphenation
|
||||
lines = textwrap.wrap(text, width=line_length)
|
||||
|
||||
# If any lines are too long, we need to apply hyphenation
|
||||
result_lines = []
|
||||
|
||||
for line in lines:
|
||||
# If the line is already short enough, keep it as is
|
||||
if len(line) <= line_length:
|
||||
result_lines.append(line)
|
||||
continue
|
||||
|
||||
# Otherwise, we need to hyphenate
|
||||
words = line.split()
|
||||
current_line = ""
|
||||
|
||||
for word in words:
|
||||
# If adding the word doesn't exceed the limit, add it
|
||||
if len(current_line) + len(word) + (1 if current_line else 0) <= line_length:
|
||||
if current_line:
|
||||
current_line += " "
|
||||
current_line += word
|
||||
# If the word itself is too long, hyphenate it
|
||||
elif len(word) > line_length - len(current_line) - (1 if current_line else 0):
|
||||
# If we already have content on the line, add it to results
|
||||
if current_line:
|
||||
result_lines.append(current_line)
|
||||
current_line = ""
|
||||
|
||||
# Get hyphenation points for the word
|
||||
hyphenated = h.syllables(word)
|
||||
|
||||
if not hyphenated:
|
||||
# If no hyphenation points found, just add the word to a new line
|
||||
result_lines.append(word)
|
||||
continue
|
||||
|
||||
# Try to find a suitable hyphenation point
|
||||
partial_word = ""
|
||||
for syllable in hyphenated:
|
||||
if len(partial_word) + len(syllable) + 1 > line_length:
|
||||
# Add hyphen to the partial word and start a new line
|
||||
if partial_word:
|
||||
result_lines.append(partial_word + "-")
|
||||
partial_word = syllable
|
||||
else:
|
||||
# If a single syllable is too long, just add it
|
||||
result_lines.append(syllable)
|
||||
else:
|
||||
partial_word += syllable
|
||||
|
||||
# Don't forget the remaining part
|
||||
if partial_word:
|
||||
current_line = partial_word
|
||||
|
||||
else:
|
||||
# Start a new line with this word
|
||||
result_lines.append(current_line)
|
||||
current_line = word
|
||||
|
||||
# Don't forget any remaining content
|
||||
if current_line:
|
||||
result_lines.append(current_line)
|
||||
|
||||
return result_lines
|
||||
|
||||
# %%
|
||||
def handle_tool_calls(message):
|
||||
response = []
|
||||
@ -165,8 +250,20 @@ async def chat(history, is_irc=False):
|
||||
tools_used = []
|
||||
if 'tool_calls' in response['message']:
|
||||
message = response['message']
|
||||
# Convert Message object to a proper dictionary format
|
||||
message_dict = {
|
||||
'role': message.get('role', 'assistant'),
|
||||
'content': message.get('content', '')
|
||||
}
|
||||
# Add tool_calls if present, ensuring they're properly serialized
|
||||
if 'tool_calls' in message:
|
||||
message_dict['tool_calls'] = [
|
||||
{'function': {'name': tc.function.name, 'arguments': tc.function.arguments}}
|
||||
for tc in message['tool_calls']
|
||||
]
|
||||
|
||||
tool_result, tools_used = handle_tool_calls(message)
|
||||
messages.append(message)
|
||||
messages.append(message_dict) # Add properly formatted dict instead of Message object
|
||||
messages.append(tool_result)
|
||||
try:
|
||||
response = client.chat(model=model, messages=messages)
|
||||
@ -188,12 +285,71 @@ async def chat(history, is_irc=False):
|
||||
# %%
|
||||
|
||||
class DynamicIRCBot(pydle.Client):
|
||||
def __init__(self, nickname, channel, bot_admin, system_info, **kwargs):
|
||||
def __init__(self, nickname, channel, bot_admin, system_info, burst_limit = 5, rate_limit = 1.0, burst_reset_timeout = 10.0, **kwargs):
|
||||
super().__init__(nickname, **kwargs)
|
||||
self.history = []
|
||||
self.channel = channel
|
||||
self.bot_admin = bot_admin
|
||||
self.system_info = system_info
|
||||
# Message throttling
|
||||
self.burst_limit = burst_limit
|
||||
self.sent_burst = 0
|
||||
self.rate_limit = rate_limit
|
||||
self.burst_reset_timeout = burst_reset_timeout
|
||||
self.sent_burst = 0 # Track messages sent in burst
|
||||
self.last_message_time = None # Track last message time
|
||||
self._message_queue = asyncio.Queue()
|
||||
self._task = asyncio.create_task(self._send_from_queue())
|
||||
|
||||
async def _send_from_queue(self):
|
||||
"""Background task that sends queued messages with burst + rate limiting."""
|
||||
while True:
|
||||
target, message = await self._message_queue.get()
|
||||
|
||||
# If burst is still available, send immediately
|
||||
if self.sent_burst < self.burst_limit:
|
||||
self.sent_burst += 1
|
||||
else:
|
||||
await asyncio.sleep(self.rate_limit) # Apply rate limit
|
||||
|
||||
logging.debug(f"Sending {message} => {target} from queue")
|
||||
await super().message(target, message) # Send message
|
||||
self.last_message_time = asyncio.get_event_loop().time() # Update last message timestamp
|
||||
|
||||
# Start burst reset countdown after each message
|
||||
asyncio.create_task(self._reset_burst_after_inactivity())
|
||||
|
||||
async def _reset_burst_after_inactivity(self):
|
||||
"""Resets burst counter only if no new messages are sent within timeout."""
|
||||
last_time = self.last_message_time
|
||||
await asyncio.sleep(self.burst_reset_timeout) # Wait for inactivity period
|
||||
|
||||
# Only reset if no new messages were sent during the wait
|
||||
if self.last_message_time == last_time:
|
||||
self.sent_burst = 0
|
||||
logging.info("Burst limit reset due to inactivity.")
|
||||
|
||||
async def message(self, target, message):
|
||||
"""Splits a multi-line message and sends each line separately. If more than 10 lines, truncate and add a message."""
|
||||
max_lines = 10
|
||||
irc_lines = []
|
||||
for line in message.splitlines():
|
||||
lines = split_paragraph_with_hyphenation(line, line_length=450)
|
||||
irc_lines.extend(lines)
|
||||
|
||||
# Send the first 'max_lines' non-empty lines
|
||||
i=0
|
||||
sent_lines=0
|
||||
while i < len(irc_lines) and sent_lines < max_lines:
|
||||
line = irc_lines[i].strip()
|
||||
i+=1
|
||||
if line != "":
|
||||
sent_lines += 1
|
||||
await self._message_queue.put((target, line))
|
||||
|
||||
# If there are more than 10 lines, add the truncation message
|
||||
if len(irc_lines) > max_lines:
|
||||
await self._message_queue.put((target, f"...and so on (message truncated to {max_lines} lines.)"))
|
||||
|
||||
async def on_connect(self):
|
||||
await super().on_connect()
|
||||
@ -263,10 +419,7 @@ class DynamicIRCBot(pydle.Client):
|
||||
self.history.append(last_message)
|
||||
self.history = await chat(self.history, is_irc=True)
|
||||
chat_response = self.history[-1]
|
||||
lines = [line.strip() for line in chat_response['content'].split('\n') if line.strip()]
|
||||
if len(lines) > 10:
|
||||
lines = lines[:9] + ["...and so on. I'll send the rest to you in a PRIVMSG (wip) :)"]
|
||||
await self.message(target, "\n".join(lines))
|
||||
await self.message(target, chat_response['content'])
|
||||
return
|
||||
|
||||
command = matches.group(1)
|
||||
@ -274,8 +427,8 @@ class DynamicIRCBot(pydle.Client):
|
||||
logging.info(f"Command directed to {self.nickname}: command={command}, arguments={arguments}")
|
||||
|
||||
match command:
|
||||
case "test":
|
||||
response = f"test received: {arguments}"
|
||||
case "help":
|
||||
response = f"info, context, reset, system [prompt], server [address], join channel"
|
||||
|
||||
case "info":
|
||||
response = str(self.system_info)
|
||||
@ -303,15 +456,16 @@ class DynamicIRCBot(pydle.Client):
|
||||
|
||||
case "server":
|
||||
server = arguments.split(" ", 1)
|
||||
if len(server) == 1:
|
||||
if server[0] == "":
|
||||
server = IRC_SERVER
|
||||
else:
|
||||
server = server[1]
|
||||
server = server[0]
|
||||
try:
|
||||
await self.connect(server, 6667, tls=False)
|
||||
response="Connected to {server}"
|
||||
except Exception:
|
||||
response = f"Unable to connect to {server}"
|
||||
logging.exception({ "error": f"Unable to process message {content}"})
|
||||
reponse = f"Unable to connect to {server}"
|
||||
|
||||
case "join":
|
||||
channel = arguments.strip()
|
||||
@ -444,7 +598,8 @@ async def main():
|
||||
global irc_bot, client, model
|
||||
# Parse command-line arguments
|
||||
args = parse_args()
|
||||
|
||||
if not re.match(r"^#", args.irc_channel):
|
||||
args.irc_channel = f"#{args.irc_channel}"
|
||||
# Setup logging based on the provided level
|
||||
setup_logging(args.level)
|
||||
|
||||
|
@ -305,17 +305,17 @@ tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather_by_location",
|
||||
"description": "Get the current and future weather for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?'",
|
||||
"description": "Get the full weather forecast as structured data for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?' use the provided data to answer the question.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City to find the weather (e.g., 'Portland', 'Seattle')."
|
||||
"description": "City to find the weather forecast (e.g., 'Portland', 'Seattle')."
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "State to find the weather (e.g., 'OR', 'WA')."
|
||||
"description": "State to find the weather forecast (e.g., 'OR', 'WA')."
|
||||
}
|
||||
},
|
||||
"required": [ "city", "state" ],
|
||||
|
Loading…
x
Reference in New Issue
Block a user