论文:Interactive Text-to-ImageRetrieval with Large Language Models: A Plug-and-Play Approach

1. Introduction

ChatIR是一种基于聊天的文本到图像的方法,提出了利用LLM进行多轮对话以提高检索效率。(我的阅读笔记见复现|ChatIR

这个方法具有以下不足:

  1. 需要进行微调文本编码器,来适应多轮对话数据
    • 微调耗费资源、可扩展性差。
    • 解决方法:将对话重构为可以直接输入到预训练的视觉语言模型的格式,不需要对模型进行微调
  2. LLM提问者G只知道对话历史,无法查看候选图像
    • 可能生成图像中不存在的属性的询问
    • LLM提问者基于候选集提问,确保问题与图像属性相关

本文提出的PlugIR系统包含2个部分:上下文重构(Context Reformulation)和上下文感知对话生成(Context-aware Dialogue Generation)

PlugIR

本文贡献:

  1. 实证0样本或微调的大模型难以理解对话数据
  2. 提出了一种LLM提问者,解决了因冗余问题和噪声问题带来的性能瓶颈
  3. 提出新指标BRI(最佳对数排名积分,Best log Rank Integral)
    • 比Recall@K和Hit@K更接近人的评价,更全面地评估交互式检索系统
  4. PlugIR具有即插即用的特性,且具有实用性
  • 文本到图像检索
  • 视觉语言模型:BLIP、CLIP
  • 大语言模型LLM

3. Method

3.1 Preliminaries:InteractiveText-to-Image Retrieval 交互式文本到图像检索

对话记录表示:$D_i = (C, Q_1, A_1, …, Q_i)$

  • $C$:目标图像的初始文本描述(标题)
  • $Q_i$:第i个问题
  • $A_i$:第i个回答

检索系统将数据库中的所有图片与文本进行匹配,根据相似度进行排序,根据目标排名评估系统性能。

  • Recall@K:本轮交互检索到的前K张图片包含目标的概率
  • Hit@K:本轮以及任意一轮的交互检索到的前K张图片包含目标的概率

3.2 Context Reformulation 上下文重构

作者测试了0样本的CLIP、BLIP、BLIP-2和一个黑箱模型ATM

  • Hit@K逐步提升,但这是由其定义决定的
  • Recall@K在仅包含最初的文本描述时最高,随着对话轮次增加而下降
    • 对话在0样本模型上可能没有贡献、产生了噪声
    • 0样本模型无法理解对话数据

为了解决这个问题,一种方法是像ChatIR一样对模型进行微调,但这样做有以下限制:

  1. 不能使用黑箱模型,比如ATM
  2. 需要大量的训练数据

本文不直接使用对话作为输入进行查询,而是将对话重构为可以直接输入到预训练的视觉语言模型的格式,不需要对模型进行微调(即所谓的Plug-and-Play)。

3.3 Context-aware Dialogue Generation 上下文感知对话生成

仅靠对话历史生成问题具有以下问题:

  1. 生成的问题可能与图像属性无关
  2. 可能询问历史对话中已有信息

提问过程(用于解决问题1):

  1. 使用重构后的查询语句进行检索,找出高相似度的“检索候选”图像集
  2. 对候选图像Embeddings进行K-means聚类,得到每个候选图像与其他图像的相似度得分分布
  3. 对于每个聚类,选择相似度分布熵最小的图像作为代表
    • 熵越小,属性越真实、越容易区分
    • 例如,同一组图像对“一张配有2台电脑显示器和一副键盘的桌子”的描述熵更低,对“办公室”的描述熵更高
  4. 将这K副图像通过image2text模型生成caption,作为附加信息提供给LLM提问者

提问(算法1)伪代码:

  1. 输入:对话上下文$c$、图像库$I$、“检索候选”图像数$n$、聚类数$m$、相似度函数$sim$、$KMeans$、i2t模型$Captioning$
  2. 从$I$中选出前$n$个和$c$最相似的图像,作为$S_R$
    1. 初始化$S_R \leftarrow {}$
    2. $while S_R.size() < n do$
      1. 将和$c$最相似的图像$x$加入$S_R$
      2. 将$x$从$I$中移除
  3. 对$S_R$进行$KMeans$聚类,得到$m$个聚类$S_R^{(1)}, S_R^{(2)}, …, S_R^{(m)}$
  4. 计算每个图像相对$S_R$的概率,使用Softmax得到$P_c(x)=\frac{exp(sim(c, x))}{\sum_{x’ \in S_R} exp(sim(c, x’))}$
  5. 从每个簇$S_R^{(i)}$中选择最优的图像,并对这$m$个图像进行$Captioning$,得到$T$
    1. $for i in range(1,m+1) do$
      1. 计算当前簇$S_R^{(i)}$中所有图像的熵,并找出最小熵的图像$\hat x^{(i)}$
      2. 对$\hat x^{(i)}$进行$Captioning$,并加入$T$
  6. 返回:$T$

采用思维链(Chain of Thought)的方法,提示词位于原文18~19页,获取与图像相关的问题。
这样生成的问题仍然可能冗余(已经知道答案),还需要经过过滤。

过滤过程(用于解决问题2):

  1. 通过上下文回答函数,判断问题是否“确定”,选取“不确定”的问题
  2. 选择“不确定”的问题中KL散度最小的问题
    1. KL散度:$KL(P_c||P_{c,q})=\sum_{x \in T} P_c(x)log\frac{P_c(x)}{P_{c,q}(x)}$
    2. 用于防止不合适的问题导致相似度骤变

过滤(算法2)伪代码:

  1. 输入:对话上下文$c$、问题集合$Q$、检索候选集$T$、相似度函数$sim$、上下文回答函数$Answer$
  2. 定义计算上下文概率分布的函数
    1. 图像$x$在上下文$c$下的分布:$P_c(x)=\frac{exp(sim(c, x))}{\sum_{x’ \in T} exp(sim(c, x’))}$
    2. 加入问题$q$后图像$x$在上下文$c$下的分布:$P_{c,q}(x)=\frac{exp(sim(concat(c, q), x))}{\sum_{x’ \in T} exp(sim(concat(c, q), x’))}$
  3. 筛选出答案“不确定”的问题,作为$Q’$
    1. 初始化$Q’ \leftarrow {}$
    2. $for q in Q do$
      1. 如果$Answer(c, q)$为“不确定”,则加入$Q’$
  4. 选择KL散度最小的问题$\hat q$
  5. 返回:$\hat q$

3.4 The Best Log Rank Integral (BRI) Metric 最佳对数排名积分

作者指出,在评估交互式检索系统时,有3个关键点:

  1. 用户满意度:在多少次交互中至少找到了一次目标图像算满意
  2. 效率:成功检索所需轮次越少越好
  3. 排名提升意义:排名靠前时提升排名的意义更大,如从2到1比从100到99更有意义

Recall@K用于非交互式检索;Hit@K只考虑了用户满意度

作者提出了BRI指标,综合了用户满意度、效率和排名提升意义

记:$Q$问题集合、$T$最大轮次

$\pi(q_t)$:表示具有$t$轮对话的查询$q_t$,在这$t$轮查询中,目标图像的历史最佳排名,用于衡量用户满意度

BRI:$\mathbb E_{q \in Q}\left[ \dfrac{1}{2T}\log\pi(q_0)\pi(q_T)+\dfrac{1}{T}\sum\limits_{t=1}^{T-1}\log\pi(q_t)\right]$

  • 边界项:$\dfrac{1}{2T}\log\pi(q_0)\pi(q_T)$
    • 权重较小,反映初始查询$q_0$到最终查询$q_T$的排名改善情况
  • 平均查询排名项:$\dfrac{1}{T}\sum\limits_{t=1}^{T-1}\log\pi(q_t)$
    • 计算了查询$q$的所有$t$轮中,目标的历史最佳排名的对数的均值
    • 对数函数使得低排名的进一步降低对BRI变化的影响更大
  • BRI越小,性能越好
  • BRI不依赖于具体的K值,更全面、统一
  • 实验表明,BRI与人类评价更接近

4. Experiments

  • 数据集:Visdail、COCO、Flickr30k
  • 文本到图像检索模型:默认BLIP,也有BLIP-2、ATM
  • LLM提问者:ChatGPT
  • 测试集回答者:BLIP-2
  • 聚类数m:10

Baseline:0-shot、ChatIR

同时进行了Ablation Study,测试了不同组件的加入对结果的影响

总结和实现

  • PlugIR系统也是一个基于对话的图像检索系统,在ChatIR的基础上进行了改进
  • 主要优化了提问过程,使得提问的有效性提升
  • 使用了新的评估指标BRI,能够更全面地评估交互式检索系统

由于系统代码量较大,在实现时划分到多个文件

config.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch

config = {
"hitk" : 10, # hit@k
"q_num" : 5, # 待select的问题数量
"threshold_low" : 500, # 低阈值,用于计算Kmeans聚类时的样本数
"gpt_model" : 'gpt-4o-mini', # OpenAI模型名称
"api_key" : "", # OpenAI API密钥
"vqa_model" : 'Salesforce/blip2-flan-t5-xl', # VQA模型名称
"retriever" : "blip", # 检索器名称,blip或clip
"device" : torch.device("cuda" if torch.cuda.is_available() else "cpu"),
"sep_token" : ", ", # ChatIR分隔符
"eval_caption" : True, # eval时是否为PlugIR总结的Caption,True:PlugIR,False:ChatIR

# 算法配置
"referring" : True, # 是否参考候选集Caption提问(算法1)
"filtering" : True, # 是否AI过滤问题(算法2)
"select" : True, # 是否使用KL散度选择问题(算法2)
"reconstruct" : True, # 是否重构对话

# 路径配置
# 图片路径前缀,路径=dir_prefix+<img_path>("unlabeled2017/xxx.jpg")
"dir_prefix" : "./",
# Visdial数据,[{"img":"<img_path>", "dialog":["<caption>", "Q? A", "Q? A", ...]}, ...],2064
"visdial_path" : "./dialogues/VisDial_v1.0_queries_val.json",
# 搜索空间,["<img_path>", "<img_path>", ...],50000
"search_space" : "./Protocol/Search_Space_val_50k.json",
# Visdial数据,仅保留caption,[{"id": "<img_path>", "caption": ["<caption>"]},...],50000
"captions_path" : "./ChatIR/ChatIR_Protocol/visdial_captions.json",
# blip预处理的embeddings
"img_emb_path" : "./ChatIR/temp/corpus_finetuned_blip.pth",
}

OpenAI.py

所有函数统一返回response对象,包含了所有信息。

messages.py中定义了消息格式,包含prompt信息,具体文本见论文附录。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import openai
import time

import API.messages as messages
from config import config

client = openai.OpenAI(api_key=config["api_key"])

SEED = 1021 # 随机种子 TODO:不设置的效果

def reconstruct_dialog(dialog:list[str], temperature:float=.0, model:str='gpt-4o-mini')->str:
"""重构对话框

将对话中的有效信息总结为一段描述。

Args:
dialog : 对话列表
temperature : 采样温度,越小越保守 [0.0,2.0]
model : 模型名称

"""
retry_count = 0
while True:
try:
response = client.chat.completions.create(
model=model,
messages=messages.reconstruct_dialog_message(dialog),
n=1,
temperature=temperature,
seed=SEED,
max_tokens=512)
break
except Exception as e:
retry_count += 1
print(f"Error: {e}")
time.sleep(3)
if retry_count > 5:
raise Exception("Retry limit exceeded!")
print(f"Retry {retry_count} times...")
continue
return response

def generate_questions(dialog:list[str], n:int=1, model:str='gpt-4o-mini')->list[str]:
"""Baseline: 1-shot生成问题 ChatIR

仅参考历史对话生成新的问题。

Args:
dialog : 对话列表
n : 生成问题数量
model : 模型名称

Returns:
response :
"""
retry_count = 0
while True:
try:
response = client.chat.completions.create(
model=model,
messages=messages.generate_questions_message(dialog),
n=n,
temperature=0.5,
max_tokens=32)
break
except Exception as e:
retry_count += 1
print(f"Error: {e}")
time.sleep(3)
if retry_count > 5:
raise Exception("Retry limit exceeded!")
print(f"Retry {retry_count} times...")
continue

return response

def generate_questions_referring(dialog:list[str], prompt_related_captions:str="", questions:list=[], n:int=1, model='gpt-4o-mini'):
"""利用思维链CoT和候选集Captions生成新问题

Args:
dailog : 对话列表
prompt_related_captions : 预处理的候选集Caption
questions : 历史问答对
n : 生成问题数量
model : 模型名称
"""
retry_count = 0
message = messages.generate_questions_referring_message(dialog, prompt_related_captions, questions)
while True:
try:
response = client.chat.completions.create(
model=model,
messages=message,
n=1,
temperature=0.5,
max_tokens=32)
break
except Exception as e:
retry_count += 1
print(f"Error: {e}")
time.sleep(3)
if retry_count > 5:
raise Exception("Retry limit exceeded!")
print(f"Retry {retry_count} times...")
continue

return response

def filter_questions(context:str, question:str, model='gpt-4o-mini'):
"""过滤问题

判断问题是否“Uncertain”

Args:
context : 上下文
question : 问题
model : 模型名称
"""
retry_count = 0
while True:
try:
response = client.chat.completions.create(
model=model,
messages=messages.filter_questions_message(context, question),
n=1,
temperature=.0,
max_tokens=32)
break
except Exception as e:
retry_count += 1
print(f"Error: {e}")
time.sleep(3)
if retry_count > 5:
raise Exception("Retry limit exceeded!")
print(f"Retry {retry_count} times...")
continue

return response

def paraphrase(text:str="", model='gpt-4o-mini'):
"""重述

重述给定的文本。

Args:
text : 待重述文本
model : 模型名称
"""
retry_count = 0
while True:
try:
response = client.chat.completions.create(
model=model,
messages=messages.paraphrase(text),
n=1,
temperature=0.7,
top_p=0.8,
max_tokens=512)
break
except Exception as e:
retry_count += 1
print(f"Error: {e}")
time.sleep(3)
if retry_count > 5:
raise Exception("Retry limit exceeded!")
print(f"Retry {retry_count} times...")
continue

return response

utils.py

实现了特征提取、K-means聚类、KL散度计算、获取簇中心caption、熵计算等函数功能,使主程序代码简洁易读。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import torch
from torch.nn.functional import normalize
from transformers import BlipForImageTextRetrieval,AutoProcessor
import json
from torch.utils.data import Dataset, DataLoader
import os
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import tqdm
import torch.nn.functional as F
from config import config

with open('./Protocol/visdial_captions.json', 'r') as ss_cap_json:
captions = json.load(ss_cap_json) # length:50000
# {"id": "<img_path>", "caption": ["<caption>"]}

class BlipForRetrieval(BlipForImageTextRetrieval):
def get_text_features(
self,
input_ids: torch.LongTensor, # Tokenized input IDs
attention_mask: torch.LongTensor | None = None,
return_dict: bool|None = None,
) -> torch.FloatTensor:
"""获取文本特征

Args:
input_ids : 文本的token ID(即经过分词后的输入)
attention_mask : 注意力掩码
return_dict : 是否返回字典
"""
return_dict=return_dict if return_dict is not None else self.config.use_return_dict

text_embeddings = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=return_dict,
)
text_embeddings = text_embeddings[0] if not return_dict else text_embeddings.last_hidden_state
return normalize(self.text_proj(text_embeddings[:, 0, :]), dim=-1) # [:,0,:]取的是[CLS]标记的隐藏状态,它通常被用作整个句子的特征表示

