Skip to content

Commit

Permalink
Merge branch 'main' into u/xiaoyun/#2593
Browse files Browse the repository at this point in the history
  • Loading branch information
LittleLittleCloud committed May 8, 2024
2 parents 38c2b77 + f75103f commit ab61c0e
Show file tree
Hide file tree
Showing 58 changed files with 1,271 additions and 448 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ jobs:
pip install pytest-cov>=5
- name: Install packages and dependencies for Transform Messages
run: |
pip install -e .
pip install -e '.[long-context]'
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
Expand Down
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,17 @@ In addition, you can find:
}
```

[AgentOptimizer](https://arxiv.org/pdf/2402.11359)

```
@article{zhang2024training,
title={Training Language Model Agents without Modifying Language Models},
author={Zhang, Shaokun and Zhang, Jieyu and Liu, Jiale and Song, Linxin and Wang, Chi and Krishna, Ranjay and Wu, Qingyun},
journal={ICML'24},
year={2024}
}
```

<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
<a href="#readme-top" style="text-decoration: none; color: blue; font-weight: bold;">
↑ Back to Top ↑
Expand Down
68 changes: 68 additions & 0 deletions autogen/agentchat/contrib/capabilities/text_compressors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Any, Dict, Optional, Protocol

IMPORT_ERROR: Optional[Exception] = None
try:
import llmlingua
except ImportError:
IMPORT_ERROR = ImportError(
"LLMLingua is not installed. Please install it with `pip install pyautogen[long-context]`"
)
PromptCompressor = object
else:
from llmlingua import PromptCompressor


class TextCompressor(Protocol):
"""Defines a protocol for text compression to optimize agent interactions."""

def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
"""This method takes a string as input and returns a dictionary containing the compressed text and other
relevant information. The compressed text should be stored under the 'compressed_text' key in the dictionary.
To calculate the number of saved tokens, the dictionary should include 'origin_tokens' and 'compressed_tokens' keys.
"""
...


class LLMLingua:
"""Compresses text messages using LLMLingua for improved efficiency in processing and response generation.
NOTE: The effectiveness of compression and the resultant token savings can vary based on the content of the messages
and the specific configurations used for the PromptCompressor.
"""

def __init__(
self,
prompt_compressor_kwargs: Dict = dict(
model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
use_llmlingua2=True,
device_map="cpu",
),
structured_compression: bool = False,
) -> None:
"""
Args:
prompt_compressor_kwargs (dict): A dictionary of keyword arguments for the PromptCompressor. Defaults to a
dictionary with model_name set to "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
use_llmlingua2 set to True, and device_map set to "cpu".
structured_compression (bool): A flag indicating whether to use structured compression. If True, the
structured_compress_prompt method of the PromptCompressor is used. Otherwise, the compress_prompt method
is used. Defaults to False.
dictionary.
Raises:
ImportError: If the llmlingua library is not installed.
"""
if IMPORT_ERROR:
raise IMPORT_ERROR

self._prompt_compressor = PromptCompressor(**prompt_compressor_kwargs)

assert isinstance(self._prompt_compressor, llmlingua.PromptCompressor)
self._compression_method = (
self._prompt_compressor.structured_compress_prompt
if structured_compression
else self._prompt_compressor.compress_prompt
)

def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
return self._compression_method([text], **compression_params)
178 changes: 163 additions & 15 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import copy
import json
import sys
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union

import tiktoken
from termcolor import colored

from autogen import token_count_utils
from autogen.cache import AbstractCache, Cache

from .text_compressors import LLMLingua, TextCompressor


class MessageTransform(Protocol):
Expand Down Expand Up @@ -156,7 +160,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
assert self._min_tokens is not None

# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not self._are_min_tokens_reached(messages):
if not _min_tokens_reached(messages, self._min_tokens):
return messages

temp_messages = copy.deepcopy(messages)
Expand Down Expand Up @@ -205,19 +209,6 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages:
return logs_str, True
return "No tokens were truncated.", False

def _are_min_tokens_reached(self, messages: List[Dict]) -> bool:
"""
Returns True if no minimum tokens restrictions are applied.
Either if the total number of tokens in the messages is greater than or equal to the `min_theshold_tokens`,
or no minimum tokens threshold is set.
"""
if not self._min_tokens:
return True

messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= self._min_tokens

def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
if isinstance(contents, str):
return self._truncate_tokens(contents, n_tokens)
Expand Down Expand Up @@ -268,7 +259,7 @@ def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int

return max_tokens if max_tokens is not None else sys.maxsize

def _validate_min_tokens(self, min_tokens: int, max_tokens: int) -> int:
def _validate_min_tokens(self, min_tokens: Optional[int], max_tokens: Optional[int]) -> int:
if min_tokens is None:
return 0
if min_tokens < 0:
Expand All @@ -278,6 +269,154 @@ def _validate_min_tokens(self, min_tokens: int, max_tokens: int) -> int:
return min_tokens


class TextMessageCompressor:
"""A transform for compressing text messages in a conversation history.
It uses a specified text compression method to reduce the token count of messages, which can lead to more efficient
processing and response generation by downstream models.
"""

def __init__(
self,
text_compressor: Optional[TextCompressor] = None,
min_tokens: Optional[int] = None,
compression_params: Dict = dict(),
cache: Optional[AbstractCache] = Cache.disk(),
):
"""
Args:
text_compressor (TextCompressor or None): An instance of a class that implements the TextCompressor
protocol. If None, it defaults to LLMLingua.
min_tokens (int or None): Minimum number of tokens in messages to apply the transformation. Must be greater
than or equal to 0 if not None. If None, no threshold-based compression is applied.
compression_args (dict): A dictionary of arguments for the compression method. Defaults to an empty
dictionary.
cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
If None, no caching will be used.
"""

if text_compressor is None:
text_compressor = LLMLingua()

self._validate_min_tokens(min_tokens)

self._text_compressor = text_compressor
self._min_tokens = min_tokens
self._compression_args = compression_params
self._cache = cache

# Optimizing savings calculations to optimize log generation
self._recent_tokens_savings = 0

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies compression to messages in a conversation history based on the specified configuration.
The function processes each message according to the `compression_args` and `min_tokens` settings, applying
the specified compression configuration and returning a new list of messages with reduced token counts
where possible.
Args:
messages (List[Dict]): A list of message dictionaries to be compressed.
Returns:
List[Dict]: A list of dictionaries with the message content compressed according to the configured
method and scope.
"""
# Make sure there is at least one message
if not messages:
return messages

# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not _min_tokens_reached(messages, self._min_tokens):
return messages

total_savings = 0
processed_messages = messages.copy()
for message in processed_messages:
# Some messages may not have content.
if not isinstance(message.get("content"), (str, list)):
continue

if _is_content_text_empty(message["content"]):
continue

cached_content = self._cache_get(message["content"])
if cached_content is not None:
savings, compressed_content = cached_content
else:
savings, compressed_content = self._compress(message["content"])

self._cache_set(message["content"], compressed_content, savings)

message["content"] = compressed_content
total_savings += savings

self._recent_tokens_savings = total_savings
return processed_messages

def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
if self._recent_tokens_savings > 0:
return f"{self._recent_tokens_savings} tokens saved with text compression.", True
else:
return "No tokens saved with text compression.", False

def _compress(self, content: Union[str, List[Dict]]) -> Tuple[int, Union[str, List[Dict]]]:
"""Compresses the given text or multimodal content using the specified compression method."""
if isinstance(content, str):
return self._compress_text(content)
elif isinstance(content, list):
return self._compress_multimodal(content)
else:
return 0, content

def _compress_multimodal(self, content: List[Dict]) -> Tuple[int, List[Dict]]:
tokens_saved = 0
for msg in content:
if "text" in msg:
savings, msg["text"] = self._compress_text(msg["text"])
tokens_saved += savings
return tokens_saved, content

def _compress_text(self, text: str) -> Tuple[int, str]:
"""Compresses the given text using the specified compression method."""
compressed_text = self._text_compressor.compress_text(text, **self._compression_args)

savings = 0
if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]

return savings, compressed_text["compressed_prompt"]

def _cache_get(self, content: Union[str, List[Dict]]) -> Optional[Tuple[int, Union[str, List[Dict]]]]:
if self._cache:
cached_value = self._cache.get(self._cache_key(content))
if cached_value:
return cached_value

def _cache_set(
self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
):
if self._cache:
value = (tokens_saved, json.dumps(compressed_content))
self._cache.set(self._cache_key(content), value)

def _cache_key(self, content: Union[str, List[Dict]]) -> str:
return f"{json.dumps(content)}_{self._min_tokens}"

def _validate_min_tokens(self, min_tokens: Optional[int]):
if min_tokens is not None and min_tokens <= 0:
raise ValueError("min_tokens must be greater than 0 or None")


def _min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value."""
if not min_tokens:
return True

messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= min_tokens


def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
token_count = 0
if isinstance(content, str):
Expand All @@ -286,3 +425,12 @@ def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
for item in content:
token_count += _count_tokens(item.get("text", ""))
return token_count


def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
if isinstance(content, str):
return content == ""
elif isinstance(content, list):
return all(_is_content_text_empty(item.get("text", "")) for item in content)
else:
return False
10 changes: 7 additions & 3 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import re
import warnings
from collections import defaultdict
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union

from openai import BadRequestError
Expand Down Expand Up @@ -434,10 +433,15 @@ def reply_func_from_nested_chats(
reply_func_from_nested_chats = self._summary_from_nested_chats
if not callable(reply_func_from_nested_chats):
raise ValueError("reply_func_from_nested_chats must be a callable")
reply_func = partial(reply_func_from_nested_chats, chat_queue)

def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)

functools.update_wrapper(wrapped_reply_func, reply_func_from_nested_chats)

self.register_reply(
trigger,
reply_func,
wrapped_reply_func,
position,
kwargs.get("config"),
kwargs.get("reset_config"),
Expand Down
3 changes: 2 additions & 1 deletion dotnet/Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
<Nullable>enable</Nullable>
<SignAssembly>True</SignAssembly>
<AssemblyOriginatorKeyFile>$(MSBuildThisFileDirectory)eng/opensource.snk</AssemblyOriginatorKeyFile>
<PublicKey>0024000004800000940000000602000000240000525341310004000001000100f1d038d0b85ae392ad72011df91e9343b0b5df1bb8080aa21b9424362d696919e0e9ac3a8bca24e283e10f7a569c6f443e1d4e3ebc84377c87ca5caa562e80f9932bf5ea91b7862b538e13b8ba91c7565cf0e8dfeccfea9c805ae3bda044170ecc7fc6f147aeeac422dd96aeb9eb1f5a5882aa650efe2958f2f8107d2038f2ab</PublicKey>
<CSNoWarn>CS1998;CS1591</CSNoWarn>
<NoWarn>$(NoWarn);$(CSNoWarn);NU5104</NoWarn>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
Expand All @@ -20,4 +21,4 @@
<PropertyGroup>
<RepoRoot>$(MSBuildThisFileDirectory)</RepoRoot>
</PropertyGroup>
</Project>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public async Task ChatWithAnAgent(IStreamingAgent agent)

#region ChatWithAnAgent_GenerateStreamingReplyAsync
var textMessage = new TextMessage(Role.User, "Hello");
await foreach (var streamingReply in await agent.GenerateStreamingReplyAsync([message]))
await foreach (var streamingReply in agent.GenerateStreamingReplyAsync([message]))
{
if (streamingReply is TextMessageUpdate update)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public async Task StreamingCallCodeSnippetAsync()
IStreamingAgent agent = default;
#region StreamingCallCodeSnippet
var helloTextMessage = new TextMessage(Role.User, "Hello");
var reply = await agent.GenerateStreamingReplyAsync([helloTextMessage]);
var reply = agent.GenerateStreamingReplyAsync([helloTextMessage]);
var finalTextMessage = new TextMessage(Role.Assistant, string.Empty, from: agent.Name);
await foreach (var message in reply)
{
Expand All @@ -24,7 +24,7 @@ await foreach (var message in reply)
#endregion StreamingCallCodeSnippet

#region StreamingCallWithFinalMessage
reply = await agent.GenerateStreamingReplyAsync([helloTextMessage]);
reply = agent.GenerateStreamingReplyAsync([helloTextMessage]);
TextMessage finalMessage = null;
await foreach (var message in reply)
{
Expand Down

0 comments on commit ab61c0e

Please sign in to comment.