• 使用FastAPI和Redis Caching加快机器学习模型推理

使用FastAPI和Redis Caching加快机器学习模型推理

2025-05-14 08:37:03 栏目:宝塔面板 113 阅读

译者 | 布加迪

审校 | 重楼

Redis 是一款开源内存数据结构存储系统,是机器学习应用领域中缓存的优选。它的速度、耐用性以及支持各种数据结构使其成为满足实时推理任务的高吞吐量需求的理想选择。

我们在本教程中将探讨Redis缓存在机器学习工作流程中的重要性。我们将演示如何使用FastAPI和Redis构建一个强大的机器学习应用程序。本教程介绍如何在Windows上安装Redis、在本地运行Redis以及如何将其集成到机器学习项目中。最后,我们将通过发送重复请求和独特请求来测试该应用程序,以验证Redis缓存系统正常运行。

为什么在机器学习中使用Redis缓存?

在当今快节奏的数字环境中,用户期望机器学习应用程序能够立即获得结果。比如说,使用推荐模型向用户推荐产品的电商平台。如果实施Redis来缓存重复请求,该平台就可以显著缩短响应时间。

当用户请求产品推荐时,系统先检查该请求是否已被缓存。如果已缓存,则在几微秒内返回缓存的响应,从而提供无缝的体验。如果没有缓存,模型就处理该请求,生成推荐,并将结果存储在Redis中供将来的请求使用。这种方法不仅提高了用户满意度,还优化了服务器资源,使模型能够高效地处理更多请求。

使用Redis构建网络钓鱼电子邮件分类应用程序

我们在本项目中将构建一个网络钓鱼电子邮件分类应用程序。整个过程包括加载和处理来自Kaggle的数据集,使用处理后的数据训练机器学习模型,评估其性能,保存经过训练的模型,最构建带有Redis集成机制FastAPI应用程序。

1. 设置

  • Kaggle下载网络钓鱼电子邮件检测数据集并将其放入data/目录。
  • 首先需要安装Redis。在终端中运行以下命令安装Redis Python客户程序
pip install redis
  • 如果使用Windows系统,且未安装Windows Subsystem for Linux(WSL,请按照微软指南启用WSL,并微软商店安装Linux发行版(比如Ubuntu)。
  • WSL设置完成后,打开WSL终端并执行以下命令安装Redis
sudo apt update
sudo apt install redis-server
  • 要启动Redis服务器,请运行:
sudo service redis-server start

应该会看到一条确认消息,表明redis-server已成功启动。

2. 模型训练

训练脚本加载数据集、处理数据、训练模型并将其保存在本地。

import joblib
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline

def main():
 # Load dataset
 df = pd.read_csv("data/Phishing_Email.csv") # adjust the path as necessary

 # Assume dataset has columns "text" and "label"
 X = df["Email Text"].fillna("")
 y = df["Email Type"]

 # Split the dataset into training and testing sets
 X_train, X_test, y_train, y_test = train_test_split(
 X, y, test_size=0.2, random_state=42
 )

 # Create a pipeline with TF-IDF and Logistic Regression
 pipeline = Pipeline(
 [
 ("tfidf", TfidfVectorizer(stop_words="english")),
 ("clf", LogisticRegression(solver="liblinear")),
 ]
 )

 # Train the model
 pipeline.fit(X_train, y_train)

 # Save the trained model to a file
 joblib.dump(pipeline, "phishing_model.pkl")
 print("Model trained and saved as phishing_model.pkl")

if __name__ == "__main__":
 main()


python train.py


Model trained and saved as phishing_model.pkl

3. 模型评估

评估脚本加载数据集和保存的模型文件以执行模型评估。

import pandas as pd
from sklearn.metrics import classification_report, accuracy_score
from sklearn.model_selection import train_test_split
import joblib

def main():
 # Load dataset
 df = pd.read_csv("data/Phishing_Email.csv") # adjust the path as necessary

 # Assume dataset has columns "text" and "label"
 X = df["Email Text"].fillna("")
 y = df["Email Type"]

 # Split the dataset
 X_train, X_test, y_train, y_test = train_test_split(
 X, y, test_size=0.2, random_state=42
 )

 # Load the trained model
 model = joblib.load("phishing_model.pkl")

 # Make predictions on the test set
 y_pred = model.predict(X_test)

 # Evaluate the model
 print("Accuracy: ", accuracy_score(y_test, y_pred))
 print("Classification Report:")
 print(classification_report(y_test, y_pred))

if __name__ == "__main__":
 main()

结果近乎完美,F1分数也非常出色。

python validate.py

Accuracy: 0.9723860589812332
Classification Report:
 precision recall   f1-score support

Phishing Email 0.96 0.97 0.96 1457
 Safe Email 0.98 0.97 0.98 2273

 accuracy 0.97 3730
 macro avg 0.97 0.97 0.97 3730
 weighted avg   0.97 0.97 0.97 3730

4. 使用Redis提供模型服务

为了提供模型服务,我们将使用FastAPI创建REST API并集成Redis缓存预测。

import asyncio
import json
import joblib
from fastapi import FastAPI
from pydantic import BaseModel
import redis.asyncio as redis

# Create an asynchronous Redis client (make sure Redis is running on localhost:6379)
redis_client = redis.Redis(host="localhost", port=6379, db=0, decode_respnotallow=True)

# Load the trained model (synchronously)
model = joblib.load("phishing_model.pkl")

app = FastAPI()

# Define the request and response data models
class PredictionRequest(BaseModel):
 text: str

class PredictionResponse(BaseModel):
 prediction: str
 probability: float

@app.post("/predict", response_model=PredictionResponse)
async def predict_email(data: PredictionRequest):
 # Use the email text as a cache key
 cache_key = f"prediction:{data.text}"
 cached = await redis_client.get(cache_key)
 if cached:
 return json.loads(cached)

 # Run model inference in a thread to avoid blocking the event loop
 pred = await asyncio.to_thread(model.predict, [data.text])
 prob = await asyncio.to_thread(lambda: model.predict_proba([data.text])[0].max())

 result = {"prediction": str(pred[0]), "probability": float(prob)}

 # Cache the result for 1 hour (3600 seconds)
 await redis_client.setex(cache_key, 3600, json.dumps(result))
 return result

if __name__ == "__main__":
 import uvicorn
 uvicorn.run(app, host="0.0.0.0", port=8000)

python serve.py

INFO: Started server process [17640]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)

