My idea is to use my alectors library to parse the JSON as tokens, as standard LLM approaches do, but because it is a pure RL library, it is a markov decision process not a markov chain.

observation space #

The state space is obviously the JSON object, with a prompt like

prompt = f"""
You have to find the pattern. You are given some clues.
Given {obj['train'][0]['input']} you get {obj['train'][0]['output']}.
Given {obj['train'][1]['input']} you get {obj['train'][1]['output']}.
Now based on {obj['test'][0]['input']}, what is the output?
"""

(this is obviously a simplified prompt for berevity.)

In theory, since we know that the max number of grid squares is 30x30, could we not make a wrapper and use a CNN, or even flatten the input and outputs, and use simple traditional RL methods?

action space #

The maximum number of grid squares is 30, and each square can have an integer from 0 to 9, inclusive. This means that in total there are 30x30x10=9000 possible actions, which is a huge action space1. We can just use it as is, and hope for the best (which will massively increase the memory needs of the network), or we can come up with a clever workaround.

Since we are not zero-shotting, and we will have multiple steps before a final output, we can simply add more steps rather than add more possible actions. A basic approach is to split it into two distinct steps, ‘pick grid square’ and ‘pick integer for chosen grid square’. Then, the agent would choose an action from each of the possible squares in a flat array, and then pick a color. In order to force the agent to pick a color we could have two distinct ways; the first one would be to repeat the colors across the entirety of the action space, or we could heavily negatively reward the agent for picking an action outside of the integer value during ‘integer picking for grid square’.

The above seems like a bad solution; the action space would still be too big (~900 at worst) and the fact that we can have at most 10 integers per grid square means that during integer picking, the agent might have trouble exploring the action space effectively. Also, due to the curse of dimentionality, one should aim to lower the action space as much as possible.

It would then make sense for a three-ministep step. We would have a 30d action space, with three disctict steps. One would be to pick the row, then the column, and finally the integer for the row/column.

for ministep in ['row', 'column', 'color']:
  prompt_with_ministep = prompt+ministep
  action = agent.choose_action(prompt_with_ministep)
  # ...

(again, this is an oversimplification)


  1. smaller than the vocab size that LLMs use(~3e5), but it is too big to cheaply do anything. ↩︎