论文:Chatting Makes Perfect: Chat-based Image Retrival

1. Introduction

ChatIR系统包含2个部分:对话构建(Dialog Building)和图像搜索(Image Search)

  • 对话构建:使用问题生成器G,考虑当前对话历史,生成下一个问题
  • 图像搜索:使用模型F,将不同长度的对话序列映射到视觉嵌入空间
  • 两个组成部分建立在Instructional LLMs和fundation Vision and Language Models上

Figure1

考虑三个问题:

  1. 使用什么数据集训练?是否需要新创建和标注数据集?
    • 使用VisDial数据集
    • 问题:VisDial是一个用于“创建关于图像的聊天”的数据集,没有检索目标
    • 解决:输入输出反置,对话作为输入、图像作为输出
  2. 如何独立评估ChatIR系统的不同组件?
    • 测试使用不同的F训练策略提问模型G对检索性能的影响
    • 使用了BLIP替代用户回答问题
  3. 如何定义评估指标?
    • 每一轮对话的成功检索概率Hit@10
  1. 视觉对话(Visual Conversation)领域
    • 当前视觉领域工作的重点在于图像理解和生成模型,而不是检索
    • 生成式视觉对话中,近期的基础模型V&L性能优越,因此ChatIR系统在此基础上构建
  2. 视觉搜索(Visual Search)领域
    • CoIR:使用多模态询问查找目标图像
    • 一些研究基于CoIR,利用用户反馈细化查询结果
    • 但是,没有考虑用户交互(只有用户反馈,没有机器提问)、没有明确利用对话历史

3. Method

3.1 Dialog Builder Model 对话生成模型

对话记录表示:$D_i = (C, Q_1, A_1, …, Q_i)$

  • $C$:目标图像的初始文本描述(标题)
  • $Q_i$:第i个问题
  • $A_i$:第i个回答

对话生成模型包含两个部分:

  • 问题生成器G:一个LLM,根据对话记录生成下一个问题
    • $G: Di \rightarrow Q{i+1}$
    • G不知道目标图像T是什么,只知道对话历史
  • 答案提供者A:在实践中,通常是一个脑海中有大致目标图像的人类
    • 由于需要大规模实验,不能依赖用户提供答案
    • 因此,使用了一个现成的模型BLIP2来回答

3.2 Image Retrieval Model 图像检索模型

图像搜索过程:在查询目标图像共享的视觉嵌入空间中,搜索匹配项。

  • 所有的目标图像先经过图像嵌入模块进行编码,由一个$d$维的特征向量$f\in \mathbb R^d$表示。
  • 图像检索模块F将对话历史$D_i$映射到视觉嵌入空间,$F: D_i \rightarrow \mathbb R^d$,
  • 候选对象根据相似度进行排序。

引入分隔符[SEP]和添加符[CLS],表示整个对话序列,投射到视觉嵌入空间。

F采用(使用BLIP)预训练的图像/文本编码器,并通过对比学习,对基于对话的检索进行微调。

通过提取VisDial数据集中的图像和相应对话,手动标注,训练F。

4. Evaluation

在评估环节,原文使用了Hit@10指标,即目标图像在前10个检索结果中的试验占比。

原文从三个方面进行了对比:

  1. 与现有文本到图像(Text to Image,TTI)检索方法的比较
    1. ChatIR使用ChatGPT作为提问者G,BLIP2作为回答者A
    2. 与Zero-shot的BLIP、CLIP以及fine-tuned SoTA TTI BLIP进行比较
    3. 结论:ChatIR在多轮对话环境中,相比传统单跳 TTI 方法表现更优
  2. 不同提问者G的比较
    1. 使用ChatGPT、FLAN-ALPACA-XXL、人类等8中不同提问者
    2. ChatGPT表现最好
  3. 人类参与对话的影响
    1. ChatGPT提问,人类回答;ChatGPT提问,BLIP2回答;人类提问,人类回答
    2. 由于人类生成的答案质量明显优于BLIP2,因此人类参与对话时,检索性能会比测试的数据更好

