PyTorch Internals

reading ezyang’s blog post and source diving. this document is going to be long.

1. tensor ≠ data, tensor is a view

the most important thing to understand before anything else: a Tensor is not the data. it’s a description of how to read some data that lives somewhere else.

physically, RAM is 1D. always. a “matrix” doesn’t exist in memory — what exists is a flat array of bytes, and the tensor is a struct that tells you how to interpret that array as if it were 2D, 3D, whatever shape you want.

x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])

what actually lives in memory:

addr:   0x00  0x04  0x08  0x0C
value: [ 1.0 | 2.0 | 3.0 | 4.0 ]

that’s it. the “2×2 matrix” is a story being told by TensorImpl. the actual C++ object hierarchy:

Python: x
  └─► intrusive_ptr<TensorImpl>        (c10/core/TensorImpl.h)
        ├── sizes_:    [2, 2]
        ├── strides_:  [2, 1]
        ├── offset_:   0
        ├── dtype_:    float32
        ├── key_set_:  {CPU}            ← DispatchKeySet bitfield
        └── storage_
              └─► StorageImpl
                    ├── data_ptr_: 0x7f8abc...
                    ├── allocator: DefaultCPUAllocator
                    └── nbytes_:   16   (4 × sizeof(float))

multiple TensorImpls can point to the same StorageImpl. memory doesn’t duplicate. only metadata changes. this is the entire reason views exist and are O(1).

graph TD
    A["Tensor A<br/>shape [3,4]<br/>stride [4,1]<br/>offset 0"] --> S["StorageImpl<br/>data_ptr: 0x7f8...<br/>16 floats"]
    B["Tensor B = A.T<br/>shape [4,3]<br/>stride [1,4]<br/>offset 0"] --> S
    C["Tensor C = A[1,:]<br/>shape [4]<br/>stride [1]<br/>offset 4"] --> S

all three tensors share the same raw memory. zero copies.


2. strides — the math that makes zero-copy possible

for any N-dim tensor, the physical memory index of element $[i_0, i_1, \ldots, i_{N-1}]$ is:

\[\mathrm{physical\_idx} = \mathrm{offset} + \sum_{k=0}^{N-1} i_k \cdot s_k\]

where $s_k$ is the stride for dimension $k$.

for a contiguous row-major tensor of shape $[R, C]$, strides are $[C, 1]$, giving:

\[\mathrm{physical\_idx} = \text{offset} + \text{row} \cdot C + \text{col} \cdot 1\]

which is literally the standard row-major index formula from linear algebra. makes sense.

concrete example: shape [3, 4]

strides: [4, 1], offset: 0

element [2, 3]:
  idx = 0 + (2 × 4) + (3 × 1) = 11  ✓

physical layout:
  [0,0][0,1][0,2][0,3] | [1,0][1,1][1,2][1,3] | [2,0][2,1][2,2][2,3]
    0    1    2    3       4    5    6    7         8    9   10   11

transpose is O(1) — proved

with original strides $[C, 1]$, element $[i, j]$ is at:

\[\text{idx} = i \cdot C + j \cdot 1\]

after swapping strides to $[1, C]$, element $[i, j]$ is at:

\[\text{idx} = i \cdot 1 + j \cdot C\]

so transposed[i, j] reads from the same address as original[j, i]. that’s exactly what a transpose means. no bytes moved, just swap two integers in the metadata.

x = torch.randn(3, 4)
y = x.T

print(x.data_ptr() == y.data_ptr())   # True — same raw pointer
print(x.strides())                     # (4, 1)
print(y.strides())                     # (1, 4) — just swapped

column slice — non-trivial strides

x  # shape [3, 4], strides [4, 1]
col = x[:, 2]
# shape [3], stride [4], offset 2

elements of col sit at positions 2, 6, 10 in memory — not adjacent, gap of 4 between each. stride [4] encodes exactly that gap. no copy, just walk with a bigger step.

stride 0 — the broadcasting trick

x = torch.randn(1, 4)
y = x.expand(3, 4)

