Blog Pages

Virtual Stock Trading Game

 """

Virtual Stock Trading Game (Streamlit)


Features:

- Simple username registration/login

- Buy / Sell simulated orders at current market price (via yfinance)

- Portfolio view, transaction history

- Leaderboard by total portfolio value

- SQLite persistence


Run:

    streamlit run virtual_trading_app.py

"""


import streamlit as st

import yfinance as yf

import pandas as pd

import sqlite3

from datetime import datetime

import altair as alt

import os


# -----------------------

# Config

# -----------------------

DB_FILE = "trading.db"

STARTING_CASH = 100000.0  # default starting cash for new users


# -----------------------

# Database helpers

# -----------------------

def get_conn():

    conn = sqlite3.connect(DB_FILE, check_same_thread=False)

    return conn


def init_db():

    conn = get_conn()

    cur = conn.cursor()

    cur.execute("""

    CREATE TABLE IF NOT EXISTS users (

        id INTEGER PRIMARY KEY AUTOINCREMENT,

        username TEXT UNIQUE,

        cash REAL,

        created_at TEXT

    )

    """)

    cur.execute("""

    CREATE TABLE IF NOT EXISTS holdings (

        id INTEGER PRIMARY KEY AUTOINCREMENT,

        user_id INTEGER,

        ticker TEXT,

        quantity REAL,

        avg_price REAL,

        FOREIGN KEY(user_id) REFERENCES users(id)

    )

    """)

    cur.execute("""

    CREATE TABLE IF NOT EXISTS transactions (

        id INTEGER PRIMARY KEY AUTOINCREMENT,

        user_id INTEGER,

        ticker TEXT,

        quantity REAL,

        price REAL,

        side TEXT,              -- 'BUY' or 'SELL'

        timestamp TEXT,

        FOREIGN KEY(user_id) REFERENCES users(id)

    )

    """)

    conn.commit()

    conn.close()


def create_user(username, starting_cash=STARTING_CASH):

    conn = get_conn()

    cur = conn.cursor()

    now = datetime.utcnow().isoformat()

    try:

        cur.execute("INSERT INTO users (username, cash, created_at) VALUES (?, ?, ?)",

                    (username, float(starting_cash), now))

        conn.commit()

    except sqlite3.IntegrityError:

        pass

    conn.close()


def get_user(username):

    conn = get_conn()

    cur = conn.cursor()

    cur.execute("SELECT id, username, cash, created_at FROM users WHERE username=?", (username,))

    row = cur.fetchone()

    conn.close()

    if row:

        return {"id": row[0], "username": row[1], "cash": row[2], "created_at": row[3]}

    return None


def update_cash(user_id, new_cash):

    conn = get_conn()

    cur = conn.cursor()

    cur.execute("UPDATE users SET cash=? WHERE id=?", (new_cash, user_id))

    conn.commit()

    conn.close()


def get_holdings(user_id):

    conn = get_conn()

    cur = conn.cursor()

    cur.execute("SELECT ticker, quantity, avg_price FROM holdings WHERE user_id=?", (user_id,))

    rows = cur.fetchall()

    conn.close()

    df = pd.DataFrame(rows, columns=["ticker", "quantity", "avg_price"])

    if df.empty:

        return pd.DataFrame(columns=["ticker", "quantity", "avg_price"])

    return df


def upsert_holding(user_id, ticker, qty_delta, trade_price):

    """

    Add or update holdings:

    - If buying: qty_delta positive -> update quantity and avg_price

    - If selling: qty_delta negative -> reduce quantity; if qty becomes 0 remove row

    """

    conn = get_conn()

    cur = conn.cursor()

    cur.execute("SELECT id, quantity, avg_price FROM holdings WHERE user_id=? AND ticker=?", (user_id, ticker))

    row = cur.fetchone()

    if row:

        hid, qty, avg = row

        new_qty = qty + qty_delta

        if new_qty <= 0.000001:

            cur.execute("DELETE FROM holdings WHERE id=?", (hid,))

        else:

            if qty_delta > 0:

                # new weighted avg: (qty*avg + qty_delta*trade_price) / (qty+qty_delta)

                new_avg = (qty * avg + qty_delta * trade_price) / (qty + qty_delta)

            else:

                new_avg = avg

            cur.execute("UPDATE holdings SET quantity=?, avg_price=? WHERE id=?", (new_qty, new_avg, hid))

    else:

        if qty_delta > 0:

            cur.execute("INSERT INTO holdings (user_id, ticker, quantity, avg_price) VALUES (?,?,?,?)",

                        (user_id, ticker, qty_delta, trade_price))

    conn.commit()

    conn.close()


