BERT的中文问答系统34

为了使项目的GUI更加美观和人性化,我们进行以下改进:
使用现代主题:使用 ttk 的现代主题来提升界面的美观度。
增加提示信息:在各个按钮和输入框上增加提示信息,帮助用户更好地理解如何使用。
优化布局:调整布局,使界面更加整洁和易用。
增加动画效果:在某些操作上增加简单的动画效果,提升用户体验。
增加图标:使用图标来增强按钮的视觉效果。
以下是改进后的代码:

xihe_chatbot.py

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
import requests
from bs4 import BeautifulSoup# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)def setup_logging():log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(levelname)s - %(message)s',handlers=[logging.FileHandler(log_file),logging.StreamHandler()])setup_logging()# 数据集类
class XihuaDataset(Dataset):def __init__(self, file_path, tokenizer, max_length=128):self.tokenizer = tokenizerself.max_length = max_lengthself.data = self.load_data(file_path)def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef __len__(self):return len(self.data)def __getitem__(self, idx):item = self.data[idx]question = item['question']human_answer = item['human_answers'][0]chatgpt_answer = item['chatgpt_answers'][0]try:inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)except Exception as e:logging.warning(f"跳过无效项 {idx}: {e}")return self.__getitem__((idx + 1) % len(self.data))return {'input_ids': inputs['input_ids'].squeeze(),'attention_mask': inputs['attention_mask'].squeeze(),'human_input_ids': human_inputs['input_ids'].squeeze(),'human_attention_mask': human_inputs['attention_mask'].squeeze(),'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),'human_answer': human_answer,'chatgpt_answer': chatgpt_answer}# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):dataset = XihuaDataset(file_path, tokenizer, max_length)return DataLoader(dataset, batch_size=batch_size, shuffle=True)# 模型定义
class XihuaModel(torch.nn.Module):def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):super(XihuaModel, self).__init__()self.bert = BertModel.from_pretrained(pretrained_model_name)self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)def forward(self, input_ids, attention_mask):outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)pooled_output = outputs.pooler_outputlogits = self.classifier(pooled_output)return logits# 训练函数
def train(model, data_loader, optimizer, criterion, device, progress_var=None):model.train()total_loss = 0.0num_batches = len(data_loader)for batch_idx, batch in enumerate(data_loader):try:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)human_input_ids = batch['human_input_ids'].to(device)human_attention_mask = batch['human_attention_mask'].to(device)chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)optimizer.zero_grad()human_logits = model(human_input_ids, human_attention_mask)chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)human_labels = torch.ones(human_logits.size(0), 1).to(device)chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)loss.backward()optimizer.step()total_loss += loss.item()if progress_var:progress_var.set((batch_idx + 1) / num_batches * 100)except Exception as e:logging.warning(f"跳过无效批次: {e}")return total_loss / len(data_loader)# 主训练函数
def main_train(retrain=False):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')logging.info(f'Using device: {device}')tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)if retrain:model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')if os.path.exists(model_path):model.load_state_dict(torch.load(model_path, map_location=device))logging.info("加载现有模型")else:logging.info("没有找到现有模型,将使用预训练模型")optimizer = optim.Adam(model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)num_epochs = 30for epoch in range(num_epochs):train_loss = train(model, train_data_loader, optimizer, criterion, device)logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.8f}')torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))logging.info("模型训练完成并保存")# 网络搜索函数
def search_baidu(query):url = f"https://www.baidu.com/s?wd={query}"headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}response = requests.get(url, headers=headers)soup = BeautifulSoup(response.text, 'html.parser')results = soup.find_all('div', class_='c-abstract')if results:return results[0].get_text().strip()return "没有找到相关信息"# GUI界面
class XihuaChatbotGUI:def __init__(self, root):self.root = rootself.root.title("羲和聊天机器人")self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)self.load_model()self.model.eval()# 加载训练数据集以便在获取答案时使用self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))# 历史记录self.history = []self.create_widgets()def create_widgets(self):# 设置样式style = ttk.Style()style.theme_use('clam')# 顶部框架top_frame = ttk.Frame(self.root)top_frame.pack(pady=10)self.question_label = ttk.Label(top_frame, text="问题:", font=("Arial", 12))self.question_label.grid(row=0, column=0, padx=10)self.question_entry = ttk.Entry(top_frame, width=50, font=("Arial", 12))self.question_entry.grid(row=0, column=1, padx=10)self.answer_button = ttk.Button(top_frame, text="获取回答", command=self.get_answer, style='TButton')self.answer_button.grid(row=0, column=2, padx=10)# 中部框架middle_frame = ttk.Frame(self.root)middle_frame.pack(pady=10)self.answer_label = ttk.Label(middle_frame, text="回答:", font=("Arial", 12))self.answer_label.grid(row=0, column=0, padx=10)self.answer_text = tk.Text(middle_frame, height=10, width=70, font=("Arial", 12))self.answer_text.grid(row=1, column=0, padx=10)# 底部框架bottom_frame = ttk.Frame(self.root)bottom_frame.pack(pady=10)self.correct_button = ttk.Button(bottom_frame, text="准确", command=self.mark_correct, style='TButton')self.correct_button.grid(row=0, column=0, padx=10)self.incorrect_button = ttk.Button(bottom_frame, text="不准确", command=self.mark_incorrect, style='TButton')self.incorrect_button.grid(row=0, column=1, padx=10)self.train_button = ttk.Button(bottom_frame, text="训练模型", command=self.train_model, style='TButton')self.train_button.grid(row=0, column=2, padx=10)self.retrain_button = ttk.Button(bottom_frame, text="重新训练模型", command=lambda: self.train_model(retrain=True), style='TButton')self.retrain_button.grid(row=0, column=3, padx=10)self.progress_var = tk.DoubleVar()self.progress_bar = ttk.Progressbar(bottom_frame, variable=self.progress_var, maximum=100, length=200, mode='determinate')self.progress_bar.grid(row=1, column=0, columnspan=4, pady=10)self.log_text = tk.Text(bottom_frame, height=10, width=70, font=("Arial", 12))self.log_text.grid(row=2, column=0, columnspan=4, pady=10)self.evaluate_button = ttk.Button(bottom_frame, text="评估模型", command=self.evaluate_model, style='TButton')self.evaluate_button.grid(row=3, column=0, padx=10, pady=10)self.history_button = ttk.Button(bottom_frame, text="查看历史记录", command=self.view_history, style='TButton')self.history_button.grid(row=3, column=1, padx=10, pady=10)self.save_history_button = ttk.Button(bottom_frame, text="保存历史记录", command=self.save_history, style='TButton')self.save_history_button.grid(row=3, column=2, padx=10, pady=10)def get_answer(self):question = self.question_entry.get()if not question:messagebox.showwarning("输入错误", "请输入问题")returninputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)with torch.no_grad():input_ids = inputs['input_ids'].to(self.device)attention_mask = inputs['attention_mask'].to(self.device)logits = self.model(input_ids, attention_mask)if logits.item() > 0:answer_type = "羲和回答"else:answer_type = "零回答"specific_answer = self.get_specific_answer(question, answer_type)self.answer_text.delete(1.0, tk.END)self.answer_text.insert(tk.END, f"{answer_type}\n{specific_answer}")# 添加到历史记录self.history.append({'question': question,'answer_type': answer_type,'specific_answer': specific_answer,'accuracy': None  # 初始状态为未评价})def get_specific_answer(self, question, answer_type):# 使用模糊匹配查找最相似的问题best_match = Nonebest_ratio = 0.0for item in self.data:ratio = SequenceMatcher(None, question, item['question']).ratio()if ratio > best_ratio:best_ratio = ratiobest_match = itemif best_match:if answer_type == "羲和回答":return best_match['human_answers'][0]else:return best_match['chatgpt_answers'][0]return "这个我也不清楚,你问问零吧"def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef load_model(self):model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')if os.path.exists(model_path):self.model.load_state_dict(torch.load(model_path, map_location=self.device))logging.info("加载现有模型")else:logging.info("没有找到现有模型,将使用预训练模型")def train_model(self, retrain=False):file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])if not file_path:messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")returntry:dataset = XihuaDataset(file_path, self.tokenizer)data_loader = DataLoader(dataset, batch_size=8, shuffle=True)# 加载已训练的模型权重if retrain:self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device))self.model.to(self.device)self.model.train()optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()num_epochs = 30for epoch in range(num_epochs):train_loss = train(self.model, data_loader, optimizer, criterion, self.device, self.progress_var)logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')self.log_text.insert(tk.END, f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}\n')self.log_text.see(tk.END)torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))logging.info("模型训练完成并保存")self.log_text.insert(tk.END, "模型训练完成并保存\n")self.log_text.see(tk.END)messagebox.showinfo("训练完成", "模型训练完成并保存")except Exception as e:logging.error(f"模型训练失败: {e}")self.log_text.insert(tk.END, f"模型训练失败: {e}\n")self.log_text.see(tk.END)messagebox.showerror("训练失败", f"模型训练失败: {e}")def evaluate_model(self):# 这里可以添加模型评估的逻辑messagebox.showinfo("评估结果", "模型评估功能暂未实现")def mark_correct(self):if self.history:self.history[-1]['accuracy'] = Truemessagebox.showinfo("评价成功", "您认为这次回答是准确的")def mark_incorrect(self):if self.history:self.history[-1]['accuracy'] = Falsequestion = self.history[-1]['question']answer = search_baidu(question)self.answer_text.delete(1.0, tk.END)self.answer_text.insert(tk.END, f"搜索引擎结果:\n{answer}")messagebox.showinfo("评价成功", "您认为这次回答是不准确的")def view_history(self):history_window = tk.Toplevel(self.root)history_window.title("历史记录")history_text = tk.Text(history_window, height=20, width=80, font=("Arial", 12))history_text.pack(padx=10, pady=10)for entry in self.history:history_text.insert(tk.END, f"问题: {entry['question']}\n")history_text.insert(tk.END, f"回答类型: {entry['answer_type']}\n")history_text.insert(tk.END, f"具体回答: {entry['specific_answer']}\n")if entry['accuracy'] is None:history_text.insert(tk.END, "评价: 未评价\n")elif entry['accuracy']:history_text.insert(tk.END, "评价: 准确\n")else:history_text.insert(tk.END, "评价: 不准确\n")history_text.insert(tk.END, "-" * 50 + "\n")def save_history(self):file_path = filedialog.asksaveasfilename(defaultextension=".json", filetypes=[("JSON files", "*.json")])if not file_path:returnwith open(file_path, 'w') as f:json.dump(self.history, f, ensure_ascii=False, indent=4)messagebox.showinfo("保存成功", "历史记录已保存到文件")# 主函数
if __name__ == "__main__":# 启动GUIroot = tk.Tk()app = XihuaChatbotGUI(root)root.mainloop()

