Skip to content

vllm.distributed.kv_transfer.kv_connector.v1.offloading.scheduler

OffloadingConnectorScheduler

Implementation of Scheduler side methods

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
class OffloadingConnectorScheduler:
    """Implementation of Scheduler side methods"""

    def __init__(self, spec: OffloadingSpec):
        self.config = SchedulerOffloadConfig.from_spec(spec)
        self.manager: OffloadingManager = spec.get_manager()

        attention_groups: list[int] = []
        for idx, _ in enumerate(spec.kv_cache_config.kv_cache_groups):
            # currently treat all groups as full attention
            attention_groups.append(idx)

        self.lookup_groups = attention_groups

        self._req_status: dict[ReqId, RequestOffloadState] = {}
        # requests to load for the current scheduler step
        self._reqs_to_load: dict[ReqId, TransferSpec] = {}
        # if GPU prefix caching is enabled,
        # track loaded blocks to avoid redundant loads
        self._blocks_being_loaded: set[OffloadKey] | None = (
            set() if spec.vllm_config.cache_config.enable_prefix_caching else None
        )

        # request ID -> set(offload keys being stored/loaded)
        self._reqs_being_stored = defaultdict[ReqId, set[OffloadKey]](set)
        self._reqs_being_loaded = defaultdict[ReqId, set[OffloadKey]](set)

    def _maximal_prefix_lookup(
        self, keys: Iterable[OffloadKey], req_context: ReqContext
    ) -> int | None:
        """Find the length of the maximal prefix of offloaded blocks."""
        hit_count = 0
        defer_lookup = False
        for key in keys:
            result = self.manager.lookup(key, req_context)
            if result is None:
                defer_lookup = True
                # continue lookup to allow manager to kick-off async lookups
                # for all blocks (until a miss is detected)
                result = True
            if not result:
                break
            hit_count += 1
        return hit_count if not defer_lookup else None

    def _sliding_window_lookup(
        self,
        keys: Sequence[OffloadKey],
        sliding_window_size: int,
        req_context: ReqContext,
    ) -> int | None:
        """Find the maximal ending position of consecutive offloaded blocks
        within a sliding window."""
        defer_lookup = False
        consecutive_hits = 0
        for idx in range(len(keys) - 1, -1, -1):
            result = self.manager.lookup(keys[idx], req_context)
            if result is None:
                defer_lookup = True
                # continue lookup to allow manager to kick-off async lookups
                # for all blocks (until a hit is detected)
                result = False
            if not result:
                consecutive_hits = 0
            else:
                consecutive_hits += 1
                if consecutive_hits == sliding_window_size:
                    return idx + sliding_window_size if not defer_lookup else None
        return consecutive_hits if not defer_lookup else None

    def get_num_new_matched_tokens(
        self, request: Request, num_computed_tokens: int
    ) -> tuple[int | None, bool]:
        """
        Get number of new tokens that can be loaded beyond the
        num_computed_tokens.

        Args:
            request (Request): the request object.
            num_computed_tokens (int): the number of locally
                computed tokens for this request

        Returns:
            A tuple with the following elements:
                - The number of tokens that can be loaded beyond what is
                  already computed.
                  If None, it means that the connector needs more time to
                  determine the number of matched tokens, and the scheduler
                  should query for this request again later.
                - `True` if tokens will be loaded asynchronously
                  (between scheduler steps).
        """
        if req_status := self._req_status.get(request.request_id):
            # make sure block IDs are cleared
            for group_state in req_status.group_states:
                group_state.block_ids.clear()
        else:
            req_status = RequestOffloadState(config=self.config, req=request)
            self._req_status[request.request_id] = req_status

        req_status.update_offload_keys()
        req_status.num_locally_computed_tokens = num_computed_tokens

        for gs in req_status.group_states:
            self.manager.touch(gs.offload_keys)

        # Start with the full request size as the maximum loadable
        max_hit_size_tokens: int = req_status.req.num_tokens
        num_hit_tokens: int = 0
        defer_lookup = False
        delay_request = False
        for group_idx in self.lookup_groups:
            group_config: GroupOffloadConfig = self.config.kv_group_configs[group_idx]
            offloaded_block_size = group_config.offloaded_block_size
            offload_keys = req_status.group_states[group_idx].offload_keys

            num_blocks = max_hit_size_tokens // offloaded_block_size
            assert len(offload_keys) >= num_blocks

            # Constrain to block-aligned boundary for this group
            max_hit_size_tokens = num_blocks * offloaded_block_size
            num_hit_tokens = max_hit_size_tokens - num_computed_tokens
            if num_hit_tokens < offloaded_block_size:
                # we can only load less than a block, better skip
                return 0, False

            start_block_idx = num_computed_tokens // offloaded_block_size
            offload_keys = offload_keys[start_block_idx:num_blocks]
            # Full attention relies on all previous KV cache blocks.
            # Thus, we search for a maximal prefix of KV cache which are all cached.
            block_hits = self._maximal_prefix_lookup(
                offload_keys, req_status.req_context
            )
            if block_hits == 0:
                return 0, False

            if block_hits is None:
                defer_lookup = True
            else:
                # Further constrain based on what's actually available by backend
                max_hit_size_tokens = offloaded_block_size * (
                    start_block_idx + block_hits
                )

            num_hit_tokens = max_hit_size_tokens - num_computed_tokens
            if num_hit_tokens < offloaded_block_size:
                # we can only load less than a block, better skip
                return 0, False

            if (
                block_hits
                and self._blocks_being_loaded
                and any(
                    key in self._blocks_being_loaded
                    for key in offload_keys[:block_hits]
                )
            ):
                # hit blocks are being loaded, delay request
                delay_request = True

        if defer_lookup:
            logger.debug(
                "Offloading manager delayed request %s as backend requested",
                req_status.req.request_id,
            )
            return None, False

        if delay_request:
            logger.debug(
                "Delaying request %s since some of its blocks are already being loaded",
                req_status.req.request_id,
            )
            return None, False

        logger.debug(
            "Request %s hit %s offloaded tokens after %s GPU hit tokens",
            request.request_id,
            num_hit_tokens,
            num_computed_tokens,
        )

        return num_hit_tokens, True

    def update_state_after_alloc(
        self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
    ):
        if num_external_tokens == 0:
            return

        req_status = self._req_status[request.request_id]

        num_locally_computed_tokens = req_status.num_locally_computed_tokens
        num_cached_tokens = num_locally_computed_tokens + num_external_tokens

        keys_to_load: list[OffloadKey] = []
        dst_block_ids: list[int] = []
        # per group
        group_sizes: list[int] = []
        block_indices: list[int] = []
        for group_config, group_state, group_blocks in zip(
            self.config.kv_group_configs,
            req_status.group_states,
            blocks.blocks,
        ):
            gpu_block_size = group_config.gpu_block_size
            offloaded_block_size = group_config.offloaded_block_size
            offload_keys = group_state.offload_keys
            num_gpu_blocks = cdiv(num_cached_tokens, gpu_block_size)

            assert len(group_blocks) >= num_gpu_blocks
            num_locally_computed_gpu_blocks = num_gpu_blocks
            # Skip null placeholder blocks (used for sliding window or mamba padding).
            for i, block in enumerate(group_blocks[:num_gpu_blocks]):
                if not block.is_null and block.block_hash is None:
                    num_locally_computed_gpu_blocks = i
                    break

            assert (
                num_locally_computed_tokens
                <= num_locally_computed_gpu_blocks * gpu_block_size
            )
            num_pending_gpu_blocks = num_gpu_blocks - num_locally_computed_gpu_blocks

            num_blocks = cdiv(num_cached_tokens, offloaded_block_size)
            assert len(offload_keys) >= num_blocks
            if num_pending_gpu_blocks:
                start_block_idx = (
                    num_locally_computed_gpu_blocks // self.config.block_size_factor
                )
                keys_to_load.extend(offload_keys[start_block_idx:num_blocks])

            dst_block_ids.extend(
                block.block_id
                for block in group_blocks[
                    num_locally_computed_gpu_blocks:num_gpu_blocks
                ]
            )
            group_sizes.append(num_pending_gpu_blocks)
            block_indices.append(num_locally_computed_gpu_blocks)

            group_state.next_stored_block_idx = num_blocks

        src_spec = self.manager.prepare_load(keys_to_load, req_status.req_context)
        dst_spec = GPULoadStoreSpec(
            dst_block_ids, group_sizes=group_sizes, block_indices=block_indices
        )

        self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
        req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
        req_blocks_being_loaded.update(keys_to_load)

        if self._blocks_being_loaded is not None:
            self._blocks_being_loaded.update(req_blocks_being_loaded)

    def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
        # Below assertion will be removed once this function supports HMA
        assert len(self.config.kv_group_configs) == 1
        group_config = self.config.kv_group_configs[0]

        reqs_to_store: dict[ReqId, TransferSpec] = {}
        # iterate over both new and cached requests
        for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
            req_status = self._req_status[req_id]
            req_status.update_offload_keys()

            if preempted:
                for group_state in req_status.group_states:
                    group_state.block_ids.clear()

            if new_block_id_groups:
                req_status.update_block_id_groups(new_block_id_groups)

            # Below assertion will be removed once this function supports HMA
            assert len(req_status.group_states) == 1
            group_state = req_status.group_states[0]

            block_ids = group_state.block_ids

            req = req_status.req
            new_tokens = scheduler_output.num_scheduled_tokens[req_id]
            expected_tokens = req.num_computed_tokens + new_tokens
            # with async scheduling, some tokens may be missing
            total_tokens = min(expected_tokens, req.num_tokens)
            num_blocks = total_tokens // group_config.offloaded_block_size
            start_block_idx = group_state.next_stored_block_idx
            num_new_blocks = num_blocks - start_block_idx

            if num_new_blocks <= 0:
                continue

            num_gpu_blocks = num_blocks * self.config.block_size_factor
            assert len(req.block_hashes) >= num_gpu_blocks

            new_offload_keys = group_state.offload_keys[start_block_idx:num_blocks]
            store_output = self.manager.prepare_store(
                new_offload_keys, req_status.req_context
            )
            if store_output is None:
                logger.warning(
                    "Request %s: cannot store %s blocks", req_id, num_new_blocks
                )
                continue

            group_state.next_stored_block_idx = num_blocks

            if not store_output.keys_to_store:
                continue
            keys_to_store = set(store_output.keys_to_store)

            self.manager.touch(group_state.offload_keys[:num_blocks])

            dst_spec = store_output.store_spec
            src_block_ids: list[int] = []
            for idx, key in enumerate(new_offload_keys):
                if key not in keys_to_store:
                    continue
                offloaded_block_idx = start_block_idx + idx
                gpu_block_idx = offloaded_block_idx * self.config.block_size_factor
                for i in range(self.config.block_size_factor):
                    src_block_ids.append(block_ids[gpu_block_idx + i])
            src_spec = GPULoadStoreSpec(
                src_block_ids,
                group_sizes=(len(src_block_ids),),
                block_indices=(0,),
            )

            reqs_to_store[req_id] = (src_spec, dst_spec)
            self._reqs_being_stored[req_id] |= keys_to_store

            logger.debug(
                "Request %s offloading %s blocks starting from block #%d",
                req_id,
                len(keys_to_store),
                start_block_idx,
            )

        return reqs_to_store

    def build_connector_meta(
        self, scheduler_output: SchedulerOutput
    ) -> KVConnectorMetadata:
        meta = OffloadingConnectorMetadata(
            reqs_to_load=self._reqs_to_load,
            reqs_to_store=self._get_reqs_to_store(scheduler_output),
            reqs_to_flush=scheduler_output.preempted_req_ids,
        )
        self._reqs_to_load = {}

        # NOTE (orozery): we should move this logic to update_connector_output
        # once KVConnectorOutput allows us to report completed transfers
        for req_id in scheduler_output.preempted_req_ids or ():
            keys = self._reqs_being_stored.get(req_id)
            if keys:
                self.manager.complete_store(keys)
                keys.clear()

        return meta

    def update_connector_output(self, connector_output: KVConnectorOutput):
        """
        Update KVConnector state from worker-side connectors output.

        Args:
            connector_output (KVConnectorOutput): the worker-side
                connectors output.
        """
        for req_id in connector_output.finished_sending or []:
            keys = self._reqs_being_stored.pop(req_id, None)
            if keys:
                self.manager.complete_store(keys)

        for req_id in connector_output.finished_recving or []:
            keys = self._reqs_being_loaded.pop(req_id, None)
            if keys:
                if self._blocks_being_loaded:
                    self._blocks_being_loaded.difference_update(keys)
                self.manager.complete_load(keys)

    def request_finished(
        self,
        request: Request,
        block_ids: list[int],
    ) -> tuple[bool, dict[str, Any] | None]:
        """
        Called when a request has finished, before its blocks are freed.

        Returns:
            True if the request is being saved/sent asynchronously and blocks
            should not be freed until the request_id is returned from
            get_finished().
            Optional KVTransferParams to be included in the request outputs
            returned by the engine.
        """
        req_id = request.request_id

        # TODO(orozery): possibly kickoff offload for last block
        # which may have been deferred due to async scheduling
        self._req_status.pop(req_id, None)

        request_being_stored = req_id in self._reqs_being_stored
        return request_being_stored, None

    def take_events(self) -> Iterable[KVCacheEvent]:
        """Take the KV cache events from the connector.

        Returns:
            A list of KV cache events.
        """
        for event in self.manager.take_events():
            block_hashes = [get_offload_block_hash(key) for key in event.keys]
            if event.removed:
                yield BlockRemoved(block_hashes=block_hashes, medium=event.medium)
            else:
                yield BlockStored(
                    block_hashes=block_hashes,
                    parent_block_hash=None,
                    token_ids=[],
                    lora_id=None,
                    block_size=0,
                    medium=event.medium,
                    lora_name=None,
                )

    def shutdown(self) -> None:
        self.manager.shutdown()

