Skip to content

Commit

Permalink
m
Browse files Browse the repository at this point in the history
  • Loading branch information
justdark committed Apr 9, 2017
0 parents commit eac17b9
Show file tree
Hide file tree
Showing 13 changed files with 17,805 additions and 0 deletions.
1 change: 1 addition & 0 deletions chinese-poetry
Submodule chinese-poetry added at 7bc2c6
39 changes: 39 additions & 0 deletions dataHandler.py
Original file line number Diff line number Diff line change
@@ -0,0 1,39 @@
#coding:utf-8
import sys
import os
import json
import re

def parseRawData():
rst = []
def sentenceParse(para):
# para = "-181-村橋路不端,數里就迴湍。積壤連涇脉,高林上笋竿。早嘗甘蔗淡,生摘琵琶酸。(「琵琶」,嚴壽澄校《張祜詩集》云:疑「枇杷」之誤。)好是去塵俗,煙花長一欄。"
result, number = re.subn("(.*)", "", para)
result, number = re.subn("{.*}", "", result)
result, number = re.subn("[\]\[]", "", result)
return result.strip("[0123456789-]")

def handleJson(file):
# print file
rst = []
data = json.loads(open(file).read())
for poetry in data:
pdata = ""
for sentence in poetry.get("paragraphs"):
pdata = sentence
pdata = sentenceParse(pdata)
if pdata!="":
rst.append(pdata)
return rst
# print sentenceParse("")
data = []
src = './chinese-poetry/json/'
for filename in os.listdir(src):
if filename.startswith("poet.tang"):
data.extend(handleJson(src filename))
return data

if __name__=='__main__':
data = parseRawData()
for s in data:
print s
Binary file added dataHandler.pyc
Binary file not shown.
28 changes: 28 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 1,28 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

class PoetryModel(nn.Module):
def __init__(self, vocab_size,hidden_dim):
super(PoetryModel, self).__init__()
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(vocab_size, self.hidden_dim)
self.linear1 = nn.Linear(self.hidden_dim,vocab_size)
self.dropout = nn.Dropout(0.1)
self.softmax = nn.LogSoftmax()

def forward(self, input, hidden):
length = input.size()[0]

output, hidden = self.lstm(input.view(length, 1, -1), hidden)
# print output.size()

output = F.relu(self.linear1(output.view(length,-1)))
output = self.dropout(output)
output = self.softmax(output)
return output, hidden

def initHidden(self,length = 1):
return (Variable(torch.zeros(length, 1, self.hidden_dim).cuda()),
Variable(torch.zeros(length, 1, self.hidden_dim)).cuda())
Binary file added model.pyc
Binary file not shown.
Binary file added poetry-gen.pt
Binary file not shown.
39 changes: 39 additions & 0 deletions sample.py
Original file line number Diff line number Diff line change
@@ -0,0 1,39 @@
#coding:utf-8
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import dataHandler
from model import PoetryModel
from torch.autograd import Variable
import cPickle as p
from utils import *

model = torch.load('poetry-gen.pt')
max_length = 20
rFile =file('wordDic','r')

word_to_ix = p.load(rFile)
def invert_dict(d):
return dict((v,k) for k,v in d.iteritems())
ix_to_word = invert_dict(word_to_ix)

# Sample from a category and starting letter
def sample(startWord='<START>'):
input = make_one_hot_vec(startWord,word_to_ix)
hidden = model.initHidden()
output_name = "";
for i in range(max_length):
output, hidden = model(input.cuda(), hidden)
topv, topi = output.data.topk(1)
topi = topi[0][0]
w = ix_to_word[topi]
if w == "<EOP>":
break
else:
output_name = w
input = make_one_hot_vec(w,word_to_ix)
print output_name
return output_name
sample()
# sample("床".decode('utf-8'))
114 changes: 114 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 1,114 @@
#coding:utf-8
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import dataHandler
from model import PoetryModel
from torch.autograd import Variable
from utils import *
import cPickle as p

data = dataHandler.parseRawData()


word_to_ix = {}
for sent in data:
for word in sent:
if word not in word_to_ix:
word_to_ix[word] = len(word_to_ix)
word_to_ix['<EOP>'] = len(word_to_ix)
word_to_ix['<START>'] = len(word_to_ix)
VOCAB_SIZE = len(word_to_ix)
print "VOCAB_SIZE:",VOCAB_SIZE
print "data_size",len(data)
def toList(sen):
rst = []
for s in sen:
rst.append(s)
return rst
for i in range(len(data)):
if data[i][0]=="-":
print data[i]
data[i] = toList(data[i])

# print data[i]
data[i].append("<EOP>")


p.dump(word_to_ix,file('wordDic','w'))

# for s in data:
# print i
# i = 1
# transferData.append(make_one_hot_data(s,word_to_ix))

model = PoetryModel(len(word_to_ix),128);
model.cuda()
optimizer = optim.RMSprop(model.parameters(), lr=0.01)

criterion = nn.NLLLoss()
one_hot_var = {}
one_hot_var_target = {}
for w in word_to_ix:
one_hot_var.setdefault(w,make_one_hot_vec(w,word_to_ix))
for w in word_to_ix:
one_hot_var_target.setdefault(w, make_one_hot_vec_target(w,word_to_ix))
wordList = open('wordList','w')
for w in word_to_ix:
wordList.write(w.encode('utf-8'))
wordList.close()

epochNum = 10;
TRAINSIZE = len(data)
batch = 100
trainingIn = []
trainingOut = []
def makeForOneCase(s):
tmpIn = []
tmpOut = []
for i in range(len(s) - 1):
w = s[i]
w_b = s[i - 1] if s > 0 else "<START>"
tmpIn.append(one_hot_var[w_b])
tmpOut.append(one_hot_var_target[w])
return torch.cat(tmpIn),torch.cat(tmpOut)

# for case in range(TRAINSIZE):
# s = data[case]
# makeForOneCase(s)


print "start training"
for epoch in range(epochNum):

for batchIndex in range(int(TRAINSIZE/batch)):
model.zero_grad()
loss = 0
for case in range(batchIndex*batch,(batchIndex 1)*batch):

s = data[case]


hidden = model.initHidden()

# for i in range(len(s)-1):
# w = s[i]
# w_b = s[i-1] if s>0 else "<START>"
# output, hidden = model(one_hot_var[w_b], hidden)
# loss = criterion(output,one_hot_var_target[w])
t,o = makeForOneCase(s)
output, hidden = model(t.cuda(),hidden)
loss = criterion(output, o.cuda())
loss = loss/batch
loss.backward()
print epoch,loss.data[0]
optimizer.step()
torch.save(model, 'poetry-gen.pt')


#
# for s in transferData:
# for i in range(s.size()[0]):
# print s[i]
# sdasdas
Binary file added train.pyc
Binary file not shown.
11 changes: 11 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 1,11 @@
import torch
import torch.autograd as autograd

def make_one_hot_vec(word, word_to_ix):
rst = torch.zeros(1,1,len(word_to_ix))
rst[0][0][word_to_ix[word]] = 1
return autograd.Variable(rst)

def make_one_hot_vec_target(word, word_to_ix):
rst = autograd.Variable(torch.LongTensor([word_to_ix[word]]))
return rst
Binary file added utils.pyc
Binary file not shown.
Loading

0 comments on commit eac17b9

Please sign in to comment.