Profiling with SGLang

This chapter provides multiple approaches for profiling SGLang, helping to understand time utilization, detect bottlenecks, and analyze both host and device behavior during inference.

SGLang includes several profiling tools and supports integration with Intel Gaudi’s native profiling features.

Built-in Profiling

  1. Enable basic profiling to track key metrics:

    python -m sglang.launch_server \
        --model-path meta-llama/Meta-Llama-3.1-8B \
        --enable-profiling \
        --profiling-interval 10 \
        --host 0.0.0.0 \
        --port 30000
    
  2. Access metrics via the metrics endpoint:

curl http://localhost:30000/metrics

Example metrics output:

{
    "throughput": {
        "requests_per_second": 12.5,
        "tokens_per_second": 1250.0,
        "prefill_tokens_per_second": 2500.0,
        "decode_tokens_per_second": 1150.0
    },
    "latency": {
        "time_to_first_token_ms": 45.2,
        "inter_token_latency_ms": 8.1,
        "end_to_end_latency_ms": 892.3
    },
    "memory": {
        "hpu_memory_used_gb": 24.5,
        "hpu_memory_total_gb": 94.6,
        "kv_cache_usage_gb": 18.2,
        "model_memory_gb": 6.3
    },
    "requests": {
        "active_requests": 3,
        "queued_requests": 1,
        "completed_requests": 1247
    }
}

Gaudi Profiler Integration

Use Habana’s profiler for detailed device-level analysis:

# Set environment variables for profiling
export HABANA_PROFILE=profile_api_light
export SGLANG_TORCH_PROFILER_DIR=<path for profiling output>
python -m sglang.bench_offline_throughput \
    --dataset-name random \
    --dtype bfloat16 \
    --model-path /mnt/weka/data/pytorch/llama3.1/Meta-Llama-3.1-8B-Instruct \
    --num-prompts $NUM_PROMPTS \
    --random-input-len $INPUT_LEN \
    --random-output-len $OUTPUT_LEN \
    --device hpu \
    --page-size 128 \
    --disable-radix-cache \
    --max-prefill-tokens 2048 \
    --random-range-ratio 1.0 \
    --profile

Use the HPU profiler for kernel-level analysis:

import habana_frameworks.torch.hpu.profiler as hpu_profiler
import sglang as sgl

# Enable profiling
with hpu_profiler.profile(
    activities=[
        hpu_profiler.ProfilerActivity.HPU,
        hpu_profiler.ProfilerActivity.CPU
    ],
    output_dir="./hpu_traces"
) as prof:

    # Set up SGLang backend
    backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B")
    sgl.set_default_backend(backend)

    # Run inference with profiling
    @sgl.function
    def profiled_inference(s, prompt):
        s += sgl.user(prompt)
        s += sgl.assistant(sgl.gen("response", max_tokens=100))

    # Execute multiple requests for comprehensive profiling
    for i in range(10):
        result = profiled_inference.run(prompt=f"Test prompt {i}")

Custom Profiling

Create custom benchmarks for your specific use case:

import time
import statistics
import sglang as sgl
from sglang import function, user, assistant, gen

class SGLangProfiler:
    def __init__(self, model_path, num_warmup=5, num_iterations=20):
        self.backend = sgl.Runtime(model_path=model_path)
        sgl.set_default_backend(self.backend)
        self.num_warmup = num_warmup
        self.num_iterations = num_iterations

    def benchmark_throughput(self, prompts, max_tokens=100):
        @function
        def inference_func(s, prompt):
            s += user(prompt)
            s += assistant(gen("response", max_tokens=max_tokens))

        # Warmup
        for _ in range(self.num_warmup):
            inference_func.run(prompt=prompts[0])

        # Actual benchmark
        latencies = []
        start_time = time.time()

        for i in range(self.num_iterations):
            prompt = prompts[i % len(prompts)]

            request_start = time.time()
            result = inference_func.run(prompt=prompt)
            request_end = time.time()

            latencies.append(request_end - request_start)

        total_time = time.time() - start_time

        return {
            "total_time": total_time,
            "requests_per_second": self.num_iterations / total_time,
            "avg_latency": statistics.mean(latencies),
            "p50_latency": statistics.median(latencies),
            "p95_latency": sorted(latencies)[int(0.95 * len(latencies))],
            "p99_latency": sorted(latencies)[int(0.99 * len(latencies))]
        }

