README.md
February 17, 2025 · View on GitHub
Torch2Jax-DeepSeek-R1-Distill-Qwen-1.5B
Flax (JAX) implementation of DeepSeek-R1-Distill-Qwen-1.5B with weights ported from Hugging Face.
Report Bug · Request Feature · Colab
Overview
Colab: https://colab.research.google.com/drive/1jJaAARwbsFeV5hZoffrNwFhc8i2I-7ji?usp=sharing
This repository provides both Flax (JAX) and PyTorch implementations of the DeepSeek-R1-Distill-Qwen-1.5B model. It includes:
-
Inference [QUICKSTART]:
inference.ipynb: Contains a quickstart script to download and convert params from torch to flax, load model and perform text generation.
-
Flax Implementations:
model_flax.py: The Flax implementation.
-
PyTorch Implementation:
model_torch.py: A reference implementation in PyTorch.
-
Conversion Script:
torch_to_flax.py: A utility to convert a PyTorch checkpoint (state dictionary) into Flax parameters.
System Requirements
Single GPU
16GB VRAM on the GPU + 64GB RAM (this can be swap)
Multi-Device
Runs sharded on v2-8 TPU on Google Colab.