相关代码
from vanna.base.base import VannaBase
from vanna.chromadb import ChromaDB_VectorStore
from vanna.ollama import Ollama
import logging
import os
import requests
import json
import pandas as pd
import chromadb
import redis
import pickle
from IPython.display import displaylogging.basicConfig(level=logging.INFO)class MyVanna(ChromaDB_VectorStore, Ollama):def __init__(self, config=None):# 初始化配置self.config = {'model': 'llama2:latest','ollama_host': 'http://127.0.0.1:11434','verbose': True,'temperature': 0.1,'collection_name': 'my_vanna_collection','redis_host': '127.0.0.1','redis_port': 6379,'redis_db': 5,'redis_password': '123456','redis_key_prefix': 'vanna_training:'}if config:self.config.update(config)# 初始化 ChromaDBself.chroma_client = chromadb.PersistentClient(path=self.config['chroma_db_path'])try:self._collection = self.chroma_client.get_collection(self.config['collection_name'])logging.info(f"获取已存在的集合: {self.config['collection_name']}")except:self._collection = self.chroma_client.create_collection(self.config['collection_name'])logging.info(f"创建新的集合: {self.config['collection_name']}")# 初始化 Redis 连接try:self.redis_client = redis.Redis(host=self.config['redis_host'],port=self.config['redis_port'],db=self.config['redis_db'],password=self.config['redis_password'],decode_responses=False,socket_timeout=5,retry_on_timeout=True)# 测试连接self.redis_client.ping()logging.info("Redis 连接成功")except Exception as e:logging.error(f"Redis 连接错误: {str(e)}")raise# 初始化父类ChromaDB_VectorStore.__init__(self, config=self.config)Ollama.__init__(self, config=self.config)self._ddl = Nonedef submit_prompt(self, prompt, **kwargs):"""重写 submit_prompt 方法"""try:url = f"{self.config['ollama_host']}/api/generate"# 如果传入的是消息列表,则组合成单个提示词if isinstance(prompt, list):full_prompt = "\n".join([msg.get('content', '') for msg in prompt if isinstance(msg, dict)])else:full_prompt = promptdata = {"model": self.config['model'],"prompt": full_prompt,"stream": False}headers = {"Content-Type": "application/json"}logging.info(f"发送请求到 Ollama: {url}")logging.debug(f"请求数据: {json.dumps(data, ensure_ascii=False)}")response = requests.post(url, json=data, headers=headers)response.raise_for_status()response_data = response.json()logging.debug(f"Ollama 响应: {json.dumps(response_data, ensure_ascii=False)}")if 'response' in response_data:return response_data['response'].strip()else:logging.error(f"Ollama 响应格式错误: {response_data}")raise ValueError("无效的 Ollama 响应格式")except Exception as e:logging.error(f"提交 prompt 错误: {str(e)}")raisedef train(self, ddl=None, question=None, sql=None, documentation=None):"""重写 train 方法,使用 Redis"""try:if ddl:self._ddl = ddl# 保存 DDL 到 Redisself.redis_client.set(f"{self.config['redis_key_prefix']}ddl", ddl)logging.info("DDL 已保存到 Redis")if question and sql:# 准备训练数据data = {'question': question,'sql': sql,'documentation': documentation or ''}# 生成唯一 IDimport hashlibdoc_id = hashlib.md5(json.dumps(data, ensure_ascii=False).encode()).hexdigest()# 保存到 Rediskey = f"{self.config['redis_key_prefix']}data:{doc_id}"self.redis_client.set(key, pickle.dumps(data))# 将 ID 添加到训练数据集合中self.redis_client.sadd(f"{self.config['redis_key_prefix']}data_ids", doc_id)logging.info(f"训练数据已保存到 Redis: {data}")return Trueexcept Exception as e:logging.error(f"训练错误: {str(e)}")raisedef get_sql_prompt(self, question, ddl=None, similar_questions=None, similar_sql=None, initial_prompt=None, question_sql_list=None, ddl_list=None, doc_list=None,**kwargs):"""重写 get_sql_prompt 方法"""# 使用存储的 DDLif not ddl and self._ddl:ddl = self._ddl# 构建提示词prompt = "你是一个 SQL 专家。请根据以下信息生成 SQL 查询。\n\n"prompt += "### 数据库结构:\n"if ddl:prompt += f"{ddl}\n\n"# 添加文档说明if doc_list:prompt += "### 相关文档:\n"for doc in doc_list:prompt += f"{doc}\n"prompt += "\n"prompt += "### 问题:\n"prompt += f"{question}\n\n"if similar_questions and similar_sql:prompt += "### 相似问题和对应的 SQL:\n"for q, s in zip(similar_questions, similar_sql):prompt += f"\n问题: {q}\nSQL: {s}\n"prompt += "\n### 请生成对应的 SQL 查询 汉字转为简体:\n"return promptdef generate_sql(self, question, **kwargs):try:if self._ddl:kwargs['ddl'] = self._ddlreturn super().generate_sql(question, **kwargs)except Exception as e:logging.error(f"SQL 生成错误: {str(e)}")raisedef get_related_ddl(self, question=None, **kwargs):"""重写 get_related_ddl 方法,从 Redis 获取 DDL"""try:if self._ddl:return self._ddl# 从 Redis 获取 DDLddl = self.redis_client.get(f"{self.config['redis_key_prefix']}ddl")if ddl:self._ddl = ddl.decode()return self._ddlreturn Noneexcept Exception as e:logging.error(f"获取 DDL 错误: {str(e)}")return Nonedef generate_plotly_code(self, question, sql_result=None, **kwargs):"""重写 generate_plotly_code 方法"""try:# 构建提示词prompt = self.get_plotly_prompt(question, sql_result=sql_result, **kwargs)# 添加系统提示词system_prompt = "你是一个数据可视化专家。请根据用户的需求生成 Plotly 图表代码。只返回 Python 代码,不需要其他解释。如果繁体转为简体。"full_prompt = f"{system_prompt}\n\n{prompt}"# 直接调用 submit_promptreturn self.submit_prompt(full_prompt, is_plotly=True)except Exception as e:logging.error(f"生成图表代码错误: {str(e)}")raisedef get_plotly_prompt(self, question, sql=None, sql_result=None, **kwargs):"""重写 get_plotly_prompt 方法"""prompt = f"""请根据以下信息生成 Plotly 图表代码:问题:{question}SQL查询:{sql if sql else ''}查询结果:{sql_result if sql_result else ''}要求:
1. 使用 Plotly Express 生成图表
2. 只返回 Python 代码
3. 不要包含任何解释或说明
4. 确保代码的正确性
5. 如果繁体转为简体
"""return promptdef get_training_data(self):"""重写 get_training_data 方法,使用 Redis"""try:# 获取所有训练数据 IDdata_ids = self.redis_client.smembers(f"{self.config['redis_key_prefix']}data_ids")if not data_ids:logging.info("Redis 中没有找到训练数据")return pd.DataFrame(columns=['question', 'sql', 'documentation'])# 获取所有训练数据documents = []for doc_id in data_ids:try:key = f"{self.config['redis_key_prefix']}data:{doc_id.decode()}"data = self.redis_client.get(key)if data:doc_data = pickle.loads(data)documents.append(doc_data)logging.info(f"从 Redis 加载数据: {doc_data}")except Exception as e:logging.error(f"处理 Redis 数据时出错: {e}")continue# 创建 DataFrameif documents:df = pd.DataFrame(documents)logging.info(f"已加载 {len(df)} 条训练数据")return dfelse:logging.info("没有找到有效的训练数据")return pd.DataFrame(columns=['question', 'sql', 'documentation'])except Exception as e:logging.error(f"获取训练数据错误: {str(e)}")return pd.DataFrame(columns=['question', 'sql', 'documentation'])def train_model(vn):try:# 训练 DDLprint("开始训练 DDL...")ddl = """CREATE TABLE `customer` (`name` int NOT NULL COMMENT '姓名',`gender` int DEFAULT NULL COMMENT '性别(男性=1/女性=2)',`id_card` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '身份证',`mobile` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '手机',`nation` varchar(10) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '民族',`residential_city` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '居住城市',`age` int DEFAULT NULL COMMENT '岁数 年纪',`salary` int NOT NULL COMMENT '薪水',PRIMARY KEY (`name`)) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='customer'"""vn.train(ddl=ddl)print("DDL 训练完成")# 训练示例查询examples = [{'question': "宁波有多少客户?",'sql': "SELECT COUNT(*) as count FROM customer WHERE residential_city like '%宁波%'"},{'question': "有多少女性客户?",'sql': "SELECT COUNT(*) as count FROM customer WHERE gender = 2"},{'question': "客户平均年龄是多少?",'sql': "SELECT AVG(age) as average_age FROM customer"},{'question': "客户平均薪水是多少?",'sql': "SELECT AVG(salary) as average_salary FROM customer"}]for example in examples:print(f"\n训练示例: {example['question']}")vn.train(question=example['question'], sql=example['sql'])print("\n所有训练完成")result = vn.ask("宁波有多少客户?")print(f"\n查询问题: 宁波有多少客户?\n查询结果: {result}")except Exception as e:logging.error(f"训练错误: {str(e)}")raiseif __name__ == "__main__":try:# 初始化 Vannavn = MyVanna()# 连接数据库vn.connect_to_mysql(host='localhost',dbname='test',user='root',password='123456',port=3306)# 训练模型train_model(vn)# 启动 Flask 应用from vanna.flask import VannaFlaskAppapp = VannaFlaskApp(vn)app.run(host='0.0.0.0', port=7123)except Exception as e:logging.error(f"程序运行错误: {str(e)}")
CREATE TABLE `customer` (`name` int NOT NULL COMMENT '姓名',`gender` int DEFAULT NULL COMMENT '性别(男性=1/女性=2)',`id_card` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '身份证',`mobile` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '手机',`nation` varchar(10) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '民族',`residential_city` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '居住城市',`age` int DEFAULT NULL COMMENT '岁数 年纪',`salary` int NOT NULL COMMENT '薪水',`id` int NOT NULL AUTO_INCREMENT COMMENT 'id',PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=21 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='customer';
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('1','1','330201199001011234','13800001111','汉族','宁波','27','5520','1');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('2','2','330201199102022345','13800002222','汉族','宁波','70','7042','2');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('3','1','330201199203033456','13800003333','回族','宁波','94','4119','3');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('4','2','330201199304044567','13800004444','汉族','宁波','60','4886','4');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('5','1','330201199405055678','13800005555','壮族','宁波','5','5762','5');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('6','1','110101199506066789','13800006666','汉族','北京','58','5515','6');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('7','2','310101199607077890','13800007777','汉族','上海','69','2927','7');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('8','1','440101199708088901','13800008888','满族','广州','90','5979','8');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('9','2','500101199809099012','13800009999','汉族','重庆','91','7256','9');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('10','1','610101199910101123','13800010000','回族','西安','28','4067','10');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('11','2','320101199001111234','13800011111','汉族','南京','13','1979','11');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('12','1','330101199002121345','13800012222','畲族','杭州','8','994','12');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('13','2','420101199003131456','13800013333','汉族','武汉','29','1073','13');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('14','1','510101199004141567','13800014444','彝族','成都','84','1441','14');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('15','2','350101199005151678','13800015555','汉族','福州','33','7725','15');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('16','1','370101199006161789','13800016666','汉族','济南','89','3821','16');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('17','2','430101199007171890','13800017777','苗族','长沙','86','3082','17');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('18','1','220101199008181901','13800018888','汉族','长春','48','4170','18');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('19','2','450101199009192012','13800019999','壮族','南宁','30','1498','19');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('20','1','130101199010202123','13800020000','汉族','石家庄','54','941','20');
1. 系统概述
这是一个基于 Ollama 和 Redis 的智能 SQL 问答系统,可以将自然语言问题转换为 SQL 查询语句。系统具有以下主要特点:
- 基于 LLM (Large Language Model) 的自然语言转 SQL
- 支持训练数据的持久化存储
- 提供 REST API 接口
- 支持数据可视化生成
2. 核心组件
2.1 MyVanna 类
主要继承关系:
class MyVanna(ChromaDB_VectorStore, Ollama)
核心配置参数:
self.config = {'model': 'llama2:latest', # LLM 模型'ollama_host': 'http://127.0.0.1:11434', # Ollama 服务地址'temperature': 0.1, # 生成温度'redis_host': '127.0.0.1', # Redis 配置'redis_port': 6379,'redis_db': 5,'redis_password': '123456','redis_key_prefix': 'vanna_training:'
}
3. 主要功能模块
3.1 提示词生成 (Prompt Engineering)
def get_sql_prompt(self, question, ddl=None, similar_questions=None, similar_sql=None, ...):
提示词结构:
- 角色定义
- 数据库结构说明
- 相关文档
- 用户问题
- 相似问题参考
- 输出要求
3.2 训练功能
def train(self, ddl=None, question=None, sql=None, documentation=None):
训练数据包含:
- DDL(数据库结构)
- 问题-SQL 对
- 相关文档
存储方式:
- 使用 Redis 持久化
- 使用 hash 作为唯一标识
- 支持批量训练
3.3 SQL 生成
def generate_sql(self, question, **kwargs):
工作流程:
- 获取相关 DDL
- 构建提示词
- 调用 LLM 生成 SQL
- 错误处理和日志记录
3.4 数据可视化
def generate_plotly_code(self, question, sql_result=None, **kwargs):
特点:
- 使用 Plotly 生成可视化代码
- 支持 SQL 结果的直接可视化
- 自动处理中文编码
4. 示例训练数据
examples = [{'question': "宁波有多少客户?",'sql': "SELECT COUNT(*) as count FROM customer WHERE residential_city like '%宁波%'"},{'question': "有多少女性客户?",'sql': "SELECT COUNT(*) as count FROM customer WHERE gender = 2"}# ...
]
5. 部署和使用
5.1 环境要求
- Python 3.x
- Redis 服务
- MySQL 数据库
- Ollama 服务
5.2 启动服务
if __name__ == "__main__":vn = MyVanna()vn.connect_to_mysql(...)train_model(vn)app = VannaFlaskApp(vn)app.run(host='0.0.0.0', port=7123)
6. 改进建议
-
错误处理优化
- 添加更详细的错误类型
- 实现错误重试机制
-
性能优化
- 添加缓存机制
- 实现批量处理
-
安全性增强
- 添加 SQL 注入防护
- 实现访问控制
-
功能扩展
- 支持更多数据库类型
- 添加更多可视化选项
- 实现对话历史记录
7. 总结
该系统通过结合 LLM 和传统数据库技术,实现了一个灵活的自然语言到 SQL 的转换系统。其模块化设计和可扩展性使其适合在实际业务场景中使用和扩展。
主要优势:
- 模块化设计
- 可扩展架构
- 完整的训练流程
- 持久化存储支持
潜在改进空间:
- 性能优化
- 安全性增强
- 功能扩展
- 错误处理完善