# Usage
profiler = SGLangProfiler("meta-llama/Meta-Llama-3.1-8B")
test_prompts = [
    "Explain machine learning",
    "What is quantum computing?",
    "Describe neural networks",
    # Add more test prompts...
]

results = profiler.benchmark_throughput(test_prompts)
print(f"Throughput: {results['requests_per_second']:.2f} req/s")
print(f"P50 Latency: {results['p50_latency']:.3f}s")

Monitor memory usage patterns:

import psutil
import time
import sglang as sgl

class MemoryProfiler:
    def __init__(self):
        self.memory_snapshots = []

    def take_snapshot(self, label):
        # Get HPU memory info (requires habana-torch-plugin)
        try:
            import habana_frameworks.torch.hpu as hpu
            hpu_memory = hpu.memory_stats()
            hpu_allocated = hpu_memory.get('allocated_bytes.all.current', 0) / 1024**3
            hpu_reserved = hpu_memory.get('reserved_bytes.all.current', 0) / 1024**3
        except:
            hpu_allocated = hpu_reserved = 0

        # Get system memory
        process = psutil.Process()
        memory_info = process.memory_info()

        snapshot = {
            'label': label,
            'timestamp': time.time(),
            'hpu_allocated_gb': hpu_allocated,
            'hpu_reserved_gb': hpu_reserved,
            'system_memory_gb': memory_info.rss / 1024**3,
            'virtual_memory_gb': memory_info.vms / 1024**3
        }

        self.memory_snapshots.append(snapshot)
        return snapshot

    def print_summary(self):
        print("Memory Usage Summary:")
        print("-" * 80)
        for snapshot in self.memory_snapshots:
            print(f"{snapshot['label']:20} | "
                f"HPU: {snapshot['hpu_allocated_gb']:.2f}GB | "
                f"System: {snapshot['system_memory_gb']:.2f}GB")

# Usage
memory_profiler = MemoryProfiler()

# Profile different stages
memory_profiler.take_snapshot("Initial")

backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B")
memory_profiler.take_snapshot("After model load")

sgl.set_default_backend(backend)
memory_profiler.take_snapshot("After backend setup")

# Run inference
@sgl.function
def test_inference(s, prompt):
    s += sgl.user(prompt)
    s += sgl.assistant(sgl.gen("response", max_tokens=100))

result = test_inference.run(prompt="Hello, how are you?")
memory_profiler.take_snapshot("After inference")

memory_profiler.print_summary()

Production Monitoring

Set up continuous monitoring for production deployments:

import time
import json
import requests
from datetime import datetime

class ProductionMonitor:
    def __init__(self, sglang_url="http://localhost:30000"):
        self.sglang_url = sglang_url
        self.metrics_history = []

    def collect_metrics(self):
        try:
            response = requests.get(f"{self.sglang_url}/metrics")
            metrics = response.json()

            # Add timestamp
            metrics['timestamp'] = datetime.now().isoformat()
            self.metrics_history.append(metrics)

            return metrics
        except Exception as e:
            print(f"Failed to collect metrics: {e}")
            return None

    def check_health(self):
        try:
            response = requests.get(f"{self.sglang_url}/health")
            return response.status_code == 200
        except:
            return False

    def analyze_performance(self, window_minutes=10):
        cutoff_time = datetime.now().timestamp() - (window_minutes * 60)
        recent_metrics = [
            m for m in self.metrics_history
            if datetime.fromisoformat(m['timestamp']).timestamp() > cutoff_time
        ]

        if not recent_metrics:
            return None

        # Calculate averages
        avg_throughput = sum(m['throughput']['tokens_per_second']
                        for m in recent_metrics) / len(recent_metrics)
        avg_latency = sum(m['latency']['time_to_first_token_ms']
                        for m in recent_metrics) / len(recent_metrics)

        return {
            'avg_throughput_tps': avg_throughput,
            'avg_time_to_first_token_ms': avg_latency,
            'sample_count': len(recent_metrics)
        }

    def save_metrics(self, filename):
        with open(filename, 'w') as f:
            json.dump(self.metrics_history, f, indent=2)

# Usage for continuous monitoring
monitor = ProductionMonitor()

