Source code for executorch.exir.lowered_backend_module
# Copyright (c) Meta Platforms, Inc. and affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.# pyre-strictimportcopyimportoperatorfromtypingimportDict,List,Optional,Tuple,Unionimporttorchimporttorch.utils._pytreeaspytreefromexecutorch.exir._serializeimport_serialize_pte_binaryfromexecutorch.exir.backend.compile_spec_schemaimportCompileSpecfromexecutorch.exir.delegateimportexecutorch_call_delegate,get_lowered_module_namefromexecutorch.exir.emitimportemit_programfromexecutorch.exir.graph_moduleimport_get_submodulefromexecutorch.exir.passes.memory_planning_passimportMemoryPlanningPassfromexecutorch.exir.passes.spec_prop_passimportmake_spec,SpecPropPassfromexecutorch.exir.schemaimportProgramfromexecutorch.exir.tracerimportValuefromtorch._export.exported_programimportExportedProgramfromtorch._subclassesimportFakeTensorfromtorch.export.exported_programimport(ExportGraphSignature,InputKind,InputSpec,OutputKind,OutputSpec,TensorArgument,)fromtorch.fx.passes.utils.fuser_utilsimport(erase_nodes,fuse_as_graphmodule,insert_subgm,legalize_graph,NodeList,topo_sort,)
[docs]classLoweredBackendModule(torch.nn.Module):""" A subclass of nn.Module that is generated for modules containing delegated functions. This is can be created by calling `to_backend`. """_backend_id:str# The backend's name_processed_bytes:bytes# The delegate blobs created from backend.preprocess_compile_specs:List[CompileSpec]# A list of backend-specific objects with static metadata to configure the "compilation" process._original_module:ExportedProgram# The original EXIR moduledef__init__(self,edge_program:ExportedProgram,backend_id:str,processed_bytes:bytes,compile_specs:List[CompileSpec],)->None:super().__init__()self._original_module=edge_programself._backend_id=backend_idself._processed_bytes=processed_bytesself._compile_specs=compile_specs@propertydefbackend_id(self)->str:""" Returns the backends name. """returnself._backend_id@propertydefprocessed_bytes(self)->bytes:""" Returns the delegate blob created from backend.preprocess """returnself._processed_bytes@propertydefcompile_specs(self)->List[CompileSpec]:""" Returns a list of backend-specific objects with static metadata to configure the "compilation" process. """returnself._compile_specs@propertydeforiginal_module(self)->ExportedProgram:""" Returns the original EXIR module """returnself._original_module# TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api
[docs]defbuffer(self,extract_segments:bool=False,segment_alignment:int=4096,constant_tensor_alignment:Optional[int]=None,delegate_alignment:Optional[int]=None,)->bytes:""" Returns a buffer containing the serialized ExecuTorch binary. """out=_serialize_pte_binary(program=self.program(),extract_segments=extract_segments,segment_alignment=segment_alignment,constant_tensor_alignment=constant_tensor_alignment,delegate_alignment=delegate_alignment,)returnout
# TODO(chenlai): re-consider recapture instead of manually constructing the program because# the meta data construction is done manually.
[docs]defprogram(self,emit_stacktrace:bool=False)->Program:""" Returns the object that represents the ExecuTorch binary before serialization. """# Creates a new module based on the original module. The original module will# look something like following:## opcode name target args kwargs# ------------- ------------------- ---------------- ------------------------------------------ --------# placeholder arg0_1 arg0_1 () {}# placeholder arg1_1 arg1_1 () {}# call_function aten_repeat_default * (arg1_1, [4, 1]) {}# call_function aten_mul_tensor * (aten_repeat_default, aten_repeat_default) {}# call_function aten_add_tensor * (arg1_1, arg1_1) {}# output output output ([aten_mul_tensor, aten_add_tensor],) {}## if the whole module is lowered, the resulting lowered module look like## opcode name target args kwargs# ------------- ------------------------ --------------------------- ---------------------------------- --------# placeholder arg0_1 arg0_1 () {}# placeholder arg1_1 arg1_1 () {}# get_attr lowered_module_0 lowered_module_0 () {}# call_function executorch_call_delegate executorch_call_delegate (lowered_module_0, arg0_1, arg1_1) {}# call_function getitem <built-in function getitem> (executorch_call_delegate, 0) {}# call_function getitem_1 <built-in function getitem> (executorch_call_delegate, 1) {}# output output_1 output ([getitem, getitem_1],) {}## We'll remove all call_function nodes, insert an call_delegate node, inserting getitems nodes to get the result for call_delegate node# and return the list of getitems as the outputlowered_exported_program=copy.deepcopy(self.original_module)# The real input nodes are the ones not buffer or parameterall_input_nodes=[nodefornodeinlowered_exported_program.graph.nodesif(node.op=="placeholder"andnode.namenotinlowered_exported_program.graph_signature.inputs_to_buffersandnode.namenotinlowered_exported_program.graph_signature.inputs_to_parameters)]output_node=[nodefornodeinlowered_exported_program.graph.nodesifnode.op=="output"]assertlen(output_node)==1,"There should be only one output node"# Step 1. Cleaning up the graph before inserting the call_delegate node# Remove the original output nodelowered_exported_program.graph.erase_node(output_node[0])# Remove all the everything else except the inputfornodeinreversed(lowered_exported_program.graph.nodes):ifnode.op!="placeholder":lowered_exported_program.graph.erase_node(node)# Find placeholders that are parameters or buffers, remove them from the main graphfornodeinlowered_exported_program.graph.nodes:ifnode.op=="placeholder"and(node.nameinlowered_exported_program.graph_signature.inputs_to_buffersornode.nameinlowered_exported_program.graph_signature.inputs_to_parameters):lowered_exported_program.graph.erase_node(node)# Step 2. Start constructing the graphlowered_name=get_lowered_module_name(lowered_exported_program.graph_module,self)# Insert the lowered module to the graph module as an attibutelowered_node=lowered_exported_program.graph.get_attr(lowered_name)# Insert a call_delegate node to the graph module, with arguments from the arg listdelegate_node=lowered_exported_program.graph.call_function(executorch_call_delegate,(lowered_node,*all_input_nodes))# Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],)# We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properlyoriginal_output_nodes=[nodefornodeinself.original_module.graph.nodesifnode.op=="output"][0].args[0]delegate_node.meta["spec"]=tuple([make_spec(node.meta["val"])fornodeinoriginal_output_nodes])# The getitem nodes that are going to be inserted to the lowered graph modulegetitem_nodes=[]foriinrange(len(original_output_nodes)):getitem_node=lowered_exported_program.graph.call_function(operator.getitem,args=(delegate_node,i),)getitem_nodes.append(getitem_node)lowered_exported_program.graph.output(getitem_nodes)lowered_exported_program.graph_module.recompile()lowered_exported_program.graph.lint()# Users output will be the get items nodes insteadoutput_specs=[OutputSpec(kind=OutputKind.USER_OUTPUT,arg=TensorArgument(name=getitem_node.name),target=None,)forgetitem_nodeingetitem_nodes]# All data are consumed by the delegates so they should be removed from the state dict.inputs_to_parameters=(lowered_exported_program.graph_signature.inputs_to_parameters)inputs_to_buffers=lowered_exported_program.graph_signature.inputs_to_buffersinput_specs=[InputSpec(kind=InputKind.USER_INPUT,arg=TensorArgument(name=node.name),target=None,)foruser_inputinlowered_exported_program.graph_signature.user_inputsifuser_inputnotininputs_to_parametersanduser_inputnotininputs_to_buffers]# Double check the ExportedProgram data(especially everything except graph) is goodexported_program=ExportedProgram(root=lowered_exported_program.graph_module,graph=lowered_exported_program.graph,graph_signature=ExportGraphSignature(input_specs=input_specs,output_specs=output_specs),# TODO: May need to set lowered_exported_program.call_spec = CallSpec(None, None)# somewhere as we should pass it a list of tensors to the lowered module and output a# list of tensors. Putting call_spec=lowered_exported_program.call_spec is correct here as the# inputs/outputs to the toplevel program will be in the format of the eager module.state_dict={},# None because all data are consumed by delegaterange_constraints=lowered_exported_program.range_constraints,equality_constraints=lowered_exported_program.equality_constraints,module_call_graph=lowered_exported_program.module_call_graph,)exported_program=exported_program._transform(SpecPropPass(),MemoryPlanningPass("greedy"))emitted_program=emit_program(exported_program,emit_stacktrace=emit_stacktrace).programreturnemitted_program
# Used to patch each delegated function with a call_delegate call# @staticmethoddefforward(self,*args:Value,**kwargs:Tuple[Value,...],)->Value:returnexecutorch_call_delegate(self,*args)
# TODO(zhxchen17) Try ExportPassdef_fixup_output_node(gm:torch.fx.GraphModule)->None:fornodeinreversed(gm.graph.nodes):ifnode.op=="output":withgm.graph.inserting_before(node):assertlen(node.args)==1outputs=node.args[0]ifisinstance(outputs,torch.fx.Node):val=outputs.meta.get("val")ifisinstance(val,list):# If a list is returned, in some cases it is represented as a# singular node, like `split_copy_tensor` but EXIR will return a# opened-up list like `[getitem1, getitem2]`outputs=[torch.fx.Proxy(outputs)[i].nodeforiinrange(len(val))]returns,out_spec=pytree.tree_flatten(outputs)node.args=(returns,)returndefarrange_graph_placeholders(gm:torch.fx.GraphModule,owning_program:ExportedProgram)->torch.fx.GraphModule:""" Modifies the graph of the given graphmodule with one that contains the same nodes as the original, but with placeholders in order of (Params + Buffers) (User Inputs) This is used by the delegate api which disturbs the placeholder ordering when creating a submodule from partitioned nodes Args: gm: The graph module that we want arranged owning_program: ExportedProgram that the submodule (gm) belongs to Returns: The graph module in-placed arranged """new_graph=torch.fx.Graph()node_map={}# mapping of nodes from old graph to new graphgraph_sign=owning_program.graph_signature# Add all placeholders into the graph first:param_nodes=[]buffer_nodes=[]input_nodes=[]fornodeingm.graph.nodes:ifnode.op!="placeholder":continueifnode.nameingraph_sign.inputs_to_parameters:param_nodes.append(node)elifnode.nameingraph_sign.inputs_to_buffers:buffer_nodes.append(node)else:input_nodes.append(node)forparam_nodeinparam_nodes:new_node=new_graph.node_copy(param_node,lambdax:node_map[x])node_map[param_node]=new_nodeforbuffer_nodeinbuffer_nodes:new_node=new_graph.node_copy(buffer_node,lambdax:node_map[x])node_map[buffer_node]=new_nodeforinput_nodeininput_nodes:new_node=new_graph.node_copy(input_node,lambdax:node_map[x])node_map[input_node]=new_node# Now add all the other nodes in orderfornodeingm.graph.nodes:ifnode.op=="placeholder":continuenew_node=new_graph.node_copy(node,lambdax:node_map[x])node_map[node]=new_node# lint to ensure correctnessnew_graph.lint()new_graph._codegen=gm.graph._codegengm.graph=new_graphreturngm# TODO Don't regenerate new signature manually.def_get_new_signature(original_program:ExportedProgram,gm:torch.fx.GraphModule)->Tuple[ExportGraphSignature,Dict[str,Union[torch.Tensor,torch.nn.Parameter]]]:old_signature=original_program.graph_signatureinput_specs=[]output_specs=[]new_signature=ExportGraphSignature(input_specs=input_specs,output_specs=output_specs)new_state_dict={}fornodeingm.graph.nodes:ifnode.op=="placeholder":ifnode.nameinold_signature.inputs_to_parameters:parameter_name=old_signature.inputs_to_parameters[node.name]# add param to graph signatureinput_specs.append(InputSpec(kind=InputKind.PARAMETER,arg=TensorArgument(name=node.name),target=parameter_name,))# add param to state_dictnew_state_dict[parameter_name]=original_program.state_dict[parameter_name]elifnode.nameinold_signature.inputs_to_buffers:buffer_name=old_signature.inputs_to_buffers[node.name]# add buffer to graph signatureinput_specs.append(InputSpec(kind=InputKind.BUFFER,arg=TensorArgument(name=node.name),target=buffer_name,))# add param to new_state_dictnew_state_dict[buffer_name]=original_program.state_dict[buffer_name]else:# not param or buffer then user inputinput_specs.append(InputSpec(kind=InputKind.USER_INPUT,arg=TensorArgument(name=node.name),target=None,))ifnode.op=="output":foroutputinnode.all_input_nodes:output_specs.append(OutputSpec(kind=OutputKind.USER_OUTPUT,arg=TensorArgument(name=output.name),target=None,))returnnew_signature,new_state_dictdefcreate_exported_program_from_submodule(submodule:torch.fx.GraphModule,owning_program:ExportedProgram,)->ExportedProgram:""" Creates an ExportedProgram from the given submodule using the parameters and buffers from the top-level owning program Args: submodule: submodule to create and exported program from owning_program: exported program containing the parameters and buffers used within the submodule Returns: The ExportedProgram created from submodule """# Arrange the submodule's placeholders in ordersubmodule=arrange_graph_placeholders(submodule,owning_program)# Get updated graph signaturesubgraph_signature,subgraph_state_dict=_get_new_signature(owning_program,submodule)returnExportedProgram(root=submodule,graph=submodule.graph,graph_signature=subgraph_signature,state_dict=subgraph_state_dict,range_constraints=copy.deepcopy(owning_program.range_constraints),equality_constraints=[],module_call_graph=[],)defcreate_submodule_from_nodes(gm:torch.fx.GraphModule,node_list:NodeList,tag:str,skip_legalize_graph:bool=False,)->Tuple[torch.fx.GraphModule,torch.fx.Node]:""" Modifies the given graph module in-place to separate out the given nodes into a submodule. The given node_list should form a fully connected subgraph. Args: gm: The graph module that we want to partition node_list: A list of nodes that belong in the partition Returns: The submodule that has been partitioned, the call_module node in the toplevel graph module calling the submodule """sorted_nodes=topo_sort(node_list)submodule_name="fused_"+tagsub_gm,orig_inputs,orig_outputs=fuse_as_graphmodule(gm,sorted_nodes,submodule_name)_fixup_output_node(sub_gm)gm=insert_subgm(gm,sub_gm,orig_inputs,orig_outputs)submodule_node=Nonefornodeingm.graph.nodes:ifnode.op=="call_module":ifnode.target==submodule_name:submodule_node=nodeelse:raiseRuntimeError(f"The submodule created with nodes {node_list} did not form \ one fully contained subgraph. Check that these nodes form a \ fully contained graph. Partitioned graph: {gm.graph}.")iflen(orig_outputs)==1andisinstance(orig_outputs[0].meta["val"],FakeTensor):# If the original output is a single tensor, it has been# pytree.tree_flatten-ed to be a singleton list, so we want to replace# all uses with a getitem call to the 0th index of the resultwithgm.graph.inserting_after(submodule_node):proxy_out=torch.fx.Proxy(submodule_node)[0].node# type: ignore[index]submodule_node.replace_all_uses_with(proxy_out)proxy_out.meta["val"]=submodule_node.meta["val"]# Reset the args since it was overwritten in the previous lineproxy_out.args=(submodule_node,0)else:# fuse_as_graphmodule will automatically propagate the metadata of the# partition's last node to the getitem nodes that appear after the# call_module node. However, in the case of delegation we do not want# these getitem nodes to contain irrelevant previous metadata# (ex. source_fn, # nn_module_stack)foruser_nodeinsubmodule_node.users:user_node.meta.pop("nn_module_stack",None)user_node.meta.pop("source_fn_stack",None)erase_nodes(gm,sorted_nodes)# Topological sort original gm with newly created sub_gm# TODO : T153794167 Get rid of support for skipping legalize graph in create_submodule_from_nodes# once we transition to using fuse_by_partitions.ifnotskip_legalize_graph:legalize_graph(gm)# Get the call_module nodesubmodule_node=Nonefornodeingm.graph.nodes:ifnode.op=="call_module"andnode.target==submodule_name:submodule_node=nodeelifnode.op=="call_module":raiseRuntimeError(f"The submodule created with nodes {node_list} did not form \ one fully contained subgraph. Check that these nodes form a \ fully contained graph. Partitioned graph: {gm.graph}.")assert(submodule_nodeisnotNone),f"No submodule was created with the nodes {node_list} in the graph {gm.graph}"returnsub_gm,submodule_nodedefget_lowered_submodules(graph_module:torch.fx.GraphModule,)->List[Tuple[str,LoweredBackendModule,torch.fx.Node]]:""" Returns a list of lowered modules that are in the given graph (does not look into submodules). Specifically, the returned value is a list containing a tuple of (name of the lowered module that's stored in the graph module, the lowered module itself, and the fx node that called this lowered module). """lowered_submodules=[]fornodeingraph_module.graph.nodes:ifnode.op=="call_function"andnode.target==executorch_call_delegate:name,module,node=_get_submodule(graph_module,node,0)assertisinstance(module,LoweredBackendModule)lowered_submodules.append((name,module,node))returnlowered_submodulesdefget_lowered_backend_modules(graph_module:torch.fx.GraphModule,)->List[LoweredBackendModule]:""" Returns a list of exported programs which were lowered by backen delegates """lowered_programs=[]fornodeingraph_module.graph.nodes:ifnode.op=="call_function"andnode.target==executorch_call_delegate:lowered_backend_module=getattr(graph_module,node.args[0].name)lowered_programs.append(lowered_backend_module)returnlowered_programs
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.