-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit eac17b9
Showing
13 changed files
with
17,805 additions
and
0 deletions.
There are no files selected for viewing
Submodule chinese-poetry
added at
7bc2c6
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,39 @@ | ||
#coding:utf-8 | ||
import sys | ||
import os | ||
import json | ||
import re | ||
|
||
def parseRawData(): | ||
rst = [] | ||
def sentenceParse(para): | ||
# para = "-181-村橋路不端,數里就迴湍。積壤連涇脉,高林上笋竿。早嘗甘蔗淡,生摘琵琶酸。(「琵琶」,嚴壽澄校《張祜詩集》云:疑「枇杷」之誤。)好是去塵俗,煙花長一欄。" | ||
result, number = re.subn("(.*)", "", para) | ||
result, number = re.subn("{.*}", "", result) | ||
result, number = re.subn("[\]\[]", "", result) | ||
return result.strip("[0123456789-]") | ||
|
||
def handleJson(file): | ||
# print file | ||
rst = [] | ||
data = json.loads(open(file).read()) | ||
for poetry in data: | ||
pdata = "" | ||
for sentence in poetry.get("paragraphs"): | ||
pdata = sentence | ||
pdata = sentenceParse(pdata) | ||
if pdata!="": | ||
rst.append(pdata) | ||
return rst | ||
# print sentenceParse("") | ||
data = [] | ||
src = './chinese-poetry/json/' | ||
for filename in os.listdir(src): | ||
if filename.startswith("poet.tang"): | ||
data.extend(handleJson(src filename)) | ||
return data | ||
|
||
if __name__=='__main__': | ||
data = parseRawData() | ||
for s in data: | ||
print s |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,28 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
import torch.nn.functional as F | ||
|
||
class PoetryModel(nn.Module): | ||
def __init__(self, vocab_size,hidden_dim): | ||
super(PoetryModel, self).__init__() | ||
self.hidden_dim = hidden_dim | ||
self.lstm = nn.LSTM(vocab_size, self.hidden_dim) | ||
self.linear1 = nn.Linear(self.hidden_dim,vocab_size) | ||
self.dropout = nn.Dropout(0.1) | ||
self.softmax = nn.LogSoftmax() | ||
|
||
def forward(self, input, hidden): | ||
length = input.size()[0] | ||
|
||
output, hidden = self.lstm(input.view(length, 1, -1), hidden) | ||
# print output.size() | ||
|
||
output = F.relu(self.linear1(output.view(length,-1))) | ||
output = self.dropout(output) | ||
output = self.softmax(output) | ||
return output, hidden | ||
|
||
def initHidden(self,length = 1): | ||
return (Variable(torch.zeros(length, 1, self.hidden_dim).cuda()), | ||
Variable(torch.zeros(length, 1, self.hidden_dim)).cuda()) |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,39 @@ | ||
#coding:utf-8 | ||
import torch | ||
import torch.autograd as autograd | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import dataHandler | ||
from model import PoetryModel | ||
from torch.autograd import Variable | ||
import cPickle as p | ||
from utils import * | ||
|
||
model = torch.load('poetry-gen.pt') | ||
max_length = 20 | ||
rFile =file('wordDic','r') | ||
|
||
word_to_ix = p.load(rFile) | ||
def invert_dict(d): | ||
return dict((v,k) for k,v in d.iteritems()) | ||
ix_to_word = invert_dict(word_to_ix) | ||
|
||
# Sample from a category and starting letter | ||
def sample(startWord='<START>'): | ||
input = make_one_hot_vec(startWord,word_to_ix) | ||
hidden = model.initHidden() | ||
output_name = ""; | ||
for i in range(max_length): | ||
output, hidden = model(input.cuda(), hidden) | ||
topv, topi = output.data.topk(1) | ||
topi = topi[0][0] | ||
w = ix_to_word[topi] | ||
if w == "<EOP>": | ||
break | ||
else: | ||
output_name = w | ||
input = make_one_hot_vec(w,word_to_ix) | ||
print output_name | ||
return output_name | ||
sample() | ||
# sample("床".decode('utf-8')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,114 @@ | ||
#coding:utf-8 | ||
import torch | ||
import torch.autograd as autograd | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import dataHandler | ||
from model import PoetryModel | ||
from torch.autograd import Variable | ||
from utils import * | ||
import cPickle as p | ||
|
||
data = dataHandler.parseRawData() | ||
|
||
|
||
word_to_ix = {} | ||
for sent in data: | ||
for word in sent: | ||
if word not in word_to_ix: | ||
word_to_ix[word] = len(word_to_ix) | ||
word_to_ix['<EOP>'] = len(word_to_ix) | ||
word_to_ix['<START>'] = len(word_to_ix) | ||
VOCAB_SIZE = len(word_to_ix) | ||
print "VOCAB_SIZE:",VOCAB_SIZE | ||
print "data_size",len(data) | ||
def toList(sen): | ||
rst = [] | ||
for s in sen: | ||
rst.append(s) | ||
return rst | ||
for i in range(len(data)): | ||
if data[i][0]=="-": | ||
print data[i] | ||
data[i] = toList(data[i]) | ||
|
||
# print data[i] | ||
data[i].append("<EOP>") | ||
|
||
|
||
p.dump(word_to_ix,file('wordDic','w')) | ||
|
||
# for s in data: | ||
# print i | ||
# i = 1 | ||
# transferData.append(make_one_hot_data(s,word_to_ix)) | ||
|
||
model = PoetryModel(len(word_to_ix),128); | ||
model.cuda() | ||
optimizer = optim.RMSprop(model.parameters(), lr=0.01) | ||
|
||
criterion = nn.NLLLoss() | ||
one_hot_var = {} | ||
one_hot_var_target = {} | ||
for w in word_to_ix: | ||
one_hot_var.setdefault(w,make_one_hot_vec(w,word_to_ix)) | ||
for w in word_to_ix: | ||
one_hot_var_target.setdefault(w, make_one_hot_vec_target(w,word_to_ix)) | ||
wordList = open('wordList','w') | ||
for w in word_to_ix: | ||
wordList.write(w.encode('utf-8')) | ||
wordList.close() | ||
|
||
epochNum = 10; | ||
TRAINSIZE = len(data) | ||
batch = 100 | ||
trainingIn = [] | ||
trainingOut = [] | ||
def makeForOneCase(s): | ||
tmpIn = [] | ||
tmpOut = [] | ||
for i in range(len(s) - 1): | ||
w = s[i] | ||
w_b = s[i - 1] if s > 0 else "<START>" | ||
tmpIn.append(one_hot_var[w_b]) | ||
tmpOut.append(one_hot_var_target[w]) | ||
return torch.cat(tmpIn),torch.cat(tmpOut) | ||
|
||
# for case in range(TRAINSIZE): | ||
# s = data[case] | ||
# makeForOneCase(s) | ||
|
||
|
||
print "start training" | ||
for epoch in range(epochNum): | ||
|
||
for batchIndex in range(int(TRAINSIZE/batch)): | ||
model.zero_grad() | ||
loss = 0 | ||
for case in range(batchIndex*batch,(batchIndex 1)*batch): | ||
|
||
s = data[case] | ||
|
||
|
||
hidden = model.initHidden() | ||
|
||
# for i in range(len(s)-1): | ||
# w = s[i] | ||
# w_b = s[i-1] if s>0 else "<START>" | ||
# output, hidden = model(one_hot_var[w_b], hidden) | ||
# loss = criterion(output,one_hot_var_target[w]) | ||
t,o = makeForOneCase(s) | ||
output, hidden = model(t.cuda(),hidden) | ||
loss = criterion(output, o.cuda()) | ||
loss = loss/batch | ||
loss.backward() | ||
print epoch,loss.data[0] | ||
optimizer.step() | ||
torch.save(model, 'poetry-gen.pt') | ||
|
||
|
||
# | ||
# for s in transferData: | ||
# for i in range(s.size()[0]): | ||
# print s[i] | ||
# sdasdas |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,11 @@ | ||
import torch | ||
import torch.autograd as autograd | ||
|
||
def make_one_hot_vec(word, word_to_ix): | ||
rst = torch.zeros(1,1,len(word_to_ix)) | ||
rst[0][0][word_to_ix[word]] = 1 | ||
return autograd.Variable(rst) | ||
|
||
def make_one_hot_vec_target(word, word_to_ix): | ||
rst = autograd.Variable(torch.LongTensor([word_to_ix[word]])) | ||
return rst |
Binary file not shown.
Oops, something went wrong.