def record_transaction(user_id, ticker, quantity, price, side):

    conn = get_conn()

    cur = conn.cursor()

    now = datetime.utcnow().isoformat()

    cur.execute("INSERT INTO transactions (user_id, ticker, quantity, price, side, timestamp) VALUES (?, ?, ?, ?, ?, ?)",

                (user_id, ticker, quantity, price, side, now))

    conn.commit()

    conn.close()


def get_transactions(user_id, limit=200):

    conn = get_conn()

    cur = conn.cursor()

    cur.execute("SELECT ticker, quantity, price, side, timestamp FROM transactions WHERE user_id=? ORDER BY id DESC LIMIT ?",

                (user_id, limit))

    rows = cur.fetchall()

    conn.close()

    df = pd.DataFrame(rows, columns=["ticker", "quantity", "price", "side", "timestamp"])

    if df.empty:

        return pd.DataFrame(columns=["ticker", "quantity", "price", "side", "timestamp"])

    return df


def get_leaderboard(top_n=20):

    conn = get_conn()

    cur = conn.cursor()

    cur.execute("SELECT id, username, cash FROM users")

    users = cur.fetchall()

    leaderboard = []

    for uid, username, cash in users:

        # compute portfolio market value

        holdings = get_holdings(uid)

        total = float(cash)

        if not holdings.empty:

            tickers = list(holdings["ticker"].unique())

            market = fetch_market_prices(tickers)

            for _, row in holdings.iterrows():

                t = row["ticker"]

                q = float(row["quantity"])

                price = market.get(t, 0.0)

                total += q * price

        leaderboard.append({"username": username, "total": total})

    conn.close()

    lb = pd.DataFrame(leaderboard).sort_values("total", ascending=False).reset_index(drop=True)

    return lb.head(top_n)


# -----------------------

# Market helpers (yfinance)

# -----------------------

def fetch_price(ticker):

    """

    Return latest price (use fast yfinance call). If ticker invalid, raises.

    """

    try:

        t = yf.Ticker(ticker)

        # use fast info if available

        price = None

        # prefer real-time quote

        quote = t.history(period="1d", interval="1m")

        if not quote.empty:

            price = quote["Close"].iloc[-1]

        else:

            info = t.info

            price = info.get("regularMarketPrice")

        if price is None:

            raise ValueError("Price not available")

        return float(price)

    except Exception as e:

        raise ValueError(f"Could not fetch price for {ticker}: {e}")


def fetch_market_prices(tickers):

    """

    Bulk fetch latest closing prices for a list of tickers using yfinance download -> faster.

    Returns dict ticker -> price

    """

    out = {}

    if not tickers:

        return out

    # yfinance can handle list

    try:

        df = yf.download(tickers, period="1d", interval="1m", progress=False)

        # df['Close'] may be single or multi-column

        if isinstance(df.columns, pd.MultiIndex):

            closes = df['Close'].iloc[-1]

            for t in tickers:

                try:

                    out[t] = float(closes[t])

                except Exception:

                    out[t] = 0.0

        else:

            out[tickers[0]] = float(df['Close'].iloc[-1])

    except Exception:

        # fallback to single fetch

        for t in tickers:

            try:

                out[t] = fetch_price(t)

            except Exception:

                out[t] = 0.0

    return out


# -----------------------

# Trading logic

# -----------------------

