How to Create an AI Application That Can Chat with Massive SQL Databases

Harshit Ahluwalia 16 May, 2024 • 8 min read

Introduction

You can easily create a simple application that can chat with SQL Database. But here’s the problem with that. You can’t make it work seamlessly when it comes to handling and working with large databases. If the database is huge, it’s impractical to include the complete list of columns and tables in the prompt context. This article explains how to bypass this obstacle by creating an AI application that can chat with massive SQL databases.

How to Create an AI Application That Can Chat with Massive SQL Databases

Creating an AI Application to Chat with Massive SQL Databases

The following code initiates a simple Streamlit application that enables users to connect to an SQL database and chat with it.

import streamlit as st
import requests
import os
import pandas as pd
from uuid import uuid4
import psycopg2
from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import SystemMessage, HumanMessagePromptTemplate
from langchain.llms import OpenAI, AzureOpenAI
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from dotenv import load_dotenv

# Create necessary folders
folders_to_create = ['csvs']
for folder_name in folders_to_create:
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
        print(f"Folder '{folder_name}' created.")
    else:
        print(f"Folder '{folder_name}' already exists.")

# Load the OpenAI API key
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
llm = OpenAI(openai_api_key=openai_api_key)
chat_llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0.4)
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)

def get_basic_table_details(cursor):
    query = """
    SELECT table_name, column_name, data_type
    FROM information_schema.columns
    WHERE table_name IN (
        SELECT tablename FROM pg_tables WHERE schemaname = 'public'
    );"""
    cursor.execute(query)
    return cursor.fetchall()

def save_db_details(db_uri):
    unique_id = str(uuid4()).replace("-", "_")
    connection = psycopg2.connect(db_uri)
    cursor = connection.cursor()
    tables_and_columns = get_basic_table_details(cursor)
    df = pd.DataFrame(tables_and_columns, columns=['table_name', 'column_name', 'data_type'])
    filename_t = f'csvs/tables_{unique_id}.csv'
    df.to_csv(filename_t, index=False)
    cursor.close()
    connection.close()
    return unique_id

def generate_template_for_sql(query, table_info, db_uri):
    template = ChatPromptTemplate.from_messages([
        SystemMessage(content=f"You are an assistant that can write SQL Queries. Given the text below, write a SQL query that answers the user's question. DB connection string is {db_uri} Here is a detailed description of the table(s): {table_info} Prepend and append the SQL query with three backticks '```'"),
        HumanMessagePromptTemplate.from_template("{text}")
    ])
    answer = chat_llm(template.format_messages(text=query))
    return answer.content

def get_the_output_from_llm(query, unique_id, db_uri):
    filename_t = f'csvs/tables_{unique_id}.csv'
    df = pd.read_csv(filename_t)
    table_info = ''
    for table in df['table_name'].unique():
        table_info += f'Information about table {table}:\n'
        table_info += df[df['table_name'] == table].to_string(index=False) + '\n\n'
    return generate_template_for_sql(query, table_info, db_uri)

def execute_the_solution(solution, db_uri):
    connection = psycopg2.connect(db_uri)
    cursor = connection.cursor()
    _, final_query, _ = solution.split("```")
    cursor.execute(final_query.strip())
    result = cursor.fetchall()
    return str(result)

def connect_with_db(uri):
    st.session_state.db_uri = uri
    st.session_state.unique_id = save_db_details(uri)
    return {"message": "Database connection established!"}

def send_message(message):
    solution = get_the_output_from_llm(message, st.session_state.unique_id, st.session_state.db_uri)
    result = execute_the_solution(solution, st.session_state.db_uri)
    return {"message": solution + "\n\n" + "Result:\n" + result}

# Streamlit interface setup
st.subheader("Instructions")
st.markdown("1. Enter your RDS Database URI below.\n2. ")

The fundamental strategy for simplifying the prompt involves sending only the relevant tables and column names that pertain to the user’s query. To achieve this, we can generate embeddings for the table and column names, dynamically retrieve the most pertinent ones based on the user’s input, and include these in the prompt. In this article, we’ll utilize ChromaDB as our vector database. However, alternatives like Pinecone, Milvus, or any suitable vector database can be used.

How to Simplify the Prompt

Let’s begin by installing ChromaDB.

pip install chromadb

First, we’ll set up an additional folder named ‘vectors’ alongside the ‘csvs’ folder to store embeddings of table and column names. This folder can also contain other pertinent database details, such as the foreign keys that link different tables, and potential values for the WHERE clause.

def generate_embeddings(filename, storage_folder):
    csv_loader = CSVLoader(file_path=filename, encoding="utf8")
    dataset = csv_loader.load()
    vector_database = Chroma.from_documents(dataset, embedding=embeddings, persist_directory=storage_folder)
    vector_database.persist()

