如何实现embedding搜索功能

人工智能 潘老师 1个月前 (03-23) 35 ℃ (0) 扫码查看

今天带大伙实现RAG(检索增强生成)里embedding模型处理文本向量化的过程,搞清楚embedding搜索功能到底是咋回事儿!

一、embedding到底是啥?

在RAG架构里,embedding可是实现文本向量化的关键一环。简单来说,它的核心操作就是把自然语言文本变成高维向量。为啥要这么干呢?因为有了这些向量,咱们就能实现基于语义的搜索啦。

打个比方,我们先把资料库中的文本,像文章标题、分类信息啥的,通过embedding模型转换成向量形式。用户提出问题时,也用同样的方法把问题转成向量。然后通过计算这两个向量之间的相似度,就能找到最符合用户意图的文本内容。这可比传统的搜索方式智能多了!要是想深入了解RAG,可以去看看“AI全栈必问的RAG是什么!”这篇文章。

二、简化版embedding实现流程

接下来,我给大家详细讲讲如何快速实现一个简化版的embedding应用,这里面涉及后端环境搭建、模型封装、文件读写,还有跨域处理这些关键步骤。

1. 环境初始化与模型封装

咱先从初始化后端Node.js环境开始,在命令行里敲上npm init -y,就能快速完成初始化。这和之前封装openai的操作有点类似,这次我们要封装的是embedding模型。在这个过程中,推荐大伙用dotenv模块来保护自己的API key,防止信息泄露,安全问题可不能马虎!

// 引入OpenAI和dotenv模块
import OpenAI from 'openai';
import dotenv from 'dotenv';
// 加载环境变量配置文件
dotenv.config({
  path: '.env'
});
// 创建OpenAI实例,设置API key和baseURL
export const client = new OpenAI({
  apiKey: process.env.OPENAI_API_KEY,
  baseURL: process.env.OPENAI_API_BASE_URL,
});

2. 读写文件及调用embedding模型

文件读写这块,我们用fs/promises模块来操作,它能帮我们避免回调地狱,让代码逻辑更清晰。再结合async/await语法,代码读起来就更顺畅了。

我们从posts.json文件里读取那些需要向量化的文章数据,然后调用embedding模型,生成对应的向量,最后把这些结果存到新文件里。这里的数据格式可以自己模拟,像下面这样:

[
    {
      "title": "如何使用Nuxt.js进行服务器端渲染",
      "category": "前端开发"
    },
    // 其他数据可以仿照这个格式自行补充
  ]

文件目录结构大概是这样:

ai-server
├── data
│   ├── posts_with_embeddings.json
│   └── posts.json
├── node_modules
├──.env
├── app.service.mjs
├── create-embedding.mjs
├── index.mjs
├── package.json
└── pnpm-lock.yaml

具体代码实现如下:

// 引入fs/promises模块和之前创建的client实例
import fs from 'fs/promises';
import { client } from './app.service.mjs';
// 定义输入输出文件路径
const inputFilePath = './data/posts.json';
const outputFilePath = './data/posts_with_embeddings.json';
// 异步读取数据文件,并将其解析为JSON格式
const data = await fs.readFile(inputFilePath, 'utf8');
const posts = JSON.parse(data);
// 用于存储带有embedding向量的文章数据
const postsWithEmbedding = [];
// 遍历每篇文章,生成embedding向量
for (const { title, category } of posts) {
    const response = await client.embeddings.create({
        model: 'text-embedding-ada-002',
        // 将文章标题和分类拼接作为输入
        input: `标题:${title};分类:${category}`
    });
    postsWithEmbedding.push({
        title,
        category,
        // 提取生成的embedding向量
        embedding: response.data[0].embedding
    });
}
// 将生成embedding的结果写入到新文件中
await fs.writeFile(outputFilePath, JSON.stringify(postsWithEmbedding));

3. 构建后端服务并实现搜索接口

这里我们用Koa框架来搭建服务,借助@koa/cors处理跨域问题。因为前端传值一般是JSON格式,所以还得引入koa-bodyparser来自动解析请求体。

下面这段代码实现了监听3000端口,并且创建了一个/search接口。这个接口的作用就是接收查询关键字,生成向量,计算余弦相似度,最后返回最匹配的结果。