_maximal_prefix_lookup

_maximal_prefix_lookup(
    keys: Iterable[OffloadKey], req_context: ReqContext
) -> int | None

Find the length of the maximal prefix of offloaded blocks.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
def _maximal_prefix_lookup(
    self, keys: Iterable[OffloadKey], req_context: ReqContext
) -> int | None:
    """Find the length of the maximal prefix of offloaded blocks."""
    hit_count = 0
    defer_lookup = False
    for key in keys:
        result = self.manager.lookup(key, req_context)
        if result is None:
            defer_lookup = True
            # continue lookup to allow manager to kick-off async lookups
            # for all blocks (until a miss is detected)
            result = True
        if not result:
            break
        hit_count += 1
    return hit_count if not defer_lookup else None

_sliding_window_lookup

_sliding_window_lookup(
    keys: Sequence[OffloadKey],
    sliding_window_size: int,
    req_context: ReqContext,
) -> int | None

Find the maximal ending position of consecutive offloaded blocks within a sliding window.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
def _sliding_window_lookup(
    self,
    keys: Sequence[OffloadKey],
    sliding_window_size: int,
    req_context: ReqContext,
) -> int | None:
    """Find the maximal ending position of consecutive offloaded blocks
    within a sliding window."""
    defer_lookup = False
    consecutive_hits = 0
    for idx in range(len(keys) - 1, -1, -1):
        result = self.manager.lookup(keys[idx], req_context)
        if result is None:
            defer_lookup = True
            # continue lookup to allow manager to kick-off async lookups
            # for all blocks (until a hit is detected)
            result = False
        if not result:
            consecutive_hits = 0
        else:
            consecutive_hits += 1
            if consecutive_hits == sliding_window_size:
                return idx + sliding_window_size if not defer_lookup else None
    return consecutive_hits if not defer_lookup else None