print(y.shape)     # [3, 4]
print(y.strides()) # (0, 1)  ← stride 0 in dim 0!

stride 0 means: moving one step in that dimension doesn’t advance the memory pointer at all. so all three “rows” of y are literally the same physical row of x. expand(1000, 4) is still O(1) regardless of how large you expand it. the indexing formula:

\[y[i, j] \to \text{offset} + i \cdot 0 + j \cdot 1 = j\]

the row index $i$ vanishes. same address for every row.

[!note] strides can also be negative. x.flip(0) gives stride $[-1]$ with offset pointing to the last element. the indexing formula walks backwards through memory. still zero-copy.


3. when the illusion breaks — .contiguous()

most high-performance kernels (cuBLAS, MKLDNN, any vendor BLAS) require data to be physically contiguous in row-major order. they load in 128-byte cache lines or 512-bit SIMD vectors. if your elements aren’t adjacent, vectorization fails, prefetcher gets confused, cache thrashes.

so .contiguous() is the moment PyTorch gives up the lie and actually copies:

x = torch.randn(1000, 1000)
y = x.T             # strides (1, 1000), non-contiguous

z = torch.mm(y, x)  # cuBLAS needs contiguous → forces y.contiguous() internally

what .contiguous() does internally:

  1. allocate a fresh StorageImpl ($1000 \times 1000 \times 4$ bytes = 4MB)
  2. walk through y using strides $(1, 1000)$ — jumping 1000 elements per row step
  3. write each element into the new buffer in row-major order
  4. return new TensorImpl with standard strides $(1000, 1)$

full O(n) memory copy. you can verify:

y  = x.T
y2 = y.contiguous()

print(y.data_ptr() == y2.data_ptr())  # False — new allocation
print(y.is_contiguous())              # False
print(y2.is_contiguous())             # True
print(y2.strides())                   # (1000, 1)

[!warning] views keep the entire StorageImpl alive. this burns people constantly:

huge = torch.randn(10000, 10000)   # ~400MB on GPU
tiny = huge[0, :5]                 # "just 5 floats"

del huge
# StorageImpl refcount still > 0 because tiny holds a reference.
# The full 400MB is still allocated. torch.cuda.empty_cache() won't help.

# Fix:
tiny = huge[0, :5].clone()         # new StorageImpl, 5 elements only
del huge                            # refcount → 0, 400MB freed

.clone() is the only escape. it creates a fresh StorageImpl with only the elements you care about.


4. device × layout × dtype

every tensor lives in the Cartesian product of three axes. this is PyTorch’s extension model.

device — where the bytes physically live, determines the allocator:

device allocator notes
cpu DefaultCPUAllocator basically malloc
cuda:0 CUDACachingAllocator custom pool, avoids cudaMalloc
xla XLA allocator opaque to PyTorch
mps Metal allocator Apple silicon

layout — how the index formula works:

layout structure use case
strided sizes + strides (what we’ve been talking about) default, dense
sparse_coo (indices tensor, values tensor) pair embedding tables, adjacency matrices
sparse_csr row_ptrs + col_indices + values sparse matmul
mkldnn vendor-opaque blocked/tiled format Intel inference

for sparse_coo: a tensor with $nnz$ nonzero elements stores indices of shape $[\text{ndim}, nnz]$ and values of shape $[nnz]$. element $[i, j]$ exists iff there exists some $k$ such that indices[:, k] == [i, j].

for sparse_csr: three flat arrays. crow_indices has $R+1$ entries where crow_indices[i] is the start of row $i$ in col_indices. row $i$ spans values[crow_indices[i] : crow_indices[i+1]]. better for row-access patterns.

dtype — what’s stored per element:

dtype bytes notes
float32 4 IEEE 754 single precision, default
float16 2 reduced range + precision, inference
bfloat16 2 same exponent range as float32 (8-bit exp), better training stability than fp16
int8 / qint8 1 quantized, needs scale factor
bool 1 not 1 bit — a full byte per element