可以通过访问URL查看REST API 文档。

本项目的源代码、配置文件、模型和数据集可在kingabzpro/Redis-ml-project GitHub代码库中找到。如果在运行上述代码时遇到任何问题,随时参阅

Redis缓存在机器学习应用中的工作原理

下面逐步解释Redis缓存在我们的机器学习应用程序中的运作方式,一张流程图加以说明:

  • 客户程序提交输入数据,请求机器学习模型进行预测。
  • 系统根据输入数据生成独特的标识符,以检查预测是否已存在。
  • 系统使用生成的键查询Redis缓存,以查找先前存储的预测。

A.如果找到缓存的预测,则检索该预测并以JSON响应的形式返回。

B.如果没有找到缓存的预测,则将输入数据传递给机器学习模型以生成新的预测。

  • 新生成的预测存储在Redis缓存中将来使用。
  • 最终结果以JSON格式返回给客户程序

测试网络钓鱼电子邮件分类应用程序

构建完网络钓鱼电子邮件分类应用程序后,就可以测试其功能了。我们在本节中使用 `cURL` 命令发送多封电子邮件并分析响应来评估该应用程序。此外,我们将验证Redis数据库,以确保缓存系统正常运行。

使用CURL命令测试 API

为了测试API,我们将向`/predict`端点发送五个请求。其中三个请求包含独特的电子邮件文本,另外两个请求是之前发送的电子邮件的复制版本。这将使我们能够验证预测准确性和缓存机制。

echo "
===== Testing API Endpoint with 5 Requests =====
"

# First unique email
echo "
----- Request 1 (First unique email) -----"
curl -X 'POST' 
 'http://localhost:8000/predict' 
 -H 'accept: application/json' 
 -H 'Content-Type: application/json' 
 -d '{
 "text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"
}'

# Second unique email
echo "

----- Request 2 (Second unique email) -----"
curl -X 'POST' 
 'http://localhost:8000/predict' 
 -H 'accept: application/json' 
 -H 'Content-Type: application/json' 
 -d '{
 "text": "urgent action required: your account has been compromised, click here to reset your password immediately"
}'

# First duplicate (same as first email)
echo "

----- Request 3 (Duplicate of first email - should be cached) -----"
curl -X 'POST' 
 'http://localhost:8000/predict' 
 -H 'accept: application/json' 
 -H 'Content-Type: application/json' 
 -d '{
 "text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"
}'

# Third unique email
echo "

