初始化
This commit is contained in:
@@ -0,0 +1,100 @@
|
||||
from torch.nn.parallel import DataParallel
|
||||
import torch
|
||||
from torch.nn.parallel._functions import Scatter
|
||||
from torch.nn.parallel.parallel_apply import parallel_apply
|
||||
|
||||
|
||||
def scatter(inputs, target_gpus, chunk_sizes, dim=0):
|
||||
r"""
|
||||
Slices tensors into approximately equal chunks and
|
||||
distributes them across given GPUs. Duplicates
|
||||
references to objects that are not tensors.
|
||||
"""
|
||||
def scatter_map(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
try:
|
||||
return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
|
||||
except:
|
||||
print('obj', obj.size())
|
||||
print('dim', dim)
|
||||
print('chunk_sizes', chunk_sizes)
|
||||
quit()
|
||||
if isinstance(obj, tuple) and len(obj) > 0:
|
||||
return list(zip(*map(scatter_map, obj)))
|
||||
if isinstance(obj, list) and len(obj) > 0:
|
||||
return list(map(list, zip(*map(scatter_map, obj))))
|
||||
if isinstance(obj, dict) and len(obj) > 0:
|
||||
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
|
||||
return [obj for targets in target_gpus]
|
||||
|
||||
# After scatter_map is called, a scatter_map cell will exist. This cell
|
||||
# has a reference to the actual function scatter_map, which has references
|
||||
# to a closure that has a reference to the scatter_map cell (because the
|
||||
# fn is recursive). To avoid this reference cycle, we set the function to
|
||||
# None, clearing the cell
|
||||
try:
|
||||
return scatter_map(inputs)
|
||||
finally:
|
||||
scatter_map = None
|
||||
|
||||
|
||||
def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
|
||||
r"""Scatter with support for kwargs dictionary"""
|
||||
inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
|
||||
kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
|
||||
if len(inputs) < len(kwargs):
|
||||
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
|
||||
elif len(kwargs) < len(inputs):
|
||||
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
|
||||
inputs = tuple(inputs)
|
||||
kwargs = tuple(kwargs)
|
||||
return inputs, kwargs
|
||||
|
||||
|
||||
class BalancedDataParallel(DataParallel):
|
||||
def __init__(self, gpu0_bsz, *args, **kwargs):
|
||||
self.gpu0_bsz = gpu0_bsz
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
if not self.device_ids:
|
||||
return self.module(*inputs, **kwargs)
|
||||
if self.gpu0_bsz == 0:
|
||||
device_ids = self.device_ids[1:]
|
||||
else:
|
||||
device_ids = self.device_ids
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
|
||||
# print('len(inputs)1: ', str(len(inputs)))
|
||||
# print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)]))
|
||||
if len(self.device_ids) == 1:
|
||||
return self.module(*inputs[0], **kwargs[0])
|
||||
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
|
||||
if self.gpu0_bsz == 0:
|
||||
replicas = replicas[1:]
|
||||
outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
|
||||
return self.gather(outputs, self.output_device)
|
||||
|
||||
def parallel_apply(self, replicas, device_ids, inputs, kwargs):
|
||||
return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)])
|
||||
|
||||
def scatter(self, inputs, kwargs, device_ids):
|
||||
bsz = inputs[0].size(self.dim)
|
||||
num_dev = len(self.device_ids)
|
||||
gpu0_bsz = self.gpu0_bsz
|
||||
bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
|
||||
if gpu0_bsz < bsz_unit:
|
||||
chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
|
||||
delta = bsz - sum(chunk_sizes)
|
||||
for i in range(delta):
|
||||
chunk_sizes[i + 1] += 1
|
||||
if gpu0_bsz == 0:
|
||||
chunk_sizes = chunk_sizes[1:]
|
||||
else:
|
||||
return super().scatter(inputs, kwargs, device_ids)
|
||||
|
||||
# print('bsz: ', bsz)
|
||||
# print('num_dev: ', num_dev)
|
||||
# print('gpu0_bsz: ', gpu0_bsz)
|
||||
# print('bsz_unit: ', bsz_unit)
|
||||
# print('chunk_sizes: ', chunk_sizes)
|
||||
return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
|
||||
Reference in New Issue
Block a user