Back to Blog

DeepSeek's open-source week

34 min read
Guide

DeepSeek open-sourced several libraries this week. While most of the techniques have appeared in earlier DeepSeek papers, it's fantastic to see the actual code released openly. It means we now have concrete examples to build scalable LLM pipelines.

Table of Contents

Day 1: FlashMLA - Efficient Multi-Head Latent Attention

Before discussing FlashMLA, let's understand why Key-Value (KV) cache optimization is crucial for LLM inference.

When generating autoregressively, naïve decoding recomputes self-attention over the full prefix at each step: step tt costs O(t2)O(t^2), totaling O(n3)O(n^3) for nn tokens. KV caching stores keys/values from prior steps so the new token only computes its own K,VK,V and attends once to cached keys: step tt becomes O(t)O(t), totaling O(n2)O(n^2). The trade-off is O(nLHdk)O(n\cdot L\cdot H\cdot d_k) memory to store caches across layers and heads (for both KK and VV).

To reduce the memory burden of KV caches, several approaches have been proposed. Unlike previous approaches like Multi-Query Attention (MQA) [1] or Group-Query Attention (GQA) [1] that reduce memory usage through head sharing, MLA [1] employs a clever low-rank latent compression technique that preserves the benefits of multiple attention heads.

FlashMLA is an efficient implementation of Multi-Head Latent Attention (MLA), designed for Hopper GPUs and optimized for variable-length sequences. It provides a CUDA kernel that achieves high performance: up to ~3 TB/s memory-bound throughput on Hopper-class GPUs. On H100 (HBM3 ~3.3 TB/s), this approaches the hardware limit. This is a significant improvement over previous attention implementations, like FlashAttention 2 and 3, which achieve bandwidth utilizations between 35 and 75% [1].

One interesting aspect of FlashMLA's implementation is double buffering for shared key/value tensors.

FLASH_ASSERT(params.k_ptr == params.v_ptr);  // Shared_KV
// ...

if (n_block % 2 == 1) {
    // Double buffer for sK
    constexpr int sK_offset = size(sK);
    tSrK.data() = tSrK.data() + sK_offset / 8;
    tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
  • FLASH_ASSERT(params.k_ptr == params.v_ptr) shows that Key and Value tensors share the same memory space -> halving the memory requirement.
  • if (n_block % 2 == 1) implements ping-pong buffering by simply toggling based on block number parity (while one buffer is being used for computation, the other can load data).
  • the division by 8 in both pointer adjustments (tSrK.data() + sK_offset / 8 and tOrVt.data() + sK_offset / 8) directly implements the MLA's latent compression ratio, allowing the code to work with the compressed representations while maintaining proper memory alignment.
  • adjusting both Key and Value tensor pointers by the same offset makes each key matching its corresponding value even after switching buffers.

FlashMLA supports BF16 and FP16 precision and employs a paged KV cache with a block size of 64, meaning each memory block can store key-value data for 64 tokens. This paged approach divides the KV cache into fixed-size blocks that can be dynamically allocated and deallocated, enabling handling of variable-length sequences [1]. In their Hopper-targeted kernels, 64 tokens is used as the block size; this is a design choice and can vary across implementations.

Quiz on FlashMLA

Question 1 of 60%

What is the primary advantage of MLA over MQA/GQA approaches?

Day 2: DeepEP - Expert Parallelism Communication Library

In MoE architectures [1, 1, 1], each input token is routed to multiple experts based on its characteristics. For example, DeepSeek-V3 [1] has 671 billion parameters distributed across 257 experts (1 shared expert and 256 routed experts). Each token activates 8 different experts, but to manage communication costs, each token is limited to accessing experts on at most 4 nodes.

DeepEP is a communication library that optimizes communication patterns in MoE models, specifically the "dispatch" and "combine" operations, which route data between GPUs.

  • Dispatch: send token embeddings from their original GPU to the GPUs that host the selected experts.
  • Combine: After the experts process the tokens, gather the results back to their original GPUs.

These communication steps can bottleneck. With NN GPUs, naive routing tends toward O(N2)O(N^2) traffic. If you do not overlap comms with compute, GPUs wait.

To fix this, DeepEP keeps three simple tactics:

  • Route on the fastest links. Inside a node it uses NVLink; across nodes it uses RDMA. Each token touches at most four nodes to tame cross‑node traffic.
  • Pick the right kernel for the job. Training/prefill kernels mix NVLink and RDMA, saturate NVLink, and reach ~90% of 400 Gbps RDMA (≈43–47 GB/s). Inference uses RDMA‑only low‑latency kernels. As experts scale from 8 to 256, dispatch latency rises ~19%; combine stays ~318–369 μs with 80–92% RDMA use.
  • Spend most SMs on math. About 20 of 132 SMs handle communication; the rest stay free for compute.

They provide a PyTorch-hook-based interface for overlapping communication and computation where the actual data transfer doesn't complete until you explicitly call the hook, letting the data transfer start early.

def low_latency_dispatch(hidden_states: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int):
  global _buffer

  # Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
  recv_hidden_states, recv_expert_count, handle, event, hook = \
      _buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts,
                                    async_finish=False, return_recv_hook=True)
  # NOTES: the actual tensor will not be received only if you call `hook()`,
  # it is useful for double-batch overlapping, but without any SM occupation
  # If you don't want to overlap, please set `return_recv_hook=False`
  # Later, you can use our GEMM library to do the computation with this specific format
  return recv_hidden_states, recv_expert_count, handle, event, hook

