Skip to content

Commit

Permalink
Use load_model()
Browse files Browse the repository at this point in the history
  • Loading branch information
certik committed Mar 21, 2023
1 parent aa202eb commit fa3d98b
Showing 1 changed file with 25 additions and 63 deletions.
88 changes: 25 additions & 63 deletions driver.f90
Original file line number Diff line number Diff line change
Expand Up @@ -95,20 95,9 @@ subroutine gpt2_driver2(input_txt, n_tokens_to_generate, input, output)
character(*), intent(in) :: input_txt
integer, intent(in) :: n_tokens_to_generate
integer, allocatable, intent(out) :: input(:), output(:)
integer :: n_vocab, n_ctx, n_seq, n_embd, n_layer, n_head, &
n_decoder_idx, n_decoder_txt, &
n_vocab_idx, n_vocab_txt, n_byte_encoder
integer, allocatable :: decoder_idx(:), vocab_idx(:), byte_decoder(:)
integer :: byte_encoder(0:255)
real(sp), allocatable :: wte(:,:), wpe(:,:), &
mlp_fc_w(:,:,:), mlp_fc_b(:,:), &
mlp_proj_w(:,:,:), mlp_proj_b(:,:), &
attn_w(:,:,:), attn_b(:,:), &
attn_proj_w(:,:,:), attn_proj_b(:,:), &
ln1_b(:,:), ln1_g(:,:), &
ln2_b(:,:), ln2_g(:,:), &
lnf_b(:), lnf_g(:)
character, allocatable :: decoder_txt(:), vocab_txt(:)
type(model_t) :: m
integer, allocatable :: byte_decoder(:)
integer :: n_seq
character(:), allocatable :: output_txt
real(dp) :: t1, t2, t1o, t2o
integer :: u, i
Expand All @@ -117,51 106,23 @@ subroutine gpt2_driver2(input_txt, n_tokens_to_generate, input, output)
! Load the model
print "(a)", "Loading the model..."
call cpu_time(t1)
open(newunit=u, file="model.dat", form="unformatted", access="stream", status="old")
!read(u) model_version
! fastGPT (digits look similar to the letters they represent)
! model_version /= 0xfa51697
read(u) n_vocab, n_ctx, n_embd, n_layer, n_head, n_decoder_idx, n_decoder_txt, &
n_vocab_idx, n_vocab_txt, n_byte_encoder
allocate(wte(n_embd,n_vocab), wpe(n_embd,n_ctx), &
mlp_fc_w(4*n_embd,n_embd,n_layer), mlp_fc_b(4*n_embd,n_layer), &
mlp_proj_w(n_embd,4*n_embd,n_layer), mlp_proj_b(n_embd,n_layer), &
attn_w(3*n_embd,n_embd,n_layer), attn_b(3*n_embd,n_layer), &
attn_proj_w(n_embd,n_embd,n_layer), attn_proj_b(n_embd,n_layer), &
ln1_b(n_embd,n_layer), ln1_g(n_embd,n_layer), &
ln2_b(n_embd,n_layer), ln2_g(n_embd,n_layer), &
lnf_b(n_embd), lnf_g(n_embd), &
decoder_idx(n_decoder_idx), decoder_txt(n_decoder_txt), &
vocab_idx(n_vocab_idx), vocab_txt(n_vocab_txt))
if (n_byte_encoder /= 256) error stop "n_byte_encoder must be 256"
read(u) wte, wpe, &
mlp_fc_w, mlp_fc_b, &
mlp_proj_w, mlp_proj_b, &
attn_w, attn_b, &
attn_proj_w, attn_proj_b, &
ln1_b, ln1_g, &
ln2_b, ln2_g, &
lnf_b, lnf_g, &
decoder_idx, decoder_txt, &
vocab_idx, vocab_txt, &
byte_encoder
close(u)
call load_model("model.dat", m)
call cpu_time(t2)
print "(a,f8.3,a)", " done. Time:", t2-t1, "s"
print *
print "(a)", "Model parameters:"
print "(a,i6)", "n_vocab =", n_vocab
print "(a,i6)", "n_ctx =", n_ctx
print "(a,i6)", "n_embd =", n_embd
print "(a,i6)", "n_layer =", n_layer
print "(a,i6)", "n_head =", n_head
print "(a,i6)", "n_vocab =", m%n_vocab
print "(a,i6)", "n_ctx =", m%n_ctx
print "(a,i6)", "n_embd =", m%n_embd
print "(a,i6)", "n_layer =", m%n_layer
print "(a,i6)", "n_head =", m%n_head
print *

