Skip to content

Commit

Permalink
Merge pull request certik#33 from certik/tok1
Browse files Browse the repository at this point in the history
Print the words as they are generated
  • Loading branch information
certik authored Mar 19, 2023
2 parents b10860e aa7eb93 commit b9d2b28
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
8 changes: 6 additions & 2 deletions gpt2.f90
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 217,8 @@ function generate(n_tokens_to_generate, &
wte, wpe, &
mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, &
attn_w, attn_b, attn_proj_w, attn_proj_b, &
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_cache) result(output)
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_cache, &
decoder_idx, decoder_txt, byte_decoder) result(output)
integer, intent(in) :: n_vocab, n_ctx, n_seq, n_embd, n_layer, n_head, &
n_tokens_to_generate
integer, intent(in) :: input(n_seq)
Expand All @@ -230,6 231,8 @@ function generate(n_tokens_to_generate, &
ln2_b(n_embd,n_layer), ln2_g(n_embd,n_layer), &
lnf_b(n_embd), lnf_g(n_embd)
logical, intent(in) :: use_cache
integer, intent(in) :: decoder_idx(:), byte_decoder(:)
character, intent(in) :: decoder_txt(:)
integer :: output(n_tokens_to_generate)
real(sp), allocatable :: logits(:,:)
integer :: i
Expand Down Expand Up @@ -260,11 263,12 @@ function generate(n_tokens_to_generate, &
attn_w, attn_b, attn_proj_w, attn_proj_b, &
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_kv_cache, kv_cache(:,:n_seq2,:,:))
next_id = maxloc(logits(:,n_seq_x), dim=1)-1
print *, i, next_id
write(*, fmt="(a)", advance="no") decode([next_id], decoder_idx, decoder_txt, byte_decoder)
input2 = [input2, next_id]
deallocate(logits)
end do
output = input2(n_seq 1:)
print *
end function

function c2s(x) result(y)
Expand Down
3 changes: 2 additions & 1 deletion main.f90
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 94,8 @@ program gpt2
wte, wpe, &
mlp_fc_w, mlp_fc_b, mlp_proj_w, mlp_proj_b, &
attn_w, attn_b, attn_proj_w, attn_proj_b, &
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_cache)
ln1_g, ln1_b, ln2_g, ln2_b, lnf_g, lnf_b, use_cache, &
decoder_idx, decoder_txt, byte_decoder)
t2o = omp_get_wtime()
call cpu_time(t2)
print "(a,f8.3,a,f4.2,a)", " done. Time:", t2o-t1o, "s (", (t2-t1)/(t2o-t1o), "x)"
Expand Down

0 comments on commit b9d2b28

Please sign in to comment.