----- Request 4 (Third unique email) -----"
curl -X 'POST' 
 'http://localhost:8000/predict' 
 -H 'accept: application/json' 
 -H 'Content-Type: application/json' 
 -d '{
 "text": "congratulations you have won a free iphone, click here to claim your prize now before it expires"
}'

# Second duplicate (same as second email)
echo "

----- Request 5 (Duplicate of second email - should be cached) -----"
curl -X 'POST' 
 'http://localhost:8000/predict' 
 -H 'accept: application/json' 
 -H 'Content-Type: application/json' 
 -d '{
 "text": "urgent action required: your account has been compromised, click here to reset your password immediately"
}'

echo "

===== Test Complete =====
"
echo "Now run 'python check_redis.py' to verify the Redis cache entries"

运行上述脚本时,API应该返回每封电子邮件的预测结果。对于重复的请求,响应应该从Redis缓存中加以检索,以确保更快的响应时间。

sh test.sh



===== Testing API Endpoint with 5 Requests =====


----- Request 1 (First unique email) -----
{"prediction":"Safe Email","probability":0.7791625553383463}

----- Request 2 (Second unique email) -----
{"prediction":"Phishing Email","probability":0.8895319031315131}

----- Request 3 (Duplicate of first email - should be cached) -----
{"prediction":"Safe Email","probability":0.7791625553383463}

----- Request 4 (Third unique email) -----
{"prediction":"Phishing Email","probability":0.9169092144856761}

----- Request 5 (Duplicate of second email - should be cached) -----
{"prediction":"Phishing Email","probability":0.8895319031315131}

===== Test Complete =====

Now run 'python check_redis.py' to verify the Redis cache entries

验证Redis缓存

为了确认缓存系统正常运行,我们将使用Python脚本`check_redis.py`检查Redis数据库。该脚本检索缓存的预测结果并将其以表格形式显示出来。

import redis
import json
from tabulate import tabulate