def get_image_features(
self,
pixel_values: torch.FloatTensor,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
) -> torch.FloatTensor:
"""获取图片特征

Args:
pixel_values : 图片像素值
output_attentions : 是否输出注意力
output_hidden_states : 是否输出隐藏状态
return_dict : 是否返回字典
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

vision_outputs = self.vision_encoder(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeddings = vision_outputs[0]
return normalize(self.vision_proj(image_embeddings[:, 0, :]), dim=-1)

processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-large-coco")
model = BlipForRetrieval.from_pretrained("Salesforce/blip-itm-large-coco").to(config["device"])

def get_text_features(text: str, model=model, processor=processor) -> torch.FloatTensor:
"""获取文本特征

Args:
text : 文本
model : Pretrained Model,用于获取文本特征
processor : 文本处理器,用于将文本转换为模型输入
"""
text_encodings = processor(text=text, padding=True, return_tensors="pt").to(config["device"])
return model.get_text_features(**text_encodings)

# ----------------------------------------------------------------------------------------------------------------------------------------------------

def search_imgs(query="", img_embs=None, search_space=None, k=10):
"""搜索前k个相关图片

Args:
query : 查询文本
img_embs : 图片特征
search_space : 搜索空间,图片路径列表
k : 返回的图片数量
Returns:
related_imgs : 前k个相关图片
related_indices : 前k个相关图片的索引
cos_sim : 余弦相似度
"""
query_emb = get_text_features(query)
query_emb_norm = normalize(query_emb, dim=-1) # 归一化查询特征
cos_sim = torch.matmul(query_emb_norm, img_embs.T).squeeze() # 计算余弦相似度
related_indices = cos_sim.sort()[1][-k:]
# related_indices = cos_sim.topk(k).indices
related_imgs = []
for idx in range(k):
related_imgs.append(search_space[related_indices[idx].item()])

return related_imgs, related_indices, cos_sim

from fast_pytorch_kmeans import KMeans
kmeans = KMeans(n_clusters=10, mode='cosine', verbose=0)

def get_related_captions(caption_recon, round=1, threshold_low=500, img_embs=None, captions=None):
"""获取簇中心的相关图片描述

