[ET-VK] Fix use-after-free in PrepackNode when TensorRefs are shared#18906
[ET-VK] Fix use-after-free in PrepackNode when TensorRefs are shared#18906meta-codesync[bot] merged 2 commits intogh/SS-JIA/520/basefrom
Conversation
When a model has shared/tied weights (e.g. tied embeddings in transformers), the serialization deduplicates them into a single TensorRef that multiple PrepackNodes reference. Previously, `PrepackNode::create_staging_buffer()` called `tref->free_buffer()` unconditionally after copying weight data to a GPU staging buffer. This meant the first PrepackNode to execute would free the underlying host memory, and subsequent PrepackNodes sharing the same TensorRef would read from a dangling pointer — producing garbage/NaN values in prepacked weight and bias tensors on the GPU. The fix adds a `prepack_use_count` field to `TensorRef` that tracks how many PrepackNodes still need to read from it. Each PrepackNode increments the count in its constructor and decrements it after copying data. The buffer is only freed when the count reaches zero. This preserves the original eager-free behavior for non-shared weights (freeing immediately after the single consumer copies) while correctly deferring the free for shared weights until the last consumer is done — avoiding both the use-after-free and unnecessary peak memory increase. Differential Revision: [D101009402](https://our.internmc.facebook.com/intern/diff/D101009402/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18906
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 2 Unrelated FailuresAs of commit 5fd2bae with merge base 89e6416 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
When a model has shared/tied weights (e.g. tied embeddings in transformers), the serialization deduplicates them into a single TensorRef that multiple PrepackNodes reference. Previously, `PrepackNode::create_staging_buffer()` called `tref->free_buffer()` unconditionally after copying weight data to a GPU staging buffer. This meant the first PrepackNode to execute would free the underlying host memory, and subsequent PrepackNodes sharing the same TensorRef would read from a dangling pointer — producing garbage/NaN values in prepacked weight and bias tensors on the GPU. The fix adds a `prepack_use_count` field to `TensorRef` that tracks how many PrepackNodes still need to read from it. Each PrepackNode increments the count in its constructor and decrements it after copying data. The buffer is only freed when the count reaches zero. This preserves the original eager-free behavior for non-shared weights (freeing immediately after the single consumer copies) while correctly deferring the free for shared weights until the last consumer is done — avoiding both the use-after-free and unnecessary peak memory increase. Differential Revision: [D101009402](https://our.internmc.facebook.com/intern/diff/D101009402/) ghstack-source-id: 367723453 Pull Request resolved: #18906
This PR needs a
|
…are shared" When a model has shared/tied weights (e.g. tied embeddings in transformers), the serialization deduplicates them into a single TensorRef that multiple PrepackNodes reference. Previously, `PrepackNode::create_staging_buffer()` called `tref->free_buffer()` unconditionally after copying weight data to a GPU staging buffer. This meant the first PrepackNode to execute would free the underlying host memory, and subsequent PrepackNodes sharing the same TensorRef would read from a dangling pointer — producing garbage/NaN values in prepacked weight and bias tensors on the GPU. The fix adds a `prepack_use_count` field to `TensorRef` that tracks how many PrepackNodes still need to read from it. Each PrepackNode increments the count in its constructor and decrements it after copying data. The buffer is only freed when the count reaches zero. This preserves the original eager-free behavior for non-shared weights (freeing immediately after the single consumer copies) while correctly deferring the free for shared weights until the last consumer is done — avoiding both the use-after-free and unnecessary peak memory increase. Differential Revision: [D101009402](https://our.internmc.facebook.com/intern/diff/D101009402/) [ghstack-poisoned]
Pull Request resolved: #18906 When a model has shared/tied weights (e.g. tied embeddings in transformers), the serialization deduplicates them into a single TensorRef that multiple PrepackNodes reference. Previously, `PrepackNode::create_staging_buffer()` called `tref->free_buffer()` unconditionally after copying weight data to a GPU staging buffer. This meant the first PrepackNode to execute would free the underlying host memory, and subsequent PrepackNodes sharing the same TensorRef would read from a dangling pointer — producing garbage/NaN values in prepacked weight and bias tensors on the GPU. The fix adds a `prepack_use_count` field to `TensorRef` that tracks how many PrepackNodes still need to read from it. Each PrepackNode increments the count in its constructor and decrements it after copying data. The buffer is only freed when the count reaches zero. This preserves the original eager-free behavior for non-shared weights (freeing immediately after the single consumer copies) while correctly deferring the free for shared weights until the last consumer is done — avoiding both the use-after-free and unnecessary peak memory increase. ghstack-source-id: 367726483 @exported-using-ghexport Differential Revision: [D101009402](https://our.internmc.facebook.com/intern/diff/D101009402/)
cd7dc61
into
gh/SS-JIA/520/base
Pull Request resolved: #18906 When a model has shared/tied weights (e.g. tied embeddings in transformers), the serialization deduplicates them into a single TensorRef that multiple PrepackNodes reference. Previously, `PrepackNode::create_staging_buffer()` called `tref->free_buffer()` unconditionally after copying weight data to a GPU staging buffer. This meant the first PrepackNode to execute would free the underlying host memory, and subsequent PrepackNodes sharing the same TensorRef would read from a dangling pointer — producing garbage/NaN values in prepacked weight and bias tensors on the GPU. The fix adds a `prepack_use_count` field to `TensorRef` that tracks how many PrepackNodes still need to read from it. Each PrepackNode increments the count in its constructor and decrements it after copying data. The buffer is only freed when the count reaches zero. This preserves the original eager-free behavior for non-shared weights (freeing immediately after the single consumer copies) while correctly deferring the free for shared weights until the last consumer is done — avoiding both the use-after-free and unnecessary peak memory increase. ghstack-source-id: 367726483 @exported-using-ghexport Differential Revision: [D101009402](https://our.internmc.facebook.com/intern/diff/D101009402/)
Stack from ghstack (oldest at bottom):
When a model has shared/tied weights (e.g. tied embeddings in transformers), the serialization deduplicates them into a single TensorRef that multiple PrepackNodes reference. Previously,
PrepackNode::create_staging_buffer()calledtref->free_buffer()unconditionally after copying weight data to a GPU staging buffer. This meant the first PrepackNode to execute would free the underlying host memory, and subsequent PrepackNodes sharing the same TensorRef would read from a dangling pointer — producing garbage/NaN values in prepacked weight and bias tensors on the GPU.The fix adds a
prepack_use_countfield toTensorRefthat tracks how many PrepackNodes still need to read from it. Each PrepackNode increments the count in its constructor and decrements it after copying data. The buffer is only freed when the count reaches zero. This preserves the original eager-free behavior for non-shared weights (freeing immediately after the single consumer copies) while correctly deferring the free for shared weights until the last consumer is done — avoiding both the use-after-free and unnecessary peak memory increase.Differential Revision: D101009402