8-RAG

1 核心概念和作用

定义:RAG 是一种结合检索与生成的技术,解决大语言模型(LLM)在长文本处理、事实准确性和上下文感知方面的局限。

核心逻辑:通过检索外部知识库(如向量数据库)获取相关文档,将其作为上下文附加到用户查询中,辅助 LLM 生成更准确、有依据的回答。

Spring AI 支持:提供模块化架构和开箱即用的 Advisor,允许自定义 RAG 流程或使用现成实现。

2 Milvus的操作

2.1 向量化

这里可以通过OpenAI提供的向量模型来处理。首先是配置相关的信息

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
@Autowired
private EmbeddingModel embeddingModel;
/**
* 向量化的实现
*/
@Test
void embeddings(){
// 例 1:将文本转换为 Embedding
float[] embeddings1 = embeddingModel.embed("我喜欢Java");
this.printFloatArrays(embeddings1);
// 例 2:将文档转换为 Embedding
float[] embeddings2 = embeddingModel.embed(new Document("我喜欢Java"));
this.printFloatArrays(embeddings2);
// 例 3:使用选项将文本转换为 Embedding
EmbeddingRequest embeddingRequest = new EmbeddingRequest(
List.of("我喜欢Java"),
OpenAiEmbeddingOptions.builder()
.model("text-embedding-3-small")
.build());
EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest);
float[] embeddings3 = embeddingResponse.getResult().getOutput();
this.printFloatArrays(embeddings3);
}

private void printFloatArrays(float[] embeddings1) {
System.out.println(embeddings1.length);
for (float v : embeddings1) {
System.out.print(v+"\t");
}
System.out.println();
}

2.2 创建集合和索引

引入对应的依赖

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.19</version>
<exclusions>
<exclusion>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-vector-store-milvus</artifactId>
</dependency>

protobuf的依赖有冲突,这里排除mysql中的冲突依赖。需要把向量数据信息存储到Milvus数据库中,那么就得先创建对应的集合和索引,具体代码如下:

先把核心的MilvusServiceClient注入到Spring容器中

1
2
3
4
5
6
7
8
@Bean
public MilvusServiceClient getMilvusServiceClient(ChatClient client) {
ConnectParam connectParam = ConnectParam.newBuilder()
.withHost("localhost")
.withPort(19530)
.build();
return new MilvusServiceClient(connectParam);
}

然后创建对应的集合和索引

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
@Autowired
private MilvusServiceClient client;
/**
* 创建集合
* @throws Exception
*/
@Test
void createCollection() throws Exception {
List<FieldType> fieldTypes = Arrays.asList(
FieldType.newBuilder()
.withName(MilvusEntity.Field.ID)
.withDescription("主键ID")
.withDataType(DataType.Int64)
.withPrimaryKey(true)
.withAutoID(true)
.build(),
FieldType.newBuilder()
.withName(MilvusEntity.Field.FEATURE)
.withDescription("特征向量")
.withDataType(DataType.FloatVector)
.withDimension(MilvusEntity.FEATURE_DIM)
.build(),
// 设置向量维度
FieldType.newBuilder()
.withName(MilvusEntity.Field.TEXT)
.withDescription("输入数据")
.withDataType(DataType.VarChar)
.withTypeParams(Collections.singletonMap("max_length", "65535"))
.build(),
FieldType.newBuilder()
.withName(MilvusEntity.Field.OUTPUT)
.withDescription("问题答案数据")
.withDataType(DataType.VarChar)
.withTypeParams(Collections.singletonMap("max_length", "65535"))
.build());
CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
.withCollectionName(MilvusEntity.COLLECTION_NAME)
.withDescription("rag collection")
.withShardsNum(MilvusEntity.SHARDS_NUM)
.withFieldTypes(fieldTypes)
.build();
client.createCollection(createCollectionReq);
// 同时给向量创建对应的索引
CreateIndexParam createIndexParam = CreateIndexParam.newBuilder()
.withCollectionName(MilvusEntity.COLLECTION_NAME)
.withFieldName(MilvusEntity.Field.FEATURE) // 向量字段名
.withIndexType(IndexType.IVF_FLAT) // 使用IVF_FLAT索引类型
.withMetricType(MetricType.L2) // 指定度量类型,如L2距离
.withExtraParam("{\"nlist\":128}") // 根据索引类型提供额外参数,比如nlist
.build();
client.createIndex(createIndexParam);
}

这里为了方便自定义了MilvusEntity

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
public class MilvusEntity {
/**
* 向量数据库名称
*/
public static final String DB_NAME = "default";
/**
* 集合名称
*/
public static final String COLLECTION_NAME = "springai_rag";
/**
* 分片数量
*/
public static final int SHARDS_NUM = 1;
/**
* 分区数量
*/
public static final int PARTITION_NUM = 1;
/**
* 特征向量维度
*/
public static final Integer FEATURE_DIM = 1536;
/**
* 字段
*/
public static class Field {
/**
* id
*/
public static final String ID = "id";
/**
* 文本特征向量
*/
public static final String FEATURE = "feature";
/**
* 文本
*/
public static final String TEXT = "instruction";
/**
* 问答匹配的结果
*/
public static final String OUTPUT = "output";
}
}

执行代码后可以看到创建的结构

2.3 插入数据

现在就可以通过代码把需要准备的数据插入到Milvus数据库中去了,数据存放在resources目录下

