• 使用FastAPI和Redis Caching加快机器学习模型推理

使用FastAPI和Redis Caching加快机器学习模型推理

2025-05-14 08:37:03 栏目:宝塔面板 135 阅读

译者 | 布加迪

审校 | 重楼

Redis 是一款开源内存数据结构存储系统,是机器学习应用领域中缓存的优选。它的速度、耐用性以及支持各种数据结构使其成为满足实时推理任务的高吞吐量需求的理想选择。

我们在本教程中将探讨Redis缓存在机器学习工作流程中的重要性。我们将演示如何使用FastAPI和Redis构建一个强大的机器学习应用程序。本教程介绍如何在Windows上安装Redis、在本地运行Redis以及如何将其集成到机器学习项目中。最后,我们将通过发送重复请求和独特请求来测试该应用程序,以验证Redis缓存系统正常运行。

为什么在机器学习中使用Redis缓存?

在当今快节奏的数字环境中,用户期望机器学习应用程序能够立即获得结果。比如说,使用推荐模型向用户推荐产品的电商平台。如果实施Redis来缓存重复请求,该平台就可以显著缩短响应时间。

当用户请求产品推荐时,系统先检查该请求是否已被缓存。如果已缓存,则在几微秒内返回缓存的响应,从而提供无缝的体验。如果没有缓存,模型就处理该请求,生成推荐,并将结果存储在Redis中供将来的请求使用。这种方法不仅提高了用户满意度,还优化了服务器资源,使模型能够高效地处理更多请求。

使用Redis构建网络钓鱼电子邮件分类应用程序

我们在本项目中将构建一个网络钓鱼电子邮件分类应用程序。整个过程包括加载和处理来自Kaggle的数据集,使用处理后的数据训练机器学习模型,评估其性能,保存经过训练的模型,最构建带有Redis集成机制FastAPI应用程序。

1. 设置

  • Kaggle下载网络钓鱼电子邮件检测数据集并将其放入data/目录。
  • 首先需要安装Redis。在终端中运行以下命令安装Redis Python客户程序
pip install redis
  • 如果使用Windows系统,且未安装Windows Subsystem for Linux(WSL,请按照微软指南启用WSL,并微软商店安装Linux发行版(比如Ubuntu)。
  • WSL设置完成后,打开WSL终端并执行以下命令安装Redis
sudo apt update
sudo apt install redis-server
  • 要启动Redis服务器,请运行:
sudo service redis-server start

应该会看到一条确认消息,表明redis-server已成功启动。

2. 模型训练

训练脚本加载数据集、处理数据、训练模型并将其保存在本地。

import joblib
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline

def main():
 # Load dataset
 df = pd.read_csv("data/Phishing_Email.csv") # adjust the path as necessary

 # Assume dataset has columns "text" and "label"
 X = df["Email Text"].fillna("")
 y = df["Email Type"]

 # Split the dataset into training and testing sets
 X_train, X_test, y_train, y_test = train_test_split(
 X, y, test_size=0.2, random_state=42
 )

 # Create a pipeline with TF-IDF and Logistic Regression
 pipeline = Pipeline(
 [
 ("tfidf", TfidfVectorizer(stop_words="english")),
 ("clf", LogisticRegression(solver="liblinear")),
 ]
 )

 # Train the model
 pipeline.fit(X_train, y_train)

 # Save the trained model to a file
 joblib.dump(pipeline, "phishing_model.pkl")
 print("Model trained and saved as phishing_model.pkl")

if __name__ == "__main__":
 main()


python train.py


Model trained and saved as phishing_model.pkl

3. 模型评估

评估脚本加载数据集和保存的模型文件以执行模型评估。

import pandas as pd
from sklearn.metrics import classification_report, accuracy_score
from sklearn.model_selection import train_test_split
import joblib

def main():
 # Load dataset
 df = pd.read_csv("data/Phishing_Email.csv") # adjust the path as necessary

 # Assume dataset has columns "text" and "label"
 X = df["Email Text"].fillna("")
 y = df["Email Type"]

 # Split the dataset
 X_train, X_test, y_train, y_test = train_test_split(
 X, y, test_size=0.2, random_state=42
 )

 # Load the trained model
 model = joblib.load("phishing_model.pkl")

 # Make predictions on the test set
 y_pred = model.predict(X_test)

 # Evaluate the model
 print("Accuracy: ", accuracy_score(y_test, y_pred))
 print("Classification Report:")
 print(classification_report(y_test, y_pred))

if __name__ == "__main__":
 main()

结果近乎完美,F1分数也非常出色。

python validate.py

Accuracy: 0.9723860589812332
Classification Report:
 precision recall   f1-score support

Phishing Email 0.96 0.97 0.96 1457
 Safe Email 0.98 0.97 0.98 2273

 accuracy 0.97 3730
 macro avg 0.97 0.97 0.97 3730
 weighted avg   0.97 0.97 0.97 3730

4. 使用Redis提供模型服务

为了提供模型服务,我们将使用FastAPI创建REST API并集成Redis缓存预测。

import asyncio
import json
import joblib
from fastapi import FastAPI
from pydantic import BaseModel
import redis.asyncio as redis

# Create an asynchronous Redis client (make sure Redis is running on localhost:6379)
redis_client = redis.Redis(host="localhost", port=6379, db=0, decode_respnotallow=True)

# Load the trained model (synchronously)
model = joblib.load("phishing_model.pkl")

app = FastAPI()

# Define the request and response data models
class PredictionRequest(BaseModel):
 text: str

class PredictionResponse(BaseModel):
 prediction: str
 probability: float

@app.post("/predict", response_model=PredictionResponse)
async def predict_email(data: PredictionRequest):
 # Use the email text as a cache key
 cache_key = f"prediction:{data.text}"
 cached = await redis_client.get(cache_key)
 if cached:
 return json.loads(cached)

 # Run model inference in a thread to avoid blocking the event loop
 pred = await asyncio.to_thread(model.predict, [data.text])
 prob = await asyncio.to_thread(lambda: model.predict_proba([data.text])[0].max())

 result = {"prediction": str(pred[0]), "probability": float(prob)}

 # Cache the result for 1 hour (3600 seconds)
 await redis_client.setex(cache_key, 3600, json.dumps(result))
 return result

if __name__ == "__main__":
 import uvicorn
 uvicorn.run(app, host="0.0.0.0", port=8000)

python serve.py

INFO: Started server process [17640]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)