if you’re extending PyTorch you pick one axis. new hardware = new device. new memory format = new layout. new number format = new dtype. you don’t get to touch the other two axes.


5. the dispatcher — how a Python call finds the right C++ kernel

when you write torch.add(a, b), PyTorch needs to resolve three independent questions:

  1. should this be recorded for autograd?
  2. is it CPU or CUDA (or XLA, or…)?
  3. is it float32 or int64 or bfloat16?

these are three separate mechanisms. the dispatcher handles (1) and (2). the dtype dispatch is a separate macro inside the kernel.

DispatchKeySet

every TensorImpl has key_set_, a 64-bit bitfield. each bit = one dispatch key. a CUDA tensor with requires_grad=True has both the CUDA bit and the Autograd bit set.

DispatchKeySet for cuda float32, requires_grad=True:
  bit  0: BackendSelect    → 0
  bit  1: Python           → 0
  ...
  bit  N: Autograd         → 1  ←
  ...
  bit  M: CUDA             → 1  ←
  ...

the dispatcher takes the union of all input tensors’ key sets, finds the highest-priority set bit, and calls the function registered at table[key][operator_id].

the dispatch table — conceptually:

graph LR
    call["torch.add(a, b)"] --> D["Dispatcher<br/>reads DispatchKeySet"]
    D -->|Autograd key| VAR["VariableType::add<br/>records AddBackward node<br/>unwraps tensors"]
    VAR -->|CUDA key| CUDA["CUDAType::add<br/>launches CUDA kernel"]
    VAR -->|CPU key| CPU["CPUType::add<br/>runs CPU kernel"]

why dynamic dispatch for device?

the CPU and CUDA kernels live in different shared libraries:

libcaffe2.so       ← CPU kernels
libcaffe2_gpu.so   ← CUDA kernels  (not loaded at all if no GPU)

you can’t make a static call into a library that might not exist at link time. the dispatch has to go through a function pointer / vtable that gets populated only when the library loads. this is why it’s a virtual call.

why static (template) dispatch for dtype?

inside the kernel, device is known. dtype isn’t. rather than another vtable, PyTorch uses a compile-time template expansion:

AT_DISPATCH_ALL_TYPES(tensor.scalar_type(), "add_cpu", ([&] {
    // scalar_t is a compile-time type alias
    // this block gets compiled once per supported dtype
    scalar_t* ptr = tensor.data_ptr<scalar_t>();
    // math here
}));

this macro expands to a switch statement:

switch (tensor.scalar_type()) {
    case ScalarType::Float: {
        using scalar_t = float;
        /* your lambda, compiled for float */
        break;
    }
    case ScalarType::Double: {
        using scalar_t = double;
        /* your lambda, compiled for double */
        break;
    }
    // ...
}

not magic. just a switch with templates. dynamic for device (because separate .so files), static for dtype (because you want zero-cost type specialization once you’re inside the right library).

full call path from Python to hardware:

flowchart TD
    PY["Python: torch.add(x, y)"] --> BIND
    BIND["THPVariable_add()\ntorch/csrc — auto-generated\nPythonArgParser extracts C++ args\nGIL released"] --> VAR
    VAR["VariableType::add()\ntorch/csrc/autograd\nunwraps Variables\ncreates AddBackward node\nsaves inputs for backward"] --> DISP
    DISP["Dispatcher\nc10/core/Dispatcher\nreads DispatchKeySet\nlooks up function pointer"] --> NAT
    NAT["at::native::add()\naten/src/ATen/native/BinaryOps.cpp\nTensorIterator setup\nbroadcasting + type promotion"] --> DT
    DT["AT_DISPATCH_ALL_TYPES\ndtype switch\nfloat? int? bfloat16?"] --> KERN
    KERN["cpu_kernel() / gpu_kernel()\nvectorized AVX loops or CUDA threads\nactual math executes"]

everything above at::native::add() is auto-generated from native_functions.yaml. you only write the bottom two boxes.


6. autograd — reverse-mode AD on a dynamic graph