def attempt_buy(user, ticker, quantity):

    """

    Attempt to buy `quantity` shares at current price. Returns (success, message).

    """

    try:

        qty = float(quantity)

        if qty <= 0:

            return False, "Quantity must be > 0"

    except:

        return False, "Invalid quantity"


    try:

        price = fetch_price(ticker)

    except Exception as e:

        return False, f"Price fetch error: {e}"


    cost = qty * price

    if cost > user["cash"] + 1e-9:

        return False, f"Insufficient funds: need {cost:.2f}, available {user['cash']:.2f}"


    # perform transaction

    new_cash = float(user["cash"]) - cost

    update_cash(user["id"], new_cash)

    upsert_holding(user["id"], ticker.upper(), qty, price)

    record_transaction(user["id"], ticker.upper(), qty, price, "BUY")

    # refresh user

    return True, f"Bought {qty} shares of {ticker.upper()} at {price:.2f} (cost {cost:.2f})"


def attempt_sell(user, ticker, quantity):

    try:

        qty = float(quantity)

        if qty <= 0:

            return False, "Quantity must be > 0"

    except:

        return False, "Invalid quantity"


    holdings = get_holdings(user["id"])

    if holdings.empty or ticker.upper() not in list(holdings["ticker"].str.upper()):

        return False, "No holdings for this ticker"


    row = holdings[holdings["ticker"].str.upper() == ticker.upper()].iloc[0]

    owned = float(row["quantity"])

    if qty > owned + 1e-9:

        return False, f"Not enough shares to sell (owned {owned})"


    try:

        price = fetch_price(ticker)

    except Exception as e:

        return False, f"Price fetch error: {e}"


    proceeds = qty * price

    new_cash = float(get_user(user["username"])["cash"]) + proceeds

    update_cash(user["id"], new_cash)

    upsert_holding(user["id"], ticker.upper(), -qty, price)

    record_transaction(user["id"], ticker.upper(), qty, price, "SELL")

    return True, f"Sold {qty} shares of {ticker.upper()} at {price:.2f} (proceeds {proceeds:.2f})"


# -----------------------

# UI

# -----------------------

def login_ui():

    st.sidebar.header("Player Login / Register")

    username = st.sidebar.text_input("Enter username", key="login_username")

    if st.sidebar.button("Login / Register"):

        if not username.strip():

            st.sidebar.error("Please enter a username")

            return None

        create_user(username.strip())

        user = get_user(username.strip())

        st.session_state["user"] = user

        st.sidebar.success(f"Logged in as {user['username']}")

        return user

    return None