为了方便存储创建了对应的对象来结构化数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
public class FaqItem {
private String instruction;
private String input;
private String output;
// Getters and Setters
public String getInstruction() {
return instruction;
}
public void setInstruction(String instruction) {
this.instruction = instruction;
}
public String getInput() {
return input;
}
public void setInput(String input) {
this.input = input;
}
public String getOutput() {
return output;
}
public void setOutput(String output) {
this.output = output;
}
@Override
public String toString() {
return "FaqItem{" +
"instruction='" + instruction + '\'' +
", input='" + input + '\'' +
", output='" + output + '\'' +
'}';
}
}

现在创建添加的逻辑代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
private static final ObjectMapper objectMapper = new ObjectMapper();
@Test
public void insertTestData() throws Exception {
// 加载数据
InputStream inputStream = getClass().getClassLoader().getResourceAsStream("train_zh.json");
// 一行一行读取数据
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (line.trim().isEmpty()) continue;
// 这里需要提取的是 instruction 的信息,然后向量化
FaqItem item = objectMapper.readValue(line, FaqItem.class);
// 获取原始的问题
String instruction = item.getInstruction();
// 对问题向量化
float[] embeddings = embeddingModel.embed(instruction);
// float[] embeddings 转换为 List<Float>
List<Float> embeddingList = new ArrayList<>(embeddings.length);
for (float f : embeddings) {
embeddingList.add(f);
}
// 最终需要存储到向量数据库中的结构
List<List<Float>> floats = new ArrayList<>();
floats.add(embeddingList);
List<InsertParam.Field> fields = new ArrayList<>();
fields.add(new InsertParam.Field(MilvusEntity.Field.FEATURE, floats));
fields.add(new InsertParam.Field(MilvusEntity.Field.TEXT, Arrays.asList(item.getInstruction())));
fields.add(new InsertParam.Field(MilvusEntity.Field.OUTPUT, Arrays.asList(item.getOutput())));
InsertParam insertParam = InsertParam.newBuilder()
.withCollectionName(MilvusEntity.COLLECTION_NAME)
.withFields(fields)
.build();
client.insert(insertParam);
}
}
}

执行后就可以查看具体的数据信息

2.4 数据检索

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
@Test
void search(){
List<List<Float>> floats = new ArrayList<>();
float[] embeddings = embeddingModel.embed("白癜风怎么治疗");
// float[] embeddings 转换为 List<Float>
List<Float> embeddingList = new ArrayList<>(embeddings.length);
for (float f : embeddings) {
embeddingList.add(f);
}
floats.add(embeddingList);
SearchParam searchParam = SearchParam.newBuilder()
.withCollectionName(MilvusEntity.COLLECTION_NAME)
.withMetricType(MetricType.L2)// 使用 L2 距离作为相似度度量
.withTopK(3) // 返回最接近的前3个结果
.withVectors(floats)
.withVectorFieldName(MilvusEntity.Field.FEATURE)
// 向量字段名
.withOutFields(Arrays.asList(MilvusEntity.Field.ID,MilvusEntity.Field.OUTPUT,MilvusEntity.Field.TEXT)) // 需要返回的字段
.build();
SearchResults data = client.search(searchParam).getData();
SearchResultsWrapper resultsWrapper = new SearchResultsWrapper(data.getResults());
List<QueryResultsWrapper.RowRecord> rowRecords = resultsWrapper.getRowRecords();
if(rowRecords != null && !rowRecords.isEmpty()){
System.out.println(rowRecords);
}
}

3 RAG的实现

创建对应的工具方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
@Tool(description = "病人患病信息检索")
String vectorSearch(@ToolParam(description = "患者问题") String msg, ToolContext context) {
// 1.获取查询的向量信息
List<List<Float>> floats = new ArrayList<>();
EmbeddingModel embeddingModel = (EmbeddingModel) context.getContext().get("embeddingModel");
MilvusServiceClient client = (MilvusServiceClient) context.getContext().get("client");
float[] embeddings = embeddingModel.embed(msg);
// float[] embeddings 转换为 List<Float>
List<Float> embeddingList = new ArrayList<>(embeddings.length);
for (float f : embeddings) {
embeddingList.add(f);
}
floats.add(embeddingList);
// 封装向量数据库检索的信息
SearchParam searchParam = SearchParam.newBuilder()
.withCollectionName(MilvusEntity.COLLECTION_NAME)
.withMetricType(MetricType.L2)// 使用 L2 距离作为相似度度量
.withTopK(3) // 返回最接近的前3个结果
.withVectors(floats)
.withVectorFieldName(MilvusEntity.Field.FEATURE)
// 向量字段名
.withOutFields(Arrays.asList(MilvusEntity.Field.ID,MilvusEntity.Field.OUTPUT,MilvusEntity.Field.TEXT)) // 需要返回的字段
.build();
SearchResults data = client.search(searchParam).getData();
// 3.处理响应的信息
SearchResultsWrapper resultsWrapper = new SearchResultsWrapper(data.getResults());
List<QueryResultsWrapper.RowRecord> rowRecords = resultsWrapper.getRowRecords();
StringBuilder sb = new StringBuilder();
if(rowRecords != null && !rowRecords.isEmpty()){
sb.append(rowRecords);
}
return sb.toString();
}

然后在使用的时候添加工具就可以

1
2
3
4
5
6
7
8
9
10
11
12
@Test
void ragTest(){
String content = ChatClient.builder(chatModel)
.build().prompt()
.tools(new Tools())
.toolContext(Map.of("embeddingModel",embeddingModel,"client",client))
.user("白癜风怎么治疗")
.call()
.content();
System.out.println(content);
}


8-RAG
http://www.zivjie.cn/2025/11/22/spring框架/springAI/SpringAi框架/8-RAG/
作者
Francis
发布于
2025年11月22日
许可协议