可以通过访问URL查看REST API 文档。

本项目的源代码、配置文件、模型和数据集可在kingabzpro/Redis-ml-project GitHub代码库中找到。如果在运行上述代码时遇到任何问题,随时参阅

Redis缓存在机器学习应用中的工作原理

下面逐步解释Redis缓存在我们的机器学习应用程序中的运作方式,一张流程图加以说明:

  • 客户程序提交输入数据,请求机器学习模型进行预测。
  • 系统根据输入数据生成独特的标识符,以检查预测是否已存在。
  • 系统使用生成的键查询Redis缓存,以查找先前存储的预测。

A.如果找到缓存的预测,则检索该预测并以JSON响应的形式返回。

B.如果没有找到缓存的预测,则将输入数据传递给机器学习模型以生成新的预测。

  • 新生成的预测存储在Redis缓存中将来使用。
  • 最终结果以JSON格式返回给客户程序

测试网络钓鱼电子邮件分类应用程序

构建完网络钓鱼电子邮件分类应用程序后,就可以测试其功能了。我们在本节中使用 `cURL` 命令发送多封电子邮件并分析响应来评估该应用程序。此外,我们将验证Redis数据库,以确保缓存系统正常运行。

使用CURL命令测试 API

为了测试API,我们将向`/predict`端点发送五个请求。其中三个请求包含独特的电子邮件文本,另外两个请求是之前发送的电子邮件的复制版本。这将使我们能够验证预测准确性和缓存机制。

echo "
===== Testing API Endpoint with 5 Requests =====
"

# First unique email
echo "
----- Request 1 (First unique email) -----"
curl -X 'POST' 
 'http://localhost:8000/predict' 
 -H 'accept: application/json' 
 -H 'Content-Type: application/json' 
 -d '{
 "text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"
}'

# Second unique email
echo "

----- Request 2 (Second unique email) -----"
curl -X 'POST' 
 'http://localhost:8000/predict' 
 -H 'accept: application/json' 
 -H 'Content-Type: application/json' 
 -d '{
 "text": "urgent action required: your account has been compromised, click here to reset your password immediately"
}'

# First duplicate (same as first email)
echo "

----- Request 3 (Duplicate of first email - should be cached) -----"
curl -X 'POST' 
 'http://localhost:8000/predict' 
 -H 'accept: application/json' 
 -H 'Content-Type: application/json' 
 -d '{
 "text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"
}'

# Third unique email
echo "

