torch.fx.experimental#
Created On: Feb 07, 2024 | Last Updated On: Jan 09, 2026
Warning
These APIs are experimental and subject to change without notice.
- class torch.fx.experimental.sym_node.DynamicInt(val)[source]#
User API for marking dynamic integers in torch.compile. Intended to be compatible with both compile and eager mode.
Example usage:
fn = torch.compile(f) x = DynamicInt(4) fn(x) # compiles x as a dynamic integer input; returns f(4)
torch.fx.experimental.sym_node#
torch.fx.experimental.symbolic_shapes#
Controls how to perform symbol allocation for a dimension. |
|
For clients: the size at this dimension must be within 'vr' (which specifies a lower and upper bound, inclusive-inclusive) AND it must be non-negative and should not be 0 or 1 (but see NB below). |
|
For clients: no explicit constraint; constraint is whatever is implicitly inferred by guards from tracing. |
|
Represent and decide various kinds of equality constraints between input sources. |
|
Data structure specifying how we should create symbols in |
|
Create symbols in |
|
Create symbols in |
|
The correct symbolic context for a given inner tensor of a traceable tensor subclass may differ from that of the outer symbolic context. |
|
Custom solver for a system of constraints on symbolic dimensions. |
|
Encapsulates all shape env settings that could potentially affect FakeTensor dispatch. |
|
This class is used in multi-graph compilation contexts where we generate multiple specialized graphs and dispatch to the appropriate one at runtime. |
|
Retrieve the hint for an int (based on the underlying real values as observed at runtime). |
|
Utility to check if underlying object in SymInt is concrete value. |
|
Utility to check if underlying object in SymBool is concrete value. |
|
Utility to check if underlying object in SymInt is concrete value. |
|
Faster version of bool(free_symbols(val)) |
|
Faster version of bool(free_unbacked_symbols(val)) |
|
Try to guard a, if data dependent error encountered just return true. |
|
Try to guard a, if data dependent error encountered just return false. |
|
Perform a guard on a symbolic boolean expression in a size oblivious way. |
|
and, but for symbolic expressions, without bool casting. |
|
Like ==, but when run on list/tuple, it will recursively test equality and use sym_and to join the results together, without guarding. |
|
or, but for symbolic expressions, without bool casting. |
|
Applies a constraint that the passed in SymInt must lie between min-max inclusive-inclusive, WITHOUT introducing a guard on the SymInt (meaning that it can be used on unbacked SymInts). |
|
Given two SymInts, constrain them so that they must be equal. |
|
Canonicalize a boolean expression by transforming it into a lt / le inequality and moving all the non-constant terms to the rhs. |
|
Returns True if x can be simplified to a constant and is true. |
|
Returns True if x can be simplified to a constant and is False. |
|
User-code friendly utility to check if a value is static or dynamic. |
|
Test that two "meta" values (typically either Tensor or SymInt) have the same values, e.g., after retracing. |
|
After having run fake tensor propagation and producing example_value result, traverse example_value looking for freshly bound unbacked symbols and record their paths for later. |
|
Suppose we are retracing a pre-existing FX graph that previously had fake tensor propagation (and therefore unbacked SymInts). |
|
When we do fake tensor prop, we oftentimes will allocate new unbacked symints. |
|
Helper function to determine if a node is trying to access a symbolic integer such as size, stride, offset or item. |
|
Converts a SymBool or bool to a SymInt or int without introducing guards. |
|
Find all nodes in an FX graph that bind sympy Symbols. |
|
Recursively collect all free symbols from a value. |
|
Like free_symbols, but filtered to only report unbacked symbols |
|
Guard a scalar value, which can be a symbolic or concrete boolean, integer, or float. |
|
Check if a given FX node is a symbol binding node. |
|
torch.fx.experimental.proxy_tensor#
Given a function f, return a new function which when executed with valid arguments to f, returns an FX GraphModule representing the set of operations that were executed during the course of execution. |
|
Call into the currently active proxy tracing mode to do a SymInt/SymFloat/SymBool dispatch trace on a function that operates on these arguments. |
|
Current the currently active proxy tracing mode, or None if we are not currently tracing. |
|
Within this context manager, if you are doing make_fx tracing, we will thunkify all SymNode compute and avoid tracing it into the graph unless it is actually needed. |
|
Within a context, disable thunkification. |
|
Delays computation of f until it's called again Also caches the result |
|
FX gets confused by varargs, de-confuse it |
|
torch.fx.experimental.optimization#
Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. |
|
For each node, if it's a module that can be preconverted into MKLDNN, then we do so and create a mapping to allow us to convert from the MKLDNN version of the module to the original. |
|
Performs a set of optimization passes to optimize a model for the purposes of inference. |
|
Removes all dropout layers from the module. |
|
Maps each module that's been changed with modules_to_mkldnn back to its original. |
|
This is a heuristic that can be passed into optimize_for_inference that determines whether a subgraph should be run in MKL by checking if there are more than 2 nodes in it |
torch.fx.experimental.recording#
torch.fx.experimental.unification.core#
Replace variables of expression with substitution >>> x, y = var(), var() >>> e = (1, x, (3, y)) >>> s = {x: 2, y: 4} >>> reify(e, s) (1, 2, (3, 4)) >>> e = {1: x, 3: (y, 5)} >>> reify(e, s) {1: 2, 3: (4, 5)} |
torch.fx.experimental.unification.unification_tools#
Return a new dict with new key value pair |
|
Return a new dict with new, potentially nested, key value pair |
|
Return a new dict with the given key(s) removed. |
|
The first element in a sequence |
|
Filter items in dictionary by key |
|
Apply function to keys of dictionary |
|
Merge a collection of dictionaries |
|
Merge dictionaries and apply function to combined values |
|
Update value in a (potentially) nested dictionary |
|
Filter items in dictionary by value |
|
Apply function to values of dictionary |
|
Filter items in dictionary by item |
|
Apply function to items of dictionary |
torch.fx.experimental.migrate_gradual_types.transform_to_z3#
Transforms an algebraic expression to z3 format :param expr: An expression is either a dimension variable or an algebraic-expression |
|
Given a trace, generates constraints and transforms them to z3 format |
|
Takes a node and a graph and generates two sets of constraints. |
|
Takes a dimension variable or a number and transforms it to a tuple according to our scheme :param dimension: The dimension to be transformed :param counter: variable tracking |
|
Transforms tensor variables to a format understood by z3 :param tensor: Tensor variable or a tensor type potentially with variable dimensions |
|
Given an IR and a node representing a conditional, evaluate the conditional and its negation :param tracer_root: Tracer root for module instances :param node: The node to be evaluated |
torch.fx.experimental.migrate_gradual_types.constraint#
torch.fx.experimental.migrate_gradual_types.constraint_generator#
Constraints that match the input to a size 3 tensor and switch the dimensions according to the rules of batch multiplication |
|
The output shape differs from the input shape in the last dimension |
|
We generate the constraint: input = output |
|
We generate the exact constraints as we do for tensor additions but we constraint the rank of this expression to be equal to len(n.args[1:]) so that only those cases get considered for the output |
|
Similar to addition. |
|
Translates to inconsistent in gradual types. |
|
If the tensor is a scalar, we will skip it since we do not support scalars yet. |
|
We generate the constraint: input = output |
|
Similar to reshape but with an extra condition on the strides |
|
Can be considered as a sequence of two index selects, so we generate constraints accordingly |
torch.fx.experimental.migrate_gradual_types.constraint_transformation#
We are considering the possibility where one input has less dimensions than another input, so we apply padding to the broadcasted results |
|
Generates constraints for the last two dimensions of a convolution or a maxpool output :param constraint: CalcConv or CalcMaxPool :param d: The list of output dimensions |
|
Create equality constraints for when no broadcasting occurs :param e1: Input 1 :param e2: Input 2 :param e11: Broadcasted input 1 :param e12: Broadcasted input 2 :param d1: Variables that store dimensions for e1 :param d2: Variables that store dimensions for e2 :param d11: Variables that store dimensions for e11 :param d12: Variables that store dimensions for e22 |
|
Generate constraints to check if the target dimensions are divisible by the input dimensions :param target: Target dimensions :param dim: Input dimensions |
|
|
|
Transforms a constraint into a simpler constraint. |
|
generate an equality of the form: t = [a1, ..., an] then generate constraints that check if the given index is valid given this particular tensor size. |
|
When the index is a tuple, then the output will be a tensor TODO: we have to check if this is the case for all HF models |
|
The constraints consider the given tensor size, checks if the index is valid and if so, generates a constraint for replacing the input dimension with the required dimension |
|
Similar to a sequence of two index-selects |
|
Given a list of dimensions, checks if an index is valid in the list |
|
if the slice instances exceed the length of the dimensions then this is a type error so we return False |
|
Generate constraints to check if the input dimensions is divisible by the target dimensions :param target: Target dimensions :param dim: Input dimensions |
torch.fx.experimental.graph_gradual_typechecker#
The input and output sizes should be the same except for the last two dimensions taken from the input, which represent width and height |
|
For operations where the input shape is equal to the output shape |
|
Given a BatchNorm2D instance and a node check the following conditions: - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, x_3, x_4) - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') - t is consistent with t' - x_2 is consistent with the module's num_features - x_2' is consistent with the module's num_features output type: the more precise type of t and t' |
|
For calculating h_in and w_out according to the conv2D documentation |
|
The equality constraints are between the first dimension of the input and output |
|
Represents the output in terms of an algrbraic expression w.r.t the input when possible |
|
For element-wise operations and handles broadcasting. |
|
Expand a type to the desired tensor dimension if possible Raise an error otherwise. |
|
For operations where the first two dimensions of the input and output shape are equal |
|
We check that dimensions for the transpose operations are within range of the tensor type of the node |