File: stm.ml

package info (click to toggle)
cothreads 0.10-7
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 500 kB
  • sloc: ml: 1,963; makefile: 216
file content (385 lines) | stat: -rw-r--r-- 13,835 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
open Coordinator
open Libext

let stm_magic = "STM2007MTS"

type tvid = string * int * int and version = int and value = Obj.t

module TvMap = Map_Make (struct type t = tvid let compare = Stdlib.compare end)
module TvSet = Set.Make (struct type t = tvid let compare = Stdlib.compare end)

type tv_repr = {version: version; value: value; ref_to: TvSet.t}

(* New: pre_version(value) = None
   First Read: pre_version(value) = Some _
   First Write: pre_version = Some _, pre_value = None *)
type tv_log = {pre_version: version option; pre_value: value option;
               mutable cur_value: value option}

type commit_log =
    { read_log: version TvMap.t; write_log: tv_repr TvMap.t; dirty_log: TvSet.t }

type stm_msg =
  [ `Tvar of tvid * tv_repr * thread * bool portal
  | `Wait of version TvMap.t * thread * bool portal
  | `Atom of thread * tv_repr TvMap.patch portal
  | `Commit of commit_log * thread * bool portal
  ]

let stm_portal : stm_msg portal = create_portal ()

let repr_eq {version=v1} {version=v2} = v1 = v2
let var_of_val v = Obj.obj v and val_of_var v = Obj.repr v

type 'a tvar = tvid

type thr_state =
    { mutable env: tv_repr TvMap.t; mutable log: tv_log TvMap.t;
      mutable tvid_count: int; mutable layer: int; mutable dirty: TvSet.t; }

let state = 
  { env = TvMap.empty; log = TvMap.empty; tvid_count = 0; 
    layer = 0; dirty = TvSet.empty }

let state_reset diff =
  assert (state.layer = 0);
  state.env <- TvMap.patch_left repr_eq state.env diff; 
  state.log <- TvMap.empty


(* Shallow copy, only work for data structure like tvid *)
let copy (x:tvid) : tvid = Obj.obj (Obj.dup (Obj.repr x))
let tvmap_add k = TvMap.add (copy k)
let tvset_add k = TvSet.add (copy k)
let finaliser vl = state.dirty <- TvSet.remove (var_of_val vl) state.dirty
let suspicious () =
  TvMap.fold (fun tv log set -> match log with
              | {pre_version = Some _; pre_value = Some _} -> 
                  TvSet.union (TvMap.find tv state.env).ref_to set
              | {pre_version = None} -> TvSet.add tv set
              | _ -> set
             ) state.log state.dirty


type 'a stm = unit -> 'a
let return v = fun () -> v
let bind t f = fun () -> f (t ()) ()
let ( >>= ) = bind
let ( >> ) s1 s2 = s1 >>= fun _ -> s2

let reference v = TvSet.filter (fun x -> obj_refed_by (=) x v)

let tvar v =
  if (not !inited) then init ();
  let self_t = self () in
  let self_id = id self_t in
  state.tvid_count <- succ state.tvid_count;
  let new_tvid = (stm_magic, self_id, state.tvid_count) in
  let ref_to = reference v (suspicious ()) in
  let new_repr = {version=0; value=val_of_var v; ref_to = ref_to} in
  let b = demand_portal 
    (fun p -> `Tvar (new_tvid, new_repr, self_t, p)) stm_portal in
  assert b;
  state.env <- tvmap_add new_tvid new_repr state.env;
  state.dirty <- tvset_add new_tvid state.dirty;
  Gc.finalise finaliser (val_of_var new_tvid);
  new_tvid

let new_tvar v = fun () ->
  let self_id = id (self ()) in
  state.tvid_count <- succ state.tvid_count;
  let new_tvid = (stm_magic, self_id, state.tvid_count) in
  let new_log = 
    { pre_version = None; pre_value = None; 
      cur_value = Some (val_of_var v) } in
  state.log <- tvmap_add new_tvid new_log state.log;
  new_tvid

let read_tvar tv = fun () ->
  let value = try
    let log = TvMap.find tv state.log in
    match log.cur_value, log.pre_value with
    | Some v, _ | _, Some v -> v | _ -> assert false
  with Not_found -> 
    let repr = TvMap.find tv state.env in
    state.log <- tvmap_add tv 
      { pre_version = Some repr.version; 
        pre_value = Some repr.value; cur_value = None } 
      state.log;
    repr.value in
  var_of_val value

let write_tvar tv v = fun () ->
  try
    let log = TvMap.find tv state.log in
    log.cur_value <- Some (val_of_var v)
  with Not_found ->
    let repr = TvMap.find tv state.env in
    state.log <- TvMap.add tv 
      { pre_version = Some repr.version; pre_value = None;
        cur_value = Some (val_of_var v) }
      state.log

let wait = fun () ->
  let wait_tv = TvMap.fold 
    (fun tv log map -> match log with
     | {pre_version = Some v; pre_value = Some _} -> TvMap.add tv v map
     | _ -> map)
    state.log TvMap.empty in
  assert (demand_portal (fun p -> `Wait (wait_tv, self (), p)) stm_portal)

