Install requirements
pip install tensorflow Pillow opencv-python numpy sympy
Train a digit model once (MNIST) — train_mnist_cnn.py
This trains a small CNN on MNIST and saves mnist_cnn.h5
.
# train_mnist_cnn.py
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
def build_model():
model = keras.Sequential([
layers.Input(shape=(28, 28, 1)),
layers.Conv2D(32, 3, activation='relu'),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Dropout(0.25),
layers.Conv2D(64, 3, activation='relu'),
layers.Conv2D(64, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Dropout(0.25),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dropout(0.5),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
def main():
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = x_train[..., None]
x_test = x_test[..., None]
model = build_model()
model.fit(x_train, y_train, batch_size=128, epochs=5, validation_split=0.1)
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
print(f"Test accuracy: {test_acc:.4f}")
model.save("mnist_cnn.h5")
print("Saved model to mnist_cnn.h5")
if __name__ == "__main__":
main()
The GUI solver — handwritten_math_solver.py
# handwritten_math_solver.py
import tkinter as tk
from tkinter import messagebox
from PIL import Image, ImageDraw, ImageOps
import numpy as np
import cv2
import io
from sympy import sympify, simplify
from tensorflow.keras.models import load_model
# ---- Config ----
MODEL_PATH = "mnist_cnn.h5"
CANVAS_SIZE = 400 # drawing canvas (square)
DRAW_WIDTH = 14 # brush thickness (thicker = easier OCR)
MIN_CONTOUR_AREA = 60 # filter noise
PADDING = 8 # pad per glyph before resize to 28x28
# Heuristics thresholds for operators
MINUS_AR_THRESH = 2.0 # width/height > this → likely '-'
MINUS_HEIGHT_FRAC = 0.45 # symbol height relative to median digit height (shorter → minus)
PLUS_PEAK_FRAC = 0.6 # vertical and horizontal central peaks to consider '+'
class MathSolverApp:
def __init__(self, root):
self.root = root
self.root.title("Handwritten Math Solver (digits + +/−)")
# Canvas to draw
self.canvas = tk.Canvas(root, width=CANVAS_SIZE, height=CANVAS_SIZE, bg="white", cursor="cross")
self.canvas.grid(row=0, column=0, columnspan=3, padx=10, pady=10)
# PIL image to accumulate strokes (black on white)
self.image = Image.new("L", (CANVAS_SIZE, CANVAS_SIZE), color=255)
self.draw = ImageDraw.Draw(self.image)
# Bind drawing
self.last_x, self.last_y = None, None
self.canvas.bind("<ButtonPress-1>", self.pen_down)
self.canvas.bind("<B1-Motion>", self.paint)
self.canvas.bind("<ButtonRelease-1>", self.pen_up)
# Buttons
tk.Button(root, text="Recognize & Solve", command=self.recognize_and_solve).grid(row=1, column=0, pady=6)
tk.Button(root, text="Clear", command=self.clear_canvas).grid(row=1, column=1, pady=6)
tk.Button(root, text="Quit", command=root.quit).grid(row=1, column=2, pady=6)
# Output
self.expr_var = tk.StringVar(value="Expression: ")
self.result_var = tk.StringVar(value="Result: ")
self.step_text = tk.Text(root, width=60, height=10, wrap="word")
tk.Label(root, textvariable=self.expr_var, anchor="w").grid(row=2, column=0, columnspan=3, sticky="w", padx=10)
tk.Label(root, textvariable=self.result_var, anchor="w").grid(row=3, column=0, columnspan=3, sticky="w", padx=10)
tk.Label(root, text="Steps:").grid(row=4, column=0, sticky="w", padx=10)
self.step_text.grid(row=5, column=0, columnspan=3, padx=10, pady=4)
# Load model
try:
self.model = load_model(MODEL_PATH)
except Exception as e:
messagebox.showerror("Model Error",
f"Could not load {MODEL_PATH}.\nTrain it first with train_mnist_cnn.py.\n\n{e}")
self.model = None
# ---------- Drawing handlers ----------
def pen_down(self, event):
self.last_x, self.last_y = event.x, event.y
def paint(self, event):
if self.last_x is not None and self.last_y is not None:
# Draw on Tk canvas
self.canvas.create_line(self.last_x, self.last_y, event.x, event.y,
width=DRAW_WIDTH, fill="black", capstyle=tk.ROUND, smooth=True)
# Draw on PIL image
self.draw.line([self.last_x, self.last_y, event.x, event.y],
fill=0, width=DRAW_WIDTH)
self.last_x, self.last_y = event.x, event.y
def pen_up(self, event):
self.last_x, self.last_y = None, None
def clear_canvas(self):
self.canvas.delete("all")
self.image = Image.new("L", (CANVAS_SIZE, CANVAS_SIZE), color=255)
self.draw = ImageDraw.Draw(self.image)
self.expr_var.set("Expression: ")
self.result_var.set("Result: ")
self.step_text.delete("1.0", tk.END)
# ---------- Core pipeline ----------
def recognize_and_solve(self):
if self.model is None:
messagebox.showwarning("Model", "Model not loaded.")
return
# Convert PIL to OpenCV
img = np.array(self.image)
expr, tokens_dbg = self.image_to_expression(img)
if not expr:
messagebox.showwarning("Parse", "Could not parse any symbols. Try writing bigger/cleaner.")
return
self.expr_var.set(f"Expression: {expr}")
try:
# Use sympy to evaluate
sym_expr = sympify(expr)
simplified = simplify(sym_expr)
self.result_var.set(f"Result: {simplified}")
# Show steps (simple for now)
self.step_text.delete("1.0", tk.END)
self.step_text.insert(tk.END, "Tokens (left→right):\n")
self.step_text.insert(tk.END, " ".join(tokens_dbg) + "\n\n")
self.step_text.insert(tk.END, f"SymPy parsed: {sym_expr}\n")
if str(sym_expr) != str(simplified):
self.step_text.insert(tk.END, f"Simplified: {simplified}\n")
else:
self.step_text.insert(tk.END, "No further simplification needed.\n")
except Exception as e:
messagebox.showerror("Evaluation Error", f"Failed to evaluate expression:\n{e}")
def image_to_expression(self, gray_img: np.ndarray) -> tuple[str, list]:
"""
Segment symbols, classify digits with CNN, infer + / - with projection heuristics.
Returns (expression_string, debug_tokens)
"""
# 1) Binarize & clean
# Invert: handwriting is black (0), background white (255) => for OpenCV we want white-on-black for morphology ops.
inv = 255 - gray_img
# Threshold
_, th = cv2.threshold(inv, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Morph open small noise
kernel = np.ones((3,3), np.uint8)
th = cv2.morphologyEx(th, cv2.MORPH_OPEN, kernel, iterations=1)
# 2) Find contours (symbols)
contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
boxes = []
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
area = w * h
if area < MIN_CONTOUR_AREA:
continue
boxes.append((x, y, w, h))
if not boxes:
return "", []
# Sort left-to-right
boxes.sort(key=lambda b: b[0])
# Median height (helps operator heuristics)
med_h = np.median([h for (_, _, _, h) in boxes])
tokens = []
debug_tokens = []
for (x, y, w, h) in boxes:
crop = th[y:y+h, x:x+w] # white ink on black background
# Operator heuristic first (minus / plus)
op = self.classify_operator(crop, w, h, med_h)
if op is not None:
tokens.append(op)
debug_tokens.append(f"[{op}]")
continue
# Otherwise, digit classification
digit = self.classify_digit(crop)
if digit is None:
# If not digit and not recognized operator, skip (or treat as minus attempt)
# Safer to skip
continue
tokens.append(str(digit))
debug_tokens.append(str(digit))
# Merge digits & operators into expression string
expr = self.tokens_to_expression(tokens)
return expr, debug_tokens
def classify_digit(self, crop_bin: np.ndarray) -> int | None:
"""
Prepare glyph for MNIST CNN (28x28, centered), and predict 0-9.
crop_bin: white ink on black background (binary)
"""
# Make sure it's binary (0/255)
crop = (crop_bin > 0).astype(np.uint8) * 255
# Add padding
crop = cv2.copyMakeBorder(crop, PADDING, PADDING, PADDING, PADDING, cv2.BORDER_CONSTANT, value=0)
# Find tight box again after pad
ys, xs = np.where(crop > 0)
if len(xs) == 0 or len(ys) == 0:
return None
x0, x1 = xs.min(), xs.max()
y0, y1 = ys.min(), ys.max()
crop = crop[y0:y1+1, x0:x1+1]
# Resize to 20x20 then center in 28x28 (like MNIST preprocessing)
h, w = crop.shape
if h > w:
new_h = 20
new_w = int(w * (20.0 / h))
else:
new_w = 20
new_h = int(h * (20.0 / w))
if new_h <= 0: new_h = 1
if new_w <= 0: new_w = 1
resized = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_AREA)
canvas = np.zeros((28, 28), dtype=np.uint8)
y_off = (28 - new_h) // 2
x_off = (28 - new_w) // 2
canvas[y_off:y_off+new_h, x_off:x_off+new_w] = resized
# Normalize for model: MNIST is black background (0) with white strokes (1)
img = canvas.astype("float32") / 255.0
img = img[..., None] # (28,28,1)
pred = self.model.predict(img[None, ...], verbose=0)[0]
cls = int(np.argmax(pred))
conf = float(np.max(pred))
# Optional confidence filtering
if conf < 0.40:
return None
return cls
def classify_operator(self, crop_bin: np.ndarray, w: int, h: int, med_h: float) -> str | None:
"""
Very lightweight heuristics:
- '-' : wide, short, one thick horizontal stroke (width/height large, height << median digit height)
- '+' : strong central vertical and horizontal projections (peaks)
"""
# Work on binary with 1s where stroke is present
b = (crop_bin > 0).astype(np.uint8)
# Aspect ratio heuristic for '-'
if h > 0:
ar = w / float(h)
else:
ar = 0
# height relative to median digit height
h_frac = h / float(med_h) if med_h > 0 else 1.0
# Horizontal projection profile (sum along columns) and vertical profile (sum along rows)
vproj = b.sum(axis=0) # per column
hproj = b.sum(axis=1) # per row
v_center_peak = vproj[len(vproj)//2] / (b.shape[0] + 1e-6)
h_center_peak = hproj[len(hproj)//2] / (b.shape[1] + 1e-6)
# Minus: flat, wide, short
if ar >= MINUS_AR_THRESH and h_frac <= MINUS_HEIGHT_FRAC:
return "-"
# Plus: vertical & horizontal strong central strokes
if v_center_peak >= PLUS_PEAK_FRAC and h_center_peak >= PLUS_PEAK_FRAC:
return "+"
return None
def tokens_to_expression(self, tokens: list[str]) -> str:
"""
Combine tokens into a valid expression.
- Collapse consecutive digits into multi-digit numbers.
- Keep '+' and '-' as operators.
- Remove illegal leading/trailing operators.
"""
# Collapse digits
out = []
num_buf = []
for t in tokens:
if t.isdigit():
num_buf.append(t)
else:
# flush number
if num_buf:
out.append("".join(num_buf))
num_buf = []
# operator allowed only if last is number
if len(out) > 0 and out[-1][-1].isdigit() and t in {"+", "-"}:
out.append(t)
# flush at end
if num_buf:
out.append("".join(num_buf))
# Join safely
expr = ""
for item in out:
if item in {"+", "-"}:
expr += f" {item} "
else:
expr += item
return expr.strip()
if __name__ == "__main__":
root = tk.Tk()
app = MathSolverApp(root)
root.mainloop()