Skip to content

Commit

Permalink
fix always return ByteTensor bug
Browse files Browse the repository at this point in the history
  • Loading branch information
BTNC committed Oct 5, 2016
1 parent 9aa079e commit a2eac63
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 7 deletions.
2 changes: 2 additions & 0 deletions cdefs.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 9,7 @@ tds_elem *tds_elem_new(void);
void tds_elem_free(tds_elem *elem);
uint32_t tds_elem_hashkey(tds_elem *elem);
int tds_elem_isequal(tds_elem *elem1, tds_elem *elem2);
void tds_elem_set_subtype(tds_elem *elem, char subtype);
void tds_elem_set_number(tds_elem *elem, double num);
void tds_elem_set_boolean(tds_elem *elem, bool flag);
void tds_elem_set_string(tds_elem *elem, const char *str, size_t size);
Expand All @@ -20,6 21,7 @@ size_t tds_elem_get_string_size(tds_elem *elem);
void* tds_elem_get_pointer(tds_elem *elem);
tds_elem_pointer_free_ptrfunc tds_elem_get_pointer_free(tds_elem *elem);
char tds_elem_type(tds_elem *elem);
char tds_elem_subtype(tds_elem *elem);
void tds_elem_free_content(tds_elem *elem);
void tds_elem_set_nil(tds_elem *elem);
int tds_elem_isnil(tds_elem *elem);
Expand Down
24 changes: 18 additions & 6 deletions elem.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 5,19 @@ local C = tds.C
local elem = {}

local elem_ctypes = {}
local elem_ctypes_abbr2name = {}
local elem_ctypes_name2abbr = {}

function elem.type()
end

function elem.addctype(ttype, free_p, setfunc, getfunc)
function elem.addctype(ttype, free_p, setfunc, getfunc, abbr)
elem_ctypes[ttype] = setfunc
elem_ctypes[tonumber(ffi.cast('intptr_t', free_p))] = getfunc
if abbr then
elem_ctypes_abbr2name[abbr] = ttype
elem_ctypes_name2abbr[ttype] = abbr
end
end

function elem.set(celem, lelem)
Expand All @@ -29,6 35,10 @@ function elem.set(celem, lelem)
else
error(string.format('unsupported key/value type <%s> (set)', tname and tname or type(lelem)))
end
local abbr = elem_ctypes_name2abbr[tname]
if abbr then
C.tds_elem_set_subtype(celem, abbr)
end
end
end

Expand All @@ -45,11 55,12 @@ function elem.get(celem)
local value = ffi.string(C.tds_elem_get_string(celem), tonumber(C.tds_elem_get_string_size(celem)))
return value
elseif elemtype == 112 then--string.byte('p') then
local subtype = C.tds_elem_subtype(celem)
local lelem_p = C.tds_elem_get_pointer(celem)
local free_p = C.tds_elem_get_pointer_free(celem)
local getfunc = elem_ctypes[tonumber(ffi.cast('intptr_t', free_p))]
if getfunc then
local value = getfunc(lelem_p)
local value = getfunc(lelem_p, subtype)
return value
else
error('unsupported key/value type (get)')
Expand Down Expand Up @@ -110,20 121,21 @@ void THRealTensor_free(THRealTensor *self);
local THTensor_free = T[string.format('TH%sTensor_free', Real)]
local tensor_type_id = string.format('torch.%sTensor', Real)
elem.addctype(
string.format('torch.%sTensor', Real),
tensor_type_id,
THTensor_free,
function(lelem)
local lelem_p = lelem:cdata()
THTensor_retain(lelem_p)
return lelem_p, THTensor_free
end,
function(lelem_p)
function(lelem_p, subtype)
THTensor_retain(lelem_p)
local tensor_type_id = elem_ctypes_abbr2name[subtype]
local lelem = torch.pushudata(lelem_p, tensor_type_id)
return lelem
end
end,
string.byte(Real, 1)
)

end
end

