Saltar al contenido principal

Rendimiento Lento de Inferencia

Cuando experimentes tiempos de inferencia lentos, primero establece una línea base de rendimiento. Utiliza la Calculadora de VRAM para determinar el throughput esperado de tu GPU y compáralo con las especificaciones de tu modelo. Zylon proporciona un script de benchmarking que simula solicitudes de inferencia concurrentes para medir el Time To First Token (TTFT), latencia y throughput. Esto puede ayudar a identificar si el rendimiento está por debajo de las expectativas y si se degrada bajo carga.
Zylon asigna recursos de cómputo para mantener tiempos de respuesta consistentes bajo carga concurrente (8-10 usuarios simultáneos). Esto significa que los benchmarks de inferencia única pueden mostrar tokens/s más bajos que el máximo teórico del hardware, pero el rendimiento en el mundo real con múltiples usuarios cumplirá o superará las expectativas.
import asyncio
import json
import random
import time
from dataclasses import dataclass
from typing import Optional

BASE_URL = "https://<host>/api/gpt"
BEARER_TOKEN = "your_token_here"
MODEL = "qwen-3-14b-awq"
MAX_TOKENS = 4096
PROMPT = "Write a paragraph about artificial intelligence."
CONCURRENCY_LEVELS = [1, 4, 8, 16, 32]
DEBUG = False
JITTER_MAX_MS = 200


@dataclass
class RequestResult:
    success: bool
    ttft: Optional[float] = None
    latency: float = 0.0
    generation_time: float = 0.0
    tokens: int = 0
    throughput: float = 0.0
    error: Optional[str] = None


@dataclass
class BenchmarkStatistics:
    success_rate: float
    avg_ttft: float
    avg_latency: float
    avg_generation_time: float
    avg_throughput: float
    p50_ttft: float
    p95_ttft: float
    p99_ttft: float


@dataclass
class ConnectionConfig:
    host: str
    port: int
    use_ssl: bool
    path: str


def debug_log(message: str, force: bool = False) -> None:
    if DEBUG or force:
        timestamp = time.strftime("%H:%M:%S", time.localtime())
        print(f"[DEBUG {timestamp}] {message}")


def parse_url(url: str) -> ConnectionConfig:
    use_ssl = url.startswith("https")
    host_and_path = url.split("//")[1]
    parts = host_and_path.split("/", 1)
    host = parts[0]
    path = "/" + parts[1] if len(parts) > 1 else "/"
    port = 443 if use_ssl else 80

    return ConnectionConfig(host=host, port=port, use_ssl=use_ssl, path=path)


def build_http_request(
    config: ConnectionConfig, payload: dict[str, object], bearer_token: str
) -> bytes:
    path = config.path + "/v1/messages"
    body = json.dumps(payload)

    request_lines = [
        f"POST {path} HTTP/1.1",
        f"Host: {config.host}",
        "Content-Type: application/json",
        f"Authorization: Bearer {bearer_token}",
        f"Content-Length: {len(body)}",
        "Connection: close",
        "",
        body,
    ]

    return "\r\n".join(request_lines).encode()


async def read_stream_response(
    reader: asyncio.StreamReader, session_id: int, start_time: float
) -> tuple[Optional[float], Optional[float], int, int]:
    ttft: Optional[float] = None
    content_block_stop_time: Optional[float] = None
    output_tokens = 0
    event_count = 0
    headers_done = False
    buffer = b""

    while True:
        chunk = await reader.read(8192)
        if not chunk:
            break

        buffer += chunk

        if not headers_done:
            if b"\r\n\r\n" in buffer:
                headers_done = True
                buffer = buffer.split(b"\r\n\r\n", 1)[1]

        if headers_done:
            lines = buffer.split(b"\n")
            buffer = lines[-1]

            for line in lines[:-1]:
                line_str = line.decode("utf-8").strip()

                if not line_str or not line_str.startswith("data: "):
                    continue

                data_str = line_str[6:]
                if data_str == "[DONE]":
                    debug_log(f"Session {session_id}: Stream complete")
                    continue

                try:
                    event = json.loads(data_str)
                    event_count += 1

                    if ttft is None and event.get("type") == "content_block_delta":
                        ttft = time.perf_counter() - start_time
                        debug_log(f"Session {session_id}: TTFT = {ttft:.3f}s")

                    if event.get("type") == "content_block_stop":
                        content_block_stop_time = time.perf_counter()
                        debug_log(
                            f"Session {session_id}: Content block stopped at {content_block_stop_time - start_time:.3f}s"
                        )

                    if event.get("type") == "message_delta":
                        usage = event.get("usage", {})
                        output_tokens = usage.get("output_tokens", 0)
                        debug_log(
                            f"Session {session_id}: Received usage data - {output_tokens} tokens"
                        )

                except json.JSONDecodeError as e:
                    debug_log(f"Session {session_id}: JSON decode error - {e}")
                    continue

    return ttft, content_block_stop_time, output_tokens, event_count