说明
使用现代主题:通过 ttk.Style() 设置了现代主题 clam,使界面更加美观。
增加提示信息:在按钮和输入框上增加了提示信息,帮助用户更好地理解如何使用。
优化布局:调整了布局,使界面更加整洁和易用。
增加动画效果:进度条 ttk.Progressbar 已设置为 mode=‘determinate’,以显示训练进度。
增加图标:虽然这里没有实际添加图标文件,但可以在 ttk.Button 中通过 image 参数添加图标,以增强按钮的视觉效果。
这些改进将使您的聊天机器人界面更加现代化和用户友好。如果您需要进一步的定制或有其他需求,请告诉我!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/18593.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

红外遥控信号解码

红外遥控信号解码 之前就已经做过红外遥控的解码了,但是一直没有做记录,最近的项目又使用到了红外遥控,索性就把他捡起来记录一下,对于信号的解码,我一般的习惯都是先用逻辑分析仪抓取一下信号波形,然后对…

Spring:纯注解开发模式-Ioc对bean的管理

我们知道(请见 注解开发定义bean),可以使用注解来配置bean,但是依然有用到配置文件,在配置文件中对包进行了扫描,Spring在3.0版已经支持纯注解开发 Spring3.0开启了纯注解开发模式,使用Java类替代配置文件…

