Batch Sequencing with Limit Orders
# Estimate job size, place a limit order for tokens (1 unit = 1M tokens),
# wait until balance is available, then run the batch.
import asyncio
import base64
import math
import os
import time
from typing import List
import requests
from nacl.signing import SigningKey
from openai import AsyncOpenAI
CONSUMPTION_BASE_URL = "https://consumption.api.thegrid.ai/api/v1"
TRADING_BASE_URL = "https://trading.api.thegrid.ai"
MARKET_ID = "market_788dcbd5-ac68-4c61-acf1-4443beaf2a1c"
MAX_PRICE = float(os.getenv("MAX_PRICE", "2.50"))
MAX_WAIT_SEC = int(os.getenv("MAX_WAIT_SEC", "120"))
INCLUDE_EXISTING_BALANCE = (
os.getenv("INCLUDE_EXISTING_BALANCE", "false").lower() == "true"
)
PROMPTS: List[str] = [
"What is 2 + 2?",
"Name three colors.",
"What is the capital of France?",
"How many days in a week?",
"What is H2O?",
]
CONFIG = {
"model": "chat-fast",
"max_concurrent": 3,
}
class SignatureAuth:
def __init__(self, private_key_b64: str, fingerprint: str):
self.private_key = SigningKey(base64.b64decode(private_key_b64)[:32])
self.fingerprint = fingerprint
def get_headers(self, method: str, path: str, body: str = ""):
timestamp = str(int(time.time()))
message = f"{timestamp}{method.upper()}{path}{body}"
signature = self.private_key.sign(message.encode()).signature
return {
"x-thegrid-signature": base64.b64encode(signature).decode(),
"x-thegrid-timestamp": timestamp,
"x-thegrid-fingerprint": self.fingerprint,
}
client = AsyncOpenAI(
api_key=os.environ["GRID_CONSUMPTION_API_KEY"],
base_url=CONSUMPTION_BASE_URL,
)
trading_auth = SignatureAuth(
os.environ["GRID_TRADING_PRIVATE_KEY"],
os.environ["GRID_TRADING_FINGERPRINT"],
)
def estimate_tokens(prompt_list: List[str]) -> int:
return sum(math.ceil((len(prompt) + 50) / 4) for prompt in prompt_list)
def units_needed(total_tokens: int) -> int:
return max(1, math.ceil(total_tokens / 1_000_000))
def get_consumption_balance(auth: SignatureAuth) -> float:
path = "/api/v1/trading/consumption-accounts"
resp = requests.get(
f"{TRADING_BASE_URL}{path}?order_by=created_at",
headers=auth.get_headers("GET", path),
timeout=10,
)
resp.raise_for_status()
data = resp.json()["data"]
balance = 0.0
for acct in data:
try:
available = float(acct.get("available_balance", 0) or 0)
total = float(acct.get("total_balance", 0) or 0)
balance += max(available, total)
except (TypeError, ValueError):
continue
return balance
def place_limit_order(auth: SignatureAuth, quantity: int) -> str:
order_data = {
"market_id": MARKET_ID,
"side": "buy",
"type": "limit",
"quantity": quantity,
"price": f"{MAX_PRICE:.2f}",
"time_in_force": "gtc",
"client_order_id": f"batch-{int(time.time())}",
}
path = "/api/v1/trading/orders"
body = json_dumps(order_data)
resp = requests.post(
f"{TRADING_BASE_URL}{path}",
data=body,
headers={
"Content-Type": "application/json",
**auth.get_headers("POST", path, body),
},
timeout=10,
)
resp.raise_for_status()
return resp.json()["data"]["order_id"]
def get_order_status(auth: SignatureAuth, order_id: str) -> str:
path = f"/api/v1/trading/orders/{order_id}"
resp = requests.get(
f"{TRADING_BASE_URL}{path}",
headers=auth.get_headers("GET", path),
timeout=10,
)
resp.raise_for_status()
return resp.json()["data"]["status"]
def wait_for_fill(auth: SignatureAuth, target_units: int, order_id: str | None):
fill_seen = False
deadline = time.time() + MAX_WAIT_SEC
while time.time() < deadline:
balance = 0.0
try:
balance = get_consumption_balance(auth)
except Exception:
pass
status = "unknown"
if order_id:
try:
status = get_order_status(auth, order_id)
except Exception:
status = "unknown"
print(
f"Waiting: order {order_id or 'n/a'} is {status}, balance {balance}/{target_units}"
)
if balance >= target_units:
return balance
if status == "filled":
fill_seen = True
if status in {"canceled", "rejected"}:
raise RuntimeError(f"Order {order_id} {status}")
time.sleep(3)
if fill_seen:
raise TimeoutError(
"Order filled but balance did not reach target within timeout"
)
raise TimeoutError("Timed out waiting for order fill")
async def process_prompt(prompt, index, semaphore):
async with semaphore:
response = await client.chat.completions.create(
model=CONFIG["model"],
messages=[
{"role": "system", "content": "Be concise and correct."},
{"role": "user", "content": prompt},
],
)
return {
"index": index,
"prompt": prompt,
"response": response.choices[0].message.content,
}
async def process_batch(prompts, max_concurrent):
semaphore = asyncio.Semaphore(max_concurrent)
tasks = [process_prompt(prompt, i, semaphore) for i, prompt in enumerate(prompts)]
results = []
for coro in asyncio.as_completed(tasks):
try:
result = await coro
results.append(result)
print(f"[{len(results)}/{len(prompts)}] {result['prompt'][:30]}...")
except Exception as e:
print(f"Failed: {e}")
return sorted(results, key=lambda x: x["index"])
async def main():
total_tokens = estimate_tokens(PROMPTS)
required_units = units_needed(total_tokens)
print(
f"Estimated {total_tokens} tokens ({required_units} unit(s) of 1M tokens each)"
)
current_balance = get_consumption_balance(trading_auth)
print(f"Current consumption balance: {current_balance} unit(s)")
if INCLUDE_EXISTING_BALANCE and current_balance >= required_units:
print("Balance already sufficient, running batch\n")
else:
balance_used = math.floor(current_balance) if INCLUDE_EXISTING_BALANCE else 0
to_buy = required_units - balance_used
if to_buy < 1:
print("Balance already sufficient, running batch\n")
else:
print(f"Placing limit order for {to_buy} unit(s) at ${MAX_PRICE:.2f}...")
order_id = place_limit_order(trading_auth, to_buy)
print(f"Order placed: {order_id}, waiting for fill...")
final_balance = wait_for_fill(trading_auth, required_units, order_id)
print(f"Order filled (balance now {final_balance}), running batch\n")
print(
f"Processing {len(PROMPTS)} prompts (max {CONFIG['max_concurrent']} concurrent)\n"
)
start_time = time.time()
results = await process_batch(PROMPTS, CONFIG["max_concurrent"])
elapsed = time.time() - start_time
print(f"\nCompleted {len(results)} requests in {elapsed:.2f}s\n")
for result in results:
print(f"{result['index'] + 1}. {result['prompt']}")
print(f" → {result['response'][:100]}...\n")
def json_dumps(obj) -> str:
import json
return json.dumps(obj)
asyncio.run(main())
Last updated
Was this helpful?