使用模型实现从数据库表格中查询各种信息
实践1:查询订单表中的各种信息
编写函数调用:是从OpenAI官方案例中提取出来的
def get_sql_completion(client,database_schema_string,messages, model="qwen-turbo"):response = client.chat.completions.create(model=model,messages=messages,temperature=0.5, tools=[{"type": "function","function": {"name": "ask_database","description": "Use this function to answer user questions about business. \Output should be a fully formed SQL query.","parameters": {"type": "object","properties": {"query": {"type": "string","description": f"""SQL query extracting info to answer the user's question.SQL should be written using this database schema:{database_schema_string}The query should be returned in plain text, not in JSON.The query should only contain grammars supported by SQLite.""",}},"required": ["query"],}}}],)return response.choices[0].message
初始化数据库
def ask_database(cursor,query):cursor.execute(query)records = cursor.fetchall()return recordsdef DBInit(cursor,database_schema_string,DataList):cursor.execute(database_schema_string)# 插入5条明确的模拟记录mock_data = DataListfor record in mock_data:cursor.execute('''INSERT INTO orders (id, customer_id, product_id, price, status, create_time, pay_time)VALUES (?, ?, ?, ?, ?, ?, ?)''', record)# 提交事务conn.commit()def select_all(cursor):# 查询orders表中的所有数据cursor.execute("SELECT * FROM orders")rows = cursor.fetchall()# 打印查询结果for row in rows:print(row)
初始化数据
database_schema_string = """CREATE TABLE orders (id INT PRIMARY KEY NOT NULL, -- 主键,不允许为空customer_id INT NOT NULL, -- 客户ID,不允许为空product_id STR NOT NULL, -- 产品ID,不允许为空price DECIMAL(10,2) NOT NULL, -- 价格,不允许为空status INT NOT NULL, -- 订单状态,整数类型,不允许为空。0代表待支付,1代表已支付,2代表已退款create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 创建时间,默认为当前时间pay_time TIMESTAMP -- 支付时间,可以为空);"""DataList = [(1, 1001, 'TSHIRT_1', 50.00, 0, '2023-10-12 10:00:00', None),(2, 1001, 'TSHIRT_2', 75.50, 1, '2023-10-16 11:00:00', '2023-08-16 12:00:00'),(3, 1002, 'SHOES_X2', 25.25, 2, '2023-10-17 12:30:00', '2023-08-17 13:00:00'),(4, 1003, 'HAT_Z112', 60.75, 1, '2023-10-20 14:00:00', '2023-08-20 15:00:00'),(5, 1002, 'WATCH_X001', 90.00, 0, '2023-10-28 16:00:00', None)
]conn = sqlite3.connect(':memory:')cursor = conn.cursor()DBInit(cursor,database_schema_string,DataList)select_all(cursor)
main函数调用函数来实现查询任务
# prompt = "上个月的销售额"#prompt = "统计每月每件商品的销售额"prompt = "哪个用户消费最高?消费多少?"messages = [{"role": "system", "content": "基于 order 表回答用户问题"},{"role": "user", "content": prompt}]response = get_sql_completion(client,database_schema_string,messages)if response.content is None:response.content = ""messages.append(response)if response.tool_calls is not None:tool_call = response.tool_calls[0]if tool_call.function.name == "ask_database":arguments = tool_call.function.argumentsargs = json.loads(arguments)print(args["query"])result = ask_database(cursor,args["query"])print(result)messages.append({"tool_call_id": tool_call.id,"role": "tool","name": "ask_database","content": str(result)})response = get_sql_completion(client,database_schema_string,messages)print("====最终回复====")print(response.content)
整体代码如下
import json
import sqlite3
from cv2.version import ci_build
from openai import OpenAI
import osimport requests
from openpyxl.chart.label import DataLabelList
from sympy.physics.units import temperaturedef get_sql_completion(client,database_schema_string,messages, model="qwen-turbo"):response = client.chat.completions.create(model=model,messages=messages,temperature=0, # 模型输出的随机性,0 表示随机性最小tools=[{"type": "function","function": {"name": "ask_database","description": "Use this function to answer user questions about business. \Output should be a fully formed SQL query.","parameters": {"type": "object","properties": {"query": {"type": "string","description": f"""SQL query extracting info to answer the user's question.SQL should be written using this database schema:{database_schema_string}The query should be returned in plain text, not in JSON.The query should only contain grammars supported by SQLite.""",}},"required": ["query"],}}}],)return response.choices[0].messagedef ask_database(cursor,query):cursor.execute(query)records = cursor.fetchall()return recordsdef DBInit(cursor,database_schema_string,DataList):cursor.execute(database_schema_string)# 插入5条明确的模拟记录mock_data = DataListfor record in mock_data:cursor.execute('''INSERT INTO orders (id, customer_id, product_id, price, status, create_time, pay_time)VALUES (?, ?, ?, ?, ?, ?, ?)''', record)# 提交事务conn.commit()def select_all(cursor):# 查询orders表中的所有数据cursor.execute("SELECT * FROM orders")rows = cursor.fetchall()# 打印查询结果for row in rows:print(row)if __name__ == '__main__':client = OpenAI(api_key=os.getenv('DASHSCOPE_API_KEY'),base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")database_schema_string = """CREATE TABLE orders (id INT PRIMARY KEY NOT NULL, -- 主键,不允许为空customer_id INT NOT NULL, -- 客户ID,不允许为空product_id STR NOT NULL, -- 产品ID,不允许为空price DECIMAL(10,2) NOT NULL, -- 价格,不允许为空status INT NOT NULL, -- 订单状态,整数类型,不允许为空。0代表待支付,1代表已支付,2代表已退款create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 创建时间,默认为当前时间pay_time TIMESTAMP -- 支付时间,可以为空);"""DataList = [(1, 1001, 'TSHIRT_1', 50.00, 0, '2023-10-12 10:00:00', None),(2, 1001, 'TSHIRT_2', 75.50, 1, '2023-10-16 11:00:00', '2023-08-16 12:00:00'),(3, 1002, 'SHOES_X2', 25.25, 2, '2023-10-17 12:30:00', '2023-08-17 13:00:00'),(4, 1003, 'HAT_Z112', 60.75, 1, '2023-10-20 14:00:00', '2023-08-20 15:00:00'),(5, 1002, 'WATCH_X001', 90.00, 0, '2023-10-28 16:00:00', None)
]conn = sqlite3.connect(':memory:')cursor = conn.cursor()DBInit(cursor,database_schema_string,DataList)select_all(cursor)# prompt = "上个月的销售额"#prompt = "统计每月每件商品的销售额"prompt = "哪个用户消费最高?消费多少?"messages = [{"role": "system", "content": "基于 order 表回答用户问题"},{"role": "user", "content": prompt}]response = get_sql_completion(client,database_schema_string,messages)if response.content is None:response.content = ""messages.append(response)if response.tool_calls is not None:tool_call = response.tool_calls[0]if tool_call.function.name == "ask_database":arguments = tool_call.function.argumentsargs = json.loads(arguments)print(args["query"])result = ask_database(cursor,args["query"])print(result)messages.append({"tool_call_id": tool_call.id,"role": "tool","name": "ask_database","content": str(result)})response = get_sql_completion(client,database_schema_string,messages)print("====最终回复====")print(response.content)
====最终回复====
消费最高的用户是用户ID为1001的用户,该用户的总消费金额为75.5元。
实践2:查询学生信息
随机生成学生信息,格式为[student_id, name, class, gender, scores, total, average, status, enrollment_date, graduation_date]
def generate_student_data(count):# 名字列表names = ["程健", "张伟", "李娜", "王强", "赵敏", "陈晨", "杨洋", "刘磊", "吴婷", "周杰"]# 学号base_student_id = 10000# 学生状态:1表示在读,2表示已毕业,0表示休学statuses = [1, 2, 0]# 入学时间:假设在过去的3年内enrollment_start_date = datetime(2020, 9, 1)enrollment_end_date = datetime(2023, 9, 1)# 生成学生数据students_data = []for i in range(count):name = random.choice(names)student_id = str(base_student_id + i)class_name = f"高三{random.randint(1, 5)}班"# 随机生成成绩scores = {"Chinese": random.randint(60, 100),"Math": random.randint(60, 100),"English": random.randint(60, 100),"Physics": random.randint(60, 100),"Chemistry": random.randint(60, 100),"Biology": random.randint(60, 100)}# 计算总分和平均分total = sum(scores.values())average = total / len(scores)gender = random.choice(["M","F"])# 随机选择学生状态(1: 在读,2: 已毕业,0: 休学)status = random.choice(statuses)# 生成入学时间enrollment_date = enrollment_start_date + (enrollment_end_date - enrollment_start_date) * random.random()enrollment_date_str = enrollment_date.strftime('%Y-%m-%d')# 毕业时间:如果已毕业,随机生成一个毕业日期graduation_date_str = Noneif status == 2:graduation_date = enrollment_date + timedelta(days=random.randint(700, 800)) # 假设毕业时间在入学后2-3年graduation_date_str = graduation_date.strftime('%Y-%m-%d')students_data.append((student_id, name, class_name, gender,str(scores), total, average, status, enrollment_date_str, graduation_date_str))return students_data
main函数
看注释
if __name__ == '__main__':client = OpenAI(api_key=os.getenv('DASHSCOPE_API_KEY'),base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")database_schema_string = """CREATE TABLE students (student_id INT NOT NULL, -- 学号,不允许为空name VARCHAR(50) NOT NULL, -- 姓名,不允许为空gender CHAR(1) NOT NULL, -- 性别,单个字符,'M'表示男,'F'表示女class VARCHAR(50) NOT NULL, -- 班级,不允许为空scores JSON NOT NULL, -- 成绩,存储为JSON类型total INT NOT NULL, -- 总分average DECIMAL(5,2) NOT NULL, -- 平均分status INT NOT NULL DEFAULT 1, -- 学生状态,0表示休学,1表示在读,2表示已毕业enrollment_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 入学时间,默认为当前时间graduation_date TIMESTAMP DEFAULT NULL -- 毕业时间,可以为空
);
"""
#创建数据库conn = sqlite3.connect(':memory:')cursor = conn.cursor()a = generate_student_data(5)
#加载数据并且检查是否完全加载成功DBInit(conn,cursor,database_schema_string,a)select_all(cursor)prompt = "找出英语最高分的班级和姓名"#prompt = “目前在读学生有多少人”#prompt = "计算一下高三4班的平均分是多少"messages = [{"role": "system", "content": "基于 student 表回答用户问题"},{"role": "user", "content": prompt}]response = get_sql_completion(client,database_schema_string,messages)if response.content is None:response.content = ""messages.append(response)print("====Function Calling====")print(response)if response.tool_calls is not None:tool_call = response.tool_calls[0]if tool_call.function.name == "ask_database":arguments = tool_call.function.arguments
#解决输出bug json无法处理[]这个符号if '["Query"]'in arguments:arguments = arguments.replace('["Query"]',"Query")print(arguments)args = json.loads(arguments)print("====SQL====")print(args["Query"])result = ask_database(cursor,args["Query"])print("====DB Records====")print(result)messages.append({"tool_call_id": tool_call.id,"role": "tool","name": "ask_database","content": str(result)})
#根据函数结果进行回复response = get_sql_completion(client,database_schema_string,messages)print("====最终回复====")print(response.content)
完整代码
import json
import sqlite3
from openai import OpenAI
import os
import requests
import randomimport randomdef generate_student_data(count):# 名字列表names = ["程健", "张伟", "李娜", "王强", "赵敏", "陈晨", "杨洋", "刘磊", "吴婷", "周杰"]# 学号base_student_id = 10000# 学生状态:1表示在读,2表示已毕业,0表示休学statuses = [1, 2, 0]# 入学时间:假设在过去的3年内enrollment_start_date = datetime(2020, 9, 1)enrollment_end_date = datetime(2023, 9, 1)# 生成学生数据students_data = []for i in range(count):name = random.choice(names)student_id = str(base_student_id + i)class_name = f"高三{random.randint(1, 5)}班"# 随机生成成绩scores = {"Chinese": random.randint(60, 100),"Math": random.randint(60, 100),"English": random.randint(60, 100),"Physics": random.randint(60, 100),"Chemistry": random.randint(60, 100),"Biology": random.randint(60, 100)}# 计算总分和平均分total = sum(scores.values())average = total / len(scores)gender = random.choice(["M","F"])# 随机选择学生状态(1: 在读,2: 已毕业,0: 休学)status = random.choice(statuses)# 生成入学时间enrollment_date = enrollment_start_date + (enrollment_end_date - enrollment_start_date) * random.random()enrollment_date_str = enrollment_date.strftime('%Y-%m-%d')# 毕业时间:如果已毕业,随机生成一个毕业日期graduation_date_str = Noneif status == 2:graduation_date = enrollment_date + timedelta(days=random.randint(700, 800)) # 假设毕业时间在入学后2-3年graduation_date_str = graduation_date.strftime('%Y-%m-%d')students_data.append((student_id, name, class_name, gender,str(scores), total, average, status, enrollment_date_str, graduation_date_str))return students_datadef get_sql_completion(client,database_schema_string,messages, model="qwen-turbo"):response = client.chat.completions.create(model=model,messages=messages,temperature=0, # 模型输出的随机性,0 表示随机性最小tools=[{"type": "function","function": {"name": "ask_database","description": "Use this function to answer user questions about business. \Output should be a fully formed SQL query.","parameters": {"type": "object","properties": {"Query": {"type": "string","description": f"""SQL query extracting info to answer the user's question.SQL should be written using this database schema:{database_schema_string}The query should be returned in plain text, not in JSON.The query should only contain grammars supported by SQLite.""",}},"required": "Query",}}}],)return response.choices[0].messagedef ask_database(cursor,query):cursor.execute(query)records = cursor.fetchall()return recordsdef DBInit(conn,cursor,database_schema_string,DataList):cursor.execute(database_schema_string)# 插入5条明确的模拟记录for record in DataList:cursor.execute('''INSERT INTO students (student_id, name, class, gender, scores, total, average, status, enrollment_date, graduation_date) VALUES (?,?,?,?,?,?,?,?,?,?)''', record)# 提交事务conn.commit()def select_all(cursor):# 查询orders表中的所有数据cursor.execute("SELECT * FROM students")rows = cursor.fetchall()# 打印查询结果for row in rows:print(row)if __name__ == '__main__':client = OpenAI(api_key=os.getenv('DASHSCOPE_API_KEY'),base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")database_schema_string = """CREATE TABLE students (student_id INT NOT NULL, -- 学号,不允许为空name VARCHAR(50) NOT NULL, -- 姓名,不允许为空gender CHAR(1) NOT NULL, -- 性别,单个字符,'M'表示男,'F'表示女class VARCHAR(50) NOT NULL, -- 班级,不允许为空scores JSON NOT NULL, -- 成绩,存储为JSON类型total INT NOT NULL, -- 总分average DECIMAL(5,2) NOT NULL, -- 平均分status INT NOT NULL DEFAULT 1, -- 学生状态,0表示休学,1表示在读,2表示已毕业enrollment_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 入学时间,默认为当前时间graduation_date TIMESTAMP DEFAULT NULL -- 毕业时间,可以为空
);
"""conn = sqlite3.connect(':memory:')cursor = conn.cursor()a = generate_student_data(5)DBInit(conn,cursor,database_schema_string,a)select_all(cursor)prompt = "找出英语最高分的班级和姓名"messages = [{"role": "system", "content": "基于 student 表回答用户问题"},{"role": "user", "content": prompt}]response = get_sql_completion(client,database_schema_string,messages)if response.content is None:response.content = ""messages.append(response)print("====Function Calling====")print(response)if response.tool_calls is not None:tool_call = response.tool_calls[0]if tool_call.function.name == "ask_database":arguments = tool_call.function.argumentsif '["Query"]'in arguments:arguments = arguments.replace('["Query"]',"Query")print(arguments)args = json.loads(arguments)print("====SQL====")print(args["Query"])result = ask_database(cursor,args["Query"])print("====DB Records====")print(result)messages.append({"tool_call_id": tool_call.id,"role": "tool","name": "ask_database","content": str(result)})response = get_sql_completion(client,database_schema_string,messages)print("====最终回复====")print(response.content)
输出信息为:
(10000, '张伟', 'F', '高三5班', "{'Chinese': 89, 'Math': 99, 'English': 75, 'Physics': 95, 'Chemistry': 98, 'Biology': 76}", 532, 88.66666666666667, 1, '2022-03-13', None)
(10001, '刘磊', 'M', '高三4班', "{'Chinese': 100, 'Math': 98, 'English': 88, 'Physics': 86, 'Chemistry': 86, 'Biology': 86}", 544, 90.66666666666667, 1, '2021-05-11', None)
(10002, '赵敏', 'M', '高三3班', "{'Chinese': 71, 'Math': 62, 'English': 93, 'Physics': 68, 'Chemistry': 94, 'Biology': 80}", 468, 78, 0, '2020-09-15', None)
(10003, '张伟', 'F', '高三5班', "{'Chinese': 83, 'Math': 86, 'English': 94, 'Physics': 84, 'Chemistry': 73, 'Biology': 70}", 490, 81.66666666666667, 2, '2021-03-27', '2023-03-20')
(10004, '周杰', 'M', '高三3班', "{'Chinese': 85, 'Math': 94, 'English': 89, 'Physics': 91, 'Chemistry': 76, 'Biology': 99}", 534, 89, 1, '2022-10-17', None)
====Function Calling====
ChatCompletionMessage(content='', refusal=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_0f92904f36084c65b90b09', function=Function(arguments='{"Query": "SELECT s.class, s.name FROM students s WHERE s.scores->>\'English\' = (SELECT MAX(scores->>\'English\') FROM students) LIMIT 1;"}', name='ask_database'), type='function', index=0)])
{"Query": "SELECT s.class, s.name FROM students s WHERE s.scores->>'English' = (SELECT MAX(scores->>'English') FROM students) LIMIT 1;"}
====SQL====
SELECT s.class, s.name FROM students s WHERE s.scores->>'English' = (SELECT MAX(scores->>'English') FROM students) LIMIT 1;
====DB Records====
[('高三5班', '张伟')]
====最终回复====
英语最高分的学生所在的班级是高三5班,学生姓名是张伟。