get_num_new_matched_tokens

get_num_new_matched_tokens(
    request: Request, num_computed_tokens: int
) -> tuple[int | None, bool]

Get number of new tokens that can be loaded beyond the num_computed_tokens.

Parameters:

Name Type Description Default
request Request

the request object.

required
num_computed_tokens int

the number of locally computed tokens for this request

required

Returns:

Type Description
tuple[int | None, bool]

A tuple with the following elements: - The number of tokens that can be loaded beyond what is already computed. If None, it means that the connector needs more time to determine the number of matched tokens, and the scheduler should query for this request again later. - True if tokens will be loaded asynchronously (between scheduler steps).

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
def get_num_new_matched_tokens(
    self, request: Request, num_computed_tokens: int
) -> tuple[int | None, bool]:
    """
    Get number of new tokens that can be loaded beyond the
    num_computed_tokens.

    Args:
        request (Request): the request object.
        num_computed_tokens (int): the number of locally
            computed tokens for this request

    Returns:
        A tuple with the following elements:
            - The number of tokens that can be loaded beyond what is
              already computed.
              If None, it means that the connector needs more time to
              determine the number of matched tokens, and the scheduler
              should query for this request again later.
            - `True` if tokens will be loaded asynchronously
              (between scheduler steps).
    """
    if req_status := self._req_status.get(request.request_id):
        # make sure block IDs are cleared
        for group_state in req_status.group_states:
            group_state.block_ids.clear()
    else:
        req_status = RequestOffloadState(config=self.config, req=request)
        self._req_status[request.request_id] = req_status

    req_status.update_offload_keys()
    req_status.num_locally_computed_tokens = num_computed_tokens

    for gs in req_status.group_states:
        self.manager.touch(gs.offload_keys)

    # Start with the full request size as the maximum loadable
    max_hit_size_tokens: int = req_status.req.num_tokens
    num_hit_tokens: int = 0
    defer_lookup = False
    delay_request = False
    for group_idx in self.lookup_groups:
        group_config: GroupOffloadConfig = self.config.kv_group_configs[group_idx]
        offloaded_block_size = group_config.offloaded_block_size
        offload_keys = req_status.group_states[group_idx].offload_keys

        num_blocks = max_hit_size_tokens // offloaded_block_size
        assert len(offload_keys) >= num_blocks

        # Constrain to block-aligned boundary for this group
        max_hit_size_tokens = num_blocks * offloaded_block_size
        num_hit_tokens = max_hit_size_tokens - num_computed_tokens
        if num_hit_tokens < offloaded_block_size:
            # we can only load less than a block, better skip
            return 0, False

        start_block_idx = num_computed_tokens // offloaded_block_size
        offload_keys = offload_keys[start_block_idx:num_blocks]
        # Full attention relies on all previous KV cache blocks.
        # Thus, we search for a maximal prefix of KV cache which are all cached.
        block_hits = self._maximal_prefix_lookup(
            offload_keys, req_status.req_context
        )
        if block_hits == 0:
            return 0, False

        if block_hits is None:
            defer_lookup = True
        else:
            # Further constrain based on what's actually available by backend
            max_hit_size_tokens = offloaded_block_size * (
                start_block_idx + block_hits
            )

        num_hit_tokens = max_hit_size_tokens - num_computed_tokens
        if num_hit_tokens < offloaded_block_size:
            # we can only load less than a block, better skip
            return 0, False

        if (
            block_hits
            and self._blocks_being_loaded
            and any(
                key in self._blocks_being_loaded
                for key in offload_keys[:block_hits]
            )
        ):
            # hit blocks are being loaded, delay request
            delay_request = True

    if defer_lookup:
        logger.debug(
            "Offloading manager delayed request %s as backend requested",
            req_status.req.request_id,
        )
        return None, False

    if delay_request:
        logger.debug(
            "Delaying request %s since some of its blocks are already being loaded",
            req_status.req.request_id,
        )
        return None, False

    logger.debug(
        "Request %s hit %s offloaded tokens after %s GPU hit tokens",
        request.request_id,
        num_hit_tokens,
        num_computed_tokens,
    )

    return num_hit_tokens, True

