Skip to content

Commit

Permalink
add main structure FFM and FM
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiang.2533 committed Mar 16, 2020
0 parents commit e297344
Show file tree
Hide file tree
Showing 24 changed files with 150,725 additions and 0 deletions.
Empty file added __init__.py
Empty file.
76 changes: 76 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 1,76 @@
import tensorflow as tf

DATA_DIR = './data/adult/{}.csv'
MODEL_DIR = './checkpoint/{}/'

FEATURE_NAME =[
'age', 'workclass', 'fnlwgt', 'education', 'education_num',
'marital_status', 'occupation', 'relationship', 'race', 'gender',
'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
'income_bracket'
]

TARGET = 'income_bracket'

CSV_RECORD_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''],
[0], [0], [0], [''], ['']]
DTYPE ={}
for i ,j in enumerate(CSV_RECORD_DEFAULTS):
if j[0]=='':
DTYPE[FEATURE_NAME[i]] = tf.string
else :
DTYPE[FEATURE_NAME[i]] = tf.float32


MODEL_PARAMS = {
'batch_size':512,
'num_epochs':500,
'buffer_size':512
}

EMB_CONFIGS = {
'workclass':{
'hash_size':10,
'emb_size':5
},
'education':{
'hash_size':10,
'emb_size':5
},
'marital_status':{
'hash_size':10,
'emb_size':5
},
'occupation': {
'hash_size': 100,
'emb_size': 5
},
'relationship': {
'hash_size': 10,
'emb_size': 5
},
'race': {
'hash_size': 10,
'emb_size': 5
},
'gender': {
'hash_size': 10,
'emb_size': 2
},
'native_country':{
'hash_size':100,
'emb_size': 10
}
}

BUCKET_CONFIGS = {
'age':[18, 25, 30, 35, 40, 45, 50, 55, 60, 65],
'fnlwgt':[6*(10**4), 10**5, 1.3*(10**5), 1.5*(10**5), 1.7*(10**5), 1.9*(10**5),
2.1*(10**5), 2.5*(10**5), 3*(10**5)],
'education_num' : [7,8,10,11,13],
'hours_per_week':[25,35,40,45,55],
'capital_gain':[0,1],
'capital_loss':[0,1]
}


32,561 changes: 32,561 additions & 0 deletions data/adult/train.csv

Large diffs are not rendered by default.

16,281 changes: 16,281 additions & 0 deletions data/adult/valid.csv

Large diffs are not rendered by default.

68 changes: 68 additions & 0 deletions data/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 1,68 @@
"""Data Set Information:
Extraction was done by Barry Becker from the 1994 Census database. A set of reasonably clean records was extracted using the following conditions: ((AAGE>16) && (AGI>100) && (AFNLWGT>1)&& (HRSWK>0))
Prediction task is to determine whether a person makes over 50K a year.
Attribute Information:
Listing of attributes:
>50K, <=50K.
age: continuous.
workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.
fnlwgt: continuous.
education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.
education-num: continuous.
marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.
occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.
relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.
race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.
sex: Female, Male.
capital-gain: continuous.
capital-loss: continuous.
hours-per-week: continuous.
native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.
"""

from six.moves import urllib
import re

def data_loader():

DATA_URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.{}'
DATA_DIR = './data/adult/{}.csv'
files = {'data':'train',
'test':'valid'}

for ifile in files.keys():
url = DATA_URL.format(ifile)

f ,_ = urllib.request.urlretrieve(url)

data_dir = DATA_DIR.format(files[ifile])
with open(data_dir, 'w') as f_writer:
with open(f, 'r', encoding='UTF-8') as f_reader:
for line in f_reader:
line = line.strip()
line = re.sub(r'\s ',"", line)

if not line:
continue

if line[-1] == '.':
line = line[:-1]

line ='\n'
f_writer.write(line)


if __name__ == '__main__' :
data_loader()



df = pd.readc
Loading

0 comments on commit e297344

Please sign in to comment.