while True:
    if monitor.check_health():
        metrics = monitor.collect_metrics()
        if metrics:
            # Check for performance issues
            throughput = metrics['throughput']['tokens_per_second']
            if throughput < 100:  # Threshold
                print(f"WARNING: Low throughput detected: {throughput} TPS")

            memory_usage = metrics['memory']['hpu_memory_used_gb']
            total_memory = metrics['memory']['hpu_memory_total_gb']
            if memory_usage / total_memory > 0.95:  # 95% threshold
                print(f"WARNING: High memory usage: {memory_usage:.1f}GB / {total_memory:.1f}GB")
    else:
        print("ERROR: SGLang server health check failed")

    time.sleep(30)  # Check every 30 seconds

Profiling Different Workloads

Profile batch inference performance:

import concurrent.futures
import time
import sglang as sgl

def profile_batch_inference(prompts, batch_size=4, max_workers=2):
    backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B")
    sgl.set_default_backend(backend)

    @sgl.function
    def batch_inference(s, prompt):
        s += sgl.user(prompt)
        s += sgl.assistant(sgl.gen("response", max_tokens=100))

    start_time = time.time()

    # Process in batches
    results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        for i in range(0, len(prompts), batch_size):
            batch = prompts[i:i+batch_size]
            futures = [executor.submit(batch_inference.run, prompt=p) for p in batch]
            batch_results = [f.result() for f in futures]
            results.extend(batch_results)

    total_time = time.time() - start_time

    return {
        'total_requests': len(prompts),
        'total_time': total_time,
        'requests_per_second': len(prompts) / total_time,
        'batch_size': batch_size,
        'max_workers': max_workers
    }

Profile streaming response performance:

import sglang as sgl
import time

def profile_streaming(prompt, max_tokens=200):
    backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B")
    sgl.set_default_backend(backend)

    @sgl.function
    def streaming_inference(s, prompt):
        s += sgl.user(prompt)
        s += sgl.assistant(sgl.gen("response", max_tokens=max_tokens, stream=True))

    start_time = time.time()
    first_token_time = None
    token_times = []

    result = streaming_inference.run(prompt=prompt)

    # Simulate streaming processing
    for i, token in enumerate(result.stream()):
        current_time = time.time()
        if first_token_time is None:
            first_token_time = current_time - start_time
        else:
            token_times.append(current_time)

    total_time = time.time() - start_time

    # Calculate inter-token latencies
    inter_token_latencies = []
    for i in range(1, len(token_times)):
        inter_token_latencies.append(token_times[i] - token_times[i-1])

    return {
        'time_to_first_token': first_token_time,
        'total_generation_time': total_time,
        'total_tokens': len(token_times),
        'avg_inter_token_latency': sum(inter_token_latencies) / len(inter_token_latencies) if inter_token_latencies else 0,
        'tokens_per_second': len(token_times) / total_time if total_time > 0 else 0
    }

Troubleshooting Performance Issues

Below are common performance issues along with troubleshooting instructions.

  1. Check resource utilization:

hl-smi  # Check HPU utilization
htop    # Check CPU usage
  1. Profile memory usage:

curl http://localhost:30000/metrics | jq '.memory'
  1. Analyze request queue:

curl http://localhost:30000/debug/queue_status
  1. Enable detailed tracing:

--enable-request-tracing --trace-output-dir ./traces
  1. Check for memory pressure:

# Look for memory allocation delays in logs
grep "memory" sglang_server.log
  1. Profile warmup completeness:

curl http://localhost:30000/warmup/status
  1. Monitor memory patterns:

# Use memory profiling script
python memory_profiler.py
  1. Check for memory leaks:

# Monitor memory over time
watch -n 5 'hl-smi | grep Memory'

Best Practices

Recommendation

Rationale

Profile Regularly

Automate profiling in production to detect regressions early.

Use Representative Workloads

Profile with realistic request patterns reflecting actual use.

Monitor Key Metrics

Focus on throughput, latency, and memory usage.

Establish Baselines

Record performance before changes to track improvements.

Test Different Scenarios

Vary batch sizes, sequence lengths, and concurrency levels.

Archive Results

Keep historical profiling data for comparison over time.

For additional profiling techniques and troubleshooting, see SGLang with Gaudi FAQs and Inference Using SGLang.