request_finished

request_finished(
    request: Request, block_ids: list[int]
) -> tuple[bool, dict[str, Any] | None]

Called when a request has finished, before its blocks are freed.

Returns:

Type Description
bool

True if the request is being saved/sent asynchronously and blocks

dict[str, Any] | None

should not be freed until the request_id is returned from

tuple[bool, dict[str, Any] | None]

get_finished().

tuple[bool, dict[str, Any] | None]

Optional KVTransferParams to be included in the request outputs

tuple[bool, dict[str, Any] | None]

returned by the engine.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
def request_finished(
    self,
    request: Request,
    block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
    """
    Called when a request has finished, before its blocks are freed.

    Returns:
        True if the request is being saved/sent asynchronously and blocks
        should not be freed until the request_id is returned from
        get_finished().
        Optional KVTransferParams to be included in the request outputs
        returned by the engine.
    """
    req_id = request.request_id

    # TODO(orozery): possibly kickoff offload for last block
    # which may have been deferred due to async scheduling
    self._req_status.pop(req_id, None)

    request_being_stored = req_id in self._reqs_being_stored
    return request_being_stored, None

take_events

take_events() -> Iterable[KVCacheEvent]

Take the KV cache events from the connector.

Returns:

Type Description
Iterable[KVCacheEvent]

A list of KV cache events.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
def take_events(self) -> Iterable[KVCacheEvent]:
    """Take the KV cache events from the connector.

    Returns:
        A list of KV cache events.
    """
    for event in self.manager.take_events():
        block_hashes = [get_offload_block_hash(key) for key in event.keys]
        if event.removed:
            yield BlockRemoved(block_hashes=block_hashes, medium=event.medium)
        else:
            yield BlockStored(
                block_hashes=block_hashes,
                parent_block_hash=None,
                token_ids=[],
                lora_id=None,
                block_size=0,
                medium=event.medium,
                lora_name=None,
            )

update_connector_output

update_connector_output(
    connector_output: KVConnectorOutput,
)

Update KVConnector state from worker-side connectors output.

Parameters:

Name Type Description Default
connector_output KVConnectorOutput

the worker-side connectors output.

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
def update_connector_output(self, connector_output: KVConnectorOutput):
    """
    Update KVConnector state from worker-side connectors output.

    Args:
        connector_output (KVConnectorOutput): the worker-side
            connectors output.
    """
    for req_id in connector_output.finished_sending or []:
        keys = self._reqs_being_stored.pop(req_id, None)
        if keys:
            self.manager.complete_store(keys)

    for req_id in connector_output.finished_recving or []:
        keys = self._reqs_being_loaded.pop(req_id, None)
        if keys:
            if self._blocks_being_loaded:
                self._blocks_being_loaded.difference_update(keys)
            self.manager.complete_load(keys)