Expand Down
2 changes: 2 additions & 0 deletions tds.def
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 5,7 @@ tds_elem_new
tds_elem_free
tds_elem_hashkey
tds_elem_isequal
tds_elem_set_subtype
tds_elem_set_number
tds_elem_set_boolean
tds_elem_set_string
Expand All @@ -16,6 17,7 @@ tds_elem_get_string_size
tds_elem_get_pointer
tds_elem_get_pointer_free
tds_elem_type
tds_elem_subtype
tds_elem_free_content
tds_elem_set_nil
tds_elem_isnil
Expand Down
18 changes: 17 additions & 1 deletion tds_elem.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 78,7 @@ uint32_t tds_elem_hashkey(tds_elem *elem)

int tds_elem_isequal(tds_elem *elem1, tds_elem *elem2)
{
if(elem1->type != elem2->type)
if(elem1->type != elem2->type || elem1->subtype != elem1->subtype)
return 0;
switch(elem1->type) {
case 'n':
Expand All @@ -97,21 97,29 @@ int tds_elem_isequal(tds_elem *elem1, tds_elem *elem2)
}
}

void tds_elem_set_subtype(tds_elem *elem, char subtype)
{
elem->subtype = subtype;
}

void tds_elem_set_number(tds_elem *elem, double num)
{
elem->type = 'n';
elem->subtype = 0;
elem->value.num = num;
}

void tds_elem_set_boolean(tds_elem *elem, bool flag)
{
elem->type = 'b';
elem->subtype = 0;
elem->value.flag = flag;
}

void tds_elem_set_string(tds_elem *elem, const char *str, size_t size)
{
elem->type = 's';
elem->subtype = 0;
elem->value.str.data = tds_malloc(size);
if(elem->value.str.data) {
memcpy(elem->value.str.data, str, size);
Expand All @@ -122,6 130,7 @@ void tds_elem_set_string(tds_elem *elem, const char *str, size_t size)
void tds_elem_set_pointer(tds_elem *elem, void *ptr, void (*free)(void*))
{
elem->type = 'p';
elem->subtype = 0;
elem->value.ptr.data = ptr;
elem->value.ptr.free = free;
}
Expand Down Expand Up @@ -161,18 170,25 @@ char tds_elem_type(tds_elem *elem)
return elem->type;
}

char tds_elem_subtype(tds_elem *elem)
{
return elem->subtype;
}

void tds_elem_free_content(tds_elem *elem)
{
if(elem->type == 's')
tds_free(elem->value.str.data);
if(elem->type == 'p' && elem->value.ptr.free)
elem->value.ptr.free(elem->value.ptr.data);
elem->type = 0;
elem->subtype = 0;
}

void tds_elem_set_nil(tds_elem *elem)
{
elem->type = 0;
elem->subtype = 0;
}

int tds_elem_isnil(tds_elem *elem)
Expand Down
3 changes: 3 additions & 0 deletions tds_elem.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 27,15 @@ typedef struct tds_elem_ {
} value;

char type;
char subtype;

} tds_elem;

tds_elem *tds_elem_new(void);
void tds_elem_free(tds_elem *elem);
uint32_t tds_elem_hashkey(tds_elem *elem);
int tds_elem_isequal(tds_elem *elem1, tds_elem *elem2);
void tds_elem_set_subtype(tds_elem *elem, char subtype);
void tds_elem_set_number(tds_elem *elem, double num);
void tds_elem_set_boolean(tds_elem *elem, bool flag);
void tds_elem_set_string(tds_elem *elem, const char *str, size_t size);
Expand All @@ -45,6 47,7 @@ size_t tds_elem_get_string_size(tds_elem *elem);
void* tds_elem_get_pointer(tds_elem *elem);
tds_elem_pointer_free_ptrfunc tds_elem_get_pointer_free(tds_elem *elem);
char tds_elem_type(tds_elem *elem);
char tds_elem_subtype(tds_elem *elem);
void tds_elem_free_content(tds_elem *elem);
void tds_elem_set_nil(tds_elem *elem);
int tds_elem_isnil(tds_elem *elem);
Expand Down

0 comments on commit a2eac63

Please sign in to comment.