From 130246cae76957bdf20466e9161fb655ff5c857e Mon Sep 17 00:00:00 2001 From: zhushuang Date: Thu, 11 Jun 2026 20:13:49 +0800 Subject: [PATCH] issue/429 - feat: adjust warmup --- examples/bench.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/examples/bench.py b/examples/bench.py index 41cb54ac..b3d97709 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -331,32 +331,43 @@ def run( warmup_steps = 1 # warmup cache capacity - warmup_cache_len = 128 - warmup_batch = len(test.input_ids_list) - - test.model.reset_cache( - StaticKVCacheConfig( + warmup_case = next(iter(cases_dict.values())) + warmup_batch = warmup_case["batch_size"] + warmup_input_len = warmup_case["input_len"] + warmup_decode_len = 5 + + if enable_paged_attn: + warmup_num_blocks = ( + (warmup_input_len + warmup_decode_len + paged_kv_block_size - 1) + // paged_kv_block_size + ) * warmup_batch + warmup_cache_config = PagedKVCacheConfig( + warmup_num_blocks, paged_kv_block_size + ) + else: + warmup_cache_config = StaticKVCacheConfig( max_batch_size=warmup_batch, - max_cache_len=warmup_cache_len, + max_cache_len=warmup_input_len + warmup_decode_len, ) - ) - avg_prompt_len = min(64, max(len(ids) for ids in test.input_ids_list)) + test.model.reset_cache(warmup_cache_config) - warmup_ids = [ - ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids - for ids in test.input_ids_list - ] + warmup_prompt_ids = repeat_prompt(test.input_ids_list[0], warmup_input_len) + warmup_ids = [warmup_prompt_ids] * warmup_batch input_ids_infini = infinicore.from_list(warmup_ids, dtype=infinicore.int64) + print( + f"\033[93m[warmup] batch={warmup_batch}, input_len={warmup_input_len}, " + f"will prefill + {warmup_decode_len} decode steps\033[0m" + ) print("=================== warmup start ===================") for _ in range(warmup_steps): _ = test.model.generate( input_ids_infini, GenerationConfig( - max_new_tokens=5, # decode kernel warmup + max_new_tokens=warmup_decode_len, # decode kernel warmup temperature=cfg.temperature, top_k=cfg.top_k, top_p=cfg.top_p,