def low_latency_combine(hidden_states: torch.Tensor,
                        topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: Tuple):
    global _buffer
    # Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)
    combined_hidden_states, event_overlap, hook = \
        _buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle,
                                    async_finish=False, return_recv_hook=True)
    # NOTES: the same behavior as described in the dispatch kernel
    return combined_hidden_states, event_overlap, hook

Lastly, they introduce traffic isolation and adaptive routing to dynamically manage network paths and isolate critical communication flows. Traffic isolation segregates different workloads (normal kernels, low-latency kernels, and other workloads) across different virtual lanes to prevent interference. Adaptive routing dynamically distributes traffic across multiple network paths to avoid congestion.

Quiz on DeepEP

Question 1 of 60%

What are the two primary operations in MoE that DeepEP optimizes?

Day 3: DeepGEMM - FP8 Matrix Multiplication

Next, we turn to a more fundamental operation: matrix multiplication. General Matrix Multiplication (GEMM) operations are the backbone of LLMs (or any NN, really).

DeepGEMM is a specialized library for FP8 matrix multiplication exclusively for Hopper GPUs. Despite its conciseness of ~300 lines of CUDA code, it delivers impressive performance compared to NVIDIA's GEMM library CUTLASS:

  • Up to 2.7× speedup for small or irregular matrices (particularly important for inference)
  • 1.0×-1.7× speedup for large, compute-bound matrices
  • 1.1×-1.2× speedup for MoE grouped GEMMs

Let's explore what makes this possible:

Fine-grained FP8 quantization

Used in the DeepSeek-V3 paper [1] and conceptually inspired by Rouhani et al. [1], DeepGEMM employs a fine-grained quantization strategy to address the limited dynamic range of FP8 precision, which often leads to overflow and underflow issues. Instead of applying a single global scaling factor, DeepGEMM applies scaling at a more granular level, significantly improving quantization accuracy and numerical stability.

DeepGEMM quantizes activations and weights using separate scaling factors at different granularities. Activations are quantized per token across groups of 128 channels (1 × 128 tiles), while weights are quantized per 128 input channels by 128 output channels (128 × 128 blocks).

Formally, given an input activation matrix XRM×KX \in \mathbb{R}^{M \times K} and a weight matrix WRK×NW \in \mathbb{R}^{K \times N}, the quantization is performed as follows:

Xm,(c:c+127)=sm,cXFP8(Xm,(c:c+127)sm,cX),W(c:c+127),(n:n+127)=sc,nWFP8(W(c:c+127),(n:n+127)sc,nW).\begin{align} X_{m, (c:c+127)} &= s^X_{m,c} \cdot \text{FP8}\left(\frac{X_{m,(c:c+127)}}{s^X_{m,c}}\right), \\ W_{(c:c+127),(n:n+127)} &= s^W_{c,n} \cdot \text{FP8}\left(\frac{W_{(c:c+127),(n:n+127)}}{s^W_{c,n}}\right). \end{align}

Here, sm,cXs^X_{m,c} and sc,nWs^W_{c,n} are scaling factors stored in BF16 precision, corresponding to each activation tile and weight block, respectively.

The resulting FP8 GEMM operation is computed as:

(XW)m,n=c=1Ksm,cXsc,nWFP8(Xm,csm,cX)FP8(Wc,nsc,nW)(XW)_{m,n} = \sum_{c=1}^{K} s^X_{m,c} s^W_{c,n} \cdot \text{FP8}\left(\frac{X_{m,c}}{s^X_{m,c}}\right) \cdot \text{FP8}\left(\frac{W_{c,n}}{s^W_{c,n}}\right)

This fine-grained quantization approach improves numerical stability compared to naive FP8 quantization. First, by adapting scaling factors to smaller groups of elements, the method effectively accommodates activation outliers. Lastly, intermediate FP8 results are periodically promoted to FP32 precision on CUDA cores, mitigating accumulation errors.

Mixture-of-Experts Support

DeepGEMM provides specialized support for Mixture-of-Experts (MoE) models, leveraging two optimized grouped GEMM implementations tailored to different stages of model usage. The first is a contiguous layout designed specifically for training and prefilling phases, where tokens assigned to each expert are concatenated for efficient processing. The second is a masked layout for inference scenarios, particularly when using CUDA graphs, allowing the model to process only valid tokens per expert efficiently.

According to evaluations presented in the DeepSeek-V3 paper [1], this specialized MoE support delivers a notable performance improvement, achieving a 1.1-1.2× speedup compared to expertly tuned CUTLASS implementations.

Unaligned Block Sizes

Unlike most GEMM libraries that use power-of-2 block sizes (128, 256), DeepGEMM supports unaligned block sizes like 112. This seemingly minor optimization dramatically improves GPU utilization for irregular matrix shapes common for inference. For example, with M=256, N=7168, this enables using 128 SMs instead of just 112.

The SM utilization can be calculated as:

SM Utilization=M/BLOCK_M×N/BLOCK_NTotal_SMs×100%\text{SM Utilization} = \frac{\lceil M / \text{BLOCK\_M} \rceil \times \lceil N / \text{BLOCK\_N} \rceil}{\text{Total\_SMs}} \times 100\%

With traditional aligned block sizes (BLOCK_M=128, BLOCK_N=128): SM Utilization=256/128×7168/128132×100%=2×56132×100%84.8%\text{SM Utilization} = \frac{\lceil 256/128 \rceil \times \lceil 7168/128 \rceil}{132} \times 100\% = \frac{2 \times 56}{132} \times 100\% \approx 84.8\%

With unaligned block sizes (BLOCK_M=128, BLOCK_N=112): SM Utilization=256/128×7168/112132×100%=2×64132×100%97.0%\text{SM Utilization} = \frac{\lceil 256/128 \rceil \times \lceil 7168/112 \rceil}{132} \times 100\% = \frac{2 \times 64}{132} \times 100\% \approx 97.0\%

This 12.2% increase in SM utilization directly translates to performance gains.

Just-In-Time Compilation

DeepGEMM generates kernels tailored specifically to the dimensions of each matrix multiplication at runtime. Unlike traditional GEMM libraries, which pre-compile generic kernels designed to handle arbitrary matrix shapes and thus incur overhead from branches, loops, and additional logic, DeepGEMM's approach optimizes precisely for the given operation dimensions. By treating matrix dimensions (M, N, K) as compile-time constants, the compiler can fully unroll loops, eliminate boundary checks, and achieve optimal register allocation tailored exactly to the specific shape.

Additionally, DeepGEMM automatically determines the ideal parameters for each matrix shape, including optimal block dimensions—such as unaligned sizes like 112—to maximize GPU utilization, the best number of pipeline stages based on arithmetic intensity, and the most suitable Tensor Memory Accelerator (TMA) cluster size. This fine-grained optimization significantly enhances performance, particularly for small matrices commonly encountered during inference, where generic kernels typically waste computational cycles on unnecessary logic.

Instruction-Level Optimization

At the lowest level, DeepGEMM optimizes CUDA instruction scheduling. Specifically, it modifies the yield and reuse bits of FFMA (Fused Multiply-Add) instructions directly within the compiled binary, a technique that enhances warp-level parallelism through precise instruction interleaving.

