Skip to content
ML SystemPart 2: System DesignChapter 7
Design CH.07 ~30 min

AI Frameworks

Compares major ML frameworks including TensorFlow, PyTorch, and JAX. Analyzes the framework ecosystem, trade-offs, and systems-level implications.

TensorFlowPyTorchJAXframework comparisoncomputational graphs
Read in mlsysbook.ai
  • Compare PyTorch, TensorFlow, and JAX across key dimensions like ease of use and performance
  • Explain the computational graph paradigms: eager execution vs. traced compilation
  • Evaluate framework selection criteria based on project requirements and team expertise
  • Describe the role of ONNX in framework interoperability and model portability
  • Analyze framework ecosystem maturity including tooling, community, and deployment support

01 The ML Framework Landscape Viz

Machine learning frameworks provide the foundational software layer between model definitions and hardware execution. They abstract the complexity of tensor operations, automatic differentiation, GPU memory management, and distributed computation, allowing researchers and engineers to focus on model design rather than low-level implementation details.

Definition

ML Framework

A software library that provides high-level APIs for defining, training, and deploying neural networks, handling tensor operations, automatic differentiation, and hardware acceleration transparently.

The Big Three: PyTorch, TensorFlow, and JAX

The framework landscape has consolidated around three major players: PyTorch, TensorFlow, and JAX. Each embodies different design philosophies with distinct trade-offs. PyTorch prioritizes developer experience and flexibility. TensorFlow emphasizes production deployment and ecosystem completeness. JAX focuses on functional programming, composability, and high-performance computing.

Figure: ML Framework Ecosystem Positioning Map

ML Framework Ecosystem Bubble size reflects relative community adoption Research Focus Production Focus ← ────────────────── → High-level Low-level PyTorchTensorFlowJAXKerasONNX RuntimeHugging Face Hover a framework bubble for details
Figure 7.1: The ML Framework Ecosystem
Table 7.1: High-level comparison of the three dominant ML frameworks.
DimensionPyTorchTensorFlowJAX
Design PhilosophyPythonic, eager-firstProduction-first, graph-basedFunctional, composable
Default ExecutionEager (dynamic graph)Eager (with tf.function)JIT-compiled via XLA
Primary StrengthResearch flexibilityDeployment ecosystemComposable transformations
DeveloperMeta (Facebook)Google BrainGoogle DeepMind
Community SizeLargest (research)Large (industry)Growing (research)

Table 7.1: High-level comparison of the three dominant ML frameworks.

Framework Consolidation

As of 2026, PyTorch dominates both research and increasingly production deployments, while TensorFlow retains legacy presence in large-scale serving. JAX has established itself as the framework of choice for large-scale training at Google DeepMind and beyond. Smaller frameworks like Caffe, Theano, and MXNet have been fully superseded.

Framework choice has profound implications for the entire ML stack. It influences debugging workflows, deployment options, hardware compatibility, community support, and the availability of pre-trained models and libraries. Organizations must consider not just the current feature set but the long-term trajectory and community momentum of each framework.

Specialized Frameworks and Runtimes

Beyond the big three, specialized frameworks serve specific domains. ONNX Runtime provides cross-framework inference optimization. TensorRT optimizes for NVIDIA GPU deployment. Core ML and TFLite Micro target mobile and embedded devices respectively. Understanding this ecosystem is essential for making informed infrastructure decisions.

  • ONNX Runtime — Cross-framework inference optimization via a standardized model format.
  • TensorRT — NVIDIA's high-performance inference optimizer for GPU deployment.
  • Core ML — Apple's framework for on-device ML on iOS, macOS, and Apple Silicon.
  • TFLite Micro — TensorFlow's runtime for microcontrollers and embedded devices.
  • OpenVINO — Intel's toolkit for optimizing inference on Intel hardware.
Choosing a Framework

If you are starting a new project, begin with PyTorch unless you have a specific reason not to. Its dominance in research means the largest ecosystem of pre-trained models, tutorials, and community support. Consider TensorFlow for production-heavy workflows with diverse deployment targets, and JAX for research into novel training algorithms or large-scale TPU training.

