def load_weights_and_online_quantize(model_loader: DefaultModelLoader,
model: nn.Module,
model_config: ModelConfig) -> set[str]:
# online quantization, right now only enabled for
# torchao
# R1, R2, R3, R4 in the Notes
# TODO: Add fp8 support
assert model_config.quantization == "torchao", "online " \
"quantization is only enabled for torchao currently"
# TODO: use create_weights to restore the weights to original state
# Step R1: First restore the quantized weights to original bfloat16
# weights, with original metadata (shape, dtype, device)
# and attributes, so that bfloat16 weights can be loaded properly
existing_param_names = dict(
model.named_parameters(remove_duplicate=False)).keys()
named_modules = dict(model.named_modules(remove_duplicate=False))
model_device = None
# Step R2: recover the parameter to the state before first loading
for name, d in model.original_weights_rebuild_keys.items():
_shape = d["shape"]
_dtype = d["dtype"]
_device = d["device"]
if model_device is not None:
assert model_device == _device, "Expecting all weights " \
"to be in the same device for now, got both: " \
f"{model_device} and {_device}"
else:
model_device = _device
if name in existing_param_names:
module_name, weight_name = name.rsplit(".", 1)
module = named_modules[module_name]
setattr(
module, weight_name,
torch.nn.Parameter(
torch.empty(_shape, dtype=_dtype, device=_device)))
# recorded_weight_attr is
# {"weight_name": {"weight_attr_key": attr}}
# e.g.
# {
# {
# "layer.0.weight": {
# "weight_loader": weight_loader_function_object,
# "input_dim": 0, ...
# },
# "layer.1.weight": ...,
# }
# }
for full_weight_name, weight_attr_dict in \
model.recorded_weight_attr.items():
for attr_name, attr in weight_attr_dict.items():
module_name, weight_name = full_weight_name.rsplit(".", 1)
module = named_modules[module_name]
weight = getattr(module, weight_name)
if not hasattr(weight, attr_name):
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
# Step I1: reload bfloat16 / high precision weights
loaded_weights = model.load_weights(
model_loader.get_all_weights(model_config, model))
# Step I2: online quantize the weights
# manually process weights after loading
model.process_weights_after_loading_already_called = False
process_weights_after_loading(model, model_config, model_device)
model.process_weights_after_loading_already_called = True
return loaded_weights