Each CUDA instruction includes control bits that influence scheduling behavior: the yield bit determines whether the Streaming Multiprocessor (SM) can yield the current warp after executing the instruction, while the reuse bit indicates whether registers can be immediately reused or must wait. By systematically adjusting these instruction bits based on patterns identified through binary analysis, DeepGEMM controls the execution schedule. This approach allows memory operations, such as loading matrix elements, to overlap seamlessly with compute operations performed by tensor cores, which in turn overlap with promotion operations involving CUDA core accumulation. Such precise timing is useful because GPU hardware contains separate units capable of executing these operations concurrently, but only if explicitly instructed to do so. The resulting synchronization and overlap of operations keep all GPU units actively engaged, eliminating idle waiting periods.

Quiz on DeepGEMM

Question 1 of 70%

In DeepGEMM's fine-grained scaling approach for FP8 matrix multiplication, how are scaling factors applied?

Day 4: Optimized Parallelism Strategies (DualPipe and EPLB)

Having optimized the core computational building blocks of LLMs, we now turn to the challenge of orchestrating these operations efficiently across multiple GPUs. Day 4 introduces two complementary techniques: DualPipe for efficient pipeline parallelism, and EPLB for balanced expert distribution.

DualPipe: Bidirectional Pipeline Parallelism

DualPipe introduces a bidirectional pipeline parallelism algorithm that achieves full overlap of forward and backward computation-communication phases, significantly reducing pipeline bubbles.

In traditional pipeline parallelism, you split a large NN across multiple GPUs, with each GPU handling a different segment of the model. The problem is that GPUs often end up waiting around, either for data from earlier stages during the forward pass or for gradients from later stages during backpropagation. These idle periods are what we call pipeline bubbles.

Example: Consider a model split across 2 GPUs, where GPU 2 can't start processing the forward pass of batch A until GPU 1 finishes its part. During backprop, the same issue happens in reverse—GPUs wait for gradients from downstream GPUs. These waiting periods create idle gaps, or "bubbles", in the pipeline. The more pipeline stages you have (like in very large models), the more bubbles you get, making scaling inefficient.

DualPipe tackles this issue by running micro-batches in both forward and backward directions simultaneously. This means that forward passes of new batches overlap with backward passes of previous batches. Effectively, it reduces idle time, by overlapping GPU-to-GPU communication with computation.

DualPipe Schedule The figure shows DualPipe scheduling with 8 pipeline parallel ranks and 20 micro-batches running in two directions. Cells sharing a black border indicate overlapping computation and communication.

To overlap communication and computation, DualPipe breaks data transfers into smaller chunks ("micro-batch streaming"), allowing computation to start on the first chunks while later chunks are still transferring. By using multiple CUDA streams, communication and computation run asynchronously on separate GPU threads. Further, the backward pass gets split into two distinct parts: input-gradients calculate gradients to pass upstream, while weight-gradients calculates gradients for updating the current layer's parameters.

By separating these, input gradients can be sent upstream immediately, speeding up the backward pass in earlier pipeline stages.

Traditional pipelines require F+BF + B time per micro-batch (forward and backward sequentially), while DualPipe approaches max(F,B)\max(F, B) through parallel execution. The efficiency difference is substantial:

  • Traditional 1F1B: Bubble time = (PP1)(F+B)(PP-1)(F+B)
    • With PP=8PP=8: 7(F+B)7(F+B) time wasted
  • DualPipe: Bubble time = (PP/21)(F&B+B3W)(PP/2-1)(F\&B+B-3W)
    • With PP=8PP=8: 3(F&B+B3W)3(F\&B+B-3W) time wasted

Where:

  • PPPP denotes the number of pipeline parallel stages
  • FF denotes the execution time of a forward chunk
  • BB denotes the execution time of a full backward chunk
  • WW denotes the execution time of a "backward for weights" chunk
  • F&BF\&B denotes the execution time of two mutually overlapped forward and backward chunks

DualPipe reduces pipeline bubbles by over 50% in many scenarios, directly improving training efficiency. However, these speedups require additional memory. While traditional 1F1B pipelines store activations for approximately PPPP micro-batches total, DualPipe needs to maintain activations for concurrent forward and backward passes.

The per-device activation storage increases from PPPP to approximately PP+1PP+1 micro-batches. This trade-off works when training speed is the primary bottleneck and the memory increase remains manageable with distributed training

DualPipe also offers a variant called DualPipeV, a streamlined V-shape schedule derived through a "cut-in-half" procedure originally introduced by Sea AI Lab. It reduces device requirements from PPPP stages to PP/2PP/2 stages.

DualPipeV Schedule Example DualPipeV scheduling for 4 PP ranks (8 PP stages) and 10 micro-batches.