总结和复现

  • ChatIR系统是一个基于对话的图像检索系统,包含对话构建和图像搜索两个部分
  • 对话构建:使用问题生成器G,考虑当前对话历史,生成下一个问题
    • 原文测试了不同的问题生成器,其中ChatGPT表现最好
  • 图像搜索:使用模型F,将不同长度的对话序列映射到视觉嵌入空间
    • 我使用了论文仓库提供的预训练BLIP_ITM模型
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

import os.path
from PIL import Image
from tqdm import tqdm
import json

import multiprocessing # 多进程处理
multiprocessing.set_start_method('spawn', force=True)

from OpenAI import request_chat # 自己实现的调用API函数 request_chat(dialog:list) -> new_question:str

import sys # 添加import路径
sys.path.insert(0, './BLIP')
from BLIP.models.blip_itm import blip_itm # BLIP用于图像文本匹配的预训练模型

config = {
"corpus_path": "VisualDial/search_space.json", # 图像库路径
"queries_path": "dialogues/ChatGPT4oMini_BLIP2.json", # 测试对话数据路径
"corpus_cache": "VisualDial/corpus_cache.pth", # 处理好的图像库缓存的路径
"device": "cuda" if torch.cuda.is_available() else "cpu",
"sep": ", ", # 对话分隔符
"batch_size": 100, # 批处理大小
"num_workers": 8, # 多进程处理数
"image_size" : 224 # 图像大小
}
corpus = None # 处理好的图像库
dialog = [] # 对话
images = [] # 图像库路径 list[图像路径]

图像数据集类

继承自Dataset,用于加载图像数据集

corpus_path:图像数据集.json,包含一个list[str],每个元素是一个图像的路径

建立路径字符串到索引的映射,便于赋值传递和查询

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Corpus(Dataset):
"""图像数据集"""
def __init__(self, corpus_path, preprocessor):
"""加载图像数据集

Args:
corpus_path: 图片路径列表
preprocessor: 图像预处理函数
"""
with open(corpus_path, "r") as f:
self.corpus = json.load(f)
f.close()
self.preprocessor = preprocessor
# 图片路径到索引的映射,用于快速查找
self.path2idx = {self.corpus[i]:i for i in range(len(self.corpus))}

def __len__(self):
return len(self.corpus)

def __getitem__(self, idx):
image = self.preprocessor(self.corpus[idx])
return {"idx": idx, "image": image}

def path_to_index(self, path):
return self.path2idx[path]

图像预处理函数

1
2
3
4
5
6
7
8
9
10
11
def image_preprocessor(image_path):
transform_prep = transforms.Compose([
transforms.Resize((config["image_size"], config["image_size"]),
interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)) # 参数参考BLIP的demo
])
raw = Image.open(image_path).convert("RGB")
img = transform_prep(raw)
return img

BLIP_ITM模型的图像编码器和对话编码器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def get_funcs():
# model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
model = blip_itm(pretrained='chatir_weights.ckpt', # 论文仓库中的预训练模型
med_config="BLIP/configs/med_config.json",
image_size=config["image_size"],
vit="base")
device = config["device"]
model = model.to(device).eval()

def image_encoder(img):
embeddings = model.visual_encoder(img) # embedding
# print(embeddings.shape) # (批次大小, patch个数+1, 隐层维度)
vision_proj = model.vision_proj(embeddings[:, 0, :]) # 取[CLS] token,提取全局特征
return F.normalize(vision_proj, dim=-1) # 正则化

def dialog_encoder(dialog):
text = model.tokenizer(dialog, padding='longest', truncation=True, max_length=200, # 填充到最长,截断到200
return_tensors="pt").to(device) # 返回PyTorch张量
text_out = model.text_encoder(text.input_ids, attention_mask=text.attention_mask,
return_dict=True, mode='text') # embedding
shift = model.text_proj(text_out.last_hidden_state[:, 0, :]) # 同
return F.normalize(shift, dim=-1)

