Skip to content

fix: correct EOS handling in batched generation#270

Open
RedaRahmani wants to merge 1 commit intomistralai:mainfrom
RedaRahmani:fix/batched-eos-token-handling
Open

fix: correct EOS handling in batched generation#270
RedaRahmani wants to merge 1 commit intomistralai:mainfrom
RedaRahmani:fix/batched-eos-token-handling

Conversation

@RedaRahmani
Copy link
Copy Markdown

fix: correct EOS handling in batched generation

Problem

When running batched inference with eos_id set, sequences that finish at different steps produce corrupted output. If row 0 hits EOS at step 1 but row 1 hasn't finished yet, the current code:

  1. Keeps sampling random tokens for finished rows and appends them to the output
  2. Keeps recording logprobs for finished rows, inflating their logprob lists with meaningless values
  3. Feeds those garbage tokens back into the model, wasting compute on already-completed sequences

This means the returned generated_tokens for early-finishing rows contain junk tokens generated after their EOS, and the logprobs don't match the actual generation.

Root cause

The original loop checks is_finished.all() to break, but between EOS detection and the break, it unconditionally appends tokens and logprobs for every row — including rows that already hit EOS. It also feeds all rows back into model.forward() regardless of their finished state.

Fix

  • Freeze finished rows: overwrite next_token with eos_id for any row that already finished, so the model gets a consistent input and the KV cache stays clean.
  • Skip logprobs for finished rows: only append logprobs when not is_finished[i].
  • Track per-row stop position: a finished_after_tokens tensor records exactly when each row hit EOS. The final output is trimmed per-row to exclude post-EOS tokens.
  • Reorder EOS detection: moved the is_finished update to after appending the current token, so the token that triggered EOS is correctly included in the output before marking the row as done.

Tests

Added tests/test_generate_eos_batching.py with a lightweight DummyTransformer mock (no GPU needed):

  • test_generate_stops_each_row_after_its_own_eos — row 0 finishes at step 1, row 1 at step 3. Validates that outputs are trimmed correctly, logprob lengths match, and the frozen row feeds eos_id back to the model instead of random tokens.
  • test_generate_keeps_final_eos_when_all_rows_finish_same_step — both rows finish simultaneously. Validates the early exit and that no unnecessary forward passes happen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant