import json import os from safetensors import safe_open from safetensors.numpy import save_file def modify_telechat_streaming( input_dir: str, output_dir: str, modify_fn, shard_size_limit: int = 9 * 1024 * 1024 * 1024, # 9GB ): """ 流式修改 TeleChat2-35B 模型权重,避免内存爆炸 Args: input_dir: 输入目录 output_dir: 输出目录 modify_fn: 函数, key, tensor -> (new_key, new_tensor) 或 (None, None) 删除 shard_size_limit: 每个分片最大字节数 """ os.makedirs(output_dir, exist_ok=True) # 1. Copy non-weight files for fname in os.listdir(input_dir): if not fname.startswith("model-") and not fname.startswith("model."): src = os.path.join(input_dir, fname) dst = os.path.join(output_dir, fname) if os.path.isfile(src): print(f"Copying {fname}...") os.system(f"cp {src} {dst}") # 2. Read index.json file index_file = "model.safetensors.index.json" index_path = os.path.join(input_dir, index_file) with open(index_path, 'r') as f: index_data = json.load(f) config_file = "config.json" config_path = os.path.join(input_dir, config_file) with open(config_path, 'r') as f: config_data = json.load(f) metadata = index_data.get("metadata", {}) weight_map = index_data["weight_map"] # key -> shard_file # 3. Collect all shard files shard_files = sorted([f for f in os.listdir(input_dir) if f.endswith(".safetensors")]) if not shard_files: raise FileNotFoundError("No shard files found") # 4. Build reverse mapping: shard_file -> [keys] shard_to_keys = {f: [] for f in shard_files} for key, shard_file in weight_map.items(): if shard_file not in shard_to_keys: raise ValueError(f"Shard file {shard_file} not found in directory") shard_to_keys[shard_file].append(key) # 5. Process each shard in a streaming manner new_weight_map = {} current_shard_idx = 1 total_shards = len(shard_files) current_shard_name = f"model-{current_shard_idx:05d}-of-{total_shards:05d}.safetensors" current_data = {} current_size = 0 def flush_current_shard(): nonlocal current_shard_idx, current_data, current_size if not current_data: return # 保存当前分片 save_file(current_data, os.path.join(output_dir, current_shard_name)) print(f"Saved {current_shard_name} ({len(current_data)} tensors)") # 重置 current_data.clear() current_size = 0 def switch_shard(): nonlocal current_shard_idx, current_shard_name flush_current_shard() current_shard_idx += 1 current_shard_name = f"model-{current_shard_idx:05d}-of-{total_shards:05d}.safetensors" # Iterate through each original shard file for shard_file in shard_files: shard_path = os.path.join(input_dir, shard_file) keys_in_shard = shard_to_keys[shard_file] print(f"Processing {shard_file} with {len(keys_in_shard)} keys...") # Only open the current shard with safe_open(shard_path, framework="np") as f: for key in keys_in_shard: tensor = f.get_tensor(key) param_dict = modify_fn(key, tensor) for new_key, new_tensor in param_dict.items(): if new_tensor is None: continue tensor_size = new_tensor.size * new_tensor.itemsize # Check if a new shard needs to be split. if current_data and (current_size + tensor_size > shard_size_limit): switch_shard() current_data[new_key] = new_tensor new_weight_map[new_key] = current_shard_name current_size += tensor_size # write into the final shard flush_current_shard() # 6. recompute total_shards final_shard_files = [f for f in os.listdir(output_dir) if f.startswith("model-") and f.endswith(".safetensors")] final_nshards = len(final_shard_files) # rename and rebuild index final_weight_map = {} for old_name in sorted(final_shard_files): # Analyze the original index try: idx = int(old_name.split('-')[1]) except: idx = 1 new_name = f"model-{idx:05d}-of-{final_nshards:05d}.safetensors" if new_name != old_name: os.rename( os.path.join(output_dir, old_name), os.path.join(output_dir, new_name) ) # update index with safe_open(os.path.join(output_dir, new_name), framework="np") as f: for key in f.keys(): final_weight_map[key] = new_name # 7. 保存新 index new_index = { "metadata": metadata, "weight_map": final_weight_map } index_out = os.path.join(output_dir, "model.safetensors.index.json") with open(index_out, 'w') as f: json.dump(new_index, f, indent=2) print(f"✅ Done! Saved to {output_dir}") print(f" {final_nshards} shards, {len(final_weight_map)} weights") def modify_fn(key, tensor): param_dict = {} kv_value = tensor num_heads = 32 #config.num_attention_heads n_kv_heads = 32 #config.num_query_groups or num_heads hidden_size = 4096 #config.hidden_size head_dim = hidden_size // num_heads n_rep = num_heads // n_kv_heads kv_channel = hidden_size // n_rep if ".key_value." in key: kv_value = kv_value.reshape(n_kv_heads, 2*head_dim, -1) k_key = key.replace("key_value", "key") k_weight = kv_value[:, :head_dim, :] k_weight = k_weight.reshape(kv_channel, -1) v_key = key.replace("key_value", "value") v_weight = kv_value[:, head_dim:head_dim*2, :] v_weight = v_weight.reshape(kv_channel, -1) param_dict[k_key] = k_weight param_dict[v_key] = v_weight else: param_dict[key] = tensor return param_dict modify_telechat_streaming( input_dir="/nfs/dataset/workspace/mindspore_dataset/weight/telechat2_7b_tmp", output_dir="/nfs/dataset/workspace/mindspore_dataset/weight/telechat2_7b", modify_fn=modify_fn, shard_size_limit= 5 * 1024 * 1024 * 1024, # 9GB )