----- Request 4 (Third unique email) -----"
curl -X 'POST' 
 'http://localhost:8000/predict' 
 -H 'accept: application/json' 
 -H 'Content-Type: application/json' 
 -d '{
 "text": "congratulations you have won a free iphone, click here to claim your prize now before it expires"
}'

# Second duplicate (same as second email)
echo "

----- Request 5 (Duplicate of second email - should be cached) -----"
curl -X 'POST' 
 'http://localhost:8000/predict' 
 -H 'accept: application/json' 
 -H 'Content-Type: application/json' 
 -d '{
 "text": "urgent action required: your account has been compromised, click here to reset your password immediately"
}'

echo "

===== Test Complete =====
"
echo "Now run 'python check_redis.py' to verify the Redis cache entries"

运行上述脚本时,API应该返回每封电子邮件的预测结果。对于重复的请求,响应应该从Redis缓存中加以检索,以确保更快的响应时间。

sh test.sh



===== Testing API Endpoint with 5 Requests =====


----- Request 1 (First unique email) -----
{"prediction":"Safe Email","probability":0.7791625553383463}

----- Request 2 (Second unique email) -----
{"prediction":"Phishing Email","probability":0.8895319031315131}

----- Request 3 (Duplicate of first email - should be cached) -----
{"prediction":"Safe Email","probability":0.7791625553383463}

----- Request 4 (Third unique email) -----
{"prediction":"Phishing Email","probability":0.9169092144856761}

----- Request 5 (Duplicate of second email - should be cached) -----
{"prediction":"Phishing Email","probability":0.8895319031315131}

===== Test Complete =====

Now run 'python check_redis.py' to verify the Redis cache entries

验证Redis缓存

为了确认缓存系统正常运行,我们将使用Python脚本`check_redis.py`检查Redis数据库。该脚本检索缓存的预测结果并将其以表格形式显示出来。

import redis
import json
from tabulate import tabulate

