import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score,
mean_squared_error, confusion_matrix
)
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.svm import SVC, SVR
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
st.set_page_config(page_title="Local ML Model Trainer", layout="wide")
st.title("Local ML Model Trainer Interface")
st.write("Upload a dataset → choose an algorithm → train → view results")
# ────────────────────────────────────────────────
# Upload Dataset
# ────────────────────────────────────────────────
uploaded_file = st.file_uploader("📤 Upload CSV Dataset", type=["csv"])
if uploaded_file:
df = pd.read_csv(uploaded_file)
st.success("Dataset Loaded Successfully!")
st.write("### Data Preview")
st.dataframe(df.head())
st.write("### Dataset Info")
st.write(df.describe())
# Target column selection
target_col = st.selectbox(" Select Target Column (Y)", df.columns)
# Feature columns
X = df.drop(columns=[target_col])
y = df[target_col]
# Auto detect problem type
if df[target_col].dtype == object or df[target_col].nunique() < 15:
problem_type = "classification"
else:
problem_type = "regression"
st.info(f"Detected Problem Type: **{problem_type.upper()}**")
# Choose model based on problem type
if problem_type == "classification":
model_choice = st.selectbox(
"Choose Model",
["Logistic Regression", "Random Forest Classifier", "SVM Classifier", "KNN Classifier"]
)
else:
model_choice = st.selectbox(
"Choose Model",
["Linear Regression", "Random Forest Regressor", "SVM Regressor", "KNN Regressor"]
)
test_size = st.slider("Test Size (Train %)", 0.1, 0.5, 0.2)
# Train button
if st.button(" Train Model"):
# Preprocessing
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X.select_dtypes(include=np.number))
X_train, X_test, y_train, y_test = train_test_split(
X_scaled, y, test_size=test_size, random_state=42
)
# Model Selection
if model_choice == "Logistic Regression":
model = LogisticRegression()
elif model_choice == "Random Forest Classifier":
model = RandomForestClassifier()
elif model_choice == "SVM Classifier":
model = SVC()
elif model_choice == "KNN Classifier":
model = KNeighborsClassifier()
elif model_choice == "Linear Regression":
model = LinearRegression()
elif model_choice == "Random Forest Regressor":
model = RandomForestRegressor()
elif model_choice == "SVM Regressor":
model = SVR()
elif model_choice == "KNN Regressor":
model = KNeighborsRegressor()
# Train
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
st.success("Model Trained Successfully!")
# ────────────────────────────────────────────────
# Show Metrics
# ────────────────────────────────────────────────
st.write("## 📈 Model Performance")
if problem_type == "classification":
st.write("### 🔹 Classification Metrics")
st.write(f"Accuracy: **{accuracy_score(y_test, y_pred):.4f}**")
st.write(f"Precision: **{precision_score(y_test, y_pred, average='weighted'):.4f}**")
st.write(f"Recall: **{recall_score(y_test, y_pred, average='weighted'):.4f}**")
st.write(f"F1 Score: **{f1_score(y_test, y_pred, average='weighted'):.4f}**")
# Confusion Matrix
cm = confusion_matrix(y_test, y_pred)
fig, ax = plt.subplots(figsize=(5, 4))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", ax=ax)
st.write("### Confusion Matrix")
st.pyplot(fig)
else:
st.write("### 🔹 Regression Metrics")
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
st.write(f"RMSE: **{rmse:.4f}**")
# ────────────────────────────────────────────────
# Feature Importance (for tree models)
# ────────────────────────────────────────────────
if "Forest" in model_choice:
st.write("## Feature Importance")
importance = model.feature_importances_
fig, ax = plt.subplots(figsize=(6, 4))
sns.barplot(x=importance, y=X.columns, ax=ax)
st.pyplot(fig)
No comments:
Post a Comment