Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchFX] Bias correction implementation #2882

Merged
merged 7 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions nncf/experimental/torch/fx/groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import nncf.torch.graph.operator_metatypes as om
from nncf.torch.model_graph_manager import OPERATORS_WITH_BIAS_METATYPES

FX_OPERATORS_WITH_BIAS_METATYPES = tuple(OPERATORS_WITH_BIAS_METATYPES) + (om.PTLinearMetatype,)
117 changes: 80 additions & 37 deletions nncf/experimental/torch/fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
# limitations under the License.

from collections import defaultdict
from typing import List
from typing import List, Set

import torch
import torch.fx
from torch.fx.passes.split_utils import split_by_tags

from nncf.common.graph.model_transformer import ModelTransformer
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
Expand All @@ -26,6 +25,8 @@
class FXModelTransformer(ModelTransformer):
"""
Applies transformations upon Torch FX model.
FXApplyTransformationCommands are made inplace,
PTModelExtractionCommands do not change the input model.
AlexanderDokuchaev marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, model: torch.fx.GraphModule):
Expand Down Expand Up @@ -61,6 +62,31 @@ def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.G
model.recompile()
return model

@staticmethod
def _traverse_graph(
input_nodes: List[torch.fx.Node],
stop_nodes: Set[torch.fx.Node],
visited: Set[torch.fx.Node],
) -> None:
"""
Traverses through the graph starting with the input nodes and
stopping for the stop nodes and the visited nodes. As the result,
it modifies the visited container with all nodes visited during the traverse.

:param input_nodes: Given input nodes.
:param stop_nodes: Given stop nodes.
:param visited: Set of already visited nodes.
"""

while input_nodes:
in_node = input_nodes.pop()
if in_node.name in visited or in_node.name in stop_nodes:
continue

visited.add(in_node.name)
input_nodes.extend(in_node.all_input_nodes)
input_nodes.extend(list(in_node.users))

@staticmethod
def _apply_model_extraction(
model: torch.fx.GraphModule,
Expand All @@ -75,46 +101,63 @@ def _apply_model_extraction(
more than one element this function raises an assert.
:return: Returns a submodel extracted from the given model by the given transformation.
"""

transformation = transformations[-1]
assert len(transformation.input_node_names) == 1
assert transformation.input_node_names == transformation.output_node_names
node_name = transformation.input_node_names[0]
stop_nodes = set(transformation.input_node_names + transformation.output_node_names)
visited = set()

for node_name in transformation.input_node_names:
node = get_graph_node_by_name(model.graph, node_name)
visited.add(node.name)
target_inputs = node.all_input_nodes[1:]
if node.name not in transformation.output_node_names:
target_inputs += list(node.users)
FXModelTransformer._traverse_graph(target_inputs, stop_nodes, visited)

for node_name in transformation.output_node_names:
node = get_graph_node_by_name(model.graph, node_name)
visited.add(node.name)
if node.name not in transformation.input_node_names:
FXModelTransformer._traverse_graph(node.all_input_nodes, stop_nodes, visited)

extracted_graph = torch.fx.Graph()
value_remap = {}

def remap_fn(node: torch.fx.Node):
return value_remap.get(node) # noqa F821

tags = ["before", "extracted", "after"]
i = 0
for node in model.graph.nodes:
if node.name == node_name:
node.tag = tags[1]
weights = [node.all_input_nodes[1]]
while weights:
w_node = weights.pop()
assert w_node.tag in tags[0:2]
w_node.tag = tags[1]
weights.extend(w_node.all_input_nodes)
i = 2
if node.name not in visited or node.op == "output":
AlexanderDokuchaev marked this conversation as resolved.
Show resolved Hide resolved
continue
node.tag = tags[i]

# TODO(dlyakhov): reduce memory consumption by
# more optimal splitting implementation.
splitted_gm = split_by_tags(model, tags)