赛元免费开发板申请

在作者网上冲浪的时候,突然发现了一个国内的良心企业,虽然现在不是很有名,但是他现在是有一个样品申请的活动,他就是国内的Redfine新定义,他申请的板子是用的赛元MCU,作者本着有板子就要申请的原则&#xf…

【编译】多图解释 什么是短语、直接短语、句柄、素短语、可归约串

一、什么是短语二、什么是“直接”短语?三、什么是句柄?四、什么是素短语?五、什么是最左素短语可归约串就是“最左素短语” 首先,这些概念 都是相对于【句型】的,都是相对于【句型】的,都是相对于【句型】…

基础IO2

文章目录 磁盘结构磁盘存储结构磁盘的逻辑结构引入文件系统理解文件系统inode 映射 data blocks文件名与inode的关系dentry树文件描述符与进程之间的关系 深刻理解软硬链接软链接硬链接 动静态库静态库1. 手动制作静态库2.调用静态库(1)安装到系统(2)自己指定查找路径(3)自己创…

RPC-健康检测机制

什么是健康检测? 在真实环境中服务提供方是以一个集群的方式提供服务,这对于服务调用方来说,就是一个接口会有多个服务提供方同时提供服务,调用方在每次发起请求的时候都可以拿到一个可用的连接。 健康检测,能帮助从连…

ZD Soft Screen Recorder:电脑录屏软件

前言 ZD Soft Screen Recorder 是一款屏幕录制软件 安装环境 [名称]:ZD Soft Screen Recorder [版本]:11.7.2 [大小]:6.8MB [语言]:英文 [环境]:pc 链接: 百度网盘 请输入提取码 提取码: ua23 软件界面 1、双击…

某杀软环境下的添加账户