将low-(round-1)*(low/10.0)个最相关的图片作为候选集S_R,进行k-means聚类,
得到10个簇中,信息熵最小的图片描述作为当前轮次的相关图片描述。

Args:
caption_recon : 对话上下文(重构的)
round : 当前轮次
threshold_low : 根据轮次计算相关图片的数量 low-(round-1)*(low/10.0)
img_embs : 图片特征
captions : 图片描述
"""
caps = []
# related_size = 100 - (round-1)*10
related_size = int(threshold_low - (round-1) * (threshold_low / 10.0)) # 候选集大小 500-(round-1)*50
emb = normalize(get_text_features(caption_recon), dim=-1)
sim = torch.matmul(emb, img_embs.T).squeeze() # 计算余弦相似度
topk = sim.argsort()[-related_size:] # 获取前related_size个相关图片的索引
img_embs_topk = img_embs[topk] # 获取前related_size个相关图片的特征

entropies = torch.zeros([related_size])
for i in range(related_size):
cap = captions[topk[i].item()]['caption']
emb = normalize(get_text_features(cap), dim=-1)
sim = torch.matmul(emb, img_embs_topk.T).squeeze()
p = torch.nn.functional.softmax(sim, dim=0)
entropy = (-p * p.log()).sum().detach().cpu() # 信息熵
entropies[i] += entropy
idx_entropies_sorted = entropies.argsort()

cluster_label = kmeans.fit_predict(img_embs_topk)
cluster_label_sorted = cluster_label[idx_entropies_sorted]
for i in range(10):
if (cluster_label_sorted == i).any():
idx_c = (cluster_label_sorted == i).nonzero().squeeze().min()
caps.append(captions[topk[idx_entropies_sorted[idx_c]].item()]['caption'][0])

# for i in range(10):
# caps.append(captions[topk[entropies.argsort()[i]].item()]['caption'][0])

return caps

def get_referring_prompt(caption="", img_embs=None, k=10, round=1, search_space=None):
""" 进行k-means聚类,获取相关图片描述;获取前k个相关图片的索引和余弦相似度。

