Handwritten Math Solver

 Install requirements

pip install tensorflow Pillow opencv-python numpy sympy

Train a digit model once (MNIST) — train_mnist_cnn.py

This trains a small CNN on MNIST and saves mnist_cnn.h5.

# train_mnist_cnn.py

import tensorflow as tf

from tensorflow import keras

from tensorflow.keras import layers


def build_model():

    model = keras.Sequential([

        layers.Input(shape=(28, 28, 1)),

        layers.Conv2D(32, 3, activation='relu'),

        layers.Conv2D(32, 3, activation='relu'),

        layers.MaxPooling2D(),

        layers.Dropout(0.25),


        layers.Conv2D(64, 3, activation='relu'),

        layers.Conv2D(64, 3, activation='relu'),

        layers.MaxPooling2D(),

        layers.Dropout(0.25),


        layers.Flatten(),

        layers.Dense(128, activation='relu'),

        layers.Dropout(0.5),

        layers.Dense(10, activation='softmax')

    ])

    model.compile(optimizer='adam',

                  loss='sparse_categorical_crossentropy',

                  metrics=['accuracy'])

    return model


def main():

    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

    x_train = x_train.astype("float32") / 255.0

    x_test = x_test.astype("float32") / 255.0

    x_train = x_train[..., None]

    x_test = x_test[..., None]


    model = build_model()

    model.fit(x_train, y_train, batch_size=128, epochs=5, validation_split=0.1)

    test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)

    print(f"Test accuracy: {test_acc:.4f}")

    model.save("mnist_cnn.h5")

    print("Saved model to mnist_cnn.h5")


if __name__ == "__main__":

    main()

The GUI solver — handwritten_math_solver.py

