Reverse Image Search (Local Only)

import os

import cv2

import numpy as np

import pickle

from PIL import Image, ImageTk

import tkinter as tk

from tkinter import ttk, filedialog, messagebox

from pathlib import Path

from math import ceil


# ----------------------

# Config

# ----------------------

SUPPORTED_EXT = (".jpg", ".jpeg", ".png", ".bmp", ".tiff")

CACHE_FILE = "image_features_cache.pkl"  # optional cache to speed up indexing

THUMB_SIZE = (200, 150)  # thumbnail size for display


# ----------------------

# Feature detector factory (SIFT preferred, fallback to ORB)

# ----------------------

def make_feature_detector():

    try:

        # try SIFT (requires opencv-contrib)

        sift = cv2.SIFT_create()

        print("Using SIFT detector")

        return ("SIFT", sift)

    except Exception:

        # fallback to ORB

        orb = cv2.ORB_create(nfeatures=1500)

        print("SIFT not available — falling back to ORB")

        return ("ORB", orb)


# ----------------------

# Matcher factory

# ----------------------

def make_matcher(detector_name):

    if detector_name == "SIFT":

        # FLANN parameters for SIFT (float descriptors)

        FLANN_INDEX_KDTREE = 1

        index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)

        search_params = dict(checks=50)

        matcher = cv2.FlannBasedMatcher(index_params, search_params)

        return matcher

    else:

        # ORB uses Hamming distance (binary descriptors)

        matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)

        return matcher


# ----------------------

# Compute descriptors for one image

# ----------------------

def compute_descriptors(img_path, detector_tuple):

    """

    Returns: keypoints, descriptors

    """

    detector_name, detector = detector_tuple

    img = cv2.imdecode(np.fromfile(str(img_path), dtype=np.uint8), cv2.IMREAD_GRAYSCALE)

    if img is None:

        raise ValueError(f"Failed to read image: {img_path}")

    # optional resize to speed up (keep aspect)

    h, w = img.shape

    max_dim = 1024

    if max(h, w) > max_dim:

        scale = max_dim / max(h, w)

        img = cv2.resize(img, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA)


    kp, des = detector.detectAndCompute(img, None)

    if des is None:

        # no descriptors found — return empty

        des = np.array([], dtype=np.float32).reshape(0, 128 if detector_name == "SIFT" else 32)

    return kp, des


# ----------------------

# Indexing folder of images

# ----------------------

def index_folder(folder_path, detector_tuple, cache_enabled=True):

    """

    Scans folder for images, computes descriptors, returns list of records:

    [ {"path": Path(...), "kp": keypoints, "des": descriptors} ... ]

    """

    folder = Path(folder_path)

    if not folder.exists() or not folder.is_dir():

        raise ValueError("Folder path invalid")


    # Try load cache (only if detector matches)

    cache = {}

    if cache_enabled and os.path.exists(CACHE_FILE):

        try:

            with open(CACHE_FILE, "rb") as f:

                cache = pickle.load(f)

        except Exception:

            cache = {}


    records = []

    for p in sorted(folder.iterdir()):

        if p.suffix.lower() not in SUPPORTED_EXT:

            continue

        key = str(p.resolve())

        # cached item must match detector name to be reused

        cached = cache.get(key)

        use_cache = cached and cached.get("detector") == detector_tuple[0]

        if use_cache:

            rec = {"path": Path(key), "kp": None, "des": cached["descriptors"]}

            print("Cache hit:", p.name)

        else:

            try:

                kp, des = compute_descriptors(p, detector_tuple)

            except Exception as e:

                print("Failed to compute:", p, e)

                kp, des = [], np.array([])

            rec = {"path": p, "kp": kp, "des": des}

            # store to cache

            cache[key] = {"detector": detector_tuple[0], "descriptors": des}

        records.append(rec)


    if cache_enabled:

        try:

            with open(CACHE_FILE, "wb") as f:

                pickle.dump(cache, f)

        except Exception as e:

            print("Could not write cache:", e)


    return records


# ----------------------

# Matching logic

# ----------------------