We will also first check whether the user’s query needs any information about tables or if the user is instead asking about just the general schema of the database.

def check_user_intent_for_database_info_or_sql(query):
    # Define a template for the conversation
    prompt_template = ChatPromptTemplate.from_messages([
        SystemMessage(
            content=(
                "Based on the provided text, the user is asking a question about databases. "
                "Determine if the user seeks information about the database schema or if they want to write a SQL query. "
                "Respond with 'yes' if the user is seeking information about the database schema and 'no' if they intend to write a SQL query."
            )
        ),
        HumanMessagePromptTemplate.from_template("{text}"),
    ])
    
    # Generate a response using a language model
    response = chat_llm(prompt_template.format_messages(text=query))
    print(response.content)
    return response.content

The user responds with either ‘yes’ or ‘no’. If the answer is ‘yes’, a prompt is generated.

def generate_sql_query_prompt(query, db_uri):
    # Configure the chat template with system and human messages
    prompt_template = ChatPromptTemplate.from_messages([
        SystemMessage(
            content=(
                "As an assistant tasked with writing SQL queries, create a SQL query based on the text below. "
                "Enclose the SQL query within three backticks '```' for clarity. "
                "Aim to use 'SELECT' queries as much as possible. "
                f"The connection string for the database is {db_uri}."
            )
        ),
        HumanMessagePromptTemplate.from_template("{text}"),
    ])
    
    # Generate and print the answer using a language model
    response = chat_llm.prompt_template.format_messages(text=query))
    print(response.content)
    return response.content

If the answer is ‘no’, it indicates that the user’s query specifically requires the names of tables and columns within those tables. We will then identify the most relevant tables & columns, and construct a string from these to include in our prompt.

Next, we’ll verify that our vectors have been successfully created and that all other components are functioning correctly. Below is the complete code up to this point.

import streamlit as st
import requests
import os
import pandas as pd
from uuid import uuid4
import psycopg2

from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import SystemMessage, HumanMessagePromptTemplate

from langchain.llms import OpenAI, AzureOpenAI
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from dotenv import load_dotenv
from langchain.vectorstores import Chroma
from langchain.document_loaders.csv_loader import CSVLoader

# Create necessary folders for data storage
folders_to_create = ['csvs', 'vectors']
for folder_name in folders_to_create:
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
        print(f"Folder '{folder_name}' created.")
    else:
        print(f"Folder '{folder_name}' already exists.")

# Load API key from environment variable
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")

# Initialize language models and embeddings
llm = OpenAI(openai_api_key=openai_api_key)
chat_llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0.4)
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)

# Function to retrieve basic table details from the database
def get_basic_table_details(cursor):
    cursor.execute("""
        SELECT c.table_name, c.column_name, c.data_type
        FROM information_schema.columns c
        WHERE c.table_name IN (
            SELECT tablename
            FROM pg_tables
            WHERE schemaname = 'public'
        );""")
    return cursor.fetchall()

# Function to create vector databases from CSV files
def create_vectors(filename, persist_directory):
    loader = CSVLoader(file_path=filename, encoding="utf8")
    data = loader.load()
    vectordb = Chroma.from_documents(data, embedding=embeddings, persist_directory=persist_directory)
    vectordb.persist()

# Function to save database details and generate vectors
def save_db_details(db_uri):
    unique_id = str(uuid4()).replace("-", "_")
    connection = psycopg2.connect(db_uri)
    cursor = connection.cursor()
    tables_and_columns = get_basic_table_details(cursor)
    df = pd.DataFrame(tables_and_columns, columns=['table_name', 'column_name', 'data_type'])
    filename_t = 'csvs/tables_' + unique_id + '.csv'
    df.to_csv(filename_t, index=False)
    create_vectors(filename_t, "./vectors/tables_"+ unique_id)
    cursor.close()
    connection.close()
    return unique_id

# Function to generate SQL query templates
def generate_template_for_sql(query, table_info, db_uri):
    template = ChatPromptTemplate.from_messages([
        SystemMessage(
            content=(
                f"You are an assistant that can write SQL Queries. Given the text below, write a SQL query that answers the user's question. DB connection string is {db_uri}. Here is a detailed description of the table(s): {table_info} Prepend and append the SQL query with three backticks '```'"
            )
        ),
        HumanMessagePromptTemplate.from_template("{text}"),
    ])
    return chat_llm(template.format_messages(text=query)).content

# Function to determine if user's query is about general schema information or SQL
def check_if_users_query_want_general_schema_information_or_sql(query):
    template = ChatPromptTemplate.from_messages([
        SystemMessage(
            content=(
                f"In the given text, the user is asking a question about the database. Determine whether the user wants information about the database schema or wants to write a SQL query. Answer 'yes' for schema information and 'no' for SQL query."
            )
        ),
        HumanMessagePromptTemplate.from_template("{text}"),
    ])
    answer = chat_llm(template.format_messages(text=query))
    print(answer.content)
    return answer.content

