论文:Chatting Makes Perfect: Chat-based Image Retrival
1. Introduction ChatIR系统包含2个部分:对话构建(Dialog Building)和图像搜索(Image Search)
对话构建:使用问题生成器G,考虑当前对话历史,生成下一个问题
图像搜索:使用模型F,将不同长度的对话序列映射到视觉嵌入空间
两个组成部分建立在Instructional LLMs和fundation Vision and Language Models上
考虑三个问题:
使用什么数据集训练?是否需要新创建和标注数据集?
使用VisDial数据集
问题:VisDial是一个用于“创建关于图像的聊天”的数据集,没有检索目标
解决:输入输出反置,对话作为输入、图像作为输出
如何独立评估ChatIR系统的不同组件?
测试使用不同的F训练策略
和提问模型G
对检索性能的影响
使用了BLIP替代用户回答问题
如何定义评估指标?
视觉对话(Visual Conversation)领域
当前视觉领域工作的重点在于图像理解和生成模型,而不是检索
生成式视觉对话中,近期的基础模型V&L性能优越,因此ChatIR系统在此基础上构建
视觉搜索(Visual Search)领域
CoIR:使用多模态询问查找目标图像
一些研究基于CoIR,利用用户反馈细化查询结果
但是,没有考虑用户交互(只有用户反馈,没有机器提问)、没有明确利用对话历史
3. Method 3.1 Dialog Builder Model 对话生成模型 对话记录表示:$D_i = (C, Q_1, A_1, …, Q_i)$
$C$:目标图像的初始文本描述(标题)
$Q_i$:第i个问题
$A_i$:第i个回答
对话生成模型包含两个部分:
问题生成器G:一个LLM,根据对话记录生成下一个问题
$G: Di \rightarrow Q {i+1}$
G不知道目标图像T是什么,只知道对话历史
答案提供者A:在实践中,通常是一个脑海中有大致目标图像的人类
由于需要大规模实验,不能依赖用户提供答案
因此,使用了一个现成的模型BLIP2来回答
3.2 Image Retrieval Model 图像检索模型 图像搜索过程:在查询
和目标图像
共享的视觉嵌入空间中,搜索匹配项。
所有的目标图像先经过图像嵌入模块进行编码,由一个$d$维的特征向量$f\in \mathbb R^d$表示。
图像检索模块F将对话历史$D_i$映射到视觉嵌入空间,$F: D_i \rightarrow \mathbb R^d$,
候选对象根据相似度进行排序。
引入分隔符[SEP]和添加符[CLS],表示整个对话序列,投射到视觉嵌入空间。
F采用(使用BLIP)预训练的图像/文本编码器,并通过对比学习,对基于对话的检索进行微调。
通过提取VisDial数据集中的图像和相应对话,手动标注,训练F。
4. Evaluation 在评估环节,原文使用了Hit@10指标,即目标图像在前10个检索结果中的试验占比。
原文从三个方面进行了对比:
与现有文本到图像(Text to Image,TTI)检索方法的比较
ChatIR使用ChatGPT作为提问者G,BLIP2作为回答者A
与Zero-shot的BLIP、CLIP以及fine-tuned SoTA TTI BLIP进行比较
结论:ChatIR在多轮对话环境中,相比传统单跳 TTI 方法表现更优
不同提问者G的比较
使用ChatGPT、FLAN-ALPACA-XXL、人类等8中不同提问者
ChatGPT表现最好
人类参与对话的影响
ChatGPT提问,人类回答;ChatGPT提问,BLIP2回答;人类提问,人类回答
由于人类生成的答案质量明显优于BLIP2,因此人类参与对话时,检索性能会比测试的数据更好
总结和复现
ChatIR系统是一个基于对话的图像检索系统,包含对话构建和图像搜索两个部分
对话构建:使用问题生成器G,考虑当前对话历史,生成下一个问题
原文测试了不同的问题生成器,其中ChatGPT表现最好
图像搜索:使用模型F,将不同长度的对话序列映射到视觉嵌入空间
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 import torchimport torch.nn.functional as Ffrom torch.utils.data import Dataset, DataLoaderfrom torchvision import transformsfrom torchvision.transforms.functional import InterpolationModeimport os.pathfrom PIL import Imagefrom tqdm import tqdmimport jsonimport multiprocessing multiprocessing.set_start_method('spawn' , force=True ) from OpenAI import request_chat import sys sys.path.insert(0 , './BLIP' ) from BLIP.models.blip_itm import blip_itm config = { "corpus_path" : "VisualDial/search_space.json" , "queries_path" : "dialogues/ChatGPT4oMini_BLIP2.json" , "corpus_cache" : "VisualDial/corpus_cache.pth" , "device" : "cuda" if torch.cuda.is_available() else "cpu" , "sep" : ", " , "batch_size" : 100 , "num_workers" : 8 , "image_size" : 224 } corpus = None dialog = [] images = []
图像数据集类 继承自Dataset,用于加载图像数据集
corpus_path:图像数据集.json,包含一个list[str],每个元素是一个图像的路径
建立路径字符串到索引的映射,便于赋值传递和查询
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 class Corpus (Dataset ): """图像数据集""" def __init__ (self, corpus_path, preprocessor ): """加载图像数据集 Args: corpus_path: 图片路径列表 preprocessor: 图像预处理函数 """ with open (corpus_path, "r" ) as f: self.corpus = json.load(f) f.close() self.preprocessor = preprocessor self.path2idx = {self.corpus[i]:i for i in range (len (self.corpus))} def __len__ (self ): return len (self.corpus) def __getitem__ (self, idx ): image = self.preprocessor(self.corpus[idx]) return {"idx" : idx, "image" : image} def path_to_index (self, path ): return self.path2idx[path]
图像预处理函数 1 2 3 4 5 6 7 8 9 10 11 def image_preprocessor (image_path ): transform_prep = transforms.Compose([ transforms.Resize((config["image_size" ], config["image_size" ]), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466 , 0.4578275 , 0.40821073 ), (0.26862954 , 0.26130258 , 0.27577711 )) ]) raw = Image.open (image_path).convert("RGB" ) img = transform_prep(raw) return img
BLIP_ITM模型的图像编码器和对话编码器 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 def get_funcs (): model = blip_itm(pretrained='chatir_weights.ckpt' , med_config="BLIP/configs/med_config.json" , image_size=config["image_size" ], vit="base" ) device = config["device" ] model = model.to(device).eval () def image_encoder (img ): embeddings = model.visual_encoder(img) vision_proj = model.vision_proj(embeddings[:, 0 , :]) return F.normalize(vision_proj, dim=-1 ) def dialog_encoder (dialog ): text = model.tokenizer(dialog, padding='longest' , truncation=True , max_length=200 , return_tensors="pt" ).to(device) text_out = model.text_encoder(text.input_ids, attention_mask=text.attention_mask, return_dict=True , mode='text' ) shift = model.text_proj(text_out.last_hidden_state[:, 0 , :]) return F.normalize(shift, dim=-1 ) return image_encoder, dialog_encoder
处理图像库 由于加载时间长,一次加载后将数据存储在本地,便于二次调用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 corpus_dataset = Corpus(config["corpus_path" ], image_preprocessor) def prepare_corpus (image_encoder ): """处理图像库 Returns: corpus: tuple[torch.Tensor, torch.Tensor] 图像库索引和对应的图像特征向量 """ global corpus corpus_cache = config["corpus_cache" ] if corpus_cache and os.path.exists(corpus_cache): print (f"-----Loading corpus from {corpus_cache} -----" ) corpus = torch.load(corpus_cache) return print ("-----Preparing corpus-----" ) corpus_dataloader = DataLoader( corpus_dataset, batch_size=config["batch_size" ], shuffle=False , num_workers=config["num_workers" ], pin_memory=True , drop_last=False ) corpus_vectors = [] corpus_ids = [] for batch in tqdm(corpus_dataloader): batch_vectors = F.normalize(image_encoder(batch["image" ].to(config["device" ])), dim=-1 ) corpus_vectors.append(batch_vectors) corpus_ids.append(batch["idx" ].to(config["device" ])) corpus_vectors = torch.cat(corpus_vectors) corpus_ids = torch.cat(corpus_ids) arg_ids = torch.argsort(corpus_ids) corpus_vectors = corpus_vectors[arg_ids] corpus_ids = corpus_ids[arg_ids] corpus = corpus_ids, corpus_vectors if config["corpus_cache" ]: torch.save(corpus, config["corpus_cache" ])
提问与匹配函数 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 def ask_for_caption (): """询问用户描述""" caption = input ("Describe the image: " ) return caption def ask_question (): """询问用户问题,获取回答""" question = request_chat(dialog) answer = input (f"Q: {question} \nA: " ) return question+' ' +answer def get_top_results (dialog, dialog_encoder, n=1 ): """获取前n最佳匹配结果 Args: dialog: str 对话 dialog_encoder: 对话编码器 Returns: tops: list[int] 前n个匹配结果的索引 topscores: list[float] 前n个匹配结果的得分 """ dialog = config["sep" ].join(dialog) dialog_vector = dialog_encoder(dialog) scores = dialog_vector @ corpus[1 ].T top_id = torch.argsort(scores, descending=True ) tops = top_id.tolist()[0 ][:n] topscores = scores[0 ][tops].tolist() return tops, topscores
主函数 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 if __name__ == "__main__" : image_encoder, dialog_encoder = get_funcs() with open (config["corpus_path" ], "r" ) as f: images = json.load(f) f.close() with torch.no_grad(): prepare_corpus(image_encoder) dialog.append(ask_for_caption()) tops,topscores = get_top_results(dialog, dialog_encoder) best_image = images[tops[0 ]] best_score = topscores[0 ] print (f"Best image: {best_image} " ) print (f"Best score: {best_score} " ) for i in range (10 ): dialog.append(ask_question()) tops,topscores = get_top_results(dialog, dialog_encoder) best_image = images[tops[0 ]] best_score = topscores[0 ] print (f"Best image: {best_image} " ) print (f"Best score: {best_score} " ) if i==1 : display(Image.open (best_image))
测试和评估 原文评估ChatIR性能的指标是Hit@10,即目标图像出现在最匹配的10个候选图像中的概率,我采用相同的评估方式
对话数据集类 单个测试对话数据结构:
1 2 3 4 { "image" : "image_path" , "dialog" : [ "caption" , "question1? answer1" , "question2? answer2" , ...] }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 class Queries (Dataset ): """对话-图像数据集""" def __init__ (self, queries_path, sep ): """加载对话-图像数据集 Args: queries_path: str 查询数据集路径 sep: str 分隔符 """ with open (queries_path, "r" ) as f: self.queries = json.load(f) f.close() self.dialog_length = None self.sep = sep def __len__ (self ): return len (self.queries) def __getitem__ (self, idx ): assert self.dialog_length is not None target_path = self.queries[idx]["img" ] text = self.sep.join(self.queries[idx]["dialog" ][:self.dialog_length + 1 ]) return {"text" : text, "target_path" : target_path}
测试代码 使用了论文仓库提供的测试代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 corpus_dataset = Corpus(config["corpus_path" ], image_preprocessor) query_dataset = Queries(config["queries_path" ], config["sep" ]) def _get_recalls (dataloader, dialog_length, dialog_encoder ): """ 计算dataloader中长度为dialog_length的对话的召回结果 Args: dataloader: 数据加载器 dialog_length: 对话长度 """ dataloader.dataset.dialog_length = dialog_length recalls = [] for batch in tqdm(dataloader): target_ids = torch.tensor( [corpus_dataset.path_to_index(p) for p in batch['target_path' ]] ).unsqueeze(1 ).to(config['device' ]) pred_vec = F.normalize(dialog_encoder(batch['text' ]), dim=-1 ) scores = pred_vec @ corpus[1 ].T arg_ranks = torch.argsort(scores, descending=True , dim=1 ).long() target_recall = ((arg_ranks - target_ids) == 0 ).nonzero()[:, 1 ] recalls.append(target_recall) return torch.cat(recalls) def get_first_hitting_time (target_recall, hitting_recall=10 ): """ 返回(11, n)张量,其中包含每轮(0, 11)的命中时间。inf表示未命中(10轮后没有命中) """ target_recalls = target_recall.view(11 , -1 ).T hits = (target_recalls < hitting_recall) final_hits = torch.inf * torch.ones(target_recalls.shape[0 ]) hitting_times = [] for ro_i in range (11 ): rh = hits[:, ro_i] final_hits[rh] = torch.min (final_hits[rh], torch.ones(final_hits[rh].shape) * ro_i) hitting_times.append(final_hits.clone()) return torch.stack(hitting_times) def cumulative_hits_per_round (target_recall, hitting_recall=10 ): """ 返回直到第x轮的平均命中次数 """ if type (hitting_recall) is tuple : assert len (hitting_recall) == 1 hitting_recall = hitting_recall[0 ] ht_times = get_first_hitting_time(target_recall, hitting_recall) return ((ht_times < torch.inf).sum (dim=-1 ) * 100 / ht_times[0 ].shape[0 ]) def eval (image_encoder, dialog_encoder, hits_at=10 ): prepare_corpus(image_encoder) query_dataloader = torch.utils.data.DataLoader(query_dataset, batch_size=config['batch_size' ], shuffle=False , num_workers=config['num_workers' ], pin_memory=True , drop_last=False ) hits_results = [] for dl in range (11 ): print (f"Calculate recalls for each dialogues of length {dl} ..." ) dialog_recalls = _get_recalls(query_dataloader, dialog_length=dl, dialog_encoder=dialog_encoder) hits_results.append(dialog_recalls) hits_results = cumulative_hits_per_round(torch.cat(hits_results).cpu(), hitting_recall=10 ).tolist() print ("====== Results for Hits@10 ====== " ) for dl in range (11 ): print (f"\t Dialog Length: {dl} : {round (hits_results[dl], 2 )} %" ) if __name__ == "__main__" : image_encoder, dialog_encoder = get_funcs() with open (config["corpus_path" ], "r" ) as f: images = json.load(f) f.close() with torch.no_grad(): eval (image_encoder, dialog_encoder)
对话数据生成 原文中,为了自动获取测试数据,免除人工回答,使用了BLIP2模型回答问题
我参考原文的方法,使用BLIP2生成对话数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 import torchfrom PIL import Imageimport jsonfrom tqdm import tqdmfrom lavis.models import load_model_and_preprocess from OpenAI import request_chatimport sysst,ed = int (sys.argv[1 ]), int (sys.argv[2 ]) with open ("../dialogues/ChatGPT_BLIP2.json" , "r" ) as f: data = json.load(f) f.close() images = [d["img" ] for d in data][st:ed] captions = [d["dialog" ][0 ] for d in data][st:ed] config = { "device" : "cuda" if torch.cuda.is_available() else "cpu" , "image_size" : 224 , "sep" : ", " } if config["device" ] == "cuda" : print ("Using GPU" ) def load_image (image_path ): raw_image = Image.open (image_path).convert('RGB' ) image = vis_processors["eval" ](raw_image).unsqueeze(0 ).to(config["device" ]) return image def visual_qa (image, question, model ): answer = model.generate({"image" : image, "prompt" : f"Question: {question} Answer:" }) return answer[0 ] def ask_question (dialog ): """询问用户问题""" question = request_chat(dialog) return question def generate_dialog (idx, model ): image = load_image(images[idx]) caption = captions[idx] dialog = [caption] for i in range (10 ): question = ask_question(dialog) answer = visual_qa(image, question, model) dialog.append(question+answer) ret = { "img" : images[idx], "dialog" : dialog } return ret if __name__ == "__main__" : model, vis_processors, _ = load_model_and_preprocess( name="blip2_opt" , model_type="caption_coco_opt6.7b" , is_eval=True , device=config["device" ] ) with torch.no_grad(): with open ("ChatGPT4oMini_BLIP2.txt" , "a" ) as f: for idx in tqdm(range (len (images))): dialog = generate_dialog(idx, model) json.dump(dialog, f) f.write("\n" ) f.close()
测试日志
由于OpenAI的token限制,我使用了讯飞星火API作为提问模型G
论文没读仔细,一开始以为生成对话数据时,代替人类回答的模型和编码用的模型一样,是BLIP
在以上基础上,测试的Hit@10结果(40%~60%)与原文(63%~80%)有较大差距
原本认为是提问模型G没有用ChatGPT的原因,微氪token,调整为ChatGPT4o-mini,但结果依然不理想(40%~67%)
注意到,40%是$D_0$,也就是只有用BLIP生成的第一句描述时的准确率,和提问模型G无关
对比原文,发现回答模型应该是另一篇paper中的BLIP2,而不是BLIP
修改数据生成代码,使用BLIP2生成对话数据,测试结果与原文接近(60%~80%)
客观问题:由于硬件条件有限,跑出一条数据需要3~5分钟,因此只测试了1021条数据,对于整体性能评估可能不够准确
测试结果 在使用ChatGPT作为提问模型G,BLIP2作为回答模型A,使用同一测试代码的情况下,测试结果和原文对比如下:
length
原文(2064 testcases)
复现(1021 testcases)
0
63.42%
62.98%
1
69.43%
70.62%
2
72.38%
72.67%
3
74.47%
74.53%
4
76.02%
75.32%
5
77.47%
75.42%
6
78.49%
76.00%
7
79.65%
76.20%
8
80.09%
76.69%
9
80.43%
76.98%
10
80.77%
77.18%
在第五轮对话后,对话长度增加对检索性能的提升不再明显