AI Frameworks
Compares major ML frameworks including TensorFlow, PyTorch, and JAX. Analyzes the framework ecosystem, trade-offs, and systems-level implications.
- Compare PyTorch, TensorFlow, and JAX across key dimensions like ease of use and performance
- Explain the computational graph paradigms: eager execution vs. traced compilation
- Evaluate framework selection criteria based on project requirements and team expertise
- Describe the role of ONNX in framework interoperability and model portability
- Analyze framework ecosystem maturity including tooling, community, and deployment support
01 The ML Framework Landscape Viz
Machine learning frameworks provide the foundational software layer between model definitions and hardware execution. They abstract the complexity of tensor operations, automatic differentiation, GPU memory management, and distributed computation, allowing researchers and engineers to focus on model design rather than low-level implementation details.
ML Framework
A software library that provides high-level APIs for defining, training, and deploying neural networks, handling tensor operations, automatic differentiation, and hardware acceleration transparently.
The Big Three: PyTorch, TensorFlow, and JAX
The framework landscape has consolidated around three major players: PyTorch, TensorFlow, and JAX. Each embodies different design philosophies with distinct trade-offs. PyTorch prioritizes developer experience and flexibility. TensorFlow emphasizes production deployment and ecosystem completeness. JAX focuses on functional programming, composability, and high-performance computing.
Figure: ML Framework Ecosystem Positioning Map
| Dimension | PyTorch | TensorFlow | JAX |
|---|---|---|---|
| Design Philosophy | Pythonic, eager-first | Production-first, graph-based | Functional, composable |
| Default Execution | Eager (dynamic graph) | Eager (with tf.function) | JIT-compiled via XLA |
| Primary Strength | Research flexibility | Deployment ecosystem | Composable transformations |
| Developer | Meta (Facebook) | Google Brain | Google DeepMind |
| Community Size | Largest (research) | Large (industry) | Growing (research) |
Table 7.1: High-level comparison of the three dominant ML frameworks.
As of 2026, PyTorch dominates both research and increasingly production deployments, while TensorFlow retains legacy presence in large-scale serving. JAX has established itself as the framework of choice for large-scale training at Google DeepMind and beyond. Smaller frameworks like Caffe, Theano, and MXNet have been fully superseded.
Framework choice has profound implications for the entire ML stack. It influences debugging workflows, deployment options, hardware compatibility, community support, and the availability of pre-trained models and libraries. Organizations must consider not just the current feature set but the long-term trajectory and community momentum of each framework.
Specialized Frameworks and Runtimes
Beyond the big three, specialized frameworks serve specific domains. ONNX Runtime provides cross-framework inference optimization. TensorRT optimizes for NVIDIA GPU deployment. Core ML and TFLite Micro target mobile and embedded devices respectively. Understanding this ecosystem is essential for making informed infrastructure decisions.
- ONNX Runtime — Cross-framework inference optimization via a standardized model format.
- TensorRT — NVIDIA's high-performance inference optimizer for GPU deployment.
- Core ML — Apple's framework for on-device ML on iOS, macOS, and Apple Silicon.
- TFLite Micro — TensorFlow's runtime for microcontrollers and embedded devices.
- OpenVINO — Intel's toolkit for optimizing inference on Intel hardware.
If you are starting a new project, begin with PyTorch unless you have a specific reason not to. Its dominance in research means the largest ecosystem of pre-trained models, tutorials, and community support. Consider TensorFlow for production-heavy workflows with diverse deployment targets, and JAX for research into novel training algorithms or large-scale TPU training.
The "framework wars" have settled into ecosystem niches — PyTorch for research, TensorFlow for production, JAX for HPCDeeper InsightRather than one framework winning outright, the ML community has converged on a division of labor. PyTorch dominates research and rapid prototyping thanks to its Pythonic API and dynamic graphs. TensorFlow retains its grip on production deployment with its unmatched breadth of serving targets (mobile, web, embedded, cloud). JAX has carved out the high-performance computing niche, favored by Google DeepMind and teams training the largest foundation models on TPU pods. Understanding these niches helps teams pick the right tool for their specific context. Click to collapse
ML Framework
A software library that provides high-level APIs for defining, training, and deploying neural networks, handling tensor operations, automatic differentiation, and hardware acceleration.
ONNX
Open Neural Network Exchange, an open format for representing ML models that enables interoperability between different frameworks.
02 PyTorch: Flexibility and Research
PyTorch has become the dominant framework in ML research due to its intuitive Pythonic interface and dynamic computational graph (eager execution). Every operation executes immediately, making debugging as simple as inserting print statements or using standard Python debuggers. This define-by-run approach matches how researchers think about models.
Eager Execution
A programming model where operations execute immediately when called, building the computational graph dynamically at runtime. This allows standard Python debugging tools and enables architectures with data-dependent control flow.
Autograd: Automatic Differentiation
PyTorch's autograd system automatically records operations on tensors and computes gradients through reverse-mode automatic differentiation. The computational graph is rebuilt on every forward pass, enabling dynamic architectures that vary from input to input, such as recursive neural networks or models with data-dependent control flow.
import torch
class="tok-comment"># Autograd example: compute gradients automatically
x = torch.tensor([class="tok-number">2.0, class="tok-number">3.0], requires_grad=True)
y = x ** class="tok-number">2 + class="tok-number">3 * x class="tok-comment"># y = x^class="tok-number">2 + 3x
loss = y.sum()
loss.backward() class="tok-comment"># Compute dy/dx
print(x.grad) class="tok-comment"># tensor([class="tok-number">7., class="tok-number">9.]) (2x + class="tok-number">3)One of PyTorch's greatest strengths is that you can insert standard Python breakpoints (import pdb; pdb.set_trace()) or print statements anywhere in your model. Because operations execute immediately, you can inspect intermediate tensor values, shapes, and gradients during debugging — just like debugging any Python program.
From Research to Production: torch.compile
For production deployment, PyTorch offers TorchScript (a subset of Python that can be JIT compiled) and torch.export for ahead-of-time compilation. PyTorch 2.0 introduced torch.compile, which uses a graph capture and compilation approach to achieve significant speedups without requiring code changes. This bridges the gap between PyTorch's research flexibility and production performance.
Adding a single line — model = torch.compile(model) — can speed up training by 30-50% on supported hardware. Under the hood, PyTorch captures the computational graph using TorchDynamo, optimizes it with TorchInductor, and generates efficient kernel code. No changes to the model definition or training loop are needed.
The PyTorch Ecosystem
The PyTorch ecosystem includes Hugging Face Transformers for pre-trained models, PyTorch Lightning for training boilerplate reduction, and torchvision/torchaudio/torchtext for domain-specific functionality. This rich ecosystem, combined with strong community momentum, makes PyTorch the default choice for most new ML projects.
- Hugging Face Transformers — Pre-trained models and fine-tuning pipelines for NLP, vision, and multimodal tasks.
- PyTorch Lightning — Reduces training boilerplate and standardizes project structure.
- torchvision / torchaudio / torchtext — Domain-specific datasets, transforms, and model architectures.
- FSDP (Fully Sharded Data Parallel) — Native distributed training with memory-efficient parameter sharding.
PyTorch's dynamic graph construction means that the exact sequence of operations can vary between runs when using data-dependent control flow. This can make bit-for-bit reproducibility challenging. Use torch.manual_seed() and torch.use_deterministic_algorithms(True) when exact reproducibility is required.
Eager Execution
A programming model where operations execute immediately as they are called, enabling intuitive debugging but potentially sacrificing optimization opportunities.
torch.compile
PyTorch's graph capture and compilation system that automatically optimizes eager-mode code for significant performance improvements without code changes.
03 TensorFlow: Production and Scale
TensorFlow was designed from the ground up for production ML at Google scale. Its original design centered on static computational graphs that are defined once and then executed repeatedly, enabling aggressive compiler optimizations, distributed execution, and deployment across diverse hardware targets.
Static Computational Graph
A computation graph that is fully defined before any execution occurs. The entire program structure is known ahead of time, enabling the compiler to perform global optimizations such as operation fusion, constant folding, and optimal memory planning.
TensorFlow 2.0 and Keras
TensorFlow 2.0 adopted eager execution by default, matching PyTorch's developer experience. The tf.function decorator allows selectively tracing Python functions into optimized graphs, providing a middle ground between ease of use and performance. Keras, integrated as TensorFlow's high-level API, provides a user-friendly interface for common model architectures.
import tensorflow as tf
class="tok-comment"># tf.function traces Python into an optimized graph
class="tok-decorator">@tf.function
def train_step(model, x, y):
with tf.GradientTape() as tape:
predictions = model(x, training=True)
loss = loss_fn(y, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return lossThe @tf.function decorator provides graph-mode performance while writing eager-style code. However, it has limitations: Python side effects (print, file I/O) only execute during tracing, not on subsequent calls. Python control flow must use tf.cond and tf.while_loop for graph-compatible branching and looping.
The Deployment Ecosystem
TensorFlow's strongest advantage is its comprehensive deployment ecosystem. No other framework matches the breadth of deployment targets that TensorFlow supports.
| Component | Target | Use Case |
|---|---|---|
| TF Serving | Cloud / data center | Production model serving with versioning and A/B testing |
| TFLite | Mobile / embedded | On-device inference for Android, iOS, and microcontrollers |
| TensorFlow.js | Web browser | Client-side ML in JavaScript applications |
| TFX | MLOps pipeline | End-to-end ML pipeline orchestration and management |
| tf.distribute | Distributed training | Multi-GPU, multi-node, and TPU training strategies |
Table 7.2: TensorFlow's deployment ecosystem covers every major target platform.
Choose TensorFlow when your primary concern is deploying models to diverse targets (mobile, web, embedded, cloud). Its end-to-end pipeline tooling (TFX) and battle-tested serving infrastructure (TF Serving) are particularly valuable for large organizations with established production ML systems.
For large-scale training, TensorFlow offers robust distributed strategies through tf.distribute, supporting data parallelism, model parallelism, and parameter server architectures. Integration with TPUs provides access to Google's custom ML accelerators, which can offer significant cost and performance advantages for large training jobs.
While TensorFlow remains used in production, its share of new research papers has declined sharply since 2020 and continues to fall through 2026. Teams choosing TensorFlow should be aware that most cutting-edge models and techniques are released as PyTorch-first implementations, with delayed or community-maintained TensorFlow ports if they exist at all.
Static Graph
A computational graph that is fully defined before execution, enabling compiler optimizations, distributed execution, and deployment to diverse targets.
TensorFlow Serving
A flexible, high-performance serving system for ML models designed for production environments, supporting model versioning and A/B testing.
04 JAX: Functional ML Computing
JAX brings a functional programming paradigm to ML computing, built on top of XLA (Accelerated Linear Algebra) compiler. JAX programs are written as pure functions that transform data, which enables powerful composition of transformations like automatic differentiation, vectorization, parallelization, and JIT compilation.
Pure Function
A function that always produces the same output for the same input and has no side effects (no mutation of external state). Pure functions are the foundation of JAX's composable transformation model, enabling the compiler to reason about and aggressively optimize programs.
Composable Transformations
JAX's core transformations are composable and orthogonal. These transformations can be applied independently or chained together, enabling patterns that would be cumbersome in other frameworks.
| Transformation | Function | Purpose |
|---|---|---|
| jax.grad | Automatic differentiation | Computes gradients of scalar-valued functions |
| jax.vmap | Auto-vectorization | Maps a function over batch dimensions without explicit loops |
| jax.pmap | Parallelization | Distributes computation across multiple devices (GPUs/TPUs) |
| jax.jit | JIT compilation | Compiles functions via XLA for accelerated execution |
Table 7.3: JAX's four core function transformations.
import jax
import jax.numpy as jnp
class="tok-comment"># Composable transformations: grad + jit + vmap
def loss_fn(params, x, y):
pred = jnp.dot(x, params)
return jnp.mean((pred - y) ** class="tok-number">2)
class="tok-comment"># Compose: JIT-compile the gradient of the loss
fast_grad = jax.jit(jax.grad(loss_fn))
class="tok-comment"># Compose: vectorize over a batch of different param sets
batched_grad = jax.vmap(jax.grad(loss_fn), in_axes=(class="tok-number">0, None, None))Because JAX programs are pure functions without side effects, the XLA compiler can perform aggressive whole-program optimizations including operation fusion, memory planning, and automatic device placement. This often yields better performance than manually optimized code in eager-mode frameworks.
The JAX Ecosystem
JAX's ecosystem is growing rapidly with libraries like Flax and Haiku for neural network modules, Optax for optimization, and JAX-based implementations of major model architectures. Google DeepMind has standardized on JAX for research, and it is increasingly used for large-scale training of foundation models.
- Flax — Google's neural network library for JAX with a focus on flexibility.
- Optax — Gradient processing and optimization library (SGD, Adam, etc.).
- Orbax — Checkpointing and serialization utilities for JAX models.
- Haiku — DeepMind's lightweight neural network library for JAX.
JAX's functional paradigm requires thinking differently about state management. Unlike PyTorch where model parameters are mutable object attributes, JAX requires explicitly passing parameters as function arguments. Developers accustomed to object-oriented frameworks often find this transition challenging, but the benefits in composability and compiler optimization are substantial.
JAX excels when you need to compose novel training algorithms, experiment with custom gradients, scale across large TPU pods, or need the compiler to aggressively optimize your code. It is particularly well-suited for research into new optimization methods and large-scale foundation model training.
XLA
Accelerated Linear Algebra, a domain-specific compiler that optimizes tensor computations for multiple hardware targets including GPUs and TPUs.
Function Transformations
JAX's composable operations (grad, vmap, pmap, jit) that transform pure functions to add capabilities like differentiation, vectorization, and parallelization.
05 Computational Graphs and Compilation
Computational graphs are the fundamental abstraction underlying all ML frameworks. They represent the sequence of mathematical operations that transform inputs into outputs, with nodes representing operations and edges representing data flow. The structure of this graph determines what optimizations are possible.
Computational Graph
A directed acyclic graph (DAG) where nodes represent mathematical operations (matmul, convolution, activation) and edges represent the flow of tensor data between operations. This graph is the foundation of automatic differentiation and compiler optimization in ML frameworks.
Static vs. Dynamic Graphs
Static graphs (TensorFlow 1.x, compiled modes) enable whole-program optimization: the compiler can see all operations before any execution occurs, allowing it to fuse operations, eliminate redundancy, optimize memory allocation, and plan device placement. The trade-off is reduced flexibility and harder debugging.
Dynamic graphs (PyTorch eager, TensorFlow eager) build the graph on-the-fly during execution, providing maximum flexibility for dynamic control flow and easy debugging. Modern frameworks bridge this gap through tracing-based compilation (torch.compile, tf.function) that captures dynamic graphs and compiles them for subsequent executions.
| Property | Static Graphs | Dynamic Graphs | Tracing-Based Compilation |
|---|---|---|---|
| Graph Construction | Define-then-run | Define-by-run | Trace-then-compile |
| Optimization Scope | Whole program | Per operation | Traced subgraphs |
| Debugging | Difficult (no Python access) | Easy (standard Python tools) | Moderate (debug eager, deploy compiled) |
| Dynamic Control Flow | Requires special ops | Native Python | Limited (graph breaks) |
| Examples | TF 1.x, JAX/XLA | PyTorch eager, TF eager | torch.compile, tf.function |
Table 7.4: Trade-offs between graph execution strategies.
What is the primary advantage of dynamic computational graphs (eager execution)?
Modern frameworks are converging on a hybrid approach: write code in eager mode for development and debugging, then use tracing-based compilation (torch.compile, tf.function, jax.jit) for production performance. This "eager-first, compile-later" pattern gives developers the best of both worlds.
ML Compilers: The Emerging Layer
The trend toward ML compilers is reshaping the framework landscape. Projects like Apache TVM, MLIR, and Triton provide compiler infrastructure that can optimize ML workloads across diverse hardware. These compilers sit between frameworks and hardware, enabling a clean separation of concerns where framework developers focus on usability and compiler developers focus on performance.
- XLA (Accelerated Linear Algebra) — Google's ML compiler, used by JAX and TensorFlow for TPU and GPU targets.
- TorchInductor — PyTorch 2.0's default compiler backend, generating Triton kernels for GPU execution.
- Apache TVM — Open-source compiler stack for deploying ML to diverse hardware targets.
- MLIR (Multi-Level IR) — LLVM-based infrastructure for building domain-specific compilers, increasingly used for ML.
- Triton — OpenAI's language for writing custom GPU kernels in Python-like syntax.
When using torch.compile or tf.function, start by profiling your model to identify which operations dominate runtime. Graph compilation benefits operations with predictable shapes most. If your model has many dynamic shapes or Python-heavy control flow, you may see "graph breaks" that limit the speedup. Use framework profiling tools (torch.profiler, TensorBoard) to diagnose compilation issues.
Consider a sequence: y = ReLU(BatchNorm(Conv2d(x))). Without compilation, this requires three separate kernel launches and two intermediate memory writes. An ML compiler fuses these into a single kernel that reads x once, computes all three operations in registers, and writes the final y once. This eliminates two round-trips to GPU global memory, often yielding a 2-3x speedup for such sequences.
Computational Graph
A directed graph representation of mathematical operations where nodes are operations and edges represent data flow, forming the basis of automatic differentiation and optimization.
Graph Compilation
The process of analyzing and optimizing a computational graph before execution, enabling operation fusion, memory optimization, and hardware-specific acceleration.
Key Takeaways
- 1Framework choice has far-reaching implications for debugging, deployment, hardware support, and ecosystem access.
- 2PyTorch dominates research with its flexible eager execution; TensorFlow leads in production deployment breadth.
- 3JAX offers a unique functional approach with composable transformations but has a steeper learning curve.
- 4The trend toward ML compilers (torch.compile, XLA, TVM) is bridging the gap between framework flexibility and hardware performance.
- 5Computational graphs are the fundamental abstraction enabling automatic differentiation, optimization, and hardware-portable execution.
CH.07
Chapter Complete
Chapter Progress
Interact with the visualization
AI Frameworks Quiz
Test your understanding of ML frameworks including PyTorch, TensorFlow, JAX, and their systems trade-offs.
Ready to test your knowledge?