Skip to content

keras_trainer

KerasTrainer(model)

Bases: Trainer

implement the 'get' and 'set' function for the usual keras trainer.

Source code in iflearner/business/homo/keras_trainer.py
def __init__(self, model: keras.models.Sequential) -> None:
    self._model = model

get(param_type=Trainer.ParameterType.ParameterModel)

get parameters form the client, maybe the model parameter or gradient.

Parameters:

Name Type Description Default
param_type Trainer.ParameterType

Param_type is ParameterModel or ParameterGradient, default is ParameterModel.

Trainer.ParameterType.ParameterModel

Returns:

Type Description
Dict[str, npt.NDArray[np.float32]]

dict, k: str (the parameter name), v: np.ndarray (the parameter value)

Source code in iflearner/business/homo/keras_trainer.py
def get(
    self, param_type: Trainer.ParameterType = Trainer.ParameterType.ParameterModel
) -> Dict[str, npt.NDArray[np.float32]]:  # type: ignore
    """get parameters form the client, maybe the model parameter or
    gradient.

    Args:
        param_type: Param_type is ParameterModel or ParameterGradient, default is ParameterModel.

    Returns:
        dict, k: str (the parameter name), v: np.ndarray (the parameter value)
    """
    parameters = dict()
    for item in self._model.layers:
        if item.name is not None:
            i = 0
            for weight in item.get_weights():
                parameters[f"{item.name}-{i}"] = weight
                i += 1
    return parameters

set(parameters, param_type=Trainer.ParameterType.ParameterModel)

set parameters to the client, maybe the model parameter or gradient.

Parameters:

Name Type Description Default
parameters Dict[str, npt.NDArray[np.float32]]

Parameters is the same as the return of 'get' function.

required
param_type Trainer.ParameterType

Param_type is ParameterModel or ParameterGradient, default is ParameterModel.

Trainer.ParameterType.ParameterModel
Source code in iflearner/business/homo/keras_trainer.py
def set(
    self,
    parameters: Dict[str, npt.NDArray[np.float32]],  # type: ignore
    param_type: Trainer.ParameterType = Trainer.ParameterType.ParameterModel,
) -> None:
    """set parameters to the client, maybe the model parameter or gradient.

    Args:
        parameters: Parameters is the same as the return of 'get' function.
        param_type: Param_type is ParameterModel or ParameterGradient, default is ParameterModel.

    Returns: None
    """
    for item in self._model.layers:
        if item.name is not None and len(item.get_weights()) > 0:
            i = 0
            weights = []
            while True:
                i_name = f"{item.name}-{i}"
                if i_name in parameters:
                    weights.append(parameters[i_name])
                else:
                    break
                i += 1
            item.set_weights(weights)