Args:
caption : 图片描述
img_embs : 图片特征
k : 返回的图片数量
round : 当前轮次
search_space : 搜索空间,图片路径列表
Returns:
prompt_sys : 系统prompt
prompt_related_captions : 相关图片描述
top_k : 前k个相关图片的索引
cos_sims : 余弦相似度
"""
img_paths, top_k, cos_sims = search_imgs(caption, img_embs, search_space, k=k)
related_captions = get_related_captions(caption, round, img_embs=img_embs, captions=captions)

prompt_sys = ""
prompt_sys += "You should leverage the 'related_captions Information' that is related to the target image "
prompt_sys += "corresponding to the caption but does not match the target image."
# 你需要利用与目标图像相关的“虚假信息”,该信息与目标图像标题相关,但与目标图像不匹配。

prompt_related_captions = ""
for i in range(len(related_captions)):
prompt_related_captions += str(i) + '. ' + related_captions[i] + '\n'
# 1. caption1
# 2. caption2
# ...

return prompt_sys, prompt_related_captions, top_k, cos_sims

def select_question(caption_recon="", questions=[], cossim_prev=None, k=10, img_embs=None, threshold=500, round=1):
"""选择KL散度最小的问题
根据上一次的相似度,选择KL散度最小的问题。

Args:
caption_recon : 对话上下文(重构的)
questions : 问题列表
cossim_prev : 上一轮的相似度
k : 返回的图片数量
img_embs : 图片特征
threshold : 相关图片的数量 low-(round-1)*(low/10.0)
round : 当前轮次
Returns:
str : 选择的问题
"""
threshold = int(threshold - (round-1) * (threshold / 10.0))
idx_related = cossim_prev.argsort()[-threshold:-k]
p_prev = torch.nn.functional.softmax(cossim_prev[idx_related], dim=0)
kl_divs = torch.zeros([len(questions)])

for i, ques in enumerate(questions):
caption_tmp = caption_recon + ", " + ques

query_emb_tmp = normalize(get_text_features(caption_tmp), dim=-1)
cossim_tmp = torch.matmul(query_emb_tmp, img_embs.T).squeeze()
p_tmp = torch.nn.functional.softmax(cossim_tmp[idx_related], dim=0)
kl_div = (p_prev*(p_prev.log() - p_tmp.log())).sum().detach().cpu()
kl_divs[i] += kl_div

idx_final = kl_divs.argsort()[0].item()

return questions[idx_final]

系统实现:PlugIR_exec.py

PlugIR的运行版,实现利用多轮对话进行图像检索的功能,描述以及每次问答后显示当前最相关的图片。

改写为PlugIR_func.py,实现了函数化,便于后续批量生成对话数据用于evaluation。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from API import OpenAI
from PIL import Image
import utils
import json
from config import config

img_embs = torch.load(config["img_emb_path"], map_location=config["device"])[1] # blip预处理的embeddings

if not config["filtering"]:
config["q_num"] = 1 # 不过滤问题,生成1个问题

with open(config["search_space"], 'r') as ss_json:
search_space = json.load(ss_json) # length:50000
search_space = [config["dir_prefix"] + path for path in search_space]
# ["<img_path>", "<img_path>", ...]

with open(config["visdial_path"], 'r') as diag_json:
visdial = json.load(diag_json) # length:2064
# [{"img":"<img_path>", "dialog":["<caption>", "Q? A", "Q? A", ...]}, ...]

with open(config["captions_path"], 'r') as ss_cap_json:
captions = json.load(ss_cap_json) # length:50000
# [{"id": "<img_path>", "caption": ["<caption>"]},...]

def ask_for_caption():
"""询问用户描述"""
caption = input("Caption: ")
return caption.strip()

def ask_question(question):
"""询问用户问题"""
ans = input(f"{question} ")
return ans.strip()

def display_image(img_path):
"""显示图片"""
img = Image.open(img_path)
img.show()

if __name__ == "__main__":
caption = ask_for_caption() # 询问用户描述
dialogue = [caption] # 对话框
caption_recon = caption # 重构后的描述,用于匹配图片
caption_recons = [caption] # 重构描述列表

for round in range(10): # 10轮对话
questions = [] # 合法问题列表

if config["referring"]: # 参考候选集的Caption提问
ques_prior = [] # 非法问题列表
_, prompt_related_captions, top_k, cos_sims = utils.get_referring_prompt(
caption_recon,
img_embs=img_embs,
k=config["hitk"],
round=round+1,
search_space=search_space)
display_image(search_space[top_k[-1]]) # 显示第一张图片

for k in range(config["q_num"]): # 生成config["q_num"]个问题
if config["filtering"]: # AI过滤问题
for _ in range(3): # 尝试3次根据候选集Caption生成问题
resp = OpenAI.generate_questions_referring(
dialog=dialogue,
prompt_related_captions=prompt_related_captions,
questions=ques_prior,
model=config["gpt_model"]
).choices[0].message.content
question = resp.split('?')[0]+'?'
fq = OpenAI.filter_questions(caption_recon, question).choices[0].message.content
if "uncertain" not in fq.lower(): # 如果不uncertain,问题非法
ques_prior.append(question)
else : # 如果uncertain,问题合法
break
if len(ques_prior) == 3: # 三次非法问题,尝试直接询问一次,再不行,就固定询问其他对象
response = OpenAI.generate_questions(dialog=dialogue, n=1).choices[0].message.content
question = response.split('?')[0]+'?'
fq = OpenAI.filter_questions(caption_recon, question).choices[0].message.content
if "uncertain" not in fq.lower():
question = "what is the other object in the image?"
else: # 不过滤问题
response = OpenAI.generate_questions_referring(
dialog=dialogue, prompt_related_captions=prompt_related_captions, questions=ques_prior, n=1, model=config["gpt_model"]
).choices[0].message.content
question = response.split('?')[0]+'?'
questions.append(question)

else: # ChatIR:直接提问
_, top_k, cos_sims = utils.search_imgs(
query=caption_recon,
img_embs=img_embs,
search_space=search_space,
k=config["hitk"]
)
display_image(search_space[top_k[-1]]) # 显示第一张图片
for k in range(config["q_num"]): # 生成config["q_num"]个问题
question = OpenAI.generate_questions(
dialog=dialogue, n=1, model=config["gpt_model"]
).choices[0].message.content
questions.append(question)

if config["select"]: # 选择KL散度最小的问题
question_final = utils.select_question(
caption_recon=caption_recon,
questions=questions,
cossim_prev=cos_sims,
k=config["hitk"],
img_embs=img_embs,
threshold=config["threshold_low"],
round=round+1
)
else:
question_final = questions[0]

answer = ask_question(question_final)
qa = question_final + ' ' + answer
dialogue.append(qa)

if config["reconstruct"]: # 重构对话上下文
caption_recon = OpenAI.reconstruct_dialog(
dialog=dialogue,
model=config["gpt_model"]
).choices[0].message.content
if caption_recon == caption_recons[-1]:
caption_recon = OpenAI.paraphrase(caption_recon, model=config["gpt_model"]).choices[0].message.content
else: # ChatIR:", "连接"
caption_recon = ', '.join(dialogue)
caption_recons.append(caption_recon)

运行效果:

PlugIR_exec

对话数据生成:test_gen.py

为了自动获取测试数据,免除人工回答,使用了BLIP2模型回答问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from lavis.models import load_model_and_preprocess
from config import config
from PlugIR_func import generate_question
import json
from PIL import Image
from API import OpenAI
from tqdm import tqdm

with open(config["visdial_path"], 'r') as diag_json:
visdial = json.load(diag_json) # length:2064
# [{"img":"<img_path>", "dialog":["<caption>", "Q? A", "Q? A", ...]}, ...]
import sys

# st,ed = 5,100 # 目标范围[0,2064]
st,ed = int(sys.argv[1]), int(sys.argv[2]) # 目标范围[0,2064]
images = [visdial[i]["img"] for i in range(st,ed)] # 目标图片列表
target_captions = [visdial[i]["dialog"][0] for i in range(st,ed)] # 目标图片对应的caption作为用户描述

from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor2 = Blip2Processor.from_pretrained(config["vqa_model"])
blip2 = Blip2ForConditionalGeneration.from_pretrained(config["vqa_model"], device_map={"": 0}, torch_dtype=torch.float16)

def reconstruct_caption(dialogue, caption_recons):
caption_recon=""
if config["reconstruct"]: # 重构对话上下文
caption_recon = OpenAI.reconstruct_dialog(
dialog=dialogue,
model=config["gpt_model"]
).choices[0].message.content
if caption_recon == caption_recons[-1]:
caption_recon = OpenAI.paraphrase(caption_recon, model=config["gpt_model"]).choices[0].message.content
else: # ChatIR:", "连接
caption_recon = config["sep_token"].join(dialogue)
return caption_recon

def generate_dialog(idx):
"""生成对话数据