Deeper InsightRather than one framework winning outright, the ML community has converged on a division of labor. PyTorch dominates research and rapid prototyping thanks to its Pythonic API and dynamic graphs. TensorFlow retains its grip on production deployment with its unmatched breadth of serving targets (mobile, web, embedded, cloud). JAX has carved out the high-performance computing niche, favored by Google DeepMind and teams training the largest foundation models on TPU pods. Understanding these niches helps teams pick the right tool for their specific context. Click to collapse

ML Framework

A software library that provides high-level APIs for defining, training, and deploying neural networks, handling tensor operations, automatic differentiation, and hardware acceleration.

ONNX

Open Neural Network Exchange, an open format for representing ML models that enables interoperability between different frameworks.

02 PyTorch: Flexibility and Research

PyTorch has become the dominant framework in ML research due to its intuitive Pythonic interface and dynamic computational graph (eager execution). Every operation executes immediately, making debugging as simple as inserting print statements or using standard Python debuggers. This define-by-run approach matches how researchers think about models.

Definition

Eager Execution

A programming model where operations execute immediately when called, building the computational graph dynamically at runtime. This allows standard Python debugging tools and enables architectures with data-dependent control flow.

Autograd: Automatic Differentiation

PyTorch's autograd system automatically records operations on tensors and computes gradients through reverse-mode automatic differentiation. The computational graph is rebuilt on every forward pass, enabling dynamic architectures that vary from input to input, such as recursive neural networks or models with data-dependent control flow.

python
import torch

class="tok-comment"># Autograd example: compute gradients automatically
x = torch.tensor([class="tok-number">2.0, class="tok-number">3.0], requires_grad=True)
y = x ** class="tok-number">2 + class="tok-number">3 * x  class="tok-comment"># y = x^class="tok-number">2 + 3x
loss = y.sum()
loss.backward()      class="tok-comment"># Compute dy/dx
print(x.grad)         class="tok-comment"># tensor([class="tok-number">7., class="tok-number">9.])  (2x + class="tok-number">3)
Debugging with Eager Execution

One of PyTorch's greatest strengths is that you can insert standard Python breakpoints (import pdb; pdb.set_trace()) or print statements anywhere in your model. Because operations execute immediately, you can inspect intermediate tensor values, shapes, and gradients during debugging — just like debugging any Python program.

From Research to Production: torch.compile

For production deployment, PyTorch offers TorchScript (a subset of Python that can be JIT compiled) and torch.export for ahead-of-time compilation. PyTorch 2.0 introduced torch.compile, which uses a graph capture and compilation approach to achieve significant speedups without requiring code changes. This bridges the gap between PyTorch's research flexibility and production performance.

torch.compile in Action

Adding a single line — model = torch.compile(model) — can speed up training by 30-50% on supported hardware. Under the hood, PyTorch captures the computational graph using TorchDynamo, optimizes it with TorchInductor, and generates efficient kernel code. No changes to the model definition or training loop are needed.

The PyTorch Ecosystem

The PyTorch ecosystem includes Hugging Face Transformers for pre-trained models, PyTorch Lightning for training boilerplate reduction, and torchvision/torchaudio/torchtext for domain-specific functionality. This rich ecosystem, combined with strong community momentum, makes PyTorch the default choice for most new ML projects.

  • Hugging Face Transformers — Pre-trained models and fine-tuning pipelines for NLP, vision, and multimodal tasks.
  • PyTorch Lightning — Reduces training boilerplate and standardizes project structure.
  • torchvision / torchaudio / torchtext — Domain-specific datasets, transforms, and model architectures.
  • FSDP (Fully Sharded Data Parallel) — Native distributed training with memory-efficient parameter sharding.
Dynamic Graphs and Reproducibility

PyTorch's dynamic graph construction means that the exact sequence of operations can vary between runs when using data-dependent control flow. This can make bit-for-bit reproducibility challenging. Use torch.manual_seed() and torch.use_deterministic_algorithms(True) when exact reproducibility is required.