def match_descriptors(query_des, target_des, matcher, detector_name, ratio_thresh=0.75):

    """

    Returns number of good matches using Lowe ratio test for k=2 neighbors.

    For ORB (binary), ratio test still works with BFMatcher and knn.

    """

    if query_des is None or target_des is None or len(query_des) == 0 or len(target_des) == 0:

        return 0, []


    # For FLANN with SIFT, descriptors must be float32

    if detector_name == "SIFT":

        if query_des.dtype != np.float32:

            query_des = query_des.astype(np.float32)

        if target_des.dtype != np.float32:

            target_des = target_des.astype(np.float32)


    try:

        matches = matcher.knnMatch(query_des, target_des, k=2)

    except Exception:

        # fallback: use BFMatcher with crossCheck off

        bf = cv2.BFMatcher()

        raw = bf.match(query_des, target_des)

        # treat each as good match (not optimal)

        good = raw

        return len(good), good


    # Apply ratio test

    good = []

    for m_n in matches:

        if len(m_n) < 2:

            continue

        m, n = m_n

        if m.distance < ratio_thresh * n.distance:

            good.append(m)

    return len(good), good


# ----------------------

# Query: find top K similar

# ----------------------

def find_similar(query_path, records, detector_tuple, top_k=6):

    detector_name, _ = detector_tuple

    matcher = make_matcher(detector_name)


    # compute descriptors for query

    qkp, qdes = compute_descriptors(query_path, detector_tuple)

    results = []

    for rec in records:

        tdes = rec["des"]

        count, good_matches = match_descriptors(qdes, tdes, matcher, detector_name)

        # normalized score: matches / sqrt(size_query * size_target) to penalize huge images

        denom = max(1.0, np.sqrt(max(1,len(qdes)) * max(1,len(tdes))))

        score = count / denom

        results.append({"path": rec["path"], "matches": count, "score": score, "good": good_matches})

    # sort by score descending

    results = sorted(results, key=lambda r: r["score"], reverse=True)

    return results[:top_k]


# ----------------------

# Utilities: load thumbnail for Tkinter display

# ----------------------

def pil_image_from_path(p):

    # handle non-ascii paths by reading bytes then PIL

    arr = np.fromfile(str(p), dtype=np.uint8)

    img = cv2.imdecode(arr, cv2.IMREAD_COLOR)

    if img is None:

        return None

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    pil = Image.fromarray(img)

    return pil


# ----------------------

# Tkinter GUI

# ----------------------