EPLB: Expert Parallelism Load Balancer

When using expert parallelism, one common challenge is evenly distributing the computational load across GPUs. The Expert Parallelism Load Balancer (EPLB) tackles this by duplicating heavily used experts.

EPLB provides two distinct strategies depending on the specific scenario:

  • Hierarchical load balancing: when server nodes evenly divide expert groups. First spread groups across nodes. Inside each node, duplicate hot experts onto separate GPUs. This mirrors group-limited expert routing from DeepSeek-V3 [1], which keeps tokens within a group to localize work. Keep a group's experts on the same node to cut cross-node traffic. Best during prefilling, when expert-parallel is smaller.
  • Global load balancing: when the hierarchy isn't feasible. Duplicate experts regardless of group and spread replicas across all GPUs. Suits decoding, where expert-parallel is larger.

Here's an illustration showing how EPLB assigns experts across nodes and GPUs:

EPLB Example In this example, we have a two-layer MoE model, each layer containing 12 experts. By adding 4 redundant experts per layer, we end up with 16 expert replicas in total. These replicas are distributed across 2 nodes, each with 4 GPUs, achieving balanced computational loads and minimizing communication overhead.

Quiz on DualPipe and EPLB

Question 1 of 60%

According to the pipeline bubble formula, how does DualPipe's bubble time compare to 1F1B when PP=8, F=B=10, W=3, and F&B=12?

Day 5: 3FS and Smallpond - Data Infra

Fire-Flyer File System (3FS)

Traditional file systems like HDFS or Lustre are not well-suited for pretraining workflows because they rely on data locality, requiring compute nodes to be physically close to their data.

3FS addresses this via a fully disaggregated architecture. It aggregates the bandwidth and storage capacity of thousands of high-speed SSDs across hundreds of storage nodes, allowing compute nodes to access data quickly without worrying about exact data placement.

The graph shows read throughput over time on a 180-node 3FS cluster, peaking at ~7 TiB/s. This exceptional performance makes 3FS ideal for high-throughput data access during training and inference.

Another key advantage of 3FS is its strong consistency model, powered by Chain Replication with Apportioned Queries (CRAQ). Unlike traditional chain replication, which bottlenecks reads at the tail node, CRAQ lets any node in the chain serve reads directly. Nodes maintain both fully committed ("clean") and in-progress ("dirty") data versions. Clean data can be served instantly, while dirty data triggers a quick version check with the tail node. This design ensures high throughput and strong consistency—essential for AI workloads where data inconsistencies can cause subtle bugs or degrade model performance.

Performance-wise, 3FS is really impressive:

  • Achieves around 6.6 TiB/s total read throughput on a 180-node cluster.
  • Provides up to 40 GiB/s per client node for KVCache lookups, ideal for inference workloads.

Figure: KVCache read throughput performance showing peak reads reaching up to 40 GiB/s per client node (dotted line) and average reads (solid line) over a 30-minute period. The consistent high peak performance demonstrates 3FS's capability to handle intensive inference workloads efficiently.

3FS is particularly useful for:

  • Dataloaders: Efficient random access to training samples across multiple nodes, simplifying data loading logic.
  • Checkpointing: Fast, parallel checkpointing essential for training large models.
  • KV Caching for Inference: Cost-effective alternative to DRAM-based caching, offering high throughput and larger storage capacity. This approach enables compute nodes to fetch cached key-value pairs on-the-fly from 3FS with minimal latency penalty, eliminating the need to store all data in RAM.