Eager Execution

A programming model where operations execute immediately as they are called, enabling intuitive debugging but potentially sacrificing optimization opportunities.

torch.compile

PyTorch's graph capture and compilation system that automatically optimizes eager-mode code for significant performance improvements without code changes.

03 TensorFlow: Production and Scale

TensorFlow was designed from the ground up for production ML at Google scale. Its original design centered on static computational graphs that are defined once and then executed repeatedly, enabling aggressive compiler optimizations, distributed execution, and deployment across diverse hardware targets.

Definition

Static Computational Graph

A computation graph that is fully defined before any execution occurs. The entire program structure is known ahead of time, enabling the compiler to perform global optimizations such as operation fusion, constant folding, and optimal memory planning.

TensorFlow 2.0 and Keras

TensorFlow 2.0 adopted eager execution by default, matching PyTorch's developer experience. The tf.function decorator allows selectively tracing Python functions into optimized graphs, providing a middle ground between ease of use and performance. Keras, integrated as TensorFlow's high-level API, provides a user-friendly interface for common model architectures.

python
import tensorflow as tf

class="tok-comment"># tf.function traces Python into an optimized graph
class="tok-decorator">@tf.function
def train_step(model, x, y):
    with tf.GradientTape() as tape:
        predictions = model(x, training=True)
        loss = loss_fn(y, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss
The tf.function Trade-off

The @tf.function decorator provides graph-mode performance while writing eager-style code. However, it has limitations: Python side effects (print, file I/O) only execute during tracing, not on subsequent calls. Python control flow must use tf.cond and tf.while_loop for graph-compatible branching and looping.

The Deployment Ecosystem

TensorFlow's strongest advantage is its comprehensive deployment ecosystem. No other framework matches the breadth of deployment targets that TensorFlow supports.

Table 7.2: TensorFlow's deployment ecosystem covers every major target platform.
ComponentTargetUse Case
TF ServingCloud / data centerProduction model serving with versioning and A/B testing
TFLiteMobile / embeddedOn-device inference for Android, iOS, and microcontrollers
TensorFlow.jsWeb browserClient-side ML in JavaScript applications
TFXMLOps pipelineEnd-to-end ML pipeline orchestration and management
tf.distributeDistributed trainingMulti-GPU, multi-node, and TPU training strategies

Table 7.2: TensorFlow's deployment ecosystem covers every major target platform.

When TensorFlow Shines

Choose TensorFlow when your primary concern is deploying models to diverse targets (mobile, web, embedded, cloud). Its end-to-end pipeline tooling (TFX) and battle-tested serving infrastructure (TF Serving) are particularly valuable for large organizations with established production ML systems.

For large-scale training, TensorFlow offers robust distributed strategies through tf.distribute, supporting data parallelism, model parallelism, and parameter server architectures. Integration with TPUs provides access to Google's custom ML accelerators, which can offer significant cost and performance advantages for large training jobs.

TensorFlow's Community Trajectory

While TensorFlow remains used in production, its share of new research papers has declined sharply since 2020 and continues to fall through 2026. Teams choosing TensorFlow should be aware that most cutting-edge models and techniques are released as PyTorch-first implementations, with delayed or community-maintained TensorFlow ports if they exist at all.

Static Graph

A computational graph that is fully defined before execution, enabling compiler optimizations, distributed execution, and deployment to diverse targets.

TensorFlow Serving

A flexible, high-performance serving system for ML models designed for production environments, supporting model versioning and A/B testing.

04 JAX: Functional ML Computing

JAX brings a functional programming paradigm to ML computing, built on top of XLA (Accelerated Linear Algebra) compiler. JAX programs are written as pure functions that transform data, which enables powerful composition of transformations like automatic differentiation, vectorization, parallelization, and JIT compilation.

Definition

Pure Function

A function that always produces the same output for the same input and has no side effects (no mutation of external state). Pure functions are the foundation of JAX's composable transformation model, enabling the compiler to reason about and aggressively optimize programs.

Composable Transformations

JAX's core transformations are composable and orthogonal. These transformations can be applied independently or chained together, enabling patterns that would be cumbersome in other frameworks.

Table 7.3: JAX's four core function transformations.
TransformationFunctionPurpose
jax.gradAutomatic differentiationComputes gradients of scalar-valued functions
jax.vmapAuto-vectorizationMaps a function over batch dimensions without explicit loops
jax.pmapParallelizationDistributes computation across multiple devices (GPUs/TPUs)
jax.jitJIT compilationCompiles functions via XLA for accelerated execution

Table 7.3: JAX's four core function transformations.

python
import jax
import jax.numpy as jnp

class="tok-comment"># Composable transformations: grad + jit + vmap
def loss_fn(params, x, y):
    pred = jnp.dot(x, params)
    return jnp.mean((pred - y) ** class="tok-number">2)

class="tok-comment"># Compose: JIT-compile the gradient of the loss
fast_grad = jax.jit(jax.grad(loss_fn))

class="tok-comment"># Compose: vectorize over a batch of different param sets
batched_grad = jax.vmap(jax.grad(loss_fn), in_axes=(class="tok-number">0, None, None))
XLA Under the Hood

Because JAX programs are pure functions without side effects, the XLA compiler can perform aggressive whole-program optimizations including operation fusion, memory planning, and automatic device placement. This often yields better performance than manually optimized code in eager-mode frameworks.

The JAX Ecosystem

JAX's ecosystem is growing rapidly with libraries like Flax and Haiku for neural network modules, Optax for optimization, and JAX-based implementations of major model architectures. Google DeepMind has standardized on JAX for research, and it is increasingly used for large-scale training of foundation models.

  • Flax — Google's neural network library for JAX with a focus on flexibility.
  • Optax — Gradient processing and optimization library (SGD, Adam, etc.).
  • Orbax — Checkpointing and serialization utilities for JAX models.
  • Haiku — DeepMind's lightweight neural network library for JAX.
Steeper Learning Curve

JAX's functional paradigm requires thinking differently about state management. Unlike PyTorch where model parameters are mutable object attributes, JAX requires explicitly passing parameters as function arguments. Developers accustomed to object-oriented frameworks often find this transition challenging, but the benefits in composability and compiler optimization are substantial.

0x
Approximate speedup from XLA compilation vs eager execution in JAX
When to Choose JAX

JAX excels when you need to compose novel training algorithms, experiment with custom gradients, scale across large TPU pods, or need the compiler to aggressively optimize your code. It is particularly well-suited for research into new optimization methods and large-scale foundation model training.

XLA

Accelerated Linear Algebra, a domain-specific compiler that optimizes tensor computations for multiple hardware targets including GPUs and TPUs.

Function Transformations

JAX's composable operations (grad, vmap, pmap, jit) that transform pure functions to add capabilities like differentiation, vectorization, and parallelization.

05 Computational Graphs and Compilation

Computational graphs are the fundamental abstraction underlying all ML frameworks. They represent the sequence of mathematical operations that transform inputs into outputs, with nodes representing operations and edges representing data flow. The structure of this graph determines what optimizations are possible.

Definition

Computational Graph

A directed acyclic graph (DAG) where nodes represent mathematical operations (matmul, convolution, activation) and edges represent the flow of tensor data between operations. This graph is the foundation of automatic differentiation and compiler optimization in ML frameworks.

Static vs. Dynamic Graphs

Static graphs (TensorFlow 1.x, compiled modes) enable whole-program optimization: the compiler can see all operations before any execution occurs, allowing it to fuse operations, eliminate redundancy, optimize memory allocation, and plan device placement. The trade-off is reduced flexibility and harder debugging.

Dynamic graphs (PyTorch eager, TensorFlow eager) build the graph on-the-fly during execution, providing maximum flexibility for dynamic control flow and easy debugging. Modern frameworks bridge this gap through tracing-based compilation (torch.compile, tf.function) that captures dynamic graphs and compiles them for subsequent executions.

Table 7.4: Trade-offs between graph execution strategies.
PropertyStatic GraphsDynamic GraphsTracing-Based Compilation
Graph ConstructionDefine-then-runDefine-by-runTrace-then-compile
Optimization ScopeWhole programPer operationTraced subgraphs
DebuggingDifficult (no Python access)Easy (standard Python tools)Moderate (debug eager, deploy compiled)
Dynamic Control FlowRequires special opsNative PythonLimited (graph breaks)
ExamplesTF 1.x, JAX/XLAPyTorch eager, TF eagertorch.compile, tf.function

Table 7.4: Trade-offs between graph execution strategies.

Quick Check

What is the primary advantage of dynamic computational graphs (eager execution)?

Not quite.Dynamic computational graphs execute operations immediately as they are called, which means you can use standard Python debugging tools (print statements, breakpoints, pdb) to inspect intermediate values at any point. This makes prototyping and debugging dramatically easier compared to static graphs, where the graph must be fully defined before any execution occurs and debugging requires specialized tools.
Continue reading
The Convergence Trend

Modern frameworks are converging on a hybrid approach: write code in eager mode for development and debugging, then use tracing-based compilation (torch.compile, tf.function, jax.jit) for production performance. This "eager-first, compile-later" pattern gives developers the best of both worlds.

ML Compilers: The Emerging Layer

The trend toward ML compilers is reshaping the framework landscape. Projects like Apache TVM, MLIR, and Triton provide compiler infrastructure that can optimize ML workloads across diverse hardware. These compilers sit between frameworks and hardware, enabling a clean separation of concerns where framework developers focus on usability and compiler developers focus on performance.

  • XLA (Accelerated Linear Algebra) — Google's ML compiler, used by JAX and TensorFlow for TPU and GPU targets.
  • TorchInductor — PyTorch 2.0's default compiler backend, generating Triton kernels for GPU execution.
  • Apache TVM — Open-source compiler stack for deploying ML to diverse hardware targets.
  • MLIR (Multi-Level IR) — LLVM-based infrastructure for building domain-specific compilers, increasingly used for ML.
  • Triton — OpenAI's language for writing custom GPU kernels in Python-like syntax.
Graph Compilation in Practice

When using torch.compile or tf.function, start by profiling your model to identify which operations dominate runtime. Graph compilation benefits operations with predictable shapes most. If your model has many dynamic shapes or Python-heavy control flow, you may see "graph breaks" that limit the speedup. Use framework profiling tools (torch.profiler, TensorBoard) to diagnose compilation issues.

Operator Fusion by Compilers

Consider a sequence: y = ReLU(BatchNorm(Conv2d(x))). Without compilation, this requires three separate kernel launches and two intermediate memory writes. An ML compiler fuses these into a single kernel that reads x once, computes all three operations in registers, and writes the final y once. This eliminates two round-trips to GPU global memory, often yielding a 2-3x speedup for such sequences.

Computational Graph

A directed graph representation of mathematical operations where nodes are operations and edges represent data flow, forming the basis of automatic differentiation and optimization.

Graph Compilation

The process of analyzing and optimizing a computational graph before execution, enabling operation fusion, memory optimization, and hardware-specific acceleration.

Key Takeaways

  1. 1Framework choice has far-reaching implications for debugging, deployment, hardware support, and ecosystem access.
  2. 2PyTorch dominates research with its flexible eager execution; TensorFlow leads in production deployment breadth.
  3. 3JAX offers a unique functional approach with composable transformations but has a steeper learning curve.
  4. 4The trend toward ML compilers (torch.compile, XLA, TVM) is bridging the gap between framework flexibility and hardware performance.
  5. 5Computational graphs are the fundamental abstraction enabling automatic differentiation, optimization, and hardware-portable execution.

CH.07

Chapter Complete

Up next:AI Training

Chapter Progress

Reading
Exercise

Interact with the visualization

Quiz

AI Frameworks Quiz

Test your understanding of ML frameworks including PyTorch, TensorFlow, JAX, and their systems trade-offs.

Ready to test your knowledge?

5 questionsRandomized from pool70% to pass