# Function to prompt when user wants general database information
def prompt_when_user_want_general_db_information(query, db_uri):
    template = ChatPromptTemplate.from_messages([
        SystemMessage(
            content=(
                "You are an assistant who writes SQL queries. Given the text below, write a SQL query that answers the user's question. Prepend and append the SQL query with three backticks '```' Write select query whenever possible Connection string to this database is {db_uri}"
            )
        ),
        HumanMessagePromptTemplate.from_template("{text}"),
    ])
    answer = chat_llm(template.format_messages(text=query))
    print(answer.content)
    return answer.content

# Function to process user queries and generate outputs based on whether it's about general schema or specific SQL query
def get_the_output_from_llm(query, unique_id, db_uri):
    filename_t = 'csvs/tables_' + unique_id + '.csv'
    df = pd.read_csv(filename_t)
    table_info = ''
    for table in df['table_name'].unique():
        table_info += 'Information about table ' + table + ':\n'
        table_info += df[df['table_name'] == table].to_string(index=False) + '\n\n\n'
    answer_to_question_general_schema = check_if_users_query_want_general_schema_information_or_sql(query)
    if answer_to_question_general_schema == "yes":
        return prompt_when_user_want_general_db_information(query, db_uri)
    return generate_template_for_sql(query, table_info, db_uri)

# Function to execute SQL solutions
def execute_the_solution(solution, db_uri):
    connection = psycopg2.connect(db_uri)
    cursor = connection.cursor()
    _, final_query, _ = solution.split("```")
    final_query = final_query.strip('sql')
    cursor.execute(final_query)
    result = cursor.fetchall()
    return str(result)

# Streamlit app setup and interaction handling
if __name__ == "__main__":
    st.subheader("Instructions")
    st.markdown("""
        1. Enter the URI of your RDS Database in the text box below.
        2. Click the **Start Chat** button to start the chat.
        3. Enter your message in the text box below and press **Enter** to send the message to the API.
    """)
    chat_history = []
    uri = st.text_input("Enter the RDS Database URI")
    if st.button("Start Chat"):
        if not uri:
            st.warning("Please enter a valid database URI.")
        else:
            st.info("Connecting to the API and starting the chat...")
            chat_response = connect_with_db(uri)
            if "error" in chat_response:
                st.error("Error: Failed to start the chat. Please check the URI and try again.")
            else:
                st.success("Chat started successfully!")
    st.subheader("Chat with the API")
    if "messages" not in st.session_state:
        st.session_state.messages = []
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
    if prompt := st.chat_input("What is up?"):
        st.chat_message("user").markdown(prompt)
        st.session_state.messages.append({"role": "user", "content": prompt})
        response = send_message(prompt)["message"]
        with st.chat_message("assistant"):
            st.markdown(response)
        st.session_state.messages.append({"role": "assistant", "content": response})
    st.write("This is a simple Streamlit app for starting a chat with an RDS Database.")

In the next step, we will perform vector retrieval to identify the most relevant tables. Once these tables are selected, we will gather all the column details to provide additional context for our prompt. We will then compile this information into a string to include in the prompt.

# Initialize the vector database for storing table embeddings
vectordb = Chroma(embedding_function=embeddings, persist_directory=f"./vectors/tables_{unique_id}")
retriever = vectordb.as_retriever()
docs = retriever.get_relevant_documents(query)
print(docs)

# Collecting the relevant tables and their columns
relevant_tables = []
relevant_table_details = []

for doc in docs:
    table_info = doc.page_content.split("\n")
    table_name = table_info[0].split(":")[1].strip()
    column_name = table_info[1].split(":")[1].strip()
    data_type = table_info[2].split(":")[1].strip()
    relevant_tables.append(table_name)
    relevant_table_details.append((table_name, column_name, data_type))

# Load data about all tables from a CSV file
filename_t = f'csvs/tables_{unique_id}.csv'
df = pd.read_csv(filename_t)

# Construct a descriptive string for each relevant table, including all columns and their data types
table_info = ''
for table in relevant_tables:
    table_info += f'Information about table {table}:\n'
    table_info += df[df['table_name'] == table].to_string(index=False) + '\n\n\n'

