章
目
录
今天带大伙实现RAG(检索增强生成)里embedding模型处理文本向量化的过程,搞清楚embedding搜索功能到底是咋回事儿!
一、embedding到底是啥?
在RAG架构里,embedding可是实现文本向量化的关键一环。简单来说,它的核心操作就是把自然语言文本变成高维向量。为啥要这么干呢?因为有了这些向量,咱们就能实现基于语义的搜索啦。
打个比方,我们先把资料库中的文本,像文章标题、分类信息啥的,通过embedding模型转换成向量形式。用户提出问题时,也用同样的方法把问题转成向量。然后通过计算这两个向量之间的相似度,就能找到最符合用户意图的文本内容。这可比传统的搜索方式智能多了!要是想深入了解RAG,可以去看看“AI全栈必问的RAG是什么!”这篇文章。
二、简化版embedding实现流程
接下来,我给大家详细讲讲如何快速实现一个简化版的embedding应用,这里面涉及后端环境搭建、模型封装、文件读写,还有跨域处理这些关键步骤。
1. 环境初始化与模型封装
咱先从初始化后端Node.js环境开始,在命令行里敲上npm init -y
,就能快速完成初始化。这和之前封装openai的操作有点类似,这次我们要封装的是embedding模型。在这个过程中,推荐大伙用dotenv
模块来保护自己的API key,防止信息泄露,安全问题可不能马虎!
// 引入OpenAI和dotenv模块
import OpenAI from 'openai';
import dotenv from 'dotenv';
// 加载环境变量配置文件
dotenv.config({
path: '.env'
});
// 创建OpenAI实例,设置API key和baseURL
export const client = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
baseURL: process.env.OPENAI_API_BASE_URL,
});
2. 读写文件及调用embedding模型
文件读写这块,我们用fs/promises
模块来操作,它能帮我们避免回调地狱,让代码逻辑更清晰。再结合async/await
语法,代码读起来就更顺畅了。
我们从posts.json
文件里读取那些需要向量化的文章数据,然后调用embedding模型,生成对应的向量,最后把这些结果存到新文件里。这里的数据格式可以自己模拟,像下面这样:
[
{
"title": "如何使用Nuxt.js进行服务器端渲染",
"category": "前端开发"
},
// 其他数据可以仿照这个格式自行补充
]
文件目录结构大概是这样:
ai-server
├── data
│ ├── posts_with_embeddings.json
│ └── posts.json
├── node_modules
├──.env
├── app.service.mjs
├── create-embedding.mjs
├── index.mjs
├── package.json
└── pnpm-lock.yaml
具体代码实现如下:
// 引入fs/promises模块和之前创建的client实例
import fs from 'fs/promises';
import { client } from './app.service.mjs';
// 定义输入输出文件路径
const inputFilePath = './data/posts.json';
const outputFilePath = './data/posts_with_embeddings.json';
// 异步读取数据文件,并将其解析为JSON格式
const data = await fs.readFile(inputFilePath, 'utf8');
const posts = JSON.parse(data);
// 用于存储带有embedding向量的文章数据
const postsWithEmbedding = [];
// 遍历每篇文章,生成embedding向量
for (const { title, category } of posts) {
const response = await client.embeddings.create({
model: 'text-embedding-ada-002',
// 将文章标题和分类拼接作为输入
input: `标题:${title};分类:${category}`
});
postsWithEmbedding.push({
title,
category,
// 提取生成的embedding向量
embedding: response.data[0].embedding
});
}
// 将生成embedding的结果写入到新文件中
await fs.writeFile(outputFilePath, JSON.stringify(postsWithEmbedding));
3. 构建后端服务并实现搜索接口
这里我们用Koa框架来搭建服务,借助@koa/cors
处理跨域问题。因为前端传值一般是JSON格式,所以还得引入koa-bodyparser
来自动解析请求体。
下面这段代码实现了监听3000端口,并且创建了一个/search
接口。这个接口的作用就是接收查询关键字,生成向量,计算余弦相似度,最后返回最匹配的结果。
// 引入Koa、cors、Router、bodyParser等模块,以及之前创建的client实例和fs/promises模块
import Koa from 'koa';
import cors from '@koa/cors';
import Router from 'koa-router';
import bodyParser from 'koa-bodyparser';
import { client } from './app.service.mjs';
import fs from 'fs/promises';
// 定义存储带有embedding向量数据的文件路径
const inputFilePath = './data/posts_with_embeddings.json';
// 读取文件数据并解析为JSON格式
const data = await fs.readFile(inputFilePath, 'utf8');
const posts = JSON.parse(data);
// 创建Koa应用实例和Router实例
const app = new Koa();
const router = new Router();
// 设置服务监听端口
const port = 3000;
// 使用cors和bodyParser中间件
app.use(cors());
app.use(bodyParser());
// 使用路由处理请求
app.use(router.routes());
app.use(router.allowedMethods());
// 监听服务启动,打印提示信息
app.listen(port, () => {
console.log(`Server is running on port ${port}`);
});
// 计算余弦相似度的函数
function cosineSimilarity(a, b) {
if (a.length!== b.length) {
throw new Error('向量长度不匹配');
}
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
// 定义搜索路由
router.post('/search', async (ctx) => {
const { keword } = ctx.request.body; // 从请求体中获取关键字
console.log(keword);
// 生成查询关键字的embedding向量
const response = await client.embeddings.create({
model: 'text-embedding-ada-002',
input: keword,
});
const { embedding } = response.data[0]; // 获取生成的向量
// 计算每篇文章与查询向量的相似度
const results = posts.map(item => ({
...item,
similarity: cosineSimilarity(embedding, item.embedding)
}));
// 按相似度降序排序,并提取最相似的前三条记录
const topResults = results.sort((a, b) => b.similarity - a.similarity)
.slice(0, 3)
.map((item, index) => ({
id: index,
title: `${index + 1}.${item.title}, ${item.category}`
}));
ctx.body = {
status: 200,
data: topResults
};
});
这里有个小细节要注意,sort
方法会返回一个新数组,所以不能直接用data:results
。可以在原results
上链式调用sort
,也可以用topResults
接收新值再传给data
。
三、余弦相似度函数解析
上面代码里的余弦相似度函数,用来衡量两个向量在方向上的相似程度。它的取值范围一般在 -1到1之间,对于正向量来说,通常在0到1之间。这个值越接近1,就表示两个向量在空间中的方向越接近,语义上也就越相关;值越低,说明两个向量在语义上越不相关。
function cosineSimilarity(a, b) {
if (a.length!== b.length) {
throw new Error('向量长度不匹配');
}
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
具体实现步骤是这样的:
- 判断长度是否一致:先检查两个向量的长度,如果不一样,就抛出错误,因为长度不同的向量没法计算相似度。
- 计算点积:遍历向量,把对应位置的元素相乘,再把这些乘积加起来,得到点积。
- 计算向量模:分别计算向量
a
和向量b
各元素的平方和,然后对平方和开平方,得到两个向量的模。 - 返回余弦相似度:把点积除以两个向量模的乘积,得到的结果就是两向量之间的相似度。
四、CORS配置扩展
默认情况下,我们允许所有跨域请求。但有时候,我们需要更细致地控制跨域访问,比如设置允许跨域的源、方法、请求头,以及是否允许携带凭据。下面这段代码就展示了具体的配置方法:
// 配置CORS
app.use(cors({
origin: (ctx) => {
const allowedOrigins = ['http://localhost:3000', 'http://example.com'];
const requestOrigin = ctx.request.header.origin;
if (allowedOrigins.includes(requestOrigin)) {
return requestOrigin; // 允许该来源
}
return ''; // 拒绝跨域请求
},
allowMethods: ['GET', 'POST'], // 允许的HTTP方法
allowHeaders: ['Content-Type', 'Authorization'], // 允许的请求头
credentials: true // 允许携带凭据
}));
五、小结
通过上面这些代码示例,咱们一步步实现了利用embedding模型进行文本向量化,再结合余弦相似度计算,完成了基于自然语义的搜索功能。这种搜索方式比传统的搜索更精准,能处理各种复杂的文本匹配场景。