// 引入Koa、cors、Router、bodyParser等模块,以及之前创建的client实例和fs/promises模块
import Koa from 'koa';
import cors from '@koa/cors';
import Router from 'koa-router';
import bodyParser from 'koa-bodyparser';
import { client } from './app.service.mjs';
import fs from 'fs/promises';
// 定义存储带有embedding向量数据的文件路径
const inputFilePath = './data/posts_with_embeddings.json';
// 读取文件数据并解析为JSON格式
const data = await fs.readFile(inputFilePath, 'utf8');
const posts = JSON.parse(data);
// 创建Koa应用实例和Router实例
const app = new Koa();
const router = new Router();
// 设置服务监听端口
const port = 3000;
// 使用cors和bodyParser中间件
app.use(cors());
app.use(bodyParser());
// 使用路由处理请求
app.use(router.routes());
app.use(router.allowedMethods());
// 监听服务启动,打印提示信息
app.listen(port, () => {
  console.log(`Server is running on port ${port}`);
});
// 计算余弦相似度的函数
function cosineSimilarity(a, b) {
  if (a.length!== b.length) {
      throw new Error('向量长度不匹配');
  }
  let dotProduct = 0;
  let normA = 0;
  let normB = 0;
  for (let i = 0; i < a.length; i++) {
      dotProduct += a[i] * b[i];
      normA += a[i] * a[i];
      normB += b[i] * b[i];
  }
  return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
// 定义搜索路由
router.post('/search', async (ctx) => {
  const { keword } = ctx.request.body; // 从请求体中获取关键字
  console.log(keword);
  // 生成查询关键字的embedding向量
  const response = await client.embeddings.create({
    model: 'text-embedding-ada-002',
    input: keword,
  });
  const { embedding } = response.data[0]; // 获取生成的向量
  // 计算每篇文章与查询向量的相似度
  const results = posts.map(item => ({
  ...item,
    similarity: cosineSimilarity(embedding, item.embedding)
  }));
  // 按相似度降序排序,并提取最相似的前三条记录 
  const topResults = results.sort((a, b) => b.similarity - a.similarity)
  .slice(0, 3)
  .map((item, index) => ({
      id: index,
      title: `${index + 1}.${item.title}, ${item.category}`
    }));
  ctx.body = {
    status: 200,
    data: topResults
  };
});

这里有个小细节要注意,sort方法会返回一个新数组,所以不能直接用data:results。可以在原results上链式调用sort,也可以用topResults接收新值再传给data

三、余弦相似度函数解析

上面代码里的余弦相似度函数,用来衡量两个向量在方向上的相似程度。它的取值范围一般在 -1到1之间,对于正向量来说,通常在0到1之间。这个值越接近1,就表示两个向量在空间中的方向越接近,语义上也就越相关;值越低,说明两个向量在语义上越不相关。

function cosineSimilarity(a, b) {
  if (a.length!== b.length) {
      throw new Error('向量长度不匹配');
  }
  let dotProduct = 0;
  let normA = 0;
  let normB = 0;
  for (let i = 0; i < a.length; i++) {
      dotProduct += a[i] * b[i];
      normA += a[i] * a[i];
      normB += b[i] * b[i];
  }
  return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}

具体实现步骤是这样的:

  • 判断长度是否一致:先检查两个向量的长度,如果不一样,就抛出错误,因为长度不同的向量没法计算相似度。
  • 计算点积:遍历向量,把对应位置的元素相乘,再把这些乘积加起来,得到点积。
  • 计算向量模:分别计算向量a和向量b各元素的平方和,然后对平方和开平方,得到两个向量的模。
  • 返回余弦相似度:把点积除以两个向量模的乘积,得到的结果就是两向量之间的相似度。

四、CORS配置扩展

默认情况下,我们允许所有跨域请求。但有时候,我们需要更细致地控制跨域访问,比如设置允许跨域的源、方法、请求头,以及是否允许携带凭据。下面这段代码就展示了具体的配置方法:

// 配置CORS
app.use(cors({
  origin: (ctx) => {
    const allowedOrigins = ['http://localhost:3000', 'http://example.com'];
    const requestOrigin = ctx.request.header.origin;
    if (allowedOrigins.includes(requestOrigin)) {
      return requestOrigin; // 允许该来源
    }
    return ''; // 拒绝跨域请求
  },
  allowMethods: ['GET', 'POST'], // 允许的HTTP方法
  allowHeaders: ['Content-Type', 'Authorization'], // 允许的请求头
  credentials: true // 允许携带凭据
}));

五、小结

通过上面这些代码示例,咱们一步步实现了利用embedding模型进行文本向量化,再结合余弦相似度计算,完成了基于自然语义的搜索功能。这种搜索方式比传统的搜索更精准,能处理各种复杂的文本匹配场景。


版权声明:本站文章,如无说明,均为本站原创,转载请注明文章来源。如有侵权,请联系博主删除。
本文链接:https://www.panziye.com/ai/16159.html
喜欢 (0)
请潘老师喝杯Coffee吧!】
分享 (0)
用户头像
发表我的评论
取消评论
表情 贴图 签到 代码

Hi,您需要填写昵称和邮箱!

  • 昵称【必填】
  • 邮箱【必填】
  • 网址【可选】