exception Abort
exception Retry of bool (* whether wait or not *)
let abort = fun () -> raise Abort
let retry = fun () ->  raise (Retry true)
let retry_now = fun () -> raise (Retry false)
let save_state st = {st with layer = st.layer} (* actually copy everything *)
let restore_state st st_bak = st.log <- st_bak.log; st.layer <- st_bak.layer

let catch t f = fun () ->
  let state_bak = save_state state in
  try t () with
  | Retry _ | Abort as e -> raise e
  | e -> restore_state state state_bak; f e ()

let or_else t1 t2 = fun () ->
  let state_bak = save_state state in
  try t1 () with (Abort | Retry _) as e1 ->
    let state_bak_1 = save_state state in
    restore_state state state_bak;
    try t2 () with (Abort | Retry _) as e2 ->
      match e1, e2 with
      | Abort, Abort -> restore_state state state_bak; raise Abort
      | Retry b, Abort -> restore_state state state_bak_1; raise (Retry b)
      | Abort, Retry b -> raise (Retry b)
      | Retry b1, Retry b2 ->
          let comb_log = TvMap.merge
            (fun k v1 v2 tbl -> match v1.pre_value,v2.pre_value with
             | None, Some _ -> TvMap.add k v2 tbl | _,_ -> tbl)
            state.log state_bak_1.log in
          restore_state state state_bak;
          state.log <- comb_log;
          raise (Retry (b1 && b2))
      | _, _ -> assert false

let dirtirise v susp = 
  let mark tv = 
    let val_tv = val_of_var tv in
    obj_iter 
      (fun o -> 
         if o = val_tv then 
           (state.dirty <- TvSet.add tv state.dirty; Gc.finalise finaliser o)
      ) v in 
  TvSet.iter mark susp


(* Compute locally to save the effort of coordinator *)
let commit_log susp =
  let read,write = TvMap.fold 
    (fun tv {pre_version=pver; pre_value=pval; cur_value=cval} (r,w) -> 
       let w = match cval with
         | Some v -> 
             let repr = 
               { version= 0; (* tmp_value, to be changed when commit *) 
                 value=v; ref_to=reference v susp } in
             TvMap.add tv repr w
         | None -> w in
       let r = match pval, pver with
         | Some _, Some ver -> TvMap.add tv ver r
         | _ -> r in
       (r, w)
    ) state.log (TvMap.empty, TvMap.empty) in
  {read_log = read; write_log = write; dirty_log = state.dirty}


let commit v : bool = 
  let susp = suspicious () in
  let _ = dirtirise v susp in
  let clog = commit_log susp in
  demand_portal (fun p -> `Commit (clog, self (), p)) stm_portal


let rec atom_once t =
  if (not !inited) then init ();
  (if state.layer = 0 then 
     let diff = demand_portal (fun p -> `Atom (self (), p)) stm_portal in
     state_reset diff);
  state.layer <- succ state.layer;
  try
    let v = t () in
    state.layer <- pred state.layer;
    if state.layer > 0 || commit v then Some v else None
  with e ->
    state.layer <- pred state.layer;
    match state.layer, e with
    | 0, Retry b -> if b then wait (); atom_once t
    | 0, Abort -> None
    | _,_ -> raise e

