Skip to content

Commit

Permalink
Fixed problem with transitions generation and added possibility to ge…
Browse files Browse the repository at this point in the history
…t a snapshot of the qmatrix.
  • Loading branch information
Pravez committed Jun 20, 2018
1 parent d865ba8 commit 4b73f6d
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 28 deletions.
43 changes: 26 additions & 17 deletions src/reimprove/algorithms/q/qagent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 2,7 @@ import {AbstractAgent} from "../abstract_agent";
import {AgentTrackingInformation, QAgentConfig} from "../agent_config";
import {QTransition} from "./qtransition";
import {QState, QStateData} from "./qstate";
import {QMatrix} from "./qmatrix";
import {GraphEdge, GraphNode, QMatrix} from "./qmatrix";

const DEFAULT_QAGENT_CONFIG: QAgentConfig = {
createMatrixDynamically: false,
Expand Down Expand Up @@ -47,6 47,7 @@ export class QAgent extends AbstractAgent {

config.actions.forEach(a => this.qmatrix.registerAction(a));
this.qmatrix.Actions.forEach(a => this.qmatrix.registerTransition(a.Name, this.currentState, null));
this.qmatrix.setStateAsInitial(this.currentState);
}
}

Expand All @@ -59,23 60,9 @@ export class QAgent extends AbstractAgent {
this.currentState = this.qmatrix.InitialState;
}