class ReverseImageSearchGUI:

    def __init__(self, master):

        self.master = master

        master.title("Reverse Image Search — Local (SIFT/ORB)")

        master.geometry("1000x700")


        self.detector_tuple = make_feature_detector()

        self.records = []

        self.indexed_folder = None


        # Top controls

        top = ttk.Frame(master)

        top.pack(side=tk.TOP, fill=tk.X, padx=8, pady=8)


        ttk.Button(top, text="Choose Images Folder", command=self.choose_folder).pack(side=tk.LEFT, padx=4)

        self.folder_label = ttk.Label(top, text="No folder chosen")

        self.folder_label.pack(side=tk.LEFT, padx=6)

        ttk.Button(top, text="Index Folder", command=self.index_folder).pack(side=tk.LEFT, padx=6)

        ttk.Button(top, text="Choose Query Image", command=self.choose_query).pack(side=tk.LEFT, padx=6)

        ttk.Label(top, text="Top K:").pack(side=tk.LEFT, padx=(10,0))

        self.topk_var = tk.IntVar(value=6)

        ttk.Entry(top, textvariable=self.topk_var, width=4).pack(side=tk.LEFT)


        # Query preview + results area

        mid = ttk.Frame(master)

        mid.pack(fill=tk.BOTH, expand=True, padx=8, pady=6)


        left = ttk.Frame(mid, width=300)

        left.pack(side=tk.LEFT, fill=tk.Y)

        ttk.Label(left, text="Query Image:").pack(anchor="w")

        self.query_canvas = tk.Label(left, text="No query selected", width=40, height=12, bg="#ddd")

        self.query_canvas.pack(padx=6, pady=6)


        ttk.Button(left, text="Clear Cache", command=self.clear_cache).pack(pady=6)

        ttk.Button(left, text="Re-index", command=self.reindex).pack(pady=6)


        right = ttk.Frame(mid)

        right.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)

        ttk.Label(right, text="Top Matches:").pack(anchor="w")

        self.results_frame = ttk.Frame(right)

        self.results_frame.pack(fill=tk.BOTH, expand=True)


        # status

        self.status_var = tk.StringVar(value="Ready")

        ttk.Label(master, textvariable=self.status_var).pack(side=tk.BOTTOM, fill=tk.X)


    def choose_folder(self):

        folder = filedialog.askdirectory()

        if not folder:

            return

        self.indexed_folder = folder

        self.folder_label.config(text=folder)

        self.status_var.set(f"Selected folder: {folder}")


    def index_folder(self):

        if not self.indexed_folder:

            messagebox.showwarning("Pick folder", "Choose images folder first")

            return

        self.status_var.set("Indexing folder (computing descriptors)...")

        self.master.update_idletasks()

        try:

            self.records = index_folder(self.indexed_folder, self.detector_tuple, cache_enabled=True)

            self.status_var.set(f"Indexed {len(self.records)} images")

            messagebox.showinfo("Indexed", f"Indexed {len(self.records)} images.")

        except Exception as e:

            messagebox.showerror("Indexing failed", str(e))

            self.status_var.set("Indexing failed")


    def reindex(self):

        if not self.indexed_folder:

            messagebox.showwarning("Pick folder", "Choose images folder first")

            return

        # delete cache and re-index

        try:

            if os.path.exists(CACHE_FILE):

                os.remove(CACHE_FILE)

        except:

            pass

        self.index_folder()


    def choose_query(self):

        q = filedialog.askopenfilename(filetypes=[("Images", "*.jpg *.jpeg *.png *.bmp *.tiff")])

        if not q:

            return

        self.query_path = Path(q)

        pil = pil_image_from_path(q)

        if pil is None:

            messagebox.showerror("Error", "Could not load image")

            return

        thumb = pil.copy()

        thumb.thumbnail(THUMB_SIZE)

        tkimg = ImageTk.PhotoImage(thumb)

        self.query_canvas.image = tkimg

        self.query_canvas.config(image=tkimg, text="")

        self.status_var.set(f"Query: {os.path.basename(q)}")

        # Run search if indexed

        if not self.records:

            if messagebox.askyesno("Not indexed", "Folder not indexed yet. Index now?"):

                self.index_folder()

            else:

                return

        self.search_query(q)


    def search_query(self, qpath):

        self.status_var.set("Searching for similar images...")

        self.master.update_idletasks()

        try:

            topk = max(1, int(self.topk_var.get()))

        except:

            topk = 6

        results = find_similar(qpath, self.records, self.detector_tuple, top_k=topk)

        # Clear previous results

        for w in self.results_frame.winfo_children():

            w.destroy()


        # Display results in grid

        cols = min(3, topk)

        r = 0; c = 0

        for idx, res in enumerate(results):

            path = res["path"]

            score = res["score"]

            matches = res["matches"]

            pil = pil_image_from_path(path)

            if pil is None:

                continue

            thumb = pil.copy()

            thumb.thumbnail(THUMB_SIZE)

            tkimg = ImageTk.PhotoImage(thumb)

            panel = ttk.Frame(self.results_frame, relief=tk.RIDGE, borderwidth=1)

            panel.grid(row=r, column=c, padx=6, pady=6, sticky="nsew")

            lbl = tk.Label(panel, image=tkimg)

            lbl.image = tkimg

            lbl.pack()

            info = ttk.Label(panel, text=f"{path.name}\nScore:{score:.3f}\nMatches:{matches}", anchor="center")

            info.pack()

            # click to open full image in default viewer

            def make_open(p=path):

                return lambda e=None: os.startfile(str(p)) if os.name == 'nt' else os.system(f'xdg-open "{p}"')

            lbl.bind("<Button-1>", make_open(path))

            c += 1

            if c >= cols:

                c = 0

                r += 1


        self.status_var.set("Search complete")


    def clear_cache(self):

        if os.path.exists(CACHE_FILE):

            try:

                os.remove(CACHE_FILE)

                messagebox.showinfo("Cache", "Cache file removed")

                self.status_var.set("Cache cleared")

            except Exception as e:

                messagebox.showerror("Error", f"Could not remove cache: {e}")

        else:

            messagebox.showinfo("Cache", "No cache file present")


# ----------------------

# Run

# ----------------------

def main():

    root = tk.Tk()

    app = ReverseImageSearchGUI(root)

    root.mainloop()


if __name__ == "__main__":

    main()


No comments: