diff --git a/examples/bench.py b/examples/bench.py index 6672d6d7d..41cb54acf 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -229,7 +229,7 @@ def run( # ---------------------------------------------------------------------------- # # 自回归生成 # ---------------------------------------------------------------------------- # - input_ids_infini = infinicore.from_list(input_ids_list) + input_ids_infini = infinicore.from_list(input_ids_list, dtype=infinicore.int64) t1 = time.time() print("=================== start generate ====================") @@ -348,7 +348,7 @@ def run( for ids in test.input_ids_list ] - input_ids_infini = infinicore.from_list(warmup_ids) + input_ids_infini = infinicore.from_list(warmup_ids, dtype=infinicore.int64) print("=================== warmup start ===================") diff --git a/examples/infer_backup.py b/examples/infer_backup.py index 54008bd90..912a3f869 100644 --- a/examples/infer_backup.py +++ b/examples/infer_backup.py @@ -194,7 +194,7 @@ def test( # Generate # ---------------------------------------------------------------------------- # print(input_contents[0], end="", flush=True) - input_ids_infini = infinicore.from_list(input_ids_list) + input_ids_infini = infinicore.from_list(input_ids_list, dtype=infinicore.int64) # Process multimodal inputs if needed pixel_values_infini = None @@ -207,9 +207,9 @@ def test( # 1. Pixel values all_pixel_values = [] - assert ( - len(pixel_values) == 1 - ), "Only batch_size=1 is supported yet for image inputs." + assert len(pixel_values) == 1, ( + "Only batch_size=1 is supported yet for image inputs." + ) for pv in pixel_values: all_pixel_values.extend( [i.flatten(end_dim=1).permute(1, 0) for i in pv] diff --git a/examples/llama.py b/examples/llama.py index 4a9b4b345..a3f0f11f8 100644 --- a/examples/llama.py +++ b/examples/llama.py @@ -85,7 +85,7 @@ def test( # ---------------------------------------------------------------------------- # # 自回归生成 # ---------------------------------------------------------------------------- # - input_ids_infini = infinicore.from_list(input_ids_list) + input_ids_infini = infinicore.from_list(input_ids_list, dtype=infinicore.int64) t1 = time.time() print("=================== start generate ====================") diff --git a/python/infinilm/generation/utils.py b/python/infinilm/generation/utils.py index 36f54cc6a..bad9c2613 100644 --- a/python/infinilm/generation/utils.py +++ b/python/infinilm/generation/utils.py @@ -86,11 +86,16 @@ def prepare_inputs_for_generation( bs, seq_len = current_position_ids.shape last_position = current_position_ids.narrow(1, seq_len - 1, 1) - one_value = infinicore.from_list( - [1] * bs, - dtype=last_position.dtype, - device=last_position.device, - ).view((bs, 1)) + one_value = ( + infinicore.from_list( + [1] * bs, + dtype=last_position.dtype, + ) + .view((bs, 1)) + .to( + device=last_position.device, + ) + ) next_position = one_value + last_position model_inputs["position_ids"] = next_position @@ -99,15 +104,14 @@ def prepare_inputs_for_generation( ] + infinicore.from_list( [seq_len], dtype=last_position.dtype, - device=last_position.device, - ) + ).to(device=last_position.device) # -------------------------------------------------------------------- # # 所需的: token的input_ids # -------------------------------------------------------------------- # if kwargs.get("next_token_ids", None) is not None: next_token_ids = kwargs["next_token_ids"] model_inputs["input_ids"] = infinicore.from_list( - [[id_] for id_ in next_token_ids], + [[id_] for id_ in next_token_ids], dtype=infinicore.int64 ) # -------------------------------------------------------------------- # diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 844989f43..75890b0f8 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -30,6 +30,7 @@ def read_hf_config(model_path): ) return config_dict + # config.json (required) defines model architecture, while generation_config.json # (optional) defines generation behavior. They are kept as separate readers # because: 1) config.json must exist and requires model_type validation, @@ -43,6 +44,7 @@ def read_hf_generation_config(model_path): return json.load(f) return {} + @dataclass class GenerationConfig: max_new_tokens: int | None = None @@ -244,9 +246,10 @@ def generate( ) // paged_block_size block_tables_list = [ - range(i * max_blocks_per_batch, (i + 1) * max_blocks_per_batch) + list(range(i * max_blocks_per_batch, (i + 1) * max_blocks_per_batch)) for i in range(batch_size) ] + block_tables = infinicore.from_list( block_tables_list, dtype=infinicore.int32, @@ -375,10 +378,14 @@ def reset_cache(self, cache_config): super().reset_cache(cache_config) def state_dict_keyname(self): - return sorted({name for state_dict in super().state_dict() for name in state_dict.keys()}) + return sorted( + {name for state_dict in super().state_dict() for name in state_dict.keys()} + ) def load_state_dict(self, state_dict, strict=None): - super().load_params({name: param._underlying for name, param in state_dict.items()}) + super().load_params( + {name: param._underlying for name, param in state_dict.items()} + ) def process_weights_after_loading(self): super().process_weights_after_loading() diff --git a/test/bench/test_benchmark.py b/test/bench/test_benchmark.py index c15c950fe..60e63dc66 100644 --- a/test/bench/test_benchmark.py +++ b/test/bench/test_benchmark.py @@ -175,7 +175,7 @@ def _generate_step(self, tokens, max_steps, topp_, topk_, temperature_): from infinilm.infer_engine import GenerationConfig input_ids_list = [tokens] - input_ids = infinicore.from_list(input_ids_list) + input_ids = infinicore.from_list(input_ids_list, dtype=infinicore.int64) start_time = time.perf_counter() diff --git a/test/models/llama/test_forward_validation.py b/test/models/llama/test_forward_validation.py index 4d51dc1a0..d109e98d9 100755 --- a/test/models/llama/test_forward_validation.py +++ b/test/models/llama/test_forward_validation.py @@ -25,8 +25,9 @@ # Import to_numpy extension for infinicore tensors try: from infinilm.generation.utils import infini_to_numpy + # This should already be registered, but ensure it's available - if not hasattr(infinicore.Tensor, 'to_numpy'): + if not hasattr(infinicore.Tensor, "to_numpy"): infinicore.Tensor.to_numpy = infini_to_numpy except ImportError: # If not available, we'll use fallback methods @@ -55,12 +56,13 @@ def infinicore_to_torch_tensor(infini_tensor, torch_tensor_for_shape=None): def torch_to_infinicore_tensor(torch_tensor, infini_device): """Fallback conversion.""" - return infinicore.from_list(torch_tensor.tolist()) + return infinicore.from_list(torch_tensor.tolist(), dtype=infinicore.int64) def get_args(): parser = argparse.ArgumentParser( - description="Validate forward pass across backends/dtypes") + description="Validate forward pass across backends/dtypes" + ) parser.add_argument( "--model_path", type=str, @@ -113,15 +115,19 @@ def create_inputs(prompt, tokenizer, device, backend="cpp"): # Match examples/llama.py: use from_list to create tensors # Wrap in list to create batch dimension: [[1, 2, 3, ...]] input_ids_infini = infinicore.from_list( - [input_ids_list], device=infini_device) + [input_ids_list], dtype=infinicore.int64 + ).to(device=infini_device) # Match generation code: use int64 dtype for position_ids position_ids_infini = infinicore.from_list( - [position_ids_list], dtype=infinicore.int64, device=infini_device) + [position_ids_list], dtype=infinicore.int64 + ).to(device=infini_device) return input_ids_infini, position_ids_infini, input_content -def run_forward_pass(model, input_ids, position_ids, backend, dtype, num_decode_steps=2): +def run_forward_pass( + model, input_ids, position_ids, backend, dtype, num_decode_steps=2 +): """Run prefill and multiple decode steps with KV cache, return all decode step logits.""" print(f" Running forward pass (prefill + {num_decode_steps} decode step(s))...") @@ -146,33 +152,37 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype, num_decode_ prefill_logits_wrapped = infinicore.Tensor(prefill_logits) else: prefill_logits_wrapped = prefill_logits - print(f" DEBUG: Prefill logits tensor dtype={prefill_logits_wrapped.dtype}, " - f"device={prefill_logits_wrapped.device}, " - f"shape={prefill_logits_wrapped.shape}") + print( + f" DEBUG: Prefill logits tensor dtype={prefill_logits_wrapped.dtype}, " + f"device={prefill_logits_wrapped.device}, " + f"shape={prefill_logits_wrapped.shape}" + ) prefill_logits_np = infinicore_to_numpy(prefill_logits) - print( - f" ✓ Prefill completed, logits shape: {prefill_logits_np.shape}") + print(f" ✓ Prefill completed, logits shape: {prefill_logits_np.shape}") # Check prefill logits for issues if np.isnan(prefill_logits_np).any(): print(f" ⚠ WARNING: Prefill logits contain NaN values!") print(f" NaN count: {np.isnan(prefill_logits_np).sum()}") print( - f" Prefill logits stats: min={np.nanmin(prefill_logits_np):.6f}, max={np.nanmax(prefill_logits_np):.6f}, mean={np.nanmean(prefill_logits_np):.6f}") + f" Prefill logits stats: min={np.nanmin(prefill_logits_np):.6f}, max={np.nanmax(prefill_logits_np):.6f}, mean={np.nanmean(prefill_logits_np):.6f}" + ) if np.isinf(prefill_logits_np).any(): print(f" ⚠ WARNING: Prefill logits contain Inf values!") print(f" Inf count: {np.isinf(prefill_logits_np).sum()}") if not np.isnan(prefill_logits_np).any(): print( - f" Prefill logits stats: min={prefill_logits_np.min():.6f}, max={prefill_logits_np.max():.6f}, mean={prefill_logits_np.mean():.6f}") + f" Prefill logits stats: min={prefill_logits_np.min():.6f}, max={prefill_logits_np.max():.6f}, mean={prefill_logits_np.mean():.6f}" + ) # Get device from input_ids if hasattr(input_ids, "device"): input_device = input_ids.device else: input_device = getattr( - position_ids, "device", infinicore.device("cpu", 0)) + position_ids, "device", infinicore.device("cpu", 0) + ) # Initialize decode logits list decode_logits_list = [] @@ -185,7 +195,9 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype, num_decode_ if decode_step == 0: # First decode step: use token from prefill if np.isnan(prefill_logits_np).any(): - print(f" ⚠ WARNING: Using default token 29902 due to NaN in prefill logits") + print( + f" ⚠ WARNING: Using default token 29902 due to NaN in prefill logits" + ) current_token_id = 29902 else: current_token_id = int(prefill_logits_np.argmax(axis=-1)[0, 0]) @@ -193,40 +205,56 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype, num_decode_ # Subsequent decode steps: use token from previous decode prev_logits_np = decode_logits_list[-1] if np.isnan(prev_logits_np).any(): - print(f" ⚠ WARNING: Using default token 29902 due to NaN in decode step {decode_step} logits") + print( + f" ⚠ WARNING: Using default token 29902 due to NaN in decode step {decode_step} logits" + ) current_token_id = 29902 else: current_token_id = int(prev_logits_np.argmax(axis=-1)[0, 0]) - print(f" Step {decode_step + 2}: Decode step {decode_step + 1} (next_token_id={current_token_id})...") + print( + f" Step {decode_step + 2}: Decode step {decode_step + 1} (next_token_id={current_token_id})..." + ) # Create single token input for decode step decode_input_ids = infinicore.from_list( - [[current_token_id]], device=input_device) + [[current_token_id]], + dtype=infinicore.int64, + ).to(device=input_device) # Create position_ids for decode step decode_position_ids = infinicore.from_list( - [[seq_len + decode_step]], dtype=infinicore.int64, device=input_device - ) + [[seq_len + decode_step]], + dtype=infinicore.int64, + ).to(device=input_device) # Run decode step - C++ backend manages cache internally decode_logits = underlying_model.forward( - decode_input_ids, decode_position_ids) + decode_input_ids, decode_position_ids + ) # Convert decode logits to numpy decode_logits_np = infinicore_to_numpy(decode_logits) decode_logits_list.append(decode_logits_np) - print(f" ✓ Decode step {decode_step + 1} completed, logits shape: {decode_logits_np.shape}") + print( + f" ✓ Decode step {decode_step + 1} completed, logits shape: {decode_logits_np.shape}" + ) # Check decode logits for issues if np.isnan(decode_logits_np).any(): - print(f" ⚠ WARNING: Decode step {decode_step + 1} logits contain NaN values!") + print( + f" ⚠ WARNING: Decode step {decode_step + 1} logits contain NaN values!" + ) print(f" NaN count: {np.isnan(decode_logits_np).sum()}") if np.isinf(decode_logits_np).any(): - print(f" ⚠ WARNING: Decode step {decode_step + 1} logits contain Inf values!") + print( + f" ⚠ WARNING: Decode step {decode_step + 1} logits contain Inf values!" + ) print(f" Inf count: {np.isinf(decode_logits_np).sum()}") if not np.isnan(decode_logits_np).any(): - print(f" Decode step {decode_step + 1} logits stats: min={decode_logits_np.min():.6f}, max={decode_logits_np.max():.6f}, mean={decode_logits_np.mean():.6f}") + print( + f" Decode step {decode_step + 1} logits stats: min={decode_logits_np.min():.6f}, max={decode_logits_np.max():.6f}, mean={decode_logits_np.mean():.6f}" + ) else: # Python backend uses DynamicCache # Get model config @@ -246,8 +274,7 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype, num_decode_ input_ids, position_ids, past_key_values=past_key_values, use_cache=True ) prefill_logits_np = infinicore_to_numpy(prefill_logits) - print( - f" ✓ Prefill completed, logits shape: {prefill_logits_np.shape}") + print(f" ✓ Prefill completed, logits shape: {prefill_logits_np.shape}") # Get device from input_ids if hasattr(input_ids, "device"): @@ -255,7 +282,8 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype, num_decode_ else: # Fallback: try to get device from position_ids or use CPU input_device = getattr( - position_ids, "device", infinicore.device("cpu", 0)) + position_ids, "device", infinicore.device("cpu", 0) + ) # Initialize decode logits list decode_logits_list = [] @@ -268,7 +296,9 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype, num_decode_ if decode_step == 0: # First decode step: use token from prefill if np.isnan(prefill_logits_np).any(): - print(f" ⚠ WARNING: Using default token 29902 due to NaN in prefill logits") + print( + f" ⚠ WARNING: Using default token 29902 due to NaN in prefill logits" + ) current_token_id = 29902 else: current_token_id = int(prefill_logits_np.argmax(axis=-1)[0, 0]) @@ -276,46 +306,68 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype, num_decode_ # Subsequent decode steps: use token from previous decode prev_logits_np = decode_logits_list[-1] if np.isnan(prev_logits_np).any(): - print(f" ⚠ WARNING: Using default token 29902 due to NaN in decode step {decode_step} logits") + print( + f" ⚠ WARNING: Using default token 29902 due to NaN in decode step {decode_step} logits" + ) current_token_id = 29902 else: current_token_id = int(prev_logits_np.argmax(axis=-1)[0, 0]) - print(f" Step {decode_step + 2}: Decode step {decode_step + 1} (next_token_id={current_token_id})...") + print( + f" Step {decode_step + 2}: Decode step {decode_step + 1} (next_token_id={current_token_id})..." + ) # Create single token input for decode step decode_input_ids = infinicore.from_list( - [[current_token_id]], device=input_device) + [[current_token_id]], + dtype=infinicore.int64, + ).to(device=input_device) # Create position_ids for decode step decode_position_ids = infinicore.from_list( - [[seq_len + decode_step]], dtype=infinicore.int64, device=input_device - ) + [[seq_len + decode_step]], + dtype=infinicore.int64, + ).to(device=input_device) # Run decode step with KV cache decode_logits = underlying_model.forward( - decode_input_ids, decode_position_ids, past_key_values=past_key_values, use_cache=True + decode_input_ids, + decode_position_ids, + past_key_values=past_key_values, + use_cache=True, ) # Convert decode logits to numpy decode_logits_np = infinicore_to_numpy(decode_logits) decode_logits_list.append(decode_logits_np) - print(f" ✓ Decode step {decode_step + 1} completed, logits shape: {decode_logits_np.shape}") + print( + f" ✓ Decode step {decode_step + 1} completed, logits shape: {decode_logits_np.shape}" + ) # Check decode logits for issues if np.isnan(decode_logits_np).any(): - print(f" ⚠ WARNING: Decode step {decode_step + 1} logits contain NaN values!") + print( + f" ⚠ WARNING: Decode step {decode_step + 1} logits contain NaN values!" + ) print(f" NaN count: {np.isnan(decode_logits_np).sum()}") if np.isinf(decode_logits_np).any(): - print(f" ⚠ WARNING: Decode step {decode_step + 1} logits contain Inf values!") + print( + f" ⚠ WARNING: Decode step {decode_step + 1} logits contain Inf values!" + ) print(f" Inf count: {np.isinf(decode_logits_np).sum()}") if not np.isnan(decode_logits_np).any(): - print(f" Decode step {decode_step + 1} logits stats: min={decode_logits_np.min():.6f}, max={decode_logits_np.max():.6f}, mean={decode_logits_np.mean():.6f}") + print( + f" Decode step {decode_step + 1} logits stats: min={decode_logits_np.min():.6f}, max={decode_logits_np.max():.6f}, mean={decode_logits_np.mean():.6f}" + ) # Summary of all decode steps - print(f" ✓ Forward pass completed (prefill + {num_decode_steps} decode step(s))") + print( + f" ✓ Forward pass completed (prefill + {num_decode_steps} decode step(s))" + ) for i, logits_np in enumerate(decode_logits_list): - print(f" Decode step {i + 1} logits shape: {logits_np.shape}, dtype: {logits_np.dtype}") + print( + f" Decode step {i + 1} logits shape: {logits_np.shape}, dtype: {logits_np.dtype}" + ) # Check for issues in all decode steps has_error = False @@ -329,12 +381,16 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype, num_decode_ print(f" Inf count: {np.isinf(logits_np).sum()}") has_error = True if np.abs(logits_np).max() < 1.0: - print(f" ⚠ WARNING: Decode step {i + 1} logits are very small (max abs: {np.abs(logits_np).max():.6f})") + print( + f" ⚠ WARNING: Decode step {i + 1} logits are very small (max abs: {np.abs(logits_np).max():.6f})" + ) # Get predicted token from last decode step if decode_logits_list and not np.isnan(decode_logits_list[-1]).any(): predicted_token = int(decode_logits_list[-1].argmax(axis=-1)[0, 0]) - print(f" Predicted token ID from decode step {num_decode_steps}: {predicted_token}") + print( + f" Predicted token ID from decode step {num_decode_steps}: {predicted_token}" + ) # Return tuple of all decode logits return tuple(decode_logits_list), has_error @@ -342,6 +398,7 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype, num_decode_ except Exception as e: print(f" ✗ Forward pass failed: {e}") import traceback + traceback.print_exc() return None, True @@ -364,10 +421,12 @@ def infinicore_to_numpy(tensor): # (to_numpy doesn't support bfloat16 directly) if tensor_cpu.dtype == infinicore.bfloat16: import ctypes + # Ensure tensor is actually on CPU and contiguous if tensor_cpu.device.type != "cpu": print( - f" DEBUG: WARNING - tensor_cpu.device.type={tensor_cpu.device.type}, forcing CPU move") + f" DEBUG: WARNING - tensor_cpu.device.type={tensor_cpu.device.type}, forcing CPU move" + ) tensor_cpu = tensor_cpu.to(infinicore.device("cpu", 0)) if not tensor_cpu.is_contiguous(): tensor_cpu = tensor_cpu.contiguous() @@ -380,12 +439,14 @@ def infinicore_to_numpy(tensor): # Debug: Check data pointer and device print( - f" DEBUG: Reading bfloat16 data: data_ptr={data_ptr}, num_elements={num_elements}, shape={shape}, device={tensor_cpu.device}") + f" DEBUG: Reading bfloat16 data: data_ptr={data_ptr}, num_elements={num_elements}, shape={shape}, device={tensor_cpu.device}" + ) # Use a safer approach: copy data using ctypes.memmove to ensure we read from CPU memory uint16_array = np.empty(shape, dtype=np.uint16) - ctypes.memmove(uint16_array.ctypes.data, data_ptr, - num_elements * 2) # 2 bytes per uint16 + ctypes.memmove( + uint16_array.ctypes.data, data_ptr, num_elements * 2 + ) # 2 bytes per uint16 # Convert to torch bfloat16, then to float32, then to numpy torch_uint16 = torch.from_numpy(uint16_array) @@ -398,11 +459,14 @@ def infinicore_to_numpy(tensor): print(f" DEBUG: NaN detected after bfloat16->float32 conversion") print(f" NaN count: {np.isnan(result).sum()}/{result.size}") print( - f" uint16_array stats: min={uint16_array.min()}, max={uint16_array.max()}, mean={uint16_array.mean():.2f}") + f" uint16_array stats: min={uint16_array.min()}, max={uint16_array.max()}, mean={uint16_array.mean():.2f}" + ) print( - f" torch_bf16 stats: min={torch_bf16.min().item():.6f}, max={torch_bf16.max().item():.6f}, mean={torch_bf16.mean().item():.6f}") + f" torch_bf16 stats: min={torch_bf16.min().item():.6f}, max={torch_bf16.max().item():.6f}, mean={torch_bf16.mean().item():.6f}" + ) print( - f" torch_f32 stats: min={torch_f32.min().item():.6f}, max={torch_f32.max().item():.6f}, mean={torch_f32.mean().item():.6f}") + f" torch_f32 stats: min={torch_f32.min().item():.6f}, max={torch_f32.max().item():.6f}, mean={torch_f32.mean().item():.6f}" + ) return result @@ -412,7 +476,8 @@ def infinicore_to_numpy(tensor): # Debug: Check for NaN in conversion result if np.isnan(result).any(): print( - f" DEBUG: NaN detected after to_numpy conversion (dtype={tensor_cpu.dtype})") + f" DEBUG: NaN detected after to_numpy conversion (dtype={tensor_cpu.dtype})" + ) print(f" NaN count: {np.isnan(result).sum()}/{result.size}") return result @@ -458,6 +523,7 @@ def test_configuration(model_path, device, backend, dtype, prompt, num_decode_st except Exception as e: print(f" ✗ Failed to create model: {e}") import traceback + traceback.print_exc() return None, True @@ -474,6 +540,7 @@ def test_configuration(model_path, device, backend, dtype, prompt, num_decode_st except Exception as e: print(f" ✗ Failed to load weights: {e}") import traceback + traceback.print_exc() return None, True @@ -481,22 +548,26 @@ def test_configuration(model_path, device, backend, dtype, prompt, num_decode_st print(f"\n4. Creating inputs from prompt: '{prompt}'...") try: input_ids, position_ids, input_content = create_inputs( - prompt, tokenizer, device, backend=backend) + prompt, tokenizer, device, backend=backend + ) print(f" ✓ Inputs created") print(f" Input content: {input_content[:100]}...") print(f" Input shape: {input_ids.shape}") print( - f" Input device: {input_ids.device.type if hasattr(input_ids, 'device') else 'unknown'}") + f" Input device: {input_ids.device.type if hasattr(input_ids, 'device') else 'unknown'}" + ) except Exception as e: print(f" ✗ Failed to create inputs: {e}") import traceback + traceback.print_exc() return None, True # Run forward pass (prefill + multiple decode steps) print(f"\n5. Running forward pass (prefill + {num_decode_steps} decode step(s))...") logits_tuple, has_error = run_forward_pass( - model, input_ids, position_ids, backend, dtype, num_decode_steps) + model, input_ids, position_ids, backend, dtype, num_decode_steps + ) if has_error: return None, True @@ -529,7 +600,7 @@ def compare_logits(logits1, logits2, name1, name2, step_name="logits"): # Check if they're close (allowing for dtype differences) # For bfloat16 vs float32, we expect larger differences rtol = 1e-2 # 1% relative tolerance - atol = 1.0 # Absolute tolerance + atol = 1.0 # Absolute tolerance is_close = np.allclose(logits1, logits2, rtol=rtol, atol=atol) @@ -544,7 +615,8 @@ def compare_logits(logits1, logits2, name1, name2, step_name="logits"): for idx in top_indices: pos = np.unravel_index(idx, diff.shape) print( - f" Position {pos}: {logits1[pos]:.6f} vs {logits2[pos]:.6f}, diff={diff[pos]:.6f}") + f" Position {pos}: {logits1[pos]:.6f} vs {logits2[pos]:.6f}, diff={diff[pos]:.6f}" + ) return is_close @@ -568,7 +640,12 @@ def main(): print("TEST 1: Python Backend + BFloat16") print("=" * 80) logits_py_bf16, error = test_configuration( - args.model_path, args.device, "python", "bfloat16", args.prompt, args.num_decode_steps + args.model_path, + args.device, + "python", + "bfloat16", + args.prompt, + args.num_decode_steps, ) results["python_bf16"] = (logits_py_bf16, error) @@ -577,7 +654,12 @@ def main(): print("TEST 3: C++ Backend + BFloat16") print("=" * 80) logits_cpp_bf16, error = test_configuration( - args.model_path, args.device, "cpp", "bfloat16", args.prompt, args.num_decode_steps + args.model_path, + args.device, + "cpp", + "bfloat16", + args.prompt, + args.num_decode_steps, ) results["cpp_bf16"] = (logits_cpp_bf16, error) @@ -603,7 +685,7 @@ def main(): cpp_logits[step_idx], "Python BF16", "C++ BF16", - step_name + step_name, ) comparisons.append((f"Python BF16 vs C++ BF16 ({step_name})", is_close))