def main():
 # Connect to Redis (ensure Redis is running on localhost:6379)
 redis_client = redis.Redis(host="localhost", port=6379, db=0, decode_respnotallow=True)

 # Retrieve all keys that start with "prediction:"
 keys = redis_client.keys("prediction:*")
 total_entries = len(keys)
 print(f"Total number of cached prediction entries: {total_entries}
")

 table_data = []
 # Process only the first 5 entries
 for key in keys[:5]:
 # Remove the 'prediction:' prefix to get the original email text
 email_text = key.replace("prediction:", "", 1)

 # Retrieve the cached value
 value = redis_client.get(key)
 try:
 data = json.loads(value)
 except json.JSONDecodeError:
 data = {}

 prediction = data.get("prediction", "N/A")

 # Display only the first 7 words of the email text
 words = email_text.split()
 truncated_text = " ".join(words[:7]) + ("..." if len(words) > 7 else "")

 table_data.append([truncated_text, prediction])

 # Print table using tabulate (only two columns now)
 headers = ["Email Text (First 7 Words)", "Prediction"]
 print(tabulate(table_data, headers=headers, tablefmt="pretty"))

if __name__ == "__main__":
 main()

运行check_redis.py脚本时,它会以表格形式显示缓存条目数量和已缓存的预测结果。

python check_redis.py


Total number of cached prediction entries: 3

+--------------------------------------------------+----------------+
| Email Text (First 7 Words) | Prediction | 
+--------------------------------------------------+----------------+
| congratulations you have won a free iphone,... | Phishing Email |
| urgent action required: your account has been... | Phishing Email |
| todays floor meeting you may get a... | Safe Email |
+--------------------------------------------------+----------------+

结语

通过使用多个请求测试钓鱼邮件分类应用程序,我们成功地演示了该API能够准确识别钓鱼邮件,同时还能使用Redis高效地缓存重复请求。这种缓存机制通过减少重复输入的冗余计算显著提升了性能,这在API处理庞大流量的实际应用场景中尤其大有助益

虽然这是一个比较简单的机器学习模型,但在处理更庞大、更复杂的模型(比如图像识别)时,缓存的优势来得明显。比如说,如果在部署一个大规模图像分类模型,缓存频繁处理输入的预测结果可以节省大量计算资源,并显著缩短响应时间。

原文标题:Accelerate Machine Learning Model Serving with FastAPI and Redis Caching作者:Abid Ali Awan

本文地址:https://www.yitenyun.com/206.html

搜索文章

Tags

数据库 API FastAPI Calcite 电商系统 MySQL Web 应用 异步数据库 数据同步 ACK 双主架构 循环复制 TIME_WAIT 运维 负载均衡 JumpServer SSL 堡垒机 跳板机 HTTPS HexHub Docker 服务器 服务器性能 管理口 JumpServer安装 堡垒机安装 Linux安装JumpServer Deepseek 宝塔面板 Linux宝塔 生命周期 esxi esxi6 root密码不对 无法登录 web无法登录 SQL 查询 序列 核心机制 Windows Windows server net3.5 .NET 安装出错 HTTPS加密 Windows宝塔 Mysql重置密码 开源 PostgreSQL 存储引擎 锁机制 宝塔面板打不开 宝塔面板无法访问 查看硬件 Linux查看硬件 Linux查看CPU Linux查看内存 行业 趋势 Oracle 处理机制 无法访问宝塔面板 Undo Log 机制 监控 Spring Redis 异步化 InnoDB 数据库锁 连接控制 机器学习 优化 万能公式 动态查询 Serverless 无服务器 语言 响应模型 ES 协同 group by 索引 技术 openHalo 分页查询 scp Linux的scp怎么用 scp上传 scp下载 scp命令 Postgres OTel Iceberg 缓存方案 缓存架构 缓存穿透 工具 存储 高可用 GreatSQL 连接数 数据 主库 SVM Embedding R edis 线程 日志文件 MIXED 3 Linux 安全 国产数据库 R2DBC SQLite-Web SQLite 数据库管理工具 加密 场景 Netstat Linux 服务器 端口 启动故障 ​Redis 推荐模型 Recursive 自定义序列化 防火墙 黑客 云原生 RocketMQ 长轮询 配置 SQLark 向量数据库 大模型 共享锁 AI 助手 OB 单机版 Hash 字段 PG DBA Rsync 信息化 智能运维 Ftp 不宕机 磁盘架构 架构 电商 系统 数据分类 向量库 Milvus Canal Python 业务 IT运维 流量 修改DNS Centos7如何修改DNS 分库 分表 传统数据库 向量化 • 索引 • 数据库 线上 库存 预扣 filelock MVCC 人工智能 推荐系统 语句 MySQL 9.3 sftp 服务器 参数 redo log 重做日志 同城 双活 聚簇 非聚簇 PostGIS mini-redis INCR指令 MongoDB MCP 开放协议 频繁 Codis 失效 Doris SeaTunnel 缓存 Redisson 锁芯 高效统计 今天这篇文章就跟大家 主从复制 代理 数据类型 虚拟服务器 虚拟机 内存 工具链 事务 Java 开发 prometheus Alert 数据备份 千万级 大表 INSERT COMPACT 窗口 函数 数据结构 ZODB 发件箱模式 SSH 容器 网络架构 网络配置 EasyExcel MySQL8 引擎 性能 Web 分布式架构 分布式锁​ 聚簇索引 非聚簇索引 QPS 高并发 数据脱敏 加密算法 崖山 新版本 核心架构 订阅机制 Go 数据库迁移 B+Tree ID 字段 分布式 集中式 RDB AOF 分页 速度 服务器中毒 Web 接口 Redis 8.0 数据集成工具 读写 自动重启 网络故障 播客 OAuth2 Token 数据页 容器化 模型 StarRocks 数据仓库 池化技术 连接池 微软 SQL Server AI功能 Redka DBMS 管理系统 排行榜 排序 SpringAI JOIN MGR 分布式集群 Caffeine CP 部署 原子性 Entity 事务隔离 网络 业务场景 LRU 数据字典 兼容性 Valkey Valkey8.0 Pottery dbt 数据转换工具 分页方案 排版 Testcloud 云端自动化 优化器 ReadView 事务同步 sqlmock 1 关系数据库 意向锁 记录锁 悲观锁 乐观锁 日志 单线程 UUIDv7 主键 Weaviate 对象 单点故障 AIOPS 仪表盘 UUID ID Order 编程 InfluxDB Pump Ansible Crash 代码 RAG HelixDB 产业链 IT 双引擎 分布式锁 Zookeeper 恢复数据 字典 订单 LLM List 类型 线程安全 国产 用户 慢SQL优化 表空间 拦截器 动态代理 解锁 调优 Next-Key RR 互联网 GitHub Git 快照读 当前读 视图 神经系统 矢量存储 数据库类型 AI代理 查询规划 count(*) count(主键) 行数 算法 CAS 技巧 多线程 并发控制 恢复机制 闪回