let rec atom t = match atom_once t with None -> atom t | Some v -> v


(* Root Service *)

type tv_rec = 
    { mutable ref_by_tv: TvSet.t; 
      mutable ref_by_thr: ThreadSet.t;
      mutable tv_wait: bool portal option ref list }

type thr_rec = 
    { mutable thr_env: tv_repr TvMap.t; 
      mutable tv_dirty: TvSet.t }

type stm_root = 
    { mutable root_env: tv_repr TvMap.t;
      mutable root_rec: tv_rec TvMap.t;
      mutable root_thr: thr_rec ThreadMap.t;
    }

let root = {root_env = TvMap.empty; root_rec = TvMap.empty; root_thr = ThreadMap.empty}

let empty_tv_rec () = 
  { ref_by_tv = TvSet.empty; ref_by_thr = ThreadSet.empty; tv_wait = []}
let empty_thr_rec () =
  { thr_env = root.root_env; tv_dirty = TvSet.empty }


(* Primitive Tvar service *)
let tvar_handle tv repr thr p = 
  let tv_rec = empty_tv_rec () in
  tv_rec.ref_by_thr <- ThreadSet.add thr tv_rec.ref_by_thr;
  root.root_env <- TvMap.add tv repr root.root_env;
  root.root_rec <- TvMap.add tv tv_rec root.root_rec;
  let thr_rec = ThreadMap.find thr root.root_thr in
  thr_rec.thr_env <- TvMap.add tv repr thr_rec.thr_env;
  thr_rec.tv_dirty <- TvSet.add tv thr_rec.tv_dirty;
  write_portal true p

(* Primitive Wait service *)
let wait_handle wait_tv thr p = 
  let answer_port = ref None in
  let mark tv version =
    let tv_repr = TvMap.find tv root.root_env in
    if tv_repr.version > version then raise Break else 
      let reco = TvMap.find tv root.root_rec in
      reco.tv_wait <- answer_port :: reco.tv_wait in
  try
    TvMap.iter mark wait_tv;
    answer_port := Some p
  with Break -> write_portal true p

(* Primitve Atom service *)
let atom_handle thr p = 
  let thr_rec = ThreadMap.find thr root.root_thr in
  let diff = TvMap.diff repr_eq thr_rec.thr_env root.root_env in
  thr_rec.thr_env <- root.root_env;
  write_portal diff p

let opr_ref_by_thr op tv =
  let tv_rec = TvMap.find tv root.root_rec in
  tv_rec.ref_by_thr <- op tv_rec.ref_by_thr
let opr_ref_by_tv op tv =
  let tv_rec = TvMap.find tv root.root_rec in
  tv_rec.ref_by_tv <- op tv_rec.ref_by_tv