PyTorch implements reverse-mode automatic differentiation. the math before the code.

why reverse-mode?

given $f: \mathbb{R}^n \to \mathbb{R}^m$, the Jacobian $J \in \mathbb{R}^{m \times n}$ has $J_{ij} = \partial f_i / \partial x_j$.

  • forward-mode: computes $Jv$ for a tangent vector $v \in \mathbb{R}^n$. cost: one forward pass per input dimension. good when $n \ll m$.
  • reverse-mode: computes $v^T J$ for a cotangent $v \in \mathbb{R}^m$. cost: one backward pass per output dimension. good when $m \ll n$.

in neural networks: $n$ = millions of parameters, $m$ = 1 (scalar loss). reverse-mode wins by a factor of $n$. that’s why backprop is reverse-mode AD.

the $v^T J$ product is called a vector-Jacobian product (VJP). what PyTorch calls “gradients” are technically VJPs where $v$ = gradient of loss w.r.t. current output.

the dynamic graph — built during forward pass

x = torch.tensor([2.0], requires_grad=True)
w = torch.tensor([3.0], requires_grad=True)
b = torch.tensor([1.0], requires_grad=True)

y = x * w       # MulBackward0 created, saves (x, w)
z = y + b       # AddBackward0 created
loss = z.sum()  # SumBackward0 created

graph after forward:

graph LR
    X((x=2)) --> MUL["MulBackward0"]
    W((w=3)) --> MUL
    MUL --> |y=6| ADD["AddBackward0"]
    B((b=1)) --> ADD
    ADD --> |z=7| SUM["SumBackward0"]
    SUM --> LOSS((loss=7))

each node stores:

  • apply() — the VJP function for this op
  • next_edges — pointers to the grad_fn of the inputs (for graph traversal)
  • saved tensors — values needed to compute the VJP (e.g. MulBackward needs both x and w)

backward pass — unwinding the tape

loss.backward()

engine starts with $\bar{L} = 1$ (seed gradient) and calls each node’s apply():

\[\text{SumBackward: } \bar{z} = \bar{L} \cdot \mathbf{1} = [1.0]\] \[\text{AddBackward: } \bar{y} = \bar{z} = 1.0, \quad \bar{b} = \bar{z} = 1.0\] \[\text{MulBackward: } \bar{x} = \bar{y} \cdot w = 1.0 \cdot 3 = 3.0, \quad \bar{w} = \bar{y} \cdot x = 1.0 \cdot 2 = 2.0\]

verification: $\text{loss} = xw + b$, so $\partial L/\partial x = w = 3$ ✓, $\partial L/\partial w = x = 2$ ✓, $\partial L/\partial b = 1$ ✓

print(x.grad)   # tensor([3.])
print(w.grad)   # tensor([2.])
print(b.grad)   # tensor([1.])

second-order gradients — re-entrant backward

you can differentiate through the backward pass:

x = torch.tensor([2.0], requires_grad=True)
y = x ** 3    # y = x³

dy_dx = torch.autograd.grad(y, x, create_graph=True)[0]
# dy/dx = 3x² = 3·4 = 12, but create_graph=True means
# the backward graph itself is a differentiable graph

d2y_dx2 = torch.autograd.grad(dy_dx, x)[0]
# d²y/dx² = 6x = 6·2 = 12
print(d2y_dx2)  # tensor([12.])  ✓

PyTorch’s C++ autograd engine handles this with re-entrant execution — the nested backward spawns a sub-queue. no deadlock because of careful locking around the ready queue.

the VariableType layer

historically Tensor and Variable were separate. they’re merged now in Python, but in C++ the VariableType dispatch key still exists. it’s the autograd wrapper:

flowchart LR
    OP["op(tensor_with_grad)"] --> VT
    VT["VariableType::op()\n1. unwrap — strip autograd meta\n2. call raw ATen kernel\n3. rewrap — attach grad_fn\n4. record saved tensors"] --> ATen["ATen kernel\n(no autograd knowledge)"]
    ATen --> VT
    VT --> OUT["output tensor\n.grad_fn = OpBackward"]

