A Scala 2 port of Andrej Karpathy's llama2.c

August 14, 2023 ยท View on GitHub

This is a Scala port of Andrej Karpathy's llama2.c, a bare bones implementation to run inference of models with a Llama-like transformer-based LLM architecture.

The code expects tokenizer.bin and stories15M.bin in the current directory.

This started as a port of the original code in pure Scala. Later, more high-level abstractions were added and low-level C kernels with AVX2 intrinsics to speed up matrix multiplication.

asciicast

Features:

  • Two implementations of the model architecture are available:
    • Llama2SimpleTransformer which is a direct port of the original C code
    • Llama2TensorTransformer which uses a Tensor abstraction to make the code more readable
  • Different matrix multiplication kernels:
    • ScalaMathImplementation (direct port of llama2.c in pure Scala)
    • AVX2MathImplementation (using JNI to call kernels written with C SIMD intrinsics)
  • Models are mapped into memory to avoid loading them into the JVM heap
  • Quantization modes:
    • use the weights as given in the model
    • Q8: quantize weights after loading to 8 bits (all but rmsnorm)
    • Q4: quantize weights after loading to 4 bits (all but rmsnorm)
  • Multi-threading:
    • AVX2MathImplementation uses OpenMP
  • Support for loading ggml models
    • only weights in Q4_0 and FP32 are supported
  • scala-native support (mostly broken right now)

Performance

Current numbers run with version 08c65d04 on my AMD Ryzen 7 4800H laptop with GraalVM JDK 17.

Implementations:

  • Scala = Llama2TensorTransformer with ScalaMathImplementation
  • native-avx2 = Llama2TensorTransformer with AVX2MathImplementation (using JNI to call kernels written with C SIMD intrinsics)
  • llama2.c = as of 94a3a5e0
  • llama.cpp = as of d783f798
  • scala-native = Using scala-native 0.4.14 with the Llama2SimpleTransformer implementation

Notes:

  • Approximate speedups are:
    • pure Scala -> AVX2: > 10x
    • FP32 -> Q8/Q4 (in Scala): same speed
    • FP32 -> Q8 (AVX2): ~ 2x
    • Q8 -> Q4 (AVX2) on one thread: same speed
    • Q4 1 thread -> 6 threads on small models: ~ 2x
    • Q4 1 thread -> 6 threads on large models: ~ 3x
  • The pure Scala mode GraalVM JDK 17 is only competitive with a llama2.c version compiled with -O3. Using -Ofast on C already makes a huge difference. Would be interesting to see the exact differences between JIT compiled code and gcc output with -Ofast. Not sure if something like -Ofast (using less strict FP math) is possible on the JVM.
  • Using (i.e. mostly adapting from llama.cpp) kernels in C with SIMD intrinsics and calling them with JNI from Scala makes a huge difference. It is easy to do locally, but, of course, much harder to do in a portable way.
  • As expected, quantization gives another boost. Interesting that it is more pronounced when multi-threading is enabled.
  • OMP-based multithreading is simple to use from C and helps a lot. Scaling is not perfect, with benefits diminishing sharply after using more than 6 (of 8) threads.
  • Multithreading is interesting, as the task units are quite small (one matrix multiplication) and overheads can be significant.
  • Quantization only helps with SIMD optimization. Only SIMD will give access to byte-wise (int8) operations and decreasing the data type size will increase the number of lanes per vector with the same factor. It is unclear why going from 32-bit to 8-bit gives only a 2x speedup while being able to run 4x more operations in parallel. One explanation could be that you need more instructions because of the added complexity of quantization.
ModelQuantizationImplementationThreadstok / s
stories15M.binQ4native-avx21494
stories15M.binQ4native-avx26931
stories15M.binQ4Scala165
stories15M.binQ8native-avx21533
stories15M.binQ8native-avx26800
stories15M.binQ8Scala157
stories15M.binnonenative-avx21374
stories15M.binnonenative-avx26677
stories15M.binnoneScala166
stories15M.binnonescala-native vanilla114
stories15M.binnonescala-native (native mmaps)150
stories42M.binQ4native-avx21223
stories42M.binQ4native-avx26497
stories42M.binQ4Scala124
stories42M.binQ8native-avx21229
stories42M.binQ8native-avx26407
stories42M.binQ8Scala122
stories42M.binnonenative-avx21137
stories42M.binnonenative-avx26243
stories42M.binnoneScala124
stories42M.binnonellama2.c / run121
stories42M.binnonellama2.c / runfast169
stories42M.binnonellama2.c / runomp198
stories42M.binnonellama2.c / runomp6195
stories110M.binQ4native-avx2195
stories110M.binQ4native-avx26239
stories110M.binQ4Scala19.6
stories110M.binQ8native-avx2199
stories110M.binQ8native-avx26183
stories110M.binQ8Scala18.4
stories110M.binnonenative-avx2150
stories110M.binnonenative-avx2685
stories110M.binnoneScala18.9
stories110M.binnonellama2.c / runomp677
llama2_7b.binQ4native-avx212.0
llama2_7b.binQ4native-avx266.5
llama2_7b.binQ4Scala10.16
llama2_7b.binQ8native-avx211.9
llama2_7b.binQ8native-avx264.46
llama2_7b.binQ8Scala10.14
llama-2-7b.ggmlv3.q4_0.binas providednative-avx211.66
llama-2-7b.ggmlv3.q4_0.binas providednative-avx266.71
llama-2-7b.ggmlv3.q4_0.binas providedScala10.13
llama-2-7b.ggmlv3.q4_0.binas providedllama.cpp12.0
llama-2-7b.ggmlv3.q4_0.binas providedllama.cpp68.1

License

MIT