# 补充shard_checkpoint方法,transformers4.47.0版本以后删除了 def shard_checkpoint( state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME ): """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. 千问 Qwen 教程 The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. <Tip warning={true}> If one of the model's weight is bigger than `max_shard_size`, it will end up in its own sub-checkpoint which will have a size greater than `max_shard_size`. </Tip> Args: state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save. max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`): The name of the model save file. """ logger.warning( "Note that `shard_checkpoint` is deprecated and will be removed in v4.44. We recommend you using " "split_torch_state_dict_into_shards from huggingface_hub library" ) max_shard_size = convert_file_size_to_int(max_shard_size) sharded_state_dicts = [{}] last_block_size = 0 total_size = 0 storage_id_to_block = {} for key, weight in state_dict.items(): # when bnb serialization is used the weights in the state dict can be strings # check: https://github.com/huggingface/transformers/pull/24416 for more details if isinstance(weight, str): continue else: storage_id = id_tensor_storage(weight) # If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block` if storage_id in storage_id_to_block and weight.device != torch.device("meta"): block_id = storage_id_to_block[storage_id] sharded_state_dicts[block_id][key] = weight continue weight_size = weight.numel() * dtype_byte_size(weight.dtype) # If this weight is going to tip up over the maximal size, we split, but only if we have put at least one # weight in the current shard. if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0: sharded_state_dicts.append({}) last_block_size = 0 sharded_state_dicts[-1][key] = weight last_block_size += weight_size total_size += weight_size storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1 # If we only have one shard, we return it if len(sharded_state_dicts) == 1: return {weights_name: sharded_state_dicts[0]}, None # Otherwise, let's build the index weight_map = {} shards = {} for idx, shard in enumerate(sharded_state_dicts): shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") shard_file = shard_file.replace( ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" ) shards[shard_file] = shard for key in shard.keys(): weight_map[key] = shard_file # Add the metadata metadata = {"total_size": total_size} index = {"metadata": metadata, "weight_map": weight_map} return shards, index
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。
发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/257095.html原文链接:https://javaforall.net