Suppose you are training a language model on a 10 TB text dataset. Without 3FS, you might have to shard this dataset across local disks on each training machine and ensure each GPU gets data from its local storage to avoid network slowdown. With 3FS, you can put all 10 TB into the file system and have every node read from it at will:

  • On each training machine (say you have 16 of them), you mount 3FS at /data.
  • In your training code, you load files from /data/my_corpus/shardX.txt (where X can be part of the dataset). Each DataLoader worker will open a file via 3FS, read a chunk of data. 3FS internally will fetch that chunk from the appropriate storage server (maybe multiple in parallel if the file's chunks span servers).
  • If each server can do, say, 50 GB/s, and you have 10 servers, an aggregate of 500 GB/s is available. Each of the 16 nodes could in theory pull >30 GB/s if needed. Practically, maybe each node might use 5-10 GB/s to saturate the GPUs. This is easily provided by 3FS.
  • Training proceeds without ever being starved for data. GPUs are busy 100% with compute rather than sometimes waiting for CPU/disk loading.

At checkpoint time, the trainer on each node saves its model partition to a file in /data/checkpoints/epoch10_rank7.ckpt. All nodes do this simultaneously. 3FS directs each write to perhaps different storage targets (to spread out load) and writes happen in parallel. The result is that you might checkpoint in, say, 30 seconds what used to take 5 minutes on an NFS.

Smallpond

Smallpond complements 3FS by offering a streamlined, distributed data processing framework built on DuckDB, an open-source RDBMS optimized for complex queries.

If you've used frameworks like Apache Spark or Hadoop, you'll appreciate Smallpond's simplicity. It doesn't require managing long-running services or complex dependencies. Instead, it provides a high-level API built on Ray, making distributed data processing easy:

import smallpond

sp = smallpond.init()
df = sp.read_parquet("path/to/dataset/*.parquet")
df = df.repartition(10)
df = df.map("x + 1")
df.write_parquet("path/to/output")

Day 6: DeepSeek-V3/R1 - Large-scale Cross-node Expert Parallelism and How to Operate at 545% Profit Margin

On the final day, DeepSeek published details about their inference system for DeepSeek-V3/R1 -- an unprecedentedly intimate look into an inference production system.

Diagram of DeepSeek's Online Inference System Architecture diagram of DeepSeek's production system.

  • Only 8 of 256 experts fire per layer, so parallelism matters.
    • Prefill: Routed Expert EP32 + MLA/Shared Expert DP32 across 4 nodes; each GPU hosts 9 routed and 1 shared expert.
    • Decode: Routed Expert EP144 + MLA/Shared Expert DP144 across 18 nodes; each GPU hosts 2 routed experts.
  • FP8 for feed-forward matrix multiplications and dispatch; BF16 for MLA and combine.
  • Dual-batch overlap hides cross-node EP latency: split each batch into two microbatches and run them alternately so compute of one hides the other's communication.

In the decoding phase, the execution durations of different stages become inherently unbalanced. To address this, DeepSeek subdivides the attention layer into two distinct computational steps and integrates them into a carefully orchestrated five-stage pipeline to, again, overlap computation and communication.

DeepSeek also employs specialized load balancing strategies: the Prefill Load Balancer equalizes token counts and computational demands, the Decode Load Balancer manages disparities in KVCache usage, and the Expert-Parallel Load Balancer duplicates heavily loaded experts across GPUs to mitigate imbalance.

DeepSeek's inference system dynamically scales resources based on daily load patterns, maximizing GPU utilization during peak hours and reallocating resources during off-peak periods. This approach significantly enhances cost efficiency and performance. In a 24-hour statistical period, the combined peak node occupancy for V3 and R1 inference services reached 278 nodes, with an average occupancy of 226.75 nodes (each containing 8 H800 GPUs). Assuming the leasing cost of one H800 GPU is $2 per hour, the total daily cost amounts to $87,072.

Within a typical 24-hour period, the system efficiently processes substantial workloads:

  • Total input tokens: 608B, of which 342B tokens (56.3%) hit the on-disk KV cache
  • Total output tokens: 168B
  • Average output speed: 20-22 tokens per second
  • Average KVCache length per output token: 4,989 tokens
  • Average prefilling throughput per H800 node: ~73.7k tokens/s (including cache hits)
  • Average decoding throughput per H800 node: ~14.8k tokens/s

The punchline is that this lets DeepSeek achieve a cost profit margin of 545% while being about ~96% cheaper than OpenAI's o1:

  • For input tokens with cache hit, R1 costs $0.14 per million
  • For input tokens with cache miss, R1 costs $0.55 per million, while o1 costs $15.00 per million
  • For output tokens, R1 costs $2.19 per million, while o1 costs $60.00 per million

However, DeepSeek notes that their actual revenue is substantially lower than these theoretical calculations because:

  • DeepSeek-V3's pricing is significantly lower than R1
  • Only a subset of services are monetized (web and APP access remain free)
  • Automatic nighttime discounts are applied during off-peak hours
Question 1 of 60%

Given DeepSeek-V3/R1's daily statistics (608B input tokens with 56.3% cache hit rate and 168B output tokens), what would be the theoretical daily revenue using their pricing model?