Skip to content

mxnet_trainer

MxnetTrainer(model)

Bases: Trainer

implement the 'get' and 'set' function for the usual mxnet trainer.

Source code in iflearner/business/homo/mxnet_trainer.py
def __init__(self, model: mx.gluon.nn.Sequential) -> None:
    self._model = model

get(param_type=Trainer.ParameterType.ParameterModel)

get parameters form the client, maybe the model parameter or gradient.

Parameters:

Name Type Description Default
param_type Trainer.ParameterType

Param_type is ParameterModel or ParameterGradient, default is ParameterModel.

Trainer.ParameterType.ParameterModel

Returns:

Type Description
Dict[str, npt.NDArray[np.float32]]

dict, k: str (the parameter name), v: np.ndarray (the parameter value)

Source code in iflearner/business/homo/mxnet_trainer.py
def get(
    self, param_type: Trainer.ParameterType = Trainer.ParameterType.ParameterModel
) -> Dict[str, npt.NDArray[np.float32]]:  # type: ignore
    """get parameters form the client, maybe the model parameter or
    gradient.

    Args:
        param_type: Param_type is ParameterModel or ParameterGradient, default is ParameterModel.

    Returns:
        dict, k: str (the parameter name), v: np.ndarray (the parameter value)
    """
    parameters = dict()
    for key, val in self._model.collect_params(".*weight").items():
        p = val.data().asnumpy()
        parameters[key] = p
    return parameters

set(parameters, param_type=Trainer.ParameterType.ParameterModel)

set parameters to the client, maybe the model parameter or gradient.

Parameters:

Name Type Description Default
parameters Dict[str, npt.NDArray[np.float32]]

Parameters is the same as the return of 'get' function.

required
param_type Trainer.ParameterType

Param_type is ParameterModel or ParameterGradient, default is ParameterModel.

Trainer.ParameterType.ParameterModel
Source code in iflearner/business/homo/mxnet_trainer.py
def set(
    self,
    parameters: Dict[str, npt.NDArray[np.float32]],  # type: ignore
    param_type: Trainer.ParameterType = Trainer.ParameterType.ParameterModel,
) -> None:
    """set parameters to the client, maybe the model parameter or gradient.

    Args:
        parameters: Parameters is the same as the return of 'get' function.
        param_type: Param_type is ParameterModel or ParameterGradient, default is ParameterModel.

    Returns: None
    """
    for key, value in parameters.items():
        self._model.collect_params().setattr(key, value)