Differentially private machine learning at scale with JAX-Privacy
November 12, 2025
Borja Balle, Staff Research Scientist, Google DeepMind, and Ryan McKenna, Senior Research Scientist, Google Research
We announce the release of JAX-Privacy 1.0, a library for differentially private machine learning on the high-performance computing library, JAX.
—
From personalized recommendations to scientific advances, AI models are helping to improve lives and transform industries. But the impact and accuracy of these AI models are often determined by the quality of data they use. Large, high-quality datasets are crucial for developing accurate and representative AI models; however, they must be used in ways that preserve individual privacy.
That’s where [JAX](https://docs.jax.dev/en/latest/) and [JAX-Privacy](https://github.com/google-deepmind/jax_privacy) come in. Introduced in 2020, JAX is a high-performance numerical computing library designed for large-scale machine learning (ML). Its core features — including automatic differentiation, just-in-time compilation, and seamless scaling across multiple accelerators — make it an ideal platform for building and training complex models efficiently. JAX [has become a cornerstone](https://github.com/jax-ml/jax/network/dependents) for researchers and engineers pushing the boundaries of AI. Its surrounding ecosystem includes a robust set of domain-specific libraries, such as [Flax](https://flax.readthedocs.io/en/latest/index.html), which simplifies the implementation of neural network architectures, and [Optax](https://github.com/google-deepmind/optax), which implements state-of-the-art optimizers.
Built on JAX, JAX-Privacy is a robust toolkit for building and auditing differentially private models. It enables researchers and developers to quickly and efficiently implement [differentially private](https://en.wikipedia.org/wiki/Differential_privacy) (DP) algorithms for training deep learning models on large datasets, providing core tools to integrate private training into modern distributed training workflows. The original version of JAX-Privacy was introduced in 2022 to enable external researchers to reproduce and validate some of our [advances on private training](https://github.com/google-deepmind/jax_privacy#reproducing-results). It has since evolved into a hub where research teams across Google integrate their novel research insights into DP training and auditing algorithms.
Today, we are proud to announce the release of [JAX-Privacy 1.0](https://github.com/google-deepmind/jax_privacy). Integrating our latest research advances and redesigned for modularity, this version makes it easier than ever for researchers and developers to build DP training pipelines that combine state-of-the-art DP algorithms with the scalability provided by JAX.
—
### How we got here: The need for JAX-Privacy
For years, researchers have turned to DP as the gold standard for quantifying and bounding privacy leakage. DP guarantees that the output of an algorithm is nearly the same whether or not a single individual (or example) is included in the dataset.
While the theory of DP is well-established, its practical implementation in large-scale ML can be challenging. The most common approach, [differentially private stochastic gradient descent](https://arxiv.org/abs/1607.00133) (DP-SGD), requires customized batching procedures, per-example gradient clipping, and the addition of carefully calibrated noise. This process is computationally intensive and can be difficult to implement correctly and efficiently, especially at the scale of modern foundation models.

*JAX-Privacy enables researchers and developers to train and fine-tune foundation models on private data using state-of-the-art differentially private algorithms in a scalable and efficient way thanks to its primitive building blocks for gradient clipping and correlated noise generation, both of which work effectively in distributed environments.*
—
Existing frameworks have made strides, but they often fall short in scalability or flexibility. Our work has consistently pushed the boundaries of private ML, from [pioneering new DP algorithms](https://arxiv.org/abs/1607.00133) to [developing sophisticated auditing techniques](https://arxiv.org/abs/2302.07956). We needed a tool that could keep pace with our research — a library that was not only correct and efficient but also designed from the ground up to handle the parallelism and complexity of state-of-the-art models.
JAX’s functional paradigm and powerful transformations, like `vmap` (for automatic vectorization) and `shard_map` (for single-program multiple-data parallelization), provided a strong foundation. By building on JAX, we could create a library that was parallelism-ready out-of-the-box, supporting the training of large-scale models across multiple accelerators and supercomputers. JAX-Privacy is the culmination of this effort, a time-tested library that has powered internal production integrations and is now being shared with the broader community.
—
### What JAX-Privacy delivers
JAX-Privacy simplifies the complexities of DP by providing a suite of carefully engineered components:
– **Core building blocks**: The library offers correct and efficient implementations of fundamental DP primitives, including [per-example gradient clipping](https://github.com/google-deepmind/jax_privacy/blob/main/jax_privacy/clipping.py), [noise addition](https://github.com/google-deepmind/jax_privacy/blob/main/jax_privacy/noise_addition.py), and [data batch construction](https://github.com/google-deepmind/jax_privacy/blob/main/jax_privacy/batch_selection.py). These components enable developers to build well-known algorithms like DP-SGD and [DP-FTRL](https://arxiv.org/abs/2103.00039) with confidence.
– **State-of-the-art algorithms**: JAX-Privacy goes beyond the basics, supporting advanced methods like [DP matrix factorization](https://arxiv.org/abs/2506.08201) which inject correlated noise across iterations and improve performance. This makes it easier for researchers to experiment with cutting-edge private training techniques.
– **Scalability**: All components are designed to work seamlessly with JAX’s native parallelism features, allowing training of large-scale models that require data and model parallelism without complex, custom code. JAX-Privacy also provides tools like micro-batching and padding to handle massive, variable-sized batches needed for optimal privacy/utility trade-offs.
– **Correctness and auditing**: The library is built on Google’s state-of-the-art [DP accounting library](https://github.com/google/differential-privacy/tree/main/python/dp_accounting), ensuring mathematically correct and tight noise calibration. Formal bounds on privacy loss can be complemented by metrics that quantify empirical privacy loss, providing a comprehensive view of privacy properties. Users can easily test and develop their own auditing techniques, like our award-winning work on “[Tight Auditing of Differentially Private Machine Learning](https://www.usenix.org/conference/usenixsecurity23/presentation/nasr)”, which uses “canaries” — known data points — to compute auditing metrics at each step.

*JAX-Privacy implements foundational tools for clipping, noise addition, batch selection, accounting, and auditing that combine to build end-to-end DP training pipelines.*
—
### From research to practice: Fine-tuning LLMs with confidence
One of the most exciting aspects of JAX-Privacy is its practical application. The library supports modern ML frameworks used for pre-training and fine-tuning large language models (LLMs). A notable example is our recent use of JAX-Privacy building blocks in training [VaultGemma](https://research.google/blog/vaultgemma-the-worlds-most-capable-differentially-private-llm/), the world’s most capable differentially private LLM.
With this open-source release, developers can easily fine-tune large models with just a few lines of code via the popular [Keras](https://keras.io/examples/nlp/) framework. Included are [fully-functional examples](https://github.com/google-deepmind/jax_privacy/tree/main/examples) for fine-tuning models in the [Gemma family](https://developers.googleblog.com/en/gemma-explained-overview-gemma-model-family-architectures/), a collection of open models built by Google DeepMind based on Gemini. These examples demonstrate how to apply JAX-Privacy to tasks like dialogue summarization and synthetic data generation, showing that the library can deliver state-of-the-art results even with the most advanced models.
By simplifying DP integration, JAX-Privacy empowers developers to build privacy-preserving applications from the ground up, whether fine-tuning a chatbot for healthcare or a model for personalized financial advice. It lowers the barrier to entry for privacy-preserving ML and makes powerful, responsible AI more accessible.
—
### Looking ahead
We are excited to share JAX-Privacy with the research community. This release is the result of years of dedicated effort and represents a significant contribution to privacy-preserving ML. We hope that by providing these tools, we can enable a new wave of research and innovation that benefits everyone.
We will continue to support and develop the library, incorporating new research advances and responding to the community’s needs. We look forward to seeing what you build with JAX-Privacy. Check out the [repository on GitHub](https://github.com/google-deepmind/jax_privacy) or the [PIP package](https://pypi.org/project/jax-privacy/) to start training privacy-preserving ML models today.
—
### Acknowledgements
*JAX-Privacy includes contributions from: Leonard Berrada, Robert Stanforth, Brendan McMahan, Christopher A. Choquette-Choo, Galen Andrew, Mikhail Pravilov, Sahra Ghalebikesabi, Aneesh Pappu, Michael Reneer, Jamie Hayes, Vadym Doroshenko, Keith Rush, Dj Dvijotham, Zachary Charles, Peter Kairouz, Soham De, Samuel L. Smith, Judy Hanwen Shen.*
—
### Labels:
Algorithms & Theory | Responsible AI | Security, Privacy and Abuse Prevention
—
### Other posts of interest
– [Reducing EV range anxiety: How a simple AI model predicts port availability (Nov 21, 2025)](https://research.google/blog/reducing-ev-range-anxiety-how-a-simple-ai-model-predicts-port-availability/)
– [Real-time speech-to-speech translation (Nov 19, 2025)](https://research.google/blog/real-time-speech-to-speech-translation/)
– [Separating natural forests from other tree cover with AI for deforestation-free supply chains (Nov 13, 2025)](https://research.google/blog/separating-natural-forests-from-other-tree-cover-with-ai-for-deforestation-free-supply-chains/)
—
This content is ready to be posted in WordPress without HTML formatting. Images and links are embedded where appropriate for inline use.
