Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def run(
# ---------------------------------------------------------------------------- #
# 自回归生成
# ---------------------------------------------------------------------------- #
input_ids_infini = infinicore.from_list(input_ids_list)
input_ids_infini = infinicore.from_list(input_ids_list, dtype=infinicore.int64)

t1 = time.time()
print("=================== start generate ====================")
Expand Down Expand Up @@ -348,7 +348,7 @@ def run(
for ids in test.input_ids_list
]

input_ids_infini = infinicore.from_list(warmup_ids)
input_ids_infini = infinicore.from_list(warmup_ids, dtype=infinicore.int64)

print("=================== warmup start ===================")

Expand Down
8 changes: 4 additions & 4 deletions examples/infer_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test(
# Generate
# ---------------------------------------------------------------------------- #
print(input_contents[0], end="", flush=True)
input_ids_infini = infinicore.from_list(input_ids_list)
input_ids_infini = infinicore.from_list(input_ids_list, dtype=infinicore.int64)

# Process multimodal inputs if needed
pixel_values_infini = None
Expand All @@ -207,9 +207,9 @@ def test(

# 1. Pixel values
all_pixel_values = []
assert (
len(pixel_values) == 1
), "Only batch_size=1 is supported yet for image inputs."
assert len(pixel_values) == 1, (
"Only batch_size=1 is supported yet for image inputs."
)
for pv in pixel_values:
all_pixel_values.extend(
[i.flatten(end_dim=1).permute(1, 0) for i in pv]
Expand Down
2 changes: 1 addition & 1 deletion examples/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test(
# ---------------------------------------------------------------------------- #
# 自回归生成
# ---------------------------------------------------------------------------- #
input_ids_infini = infinicore.from_list(input_ids_list)
input_ids_infini = infinicore.from_list(input_ids_list, dtype=infinicore.int64)

t1 = time.time()
print("=================== start generate ====================")
Expand Down
20 changes: 12 additions & 8 deletions python/infinilm/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,16 @@ def prepare_inputs_for_generation(
bs, seq_len = current_position_ids.shape
last_position = current_position_ids.narrow(1, seq_len - 1, 1)

one_value = infinicore.from_list(
[1] * bs,
dtype=last_position.dtype,
device=last_position.device,
).view((bs, 1))
one_value = (
infinicore.from_list(
[1] * bs,
dtype=last_position.dtype,
)
.view((bs, 1))
.to(
device=last_position.device,
)
)

next_position = one_value + last_position
model_inputs["position_ids"] = next_position
Expand All @@ -99,15 +104,14 @@ def prepare_inputs_for_generation(
] + infinicore.from_list(
[seq_len],
dtype=last_position.dtype,
device=last_position.device,
)
).to(device=last_position.device)
# -------------------------------------------------------------------- #
# 所需的: token的input_ids
# -------------------------------------------------------------------- #
if kwargs.get("next_token_ids", None) is not None:
next_token_ids = kwargs["next_token_ids"]
model_inputs["input_ids"] = infinicore.from_list(
[[id_] for id_ in next_token_ids],
[[id_] for id_ in next_token_ids], dtype=infinicore.int64
)

# -------------------------------------------------------------------- #
Expand Down
13 changes: 10 additions & 3 deletions python/infinilm/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def read_hf_config(model_path):
)
return config_dict


# config.json (required) defines model architecture, while generation_config.json
# (optional) defines generation behavior. They are kept as separate readers
# because: 1) config.json must exist and requires model_type validation,
Expand All @@ -43,6 +44,7 @@ def read_hf_generation_config(model_path):
return json.load(f)
return {}


@dataclass
class GenerationConfig:
max_new_tokens: int | None = None
Expand Down Expand Up @@ -244,9 +246,10 @@ def generate(
) // paged_block_size

block_tables_list = [
range(i * max_blocks_per_batch, (i + 1) * max_blocks_per_batch)
list(range(i * max_blocks_per_batch, (i + 1) * max_blocks_per_batch))
for i in range(batch_size)
]

block_tables = infinicore.from_list(
block_tables_list,
dtype=infinicore.int32,
Expand Down Expand Up @@ -375,10 +378,14 @@ def reset_cache(self, cache_config):
super().reset_cache(cache_config)

def state_dict_keyname(self):
return sorted({name for state_dict in super().state_dict() for name in state_dict.keys()})
return sorted(
{name for state_dict in super().state_dict() for name in state_dict.keys()}
)

def load_state_dict(self, state_dict, strict=None):
super().load_params({name: param._underlying for name, param in state_dict.items()})
super().load_params(
{name: param._underlying for name, param in state_dict.items()}
)

def process_weights_after_loading(self):
super().process_weights_after_loading()
2 changes: 1 addition & 1 deletion test/bench/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _generate_step(self, tokens, max_steps, topp_, topk_, temperature_):
from infinilm.infer_engine import GenerationConfig

input_ids_list = [tokens]
input_ids = infinicore.from_list(input_ids_list)
input_ids = infinicore.from_list(input_ids_list, dtype=infinicore.int64)

start_time = time.perf_counter()

Expand Down
Loading
Loading