import os import json import torch from safetensors.torch import save_file, load_file def convert_bin_to_safetensors(bin_path, output_dir=None): if output_dir is None: output_dir = os.path.dirname(bin_path) or "." print(f"Loading weights from {bin_path}...") state_dict = torch.load(bin_path, map_location="cpu", weights_only=True) if not isinstance(state_dict, dict): raise ValueError("The loaded file does not contain a valid state_dict.") base_name = os.path.splitext(os.path.basename(bin_path))[0] sf_path = os.path.join(output_dir, f"{base_name}.safetensors") print(f"Saving to {sf_path}") save_file(state_dict, sf_path) print(f"✅ Converted {bin_path} -> {sf_path}") return sf_path def generate_index_json(output_dir, bin_files, index_filename="model.safetensors.index.json"): index_path = os.path.join(output_dir, index_filename) weight_map = {} total_size = 0 for bin_file in bin_files: base_name = os.path.splitext(os.path.basename(bin_file))[0] sf_file = f"{base_name}.safetensors" sf_path = os.path.join(output_dir, sf_file) if not os.path.exists(sf_path): raise FileNotFoundError(f"Converted file not found: {sf_path}") try: tensors = load_file(sf_path) except Exception as e: raise RuntimeError(f"Failed to read tensors from {sf_path}: {e}") for key in tensors: weight_map[key] = sf_file total_size += os.path.getsize(sf_path) index_data = { "metadata": { "total_size": total_size }, "weight_map": weight_map } with open(index_path, "w", encoding="utf-8") as f: json.dump(index_data, f, indent=2, ensure_ascii=False) print(f"✅ Created index file at {index_path}") return index_path if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Convert multiple .bin files to .safetensors and generate index.json") parser.add_argument("--input_dir", type=str, required=True, help="Directory containing .bin files (e.g., pytorch_model-*.bin)") parser.add_argument("--output_dir", type=str, default=None, help="Output directory for converted files. Default: same as input_dir.") args = parser.parse_args() input_dir = args.input_dir output_dir = args.output_dir if args.output_dir else input_dir # 自动查找所有 .bin 文件 bin_files = sorted([ os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.startswith("pytorch_model") and f.endswith(".bin") ]) if not bin_files: raise FileNotFoundError(f"No .bin files found in {input_dir}") print(f"Found {len(bin_files)} .bin files. Starting conversion...") # 转换所有 .bin 文件 safetensor_files = [] for bin_path in bin_files: sf_path = convert_bin_to_safetensors(bin_path, output_dir) safetensor_files.append(sf_path) # 生成 index.json generate_index_json(output_dir, bin_files) print("🎉 All conversions completed successfully!")