return image_encoder, dialog_encoder

处理图像库

由于加载时间长,一次加载后将数据存储在本地,便于二次调用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
corpus_dataset = Corpus(config["corpus_path"], image_preprocessor)

def prepare_corpus(image_encoder):
"""处理图像库

Returns:
corpus: tuple[torch.Tensor, torch.Tensor] 图像库索引和对应的图像特征向量
"""
global corpus
corpus_cache = config["corpus_cache"]
if corpus_cache and os.path.exists(corpus_cache): # 读取缓存
print(f"-----Loading corpus from {corpus_cache}-----")
corpus = torch.load(corpus_cache)
return

print("-----Preparing corpus-----")
corpus_dataloader = DataLoader( # 图像库的DataLoader
corpus_dataset,
batch_size=config["batch_size"],
shuffle=False,
num_workers=config["num_workers"],
pin_memory=True,
drop_last=False
)
corpus_vectors = []
corpus_ids = []
for batch in tqdm(corpus_dataloader): # 预处理图像库
batch_vectors = F.normalize(image_encoder(batch["image"].to(config["device"])), dim=-1) # 正则化
corpus_vectors.append(batch_vectors)
corpus_ids.append(batch["idx"].to(config["device"]))

corpus_vectors = torch.cat(corpus_vectors)
corpus_ids = torch.cat(corpus_ids)

# 按照索引排序
arg_ids = torch.argsort(corpus_ids)
corpus_vectors = corpus_vectors[arg_ids]
corpus_ids = corpus_ids[arg_ids]

corpus = corpus_ids, corpus_vectors
if config["corpus_cache"]:
torch.save(corpus, config["corpus_cache"])

提问与匹配函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def ask_for_caption():
"""询问用户描述"""
caption = input("Describe the image: ")
return caption

def ask_question():
"""询问用户问题,获取回答"""
question = request_chat(dialog)
answer = input(f"Q: {question}\nA: ")
return question+' '+answer

def get_top_results(dialog, dialog_encoder, n=1):
"""获取前n最佳匹配结果

Args:
dialog: str 对话
dialog_encoder: 对话编码器

Returns:
tops: list[int] 前n个匹配结果的索引
topscores: list[float] 前n个匹配结果的得分
"""
dialog = config["sep"].join(dialog)
dialog_vector = dialog_encoder(dialog) # 提取特征&正则化
scores = dialog_vector @ corpus[1].T # 计算点积相似度
top_id = torch.argsort(scores, descending=True) # 排序
tops = top_id.tolist()[0][:n]
topscores = scores[0][tops].tolist()
return tops, topscores

主函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
if __name__ == "__main__":
image_encoder, dialog_encoder = get_funcs()
with open(config["corpus_path"], "r") as f:
images = json.load(f)
f.close()
with torch.no_grad():
prepare_corpus(image_encoder)

dialog.append(ask_for_caption())
tops,topscores = get_top_results(dialog, dialog_encoder)
best_image = images[tops[0]]
best_score = topscores[0]
print(f"Best image: {best_image}")
print(f"Best score: {best_score}")
# display(Image.open(best_image))

for i in range(10):
dialog.append(ask_question())
tops,topscores = get_top_results(dialog, dialog_encoder)
best_image = images[tops[0]]
best_score = topscores[0]
print(f"Best image: {best_image}")
print(f"Best score: {best_score}")
if i==1:
display(Image.open(best_image))

测试和评估

原文评估ChatIR性能的指标是Hit@10,即目标图像出现在最匹配的10个候选图像中的概率,我采用相同的评估方式

对话数据集类

单个测试对话数据结构:

1
2
3
4
{
"image": "image_path",
"dialog": ["caption", "question1? answer1", "question2? answer2", ...]
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Queries(Dataset):
"""对话-图像数据集"""
def __init__(self, queries_path, sep):
"""加载对话-图像数据集

Args:
queries_path: str 查询数据集路径
sep: str 分隔符
"""
with open(queries_path, "r") as f:
self.queries = json.load(f)
f.close()
self.dialog_length = None
self.sep = sep

def __len__(self):
return len(self.queries)

def __getitem__(self, idx):
assert self.dialog_length is not None
target_path = self.queries[idx]["img"]
# 保留对话的前dialog_length轮
text = self.sep.join(self.queries[idx]["dialog"][:self.dialog_length + 1])
return {"text": text, "target_path": target_path}

测试代码

使用了论文仓库提供的测试代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
corpus_dataset = Corpus(config["corpus_path"], image_preprocessor)
query_dataset = Queries(config["queries_path"], config["sep"])

def _get_recalls(dataloader, dialog_length, dialog_encoder):
""" 计算dataloader中长度为dialog_length的对话的召回结果

Args:
dataloader: 数据加载器
dialog_length: 对话长度
"""
dataloader.dataset.dialog_length = dialog_length # 设置对话长度
recalls = [] # 每个对话的召回结果
for batch in tqdm(dataloader):
target_ids = torch.tensor(
[corpus_dataset.path_to_index(p) for p in batch['target_path']]
).unsqueeze(1).to(config['device']) # 图片路径转换为索引
pred_vec = F.normalize(dialog_encoder(batch['text']), dim=-1)

scores = pred_vec @ corpus[1].T # 计算点积,得到相似度分数
arg_ranks = torch.argsort(scores, descending=True, dim=1).long() # 对分数进行排序

target_recall = ((arg_ranks - target_ids) == 0).nonzero()[:, 1] # 目标图像在检索排名中出现的位置
recalls.append(target_recall)

return torch.cat(recalls)

def get_first_hitting_time(target_recall, hitting_recall=10):
""" 返回(11, n)张量,其中包含每轮(0, 11)的命中时间。inf表示未命中(10轮后没有命中) """
target_recalls = target_recall.view(11, -1).T # 转置
hits = (target_recalls < hitting_recall) # 目标图像是否在前 hitting_recall 轮内出现

final_hits = torch.inf * torch.ones(target_recalls.shape[0]) # 初始化为inf

hitting_times = [] # 每轮的命中时间
for ro_i in range(11):
rh = hits[:, ro_i]
final_hits[rh] = torch.min(final_hits[rh], torch.ones(final_hits[rh].shape) * ro_i)
hitting_times.append(final_hits.clone())

return torch.stack(hitting_times)


def cumulative_hits_per_round(target_recall, hitting_recall=10):
""" 返回直到第x轮的平均命中次数 """
if type(hitting_recall) is tuple:
assert len(hitting_recall) == 1
hitting_recall = hitting_recall[0]
ht_times = get_first_hitting_time(target_recall, hitting_recall)
return ((ht_times < torch.inf).sum(dim=-1) * 100 / ht_times[0].shape[0])

def eval(image_encoder, dialog_encoder, hits_at=10):
prepare_corpus(image_encoder)
query_dataloader = torch.utils.data.DataLoader(query_dataset, # 询问数据集
batch_size=config['batch_size'], # 批大小
shuffle=False, # 不打乱
num_workers=config['num_workers'], # 多线程
pin_memory=True, # 锁页内存
drop_last=False
)
hits_results = []
for dl in range(11): # 对话长度从0到10
print(f"Calculate recalls for each dialogues of length {dl}...")
dialog_recalls = _get_recalls(query_dataloader, dialog_length=dl, dialog_encoder=dialog_encoder)
hits_results.append(dialog_recalls)

# 使用cumulative_hits_per_round计算最终的Hits@10结果
# Hits@10:`在预测的前 10 个候选项中,包含了正确答案`的比例
hits_results = cumulative_hits_per_round(torch.cat(hits_results).cpu(), hitting_recall=10).tolist()
print("====== Results for Hits@10 ====== ")
for dl in range(11):
print(f"\t Dialog Length: {dl}: {round(hits_results[dl], 2)}%")

if __name__ == "__main__":
image_encoder, dialog_encoder = get_funcs()
with open(config["corpus_path"], "r") as f:
images = json.load(f)
f.close()
with torch.no_grad():
eval(image_encoder, dialog_encoder)

对话数据生成

原文中,为了自动获取测试数据,免除人工回答,使用了BLIP2模型回答问题

我参考原文的方法,使用BLIP2生成对话数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from PIL import Image
import json
from tqdm import tqdm

from lavis.models import load_model_and_preprocess # BLIP2

from OpenAI import request_chat

import sys
st,ed = int(sys.argv[1]), int(sys.argv[2])

with open("../dialogues/ChatGPT_BLIP2.json", "r") as f:
data = json.load(f)
f.close()

images = [d["img"] for d in data][st:ed]
captions = [d["dialog"][0] for d in data][st:ed]

config = {
"device": "cuda" if torch.cuda.is_available() else "cpu",
"image_size" : 224,
"sep": ", "
}

if config["device"] == "cuda":
print("Using GPU")

def load_image(image_path):
raw_image = Image.open(image_path).convert('RGB')
image = vis_processors["eval"](raw_image).unsqueeze(0).to(config["device"])
return image

def visual_qa(image, question, model):
answer = model.generate({"image": image, "prompt": f"Question: {question} Answer:"})
return answer[0]

def ask_question(dialog):
"""询问用户问题"""
question = request_chat(dialog)
return question

def generate_dialog(idx, model):
image = load_image(images[idx])
caption = captions[idx]
dialog = [caption]
for i in range(10):
question = ask_question(dialog)
answer = visual_qa(image, question, model)
dialog.append(question+answer)
ret = {
"img": images[idx],
"dialog": dialog
}
return ret

if __name__ == "__main__":
model, vis_processors, _ = load_model_and_preprocess(
name="blip2_opt", model_type="caption_coco_opt6.7b", is_eval=True, device=config["device"]
)
with torch.no_grad():
with open("ChatGPT4oMini_BLIP2.txt", "a") as f:
for idx in tqdm(range(len(images))):
dialog = generate_dialog(idx, model)
json.dump(dialog, f)
f.write("\n")
f.close()

测试日志

  1. 由于OpenAI的token限制,我使用了讯飞星火API作为提问模型G
  2. 论文没读仔细,一开始以为生成对话数据时,代替人类回答的模型和编码用的模型一样,是BLIP
  3. 在以上基础上,测试的Hit@10结果(40%~60%)与原文(63%~80%)有较大差距
  4. 原本认为是提问模型G没有用ChatGPT的原因,微氪token,调整为ChatGPT4o-mini,但结果依然不理想(40%~67%)
  5. 注意到,40%是$D_0$,也就是只有用BLIP生成的第一句描述时的准确率,和提问模型G无关
  6. 对比原文,发现回答模型应该是另一篇paper中的BLIP2,而不是BLIP
  7. 修改数据生成代码,使用BLIP2生成对话数据,测试结果与原文接近(60%~80%)

客观问题:由于硬件条件有限,跑出一条数据需要3~5分钟,因此只测试了1021条数据,对于整体性能评估可能不够准确

测试结果

在使用ChatGPT作为提问模型G,BLIP2作为回答模型A,使用同一测试代码的情况下,测试结果和原文对比如下:

length 原文(2064 testcases) 复现(1021 testcases)
0 63.42% 62.98%
1 69.43% 70.62%
2 72.38% 72.67%
3 74.47% 74.53%
4 76.02% 75.32%
5 77.47% 75.42%
6 78.49% 76.00%
7 79.65% 76.20%
8 80.09% 76.69%
9 80.43% 76.98%
10 80.77% 77.18%

在第五轮对话后,对话长度增加对检索性能的提升不再明显