def main():
 # Connect to Redis (ensure Redis is running on localhost:6379)
 redis_client = redis.Redis(host="localhost", port=6379, db=0, decode_respnotallow=True)

 # Retrieve all keys that start with "prediction:"
 keys = redis_client.keys("prediction:*")
 total_entries = len(keys)
 print(f"Total number of cached prediction entries: {total_entries}
")

 table_data = []
 # Process only the first 5 entries
 for key in keys[:5]:
 # Remove the 'prediction:' prefix to get the original email text
 email_text = key.replace("prediction:", "", 1)

 # Retrieve the cached value
 value = redis_client.get(key)
 try:
 data = json.loads(value)
 except json.JSONDecodeError:
 data = {}

 prediction = data.get("prediction", "N/A")

 # Display only the first 7 words of the email text
 words = email_text.split()
 truncated_text = " ".join(words[:7]) + ("..." if len(words) > 7 else "")

 table_data.append([truncated_text, prediction])

 # Print table using tabulate (only two columns now)
 headers = ["Email Text (First 7 Words)", "Prediction"]
 print(tabulate(table_data, headers=headers, tablefmt="pretty"))

if __name__ == "__main__":
 main()

运行check_redis.py脚本时,它会以表格形式显示缓存条目数量和已缓存的预测结果。

python check_redis.py


Total number of cached prediction entries: 3

+--------------------------------------------------+----------------+
| Email Text (First 7 Words) | Prediction | 
+--------------------------------------------------+----------------+
| congratulations you have won a free iphone,... | Phishing Email |
| urgent action required: your account has been... | Phishing Email |
| todays floor meeting you may get a... | Safe Email |
+--------------------------------------------------+----------------+

结语

通过使用多个请求测试钓鱼邮件分类应用程序,我们成功地演示了该API能够准确识别钓鱼邮件,同时还能使用Redis高效地缓存重复请求。这种缓存机制通过减少重复输入的冗余计算显著提升了性能,这在API处理庞大流量的实际应用场景中尤其大有助益

虽然这是一个比较简单的机器学习模型,但在处理更庞大、更复杂的模型(比如图像识别)时,缓存的优势来得明显。比如说,如果在部署一个大规模图像分类模型,缓存频繁处理输入的预测结果可以节省大量计算资源,并显著缩短响应时间。

原文标题:Accelerate Machine Learning Model Serving with FastAPI and Redis Caching作者:Abid Ali Awan

本文地址:https://www.yitenyun.com/206.html

搜索文章

Tags

数据库 API FastAPI Calcite 电商系统 MySQL Web 应用 异步数据库 数据同步 ACK 双主架构 循环复制 JumpServer SSL 堡垒机 跳板机 HTTPS TIME_WAIT 运维 负载均衡 HexHub Docker JumpServer安装 堡垒机安装 Linux安装JumpServer Deepseek 宝塔面板 Linux宝塔 生命周期 esxi esxi6 root密码不对 无法登录 web无法登录 服务器 管理口 序列 核心机制 HTTPS加密 Windows Windows server net3.5 .NET 安装出错 服务器性能 查看硬件 Linux查看硬件 Linux查看CPU Linux查看内存 宝塔面板打不开 宝塔面板无法访问 开源 PostgreSQL 存储引擎 Windows宝塔 Mysql重置密码 Oracle 处理机制 无法访问宝塔面板 InnoDB 数据库锁 连接控制 机制 监控 Serverless 无服务器 语言 Spring Redis 异步化 ES 协同 SQL 优化 万能公式 技术 Undo Log group by 索引 缓存方案 缓存架构 缓存穿透 分页查询 高可用 动态查询 机器学习 GreatSQL 连接数 工具 响应模型 查询 日志文件 MIXED 3 scp Linux的scp怎么用 scp上传 scp下载 scp命令 SVM Embedding R edis 线程 R2DBC 锁机制 加密 场景 数据 主库 Netstat Linux 服务器 端口 openHalo Postgres OTel Iceberg 云原生 Linux 安全 RocketMQ 长轮询 配置 自定义序列化 存储 AI 助手 ​Redis 推荐模型 SQLite-Web SQLite 数据库管理工具 Recursive SQLark 共享锁 Hash 字段 PG DBA 向量数据库 大模型 电商 系统 Ftp OB 单机版 启动故障 国产数据库 架构 修改DNS Centos7如何修改DNS MySQL 9.3 • 索引 • 数据库 数据分类 防火墙 黑客 人工智能 推荐系统 流量 sftp 服务器 参数 线上 库存 预扣 磁盘架构 Rsync Python 业务 redo log 重做日志 分库 分表 同城 双活 信息化 智能运维 向量库 Milvus mini-redis INCR指令 MVCC 不宕机 PostGIS 传统数据库 向量化 行业 趋势 Canal 聚簇 非聚簇 高效统计 今天这篇文章就跟大家 INSERT COMPACT 缓存 Redisson 锁芯 网络架构 网络配置 Doris SeaTunnel filelock prometheus Alert 数据备份 事务 Java 开发 虚拟服务器 虚拟机 内存 ZODB 语句 Web 窗口 函数 MongoDB 数据结构 RDB AOF 读写 引擎 性能 数据脱敏 加密算法 容器 失效 OAuth2 Token IT运维 核心架构 订阅机制 Go 数据库迁移 数据类型 频繁 Codis B+Tree ID 字段 模型 Redis 8.0 自动重启 分布式 集中式 崖山 新版本 发件箱模式 容器化 网络故障 DBMS 管理系统 聚簇索引 非聚簇索引 播客 SpringAI SSH 微软 SQL Server AI功能 MCP 开放协议 QPS 高并发 JOIN Entity 数据页 Web 接口 部署 原子性 排行榜 排序 数据集成工具 Pottery 工具链 Testcloud 云端自动化 速度 服务器中毒 网络 StarRocks 数据仓库 Redka Caffeine CP 分页方案 排版 1 大表 业务场景 池化技术 连接池 事务隔离 分布式架构 分布式锁​ dbt 数据转换工具 悲观锁 乐观锁 主从复制 代理 日志 LRU AIOPS sqlmock 分页 优化器 EasyExcel MySQL8 单点故障 Order 意向锁 记录锁 仪表盘 事务同步 数据字典 兼容性 InfluxDB 对象 UUIDv7 主键 RAG HelixDB Ansible ReadView 订单 Crash 代码 单线程 UUID ID 双引擎 IT 字典 Weaviate LLM Valkey Valkey8.0 恢复数据 产业链 编程 分布式锁 Zookeeper MGR 分布式集群 千万级 线程安全 Pump List 类型 关系数据库 拦截器 动态代理 Next-Key 表空间 解锁 调优 慢SQL优化 快照读 当前读 视图 矢量存储 数据库类型 AI代理 国产 用户 RR 互联网 GitHub Git 神经系统 查询规划 算法 count(*) count(主键) 行数 技巧 CAS 并发控制 恢复机制 多线程 闪回