Skip to content

Commit

Permalink
fix: chance sampling in leduc
Browse files Browse the repository at this point in the history
  • Loading branch information
Egiob committed Aug 8, 2024
1 parent 04074a3 commit 67fe74a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
13 changes: 11 additions & 2 deletions cfrx/envs/leduc_poker/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 38,9 @@ class State(pgx.leduc_holdem.State):
chance_node=jnp.bool_(True),
)
chance_node: Bool[Array, ""] = jnp.bool_(False)
chance_prior: Float[Array, "..."] = jnp.ones(NUM_TOTAL_CARDS, dtype=int)
chance_prior: Float[Array, "..."] = (
jnp.ones(NUM_DIFFERENT_CARDS, dtype=int) * NUM_REPEAT_CARDS
)


def convert_info_state_to_idx(info_state: InfoState) -> jnp.ndarray:
Expand Down Expand Up @@ -138,6 140,13 @@ def get_action_mask(self, state: State) -> jax.Array:
def get_chance_mask(self, state: State) -> jax.Array:
return state.chance_prior > 0

def get_chance_probs(self, state: State) -> jax.Array:
return jnp.where(
(state.chance_prior != 0).any(),
state.chance_prior / state.chance_prior.sum(),
0,
)

def get_info_state(self, state: State) -> jax.Array:
return state.info_state

Expand Down Expand Up @@ -170,7 179,7 @@ def _init(self, rng: Shaped[PRNGKeyArray, "2"]) -> State:
_chips=env_state._chips,
_raise_count=env_state._raise_count,
info_state=info_state,
chance_prior=jnp.ones(NUM_TOTAL_CARDS, dtype=int),
chance_prior=jnp.ones(NUM_DIFFERENT_CARDS, dtype=int) * NUM_REPEAT_CARDS,
chance_node=jnp.bool_(True),
)

Expand Down
3 changes: 1 addition & 2 deletions cfrx/tree/traverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 318,7 @@ def loop_fn(val: Tuple) -> Tuple:
use_behavior_policy=jnp.bool_(False),
)

chance_mask = env.get_chance_mask(parent_state)
chance_strategy = chance_mask[action] / chance_mask.sum()
chance_strategy = env.get_chance_probs(parent_state)[action]
# jax.debug.breakpoint()

action_prob = jnp.where(
Expand Down

0 comments on commit 67fe74a

Please sign in to comment.