! Compute byte_decoder:
allocate(byte_decoder(0:maxval(byte_encoder)))
allocate(byte_decoder(0:maxval(m%byte_encoder)))
byte_decoder = 0
do i = 0, size(byte_encoder)-1
byte_decoder(byte_encoder(i)) = i
do i = 0, size(m%byte_encoder)-1
byte_decoder(m%byte_encoder(i)) = i
end do

print "(a)", "Input text"
Expand All @@ -170,8 131,8 @@ subroutine gpt2_driver2(input_txt, n_tokens_to_generate, input, output)
print *
print "(a)", "Encoding: tokenizing input text into tokens (currently slow)..."
call cpu_time(t1)
input = encode(input_txt, decoder_idx, decoder_txt, vocab_idx, vocab_txt, &
byte_encoder)
input = encode(input_txt, m%decoder_idx, m%decoder_txt, m%vocab_idx, m%vocab_txt, &
m%byte_encoder)
call cpu_time(t2)
n_seq = size(input)
print "(a,f8.3,a)", " done. Time:", t2-t1, "s"
Expand All @@ -184,7 145,7 @@ subroutine gpt2_driver2(input_txt, n_tokens_to_generate, input, output)
print "(1000(i6))", input
print *

if (n_seq n_tokens_to_generate >= n_ctx) then
if (n_seq n_tokens_to_generate >= m%n_ctx) then
print *, "The maximum sequence length of the model was surpassed."
print *, "Make the input and/or number of tokens to generate shorter."
error stop
Expand All @@ -193,7 154,7 @@ subroutine gpt2_driver2(input_txt, n_tokens_to_generate, input, output)
print "(a)", "Decoded input as text:"
!print "(a)", decode(input, decoder_idx, decoder_txt, byte_decoder)
allocate(character(0) :: output_txt) ! Fix GFortran warning
output_txt = decode(input, decoder_idx, decoder_txt, byte_decoder)
output_txt = decode(input, m%decoder_idx, m%decoder_txt, byte_decoder)
print "(a)", output_txt
print *

Expand All @@ -206,21 167,22 @@ subroutine gpt2_driver2(input_txt, n_tokens_to_generate, input, output)
call cpu_time(t1)
t1o = omp_get_wtime()
use_cache = .true.
output = generate(n_tokens_to_generate, n_vocab, n_ctx, size(input), n_embd, &
n_layer, n_head, &
output = generate(n_tokens_to_generate, m%n_vocab, m%n_ctx, size(input), &
m%n_embd, &
m%n_layer, m%n_head, &
input, &
wte, wpe, &
mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, &
attn_w, attn_b, attn_proj_w, attn_proj_b, &
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_cache, &
decoder_idx, decoder_txt, byte_decoder)
m%wte, m%wpe, &
m%mlp_fc_w, m%mlp_fc_b, m%mlp_proj_w, m%mlp_proj_b, &
m%attn_w, m%attn_b, m%attn_proj_w, m%attn_proj_b, &
m%ln1_g, m%ln1_b, m%ln2_g, m%ln2_b, m%lnf_g, m%lnf_b, use_cache, &
m%decoder_idx, m%decoder_txt, byte_decoder)
t2o = omp_get_wtime()
call cpu_time(t2)
print "(a,f8.3,a,f4.2,a)", " done. Time:", t2o-t1o, "s (", (t2-t1)/(t2o-t1o), "x)"
print *
print "(a)", "Output tokens:"
print "(1000(i6))", output
output_txt = decode(output, decoder_idx, decoder_txt, byte_decoder)
output_txt = decode(output, m%decoder_idx, m%decoder_txt, byte_decoder)
print *
print "(a)", "Decoded output as text:"
print "(a)", output_txt
Expand Down

0 comments on commit fa3d98b

Please sign in to comment.