BitFunded Decision Transformer (Stage 1)
Decision Transformer trained for the BitFunded prop firm Stage 1 crypto challenge. Predicts optimal trading actions (HOLD, LONG, SHORT, CLOSE) conditioned on a target return of 8%, while respecting drawdown limits.
Challenge Constraints
| Parameter | Value |
|---|---|
| Account | 10,000 USDT |
| Leverage | 1:5 |
| Profit target | 800 USDT (8%) |
| Max daily loss | 500 USDT (5%) |
| Max total loss | 1,000 USDT (10%) |
Model Architecture
| Parameter | Value |
|---|---|
| Type | Decision Transformer (causal) |
| Hidden dim | 128 |
| Attention heads | 4 |
| Transformer layers | 3 |
| Context length | 20 timesteps |
| Parameters | ~635K |
| Actions | HOLD (0), LONG (1), SHORT (2), CLOSE (3) |
Training
- Data: 23 crypto pairs × 800 4H candles from OKX (~6 months)
- Pairs: BTC, ETH, SOL, BNB, XRP, ADA, DOT, LINK, AVAX, NEAR, SUI, TON, ATOM, APT, ARB, OP, TRX, UNI, LTC, IMX, ONDO, ICP, FET
- Expert trajectories: 150 episodes with enriched rule-based policy (trend pullback, RSI at S/R, BB squeeze, MACD crossover)
- Action distribution: HOLD=21,507 | LONG=1,097 | SHORT=1,604 | CLOSE=2,576
- Training accuracy: 89.4% | Validation accuracy: 85.8%
- Trade action accuracy: 98.8%
- Composite score: 0.936
Input Features (39 dimensions)
| Index | Description |
|---|---|
| 0-4 | Price returns (1/6/18/42 bar log returns + candle body ratio) |
| 5-8 | Volatility (14/42 bar rolling std, normalized ATR, vol ratio) |
| 9-14 | Moving averages (EMA 9/21/50 distance, slopes, alignment signal) |
| 15-17 | RSI (normalized [-1,1], zone signal, divergence) |
| 18-20 | MACD (normalized line, histogram, crossover signal) |
| 21-24 | Bollinger Bands (%B position, width, squeeze ratio, momentum) |
| 25-28 | Volume (ratio to 20-MA, log ratio, trend, price-volume divergence) |
| 29-31 | Support/Resistance (distance to nearest high/low, range position) |
| 32-34 | Market regime (trend strength, mean reversion z-score, composite) |
| 35-38 | Portfolio state (position flag, PnL ratio, daily PnL, drawdown) |
Reward Shaping
- Base: realized + unrealized PnL (normalized)
- Penalty at 3% daily drawdown (approaching 5% limit)
- Penalty at 6% total drawdown (approaching 10% limit)
- Progress bonus toward 8% profit target
- Small HOLD cost to encourage decision-making
- Terminal: +5.0 for hitting target, -3.0 for daily loss breach, -5.0 for total loss breach
Usage
import torch
import torch.nn as nn
import json
# Load config
with open("config.json") as f:
cfg = json.load(f)
# Define architecture (must match training)
class DTBlock(nn.Module):
def __init__(self, h, nh, do):
super().__init__()
self.attn = nn.MultiheadAttention(h, nh, dropout=do, batch_first=True)
self.ln1 = nn.LayerNorm(h)
self.ln2 = nn.LayerNorm(h)
self.ffn = nn.Sequential(
nn.Linear(h, h*4), nn.GELU(), nn.Dropout(do),
nn.Linear(h*4, h), nn.Dropout(do)
)
def forward(self, x, mask=None):
n = self.ln1(x)
a, _ = self.attn(n, n, n, attn_mask=mask)
x = x + a
return x + self.ffn(self.ln2(x))
class DecisionTransformer(nn.Module):
def __init__(self, state_dim, act_dim, hidden_dim, n_heads, n_layers, max_ep_len, seq_len, dropout=0.0):
super().__init__()
self.hidden_dim = hidden_dim
self.state_embed = nn.Linear(state_dim, hidden_dim)
self.action_embed = nn.Embedding(act_dim, hidden_dim)
self.return_embed = nn.Linear(1, hidden_dim)
self.timestep_embed = nn.Embedding(max_ep_len, hidden_dim)
self.pos_embed = nn.Parameter(torch.zeros(1, 3*seq_len, hidden_dim))
self.blocks = nn.ModuleList([DTBlock(hidden_dim, n_heads, dropout) for _ in range(n_layers)])
self.ln_f = nn.LayerNorm(hidden_dim)
self.action_head = nn.Linear(hidden_dim, act_dim)
def forward(self, states, actions, returns_to_go, timesteps):
B, T = states.shape[:2]
te = self.timestep_embed(timesteps)
se = self.state_embed(states) + te
ae = self.action_embed(actions) + te
re = self.return_embed(returns_to_go.unsqueeze(-1)) + te
tok = torch.zeros(B, 3*T, self.hidden_dim, device=states.device)
tok[:, 0::3] = re
tok[:, 1::3] = se
tok[:, 2::3] = ae
tok = tok + self.pos_embed[:, :3*T]
mask = torch.triu(torch.ones(3*T, 3*T, device=states.device), diagonal=1).bool()
x = tok
for block in self.blocks:
x = block(x, mask)
return self.action_head(self.ln_f(x[:, 1::3]))
# Load model
model = DecisionTransformer(**cfg)
model.load_state_dict(torch.load("decision_transformer_v3.pt", map_location="cpu"))
model.eval()
# Inference: predict action for the last timestep
# states: (1, seq_len, 39), actions: (1, seq_len), rtg: (1, seq_len), timesteps: (1, seq_len)
with torch.no_grad():
logits = model(states, actions, returns_to_go, timesteps)
probs = torch.softmax(logits[0, -1], dim=-1)
action = torch.argmax(probs).item()
# 0=HOLD, 1=LONG, 2=SHORT, 3=CLOSE
Integration
This model is used as a confidence overlay in a live signal scanner. Technical analysis generates candidate signals, then the DT provides:
- Confirmation (upgrades signal strength) when it agrees
- Caution (downgrades signal strength) when it disagrees
Risk per trade is capped at 40 USDT (0.4%) with a personal daily stop of 200 USDT.
Disclaimer
This model is for educational and research purposes. Not financial advice. Past performance does not guarantee future results. Always use proper risk management.
- Downloads last month
- 13