Skip to content

Commit

Permalink
add multi_fruit haggling training scenario and add seeding of random …
Browse files Browse the repository at this point in the history
…numbers for reproducibility

PiperOrigin-RevId: 681484872
Change-Id: I253ee976278a2a34f0766aa88d3ac723deef2436
  • Loading branch information
vezhnick authored and copybara-github committed Oct 2, 2024
1 parent 303ba6c commit b3cf457
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 33 deletions.
44 changes: 29 additions & 15 deletions examples/modular/environment/haggling.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class WorldConfig:
seller_base_reward_max: The maximum base reward for the seller.
num_games: The number of games to play.
num_main_players: The number of main players in the scenario.
random_seed: The random seed for the random number generator.
"""

year: int
Expand All @@ -98,6 +99,7 @@ class WorldConfig:
seller_base_reward_max: int = 2
num_games: int = 2
num_main_players: int = 3
random_seed: int = 42


def bargain_statements(
Expand Down Expand Up @@ -139,14 +141,17 @@ def get_shared_memories_and_context(premise: str) -> tuple[Sequence[str], str]:
return shared_memories, shared_context


def configure_player(name: str, gender: str, year: int, is_main: bool):
def configure_player(
name: str, gender: str, year: int, is_main: bool, rng: random.Random
):
"""Configure a player.
Args:
name: the name of the player
gender: the gender of the player
year: the year of the simulation to sample the age of the players
is_main: whether the player is a main character or not
rng: the random number generator to use
Returns:
config: the config for the player
Expand All @@ -165,9 +170,9 @@ def configure_player(name: str, gender: str, year: int, is_main: bool):
name=name,
gender=gender,
date_of_birth=datetime.datetime(
year=year - random.randint(25, 54),
month=random.randint(1, 12),
day=random.randint(1, 28),
year=year - rng.randint(25, 54),
month=rng.randint(1, 12),
day=rng.randint(1, 28),
),
context=(
f'{name} is a travelling merchant. Her business is buying and'
Expand All @@ -183,6 +188,7 @@ def configure_player(name: str, gender: str, year: int, is_main: bool):

def configure_players(
sampled_settings: WorldConfig,
rng: random.Random,
) -> tuple[
list[formative_memories.AgentConfig],
list[formative_memories.AgentConfig],
Expand All @@ -191,6 +197,7 @@ def configure_players(
Args:
sampled_settings: the sampled settings for the world configuration
rng: the random number generator to use
Returns:
main_player_configs: configs for the main characters
Expand All @@ -204,15 +211,21 @@ def configure_players(
name = names[i]
gender = sampled_settings.person_data[name]['gender']

config = configure_player(name, gender, sampled_settings.year, is_main=True)
config = configure_player(
name, gender, sampled_settings.year, is_main=True, rng=rng
)
player_configs.append(config)

for i in range(sampled_settings.num_supporting_players):
name = names[i + sampled_settings.num_main_players]
gender = sampled_settings.person_data[name]['gender']

config = configure_player(
name, gender, sampled_settings.year, is_main=False
name,
gender,
sampled_settings.year,
is_main=False,
rng=rng,
)

player_configs.append(config)
Expand Down Expand Up @@ -365,6 +378,7 @@ def configure_scenes(
main_player_configs: Sequence[formative_memories.AgentConfig],
supporting_player_configs: Sequence[formative_memories.AgentConfig],
start_time: datetime.datetime,
rng: random.Random,
sampled_settings: WorldConfig,
) -> tuple[
Sequence[scene_lib.SceneSpec],
Expand All @@ -381,6 +395,7 @@ def configure_scenes(
main_player_configs: configs for the main characters
supporting_player_configs: configs for the supporting characters
start_time: the start time of the simulation
rng: the random number generator to use
sampled_settings: the sampled settings for the world configuration
Returns:
Expand Down Expand Up @@ -408,12 +423,8 @@ def configure_scenes(

for i in range(sampled_settings.num_games * len(pairs)):

buyer_base_reward = random.randint(
sampled_settings.buyer_base_reward_min, 6
)
seller_base_reward = random.randint(
1, sampled_settings.seller_base_reward_max
)
buyer_base_reward = rng.randint(sampled_settings.buyer_base_reward_min, 6)
seller_base_reward = rng.randint(1, sampled_settings.seller_base_reward_max)

this_game_players = pairs[i % len(pairs)]

Expand All @@ -427,7 +438,7 @@ def configure_scenes(
[cfg for cfg in player_configs if cfg.name == player.name][0]
)

scene_opening = random.choice(list(sampled_settings.scene_visuals))
scene_opening = rng.choice(list(sampled_settings.scene_visuals))
scene_specs = {
'social': scene_lib.SceneTypeSpec(
name='day',
Expand Down Expand Up @@ -585,6 +596,8 @@ def __init__(
sampled_settings.num_main_players = num_main_players
sampled_settings.num_games = num_games

self._rng = random.Random(sampled_settings.random_seed)

start_time = datetime.datetime(
year=time_and_place_params.YEAR,
month=time_and_place_params.MONTH,
Expand Down Expand Up @@ -637,9 +650,9 @@ def __init__(
)

main_player_configs, supporting_player_configs = configure_players(
sampled_settings
sampled_settings, rng=self._rng
)
random.shuffle(main_player_configs)
self._rng.shuffle(main_player_configs)

tasks = {
config.name: functools.partial(
Expand Down Expand Up @@ -733,6 +746,7 @@ def __init__(
main_player_configs=main_player_configs,
supporting_player_configs=supporting_player_configs,
start_time=start_time,
rng=self._rng,
sampled_settings=sampled_settings,
)

Expand Down
35 changes: 24 additions & 11 deletions examples/modular/environment/haggling_multi_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class WorldConfig:
num_main_players: The number of main players in the scenario.
prices: The prices for the items.
items_for_sale: The items for sale in the scenario.
random_seed: The random seed to use for the scenario.
"""

year: int
Expand All @@ -102,6 +103,7 @@ class WorldConfig:
num_main_players: int = 3
prices: Sequence[int] = (1, 2, 3, 4, 5, 6)
items_for_sale: Sequence[str] = ('apple', 'banana', 'pear')
random_seed: int = 42


def bargain_statements(
Expand Down Expand Up @@ -143,14 +145,17 @@ def get_shared_memories_and_context(premise: str) -> tuple[Sequence[str], str]:
return shared_memories, shared_context


def configure_player(name: str, gender: str, year: int, is_main: bool):
def configure_player(
name: str, gender: str, year: int, is_main: bool, rng: random.Random
):
"""Configure a player.
Args:
name: the name of the player
gender: the gender of the player
year: the year of the simulation to sample the age of the players
is_main: whether the player is a main character or not
rng: the random number generator to use
Returns:
config: the config for the player
Expand All @@ -169,9 +174,9 @@ def configure_player(name: str, gender: str, year: int, is_main: bool):
name=name,
gender=gender,
date_of_birth=datetime.datetime(
year=year - random.randint(25, 54),
month=random.randint(1, 12),
day=random.randint(1, 28),
year=year - rng.randint(25, 54),
month=rng.randint(1, 12),
day=rng.randint(1, 28),
),
context=(
f'{name} is a travelling merchant. Her business is buying and'
Expand All @@ -187,6 +192,7 @@ def configure_player(name: str, gender: str, year: int, is_main: bool):

def configure_players(
sampled_settings: WorldConfig,
rng: random.Random,
) -> tuple[
list[formative_memories.AgentConfig],
list[formative_memories.AgentConfig],
Expand All @@ -195,6 +201,7 @@ def configure_players(
Args:
sampled_settings: the sampled settings for the world configuration
rng: the random number generator to use
Returns:
main_player_configs: configs for the main characters
Expand All @@ -208,15 +215,17 @@ def configure_players(
name = names[i]
gender = sampled_settings.person_data[name]['gender']

config = configure_player(name, gender, sampled_settings.year, is_main=True)
config = configure_player(
name, gender, sampled_settings.year, is_main=True, rng=rng
)
player_configs.append(config)

for i in range(sampled_settings.num_supporting_players):
name = names[i + sampled_settings.num_main_players]
gender = sampled_settings.person_data[name]['gender']

config = configure_player(
name, gender, sampled_settings.year, is_main=False
name, gender, sampled_settings.year, is_main=False, rng=rng
)

player_configs.append(config)
Expand Down Expand Up @@ -382,6 +391,7 @@ def configure_scenes(
supporting_player_configs: Sequence[formative_memories.AgentConfig],
start_time: datetime.datetime,
sampled_settings: WorldConfig,
rng: random.Random,
) -> tuple[
Sequence[scene_lib.SceneSpec],
list[game_master.GameMaster] | list[None],
Expand All @@ -398,6 +408,7 @@ def configure_scenes(
supporting_player_configs: configs for the supporting characters
start_time: the start time of the simulation
sampled_settings: the sampled settings for the world configuration
rng: the random number generator to use
Returns:
scenes: a sequence of scene specifications
Expand Down Expand Up @@ -425,11 +436,11 @@ def configure_scenes(
for i in range(sampled_settings.num_games * len(pairs)):

buyer_base_reward_per_item = {
item: random.randint(sampled_settings.buyer_base_reward_min, 6)
item: rng.randint(sampled_settings.buyer_base_reward_min, 6)
for item in sampled_settings.items_for_sale
}
seller_base_reward_per_item = {
item: random.randint(sampled_settings.seller_base_reward_max, 6)
item: rng.randint(sampled_settings.seller_base_reward_max, 6)
for item in sampled_settings.items_for_sale
}

Expand All @@ -455,7 +466,7 @@ def configure_scenes(
for item in sampled_settings.items_for_sale
])

scene_opening = random.choice(list(sampled_settings.scene_visuals))
scene_opening = rng.choice(list(sampled_settings.scene_visuals))
scene_specs = {
'social': scene_lib.SceneTypeSpec(
name='day',
Expand Down Expand Up @@ -614,6 +625,7 @@ def __init__(
sampled_settings.only_match_with_support = only_match_with_support
sampled_settings.num_main_players = num_main_players
sampled_settings.num_games = num_games
self._rng = random.Random(sampled_settings.random_seed)

start_time = datetime.datetime(
year=time_and_place_params.YEAR,
Expand Down Expand Up @@ -667,9 +679,9 @@ def __init__(
)

main_player_configs, supporting_player_configs = configure_players(
sampled_settings
sampled_settings, rng=self._rng
)
random.shuffle(main_player_configs)
self._rng.shuffle(main_player_configs)

tasks = {
config.name: functools.partial(
Expand Down Expand Up @@ -764,6 +776,7 @@ def __init__(
supporting_player_configs=supporting_player_configs,
start_time=start_time,
sampled_settings=sampled_settings,
rng=self._rng,
)

self._secondary_environments = choice_gms
Expand Down
8 changes: 5 additions & 3 deletions examples/modular/environment/modules/fruitville_haggling.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,21 @@
]


def sample_parameters():
def sample_parameters(seed: int | None = None):
"""Samples a set of parameters for the world configuration."""
seed = seed or random.getrandbits(63)

config = haggling.WorldConfig(
year=YEAR,
location="Fruitville",
premise=SCENARIO_PREMISE,
scene_visuals=VISUAL_SCENE_OPENINGS,
random_seed=seed,
)

all_names = list(MALE_NAMES) + list(FEMALE_NAMES)

random.shuffle(all_names)
rng = random.Random(config.random_seed)
rng.shuffle(all_names)
config.people = all_names

for _, name in enumerate(MALE_NAMES):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,22 @@
]


def sample_parameters():
def sample_parameters(seed: int | None = None):
"""Samples a set of parameters for the world configuration."""
seed = seed or random.getrandbits(63)

config = haggling_multi_item.WorldConfig(
year=YEAR,
location="Fruitville",
premise=SCENARIO_PREMISE,
scene_visuals=VISUAL_SCENE_OPENINGS,
random_seed=seed,
)

all_names = list(MALE_NAMES) + list(FEMALE_NAMES)

random.shuffle(all_names)
rng = random.Random(config.random_seed)
rng.shuffle(all_names)
config.people = all_names

for _, name in enumerate(MALE_NAMES):
Expand Down
7 changes: 5 additions & 2 deletions examples/modular/environment/modules/vegbrooke_haggling.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@
]


def sample_parameters():
def sample_parameters(seed: int | None = None):
"""Samples a set of parameters for the world configuration."""
seed = seed or random.getrandbits(63)

config = haggling.WorldConfig(
year=YEAR,
Expand All @@ -117,11 +118,13 @@ def sample_parameters():
scene_visuals=VISUAL_SCENE_OPENINGS,
buyer_base_reward_min=2,
seller_base_reward_max=5,
random_seed=seed,
)
rng = random.Random(config.random_seed)

all_names = list(MALE_NAMES) + list(FEMALE_NAMES)

random.shuffle(all_names)
rng.shuffle(all_names)
config.people = all_names

for _, name in enumerate(MALE_NAMES):
Expand Down
Loading

0 comments on commit b3cf457

Please sign in to comment.