infer(data?: QStateData): QTransition {
infer(): QTransition {
const action = QAgent.bestAction(...this.currentState.Transitions);
this.previousTransition = this.currentState.takeAction(action.Action);
if(!this.previousTransition.To) {
let state: QState;
if(this.qmatrix.exists(data)) {
state = this.qmatrix.getStateFromData(data);
} else {
state = this.qmatrix.registerState(data);
this.qmatrix.Actions.forEach(a => this.qmatrix.registerTransition(a.Name, state, null));
}

this.previousTransition.To = state;
this.currentState = state;
} else {
this.currentState = this.previousTransition.To;
}

this.history.push(this.previousTransition);

Expand All @@ -86,13 73,31 @@ export class QAgent extends AbstractAgent {
return !this.currentState.Final;
}

learn(): void {
learn(data?: QStateData): void {
if (this.previousTransition) {
this.updateMatrix(data);
const reward = this.previousTransition.To.Reward - (this.lossOnAlreadyVisited && this.history.indexOf(this.previousTransition) !== this.history.length - 1 ? 1 : 0);
this.previousTransition.Q = reward this.AgentConfig.gamma * QAgent.bestAction(...this.previousTransition.To.Transitions).Q;
}
}

updateMatrix(data: QStateData) {
if(!this.previousTransition.To) {
let state: QState;
if(this.qmatrix.exists(data)) {
state = this.qmatrix.getStateFromData(data);
} else {
state = this.qmatrix.registerState(data);
this.qmatrix.Actions.forEach(a => this.qmatrix.registerTransition(a.Name, state, null));
}

this.qmatrix.updateTransition(this.previousTransition.Id, state);
this.currentState = state;
} else {
this.currentState = this.previousTransition.To;
}
}

finalState(reward: number, state?: QState): void {
this.qmatrix.setStateAsFinal(state ? state : this.currentState);
this.currentState.Reward = reward;
Expand Down Expand Up @@ -138,6 143,10 @@ export class QAgent extends AbstractAgent {
this.setAgentConfig(config);
}

getStatesGraph(): { nodes: GraphNode[]; edges: GraphEdge[] } {
return this.qmatrix.getGraphData();
}

reset(): void {

}
Expand Down
77 changes: 71 additions & 6 deletions src/reimprove/algorithms/q/qmatrix.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 2,20 @@ import {QAction, QActionData} from "./qaction";
import {QState, QStateData} from "./qstate";
import {QTransition} from "./qtransition";

export interface GraphNode {
id: number;
label: string;
color?: string;
}

export interface GraphEdge {
from: number;
to: number;
id: number;
label: string;
arrows: string;
color: string;
}

export class QMatrix {
private actions: Map<string, QAction>;
Expand All @@ -22,7 36,7 @@ export class QMatrix {
}

registerState(data: QStateData, reward: number = 0.): QState {
if(!this.hashFunction)
if (!this.hashFunction)
throw new Error("Unable to register a state without a hash function.");
if (this.states.has(this.hash(data)))
return this.states.get(this.hash(data));
Expand All @@ -33,14 47,29 @@ export class QMatrix {

registerTransition(action: string, from: QState, to: QState, oppositeActionName?: string): QTransition {
const qaction = this.action(action);
const transition = new QTransition(from, to, qaction);
from.setTransition(qaction, to);
if (oppositeActionName)
to.setTransition(this.action(oppositeActionName), from);

let transition = new QTransition(from, to, qaction);
from.setTransition(qaction, transition);
this.transitions.push(transition);

if (oppositeActionName) {
transition = new QTransition(to, from, qaction);
to.setTransition(this.action(oppositeActionName), transition);
this.transitions.push(transition);
}

return transition;
}

updateTransition(id: number, to: QState): QTransition | undefined {
const trans = this.transitions.find(t => t.Id === id);
if (trans) {
trans.To = to;
return trans;
}
return undefined;
}

action(name: string): QAction {
return this.actions.get(name);
}
Expand All @@ -59,7 88,7 @@ export class QMatrix {
return JSON.stringify(data);
}

getStateFromData(data: QStateData): QState | undefined{
getStateFromData(data: QStateData): QState | undefined {
return this.states.get(this.hash(data));
}

Expand Down Expand Up @@ -151,4 180,40 @@ export class QMatrix {
get Actions(): Array<QAction> {
return Array.from(this.actions.values());
}

getGraphData(): { nodes: GraphNode[], edges: GraphEdge[] } {
const nodes: GraphNode[] = this.States.map(s => ({
id: s.Id,
label: JSON.stringify(s.Data),
color: getColor(s.Reward)
}));
const edges: GraphEdge[] = this.transitions
.filter(t => t.To && t.From)
.map(t => ({
id: t.Id,
to: t.To.Id,
from: t.From.Id,
label: `${t.Q}-${t.Action.Name}`,
color: getColor(t.Q),
arrows: 'to'
}));

return {nodes: nodes, edges: edges};
}
}

function getColor(value: number) {
//value from 0 to 1
const hue = parseInt(((1 - value) * 120).toString(10));
const h = hue;
const s = 1;
const l = 0.5;

const c = (1 - Math.abs(2 * l - 1)) * s;
const x = c * (1 - Math.abs(h / 60 % 2 - 1));
const m = l - c / 2;

const values = hue < 60 ? [c, x, 0] : [x, c, 0];
const rgb = [(values[0] m) * 255, (values[1] m) * 255, (values[2] m) * 255];
return `rgb(${rgb[0]},${rgb[1]},${rgb[2]})`;
}
6 changes: 3 additions & 3 deletions src/reimprove/algorithms/q/qstate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 8,7 @@ export interface QStateData {
export class QState {
private transitions: Map<QAction, QTransition>;
private final: boolean;
private id: number;
private readonly id: number;

private static stateId: number = 0;

Expand All @@ -18,9 18,9 @@ export class QState {
this.id = QState.stateId ;
}

setTransition(action: QAction, to: QState): QTransition {
setTransition(action: QAction, transition: QTransition): QTransition {
if(!this.transitions.has(action) || this.transitions.get(action) === null)
return this.transitions.set(action, new QTransition(this, to, action)).get(action);
return this.transitions.set(action, transition).get(action);
return null;
}

Expand Down
6 changes: 6 additions & 0 deletions src/reimprove/algorithms/q/qtransition.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 4,13 @@ import {QState} from "./qstate";

export class QTransition {
private QValue: number;
private readonly id: number;

private static transitionId: number = 0;

constructor(private from: QState, private to: QState, private action: QAction) {
this.QValue = 0;
this.id = QTransition.transitionId ;
}

get Q() { return this.QValue; }
Expand All @@ -18,4 22,6 @@ export class QTransition {

set To(state: QState) { this.to = state; }
set From(state: QState) { this.from = state; }

get Id(): number { return this.id; }
}
29 changes: 27 additions & 2 deletions test/q.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 109,7 @@ describe("QLearning", () => {
});

while (qagent.isPerforming()) {
switch (qagent.infer(data).Action.Name) {
switch (qagent.infer().Action.Name) {
case "LEFT":
data.x -= data.x > 0 ? 1 : 0;
break;
Expand All @@ -119,7 119,7 @@ describe("QLearning", () => {
}


qagent.learn();
qagent.learn(data);
console.log(`State : ${qagent.CurrentState.Data.x}`);

if (data.x === 3 && data.y === 0)
Expand All @@ -128,4 128,29 @@ describe("QLearning", () => {

expect(data).to.be.deep.equal({x: 3, y: 0});
});

it("should produce a good graph output", () => {
const data = {x: 0, y: 0};
const gamma = 0.9;

qagent = new QAgent({
dataHash: hash,
initialState: data,
gamma: gamma,
createMatrixDynamically: true,
actions: ["LEFT", "RIGHT"]
});

let graph = qagent.getStatesGraph();

expect(graph.nodes.length).to.be.equal(1);
expect(graph.edges.length).to.be.equal(0);

qagent.infer();
qagent.learn({x:1, y:0});

graph = qagent.getStatesGraph();
expect(graph.nodes.length).to.be.equal(2);
expect(graph.edges.length).to.be.equal(1);
});
});

0 comments on commit 4b73f6d

Please sign in to comment.