22import multiprocessing as mp
33import time
44from functools import partial
5+ from typing import List
56
67import numpy as np
78import pytest
89import torch
910import torch .nn as nn
1011import torch .nn .functional as F
12+ from multiaddr import Multiaddr
1113
1214import hivemind
1315from hivemind .averaging .control import AveragingStage
@@ -227,8 +229,10 @@ def test_progress_tracker():
227229 finished_evt = mp .Event ()
228230 emas = mp .Array (ctypes .c_double , 5 )
229231
230- def run_worker (index : int , batch_size : int , period : float , ** kwargs ):
231- dht = hivemind .DHT (initial_peers = dht_root .get_visible_maddrs (), start = True )
232+ root_maddrs = dht_root .get_visible_maddrs ()
233+
234+ def run_worker (index : int , batch_size : int , step_time : float , initial_peers : List [Multiaddr ]):
235+ dht = hivemind .DHT (initial_peers = initial_peers , start = True )
232236 tracker = ProgressTracker (
233237 dht ,
234238 prefix ,
@@ -238,18 +242,17 @@ def run_worker(index: int, batch_size: int, period: float, **kwargs):
238242 default_refresh_period = 0.2 ,
239243 max_refresh_period = 0.5 ,
240244 private_key = RSAPrivateKey (),
241- ** kwargs ,
242245 )
246+ with tracker .pause_updates ():
247+ barrier .wait ()
248+ if index == 4 :
249+ delayed_start_evt .wait ()
243250
244- barrier .wait ()
245- if index == 4 :
246- delayed_start_evt .wait ()
247-
248- local_epoch = 2 if index == 4 else 0
249- samples_accumulated = 0
251+ local_epoch = 2 if index == 4 else 0
252+ samples_accumulated = 0
250253
251254 while True :
252- time .sleep (period )
255+ time .sleep (step_time )
253256 if finished_evt .is_set ():
254257 break
255258
@@ -270,10 +273,10 @@ def run_worker(index: int, batch_size: int, period: float, **kwargs):
270273 dht .shutdown ()
271274
272275 workers = [
273- mp .Process (target = run_worker , kwargs = dict (index = 1 , batch_size = 12 , period = 0.6 )),
274- mp .Process (target = run_worker , kwargs = dict (index = 2 , batch_size = 16 , period = 0.5 )),
275- mp .Process (target = run_worker , kwargs = dict (index = 3 , batch_size = 24 , period = 0.4 )),
276- mp .Process (target = run_worker , kwargs = dict (index = 4 , batch_size = 64 , period = 0.4 )),
276+ mp .Process (target = run_worker , kwargs = dict (index = 1 , batch_size = 12 , step_time = 0.6 , initial_peers = root_maddrs )),
277+ mp .Process (target = run_worker , kwargs = dict (index = 2 , batch_size = 16 , step_time = 0.5 , initial_peers = root_maddrs )),
278+ mp .Process (target = run_worker , kwargs = dict (index = 3 , batch_size = 24 , step_time = 0.2 , initial_peers = root_maddrs )),
279+ mp .Process (target = run_worker , kwargs = dict (index = 4 , batch_size = 64 , step_time = 0.2 , initial_peers = root_maddrs )),
277280 ]
278281 for worker in workers :
279282 worker .start ()
@@ -336,7 +339,7 @@ def run_worker(index: int, batch_size: int, period: float, **kwargs):
336339 (False , True , True , True , True ),
337340 (False , True , True , False , True ),
338341 (True , False , False , False , False ),
339- (True , True , False , False , False , ),
342+ (True , True , False , False , False ),
340343 ],
341344 # fmt: on
342345)
@@ -359,6 +362,8 @@ def test_optimizer(
359362def _test_optimizer (
360363 num_peers : int = 1 ,
361364 num_clients : int = 0 ,
365+ default_batch_size : int = 4 ,
366+ default_batch_time : int = 0.1 ,
362367 target_batch_size : int = 32 ,
363368 total_epochs : int = 3 ,
364369 use_local_updates : bool = False ,
@@ -422,20 +427,21 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
422427
423428 prev_time = time .perf_counter ()
424429
425- time .sleep (1.0 )
426430 optimizer .shutdown ()
427431 return optimizer
428432
429433 peers = []
430434
431435 for index in range (num_peers ):
436+ peer_batch_size = default_batch_size + index
437+ peer_batch_time = default_batch_time + 0.01 * index
432438 peers .append (
433439 mp .Process (
434440 target = run_trainer ,
435441 name = f"trainer-{ index } " ,
436442 kwargs = dict (
437- batch_size = 4 + index ,
438- batch_time = 0.3 + 0.2 * index ,
443+ batch_size = peer_batch_size ,
444+ batch_time = peer_batch_time ,
439445 client_mode = (index >= num_peers - num_clients ),
440446 ),
441447 )
@@ -451,7 +457,12 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
451457 assert optimizer .local_epoch == optimizer .tracker .global_epoch == total_epochs
452458 expected_samples_accumulated = target_batch_size * total_epochs
453459 assert expected_samples_accumulated <= total_samples_accumulated .value <= expected_samples_accumulated * 1.2
454- assert 4 / 0.3 * 0.8 <= optimizer .tracker .performance_ema .samples_per_second <= 4 / 0.3 * 1.2
460+ expected_performance = default_batch_size / default_batch_time
461+ assert (
462+ expected_performance * 0.8
463+ <= optimizer .tracker .performance_ema .samples_per_second
464+ <= expected_performance * 1.2
465+ )
455466
456467 assert not optimizer .state_averager .is_alive ()
457468 assert not optimizer .tracker .is_alive ()
0 commit comments