(* Primitve Commit service *)
let commit_handle {read_log=rl; write_log=wl; dirty_log=dl} thr p =
  let conflict = try
    TvMap.iter 
      (fun tv ver -> 
         if (TvMap.find tv root.root_env).version <> ver then raise Break)
      rl;
    false
  with Break | Not_found -> true in
  if conflict then write_portal false p else begin
    (* References decreasing set *)
    let ref_dec_set = ref TvSet.empty in
    (* we must first update the whole root_env to its final state before we
       begin to test dirty sets relation, otherwise there will be
       inconsistence *)
    let _ = TvMap.iter
      (fun tv repr ->
         try
           let old_repr = TvMap.find tv root.root_env in
           let new_repr = {repr with version = old_repr.version   1} in
           TvSet.iter (opr_ref_by_tv (TvSet.remove tv)) old_repr.ref_to;
           ref_dec_set := TvSet.union !ref_dec_set old_repr.ref_to;
           (* we can not handle new references to other tv at this moment,
              because not all tv has been commited in *)
           root.root_env <- TvMap.add tv new_repr root.root_env;
           (* reactive waiting thread because of the value change *)
           let tv_rec = TvMap.find tv root.root_rec in
           List.iter (fun w -> match !w with 
                      | Some p -> (write_portal true p; w := None)
                      | None -> ()
                     ) tv_rec.tv_wait;
           tv_rec.tv_wait <- []
         with Not_found -> (* only reason: new tvar; collected: impossible *)
           root.root_env <- TvMap.add tv repr root.root_env;
           (* create record for new tvar now, in case of dangling points when
              updating reference *)
           root.root_rec <- TvMap.add tv (empty_tv_rec ()) root.root_rec;
      ) wl in
    (* handle new reference now *)
    let _ = TvMap.iter
      (fun tv repr -> 
         TvSet.iter (opr_ref_by_tv (TvSet.add tv)) repr.ref_to;
         ref_dec_set := TvSet.diff !ref_dec_set repr.ref_to
      ) wl in
    (* We update thr_rec in the next step *)
    let _ = 
      let thr_rec = ThreadMap.find thr root.root_thr in
      let to_remove = TvSet.diff thr_rec.tv_dirty dl in
      let to_add = TvSet.diff dl thr_rec.tv_dirty in
      TvSet.iter (opr_ref_by_thr (ThreadSet.remove thr)) to_remove;
      ref_dec_set := TvSet.union !ref_dec_set to_remove;
      TvSet.iter (opr_ref_by_thr (ThreadSet.add thr)) to_add;
      ref_dec_set := TvSet.diff !ref_dec_set to_add;
      thr_rec.tv_dirty <- dl in
    (* Finally doing house mantinance: GC *)
    let _ = 
      let rec gc tv_set = 
        let tv = TvSet.max_elt tv_set in
        let tv_rest = TvSet.remove tv tv_set in
        let tv_rec = TvMap.find tv root.root_rec in
        if TvSet.is_empty tv_rec.ref_by_tv && 
          ThreadSet.is_empty tv_rec.ref_by_thr 
        then
          let ref_to = (TvMap.find tv root.root_env).ref_to in
          root.root_rec <- TvMap.remove tv root.root_rec;
          root.root_env <- TvMap.remove tv root.root_env;
          TvSet.iter (opr_ref_by_tv (TvSet.remove tv)) ref_to;
          gc (TvSet.union ref_to tv_rest)
        else gc tv_rest in
      try gc !ref_dec_set with Not_found -> () in
    (* For now, we don't update the env record of thread, to make it agree with
       old_env of client; only when next atom requirement, we diff the current
       root_env with this version, update this version,  and send out patchs
    *)
    write_portal true p
  end

let stm_extend_handle : root_msg -> unit = function
  | `Create (t', t, _) ->
      (try 
        let fat_thr = ThreadMap.find t' root.root_thr in
        let son_thr = {fat_thr with tv_dirty = fat_thr.tv_dirty} in
        root.root_thr <- ThreadMap.add t son_thr root.root_thr;
        TvSet.iter (opr_ref_by_thr (ThreadSet.add t)) son_thr.tv_dirty;
       with Not_found -> (* The first one *)
         root.root_thr <- ThreadMap.add t (empty_thr_rec ()) root.root_thr)
  | `Delete (t, _) ->
      let thr_rec = ThreadMap.find t root.root_thr in
      TvSet.iter (opr_ref_by_thr (ThreadSet.remove t)) thr_rec.tv_dirty;
      root.root_thr <- ThreadMap.remove t root.root_thr
  | _ -> ()

let stm_handle : stm_msg -> unit = function
  | `Tvar (new_tvid, new_repr, self_t, p) -> 
      tvar_handle new_tvid new_repr self_t p
  | `Wait (touch, thr, p) -> wait_handle touch thr p
  | `Atom (thr, p) -> atom_handle thr p
  | `Commit (clog, thr, p) -> commit_handle clog thr p

let _ = new_serv root_portal stm_extend_handle
let _ = new_serv stm_portal stm_handle