Args:
idx : 图片索引(st->0)
"""
image = Image.open(images[idx])
caption = target_captions[idx] # 目标图片对应的caption作为用户描述
tqdm.write(f"\nCaption {idx}: {caption}")
dialogue = [caption] # 对话框
caption_recon = caption # 重构后的描述,用于匹配图片
caption_recons = [caption] # 重构描述列表
for i in range(10): # 10轮对话
question = generate_question(dialogue, caption_recon)
prompt = f"Question: {question} Answer: "
_inputs = processor2(images=image, text=prompt, return_tensors='pt').to(config["device"])
_outputs = blip2.generate(**_inputs, do_sample=False)
answer = processor2.decode(_outputs[0], skip_special_tokens=True).strip()
qa = question.strip()+ ' ' + answer
dialogue.append(qa)
caption_recon = reconstruct_caption(dialogue, caption_recons)
caption_recons.append(caption_recon)
tqdm.write(f"Round {i+1}: {caption_recon}")
ret = {
"img": images[idx],
"dialog": caption_recons
}
return ret

if __name__=="__main__":
filename = "dialogues/"+"mine_" + config["gpt_model"]+"_"+"BLIP2"+".txt"
with torch.no_grad():
with open(filename, 'a') as f:
for idx in tqdm(range(len(images)), desc="对话生成进度", position=0):
dialog = generate_dialog(idx)
f.write(json.dumps(dialog, ensure_ascii=False)+"\n")

Debug日志

测试使用了项目仓库的eval.py源代码。

可能是我使用Windows系统的缘故,eval.py代码中有一些报错,具体问题和调整如下:

  1. AttributeError: Can't pickle local object 'BLIP_ZERO_SHOT_BASELINE.<locals>.<lambda>'
    • 原因:在Windows上使用多进程(num_workers>0)时,需要pickle对象,但是其中的lambda函数或局部函数不能被pickle
    • 解决:将lambda函数改为全局函数,再使用functools.partial进行参数绑定
  2. RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
    • 原因:函数text_encode_fn中的processor没有to(device)
    • 解决:多做一步,将processor的输出``to(device)`

尝试跑通generate_dialog.py的过程中也遇到了上述问题,解决方法同。

测试结果

正常情况下3~5分钟可以生成1条数据(大约80次请求)。
但由于我生成数据短时间内大量调用OpenAI API,导致被限流,20~30分钟才能产生1条数据,故本次实现在eval阶段仅有253条数据。

在使用ChatGPT-4o-mini作为提问模型G,BLIP2作为回答模型A,使用同一测试代码的情况下,测试结果和仓库的对话数据Hit@K对比如下:

length 仓库数据Hit@10(2064 testcases) 实现数据Hit@10(253 testcases)
0 71.12% 72.33%
1 79.02% 81.42%
2 83.09% 83.40%
3 85.85% 85.38%
4 87.55% 86.56%
5 88.71% 87.75%
6 89.39% 88.14%
7 90.12% 88.54%
8 90.70% 88.93%
9 91.09% 89.33%
10 91.47% 90.12%

BRI对比(越低越好):

  • 仓库对话的BRI:10.195615768432617
  • 实现对话的BRI:10.252569198608398