Skip to content

pytorch/executorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ExecuTorch logo mark

ExecuTorch

On-device AI inference powered by PyTorch

PyPI - Version GitHub - Contributors GitHub - Stars Discord - Chat with Us Documentation

ExecuTorch is PyTorch's unified solution for deploying AI models on-deviceโ€”from smartphones to microcontrollersโ€”built for privacy, performance, and portability. It powers Meta's on-device AI across Instagram, WhatsApp, Quest 3, Ray-Ban Meta Smart Glasses, and more.

Deploy LLMs, vision, speech, and multimodal models with the same PyTorch APIs you already knowโ€”accelerating research to production with seamless model export, optimization, and deployment. No manual C++ rewrites. No format conversions. No vendor lock-in.

๐Ÿ“˜ Table of Contents

Why ExecuTorch?

  • ๐Ÿ”’ Native PyTorch Export โ€” Direct export from PyTorch. No .onnx, .tflite, or intermediate format conversions. Preserve model semantics.
  • โšก Production-Proven โ€” Powers billions of users at Meta with real-time on-device inference.
  • ๐Ÿ’พ Tiny Runtime โ€” 50KB base footprint. Runs on microcontrollers to high-end smartphones.
  • ๐Ÿš€ 12+ Hardware Backends โ€” Open-source acceleration for Apple, Qualcomm, ARM, MediaTek, Vulkan, and more.
  • ๐ŸŽฏ One Export, Multiple Backends โ€” Switch hardware targets with a single line change. Deploy the same model everywhere.

How It Works

ExecuTorch uses ahead-of-time (AOT) compilation to prepare PyTorch models for edge deployment:

  1. ๐Ÿงฉ Export โ€” Capture your PyTorch model graph with torch.export()
  2. โš™๏ธ Compile โ€” Quantize, optimize, and partition to hardware backends โ†’ .pte
  3. ๐Ÿš€ Execute โ€” Load .pte on-device via lightweight C++ runtime

Models use a standardized Core ATen operator set. Partitioners delegate subgraphs to specialized hardware (NPU/GPU) with CPU fallback.

Learn more: How ExecuTorch Works โ€ข Architecture Guide

Quick Start

Installation

pip install executorch

For platform-specific setup (Android, iOS, embedded systems), see the Quick Start documentation for additional info.

Export and Deploy in 3 Steps

import torch
from executorch.exir import to_edge_transform_and_lower
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

# 1. Export your PyTorch model
model = MyModel().eval()
example_inputs = (torch.randn(1, 3, 224, 224),)
exported_program = torch.export.export(model, example_inputs)

# 2. Optimize for target hardware (switch backends with one line)
program = to_edge_transform_and_lower(
    exported_program,
    partitioner=[XnnpackPartitioner()]  # CPU | CoreMLPartitioner() for iOS | QnnPartitioner() for Qualcomm
).to_executorch()

# 3. Save for deployment
with open("model.pte", "wb") as f:
    f.write(program.buffer)

# Test locally via ExecuTorch runtime's pybind API (optional)
from executorch.runtime import Runtime
runtime = Runtime.get()
method = runtime.load_program("model.pte").load_method("forward")
outputs = method.execute([torch.randn(1, 3, 224, 224)])

Run on Device

C++

#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor.h>

Module module("model.pte");
auto tensor = make_tensor_ptr({2, 2}, {1.0f, 2.0f, 3.0f, 4.0f});
auto outputs = module.forward({tensor});

Swift (iOS)

let module = Module(filePath: "model.pte")
let input = Tensor<Float>([1.0, 2.0, 3.0, 4.0])
let outputs: [Value] = try module.forward([input])

Kotlin (Android)

val module = Module.load("model.pte")
val inputTensor = Tensor.fromBlob(floatArrayOf(1.0f, 2.0f, 3.0f, 4.0f), longArrayOf(2, 2))
val outputs = module.forward(EValue.from(inputTensor))

