模型训练
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 定义全连接神经网络模型
class FNN(nn.Module):def __init__(self):super(FNN, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.fc2 = nn.Linear(128, 64)self.fc3 = nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 28 * 28) # 将输入图像展平为一维x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x# 数据预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])# 加载MNIST数据集
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)# 创建模型实例
model = FNN()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
epochs = 5
for epoch in range(epochs):running_loss = 0for images, labels in trainloader:optimizer.zero_grad()output = model(images)loss = criterion(output, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1} - Loss: {running_loss / len(trainloader)}")# 保存模型
torch.save(model.state_dict(), 'fnn_model.pth')
print("模型已保存为 fnn_model.pth")
测试
from flask import Flask, request, jsonify
import torch
from torchvision import transforms
from PIL import Image
import io
from torch import nn, optimclass FNN(nn.Module):def __init__(self):super(FNN, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.fc2 = nn.Linear(128, 64)self.fc3 = nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 28 * 28) # 将输入图像展平为一维x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x# 定义 Flask app
app = Flask(__name__)# 加载已保存的模型
model = FNN()
model.load_state_dict(torch.load('fnn_model.pth'))
model.eval()# 数据预处理函数
def transform_image(image_bytes):transform = transforms.Compose([transforms.Grayscale(), transforms.Resize((28, 28)),transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])image = Image.open(io.BytesIO(image_bytes))return transform(image).unsqueeze(0) # 添加 batch 维度# 预测函数
def get_prediction(image_tensor):output = model(image_tensor)_, predicted = torch.max(output, 1)return predicted.item()# 定义 API 路由
@app.route('/predict', methods=['POST'])
def predict():if request.method == 'POST':if 'file' not in request.files:return jsonify({'error': 'No file provided'}), 400file = request.files['file']if file:image_bytes = file.read()image_tensor = transform_image(image_bytes)prediction = get_prediction(image_tensor)return jsonify({'prediction': prediction})# 运行 Flask 应用
if __name__ == '__main__':app.run(debug=True)