某杀软环境下的添加账户 我们在某个杀软环境下,正常添加账户一般是会被直接拦截的 白+黑 在这个环境下,白+黑是最实用的绕过方式,我们可以通过调用winapi来创建账户,这些代码再存储到dll里面&#xff0c…

《Spring 基础之 IoC 与 DI 入门指南》

一、IoC 与 DI 概念引入 Spring 的 IoC(控制反转)和 DI(依赖注入)在 Java 开发中扮演着至关重要的角色,是提升代码质量和可维护性的关键技术。 (一)IoC 的含义及作用 IoC 全称为 Inversion of…

【C++】set,map,multiset,multimap的介绍和使用

set、map、multiset、multimap set、multiset的介绍和使用1、关联式容器2、键值对3、树形结构的关联式容器4、setset的介绍set的定义set的使用 5、multisetmultiset的介绍multiset的使用 map、multimap的介绍和使用1、map的介绍map的定义insert插入函数map的迭代器find查找函数…

Midjourney基础命令和提示词

1 基础命令 1.1 /imagine prompt 生成图片的核心命令,prompt 后输入描述。 /imagine prompt: A majestic dragon flying over a misty mountain, cinematic lighting, 4K resolution 高级提示 1.1.1 基本参数 图片比例 --ar 图片比例 混乱 Aspect Ratios --…

【代码pycharm】动手学深度学习v2-04 数据操作 + 数据预处理

数据操作 数据预处理 1.数据操作运行结果 2.数据预处理实现运行结果 第四课链接 1.数据操作 import torch # 张量的创建 x1 torch.arange(12) print(1.有12个元素的张量:\n,x1) print(2.张量的形状:\n,x1.shape) print(3.张量中元素的总数&#xff1…

智云-一个抓取web流量的轻量级蜜罐v1.5

智云-一个抓取web流量的轻量级蜜罐v1.5 github地址 https://github.com/xiaoxiaoranxxx/POT-ZHIYUN 新增功能-自定义漏洞信息 可通过正则来添加相关路由以及响应来伪造 nacos的版本响应如下 日流量态势 月流量态势 抓取流量效果

【Android原生问题分析】夸克、抖音划动无响应问题【Android14】

1 问题描述 偶现问题,用户打开夸克、抖音后,在界面上划动无响应,但是没有ANR。回到Launcher后再次打开夸克/抖音,发现App的界面发生了变化,但是仍然是划不动的。 2 log初分析 复现问题附近的log为: 用户…

2024年11月16日 星期六 重新整理Go技术

今日格言 坚持每天进步一点点~ 一个人也可以是一个团队~ 学习全栈开发, 做自己喜欢的产品~~ 简介 大家好, 我是张大鹏, 今天是2024年11月16日星期六, 很高兴在这里给大家分享技术. 今天又是休息的一天, 做了很多的思考, 整理了自己掌握的技术, 比如Java, Python, Golang,…

go-zero(一) 介绍和使用

go-zero 介绍和使用 一、什么是 go-zero? go-zero 是一个基于 Go 语言的微服务框架,提供了高效、简单并易于扩展的 API 设计和开发模式。它主要目的是为开发者提供一种简单的方式来构建和管理云原生应用。 1.go-zero 的核心特性 高性能: g…

VUE+SPRINGBOOT实现邮箱注册、重置密码、登录功能

随着互联网的发展,网站用户的管理、触达、消息通知成为一个网站设计是否合理的重要标志。目前主流互联网公司都支持手机验证码注册、登录。但是手机短信作为服务端网站是需要付出运营商通信成本的,而邮箱的注册、登录、重置密码,无疑成为了这…

多目标优化算法:多目标蛇鹫优化算法(MOSBOA)求解ZDT1、ZDT2、ZDT3、ZDT4、ZDT6,提供完整MATLAB代码

一、蛇鹫优化算法 蛇鹫优化算法(Secretary Bird Optimization Algorithm,简称SBOA)由Youfa Fu等人于2024年4月发表在《Artificial Intelligence Review》期刊上的一种新型的元启发式算法。该算法旨在解决复杂工程优化问题,特别是…

2024-11-15 Java开发工程师 内推

Java开发工程师 坐标:大连 岗位要求: 1、本科以上学历,计算机相关专业 2、22/23/24届毕业生 小结:有意向的私信发简历

Python绘制雪花

文章目录 系列目录写在前面技术需求完整代码代码分析1. 代码初始化部分分析2. 雪花绘制核心逻辑分析3. 窗口保持部分分析4. 美学与几何特点总结 写在后面 系列目录 序号直达链接爱心系列1Python制作一个无法拒绝的表白界面2Python满屏飘字表白代码3Python无限弹窗满屏表白代码4…