extracted_model = splitted_gm.extracted
graph: torch.fx.Graph = extracted_model.graph
# Check extracted model has inputs.
# It is possible to have two constant inputs
# for the target layer, an placeholder is being
# placed to the input port.
target_node = get_graph_node_by_name(graph, node_name)
input_node = target_node.all_input_nodes[0]
if input_node.op != "placeholder":
with graph.inserting_before(target_node):
new_input_node = graph.create_node(
"placeholder", "placeholder_node", (), {}, name="placeholder_graph_node"
value_remap[node] = extracted_graph.node_copy(node, remap_fn)
del value_remap

for input_name in transformation.input_node_names:
node_with_input = get_graph_node_by_name(extracted_graph, input_name)
with extracted_graph.inserting_before(node_with_input):
graph_input_name = input_name + "_input"
graph_input = extracted_graph.create_node(
op="placeholder",
target=graph_input_name,
name=graph_input_name,
)
target_node.replace_input_with(input_node, new_input_node)
extracted_model.graph.eliminate_dead_code()
return extracted_model

args = list(node_with_input.args)
args[0] = graph_input
node_with_input.args = tuple(args)

nodes_with_output = [get_graph_node_by_name(extracted_graph, name) for name in transformation.output_node_names]
AlexanderDokuchaev marked this conversation as resolved.
Show resolved Hide resolved
last_node = list(extracted_graph.nodes)[-1]
with extracted_graph.inserting_after(last_node):
graph_output_name = "output"
extracted_graph.create_node(
"output",
graph_output_name,
(tuple(nodes_with_output),),
name=graph_output_name,
)

return torch.fx.GraphModule(model, extracted_graph)

@staticmethod
def _apply_transformation(
Expand Down
76 changes: 76 additions & 0 deletions nncf/experimental/torch/fx/model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import deque

import torch.fx

from nncf.common.factory import ModelTransformerFactory
from nncf.common.graph.definitions import NNCFGraphNodeType
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
from nncf.experimental.torch.fx.transformations import node_removal_transformation_builder
from nncf.torch.graph.operator_metatypes import QUANTIZE_NODE_TYPES
from nncf.torch.graph.transformations.commands import PTTargetPoint


def remove_fq_from_inputs(model: torch.fx.GraphModule, graph: NNCFGraph) -> torch.fx.GraphModule:
"""
This method removes the activation Fake Quantize nodes from the model.
It's needed for the further bias shift calculation that relates on quantized weights.

:param model: ov.Model instance.
:param graph: NNCFGraph instance.
:return: ov.Model instance without activation Fake Quantize nodes.
"""
transformation_layout = TransformationLayout()
model_transformer = ModelTransformerFactory.create(model)

seen_nodes = []
nodes_queue = deque(graph.get_input_nodes())
while nodes_queue:
current_node = nodes_queue.popleft()
current_node_name = current_node.node_name

if current_node_name in seen_nodes:
continue

seen_nodes.append(current_node_name)
if current_node.node_type in QUANTIZE_NODE_TYPES:
transformation = node_removal_transformation_builder(current_node, input_port_id=0)
transformation_layout.register(FXApplyTransformationCommand(transformation))
nodes_queue.extend(graph.get_next_nodes(current_node))

return model_transformer.transform(transformation_layout)


_TARGET_TYPE_TO_FX_INS_TYPE_MAP = {
TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK,
TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK,
}


def get_target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint:
"""
Creates torch-specific target point.

:param target_type: Target point target type.
:param target_node_name: Target node name to use in the target point.
:param port_id: Target port id.
:return: Torch-specific target point.
"""
if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION:
port_id = None
if target_type in _TARGET_TYPE_TO_FX_INS_TYPE_MAP:
target_type = _TARGET_TYPE_TO_FX_INS_TYPE_MAP[target_type]
return PTTargetPoint(target_type, target_node_name, input_port_id=port_id)
37 changes: 37 additions & 0 deletions nncf/experimental/torch/fx/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@

import torch.fx

import nncf.torch.graph.operator_metatypes as om
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.experimental.torch.fx.groups import FX_OPERATORS_WITH_BIAS_METATYPES
from nncf.tensor import Tensor


# TODO(dlyakhov): Use torch.fx.graph.find_nodes method instead after
# torch version update (>= 2.4)
Expand Down Expand Up @@ -49,3 +55,34 @@ def get_tensor_constant_from_node(constant_node: torch.fx.Node, model: torch.fx.
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr


def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
"""
Returns True if the node has a bias, False otherwise.

:param node: Target node.
:param nncf_graph: Target nncf_graph.
:return: True if the node has a bias, False otherwise.
"""
# Assumes that all biases were unfused
KodiaqQ marked this conversation as resolved.
Show resolved Hide resolved
if node.metatype in FX_OPERATORS_WITH_BIAS_METATYPES:
next_nodes = nncf_graph.get_next_nodes(node)
if len(next_nodes) != 1:
return False
return next_nodes[0].metatype in (om.PTAddMetatype,)


def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphModule) -> Tensor:
"""
Retrieves the bias value from the given node.

:param node: Target node.
:param nncf_graph: Target nncf_graph.
:param model: Target GraphModule.
:return: Bias value of the given node.
"""
bias_node = nncf_graph.get_next_nodes(node)[0]
# TODO(dlyakhov): make a node_name_vs_node map to speed up the process
KodiaqQ marked this conversation as resolved.
Show resolved Hide resolved
graph_bias_node = get_graph_node_by_name(model.graph, bias_node.node_name)
return Tensor(get_tensor_constant_from_node(graph_bias_node.all_input_nodes[1], model))
2 changes: 0 additions & 2 deletions nncf/experimental/torch/fx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ def quantize_impl(
" Torch FX PTQ is an experimental feature, consider using Torch or OpenVino PTQ backends"
" in case of errors or a poor model performance."
)
if fast_bias_correction is False:
raise ValueError(f"fast_bias_correction={fast_bias_correction} is not supported")
if target_device == TargetDevice.CPU_SPR:
raise nncf.InternalError("target_device == CPU_SPR is not supported")
if mode is not None:
Expand Down
Loading
Loading