LLM Example: Llama

Export Llama models using the export_llm script or Optimum-ExecuTorch:

# Using export_llm
python -m executorch.extension.llm.export.export_llm --model llama3_2 --output llama.pte

# Using Optimum-ExecuTorch
optimum-cli export executorch \
  --model meta-llama/Llama-3.2-1B \
  --task text-generation \
  --recipe xnnpack \
  --output_dir llama_model

Run on-device with the LLM runner API:

C++

#include <executorch/extension/llm/runner/text_llm_runner.h>

auto runner = create_llama_runner("llama.pte", "tiktoken.bin");
executorch::extension::llm::GenerationConfig config{
    .seq_len = 128, .temperature = 0.8f};
runner->generate("Hello, how are you?", config);

Swift (iOS)

let runner = TextRunner(modelPath: "llama.pte", tokenizerPath: "tiktoken.bin")
try runner.generate("Hello, how are you?", Config {
    $0.sequenceLength = 128
}) { token in
    print(token, terminator: "")
}

Kotlin (Android) โ€” API Docs โ€ข Demo App

val llmModule = LlmModule("llama.pte", "tiktoken.bin", 0.8f)
llmModule.load()
llmModule.generate("Hello, how are you?", 128, object : LlmCallback {
    override fun onResult(result: String) { print(result) }
    override fun onStats(stats: String) { }
})

For multimodal models (vision, audio), use the MultiModal runner API which extends the LLM runner to handle image and audio inputs alongside text. See Llava and Voxtral examples.

See examples/models/llama for complete workflow including quantization, mobile deployment, and advanced options.

Next Steps:

Platform & Hardware Support

Platform Supported Backends
Android XNNPACK, Vulkan, Qualcomm, MediaTek, Samsung Exynos
iOS XNNPACK, MPS, CoreML (Neural Engine)
Linux / Windows XNNPACK, OpenVINO, CUDA (experimental)
macOS XNNPACK, MPS, Metal (experimental)
Embedded / MCU XNNPACK, ARM Ethos-U, NXP, Cadence DSP

See Backend Documentation for detailed hardware requirements and optimization guides.

Production Deployments

ExecuTorch powers on-device AI at scale across Meta's family of apps, VR/AR devices, and partner deployments. View success stories โ†’

Examples & Models

LLMs: Llama 3.2/3.1/3, Qwen 3, Phi-4-mini, LiquidAI LFM2

Multimodal: Llava (vision-language), Voxtral (audio-language)

Vision/Speech: MobileNetV2, DeepLabV3, Whisper

Resources: examples/ directory โ€ข executorch-examples out-of-tree demos โ€ข Optimum-ExecuTorch for HuggingFace models

Key Features

ExecuTorch provides advanced capabilities for production deployment:

  • Quantization โ€” Built-in support via torchao for 8-bit, 4-bit, and dynamic quantization
  • Memory Planning โ€” Optimize memory usage with ahead-of-time allocation strategies
  • Developer Tools โ€” ETDump profiler, ETRecord inspector, and model debugger
  • Selective Build โ€” Strip unused operators to minimize binary size
  • Custom Operators โ€” Extend with domain-specific kernels
  • Dynamic Shapes โ€” Support variable input sizes with bounded ranges

See Advanced Topics for quantization techniques, custom backends, and compiler passes.

Documentation

Community & Contributing

We welcome contributions from the community!

  • ๐Ÿ’ฌ GitHub Discussions โ€” Ask questions and share ideas
  • ๐ŸŽฎ Discord โ€” Chat with the team and community
  • ๐Ÿ› Issues โ€” Report bugs or request features
  • ๐Ÿค Contributing Guide โ€” Guidelines and codebase structure

License

ExecuTorch is BSD licensed, as found in the LICENSE file.




Part of the PyTorch ecosystem

GitHub โ€ข Documentation