once you pass the unwrap step and drop into raw ATen, there’s no going back up during that call. unwrapping is one-directional.


7. TensorIterator — writing kernels without losing your mind

writing add(a, b) sounds trivial until you count the cases:

  • broadcasting: [3, 1] + [1, 4][3, 4]
  • type promotion: float32 + int32float32
  • non-contiguous inputs: transposed, strided views
  • output allocation: where does the result go?
  • vectorization: can we use AVX/SSE?
  • parallelism: OpenMP threads?

TensorIterator handles all of this. you write the scalar math, it handles the loop:

auto iter = TensorIteratorConfig()
    .add_output(result)
    .add_input(a)
    .add_input(b)
    .build();
// at build() time, TensorIterator has:
//   - computed output shape via broadcasting
//   - resolved type promotion
//   - analyzed contiguity of all tensors
//   - collapsed contiguous dims into flat loops

at::native::cpu_kernel(iter, [](float a, float b) -> float {
    return a + b;   // you write ONLY this
});

broadcasting math

NumPy/PyTorch broadcasting rules — align shapes from the right, pad 1s on the left, size-1 dims stretch:

\[[3, 1, 4] + [2, 4] \to [3, 1, 4] + [1, 2, 4] \to [3, 2, 4]\]

TensorIterator computes the output shape at build() time and handles the stride-0 trick for expanded dims automatically.

dimension collapsing

if adjacent dimensions are both contiguous in all tensors, TensorIterator merges them into one flat loop:

before:

for (int i = 0; i < 3; i++)
    for (int j = 0; j < 4; j++)
        out[i,j] = f(a[i,j])

after collapsing:

for (int k = 0; k < 12; k++)   // 3×4 = 12
    out[k] = f(a[k])

fewer loop overhead instructions, better CPU prefetching, trivially vectorizable.

vectorized kernel example

at::native::cpu_kernel_vec(
    iter,
    // scalar path — handles remainder elements
    [](float a, float b) -> float { return a + b; },
    // vectorized path — processes 8 floats at once (AVX)
    [](Vectorized<float> a, Vectorized<float> b) { return a + b; }
);

TensorIterator calls the vectorized path for as many full 8-element chunks as possible, then falls back to the scalar path for the remainder. the infrastructure also compiles the kernel multiple times for different ISAs (SSE4, AVX2, AVX-512) and selects at runtime based on what the CPU supports.


8. CUDA caching allocator — why training doesn’t stall on every allocation

cudaMalloc is not like malloc. it’s a synchronizing call:

  1. it waits for all pending CUDA operations to complete before returning
  2. it goes: Python → CUDA runtime → GPU driver → OS kernel
  3. typical latency: tens of microseconds to low milliseconds

a simple elementwise op on a small GPU tensor takes ~5-10 μs. if every torch.randn(..., device='cuda') cost even 50 μs, a training loop with thousands of allocations would be dominated by malloc overhead.

the fix: CUDACachingAllocator

flowchart TD
    ALLOC["tensor allocation request\n(e.g. 4MB)"] --> CACHE{"block in\nfree list?"}
    CACHE -->|yes| FAST["return cached block\nO(log n), no syscall\n~100ns"]
    CACHE -->|no| SLOW["cudaMalloc new block\n~50μs+ synchronizing call"]
    SLOW --> FAST
    FREE["tensor freed\n(del x or goes out of scope)"] --> POOL["return block to free list\nno cudaFree\n~100ns"]
    EMPTY["torch.cuda.empty_cache()"] --> CFREE["cudaFree all free list blocks\nreturns VRAM to OS"]
x = torch.randn(1000, 1000, device='cuda')   # from pool or cudaMalloc once
del x                                          # goes to free list, not cudaFree
y = torch.randn(1000, 1000, device='cuda')   # very likely reuses x's block

you can inspect the allocator state:

stats = torch.cuda.memory_stats()
# 'allocated_bytes.all.current'  — bytes in live tensors
# 'reserved_bytes.all.current'   — bytes held by allocator (free list + live)
# difference = free list = "yours but unused"

torch.cuda.memory_snapshot()  # per-block breakdown

[!tip] torch.cuda.empty_cache() does NOT free memory used by live tensors. it only returns the free list to the OS. if you’re OOM, the issue is live tensors, not the cache. run torch.cuda.memory_summary() to see what’s actually allocated.


9. GIL and dispatch overhead

CUDA ops are async — when you call torch.matmul(a, b) on CUDA, the CPU launches the kernel and returns immediately without waiting for the GPU. so the GPU is doing compute while the CPU is already dispatching the next op.

the bottleneck isn’t the kernel. it’s the overhead of getting to the point of launching the kernel:

Python: z = x + y
  │
  ▼  acquire GIL
  ▼  C++ binding layer         ~1-5 μs
  ▼  Dispatcher lookup         ~200-500 ns
  ▼  VariableType recording    ~1 μs
  ▼  CUDA kernel launch        ~5-10 μs
  ▼  return to Python

for a big matmul this ~10 μs is irrelevant — the kernel runs for milliseconds. but many small ops:

# BAD — 10k dispatch round-trips
for i in range(10000):
    result[i] = result[i] * 2.0 + 1.0
# overhead alone: ~10000 × 10μs = ~100ms

# GOOD — one dispatch, C++ loop handles the 10k iterations
result = result * 2.0 + 1.0
# overhead: ~10μs

torch.compile (PT2.0)

torch.compile is the permanent fix:

flowchart LR
    PY["Python model code"] --> DY["Dynamo\n(traces Python,\nrecords torch ops)"]
    DY --> FX["FX Graph\n(intermediate representation\nof all torch ops)"]
    FX --> IND["TorchInductor\n(compiles graph to\nTriton / C++ code)"]
    IND --> TRIT["Triton kernels\n(GPU) or C++ loops\n(CPU)"]
model = torch.compile(my_model)

# first call: traces + compiles (~seconds)
out = model(x)

# subsequent calls: compiled path
# Python → compiled artifact → GPU kernel
# no dispatcher, no VariableType, no Python overhead
out = model(x)

[!important] the compiled path doesn’t go through the dispatcher at all. it goes directly from the compiled artifact to the CUDA kernel. the speedup on small-op-heavy models can be 2-5x just from eliminating dispatch overhead.


10. codebase map

pytorch/
├── torch/                       python surface
│   ├── nn/                      Module, Linear, Conv2d...
│   ├── optim/                   Adam, SGD...
│   └── csrc/                    C++ glue
│       ├── autograd/            backward engine — engine.cpp is the heart
│       ├── jit/                 TorchScript + torch.compile frontend (Dynamo)
│       └── api/                 LibTorch (C++ frontend)
│
├── aten/                        "A Tensor Library" — all the math
│   └── src/ATen/
│       ├── native/              start here. modern C++11 kernels
│       │   ├── cpu/             CPU-specific vectorized kernels
│       │   ├── cuda/            CUDA .cu files
│       │   └── *.cpp            device-agnostic logic
│       ├── TH/ THCUNN/          legacy C code — do not touch, do not add to
│       └── core/                dispatcher implementation
│
└── c10/                         "Core" — zero NN knowledge, just memory + types
    └── core/
        ├── TensorImpl.h         the main struct
        ├── StorageImpl.h        the memory blob
        └── Dispatcher.h         dispatch table machinery

search strategy:

  • find a kernel: grep -r "native::relu" aten/src/ATen/native/
  • find the backward formula: tools/autograd/derivatives.yaml
  • find generated bindings: not in repo, appears in torch/csrc/autograd/generated/ after building
  • find dispatcher registration: grep -r "TORCH_LIBRARY_IMPL" aten/

