1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
| import os import cv2 import numpy as np from sklearn.cluster import KMeans from scipy.spatial import distance from tqdm import tqdm from time import time import pickle
os.environ["LOKY_MAX_CPU_COUNT"] = "4"
descriptors_list = [] histograms = {} sift = cv2.SIFT_create()
def get_images(image_folder): images = [] for root, _, files in os.walk(image_folder): for filename in files: if filename.lower().endswith(".jpg"): images.append(os.path.join(root, filename)) return images
def extract_sift_feature(img_path): img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) if img is None: return None, None keypoints, descriptors = sift.detectAndCompute(img, None) return keypoints, descriptors
def extract_sift_features(image_paths, cache="cache/BOF_SIFT_desc.pkl"): print("1.提取SIFT特征")
if os.path.exists(cache): with open(cache, "rb") as f: print(f"从缓存中加载SIFT特征: {cache}") return pickle.load(f) for img_path in tqdm(image_paths): keypoints, descriptors = extract_sift_feature(img_path) if descriptors is not None: descriptors_list.append(descriptors)
with open(cache, "wb") as f: print(f"保存SIFT特征到缓存: {cache}") pickle.dump(descriptors_list, f)
return descriptors_list
def create_vocabulary(descriptors_list, K, cache="cache/BOF_kmeans.pkl"): print("2.创建视觉词典")
if os.path.exists(cache): with open(cache, "rb") as f: print(f"从缓存中加载K-Means模型: {cache}") return pickle.load(f)
st = time() all_descriptors = np.vstack(descriptors_list) kmeans = KMeans(n_clusters=K, random_state=0, n_init=10) kmeans.fit(all_descriptors) print(f"K-Means聚类耗时: {time()-st:.2f}s")
with open(cache, "wb") as f: print(f"保存K-Means模型到缓存: {cache}") pickle.dump(kmeans, f)
return kmeans
def compute_histogram(img_path, kmeans, K): img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) keypoints, descriptors = sift.detectAndCompute(img, None) if descriptors is not None: words = kmeans.predict(descriptors) hist, _ = np.histogram(words, bins=np.arange(K+1)) hist = hist / np.linalg.norm(hist) return hist return None
def compute_histograms(image_paths, kmeans, K, cache="cache/BOF_histogram.pkl"): print("3.计算数据库的BoF直方图")
if os.path.exists(cache): with open(cache, "rb") as f: print(f"从缓存中加载数据库直方图: {cache}") return pickle.load(f)
for img_path in tqdm(image_paths): hist = compute_histogram(img_path, kmeans, K) if hist is not None: histograms[img_path] = hist with open(cache, "wb") as f: print(f"保存数据库直方图到缓存: {cache}") pickle.dump(histograms, f)
return histograms
def build_inverted_index(histograms): print("4.构造倒排索引") inverted_index = {} for img_path, hist in tqdm(histograms.items()): for word_idx, freq in enumerate(hist): if freq > 0: if word_idx not in inverted_index: inverted_index[word_idx] = [] inverted_index[word_idx].append((img_path, freq)) return inverted_index
def match_image(query_hist, histograms): print("5.直方图匹配") best_match = None min_dist = float('inf') for img_path, hist in tqdm(histograms.items()): dist = distance.euclidean(query_hist, hist) if dist < min_dist: min_dist = dist best_match = img_path return best_match
image_folder = "image" image_paths = get_images(image_folder) descriptors_list = extract_sift_features(image_paths) kmeans = create_vocabulary(descriptors_list, K=50) histograms = compute_histograms(image_paths, kmeans, K=50) inverted_index = build_inverted_index(histograms)
test_img = "query/A0C573_20151029074136_6562078379.jpg" histograms[test_img] = compute_histogram(test_img, kmeans, K=50) test_hist = histograms[test_img] matched_img = match_image(test_hist, histograms) print(f"{test_img} 匹配到的最相似图片是 {matched_img}")
|