"""
Virtual Stock Trading Game (Streamlit)
Features:
- Simple username registration/login
- Buy / Sell simulated orders at current market price (via yfinance)
- Portfolio view, transaction history
- Leaderboard by total portfolio value
- SQLite persistence
Run:
streamlit run virtual_trading_app.py
"""
import streamlit as st
import yfinance as yf
import pandas as pd
import sqlite3
from datetime import datetime
import altair as alt
import os
# -----------------------
# Config
# -----------------------
DB_FILE = "trading.db"
STARTING_CASH = 100000.0 # default starting cash for new users
# -----------------------
# Database helpers
# -----------------------
def get_conn():
conn = sqlite3.connect(DB_FILE, check_same_thread=False)
return conn
def init_db():
conn = get_conn()
cur = conn.cursor()
cur.execute("""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE,
cash REAL,
created_at TEXT
)
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS holdings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER,
ticker TEXT,
quantity REAL,
avg_price REAL,
FOREIGN KEY(user_id) REFERENCES users(id)
)
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS transactions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER,
ticker TEXT,
quantity REAL,
price REAL,
side TEXT, -- 'BUY' or 'SELL'
timestamp TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
)
""")
conn.commit()
conn.close()
def create_user(username, starting_cash=STARTING_CASH):
conn = get_conn()
cur = conn.cursor()
now = datetime.utcnow().isoformat()
try:
cur.execute("INSERT INTO users (username, cash, created_at) VALUES (?, ?, ?)",
(username, float(starting_cash), now))
conn.commit()
except sqlite3.IntegrityError:
pass
conn.close()
def get_user(username):
conn = get_conn()
cur = conn.cursor()
cur.execute("SELECT id, username, cash, created_at FROM users WHERE username=?", (username,))
row = cur.fetchone()
conn.close()
if row:
return {"id": row[0], "username": row[1], "cash": row[2], "created_at": row[3]}
return None
def update_cash(user_id, new_cash):
conn = get_conn()
cur = conn.cursor()
cur.execute("UPDATE users SET cash=? WHERE id=?", (new_cash, user_id))
conn.commit()
conn.close()
def get_holdings(user_id):
conn = get_conn()
cur = conn.cursor()
cur.execute("SELECT ticker, quantity, avg_price FROM holdings WHERE user_id=?", (user_id,))
rows = cur.fetchall()
conn.close()
df = pd.DataFrame(rows, columns=["ticker", "quantity", "avg_price"])
if df.empty:
return pd.DataFrame(columns=["ticker", "quantity", "avg_price"])
return df
def upsert_holding(user_id, ticker, qty_delta, trade_price):
"""
Add or update holdings:
- If buying: qty_delta positive -> update quantity and avg_price
- If selling: qty_delta negative -> reduce quantity; if qty becomes 0 remove row
"""
conn = get_conn()
cur = conn.cursor()
cur.execute("SELECT id, quantity, avg_price FROM holdings WHERE user_id=? AND ticker=?", (user_id, ticker))
row = cur.fetchone()
if row:
hid, qty, avg = row
new_qty = qty + qty_delta
if new_qty <= 0.000001:
cur.execute("DELETE FROM holdings WHERE id=?", (hid,))
else:
if qty_delta > 0:
# new weighted avg: (qty*avg + qty_delta*trade_price) / (qty+qty_delta)
new_avg = (qty * avg + qty_delta * trade_price) / (qty + qty_delta)
else:
new_avg = avg
cur.execute("UPDATE holdings SET quantity=?, avg_price=? WHERE id=?", (new_qty, new_avg, hid))
else:
if qty_delta > 0:
cur.execute("INSERT INTO holdings (user_id, ticker, quantity, avg_price) VALUES (?,?,?,?)",
(user_id, ticker, qty_delta, trade_price))
conn.commit()
conn.close()
def record_transaction(user_id, ticker, quantity, price, side):
conn = get_conn()
cur = conn.cursor()
now = datetime.utcnow().isoformat()
cur.execute("INSERT INTO transactions (user_id, ticker, quantity, price, side, timestamp) VALUES (?, ?, ?, ?, ?, ?)",
(user_id, ticker, quantity, price, side, now))
conn.commit()
conn.close()
def get_transactions(user_id, limit=200):
conn = get_conn()
cur = conn.cursor()
cur.execute("SELECT ticker, quantity, price, side, timestamp FROM transactions WHERE user_id=? ORDER BY id DESC LIMIT ?",
(user_id, limit))
rows = cur.fetchall()
conn.close()
df = pd.DataFrame(rows, columns=["ticker", "quantity", "price", "side", "timestamp"])
if df.empty:
return pd.DataFrame(columns=["ticker", "quantity", "price", "side", "timestamp"])
return df
def get_leaderboard(top_n=20):
conn = get_conn()
cur = conn.cursor()
cur.execute("SELECT id, username, cash FROM users")
users = cur.fetchall()
leaderboard = []
for uid, username, cash in users:
# compute portfolio market value
holdings = get_holdings(uid)
total = float(cash)
if not holdings.empty:
tickers = list(holdings["ticker"].unique())
market = fetch_market_prices(tickers)
for _, row in holdings.iterrows():
t = row["ticker"]
q = float(row["quantity"])
price = market.get(t, 0.0)
total += q * price
leaderboard.append({"username": username, "total": total})
conn.close()
lb = pd.DataFrame(leaderboard).sort_values("total", ascending=False).reset_index(drop=True)
return lb.head(top_n)
# -----------------------
# Market helpers (yfinance)
# -----------------------
def fetch_price(ticker):
"""
Return latest price (use fast yfinance call). If ticker invalid, raises.
"""
try:
t = yf.Ticker(ticker)
# use fast info if available
price = None
# prefer real-time quote
quote = t.history(period="1d", interval="1m")
if not quote.empty:
price = quote["Close"].iloc[-1]
else:
info = t.info
price = info.get("regularMarketPrice")
if price is None:
raise ValueError("Price not available")
return float(price)
except Exception as e:
raise ValueError(f"Could not fetch price for {ticker}: {e}")
def fetch_market_prices(tickers):
"""
Bulk fetch latest closing prices for a list of tickers using yfinance download -> faster.
Returns dict ticker -> price
"""
out = {}
if not tickers:
return out
# yfinance can handle list
try:
df = yf.download(tickers, period="1d", interval="1m", progress=False)
# df['Close'] may be single or multi-column
if isinstance(df.columns, pd.MultiIndex):
closes = df['Close'].iloc[-1]
for t in tickers:
try:
out[t] = float(closes[t])
except Exception:
out[t] = 0.0
else:
out[tickers[0]] = float(df['Close'].iloc[-1])
except Exception:
# fallback to single fetch
for t in tickers:
try:
out[t] = fetch_price(t)
except Exception:
out[t] = 0.0
return out
# -----------------------
# Trading logic
# -----------------------
def attempt_buy(user, ticker, quantity):
"""
Attempt to buy `quantity` shares at current price. Returns (success, message).
"""
try:
qty = float(quantity)
if qty <= 0:
return False, "Quantity must be > 0"
except:
return False, "Invalid quantity"
try:
price = fetch_price(ticker)
except Exception as e:
return False, f"Price fetch error: {e}"
cost = qty * price
if cost > user["cash"] + 1e-9:
return False, f"Insufficient funds: need {cost:.2f}, available {user['cash']:.2f}"
# perform transaction
new_cash = float(user["cash"]) - cost
update_cash(user["id"], new_cash)
upsert_holding(user["id"], ticker.upper(), qty, price)
record_transaction(user["id"], ticker.upper(), qty, price, "BUY")
# refresh user
return True, f"Bought {qty} shares of {ticker.upper()} at {price:.2f} (cost {cost:.2f})"
def attempt_sell(user, ticker, quantity):
try:
qty = float(quantity)
if qty <= 0:
return False, "Quantity must be > 0"
except:
return False, "Invalid quantity"
holdings = get_holdings(user["id"])
if holdings.empty or ticker.upper() not in list(holdings["ticker"].str.upper()):
return False, "No holdings for this ticker"
row = holdings[holdings["ticker"].str.upper() == ticker.upper()].iloc[0]
owned = float(row["quantity"])
if qty > owned + 1e-9:
return False, f"Not enough shares to sell (owned {owned})"
try:
price = fetch_price(ticker)
except Exception as e:
return False, f"Price fetch error: {e}"
proceeds = qty * price
new_cash = float(get_user(user["username"])["cash"]) + proceeds
update_cash(user["id"], new_cash)
upsert_holding(user["id"], ticker.upper(), -qty, price)
record_transaction(user["id"], ticker.upper(), qty, price, "SELL")
return True, f"Sold {qty} shares of {ticker.upper()} at {price:.2f} (proceeds {proceeds:.2f})"
# -----------------------
# UI
# -----------------------
def login_ui():
st.sidebar.header("Player Login / Register")
username = st.sidebar.text_input("Enter username", key="login_username")
if st.sidebar.button("Login / Register"):
if not username.strip():
st.sidebar.error("Please enter a username")
return None
create_user(username.strip())
user = get_user(username.strip())
st.session_state["user"] = user
st.sidebar.success(f"Logged in as {user['username']}")
return user
return None
def main_app(user):
st.title("📈 Virtual Stock Trading Game")
st.write("**Simulation only — not financial advice.**")
st.markdown("---")
# show user summary
col1, col2 = st.columns([2,1])
with col1:
st.subheader(f"Hello, {user['username']} 👋")
st.write(f"**Cash:** ${user['cash']:.2f}")
holdings = get_holdings(user["id"])
if holdings.empty:
st.info("You have no holdings yet. Search a ticker and buy to get started.")
else:
st.write("Your holdings:")
# Fetch market prices for tickers
tickers = list(holdings["ticker"].unique())
market = fetch_market_prices([t for t in tickers])
holdings_display = holdings.copy()
holdings_display["market_price"] = holdings_display["ticker"].apply(lambda t: market.get(t, 0.0))
holdings_display["market_value"] = holdings_display["quantity"] * holdings_display["market_price"]
holdings_display["unreal_pnl"] = holdings_display["market_value"] - holdings_display["quantity"] * holdings_display["avg_price"]
st.dataframe(holdings_display.style.format({"quantity":"{:.3f}", "avg_price":"{:.2f}", "market_price":"{:.2f}", "market_value":"{:.2f}", "unreal_pnl":"{:.2f}"}), use_container_width=True)
total_market = holdings_display["market_value"].sum()
st.write(f"Total holdings market value: ${total_market:.2f}")
with col2:
st.subheader("Leaderboard")
lb = get_leaderboard()
if lb.empty:
st.write("No players yet.")
else:
st.table(lb.style.format({"total":"${:,.2f}"}).head(10))
st.markdown("---")
# Trading panel
st.header("Trade")
tcol1, tcol2 = st.columns(2)
with tcol1:
ticker = st.text_input("Ticker (e.g., AAPL)", key="trade_ticker")
qty = st.number_input("Quantity", min_value=0.0, value=1.0, step=1.0, key="trade_qty")
with tcol2:
if st.button("Fetch Price"):
try:
price = fetch_price(ticker)
st.success(f"Price for {ticker.upper()}: ${price:.2f}")
except Exception as e:
st.error(str(e))
if st.button("Buy"):
if not ticker:
st.error("Enter ticker")
else:
ok, msg = attempt_buy(user, ticker, qty)
if ok:
st.success(msg)
# refresh user object
st.session_state["user"] = get_user(user["username"])
else:
st.error(msg)
if st.button("Sell"):
if not ticker:
st.error("Enter ticker")
else:
ok, msg = attempt_sell(user, ticker, qty)
if ok:
st.success(msg)
st.session_state["user"] = get_user(user["username"])
else:
st.error(msg)
st.markdown("---")
# Transaction history and portfolio chart
st.header("Transaction History & Portfolio Value")
tx = get_transactions(user["id"], limit=500)
st.subheader("Recent Transactions")
if tx.empty:
st.info("No transactions yet.")
else:
st.dataframe(tx, use_container_width=True)
# Portfolio value over time (reconstruct from transactions)
st.subheader("Portfolio Value (by re-using transactions)")
# basic reconstruction: assume each transaction timestamp, compute cash and holdings snapshot
# We will create a simple time series from transactions for demo
conn = get_conn()
q = conn.cursor()
q.execute("SELECT timestamp, ticker, quantity, price, side FROM transactions WHERE user_id=? ORDER BY id ASC", (user["id"],))
rows = q.fetchall()
conn.close()
if rows:
df_tx = pd.DataFrame(rows, columns=["timestamp","ticker","quantity","price","side"])
df_tx["timestamp"] = pd.to_datetime(df_tx["timestamp"])
# sample points: we compute portfolio value at each tx time using latest market prices (this is approximate)
records = []
cash = get_user(user["username"])["cash"]
# Instead compute forward: start with starting cash and apply transactions in order to track cash (we need starting cash)
start_user = get_user(user["username"])
# To compute portfolio value over time properly we'd need historic prices at each tx time — skip heavy calls; instead show current portfolio snapshot vs time by trade counts
# So we'll create a simple chart: cumulative invested vs current market value
holdings_now = get_holdings(user["id"])
if not holdings_now.empty:
prices = fetch_market_prices(list(holdings_now["ticker"].unique()))
holdings_now["market_price"] = holdings_now["ticker"].apply(lambda t: prices.get(t, 0.0))
holdings_now["market_value"] = holdings_now["quantity"] * holdings_now["market_price"]
chart_df = holdings_now[["ticker","market_value"]]
chart_df = chart_df.rename(columns={"market_value":"value"})
st.write("Current holdings market values:")
st.dataframe(holdings_now)
chart = alt.Chart(chart_df).mark_bar().encode(x="ticker", y="value")
st.altair_chart(chart, use_container_width=True)
st.markdown("---")
st.sidebar.markdown("## Player Actions")
if st.sidebar.button("Refresh Data"):
st.session_state["user"] = get_user(user["username"])
st.experimental_rerun()
if st.sidebar.button("Log out"):
st.session_state.pop("user", None)
st.experimental_rerun()
# -----------------------
# App entrypoint
# -----------------------
def main():
st.set_page_config(page_title="Virtual Stock Trading Game", layout="wide")
init_db()
st.sidebar.title("Virtual Trading")
user = st.session_state.get("user", None)
if not user:
ui_user = login_ui()
if ui_user:
user = ui_user
else:
# refresh user data from DB
user = get_user(user["username"])
st.session_state["user"] = user
if user:
main_app(user)
else:
st.title("Welcome to the Virtual Stock Trading Game")
st.write("Create a username in the left panel to start. You'll receive some starting cash to practice trading.")
st.info("This app uses real market prices via yfinance but only simulates trades with fake money.")
if __name__ == "__main__":
main()