本文最后更新于 34 天前,其中的信息可能已经有所发展或是发生改变。
本文以m3e文本相关性查询为例
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import numpy as np
import faiss
# 初始化FastAPI应用
app = FastAPI()
# 加载SentenceTransformer模型
model_path = "C:\\Users\\1\\.cache\\modelscope\\hub\\models\\xrunda\\m3e-base"
model = SentenceTransformer(model_path)
# 定义请求和响应模型(请求结构体需要包含的内容及回复结构体包含的内容)
class TextRequest(BaseModel):
text: str
texts : list
class EmbeddingResponse(BaseModel):
embeddings: list
indices: list
distances: list
most_relevant: list
# 定义POST端点
@app.post("/embedding", response_model=EmbeddingResponse)
async def get_embedding(request: TextRequest):
try:
# 初始化FAISS索引
# texts = [
# "机器学习是人工智能的核心技术。",
# "深度学习利用神经网络进行特征学习。",
# "自然语言处理让计算机理解人类语言。"
# ]
embeddings = model.encode(request.texts, normalize_embeddings=True).astype('float32')
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
# 编码查询文本
query_embed = model.encode([request.text], normalize_embeddings=True).astype('float32')
# 搜索最相关的文本
D, I = index.search(query_embed, k=2)
# 准备响应数据
up_list = []
for idx in I[0]:
if idx < len(request.texts):
up_list.append(request.texts[idx])
# return up_list
return EmbeddingResponse(embeddings=query_embed.tolist(), indices=I[0].tolist(), distances=D[0].tolist(), most_relevant=up_list)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 启动FastAPI服务
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)