[!note] if you see TH or THC prefix anywhere, that’s legacy C code from the pre-ATen era. the file likely lives in aten/src/ATen/TH/ and is compiled multiple times via #define scalar_t float, #define scalar_t int, etc. — that’s how they did dtype dispatch before templates. don’t add to it. porting TH functions to ATen native (C++11 + TensorIterator) is a legit contribution and maintainers actually love getting those PRs.


11. adding a new operator — complete workflow

to add scaled_abs(x, scale) → $\text{output}_i = \lvert x_i \rvert \cdot \text{scale}$:

step 1: declare in native_functions.yaml

- func: scaled_abs(Tensor self, float scale=1.0) -> Tensor
  variants: function, method
  dispatch:
    CPU: scaled_abs_cpu
    CUDA: scaled_abs_cuda

this auto-generates: Python binding, C++ header declaration, dispatch table registration, docs stub. you write zero binding code.

step 2: CPU kernel (aten/src/ATen/native/ScaledAbs.cpp)

Tensor scaled_abs_cpu(const Tensor& self, double scale) {
    TORCH_CHECK(self.is_floating_point(),
                "scaled_abs: expected float tensor, got ", self.dtype());

    auto out = at::empty_like(self);
    auto iter = TensorIteratorConfig()
        .add_output(out)
        .add_input(self)
        .build();

    AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "scaled_abs_cpu", ([&] {
        cpu_kernel_vec(
            iter,
            [scale](scalar_t x) -> scalar_t {
                return std::abs(x) * static_cast<scalar_t>(scale);
            },
            [scale](Vectorized<scalar_t> x) {
                return x.abs() * Vectorized<scalar_t>(scale);  // AVX path
            }
        );
    }));

    return out;
}

step 3: CUDA kernel (aten/src/ATen/native/cuda/ScaledAbs.cu)

Tensor scaled_abs_cuda(const Tensor& self, double scale) {
    TORCH_CHECK(self.is_floating_point(), "scaled_abs: expected float tensor");

    auto out = at::empty_like(self);
    auto iter = TensorIteratorConfig()
        .add_output(out).add_input(self).build();

    AT_DISPATCH_FLOATING_TYPES_AND2(
        ScalarType::Half, ScalarType::BFloat16,
        self.scalar_type(), "scaled_abs_cuda", ([&] {
            gpu_kernel(iter, [scale] GPU_LAMBDA (scalar_t x) -> scalar_t {
                return ::abs(x) * static_cast<scalar_t>(scale);
            });
        })
    );

    return out;
}

step 4: backward formula (tools/autograd/derivatives.yaml)

- name: scaled_abs(Tensor self, float scale) -> Tensor
  self: grad * self.sign() * scale

derivative of $\lvert x \rvert \cdot s$ w.r.t. $x$ is $\text{sign}(x) \cdot s$. done.

the codegen handles Python bindings, dispatch table entries, and autograd wiring. you wrote exactly the math and nothing else.


12. things that will ruin your workflow

[!warning] editing a .h file in ATen can cascade into recompiling 50%+ of the codebase, especially if CUDA files include it. a full CUDA rebuild is 1-2 hours on a good server. stick to .cpp/.cu files. set up ccache — it caches compiled objects and saves you when you accidentally touch a header.

other things:

  • building CUDA on a laptop: don’t. use a server with real CPUs and RAM
  • the CI gives feedback after 1-2 hours. set up local builds for anything iterative
  • grep and rg (ripgrep) are your navigation tools — the codebase is too big to hold in your head, you search your way around it

13. open questions

things I haven’t fully dug into yet:

  • how does Dynamo handle Python control flow? like if x.shape[0] > 0: in the model — does it recompile for different shapes or does it guard?
  • what does the Triton IR actually look like for a compiled matmul? is it similar to PTX or higher-level?
  • how does CUDACachingAllocator handle long-term fragmentation? if you allocate a lot of varied-size tensors over a training run, does the free list get fragmented?
  • when does TensorIterator decide NOT to collapse dimensions? what breaks the collapsing?
  • the autograd engine has a thread pool for the backward pass — how does it partition work across nodes in the graph? does it parallelize backward across independent subgraphs?

GitHub · RSS