def main_app(user):

    st.title("📈 Virtual Stock Trading Game")

    st.write("**Simulation only — not financial advice.**")

    st.markdown("---")


    # show user summary

    col1, col2 = st.columns([2,1])

    with col1:

        st.subheader(f"Hello, {user['username']} 👋")

        st.write(f"**Cash:** ${user['cash']:.2f}")

        holdings = get_holdings(user["id"])

        if holdings.empty:

            st.info("You have no holdings yet. Search a ticker and buy to get started.")

        else:

            st.write("Your holdings:")

            # Fetch market prices for tickers

            tickers = list(holdings["ticker"].unique())

            market = fetch_market_prices([t for t in tickers])

            holdings_display = holdings.copy()

            holdings_display["market_price"] = holdings_display["ticker"].apply(lambda t: market.get(t, 0.0))

            holdings_display["market_value"] = holdings_display["quantity"] * holdings_display["market_price"]

            holdings_display["unreal_pnl"] = holdings_display["market_value"] - holdings_display["quantity"] * holdings_display["avg_price"]

            st.dataframe(holdings_display.style.format({"quantity":"{:.3f}", "avg_price":"{:.2f}", "market_price":"{:.2f}", "market_value":"{:.2f}", "unreal_pnl":"{:.2f}"}), use_container_width=True)

            total_market = holdings_display["market_value"].sum()

            st.write(f"Total holdings market value: ${total_market:.2f}")


    with col2:

        st.subheader("Leaderboard")

        lb = get_leaderboard()

        if lb.empty:

            st.write("No players yet.")

        else:

            st.table(lb.style.format({"total":"${:,.2f}"}).head(10))


    st.markdown("---")

    # Trading panel

    st.header("Trade")

    tcol1, tcol2 = st.columns(2)

    with tcol1:

        ticker = st.text_input("Ticker (e.g., AAPL)", key="trade_ticker")

        qty = st.number_input("Quantity", min_value=0.0, value=1.0, step=1.0, key="trade_qty")

    with tcol2:

        if st.button("Fetch Price"):

            try:

                price = fetch_price(ticker)

                st.success(f"Price for {ticker.upper()}: ${price:.2f}")

            except Exception as e:

                st.error(str(e))


        if st.button("Buy"):

            if not ticker:

                st.error("Enter ticker")

            else:

                ok, msg = attempt_buy(user, ticker, qty)

                if ok:

                    st.success(msg)

                    # refresh user object

                    st.session_state["user"] = get_user(user["username"])

                else:

                    st.error(msg)


        if st.button("Sell"):

            if not ticker:

                st.error("Enter ticker")

            else:

                ok, msg = attempt_sell(user, ticker, qty)

                if ok:

                    st.success(msg)

                    st.session_state["user"] = get_user(user["username"])

                else:

                    st.error(msg)


    st.markdown("---")

    # Transaction history and portfolio chart

    st.header("Transaction History & Portfolio Value")

    tx = get_transactions(user["id"], limit=500)

    st.subheader("Recent Transactions")

    if tx.empty:

        st.info("No transactions yet.")

    else:

        st.dataframe(tx, use_container_width=True)


    # Portfolio value over time (reconstruct from transactions)

    st.subheader("Portfolio Value (by re-using transactions)")

    # basic reconstruction: assume each transaction timestamp, compute cash and holdings snapshot

    # We will create a simple time series from transactions for demo

    conn = get_conn()

    q = conn.cursor()

    q.execute("SELECT timestamp, ticker, quantity, price, side FROM transactions WHERE user_id=? ORDER BY id ASC", (user["id"],))

    rows = q.fetchall()

    conn.close()

    if rows:

        df_tx = pd.DataFrame(rows, columns=["timestamp","ticker","quantity","price","side"])

        df_tx["timestamp"] = pd.to_datetime(df_tx["timestamp"])

        # sample points: we compute portfolio value at each tx time using latest market prices (this is approximate)

        records = []

        cash = get_user(user["username"])["cash"]

        # Instead compute forward: start with starting cash and apply transactions in order to track cash (we need starting cash)

        start_user = get_user(user["username"])

        # To compute portfolio value over time properly we'd need historic prices at each tx time — skip heavy calls; instead show current portfolio snapshot vs time by trade counts

        # So we'll create a simple chart: cumulative invested vs current market value

        holdings_now = get_holdings(user["id"])

        if not holdings_now.empty:

            prices = fetch_market_prices(list(holdings_now["ticker"].unique()))

            holdings_now["market_price"] = holdings_now["ticker"].apply(lambda t: prices.get(t, 0.0))

            holdings_now["market_value"] = holdings_now["quantity"] * holdings_now["market_price"]

            chart_df = holdings_now[["ticker","market_value"]]

            chart_df = chart_df.rename(columns={"market_value":"value"})

            st.write("Current holdings market values:")

            st.dataframe(holdings_now)

            chart = alt.Chart(chart_df).mark_bar().encode(x="ticker", y="value")

            st.altair_chart(chart, use_container_width=True)


    st.markdown("---")

    st.sidebar.markdown("## Player Actions")

    if st.sidebar.button("Refresh Data"):

        st.session_state["user"] = get_user(user["username"])

        st.experimental_rerun()


    if st.sidebar.button("Log out"):

        st.session_state.pop("user", None)

        st.experimental_rerun()


# -----------------------

# App entrypoint

# -----------------------

def main():

    st.set_page_config(page_title="Virtual Stock Trading Game", layout="wide")

    init_db()


    st.sidebar.title("Virtual Trading")

    user = st.session_state.get("user", None)

    if not user:

        ui_user = login_ui()

        if ui_user:

            user = ui_user

    else:

        # refresh user data from DB

        user = get_user(user["username"])

        st.session_state["user"] = user


    if user:

        main_app(user)

    else:

        st.title("Welcome to the Virtual Stock Trading Game")

        st.write("Create a username in the left panel to start. You'll receive some starting cash to practice trading.")

        st.info("This app uses real market prices via yfinance but only simulates trades with fake money.")


if __name__ == "__main__":

    main()

No comments:

Post a Comment