async def make_request(session_id: int) -> RequestResult:
    jitter = random.randint(0, JITTER_MAX_MS) / 1000.0
    debug_log(f"Session {session_id}: Waiting {jitter:.3f}s before starting")
    await asyncio.sleep(jitter)

    debug_log(f"Session {session_id}: Starting request")

    payload: dict[str, object] = {
        "model": MODEL,
        "max_tokens": MAX_TOKENS,
        "messages": [{"role": "user", "content": PROMPT}],
        "stream": True,
        "correlation_id": f"session-{session_id}",
    }

    start_time = time.perf_counter()

    try:
        debug_log(f"Session {session_id}: Opening connection")

        config = parse_url(BASE_URL)
        reader, writer = await asyncio.open_connection(
            config.host, config.port, ssl=config.use_ssl
        )

        request_bytes = build_http_request(config, payload, BEARER_TOKEN)
        writer.write(request_bytes)
        await writer.drain()

        debug_log(f"Session {session_id}: Connection established")

        ttft, content_block_stop_time, output_tokens, event_count = (
            await read_stream_response(reader, session_id, start_time)
        )

        writer.close()
        await writer.wait_closed()

        end_time = time.perf_counter()
        total_latency = end_time - start_time

        generation_time = 0.0
        throughput = 0.0
        if ttft is not None and content_block_stop_time is not None:
            generation_time = content_block_stop_time - (start_time + ttft)
            if generation_time > 0:
                throughput = output_tokens / generation_time

        debug_log(
            f"Session {session_id}: Completed - "
            f"Latency: {total_latency:.3f}s, "
            f"Generation: {generation_time:.3f}s, "
            f"Tokens: {output_tokens}, "
            f"Events: {event_count}, "
            f"Throughput: {throughput:.2f} tok/s"
        )

        return RequestResult(
            success=True,
            ttft=ttft,
            latency=total_latency,
            generation_time=generation_time,
            tokens=output_tokens,
            throughput=throughput,
        )

    except Exception as e:
        debug_log(f"Session {session_id}: Error - {type(e).__name__}: {e}")
        return RequestResult(success=False, error=str(e))


async def run_concurrent_requests(num_users: int) -> list[RequestResult]:
    debug_log(f"Starting {num_users} concurrent requests")
    tasks = [make_request(i) for i in range(num_users)]
    results = await asyncio.gather(*tasks)
    debug_log(f"All {num_users} requests completed")
    return list(results)


def calculate_percentile(data: list[float], p: float) -> float:
    if not data:
        return 0.0
    k = (len(data) - 1) * p
    f = int(k)
    c = f + 1
    if c >= len(data):
        return data[f]
    return data[f] + (k - f) * (data[c] - data[f])


def calculate_statistics(results: list[RequestResult]) -> BenchmarkStatistics:
    successful = [r for r in results if r.success]
    debug_log(
        f"Calculating statistics for {len(successful)}/{len(results)} successful requests"
    )

    if not successful:
        debug_log("No successful requests to calculate statistics")
        return BenchmarkStatistics(
            success_rate=0.0,
            avg_ttft=0.0,
            avg_latency=0.0,
            avg_generation_time=0.0,
            avg_throughput=0.0,
            p50_ttft=0.0,
            p95_ttft=0.0,
            p99_ttft=0.0,
        )

    ttfts = sorted([r.ttft for r in successful if r.ttft is not None])
    latencies = [r.latency for r in successful]
    generation_times = [r.generation_time for r in successful]
    throughputs = [r.throughput for r in successful if r.throughput > 0]

    return BenchmarkStatistics(
        success_rate=len(successful) / len(results),
        avg_ttft=sum(ttfts) / len(ttfts) if ttfts else 0.0,
        avg_latency=sum(latencies) / len(latencies),
        avg_generation_time=sum(generation_times) / len(generation_times)
        if generation_times
        else 0.0,
        avg_throughput=sum(throughputs) / len(throughputs) if throughputs else 0.0,
        p50_ttft=calculate_percentile(ttfts, 0.5) if ttfts else 0.0,
        p95_ttft=calculate_percentile(ttfts, 0.95) if ttfts else 0.0,
        p99_ttft=calculate_percentile(ttfts, 0.99) if ttfts else 0.0,
    )


def print_statistics(concurrency: int, stats: BenchmarkStatistics) -> None:
    print(f"\nTesting with {concurrency} concurrent inferences...")
    print(f"  Success Rate: {stats.success_rate:.2%}")
    print(f"  Avg TTFT: {stats.avg_ttft:.3f}s")
    print(f"  P50 TTFT: {stats.p50_ttft:.3f}s")
    print(f"  P95 TTFT: {stats.p95_ttft:.3f}s")
    print(f"  P99 TTFT: {stats.p99_ttft:.3f}s")
    print(f"  Avg Latency: {stats.avg_latency:.3f}s")
    print(f"  Avg Generation Time: {stats.avg_generation_time:.3f}s")
    print(f"  Avg Throughput: {stats.avg_throughput:.2f} tok/s")


async def benchmark() -> None:
    print(f"Benchmarking API: {BASE_URL}")
    print(f"Model: {MODEL}")
    print(f"Prompt: {PROMPT[:50]}...")
    print("-" * 80)

    for concurrency in CONCURRENCY_LEVELS:
        debug_log(f"Starting concurrency level: {concurrency}")

        results = await run_concurrent_requests(concurrency)
        stats = calculate_statistics(results)
        print_statistics(concurrency, stats)


if __name__ == "__main__":
    asyncio.run(benchmark())
Actualiza BASE_URL, BEARER_TOKEN, y MODEL en el script, luego ejecútalo para medir TTFT, throughput y latencia bajo diferentes niveles de concurrencia. Compara tus resultados con el máximo teórico de la calculadora de VRAM para identificar si el rendimiento está por debajo de las expectativas.