# handwritten_math_solver.py
import tkinter as tk
from tkinter import messagebox
from PIL import Image, ImageDraw, ImageOps
import numpy as np
import cv2
import io
from sympy import sympify, simplify
from tensorflow.keras.models import load_model
# ---- Config ----
MODEL_PATH = "mnist_cnn.h5"
CANVAS_SIZE = 400           # drawing canvas (square)
DRAW_WIDTH = 14             # brush thickness (thicker = easier OCR)
MIN_CONTOUR_AREA = 60       # filter noise
PADDING = 8                 # pad per glyph before resize to 28x28
# Heuristics thresholds for operators
MINUS_AR_THRESH = 2.0       # width/height > this → likely '-'
MINUS_HEIGHT_FRAC = 0.45    # symbol height relative to median digit height (shorter → minus)
PLUS_PEAK_FRAC = 0.6        # vertical and horizontal central peaks to consider '+'
class MathSolverApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Handwritten Math Solver (digits + +/−)")
        # Canvas to draw
        self.canvas = tk.Canvas(root, width=CANVAS_SIZE, height=CANVAS_SIZE, bg="white", cursor="cross")
        self.canvas.grid(row=0, column=0, columnspan=3, padx=10, pady=10)
        # PIL image to accumulate strokes (black on white)
        self.image = Image.new("L", (CANVAS_SIZE, CANVAS_SIZE), color=255)
        self.draw = ImageDraw.Draw(self.image)
        # Bind drawing
        self.last_x, self.last_y = None, None
        self.canvas.bind("<ButtonPress-1>", self.pen_down)
        self.canvas.bind("<B1-Motion>", self.paint)
        self.canvas.bind("<ButtonRelease-1>", self.pen_up)
        # Buttons
        tk.Button(root, text="Recognize & Solve", command=self.recognize_and_solve).grid(row=1, column=0, pady=6)
        tk.Button(root, text="Clear", command=self.clear_canvas).grid(row=1, column=1, pady=6)
        tk.Button(root, text="Quit", command=root.quit).grid(row=1, column=2, pady=6)
        # Output
        self.expr_var = tk.StringVar(value="Expression: ")
        self.result_var = tk.StringVar(value="Result: ")
        self.step_text = tk.Text(root, width=60, height=10, wrap="word")
        tk.Label(root, textvariable=self.expr_var, anchor="w").grid(row=2, column=0, columnspan=3, sticky="w", padx=10)
        tk.Label(root, textvariable=self.result_var, anchor="w").grid(row=3, column=0, columnspan=3, sticky="w", padx=10)
        tk.Label(root, text="Steps:").grid(row=4, column=0, sticky="w", padx=10)
        self.step_text.grid(row=5, column=0, columnspan=3, padx=10, pady=4)
        # Load model
        try:
            self.model = load_model(MODEL_PATH)
        except Exception as e:
            messagebox.showerror("Model Error",
                                 f"Could not load {MODEL_PATH}.\nTrain it first with train_mnist_cnn.py.\n\n{e}")
            self.model = None
    # ---------- Drawing handlers ----------
    def pen_down(self, event):
        self.last_x, self.last_y = event.x, event.y
    def paint(self, event):
        if self.last_x is not None and self.last_y is not None:
            # Draw on Tk canvas
            self.canvas.create_line(self.last_x, self.last_y, event.x, event.y,
                                    width=DRAW_WIDTH, fill="black", capstyle=tk.ROUND, smooth=True)
            # Draw on PIL image
            self.draw.line([self.last_x, self.last_y, event.x, event.y],
                           fill=0, width=DRAW_WIDTH)
        self.last_x, self.last_y = event.x, event.y
    def pen_up(self, event):
        self.last_x, self.last_y = None, None
    def clear_canvas(self):
        self.canvas.delete("all")
        self.image = Image.new("L", (CANVAS_SIZE, CANVAS_SIZE), color=255)
        self.draw = ImageDraw.Draw(self.image)
        self.expr_var.set("Expression: ")
        self.result_var.set("Result: ")
        self.step_text.delete("1.0", tk.END)
    # ---------- Core pipeline ----------
    def recognize_and_solve(self):
        if self.model is None:
            messagebox.showwarning("Model", "Model not loaded.")
            return
        # Convert PIL to OpenCV
        img = np.array(self.image)
        expr, tokens_dbg = self.image_to_expression(img)
        if not expr:
            messagebox.showwarning("Parse", "Could not parse any symbols. Try writing bigger/cleaner.")
            return
        self.expr_var.set(f"Expression: {expr}")
        try:
            # Use sympy to evaluate
            sym_expr = sympify(expr)
            simplified = simplify(sym_expr)
            self.result_var.set(f"Result: {simplified}")
            # Show steps (simple for now)
            self.step_text.delete("1.0", tk.END)
            self.step_text.insert(tk.END, "Tokens (left→right):\n")
            self.step_text.insert(tk.END, " ".join(tokens_dbg) + "\n\n")
            self.step_text.insert(tk.END, f"SymPy parsed: {sym_expr}\n")
            if str(sym_expr) != str(simplified):
                self.step_text.insert(tk.END, f"Simplified: {simplified}\n")
            else:
                self.step_text.insert(tk.END, "No further simplification needed.\n")
        except Exception as e:
            messagebox.showerror("Evaluation Error", f"Failed to evaluate expression:\n{e}")
    def image_to_expression(self, gray_img: np.ndarray) -> tuple[str, list]:
        """
        Segment symbols, classify digits with CNN, infer + / - with projection heuristics.
        Returns (expression_string, debug_tokens)
        """
        # 1) Binarize & clean
        # Invert: handwriting is black (0), background white (255) => for OpenCV we want white-on-black for morphology ops.
        inv = 255 - gray_img
        # Threshold
        _, th = cv2.threshold(inv, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        # Morph open small noise
        kernel = np.ones((3,3), np.uint8)
        th = cv2.morphologyEx(th, cv2.MORPH_OPEN, kernel, iterations=1)
        # 2) Find contours (symbols)
        contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        boxes = []
        for cnt in contours:
            x, y, w, h = cv2.boundingRect(cnt)
            area = w * h
            if area < MIN_CONTOUR_AREA:
                continue
            boxes.append((x, y, w, h))
        if not boxes:
            return "", []
        # Sort left-to-right
        boxes.sort(key=lambda b: b[0])
        # Median height (helps operator heuristics)
        med_h = np.median([h for (_, _, _, h) in boxes])
        tokens = []
        debug_tokens = []
        for (x, y, w, h) in boxes:
            crop = th[y:y+h, x:x+w]  # white ink on black background
            # Operator heuristic first (minus / plus)
            op = self.classify_operator(crop, w, h, med_h)
            if op is not None:
                tokens.append(op)
                debug_tokens.append(f"[{op}]")
                continue
            # Otherwise, digit classification
            digit = self.classify_digit(crop)
            if digit is None:
                # If not digit and not recognized operator, skip (or treat as minus attempt)
                # Safer to skip
                continue
            tokens.append(str(digit))
            debug_tokens.append(str(digit))
        # Merge digits & operators into expression string
        expr = self.tokens_to_expression(tokens)
        return expr, debug_tokens
    def classify_digit(self, crop_bin: np.ndarray) -> int | None:
        """
        Prepare glyph for MNIST CNN (28x28, centered), and predict 0-9.
        crop_bin: white ink on black background (binary)
        """
        # Make sure it's binary (0/255)
        crop = (crop_bin > 0).astype(np.uint8) * 255
        # Add padding
        crop = cv2.copyMakeBorder(crop, PADDING, PADDING, PADDING, PADDING, cv2.BORDER_CONSTANT, value=0)
        # Find tight box again after pad
        ys, xs = np.where(crop > 0)
        if len(xs) == 0 or len(ys) == 0:
            return None
        x0, x1 = xs.min(), xs.max()
        y0, y1 = ys.min(), ys.max()
        crop = crop[y0:y1+1, x0:x1+1]
        # Resize to 20x20 then center in 28x28 (like MNIST preprocessing)
        h, w = crop.shape
        if h > w:
            new_h = 20
            new_w = int(w * (20.0 / h))
        else:
            new_w = 20
            new_h = int(h * (20.0 / w))
        if new_h <= 0: new_h = 1
        if new_w <= 0: new_w = 1
        resized = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_AREA)
        canvas = np.zeros((28, 28), dtype=np.uint8)
        y_off = (28 - new_h) // 2
        x_off = (28 - new_w) // 2
        canvas[y_off:y_off+new_h, x_off:x_off+new_w] = resized
        # Normalize for model: MNIST is black background (0) with white strokes (1)
        img = canvas.astype("float32") / 255.0
        img = img[..., None]  # (28,28,1)
        pred = self.model.predict(img[None, ...], verbose=0)[0]
        cls = int(np.argmax(pred))
        conf = float(np.max(pred))
        # Optional confidence filtering
        if conf < 0.40:
            return None
        return cls
    def classify_operator(self, crop_bin: np.ndarray, w: int, h: int, med_h: float) -> str | None:
        """
        Very lightweight heuristics:
        - '-' : wide, short, one thick horizontal stroke (width/height large, height << median digit height)
        - '+' : strong central vertical and horizontal projections (peaks)
        """
        # Work on binary with 1s where stroke is present
        b = (crop_bin > 0).astype(np.uint8)
        # Aspect ratio heuristic for '-'
        if h > 0:
            ar = w / float(h)
        else:
            ar = 0
        # height relative to median digit height
        h_frac = h / float(med_h) if med_h > 0 else 1.0
        # Horizontal projection profile (sum along columns) and vertical profile (sum along rows)
        vproj = b.sum(axis=0)  # per column
        hproj = b.sum(axis=1)  # per row
        v_center_peak = vproj[len(vproj)//2] / (b.shape[0] + 1e-6)
        h_center_peak = hproj[len(hproj)//2] / (b.shape[1] + 1e-6)
        # Minus: flat, wide, short
        if ar >= MINUS_AR_THRESH and h_frac <= MINUS_HEIGHT_FRAC:
            return "-"
        # Plus: vertical & horizontal strong central strokes
        if v_center_peak >= PLUS_PEAK_FRAC and h_center_peak >= PLUS_PEAK_FRAC:
            return "+"
        return None
    def tokens_to_expression(self, tokens: list[str]) -> str:
        """
        Combine tokens into a valid expression.
        - Collapse consecutive digits into multi-digit numbers.
        - Keep '+' and '-' as operators.
        - Remove illegal leading/trailing operators.
        """
        # Collapse digits
        out = []
        num_buf = []
        for t in tokens:
            if t.isdigit():
                num_buf.append(t)
            else:
                # flush number
                if num_buf:
                    out.append("".join(num_buf))
                    num_buf = []
                # operator allowed only if last is number
                if len(out) > 0 and out[-1][-1].isdigit() and t in {"+", "-"}:
                    out.append(t)
        # flush at end
        if num_buf:
            out.append("".join(num_buf))
        # Join safely
        expr = ""
        for item in out:
            if item in {"+", "-"}:
                expr += f" {item} "
            else:
                expr += item
        return expr.strip()
if __name__ == "__main__":
    root = tk.Tk()
    app = MathSolverApp(root)
    root.mainloop()


No comments: