Skip to content

Commit

Permalink
Merge pull request #28 from BTNC/master
Browse files Browse the repository at this point in the history
fix always return ByteTensor bug && port to windows
  • Loading branch information
soumith authored Oct 12, 2016
2 parents 041d628 a2eac63 commit 3f16e88
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 8 deletions.
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 15,9 @@ add_library(${PKGNAME} MODULE
"tds_vec.c"
"tds_atomic_counter.c"
)
SET_TARGET_PROPERTIES(${PKGNAME} PROPERTIES
PREFIX "lib"
IMPORT_PREFIX "lib")

IF(Torch_FOUND)
TARGET_LINK_LIBRARIES(${PKGNAME} TH)
Expand All @@ -27,3 30,8 @@ FILE(GLOB luafiles *.lua)

install(FILES ${luafiles}
DESTINATION ${LUA_PATH}/${PKGNAME})

if(MSVC)
set(CMAKE_MODULE_LINKER_FLAGS
"${CMAKE_MODULE_LINKER_FLAGS} /DEF:..\\tds.def")
endif()
2 changes: 2 additions & 0 deletions cdefs.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 9,7 @@ tds_elem *tds_elem_new(void);
void tds_elem_free(tds_elem *elem);
uint32_t tds_elem_hashkey(tds_elem *elem);
int tds_elem_isequal(tds_elem *elem1, tds_elem *elem2);
void tds_elem_set_subtype(tds_elem *elem, char subtype);
void tds_elem_set_number(tds_elem *elem, double num);
void tds_elem_set_boolean(tds_elem *elem, bool flag);
void tds_elem_set_string(tds_elem *elem, const char *str, size_t size);
Expand All @@ -20,6 21,7 @@ size_t tds_elem_get_string_size(tds_elem *elem);
void* tds_elem_get_pointer(tds_elem *elem);
tds_elem_pointer_free_ptrfunc tds_elem_get_pointer_free(tds_elem *elem);
char tds_elem_type(tds_elem *elem);
char tds_elem_subtype(tds_elem *elem);
void tds_elem_free_content(tds_elem *elem);
void tds_elem_set_nil(tds_elem *elem);
int tds_elem_isnil(tds_elem *elem);
Expand Down
26 changes: 19 additions & 7 deletions elem.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 5,19 @@ local C = tds.C
local elem = {}

local elem_ctypes = {}
local elem_ctypes_abbr2name = {}
local elem_ctypes_name2abbr = {}

function elem.type()
end

function elem.addctype(ttype, free_p, setfunc, getfunc)
function elem.addctype(ttype, free_p, setfunc, getfunc, abbr)
elem_ctypes[ttype] = setfunc
elem_ctypes[tonumber(ffi.cast('intptr_t', free_p))] = getfunc
if abbr then
elem_ctypes_abbr2name[abbr] = ttype
elem_ctypes_name2abbr[ttype] = abbr
end
end

function elem.set(celem, lelem)
Expand All @@ -29,6 35,10 @@ function elem.set(celem, lelem)
else
error(string.format('unsupported key/value type <%s> (set)', tname and tname or type(lelem)))
end
local abbr = elem_ctypes_name2abbr[tname]
if abbr then
C.tds_elem_set_subtype(celem, abbr)
end
end
end

Expand All @@ -45,11 55,12 @@ function elem.get(celem)
local value = ffi.string(C.tds_elem_get_string(celem), tonumber(C.tds_elem_get_string_size(celem)))
return value
elseif elemtype == 112 then--string.byte('p') then
local subtype = C.tds_elem_subtype(celem)
local lelem_p = C.tds_elem_get_pointer(celem)
local free_p = C.tds_elem_get_pointer_free(celem)
local getfunc = elem_ctypes[tonumber(ffi.cast('intptr_t', free_p))]
if getfunc then
local value = getfunc(lelem_p)
local value = getfunc(lelem_p, subtype)
return value
else
error('unsupported key/value type (get)')
Expand All @@ -61,7 72,7 @@ end

-- torch specific
if pcall(require, 'torch') then
local T = ffi.C
local T = ffi.os ~= 'Windows' and ffi.C or ffi.load('TH')

elem.type = torch.typename

Expand Down Expand Up @@ -110,20 121,21 @@ void THRealTensor_free(THRealTensor *self);
local THTensor_free = T[string.format('TH%sTensor_free', Real)]
local tensor_type_id = string.format('torch.%sTensor', Real)
elem.addctype(
string.format('torch.%sTensor', Real),
tensor_type_id,
THTensor_free,
function(lelem)
local lelem_p = lelem:cdata()
THTensor_retain(lelem_p)
return lelem_p, THTensor_free
end,
function(lelem_p)
function(lelem_p, subtype)
THTensor_retain(lelem_p)
local tensor_type_id = elem_ctypes_abbr2name[subtype]
local lelem = torch.pushudata(lelem_p, tensor_type_id)
return lelem
end
end,
string.byte(Real, 1)
)

end
end

Expand Down
62 changes: 62 additions & 0 deletions tds.def
Original file line number Diff line number Diff line change
@@ -0,0 1,62 @@
LIBRARY tds
EXPORTS
;/* elem */
tds_elem_new
tds_elem_free
tds_elem_hashkey
tds_elem_isequal
tds_elem_set_subtype
tds_elem_set_number
tds_elem_set_boolean
tds_elem_set_string
tds_elem_set_pointer
tds_elem_get_number
tds_elem_get_boolean
tds_elem_get_string
tds_elem_get_string_size
tds_elem_get_pointer
tds_elem_get_pointer_free
tds_elem_type
tds_elem_subtype
tds_elem_free_content
tds_elem_set_nil
tds_elem_isnil

EXPORTS
; /* hash */
tds_hash_new
tds_hash_size
tds_hash_insert
tds_hash_search
tds_hash_remove
tds_hash_retain
tds_hash_free

EXPORTS
; /* hash iterator */
tds_hash_iterator_new
tds_hash_iterator_next
tds_hash_iterator_free

EXPORTS
; /* vec */
tds_vec_new
tds_vec_size
tds_vec_insert
tds_vec_set
tds_vec_get
tds_vec_remove
tds_vec_resize
tds_vec_sort
tds_vec_retain
tds_vec_free

EXPORTS
; /* atomic counter */
tds_has_atomic
tds_atomic_new
tds_atomic_inc
tds_atomic_get
tds_atomic_set
tds_atomic_retain
tds_atomic_free
18 changes: 17 additions & 1 deletion tds_elem.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 78,7 @@ uint32_t tds_elem_hashkey(tds_elem *elem)

int tds_elem_isequal(tds_elem *elem1, tds_elem *elem2)
{
if(elem1->type != elem2->type)
if(elem1->type != elem2->type || elem1->subtype != elem1->subtype)
return 0;
switch(elem1->type) {
case 'n':
Expand All @@ -97,21 97,29 @@ int tds_elem_isequal(tds_elem *elem1, tds_elem *elem2)
}
}

void tds_elem_set_subtype(tds_elem *elem, char subtype)
{
elem->subtype = subtype;
}

void tds_elem_set_number(tds_elem *elem, double num)
{
elem->type = 'n';
elem->subtype = 0;
elem->value.num = num;
}

void tds_elem_set_boolean(tds_elem *elem, bool flag)
{
elem->type = 'b';
elem->subtype = 0;
elem->value.flag = flag;
}

void tds_elem_set_string(tds_elem *elem, const char *str, size_t size)
{
elem->type = 's';
elem->subtype = 0;
elem->value.str.data = tds_malloc(size);
if(elem->value.str.data) {
memcpy(elem->value.str.data, str, size);
Expand All @@ -122,6 130,7 @@ void tds_elem_set_string(tds_elem *elem, const char *str, size_t size)
void tds_elem_set_pointer(tds_elem *elem, void *ptr, void (*free)(void*))
{
elem->type = 'p';
elem->subtype = 0;
elem->value.ptr.data = ptr;
elem->value.ptr.free = free;
}
Expand Down Expand Up @@ -161,18 170,25 @@ char tds_elem_type(tds_elem *elem)
return elem->type;
}

char tds_elem_subtype(tds_elem *elem)
{
return elem->subtype;
}

void tds_elem_free_content(tds_elem *elem)
{
if(elem->type == 's')
tds_free(elem->value.str.data);
if(elem->type == 'p' && elem->value.ptr.free)
elem->value.ptr.free(elem->value.ptr.data);
elem->type = 0;
elem->subtype = 0;
}

void tds_elem_set_nil(tds_elem *elem)
{
elem->type = 0;
elem->subtype = 0;
}

int tds_elem_isnil(tds_elem *elem)
Expand Down
3 changes: 3 additions & 0 deletions tds_elem.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 27,15 @@ typedef struct tds_elem_ {
} value;

char type;
char subtype;

} tds_elem;

tds_elem *tds_elem_new(void);
void tds_elem_free(tds_elem *elem);
uint32_t tds_elem_hashkey(tds_elem *elem);
int tds_elem_isequal(tds_elem *elem1, tds_elem *elem2);
void tds_elem_set_subtype(tds_elem *elem, char subtype);
void tds_elem_set_number(tds_elem *elem, double num);
void tds_elem_set_boolean(tds_elem *elem, bool flag);
void tds_elem_set_string(tds_elem *elem, const char *str, size_t size);
Expand All @@ -45,6 47,7 @@ size_t tds_elem_get_string_size(tds_elem *elem);
void* tds_elem_get_pointer(tds_elem *elem);
tds_elem_pointer_free_ptrfunc tds_elem_get_pointer_free(tds_elem *elem);
char tds_elem_type(tds_elem *elem);
char tds_elem_subtype(tds_elem *elem);
void tds_elem_free_content(tds_elem *elem);
void tds_elem_set_nil(tds_elem *elem);
int tds_elem_isnil(tds_elem *elem);
Expand Down

0 comments on commit 3f16e88

Please sign in to comment.