def create_sql_query_template(query, relevant_tables, table_info):
    tables_list = ",".join(relevant_tables)
    chat_template = ChatPromptTemplate.from_messages([
        SystemMessage(
            content=(
                f"As an assistant capable of composing SQL queries, please write a query that resolves the user's inquiry based on the text provided. "
                f"Consider SQL tables named '{tables_list}'. "
                f"Below is a detailed description of these table(s): "
                f"{table_info}"
                "Enclose the SQL query within three backticks '```' for proper formatting."
            )
        ),
        HumanMessagePromptTemplate.from_template("{text}"),
    ])
    
    response = chat_llm(chat_template.format_messages(text=query))
    print(response.content)
    return response.content

One final thing we can do is to give information about foreign keys to the prompt.

import streamlit as st
import os
import pandas as pd
from uuid import uuid4
import psycopg2

from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import SystemMessage, HumanMessagePromptTemplate
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from dotenv import load_dotenv
from langchain.vectorstores import Chroma
from langchain.document_loaders.csv_loader import CSVLoader

# Ensure necessary directories exist
folders_to_create = ['csvs', 'vectors']
for folder in folders_to_create:
    os.makedirs(folder, exist_ok=True)
    print(f"Directory '{folder}' checked or created.")

# Load environment and API keys
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")

# Initialize language models and embeddings
language_model = OpenAI(openai_api_key=openai_api_key)
chat_language_model = ChatOpenAI(openai_api_key=openai_api_key, temperature=0.4)
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)

def fetch_table_details(cursor):
    sql = """
        SELECT table_name, column_name, data_type
        FROM information_schema.columns
        WHERE table_schema = 'public';
    """
    cursor.execute(sql)
    return cursor.fetchall()

def fetch_foreign_key_details(cursor):
    sql = """
        SELECT conrelid::regclass AS table_name, conname AS foreign_key,
               pg_get_constraintdef(oid) AS constraint_definition
        FROM pg_constraint
        WHERE contype = 'f' AND connamespace = 'public'::regnamespace;
    """
    cursor.execute(sql)
    return cursor.fetchall()

def create_vector_database(data, directory):
    loader = CSVLoader(data=data, encoding="utf8")
    document_data = loader.load()
    vector_db = Chroma(embeddings, persist_directory=directory)
    vector_db.from_documents(document_data)
    vector_db.persist()

def save_database_details(uri):
    unique_id = str(uuid4()).replace("-", "_")
    conn = psycopg2.connect(uri)
    cur = conn.cursor()
    details = fetch_table_details(cur)
    df = pd.DataFrame(details, columns=['table_name', 'column_name', 'data_type'])
    csv_path = f'csvs/tables_{unique_id}.csv'
    df.to_csv(csv_path, index=False)
    create_vector_database(df, f"./vectors/tables_{unique_id}")
    
    foreign_keys = fetch_foreign_key_details(cur)
    fk_df = pd.DataFrame(foreign_keys, columns=['table_name', 'foreign_key', 'constraint_definition'])
    fk_csv_path = f'csvs/foreign_keys_{unique_id}.csv'
    fk_df.to_csv(fk_csv_path, index=False)
    
    cur.close()
    conn.close()
    return unique_id

def generate_sql_query_template(query, db_uri):
    template = ChatPromptTemplate.from_messages([
        SystemMessage(
            content=(
                f"You are an assistant capable of composing SQL queries. Use the details provided to write a relevant SQL query for the question below. DB connection string is {db_uri}."
                "Enclose the SQL query with three backticks '```'."
            )
        ),
        HumanMessagePromptTemplate.from_template("{text}"),
    ])
    response = chat_language_model(template.format_messages(text=query))
    return response.content

# Streamlit application setup
st.title("Database Interaction Tool")
uri = st.text_input("Enter the RDS Database URI")
if st.button("Connect to Database"):
    if uri:
        try:
            unique_id = save_database_details(uri)
            st.success(f"Connected to database and data saved with ID: {unique_id}")
        except Exception as e:
            st.error(f"Failed to connect: {str(e)}")
    else:
        st.warning("Please enter a valid database URI.")

In similar ways, we can keep enhancing this application by adding fallbacks. In each fallback, we can keep adding additional information.

Conclusion

This article presents a novel approach to developing an AI application capable of seamlessly interacting with massive SQL databases through chat. We have addressed the challenge of handling large databases where it is impractical to include the complete list of columns and tables in the prompt. Our proposed solution dynamically retrieves relevant table and column names based on user queries. We ensure the prompt includes only pertinent information, enhancing user experience and efficiency. This is done by leveraging vector databases like ChromaDB for embedding generation and retrieval.

We have demonstrated how to streamline the interaction process through step-by-step implementation and code examples. Meanwhile, we also worked on continuously improving the application’s functionality. With further enhancements such as incorporating foreign keys and additional fallbacks, this application holds promise for diverse database interaction scenarios.

Frequently Asked Questions

Lorem ipsum dolor sit amet, consectetur adipiscing elit,

Responses From Readers

Clear