Skip to content

Commit c0ce883

Browse files
authored
FEAT: Attack Identifier (#1364)
1 parent 4039c2d commit c0ce883

83 files changed

Lines changed: 1091 additions & 496 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ API Reference
272272
:toctree: _autosummary/
273273

274274
class_name_to_snake_case
275+
AttackIdentifier
275276
ConverterIdentifier
276277
Identifiable
277278
Identifier

pyrit/analytics/result_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def analyze_results(attack_results: list[AttackResult]) -> dict[str, AttackStats
6262
raise TypeError(f"Expected AttackResult, got {type(attack).__name__}: {attack!r}")
6363

6464
outcome = attack.outcome
65-
attack_type = attack.attack_identifier.get("type", "unknown")
65+
attack_type = attack.attack_identifier.class_name if attack.attack_identifier else "unknown"
6666

6767
if outcome == AttackOutcome.SUCCESS:
6868
overall_counts["successes"] += 1

pyrit/exceptions/exception_context.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from contextvars import ContextVar
1414
from dataclasses import dataclass, field
1515
from enum import Enum
16-
from typing import Any, Dict, Optional, Union
16+
from typing import Any, Optional
1717

18-
from pyrit.identifiers import Identifier
18+
from pyrit.identifiers import AttackIdentifier, Identifier
1919

2020

2121
class ComponentRole(Enum):
@@ -61,11 +61,11 @@ class ExecutionContext:
6161
# The attack strategy class name (e.g., "PromptSendingAttack")
6262
attack_strategy_name: Optional[str] = None
6363

64-
# The identifier from the attack strategy's get_identifier()
65-
attack_identifier: Optional[Dict[str, Any]] = None
64+
# The identifier for the attack strategy
65+
attack_identifier: Optional[AttackIdentifier] = None
6666

6767
# The identifier from the component's get_identifier() (target, scorer, etc.)
68-
component_identifier: Optional[Dict[str, Any]] = None
68+
component_identifier: Optional[Identifier] = None
6969

7070
# The objective target conversation ID if available
7171
objective_target_conversation_id: Optional[str] = None
@@ -192,8 +192,8 @@ def execution_context(
192192
*,
193193
component_role: ComponentRole,
194194
attack_strategy_name: Optional[str] = None,
195-
attack_identifier: Optional[Dict[str, Any]] = None,
196-
component_identifier: Optional[Union[Identifier, Dict[str, Any]]] = None,
195+
attack_identifier: Optional[AttackIdentifier] = None,
196+
component_identifier: Optional[Identifier] = None,
197197
objective_target_conversation_id: Optional[str] = None,
198198
objective: Optional[str] = None,
199199
) -> ExecutionContextManager:
@@ -203,9 +203,8 @@ def execution_context(
203203
Args:
204204
component_role: The role of the component being executed.
205205
attack_strategy_name: The name of the attack strategy class.
206-
attack_identifier: The identifier from attack.get_identifier().
206+
attack_identifier: The attack identifier.
207207
component_identifier: The identifier from component.get_identifier().
208-
Can be an Identifier object or a dict (legacy format).
209208
objective_target_conversation_id: The objective target conversation ID if available.
210209
objective: The attack objective if available.
211210
@@ -215,22 +214,15 @@ def execution_context(
215214
# Extract endpoint and component_name from component_identifier if available
216215
endpoint = None
217216
component_name = None
218-
component_id_dict: Optional[Dict[str, Any]] = None
219217
if component_identifier:
220-
if isinstance(component_identifier, Identifier):
221-
endpoint = getattr(component_identifier, "endpoint", None)
222-
component_name = component_identifier.class_name
223-
component_id_dict = component_identifier.to_dict()
224-
else:
225-
endpoint = component_identifier.get("endpoint")
226-
component_name = component_identifier.get("__type__")
227-
component_id_dict = component_identifier
218+
endpoint = getattr(component_identifier, "endpoint", None)
219+
component_name = component_identifier.class_name
228220

229221
context = ExecutionContext(
230222
component_role=component_role,
231223
attack_strategy_name=attack_strategy_name,
232224
attack_identifier=attack_identifier,
233-
component_identifier=component_id_dict,
225+
component_identifier=component_identifier,
234226
objective_target_conversation_id=objective_target_conversation_id,
235227
endpoint=endpoint,
236228
component_name=component_name,

pyrit/executor/attack/component/conversation_manager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
import logging
55
import uuid
66
from dataclasses import dataclass, field
7-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
7+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
88

99
from pyrit.common.utils import combine_dict
1010
from pyrit.executor.attack.component.prepended_conversation_config import (
1111
PrependedConversationConfig,
1212
)
13-
from pyrit.identifiers import TargetIdentifier
13+
from pyrit.identifiers import AttackIdentifier, TargetIdentifier
1414
from pyrit.memory import CentralMemory
1515
from pyrit.message_normalizer import ConversationContextNormalizer
1616
from pyrit.models import ChatMessageRole, Message, MessagePiece, Score
@@ -54,8 +54,8 @@ def get_adversarial_chat_messages(
5454
prepended_conversation: List[Message],
5555
*,
5656
adversarial_chat_conversation_id: str,
57-
attack_identifier: Dict[str, str],
58-
adversarial_chat_target_identifier: Union[TargetIdentifier, Dict[str, Any]],
57+
attack_identifier: AttackIdentifier,
58+
adversarial_chat_target_identifier: TargetIdentifier,
5959
labels: Optional[Dict[str, str]] = None,
6060
) -> List[Message]:
6161
"""
@@ -183,7 +183,7 @@ class ConversationManager:
183183
def __init__(
184184
self,
185185
*,
186-
attack_identifier: Dict[str, str],
186+
attack_identifier: AttackIdentifier,
187187
prompt_normalizer: Optional[PromptNormalizer] = None,
188188
):
189189
"""

pyrit/executor/attack/core/attack_strategy.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
StrategyEventData,
2121
StrategyEventHandler,
2222
)
23+
from pyrit.identifiers import AttackIdentifier, Identifiable
2324
from pyrit.memory.central_memory import CentralMemory
2425
from pyrit.models import (
2526
AttackOutcome,
@@ -224,7 +225,7 @@ def _log_attack_outcome(self, result: AttackResult) -> None:
224225
self._logger.info(message)
225226

226227

227-
class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], ABC):
228+
class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], Identifiable[AttackIdentifier], ABC):
228229
"""
229230
Abstract base class for attack strategies.
230231
Defines the interface for executing attacks and handling results.
@@ -258,6 +259,54 @@ def __init__(
258259
)
259260
self._objective_target = objective_target
260261
self._params_type = params_type
262+
# Guard so subclasses that set converters before calling super() aren't clobbered
263+
if not hasattr(self, "_request_converters"):
264+
self._request_converters: list[Any] = []
265+
if not hasattr(self, "_response_converters"):
266+
self._response_converters: list[Any] = []
267+
268+
def _build_identifier(self) -> AttackIdentifier:
269+
"""
270+
Build the typed identifier for this attack strategy.
271+
272+
Captures the objective target, optional scorer, and converter pipeline.
273+
This is the *stable* strategy-level identifier that does not change
274+
between calls to ``execute_async``.
275+
276+
Returns:
277+
AttackIdentifier: The constructed identifier.
278+
"""
279+
# Get target identifier
280+
objective_target_identifier = self.get_objective_target().get_identifier()
281+
282+
# Get scorer identifier if present
283+
scorer_identifier = None
284+
scoring_config = self.get_attack_scoring_config()
285+
if scoring_config and scoring_config.objective_scorer:
286+
scorer_identifier = scoring_config.objective_scorer.get_identifier()
287+
288+
# Get request converter identifiers if present
289+
request_converter_ids = None
290+
if self._request_converters:
291+
request_converter_ids = [
292+
converter.get_identifier() for config in self._request_converters for converter in config.converters
293+
]
294+
295+
# Get response converter identifiers if present
296+
response_converter_ids = None
297+
if self._response_converters:
298+
response_converter_ids = [
299+
converter.get_identifier() for config in self._response_converters for converter in config.converters
300+
]
301+
302+
return AttackIdentifier(
303+
class_name=self.__class__.__name__,
304+
class_module=self.__class__.__module__,
305+
objective_target_identifier=objective_target_identifier,
306+
objective_scorer_identifier=scorer_identifier,
307+
request_converter_identifiers=request_converter_ids or None,
308+
response_converter_identifiers=response_converter_ids or None,
309+
)
261310

262311
@property
263312
def params_type(self) -> Type[AttackParameters]:
@@ -291,6 +340,15 @@ def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]:
291340
"""
292341
return None
293342

343+
def get_request_converters(self) -> list[Any]:
344+
"""
345+
Get request converter configurations used by this strategy.
346+
347+
Returns:
348+
list[Any]: The list of request PromptConverterConfiguration objects.
349+
"""
350+
return self._request_converters
351+
294352
@overload
295353
async def execute_async(
296354
self,

pyrit/executor/attack/multi_turn/tree_of_attacks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
)
3838
from pyrit.executor.attack.core.attack_strategy import AttackStrategy
3939
from pyrit.executor.attack.multi_turn import MultiTurnAttackContext
40+
from pyrit.identifiers import AttackIdentifier
4041
from pyrit.memory import CentralMemory
4142
from pyrit.models import (
4243
AttackOutcome,
@@ -267,7 +268,7 @@ def __init__(
267268
request_converters: List[PromptConverterConfiguration],
268269
response_converters: List[PromptConverterConfiguration],
269270
auxiliary_scorers: Optional[List[Scorer]],
270-
attack_id: dict[str, str],
271+
attack_id: AttackIdentifier,
271272
attack_strategy_name: str,
272273
memory_labels: Optional[dict[str, str]] = None,
273274
parent_id: Optional[str] = None,
@@ -289,7 +290,7 @@ def __init__(
289290
request_converters (List[PromptConverterConfiguration]): Converters for request normalization
290291
response_converters (List[PromptConverterConfiguration]): Converters for response normalization
291292
auxiliary_scorers (Optional[List[Scorer]]): Additional scorers for the response
292-
attack_id (dict[str, str]): Unique identifier for the attack.
293+
attack_id (AttackIdentifier): Unique identifier for the attack.
293294
attack_strategy_name (str): Name of the attack strategy for execution context.
294295
memory_labels (Optional[dict[str, str]]): Labels for memory storage.
295296
parent_id (Optional[str]): ID of the parent node, if this is a child node

pyrit/executor/attack/printer/console_printer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,8 @@ async def print_summary_async(self, result: AttackResult) -> None:
258258

259259
# Extract attack type name from attack_identifier
260260
attack_type = "Unknown"
261-
if isinstance(result.attack_identifier, dict) and "__type__" in result.attack_identifier:
262-
attack_type = result.attack_identifier["__type__"]
263-
elif isinstance(result.attack_identifier, str):
264-
attack_type = result.attack_identifier
261+
if result.attack_identifier:
262+
attack_type = result.attack_identifier.class_name
265263

266264
self._print_colored(f"{self._indent * 2}• Attack Type: {attack_type}", Fore.CYAN)
267265
self._print_colored(f"{self._indent * 2}• Conversation ID: {result.conversation_id}", Fore.CYAN)

pyrit/executor/attack/printer/markdown_printer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ async def _get_summary_markdown_async(self, result: AttackResult) -> List[str]:
493493
markdown_lines.append("|-------|-------|")
494494
markdown_lines.append(f"| **Objective** | {result.objective} |")
495495

496-
attack_type = result.attack_identifier.get("__type__", "Unknown")
496+
attack_type = result.attack_identifier.class_name if result.attack_identifier else "Unknown"
497497

498498
markdown_lines.append(f"| **Attack Type** | `{attack_type}` |")
499499
markdown_lines.append(f"| **Conversation ID** | `{result.conversation_id}` |")

pyrit/executor/benchmark/fairness_bias.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
PromptSendingAttack,
1818
)
1919
from pyrit.executor.core import Strategy, StrategyContext
20+
from pyrit.identifiers import AttackIdentifier
2021
from pyrit.memory import CentralMemory
2122
from pyrit.models import (
2223
AttackOutcome,
@@ -195,7 +196,10 @@ async def _perform_async(self, *, context: FairnessBiasBenchmarkContext) -> Atta
195196
conversation_id=str(uuid.UUID(int=0)),
196197
objective=context.generated_objective,
197198
outcome=AttackOutcome.FAILURE,
198-
attack_identifier=self.get_identifier(),
199+
attack_identifier=AttackIdentifier(
200+
class_name=self.__class__.__name__,
201+
class_module=self.__class__.__module__,
202+
),
199203
)
200204

201205
return last_attack_result

pyrit/executor/core/strategy.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -176,19 +176,6 @@ def __init__(
176176
default_values.get_non_required_value(env_var_name="GLOBAL_MEMORY_LABELS") or "{}"
177177
)
178178

179-
def get_identifier(self) -> Dict[str, str]:
180-
"""
181-
Get a serializable identifier for the strategy instance.
182-
183-
Returns:
184-
dict: A dictionary containing the type, module, and unique ID of the strategy.
185-
"""
186-
return {
187-
"__type__": self.__class__.__name__,
188-
"__module__": self.__class__.__module__,
189-
"id": str(self._id),
190-
}
191-
192179
def _register_event_handler(self, event_handler: StrategyEventHandler[StrategyContextT, StrategyResultT]) -> None:
193180
"""
194181
Register an event handler for strategy events.

0 commit comments

Comments
 (0)