Skip to content

pytorch_trainer

PyTorchTrainer(model)

Bases: Trainer

implement the 'get' and 'set' function for the usual pytorch trainer.

Source code in iflearner/business/homo/pytorch_trainer.py
def __init__(self, model: torch.nn.Module) -> None:
    self._model = model

get(param_type=Trainer.ParameterType.ParameterModel)

get parameters form the client, maybe the model parameter or gradient.

Parameters:

Name Type Description Default
param_type Trainer.ParameterType

Param_type is ParameterModel or ParameterGradient, default is ParameterModel.

Trainer.ParameterType.ParameterModel

Returns:

Type Description
Dict[str, npt.NDArray[np.float32]]

dict, k: str (the parameter name), v: np.ndarray (the parameter value)

Source code in iflearner/business/homo/pytorch_trainer.py
def get(
    self, param_type: Trainer.ParameterType = Trainer.ParameterType.ParameterModel
) -> Dict[str, npt.NDArray[np.float32]]:  # type: ignore
    """get parameters form the client, maybe the model parameter or
    gradient.

    Args:
        param_type: Param_type is ParameterModel or ParameterGradient, default is ParameterModel.

    Returns:
        dict, k: str (the parameter name), v: np.ndarray (the parameter value)
    """
    parameters = dict()
    for name, param in self._model.named_parameters():
        if param.requires_grad:
            if param_type == self.ParameterType.ParameterModel:
                parameters[name] = param.cpu().detach().numpy()
            else:
                parameters[name] = param.grad.cpu().detach().numpy()

    return parameters

set(parameters, param_type=Trainer.ParameterType.ParameterModel)

set parameters to the client, maybe the model parameter or gradient.

Parameters:

Name Type Description Default
parameters Dict[str, npt.NDArray[np.float32]]

Parameters is the same as the return of 'get' function.

required
param_type Trainer.ParameterType

Param_type is ParameterModel or ParameterGradient, default is ParameterModel.

Trainer.ParameterType.ParameterModel
Source code in iflearner/business/homo/pytorch_trainer.py
def set(
    self,
    parameters: Dict[str, npt.NDArray[np.float32]],  # type: ignore
    param_type: Trainer.ParameterType = Trainer.ParameterType.ParameterModel,
) -> None:
    """set parameters to the client, maybe the model parameter or gradient.

    Args:
        parameters: Parameters is the same as the return of 'get' function.
        param_type: Param_type is ParameterModel or ParameterGradient, default is ParameterModel.

    Returns: None
    """
    for name, param in self._model.named_parameters():
        if param.requires_grad:
            if param_type == self.ParameterType.ParameterModel:
                param.data.copy_(torch.from_numpy(parameters[name]))
            else:
                param.grad.copy_(torch.from_numpy(parameters[name]))