#!/usr/bin/env python3 """Concurrency + NIAH bench for Qwen3.5-397B-A17B-FP8 with MTP on 4x GX10.""" import requests import time import json import threading import random import sys import os from concurrent.futures import ThreadPoolExecutor, as_completed BASE = os.environ.get("VLLM_BASE_URL", "http://HEAD_NODE:8000") MODEL = "Qwen3.5-397B-A17B-FP8" TIMEOUT_TOK = 120 TIMEOUT_GEN = 1800 # ---------- helpers ---------- def tokenize_count(text): r = requests.post(f"{BASE}/tokenize", json={"model": MODEL, "prompt": text}, timeout=TIMEOUT_TOK) r.raise_for_status() return r.json()["count"] def build_text_of_tokens(target): """Generate varied English filler text of ~target tokens.""" random.seed(42) subjects = ["The detective", "She", "He", "The cat", "A small fox", "Mary", "John", "The team", "An engineer", "The pilot", "A researcher", "The astronaut", "The chef", "The child", "The gardener", "Captain Ortiz", "Dr. Reed", "The librarian"] verbs = ["walked through", "examined", "spoke to", "considered", "wrote about", "imagined", "noticed", "discovered", "observed", "described", "remembered", "studied", "explained", "decided on", "questioned", "ignored"] objects = ["the strange wooden box", "an old crumpled letter", "a faded photograph", "the heavy iron key", "a thin leather notebook", "the morning sunlight", "a distant white ship", "the rusted garden gate", "an unknown radio signal", "the heavy oak door", "a quiet voice in the hallway", "a small printing mistake", "a complex chemical idea", "the new spiral pattern", "a difficult open question"] extras = ["under a wide grey sky", "after the long cold winter", "while the kettle whistled softly", "before anyone else woke", "in the quiet of the library", "as the rain fell steadily", "across the empty wheat field", "near the old stone wall", "between two very tall trees", "during the second hour", "around the small bronze fountain", "amid the murmuring crowd", "behind a locked oak cabinet", "above the muddy slow river", "below a starless winter ceiling", "outside the rusted gate"] # over-estimate then trim target_chars = int(target * 4.5) parts = [] total = 0 while total < target_chars: s = (f"{random.choice(subjects)} {random.choice(verbs)} " f"{random.choice(objects)} {random.choice(extras)}.") parts.append(s) total += len(s) + 1 text = " ".join(parts) n = tokenize_count(text) # binary-ish trim while n > target + 30: ratio = target / n text = text[:int(len(text) * ratio)] idx = text.rfind(".") if idx > 0: text = text[:idx+1] n = tokenize_count(text) # extend if short extra_idx = 0 while n < target - 20: text += " " + parts[extra_idx % len(parts)] extra_idx += 1 if extra_idx % 50 == 0: n = tokenize_count(text) n = tokenize_count(text) return text, n def stream_chat(prompt, max_tokens, label): payload = { "model": MODEL, "messages": [{"role": "user", "content": prompt}], "max_tokens": max_tokens, "temperature": 0.0, "stream": True, "stream_options": {"include_usage": True}, } t0 = time.perf_counter() ttft = None content_buf = [] reasoning_buf = [] usage = None n_chunks = 0 first_token_t = None last_chunk_t = None r = requests.post(f"{BASE}/v1/chat/completions", json=payload, stream=True, timeout=TIMEOUT_GEN) r.raise_for_status() for raw in r.iter_lines(decode_unicode=False): if not raw: continue if not raw.startswith(b"data: "): continue data = raw[6:] if data == b"[DONE]": break try: obj = json.loads(data) except Exception: continue now = time.perf_counter() n_chunks += 1 if ttft is None: ttft = now - t0 first_token_t = now last_chunk_t = now if obj.get("choices"): delta = obj["choices"][0].get("delta") or {} if delta.get("content"): content_buf.append(delta["content"]) rc = delta.get("reasoning") or delta.get("reasoning_content") if rc: reasoning_buf.append(rc) if obj.get("usage"): usage = obj["usage"] end = time.perf_counter() return { "label": label, "t0": t0, "ttft": ttft, "first_token_t": first_token_t, "end": end, "duration": end - t0, "content": "".join(content_buf), "reasoning": "".join(reasoning_buf), "usage": usage, "chunks": n_chunks, } # ---------- concurrency test ---------- def concurrency_test(prompt, input_tokens, max_tokens, n): barrier = threading.Barrier(n) def task(i): barrier.wait() return stream_chat(prompt, max_tokens, f"c{n}-r{i}") t_start = time.perf_counter() with ThreadPoolExecutor(max_workers=n) as ex: results = list(ex.map(task, range(n))) t_end = time.perf_counter() per = [] for r in results: ttft = r["ttft"] or 0.0 decode_time = max(r["duration"] - ttft, 1e-6) in_tokens = r["usage"]["prompt_tokens"] if r["usage"] else input_tokens out_tokens = r["usage"]["completion_tokens"] if r["usage"] else 0 prefill_tps = in_tokens / ttft if ttft > 0 else 0 decode_tps = (out_tokens - 1) / decode_time if out_tokens > 1 else 0 per.append({ "ttft_s": ttft, "duration_s": r["duration"], "in_tokens": in_tokens, "out_tokens": out_tokens, "prefill_tps": prefill_tps, "decode_tps": decode_tps, }) # aggregate using wall-clock windows min_t0 = min(r["t0"] for r in results) last_ttft_wall = max(r["t0"] + r["ttft"] for r in results) - min_t0 first_tok_wall = min(r["t0"] + r["ttft"] for r in results) decode_window = max(r["end"] for r in results) - first_tok_wall total_in = sum(p["in_tokens"] for p in per) total_out = sum(p["out_tokens"] for p in per) agg_prefill_tps = total_in / last_ttft_wall if last_ttft_wall > 0 else 0 agg_decode_tps = (total_out - n) / decode_window if decode_window > 0 else 0 return { "n": n, "wall_s": t_end - t_start, "per": per, "agg_prefill_tps": agg_prefill_tps, "agg_decode_tps": agg_decode_tps, "total_in": total_in, "total_out": total_out, } # ---------- NIAH ---------- def niah_test(filler_text, needle, needle_value, question, max_tokens=2048): mid = len(filler_text) // 2 idx = filler_text.find(".", mid) if idx < 0: idx = mid body = filler_text[:idx+1] + "\n\n" + needle + "\n\n" + filler_text[idx+1:] prompt = body + "\n\n" + question in_tokens = tokenize_count(prompt) print(f" NIAH prompt = {in_tokens} tokens", file=sys.stderr, flush=True) r = stream_chat(prompt, max_tokens, "niah") out_tokens = r["usage"]["completion_tokens"] if r["usage"] else 0 blob = (r["content"] + "\n" + r["reasoning"]) found = needle_value in blob return { "in_tokens": in_tokens, "ttft_s": r["ttft"], "duration_s": r["duration"], "out_tokens": out_tokens, "prefill_tps": in_tokens / r["ttft"] if r["ttft"] else 0, "decode_tps": (out_tokens - 1) / max(r["duration"] - r["ttft"], 1e-6) if out_tokens > 1 else 0, "found": found, "needle_value": needle_value, "content": r["content"][:600], "reasoning_head": r["reasoning"][:400], } # ---------- main ---------- def main(): print("== Building 10k-token prompt ==", file=sys.stderr, flush=True) body_10k, n10k = build_text_of_tokens(9800) # leave room for instruction prompt_10k = body_10k + "\n\nPlease continue the narrative above with new events for the next thousand words or so." n10k_full = tokenize_count(prompt_10k) print(f" 10k prompt: {n10k_full} tokens", file=sys.stderr, flush=True) print("== Concurrency sweep (10k in / 1024 out) ==", file=sys.stderr, flush=True) results_conc = [] for n in [1, 2, 4, 8, 16]: print(f" -- N={n}", file=sys.stderr, flush=True) res = concurrency_test(prompt_10k, n10k_full, 1024, n) results_conc.append(res) print(f" wall={res['wall_s']:.1f}s " f"agg_prefill={res['agg_prefill_tps']:.0f}t/s " f"agg_decode={res['agg_decode_tps']:.1f}t/s", file=sys.stderr, flush=True) print("== Building 200k-token filler for NIAH ==", file=sys.stderr, flush=True) body_200k, n200 = build_text_of_tokens(199000) print(f" 200k filler: {n200} tokens", file=sys.stderr, flush=True) needle_value = "BANANA-MOOSE-7421" needle = (f"IMPORTANT: The secret access code for Project Aurora is " f"{needle_value}. Remember this exactly.") question = ("Based on the text above, what is the secret access code for " "Project Aurora? Reply with just the code itself, nothing else.") print("== NIAH test (200k prefill) ==", file=sys.stderr, flush=True) niah = niah_test(body_200k, needle, needle_value, question, max_tokens=2048) print(f" NIAH: in={niah['in_tokens']}t ttft={niah['ttft_s']:.1f}s " f"prefill={niah['prefill_tps']:.0f}t/s found={niah['found']}", file=sys.stderr, flush=True) out = {"concurrency": results_conc, "niah": niah} print(json.dumps(out, indent=2)) if __name__ == "__main__": main()