使用FastAPI和Redis缓存加速机器学习模型服务
译者 | 李睿
审校 | 重楼
本文介绍了如何使用FastAPI和Redis缓存加速机器学习模型服务。FastAPI作为高性能Web框架用于构建API,Redis作为内存中的数据结构存储系统作为缓存层。通过集成FastAPI和Redis,系统能快速响应重复请求,避免冗余计算,显著降低延迟和CPU负载。此外还详细阐述了实现步骤,包括加载模型、创建FastAPI端点、设置Redis缓存及测试性能提升。
你是否因为等待机器学习模型返回预测结果而耗费过长时间?很多人都有过这样的经历。机器学习模型在实时服务时可能会非常缓慢,尤其是那些大型且复杂的机器学习模型。另一方面,用户希望得到即时反馈。因此,这使得延迟问题愈发凸显。从技术层面来看,最主要的问题之一是当相同的输入反复触发相同的缓慢过程时,会出现冗余计算。本文将展示如何解决这个问题,因此将构建一个基于FastAPI的机器学习服务,并集成Redis缓存,以便在毫秒级的时间内迅速返回重复的预测结果。
什么是FastAPI?
FastAPI是一个基于Python的现代Web框架,用于构建 API。它使用Python的类型提示进行数据验证,并使用Swagger UI和ReDoc自动生成交互式API文档。FastAPI基于Starlette和Pydantic构建,支持异步编程,使其性能可与Node.js和Go相媲美。其设计有助于快速开发健壮的、生产就绪的API,使其成为将机器学习模型部署为可扩展的RESTful服务的绝佳选择。
什么是Redis?
Redis(Remote Dictionary Server)是一个开源的内存数据结构存储系统,其功能包括数据库、缓存和消息代理。通过将数据存储在内存中,Redis为读写操作提供了超低延迟,使其成为缓存频繁或计算密集型任务(例如机器学习模型预测)的理想选择。它支持各种数据结构,包括字符串、列表、集合和散列,并提供密钥过期(TTL)等功能,以实现高效的缓存管理。
为什么要结合FastAPI和Redis?
将FastAPI与Redis集成,可以创建一个响应速度快、效率高的系统。FastAPI作为处理API请求的快速且可靠的接口,而Redis则作为缓存层可以存储之前计算的结果。当再次接收到相同的输入时,可以立即从Redis检索结果,无需重新计算。这种方法降低了延迟,减轻了计算负载,并提高了应用程序的可扩展性。在分布式环境中,Redis充当可由多个FastAPI实例访问的集中式缓存,使其非常适用于生产级机器学习部署。
接下来,深入了解如何实现一个使用Redis缓存提供机器学习模型预测的FastAPI应用程序。这种设置能够确保针对相同输入的重复请求能够迅速从缓存中获取服务,从而大幅减少计算时间,并缩短响应时间。其实现步骤如下:
- 加载预训练模型
- 创建FastAPI预测端点
- 设置Redis缓存
- 测试和衡量性能提升
以下详细地了解这些步骤。
步骤1:加载预训练模型
首先,假设已经拥有一个训练有素的机器学习模型,并准备将其投入部署。在实际应用中,大多数机器学习模型都是离线训练的(例如scikit-learn模型,TensorFlow/Pytorch模型等),并保存到磁盘中,然后加载到服务应用程序中。在这个示例中,将创建一个简单的scikit-learn分类器,它将在著名的Iris flower数据集上进行训练,并使用joblib库保存。如果已经保存了一个模型文件,可以跳过训练步骤直接加载它。以下介绍如何训练一个模型,然后加载它进行服务:
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import joblib
# Load example dataset and train a simple model (Iris classification)
X, y = load_iris(return_X_y=True)
# Train the model
model = RandomForestClassifier().fit(X, y)
# Save the trained model to disk
joblib.dump(model, "model.joblib")
# Load the pre-trained model from disk (using the saved file)
model = joblib.load("model.joblib")
print("Model loaded and ready to serve predictions.")
在以上的代码中,使用了scikit-learn的内置Iris数据集训练了一个随机森林分类器,然后将该模型保存到一个名为model.joblib的文件中。之后,使用joblib.load方法将其重新加载。joblib库在保存scikit-learn模型时非常常见,主要是因为它擅长处理模型内的NumPy数组。随后,就有了一个可以预测新数据的模型对象。不过需要注意的是,在这里使用任何预训练的模型,使用FastAPI提供服务的方式以及缓存的结果或多或少是相同的。唯一的问题是,模型应该有一个预测方法,该方法接受一些输入并产生结果。此外,确保每次输入相同的数据时,都能给出一致的预测结果(即模型需具备确定性)。如果不是这样,缓存对于非确定性模型来说将会出现问题,因为它将返回不正确的结果。
步骤2:创建FastAPI预测端点
现在已经有了一个训练好的模型,可以通过API来使用它。我们将使用FASTAPI创建一个Web服务器来处理预测请求。FASTAPI可以很容易地定义端点并将请求参数映射到Python函数参数。在这个示例中,将假设模型需要四个特征作为输入。并将创建一个GET端点/预测,该端点/预测接受这些特征作为查询参数并返回模型的预测。
from fastapi import FastAPI
import joblib
app = FastAPI()
# Load the trained model at startup (to avoid re-loading on every request)
model = joblib.load("model.joblib") # Ensure this file exists from the training step
@app.get("/predict")
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float):
""" Predict the Iris flower species from input measurements. """
# Prepare the features for the model as a 2D list (model expects shape [n_samples, n_features])
features = [[sepal_length, sepal_width, petal_length, petal_width]]
# Get the prediction (in the iris dataset, prediction is an integer class label 0,1,2 representing the species)
prediction = model.predict(features)[0] # Get the first (only) prediction
return {"prediction": str(prediction)}
在以上代码中,成功创建了一个FastAPI应用程序,并在执行该文件之后启动API服务器。FastAPI对于Python来说非常快,因此它可以轻松地处理大量请求。为避免在每次请求时都重复加载模型(这一操作会显著降低性能),在程序启动时就将模型加载到内存中,以便随时调用。随后使用@app创建了一个/predict端点。GET使测试变得简单,因为可以在URL中传递内容,但在实际项目中,可能会想要使用POST,特别是在发送大型或复杂的输入(如图像或JSON)时。
这个函数接受4个输入参数:sepal_length、sepal_width、petal_length和petal_width, FastAPI会自动从URL中读取它们。在函数内部,将所有输入放入一个2D列表中(因为scikit-learn只接受二维数组作为输入),然后调用model.predict(),它会返回一个列表。然后将其作为JSON返回,例如{ “prediction”: “...”}。
该系统现在已经能正常运行,可以使用uvicorn main:app–reload命令运行它,然后访问 /predict 端点并获取结果。然而,再次发送相同的输入,它仍然会再次运行模型,这显然不够高效,所以下一步是添加Redis来缓存之前的结果,从而避免重复计算。
步骤3:设置Redis缓存
为了缓存模型输出,将使用Redis。首先,确保Redis服务器正在运行。你可以在本地安装,或者直接运行Docker容器;在默认情况下,它通常运行在端口6379上,并使用Python Redis库与服务器通信。
所以,其思路很简单:当请求进来时,创建一个表示输入的唯一键。然后检查该键是否存在于Redis中;如果那个键已经存在,这意味着之前已经缓存了这个,所以只返回保存的结果,不需要再次调用模型。如果没有,则执行model.predict,获得输出,将其保存在Redis中,并返回预测。
现在更新FastAPI应用程序来添加这个缓存逻辑。
!pip install redis
import redis # New import to use Redis
# Connect to a local Redis server (adjust host/port if needed)
cache = redis.Redis(host="localhost", port=6379, db=0)
@app.get("/predict")
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float):
"""
Predict the species, with caching to speed up repeated predictions.
"""
# 1. Create a unique cache key from input parameters
cache_key = f"{sepal_length}:{sepal_width}:{petal_length}:{petal_width}"
# 2. Check if the result is already cached in Redis
cached_val = cache.get(cache_key)
if cached_val:
# If cache hit, decode the bytes to a string and return the cached prediction
return {"prediction": cached_val.decode("utf-8")}
# 3. If not cached, compute the prediction using the model
features = [[sepal_length, sepal_width, petal_length, petal_width]]
prediction = model.predict(features)[0]
# 4. Store the result in Redis for next time (as a string)
cache.set(cache_key, str(prediction))
# 5. Return the freshly computed prediction
return {"prediction": str(prediction)}
在以上的代码中添加了Redis。首先,使用redis.Redis()创建了一个客户端,它连接到Redis服务器。在默认情况下使用db=0。然后,通过连接输入值来创建一个缓存键。在这里,它之所以有效,是因为输入是简单的数字,但对于复杂的数字,最好使用散列或JSON字符串。每个输入的键必须是唯一的。因此使用了cache.get(cache_key)。如果它找到相同的键,它就返回这个键,这使其速度更快,并且不需要重新运行模型。但是如果在缓存中没有找到,需要运行模型并获得预测结果。最后,使用cache.set()保存在Redis中。而当相同的输入在下次到来时,因为它已经存在,因为缓存将会很快。
步骤4:测试和衡量性能提升
现在,FastAPI应用程序正在运行并连接到Redis,现在是测试缓存如何提高响应时间的时候了。在这里,演示如何使用Python的请求库使用相同的输入两次调用API,并衡量每次调用所花费的时间。此外,需要确保在运行测试代码之前启动FastAPI:
import requests, time
# Sample input to predict (same input will be used twice to test caching)
params = {
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
# First request (expected to be a cache miss, will run the model)
start = time.time()
response1 = requests.get("http://localhost:8000/predict", params=params)
elapsed1 = time.time() - start
print("First response:", response1.json(), f"(Time: {elapsed1:.4f} seconds)")
# Second request (same params, expected cache hit, no model computation)
start = time.time()
response2 = requests.get("http://localhost:8000/predict", params=params)
elapsed2 = time.time() - start
print("Second response:", response2.json(), f"(Time: {elapsed2:.6f}seconds)")
当运行这个命令时,应该看到第一个请求返回一个结果。然后第二个请求返回相同的结果,但明显速度更快。例如,可能会发现第一次调用花费了几十毫秒的时间(取决于模型的复杂性),而第二次调用可能只有几毫秒或更少的时间。在使用轻量级模型的简单演示中,差异可能很小(因为模型本身速度很快),但对于更大的模型来说,其效果非常显著。
比较
为了更好地理解这一点,可以了解一下取得的成果:
- 无缓存:每个请求,即使是相同的请求,都会命中模型。如果模型每次预测需要100毫秒,那么10个相同的请求仍然需要约1000毫秒。
- 使用缓存:第一个请求需要全部命中(100毫秒),但接下来的9个相同的请求可能每个需要1~2毫秒(只是一个Redis查找和返回数据)。因此,这10个请求可能总共120毫秒,而不是1000毫秒,在这种情况下,速度提高了8倍。
在实际实验中,缓存可以带来数量级的性能提升。例如,在电子商务领域中,使用Redis意味着在微秒内返回重复请求的建议,而不必使用完整的模型服务管道重新计算它们。性能提升将取决于模型推理的成本。模型越复杂,从缓存重复调用中的收益越大。这也取决于请求模式:如果每个请求都是唯一的,缓存将无法发挥作用(没有重复请求可以从内存中提供服务),但是许多应用程序确实会看到重叠的请求(例如,流行的搜索查询,推荐的项目等)。
为了验证Redis缓存是否正常存储键值对可以直接对Redis缓存进行检查。
结论
本文展示了FastAPI和Redis如何协同工作以加速机器学习模型服务。FastAPI提供了一个快速且易于构建的API层用于提供预测服务,Redis添加了一个缓存层,可以显著减少重复计算的延迟和CPU负载。通过避免重复的模型调用,提高了响应速度,并使系统能够使用相同的资源处理更多的请求。
原文标题:Accelerate Machine Learning Model Serving With FastAPI and Redis Caching,作者:Janvi Kumari