DBChat : Experiments with recursive self-correcting LLM agents. (WIP)
This idea came to me when I was building a slack-bot for a client. It had to be connected to a postgres db and its job was to answer some predefined slash commands. Basic stuff.
This got me thinking if there was a way to make a universal DB wrapper using LLMs, that can answer natural language queries on any database, given only the database connection string. I mean automatically parse and understand the schema and generate sql/no-sql commands for each query.
The basic principle is to first somehow give the language mode enough context about the schema of our database and then get it generate correct sql for a give natural language query.
I had already tried to do it with the earlier versions of gpt3.5 turbo but failed miserably. earlier versions of 3.5 turbo were instruct tuned versions of pruned gpt3 model. It kinda worked for extremely toy databases but quickly fell apart for any real use-case.
But right around june of this year, OpenAI introduced function calling. With it came a new instruct tuned model (gpt-4-0613) which adhered to the instructions better. This combined with 16k context length meant that it understood large database better due to larger context length and understood the natural language better to write sql queries on the data.
Now the remaining challenge was to make a sure that GPT always returned the correct results and not f*ck up the sql queries it generated. This task seems simple on surface but took a lot of experimentation to make it kinda viable.
To benchmark the viability I used my good old nalanda library database that I scraped during my junior year in college. It has about 10k rows of data, multiple tables and relationships and is about 300mb in size. I wrote about 20 queries of varying complexities and benchmarked run-time and accuracy of each query for every experiment(manually).
Here was the first viable experiment :
I extract out the relevant part of schema with examples :
db_params = {
}
import psycopg2
def fetch_database_info(db_params):
result_str = ""
connection = None
try:
# Establish a connection to the database
connection = psycopg2.connect(**db_params)
cursor = connection.cursor()
# Fetch tables
cursor.execute("""SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';""")
tables = cursor.fetchall()
tables_str = "Tables:\n"
for table in tables:
table = table[0]
tables_str += f" - {table}\n"
# Fetch columns for each table
cursor.execute(f"""SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table}';""")
columns = cursor.fetchall()
columns_str = " Columns:\n"
for column in columns:
column_name, data_type = column
columns_str += f" - {column_name} ({data_type})\n"
# Fetch sample data (optional)
cursor.execute(f"SELECT * FROM {table} LIMIT 5;")
sample_data = cursor.fetchall()
sample_data_str = " Sample Data:\n"
for row in sample_data:
row = [str(item)[:30] + ('...' if len(str(item)) > 30 else '') for item in row]
sample_data_str += f" - {row}\n"
tables_str += columns_str + sample_data_str + "\n"
result_str += tables_str
# Fetch relationships (Foreign Key references)
cursor.execute("""
SELECT
tc.table_name, kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
WHERE tc.constraint_type = 'FOREIGN KEY';
""")
relationships = cursor.fetchall()
relationships_str = "Relationships:\n"
for rel in relationships:
table, column, foreign_table, foreign_column = rel
relationships_str += f" - {table}.{column} -> {foreign_table}.{foreign_column}\n"
result_str += relationships_str
except Exception as e:
return f"Error: {e}"
finally:
if connection:
connection.close()
return result_str
# Call the function
database_info_str = fetch_database_info(db_params)
print(database_info_str)
This gives me output similar to this :
Now to the gpt agent is quite simple :
Get natural language query from user > give GPT context about the db + the user query and generate sql from it > Validate and execute SQL > Give back the SQL error to GPT recursively till valid query is generated > Give output of that query + the input query to another agent to summarise > display the summary.
import subprocess
import psycopg2
import openai
import re
import schema
openai.api_key = ""
# Database connection parameters
db_params = {
}
def extract_sql(response):
pattern = r'```sql\n(.*?)\n```'
match = re.search(pattern, response, re.DOTALL)
if match:
return match.group(1).strip()
else:
return None
def get_valid_sql(messages, attempts=0):
if attempts >= 5:
return None, []
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
temperature=1,
messages=messages
)
rawResponse = response.choices[0].message['content']
sqlCommand = extract_sql(rawResponse)
try:
cur.execute(sqlCommand)
rows = cur.fetchall()
return rawResponse, sqlCommand, rows
except Exception as e:
conn.rollback() # Rollback the current transaction
messages.append({"role": "user", "content": "It gave me this error please correct it." + str(e)})
return get_valid_sql(messages, attempts + 1)
def convert_to_natural_language(rows, original_query):
messages = [
{"role": "system", "content": "You are a helper that interprets database results into natural language. Convert the given tuples into a human-readable answer based on the user's input question."},
{"role": "user", "content": f"Original Question: {original_query}"},
{"role": "user", "content": f"Database Results: {rows}"}
]
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=messages
)
return "\n\n"+response.choices[0].message['content']
def fetch_db_schema():
return schema.fetch_database_info(db_params)
# """
# Fetch the database schema: list of tables, columns, and their types.
# """
# try:
# cur.execute("""
# SELECT
# table_name,
# column_name,
# data_type
# FROM
# information_schema.columns
# WHERE
# table_schema = 'public'
# ORDER BY
# table_name,
# ordinal_position
# """)
# rows = cur.fetchall()
# schema_desc = ""
# current_table = ""
# for row in rows:
# table_name, column_name, data_type = row
# if current_table != table_name:
# schema_desc += f"\nTable: {table_name}\n"
# current_table = table_name
# schema_desc += f" {column_name} ({data_type})\n"
# print(schema_desc)
# return schema_desc
# except Exception as e:
# print("Error fetching schema:", e)
# return ""
def handle_user_query(messages):
while True: # Continue indefinitely until the user decides to exit
query = input("\nWhats the query? (or type 'exit' to stop) ")
if query.lower() == 'exit':
break
messages.append({"role": "user", "content": query})
rawResponse, sqlCommand, rows = get_valid_sql(messages)
if sqlCommand:
print("\nExecuted SQL Command:", sqlCommand)
messages.append({"role": "assistant", "content": f"SQL: {rawResponse}"}) # Add SQL to messages list
# print(messages)
natural_language_response = convert_to_natural_language(rows, query)
print(natural_language_response)
else:
print("\nFailed to generate a valid SQL command after 5 attempts.")
# Connect to the database
conn = psycopg2.connect(**db_params)
cur = conn.cursor()
# Get database schema
schema = schema.fetch_database_info(db_params)
# Prepare initial messages for SQL generation
messages = [
{"role": "system", "content": "you are a sql generator program. user inputs their db schema and they query in english. Your job is to think step by step and generate sql for that query to the best of your ability. Important !! : user input is case insensitive. wrap the sql part in '''sql ''' so that its easy for me to understand."},
{"role": "user", "content": schema}
]
# Start handling user queries
handle_user_query(messages)
# Close the database connection
cur.close()
conn.close()
Extremely rudimentary but works on 16/20 of my test queries.
Next part is making it more robust for complex databases and making it modular to work on any given database string. I'll also share my benchmark and testing suite in the next part. I'll also be testing this(https://github.com/defog-ai/sqlcoder) OSS llm on my bench.
To be continued....