Skip to content

Commit

Permalink
Add migrations and refactor internal structs
Browse files Browse the repository at this point in the history
  • Loading branch information
kegsay committed May 17, 2024
1 parent 2cd9a81 commit b383ed0
Show file tree
Hide file tree
Showing 12 changed files with 439 additions and 217 deletions.
102 changes: 19 additions & 83 deletions internal/device_data.go
Original file line number Diff line number Diff line change
@@ -1,9 1,5 @@
package internal

import (
"sync"
)

const (
bitOTKCount int = iota
bitFallbackKeyTypes
Expand All @@ -18,105 14,45 @@ func isBitSet(n int, bit int) bool {
return val > 0
}

// DeviceData contains useful data for this user's device. This list can be expanded without prompting
// schema changes. These values are upserted into the database and persisted forever.
// DeviceData contains useful data for this user's device.
type DeviceData struct {
DeviceListChanges
DeviceKeyData
UserID string
DeviceID string
}

// This is calculated from device_lists table
type DeviceListChanges struct {
DeviceListChanged []string
DeviceListLeft []string
}

// This gets serialised as CBOR in device_data table
type DeviceKeyData struct {
// Contains the latest device_one_time_keys_count values.
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
OTKCounts MapStringInt `json:"otk"`
// Contains the latest device_unused_fallback_key_types value
// Set whenever this field arrives down the v2 poller, and it replaces what was previously there.
// If this is a nil slice this means no change. If this is an empty slice then this means the fallback key was used up.
FallbackKeyTypes []string `json:"fallback"`

DeviceLists DeviceLists `json:"dl"`

// bitset for which device data changes are present. They accumulate until they get swapped over
// when they get reset
ChangedBits int `json:"c"`

UserID string
DeviceID string
}

func (dd *DeviceData) SetOTKCountChanged() {
func (dd *DeviceKeyData) SetOTKCountChanged() {
dd.ChangedBits = setBit(dd.ChangedBits, bitOTKCount)
}

func (dd *DeviceData) SetFallbackKeysChanged() {
func (dd *DeviceKeyData) SetFallbackKeysChanged() {
dd.ChangedBits = setBit(dd.ChangedBits, bitFallbackKeyTypes)
}

func (dd *DeviceData) OTKCountChanged() bool {
func (dd *DeviceKeyData) OTKCountChanged() bool {
return isBitSet(dd.ChangedBits, bitOTKCount)
}
func (dd *DeviceData) FallbackKeysChanged() bool {
func (dd *DeviceKeyData) FallbackKeysChanged() bool {
return isBitSet(dd.ChangedBits, bitFallbackKeyTypes)
}

type UserDeviceKey struct {
UserID string
DeviceID string
}

type DeviceDataMap struct {
deviceDataMu *sync.Mutex
deviceDataMap map[UserDeviceKey]*DeviceData
Pos int64
}

func NewDeviceDataMap(startPos int64, devices []DeviceData) *DeviceDataMap {
ddm := &DeviceDataMap{
deviceDataMu: &sync.Mutex{},
deviceDataMap: make(map[UserDeviceKey]*DeviceData),
Pos: startPos,
}
for i, dd := range devices {
ddm.deviceDataMap[UserDeviceKey{
UserID: dd.UserID,
DeviceID: dd.DeviceID,
}] = &devices[i]
}
return ddm
}

func (d *DeviceDataMap) Get(userID, deviceID string) *DeviceData {
key := UserDeviceKey{
UserID: userID,
DeviceID: deviceID,
}
d.deviceDataMu.Lock()
defer d.deviceDataMu.Unlock()
dd, ok := d.deviceDataMap[key]
if !ok {
return nil
}
return dd
}

func (d *DeviceDataMap) Update(dd DeviceData) DeviceData {
key := UserDeviceKey{
UserID: dd.UserID,
DeviceID: dd.DeviceID,
}
d.deviceDataMu.Lock()
defer d.deviceDataMu.Unlock()
existing, ok := d.deviceDataMap[key]
if !ok {
existing = &DeviceData{
UserID: dd.UserID,
DeviceID: dd.DeviceID,
}
}
if dd.OTKCounts != nil {
existing.OTKCounts = dd.OTKCounts
}
if dd.FallbackKeyTypes != nil {
existing.FallbackKeyTypes = dd.FallbackKeyTypes
}
existing.DeviceLists = existing.DeviceLists.Combine(dd.DeviceLists)

d.deviceDataMap[key] = existing

return *existing
}
74 changes: 36 additions & 38 deletions state/device_data_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 15,9 @@ type DeviceDataRow struct {
ID int64 `db:"id"`
UserID string `db:"user_id"`
DeviceID string `db:"device_id"`
// This will contain internal.DeviceData serialised as JSON. It's stored in a single column as we don't
// This will contain internal.DeviceKeyData serialised as JSON. It's stored in a single column as we don't
// need to perform searches on this data.
Data []byte `db:"data"`
KeyData []byte `db:"data"`
}

type DeviceDataTable struct {
Expand Down Expand Up @@ -47,6 47,7 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
// This should only be called by the v3 HTTP APIs when servicing an E2EE extension request.
func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
// grab otk counts and fallback key types
var row DeviceDataRow
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
if err != nil {
Expand All @@ -56,32 57,38 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
}
return err
}
result = &internal.DeviceData{}
var keyData *internal.DeviceKeyData
// unmarshal to swap
opts := cbor.DecOptions{
MaxMapPairs: 1000000000, // 1 billion :(
if err = cbor.Unmarshal(row.KeyData, &keyData); err != nil {
return err
}
decMode, err := opts.DecMode()
result.UserID = userID
result.DeviceID = deviceID
if keyData != nil {
result.DeviceKeyData = *keyData
}

deviceListChanges, err := t.deviceListTable.SelectTx(txn, userID, deviceID, swap)
if err != nil {
return err
}
if err = decMode.Unmarshal(row.Data, &result); err != nil {
return err
for targetUserID, targetState := range deviceListChanges {
switch targetState {
case internal.DeviceListChanged:
result.DeviceListChanged = append(result.DeviceListChanged, targetUserID)
case internal.DeviceListLeft:
result.DeviceListLeft = append(result.DeviceListLeft, targetUserID)
}
}
result.UserID = userID
result.DeviceID = deviceID
if !swap {
return nil // don't swap
}
// the caller will only look at sent, so make sure what is new is now in sent
result.DeviceLists.Sent = result.DeviceLists.New

// swap over the fields
writeBack := *result
writeBack.DeviceLists.Sent = result.DeviceLists.New
writeBack.DeviceLists.New = make(map[string]int)
writeBack := *keyData
writeBack.ChangedBits = 0

if reflect.DeepEqual(result, &writeBack) {
if reflect.DeepEqual(keyData, &writeBack) {
// The update to the DB would be a no-op; don't bother with it.
// This helps reduce write usage and the contention on the unique index for
// the device_data table.
Expand All @@ -99,45 106,36 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
return
}

func (t *DeviceDataTable) DeleteDevice(userID, deviceID string) error {
_, err := t.db.Exec(`DELETE FROM syncv3_device_data WHERE user_id = $1 AND device_id = $2`, userID, deviceID)
return err
}

// Upsert combines what is in the database for this user|device with the partial entry `dd`
func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (err error) {
func (t *DeviceDataTable) Upsert(dd *internal.DeviceData, deviceListChanges map[string]int) (err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
// Update device lists
if err = t.deviceListTable.UpsertTx(txn, dd.UserID, dd.DeviceID, deviceListChanges); err != nil {
return err
}
// select what already exists
var row DeviceDataRow
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, dd.UserID, dd.DeviceID)
if err != nil && err != sql.ErrNoRows {
return err
}
// unmarshal and combine
var tempDD internal.DeviceData
if len(row.Data) > 0 {
opts := cbor.DecOptions{
MaxMapPairs: 1000000000, // 1 billion :(
}
decMode, err := opts.DecMode()
if err != nil {
return err
}
if err = decMode.Unmarshal(row.Data, &tempDD); err != nil {
var keyData internal.DeviceKeyData
if len(row.KeyData) > 0 {
if err = cbor.Unmarshal(row.KeyData, &keyData); err != nil {
return err
}
}
if dd.FallbackKeyTypes != nil {
tempDD.FallbackKeyTypes = dd.FallbackKeyTypes
tempDD.SetFallbackKeysChanged()
keyData.FallbackKeyTypes = dd.FallbackKeyTypes
keyData.SetFallbackKeysChanged()
}
if dd.OTKCounts != nil {
tempDD.OTKCounts = dd.OTKCounts
tempDD.SetOTKCountChanged()
keyData.OTKCounts = dd.OTKCounts
keyData.SetOTKCountChanged()
}
tempDD.DeviceLists = tempDD.DeviceLists.Combine(dd.DeviceLists)

data, err := cbor.Marshal(tempDD)
data, err := cbor.Marshal(keyData)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit b383ed0

Please sign in to comment.